Spaces:
Sleeping
Sleeping
| import random | |
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| from Src.Processing import load_data | |
| from Src.Processing import process_data | |
| from Src.Inference import load_model | |
| from Src.NST_Inference import save_style | |
| import torch | |
| import time | |
| import os | |
| import mne | |
| import matplotlib.pyplot as plt | |
| import io | |
| import matplotlib.cm as cm | |
| import gradio as gr | |
| dummy_emotion_data = pd.DataFrame({ | |
| 'Emotion': ['sad', 'dis', 'fear', 'neu', 'joy', 'ten', 'ins'], | |
| 'Value': [0.8, 0.6, 0.1, 0.4, 0.7, 0.2, 0.3] | |
| }) | |
| int_to_emotion = { | |
| 0: 'sad', | |
| 1: 'dis', | |
| 2: 'fear', | |
| 3: 'neu', | |
| 4: 'joy', | |
| 5: 'ten', | |
| 6: 'ins' | |
| } | |
| abr_to_emotion = { | |
| 'sad': "sadness", | |
| 'dis': "disgust", | |
| 'fear': "fear", | |
| 'neu': "neutral", | |
| 'joy': "joy", | |
| 'ten': 'Tenderness', | |
| 'ins': "inspiration" | |
| } | |
| PAINTERS_BASE_DIR = "Painters" | |
| EMOTION_BASE_DIR = "Emotions" | |
| output_dir = "outputs" | |
| input_size = 320 | |
| hidden_size=50 | |
| output_size = 7 | |
| num_layers=1 | |
| painters = ["Pablo Picasso", "Vincent van Gogh", "Salvador Dalí"] | |
| predefined_psd_files = ["task-emotion_psd_1.npy", "task-emotion_psd_2.npy", "task-emotion_psd_3.npy"] | |
| Base_Dir = "Datasets" | |
| PAINTER_PLACEHOLDER_DATA = { | |
| "Pablo Picasso": [ | |
| ("Dora Maar with Cat (1941).png", "Dora Maar with Cat (1941)"), | |
| ("The Weeping Woman (1937).png", "The Weeping Woman (1937)"), | |
| ("Three Musicians (1921).png", "Three Musicians (1921)"), | |
| ], | |
| "Vincent van Gogh": [ | |
| ("Sunflowers (1888).png", "Sunflowers (1888)"), | |
| ("The Starry Night (1889).png", "The Starry Night (1889)"), | |
| ("The Potato Eaters (1885).png", "The Potato Eaters (1885)"), | |
| ], | |
| "Salvador Dalí": [ | |
| ("Persistence of Memory (1931).png", "Persistence of Memory (1931)"), | |
| ("Swans Reflecting Elephants (1937).png", "Swans Reflecting Elephants (1937)"), | |
| ("Sleep (1937).png", "Sleep (1937)"), | |
| ], | |
| } | |
| def upload_psd_file(selected_file_name): | |
| """ | |
| Processes a selected PSD file, performs inference, and prepares emotion distribution data. | |
| """ | |
| if selected_file_name is None: | |
| return gr.BarPlot(dummy_emotion_data, x="Emotion", y="Value", label="Emotion Distribution", visible=False), pd.DataFrame() | |
| psd_file_path = os.path.join(Base_Dir, selected_file_name).replace(os.sep, '/') | |
| try: | |
| global np_data | |
| np_data = load_data(psd_file_path) | |
| print(f"np data orig {np_data.shape}") | |
| except FileNotFoundError: | |
| print(f"Error: PSD file not found at {psd_file_path}") | |
| # Return a plot with error message or just hide it | |
| return gr.BarPlot(dummy_emotion_data, x="Emotion", y="Value", label="Emotion Distribution (Error: File not found)", visible=False), pd.DataFrame() | |
| final_data = process_data(np_data) | |
| torch_data = torch.tensor(final_data, dtype=torch.float32).unsqueeze(0) | |
| absolute_model_path = os.path.join("models", "lstm_emotion_model_state.pth") | |
| loaded_model = load_model(absolute_model_path, input_size, hidden_size, output_size, num_layers) | |
| loaded_model.eval() | |
| with torch.no_grad(): | |
| predicted_logits, _ = loaded_model(torch_data) | |
| final_output_indices = torch.argmax(predicted_logits, dim=2) | |
| all_predicted_indices = final_output_indices.view(-1) | |
| # Count occurrences of each predicted emotion index | |
| values_count = torch.bincount(all_predicted_indices, minlength=output_size) | |
| print(f"Raw bincount: {values_count}") | |
| emotions_count = {int_to_emotion[i].strip(): 0 for i in range(output_size)} | |
| for idx, count in enumerate(values_count): | |
| if idx < output_size: | |
| emotions_count[int_to_emotion[idx].strip()] = count.item() | |
| dom_emotion = max(emotions_count, key=emotions_count.get) | |
| emotion_data = pd.DataFrame({ | |
| "Emotion": list(emotions_count.keys()), | |
| "Frequency": list(emotions_count.values()) | |
| }) | |
| emotion_data = emotion_data.sort_values(by="Emotion").reset_index(drop=True) | |
| print(f"Final emotion_data DataFrame:\n{emotion_data}") | |
| return gr.BarPlot( | |
| emotion_data, | |
| x="Emotion", | |
| y="Frequency", | |
| label="Emotion Distribution", | |
| visible=True, | |
| y_title="Frequency" | |
| ), emotion_data, gr.Textbox(abr_to_emotion[dom_emotion], visible=True) | |
| def update_paintings(painter_name): | |
| """ | |
| Updates the gallery with paintings specific to the selected painter by | |
| dynamically listing files in the painter's directory. | |
| """ | |
| painter_dir = os.path.join(PAINTERS_BASE_DIR, painter_name).replace(os.sep, '/') | |
| artist_paintings_for_gallery = [] | |
| if os.path.isdir(painter_dir): | |
| for filename in sorted(os.listdir(painter_dir)): | |
| if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.gif')): | |
| file_path = os.path.join(painter_dir, filename).replace(os.sep, '/') | |
| print(file_path) | |
| title_with_ext = os.path.splitext(filename)[0] | |
| artist_paintings_for_gallery.append((file_path, title_with_ext)) | |
| print(f"Loaded paintings for {painter_name}: {artist_paintings_for_gallery}") | |
| return artist_paintings_for_gallery | |
| def generate_my_art(painter, chosen_painting, dom_emotion): | |
| if not painter or not chosen_painting: | |
| return "Please select a painter and a painting.", None, None | |
| img_style_pth = os.path.join(PAINTERS_BASE_DIR, painter, chosen_painting) | |
| print(f"img_stype_path: {img_style_pth}") | |
| time.sleep(3) | |
| ##original image | |
| emotion_pth = os.path.join(EMOTION_BASE_DIR, dom_emotion) | |
| image_name = list(os.listdir(emotion_pth))[random.randint(0, len(os.listdir(emotion_pth)) -1)] | |
| original_image_pth = os.path.join(emotion_pth, image_name) | |
| print(f"original img _path: {original_image_pth}") | |
| final_message = f"Art generated based on {painter}'s {chosen_painting} style!" | |
| ## Neural Style Transfer | |
| stylized_img_path = save_style(output_dir, original_image_pth, img_style_pth) | |
| yield gr.Textbox(final_message), original_image_pth, stylized_img_path | |
| # --- Gradio Interface Definition --- | |
| with gr.Blocks(css=".gradio-container { max-width: 2000px; margin: auto; }") as demo: | |
| current_emotion_df_state = gr.State(value=pd.DataFrame()) | |
| # Header Section | |
| gr.Markdown( | |
| """ | |
| <h1 style="text-align: center;font-size: 5em; padding: 20px; font-weight: bold;">Brain Emotion Decoder 🧠🎨</h1> | |
| <p style="text-align: center; font-size: 1.5em; color: #555;font-weight: bold;"> | |
| Imagine seeing your deepest feelings transform into art. We decode the underlying emotions from your brain activity, | |
| generating a personalized artwork. Discover the art of your inner self. | |
| </p> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("<h2 font-size: 2em;>1. Choose a PSD file<h2>") | |
| psd_file_selection = gr.Radio( | |
| choices=predefined_psd_files, | |
| label="Select a PSD file for analysis", | |
| value=predefined_psd_files[0], | |
| interactive=True | |
| ) | |
| analyze_psd_button = gr.Button("Analyze PSD File", variant="secondary") | |
| gr.Markdown("<h2 font-size: 2em;>2. Emotion Distribution<h2>") | |
| emotion_distribution_plot = gr.BarPlot( | |
| dummy_emotion_data, | |
| x="Emotion", | |
| y="Value", | |
| label="Emotion Distribution", | |
| height=300, | |
| x_title="Emotion Type", | |
| y_title="Frequency", | |
| visible=False | |
| ) | |
| dom_emotion = gr.Textbox(label = "dominant emotion", visible=False) | |
| # Right Column: Art Museum and Generation | |
| with gr.Column(scale=1): | |
| gr.Markdown("<h3>Your Art Mesum</h3>") # Kept original heading | |
| gr.Markdown("<h3>3. Choose your favourite painter</h3>") | |
| painter_dropdown = gr.Dropdown( | |
| choices=painters, | |
| value="Pablo Picasso", # Default selection | |
| label="Select a Painter" | |
| ) | |
| gr.Markdown("<h3>4. Choose your favourite painting</h3>") | |
| painting_gallery = gr.Gallery( | |
| value=update_paintings("Pablo Picasso"), # Initial load for Picasso's paintings | |
| label="Select a Painting", | |
| height=300, | |
| columns=3, | |
| rows=1, | |
| object_fit="contain", | |
| preview=True, | |
| interactive=True, | |
| elem_id="painting_gallery", | |
| visible=True, | |
| ) | |
| selected_painting_name = gr.Textbox(visible=False) | |
| generate_button = gr.Button("Generate My Art", variant="primary", scale=0) | |
| status_message = gr.Textbox( | |
| value="Click 'Generate My Art' to begin.", | |
| label="Generation Status", | |
| interactive=False, | |
| show_label=False, | |
| lines=1 | |
| ) | |
| gr.Markdown( | |
| """ | |
| <h1 style="text-align: center;">Your Generated Artwork</h1> | |
| <p style="text-align: center; color: #555;"> | |
| Once your brain's emotional data is processed, we pinpoint the <b>dominant emotion</b>. This single feeling inspires a <b>personalized artwork</b>. You can then <b>download</b> this unique visual representation of your inner self. | |
| </p> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("<h3>Generated Image</h3>") | |
| generated_image_output = gr.Image(label="Generated Image", show_label=False, height=300) | |
| gr.Markdown("<h3>Blended Style Image</h3>") | |
| blended_image_output = gr.Image(label="Blended Style Image", show_label=False, height=300) | |
| # --- Event Listeners --- | |
| analyze_psd_button.click( | |
| upload_psd_file, | |
| inputs=[psd_file_selection], | |
| outputs=[emotion_distribution_plot, current_emotion_df_state, dom_emotion] | |
| ) | |
| painter_dropdown.change( | |
| update_paintings, | |
| inputs=[painter_dropdown], | |
| outputs=[painting_gallery] | |
| ) | |
| def on_select(evt: gr.SelectData): | |
| print("this function started") | |
| print(f"Image index: {evt.index}\nImage value: {evt.value['image']['orig_name']}") | |
| return evt.value['image']['orig_name'] | |
| painting_gallery.select( | |
| on_select, | |
| outputs=[selected_painting_name] | |
| ) | |
| generate_button.click( | |
| generate_my_art, | |
| inputs=[painter_dropdown, selected_painting_name, dom_emotion], | |
| outputs=[status_message, generated_image_output, blended_image_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |