File size: 10,788 Bytes
3b6d764
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
import random
import gradio as gr
import pandas as pd
import numpy as np
from Src.Processing import load_data
from Src.Processing import process_data
from Src.Inference import load_model
from Src.NST_Inference import save_style
import torch
import time 
import os 
import mne
import matplotlib.pyplot as plt
import io
import matplotlib.cm as cm
import gradio as gr


dummy_emotion_data = pd.DataFrame({
    'Emotion': ['sad', 'dis', 'fear', 'neu', 'joy', 'ten', 'ins'],
    'Value': [0.8, 0.6, 0.1, 0.4, 0.7, 0.2, 0.3]
})

int_to_emotion = {
    0: 'sad',
    1: 'dis',
    2: 'fear',
    3: 'neu',
    4: 'joy',
    5: 'ten',
    6: 'ins'
}

abr_to_emotion = {
    'sad': "sadness",
    'dis': "disgust",
    'fear': "fear",
    'neu': "neutral",
    'joy': "joy",
    'ten': 'Tenderness',
    'ins': "inspiration"
}

PAINTERS_BASE_DIR = "Painters"
EMOTION_BASE_DIR = "Emotions"
output_dir = "outputs"
input_size = 320
hidden_size=50
output_size = 7
num_layers=1

painters = ["Pablo Picasso", "Vincent van Gogh", "Salvador Dalí"]
predefined_psd_files = ["task-emotion_psd_1.npy", "task-emotion_psd_2.npy", "task-emotion_psd_3.npy"] 
Base_Dir = "Datasets"

PAINTER_PLACEHOLDER_DATA = {
    "Pablo Picasso": [
        ("Dora Maar with Cat (1941).png", "Dora Maar with Cat (1941)"),
        ("The Weeping Woman (1937).png", "The Weeping Woman (1937)"),
        ("Three Musicians (1921).png", "Three Musicians (1921)"),
    ],
    "Vincent van Gogh": [
        ("Sunflowers (1888).png", "Sunflowers (1888)"),
        ("The Starry Night (1889).png", "The Starry Night (1889)"),
        ("The Potato Eaters (1885).png", "The Potato Eaters (1885)"),
    ],
    "Salvador Dalí": [
        ("Persistence of Memory (1931).png", "Persistence of Memory (1931)"),
        ("Swans Reflecting Elephants (1937).png", "Swans Reflecting Elephants (1937)"),
        ("Sleep (1937).png", "Sleep (1937)"),
    ],
}

def upload_psd_file(selected_file_name):
    """
    Processes a selected PSD file, performs inference, and prepares emotion distribution data.
    """
    if selected_file_name is None:
        return gr.BarPlot(dummy_emotion_data, x="Emotion", y="Value", label="Emotion Distribution", visible=False), pd.DataFrame()
    
    psd_file_path = os.path.join(Base_Dir, selected_file_name).replace(os.sep, '/')
    
    try:
        global np_data
        np_data = load_data(psd_file_path)
        print(f"np data orig {np_data.shape}")
    except FileNotFoundError:
        print(f"Error: PSD file not found at {psd_file_path}")
        # Return a plot with error message or just hide it
        return gr.BarPlot(dummy_emotion_data, x="Emotion", y="Value", label="Emotion Distribution (Error: File not found)", visible=False), pd.DataFrame()
    
    
    final_data = process_data(np_data)
    torch_data = torch.tensor(final_data, dtype=torch.float32).unsqueeze(0) 
    absolute_model_path = os.path.join("models", "lstm_emotion_model_state.pth")
    loaded_model = load_model(absolute_model_path, input_size, hidden_size, output_size, num_layers)
    loaded_model.eval() 
    with torch.no_grad():
        predicted_logits, _ = loaded_model(torch_data) 
    final_output_indices = torch.argmax(predicted_logits, dim=2) 
    all_predicted_indices = final_output_indices.view(-1) 
        
    # Count occurrences of each predicted emotion index
    values_count = torch.bincount(all_predicted_indices, minlength=output_size) 
    print(f"Raw bincount: {values_count}")
    emotions_count = {int_to_emotion[i].strip(): 0 for i in range(output_size)} 
    for idx, count in enumerate(values_count):
        if idx < output_size: 
            emotions_count[int_to_emotion[idx].strip()] = count.item()             
    dom_emotion = max(emotions_count, key=emotions_count.get)
    emotion_data = pd.DataFrame({
        "Emotion": list(emotions_count.keys()),
        "Frequency": list(emotions_count.values())
    })
    emotion_data = emotion_data.sort_values(by="Emotion").reset_index(drop=True)
    print(f"Final emotion_data DataFrame:\n{emotion_data}")
    
    return gr.BarPlot(
        emotion_data, 
        x="Emotion", 
        y="Frequency", 
        label="Emotion Distribution", 
        visible=True,
        y_title="Frequency"
    ), emotion_data, gr.Textbox(abr_to_emotion[dom_emotion], visible=True)


def update_paintings(painter_name):
    """
    Updates the gallery with paintings specific to the selected painter by
    dynamically listing files in the painter's directory.
    """
    painter_dir = os.path.join(PAINTERS_BASE_DIR, painter_name).replace(os.sep, '/')    
    artist_paintings_for_gallery = []
    if os.path.isdir(painter_dir):
        for filename in sorted(os.listdir(painter_dir)): 
            if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.gif')):
                file_path = os.path.join(painter_dir, filename).replace(os.sep, '/')
                print(file_path)
                title_with_ext = os.path.splitext(filename)[0]
                artist_paintings_for_gallery.append((file_path, title_with_ext)) 
    print(f"Loaded paintings for {painter_name}: {artist_paintings_for_gallery}")
    return artist_paintings_for_gallery


def generate_my_art(painter, chosen_painting, dom_emotion):
    if not painter or not chosen_painting:
        return "Please select a painter and a painting.", None, None
    img_style_pth = os.path.join(PAINTERS_BASE_DIR, painter, chosen_painting)
    print(f"img_stype_path: {img_style_pth}")
    time.sleep(3)
    ##original image
    emotion_pth = os.path.join(EMOTION_BASE_DIR, dom_emotion)
    image_name = list(os.listdir(emotion_pth))[random.randint(0, len(os.listdir(emotion_pth)) -1)]
    original_image_pth = os.path.join(emotion_pth, image_name)
    print(f"original img _path: {original_image_pth}")
    final_message = f"Art generated based on {painter}'s {chosen_painting} style!"
    ## Neural Style Transfer
    stylized_img_path = save_style(output_dir, original_image_pth, img_style_pth)
    yield gr.Textbox(final_message), original_image_pth, stylized_img_path

# --- Gradio Interface Definition ---

with gr.Blocks(css=".gradio-container { max-width: 2000px; margin: auto; }") as demo:
    current_emotion_df_state = gr.State(value=pd.DataFrame())
    # Header Section
    gr.Markdown(
        """
        <h1 style="text-align: center;font-size: 5em; padding: 20px;  font-weight: bold;">Brain Emotion Decoder 🧠🎨</h1>
        <p style="text-align: center; font-size: 1.5em; color: #555;font-weight: bold;">
        Imagine seeing your deepest feelings transform into art. We decode the underlying emotions from your brain activity,
        generating a personalized artwork. Discover the art of your inner self.
        </p>
        """
    )

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("<h2 font-size: 2em;>1. Choose a PSD file<h2>")
            psd_file_selection = gr.Radio(
                choices=predefined_psd_files,
                label="Select a PSD file for analysis",
                value=predefined_psd_files[0], 
                interactive=True
            )
            
            analyze_psd_button = gr.Button("Analyze PSD File", variant="secondary")

            gr.Markdown("<h2 font-size: 2em;>2. Emotion Distribution<h2>")

            emotion_distribution_plot = gr.BarPlot(
                dummy_emotion_data,
                x="Emotion",
                y="Value",
                label="Emotion Distribution",
                height=300,
                x_title="Emotion Type",
                y_title="Frequency",
                visible=False 
            )
            dom_emotion = gr.Textbox(label = "dominant emotion", visible=False)

        # Right Column: Art Museum and Generation
        with gr.Column(scale=1):
            gr.Markdown("<h3>Your Art Mesum</h3>") # Kept original heading
            gr.Markdown("<h3>3. Choose your favourite painter</h3>")
            painter_dropdown = gr.Dropdown(
                choices=painters,
                value="Pablo Picasso", # Default selection
                label="Select a Painter"
            )
            gr.Markdown("<h3>4. Choose your favourite painting</h3>")
            painting_gallery = gr.Gallery(
                value=update_paintings("Pablo Picasso"), # Initial load for Picasso's paintings
                label="Select a Painting",
                height=300,
                columns=3,
                rows=1,
                object_fit="contain",
                preview=True, 
                interactive=True, 
                elem_id="painting_gallery",
                visible=True, 
            )
            selected_painting_name = gr.Textbox(visible=False)
            generate_button = gr.Button("Generate My Art", variant="primary", scale=0) 
            status_message = gr.Textbox(
                value="Click 'Generate My Art' to begin.", 
                label="Generation Status",
                interactive=False,
                show_label=False,
                lines=1 
            )
            
    gr.Markdown(
        """
        <h1 style="text-align: center;">Your Generated Artwork</h1>
        <p style="text-align: center; color: #555;">
        Once your brain's emotional data is processed, we pinpoint the <b>dominant emotion</b>. This single feeling inspires a <b>personalized artwork</b>. You can then <b>download</b> this unique visual representation of your inner self.
        </p>
        """
    )

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("<h3>Generated Image</h3>")
            generated_image_output = gr.Image(label="Generated Image", show_label=False, height=300)
            gr.Markdown("<h3>Blended Style Image</h3>")
            blended_image_output = gr.Image(label="Blended Style Image", show_label=False, height=300)
        
    # --- Event Listeners ---
    analyze_psd_button.click(
        upload_psd_file,
        inputs=[psd_file_selection], 
        outputs=[emotion_distribution_plot, current_emotion_df_state, dom_emotion] 
    )

    painter_dropdown.change(
        update_paintings, 
        inputs=[painter_dropdown],
        outputs=[painting_gallery] 
    )

    def on_select(evt: gr.SelectData):
        print("this function started")
        print(f"Image index: {evt.index}\nImage value: {evt.value['image']['orig_name']}")
        return evt.value['image']['orig_name']
    painting_gallery.select(
        on_select, 
        outputs=[selected_painting_name] 
    )

    generate_button.click(
        generate_my_art,
        inputs=[painter_dropdown, selected_painting_name, dom_emotion], 
        outputs=[status_message, generated_image_output, blended_image_output]
    )
if __name__ == "__main__":
    demo.launch()