import random import gradio as gr import pandas as pd import numpy as np from Src.Processing import load_data from Src.Processing import process_data from Src.Inference import load_model from Src.NST_Inference import save_style import torch import time import os import mne import matplotlib.pyplot as plt import io import matplotlib.cm as cm import gradio as gr dummy_emotion_data = pd.DataFrame({ 'Emotion': ['sad', 'dis', 'fear', 'neu', 'joy', 'ten', 'ins'], 'Value': [0.8, 0.6, 0.1, 0.4, 0.7, 0.2, 0.3] }) int_to_emotion = { 0: 'sad', 1: 'dis', 2: 'fear', 3: 'neu', 4: 'joy', 5: 'ten', 6: 'ins' } abr_to_emotion = { 'sad': "sadness", 'dis': "disgust", 'fear': "fear", 'neu': "neutral", 'joy': "joy", 'ten': 'Tenderness', 'ins': "inspiration" } PAINTERS_BASE_DIR = "Painters" EMOTION_BASE_DIR = "Emotions" output_dir = "outputs" input_size = 320 hidden_size=50 output_size = 7 num_layers=1 painters = ["Pablo Picasso", "Vincent van Gogh", "Salvador Dalí"] predefined_psd_files = ["task-emotion_psd_1.npy", "task-emotion_psd_2.npy", "task-emotion_psd_3.npy"] Base_Dir = "Datasets" PAINTER_PLACEHOLDER_DATA = { "Pablo Picasso": [ ("Dora Maar with Cat (1941).png", "Dora Maar with Cat (1941)"), ("The Weeping Woman (1937).png", "The Weeping Woman (1937)"), ("Three Musicians (1921).png", "Three Musicians (1921)"), ], "Vincent van Gogh": [ ("Sunflowers (1888).png", "Sunflowers (1888)"), ("The Starry Night (1889).png", "The Starry Night (1889)"), ("The Potato Eaters (1885).png", "The Potato Eaters (1885)"), ], "Salvador Dalí": [ ("Persistence of Memory (1931).png", "Persistence of Memory (1931)"), ("Swans Reflecting Elephants (1937).png", "Swans Reflecting Elephants (1937)"), ("Sleep (1937).png", "Sleep (1937)"), ], } def upload_psd_file(selected_file_name): """ Processes a selected PSD file, performs inference, and prepares emotion distribution data. """ if selected_file_name is None: return gr.BarPlot(dummy_emotion_data, x="Emotion", y="Value", label="Emotion Distribution", visible=False), pd.DataFrame() psd_file_path = os.path.join(Base_Dir, selected_file_name).replace(os.sep, '/') try: global np_data np_data = load_data(psd_file_path) print(f"np data orig {np_data.shape}") except FileNotFoundError: print(f"Error: PSD file not found at {psd_file_path}") # Return a plot with error message or just hide it return gr.BarPlot(dummy_emotion_data, x="Emotion", y="Value", label="Emotion Distribution (Error: File not found)", visible=False), pd.DataFrame() final_data = process_data(np_data) torch_data = torch.tensor(final_data, dtype=torch.float32).unsqueeze(0) absolute_model_path = os.path.join("models", "lstm_emotion_model_state.pth") loaded_model = load_model(absolute_model_path, input_size, hidden_size, output_size, num_layers) loaded_model.eval() with torch.no_grad(): predicted_logits, _ = loaded_model(torch_data) final_output_indices = torch.argmax(predicted_logits, dim=2) all_predicted_indices = final_output_indices.view(-1) # Count occurrences of each predicted emotion index values_count = torch.bincount(all_predicted_indices, minlength=output_size) print(f"Raw bincount: {values_count}") emotions_count = {int_to_emotion[i].strip(): 0 for i in range(output_size)} for idx, count in enumerate(values_count): if idx < output_size: emotions_count[int_to_emotion[idx].strip()] = count.item() dom_emotion = max(emotions_count, key=emotions_count.get) emotion_data = pd.DataFrame({ "Emotion": list(emotions_count.keys()), "Frequency": list(emotions_count.values()) }) emotion_data = emotion_data.sort_values(by="Emotion").reset_index(drop=True) print(f"Final emotion_data DataFrame:\n{emotion_data}") return gr.BarPlot( emotion_data, x="Emotion", y="Frequency", label="Emotion Distribution", visible=True, y_title="Frequency" ), emotion_data, gr.Textbox(abr_to_emotion[dom_emotion], visible=True) def update_paintings(painter_name): """ Updates the gallery with paintings specific to the selected painter by dynamically listing files in the painter's directory. """ painter_dir = os.path.join(PAINTERS_BASE_DIR, painter_name).replace(os.sep, '/') artist_paintings_for_gallery = [] if os.path.isdir(painter_dir): for filename in sorted(os.listdir(painter_dir)): if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.gif')): file_path = os.path.join(painter_dir, filename).replace(os.sep, '/') print(file_path) title_with_ext = os.path.splitext(filename)[0] artist_paintings_for_gallery.append((file_path, title_with_ext)) print(f"Loaded paintings for {painter_name}: {artist_paintings_for_gallery}") return artist_paintings_for_gallery def generate_my_art(painter, chosen_painting, dom_emotion): if not painter or not chosen_painting: return "Please select a painter and a painting.", None, None img_style_pth = os.path.join(PAINTERS_BASE_DIR, painter, chosen_painting) print(f"img_stype_path: {img_style_pth}") time.sleep(3) ##original image emotion_pth = os.path.join(EMOTION_BASE_DIR, dom_emotion) image_name = list(os.listdir(emotion_pth))[random.randint(0, len(os.listdir(emotion_pth)) -1)] original_image_pth = os.path.join(emotion_pth, image_name) print(f"original img _path: {original_image_pth}") final_message = f"Art generated based on {painter}'s {chosen_painting} style!" ## Neural Style Transfer stylized_img_path = save_style(output_dir, original_image_pth, img_style_pth) yield gr.Textbox(final_message), original_image_pth, stylized_img_path # --- Gradio Interface Definition --- with gr.Blocks(css=".gradio-container { max-width: 2000px; margin: auto; }") as demo: current_emotion_df_state = gr.State(value=pd.DataFrame()) # Header Section gr.Markdown( """

Brain Emotion Decoder 🧠🎨

Imagine seeing your deepest feelings transform into art. We decode the underlying emotions from your brain activity, generating a personalized artwork. Discover the art of your inner self.

""" ) with gr.Row(): with gr.Column(scale=1): gr.Markdown("

1. Choose a PSD file

") psd_file_selection = gr.Radio( choices=predefined_psd_files, label="Select a PSD file for analysis", value=predefined_psd_files[0], interactive=True ) analyze_psd_button = gr.Button("Analyze PSD File", variant="secondary") gr.Markdown("

2. Emotion Distribution

") emotion_distribution_plot = gr.BarPlot( dummy_emotion_data, x="Emotion", y="Value", label="Emotion Distribution", height=300, x_title="Emotion Type", y_title="Frequency", visible=False ) dom_emotion = gr.Textbox(label = "dominant emotion", visible=False) # Right Column: Art Museum and Generation with gr.Column(scale=1): gr.Markdown("

Your Art Mesum

") # Kept original heading gr.Markdown("

3. Choose your favourite painter

") painter_dropdown = gr.Dropdown( choices=painters, value="Pablo Picasso", # Default selection label="Select a Painter" ) gr.Markdown("

4. Choose your favourite painting

") painting_gallery = gr.Gallery( value=update_paintings("Pablo Picasso"), # Initial load for Picasso's paintings label="Select a Painting", height=300, columns=3, rows=1, object_fit="contain", preview=True, interactive=True, elem_id="painting_gallery", visible=True, ) selected_painting_name = gr.Textbox(visible=False) generate_button = gr.Button("Generate My Art", variant="primary", scale=0) status_message = gr.Textbox( value="Click 'Generate My Art' to begin.", label="Generation Status", interactive=False, show_label=False, lines=1 ) gr.Markdown( """

Your Generated Artwork

Once your brain's emotional data is processed, we pinpoint the dominant emotion. This single feeling inspires a personalized artwork. You can then download this unique visual representation of your inner self.

""" ) with gr.Row(): with gr.Column(scale=1): gr.Markdown("

Generated Image

") generated_image_output = gr.Image(label="Generated Image", show_label=False, height=300) gr.Markdown("

Blended Style Image

") blended_image_output = gr.Image(label="Blended Style Image", show_label=False, height=300) # --- Event Listeners --- analyze_psd_button.click( upload_psd_file, inputs=[psd_file_selection], outputs=[emotion_distribution_plot, current_emotion_df_state, dom_emotion] ) painter_dropdown.change( update_paintings, inputs=[painter_dropdown], outputs=[painting_gallery] ) def on_select(evt: gr.SelectData): print("this function started") print(f"Image index: {evt.index}\nImage value: {evt.value['image']['orig_name']}") return evt.value['image']['orig_name'] painting_gallery.select( on_select, outputs=[selected_painting_name] ) generate_button.click( generate_my_art, inputs=[painter_dropdown, selected_painting_name, dom_emotion], outputs=[status_message, generated_image_output, blended_image_output] ) if __name__ == "__main__": demo.launch()