Ihssane123's picture
Initial commit
3b6d764
import random
import gradio as gr
import pandas as pd
import numpy as np
from Src.Processing import load_data
from Src.Processing import process_data
from Src.Inference import load_model
from Src.NST_Inference import save_style
import torch
import time
import os
import mne
import matplotlib.pyplot as plt
import io
import matplotlib.cm as cm
import gradio as gr
dummy_emotion_data = pd.DataFrame({
'Emotion': ['sad', 'dis', 'fear', 'neu', 'joy', 'ten', 'ins'],
'Value': [0.8, 0.6, 0.1, 0.4, 0.7, 0.2, 0.3]
})
int_to_emotion = {
0: 'sad',
1: 'dis',
2: 'fear',
3: 'neu',
4: 'joy',
5: 'ten',
6: 'ins'
}
abr_to_emotion = {
'sad': "sadness",
'dis': "disgust",
'fear': "fear",
'neu': "neutral",
'joy': "joy",
'ten': 'Tenderness',
'ins': "inspiration"
}
PAINTERS_BASE_DIR = "Painters"
EMOTION_BASE_DIR = "Emotions"
output_dir = "outputs"
input_size = 320
hidden_size=50
output_size = 7
num_layers=1
painters = ["Pablo Picasso", "Vincent van Gogh", "Salvador Dalí"]
predefined_psd_files = ["task-emotion_psd_1.npy", "task-emotion_psd_2.npy", "task-emotion_psd_3.npy"]
Base_Dir = "Datasets"
PAINTER_PLACEHOLDER_DATA = {
"Pablo Picasso": [
("Dora Maar with Cat (1941).png", "Dora Maar with Cat (1941)"),
("The Weeping Woman (1937).png", "The Weeping Woman (1937)"),
("Three Musicians (1921).png", "Three Musicians (1921)"),
],
"Vincent van Gogh": [
("Sunflowers (1888).png", "Sunflowers (1888)"),
("The Starry Night (1889).png", "The Starry Night (1889)"),
("The Potato Eaters (1885).png", "The Potato Eaters (1885)"),
],
"Salvador Dalí": [
("Persistence of Memory (1931).png", "Persistence of Memory (1931)"),
("Swans Reflecting Elephants (1937).png", "Swans Reflecting Elephants (1937)"),
("Sleep (1937).png", "Sleep (1937)"),
],
}
def upload_psd_file(selected_file_name):
"""
Processes a selected PSD file, performs inference, and prepares emotion distribution data.
"""
if selected_file_name is None:
return gr.BarPlot(dummy_emotion_data, x="Emotion", y="Value", label="Emotion Distribution", visible=False), pd.DataFrame()
psd_file_path = os.path.join(Base_Dir, selected_file_name).replace(os.sep, '/')
try:
global np_data
np_data = load_data(psd_file_path)
print(f"np data orig {np_data.shape}")
except FileNotFoundError:
print(f"Error: PSD file not found at {psd_file_path}")
# Return a plot with error message or just hide it
return gr.BarPlot(dummy_emotion_data, x="Emotion", y="Value", label="Emotion Distribution (Error: File not found)", visible=False), pd.DataFrame()
final_data = process_data(np_data)
torch_data = torch.tensor(final_data, dtype=torch.float32).unsqueeze(0)
absolute_model_path = os.path.join("models", "lstm_emotion_model_state.pth")
loaded_model = load_model(absolute_model_path, input_size, hidden_size, output_size, num_layers)
loaded_model.eval()
with torch.no_grad():
predicted_logits, _ = loaded_model(torch_data)
final_output_indices = torch.argmax(predicted_logits, dim=2)
all_predicted_indices = final_output_indices.view(-1)
# Count occurrences of each predicted emotion index
values_count = torch.bincount(all_predicted_indices, minlength=output_size)
print(f"Raw bincount: {values_count}")
emotions_count = {int_to_emotion[i].strip(): 0 for i in range(output_size)}
for idx, count in enumerate(values_count):
if idx < output_size:
emotions_count[int_to_emotion[idx].strip()] = count.item()
dom_emotion = max(emotions_count, key=emotions_count.get)
emotion_data = pd.DataFrame({
"Emotion": list(emotions_count.keys()),
"Frequency": list(emotions_count.values())
})
emotion_data = emotion_data.sort_values(by="Emotion").reset_index(drop=True)
print(f"Final emotion_data DataFrame:\n{emotion_data}")
return gr.BarPlot(
emotion_data,
x="Emotion",
y="Frequency",
label="Emotion Distribution",
visible=True,
y_title="Frequency"
), emotion_data, gr.Textbox(abr_to_emotion[dom_emotion], visible=True)
def update_paintings(painter_name):
"""
Updates the gallery with paintings specific to the selected painter by
dynamically listing files in the painter's directory.
"""
painter_dir = os.path.join(PAINTERS_BASE_DIR, painter_name).replace(os.sep, '/')
artist_paintings_for_gallery = []
if os.path.isdir(painter_dir):
for filename in sorted(os.listdir(painter_dir)):
if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.gif')):
file_path = os.path.join(painter_dir, filename).replace(os.sep, '/')
print(file_path)
title_with_ext = os.path.splitext(filename)[0]
artist_paintings_for_gallery.append((file_path, title_with_ext))
print(f"Loaded paintings for {painter_name}: {artist_paintings_for_gallery}")
return artist_paintings_for_gallery
def generate_my_art(painter, chosen_painting, dom_emotion):
if not painter or not chosen_painting:
return "Please select a painter and a painting.", None, None
img_style_pth = os.path.join(PAINTERS_BASE_DIR, painter, chosen_painting)
print(f"img_stype_path: {img_style_pth}")
time.sleep(3)
##original image
emotion_pth = os.path.join(EMOTION_BASE_DIR, dom_emotion)
image_name = list(os.listdir(emotion_pth))[random.randint(0, len(os.listdir(emotion_pth)) -1)]
original_image_pth = os.path.join(emotion_pth, image_name)
print(f"original img _path: {original_image_pth}")
final_message = f"Art generated based on {painter}'s {chosen_painting} style!"
## Neural Style Transfer
stylized_img_path = save_style(output_dir, original_image_pth, img_style_pth)
yield gr.Textbox(final_message), original_image_pth, stylized_img_path
# --- Gradio Interface Definition ---
with gr.Blocks(css=".gradio-container { max-width: 2000px; margin: auto; }") as demo:
current_emotion_df_state = gr.State(value=pd.DataFrame())
# Header Section
gr.Markdown(
"""
<h1 style="text-align: center;font-size: 5em; padding: 20px; font-weight: bold;">Brain Emotion Decoder 🧠🎨</h1>
<p style="text-align: center; font-size: 1.5em; color: #555;font-weight: bold;">
Imagine seeing your deepest feelings transform into art. We decode the underlying emotions from your brain activity,
generating a personalized artwork. Discover the art of your inner self.
</p>
"""
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("<h2 font-size: 2em;>1. Choose a PSD file<h2>")
psd_file_selection = gr.Radio(
choices=predefined_psd_files,
label="Select a PSD file for analysis",
value=predefined_psd_files[0],
interactive=True
)
analyze_psd_button = gr.Button("Analyze PSD File", variant="secondary")
gr.Markdown("<h2 font-size: 2em;>2. Emotion Distribution<h2>")
emotion_distribution_plot = gr.BarPlot(
dummy_emotion_data,
x="Emotion",
y="Value",
label="Emotion Distribution",
height=300,
x_title="Emotion Type",
y_title="Frequency",
visible=False
)
dom_emotion = gr.Textbox(label = "dominant emotion", visible=False)
# Right Column: Art Museum and Generation
with gr.Column(scale=1):
gr.Markdown("<h3>Your Art Mesum</h3>") # Kept original heading
gr.Markdown("<h3>3. Choose your favourite painter</h3>")
painter_dropdown = gr.Dropdown(
choices=painters,
value="Pablo Picasso", # Default selection
label="Select a Painter"
)
gr.Markdown("<h3>4. Choose your favourite painting</h3>")
painting_gallery = gr.Gallery(
value=update_paintings("Pablo Picasso"), # Initial load for Picasso's paintings
label="Select a Painting",
height=300,
columns=3,
rows=1,
object_fit="contain",
preview=True,
interactive=True,
elem_id="painting_gallery",
visible=True,
)
selected_painting_name = gr.Textbox(visible=False)
generate_button = gr.Button("Generate My Art", variant="primary", scale=0)
status_message = gr.Textbox(
value="Click 'Generate My Art' to begin.",
label="Generation Status",
interactive=False,
show_label=False,
lines=1
)
gr.Markdown(
"""
<h1 style="text-align: center;">Your Generated Artwork</h1>
<p style="text-align: center; color: #555;">
Once your brain's emotional data is processed, we pinpoint the <b>dominant emotion</b>. This single feeling inspires a <b>personalized artwork</b>. You can then <b>download</b> this unique visual representation of your inner self.
</p>
"""
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("<h3>Generated Image</h3>")
generated_image_output = gr.Image(label="Generated Image", show_label=False, height=300)
gr.Markdown("<h3>Blended Style Image</h3>")
blended_image_output = gr.Image(label="Blended Style Image", show_label=False, height=300)
# --- Event Listeners ---
analyze_psd_button.click(
upload_psd_file,
inputs=[psd_file_selection],
outputs=[emotion_distribution_plot, current_emotion_df_state, dom_emotion]
)
painter_dropdown.change(
update_paintings,
inputs=[painter_dropdown],
outputs=[painting_gallery]
)
def on_select(evt: gr.SelectData):
print("this function started")
print(f"Image index: {evt.index}\nImage value: {evt.value['image']['orig_name']}")
return evt.value['image']['orig_name']
painting_gallery.select(
on_select,
outputs=[selected_painting_name]
)
generate_button.click(
generate_my_art,
inputs=[painter_dropdown, selected_painting_name, dom_emotion],
outputs=[status_message, generated_image_output, blended_image_output]
)
if __name__ == "__main__":
demo.launch()