Ihssane123 commited on
Commit
3b6d764
·
0 Parent(s):

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +36 -0
  2. Datasets/task-emotion_psd_1.npy +3 -0
  3. Datasets/task-emotion_psd_2.npy +3 -0
  4. Datasets/task-emotion_psd_3.npy +3 -0
  5. Emotions/disgust/dis_1.png +3 -0
  6. Emotions/disgust/dis_2.png +3 -0
  7. Emotions/disgust/dis_3.png +3 -0
  8. Emotions/fear/fear_1.png +3 -0
  9. Emotions/fear/fear_2.png +3 -0
  10. Emotions/fear/fear_3.png +3 -0
  11. Emotions/joy/joy_1.png +3 -0
  12. Emotions/joy/joy_2.png +3 -0
  13. Emotions/joy/joy_3.png +3 -0
  14. Emotions/sad/sad_1.png +3 -0
  15. Emotions/sad/sad_2.png +3 -0
  16. Emotions/sad/sad_3.png +3 -0
  17. Models_Class/LSTMModel.py +24 -0
  18. Models_Class/__pycache__/LSTMModel.cpython-311.pyc +0 -0
  19. Models_Class/__pycache__/LSTMModel.cpython-312.pyc +0 -0
  20. Models_Class/__pycache__/NST_class.cpython-311.pyc +0 -0
  21. Painters/Pablo Picasso/Dora Maar with Cat (1941).png +3 -0
  22. Painters/Pablo Picasso/The Weeping Woman (1937).png +3 -0
  23. Painters/Pablo Picasso/Three Musicians (1921).png +3 -0
  24. Painters/Salvador Dalí/Sleep (1937).png +3 -0
  25. Painters/Salvador Dalí/Swans Reflecting Elephants (1937).png +3 -0
  26. Painters/Salvador Dalí/The Persistence of Memory (1931).png +3 -0
  27. Painters/Vincent van Gogh/Sunflowers (1888).png +3 -0
  28. Painters/Vincent van Gogh/The Potato Eaters (1885).png +3 -0
  29. Painters/Vincent van Gogh/The Starry Night (1889).png +3 -0
  30. README.md +13 -0
  31. Src/Inference.py +8 -0
  32. Src/NST_Inference.py +81 -0
  33. Src/Processing.py +17 -0
  34. Src/Processing_img.py +110 -0
  35. Src/__init__.py +0 -0
  36. Src/__pycache__/Inference.cpython-311.pyc +0 -0
  37. Src/__pycache__/Inference.cpython-312.pyc +0 -0
  38. Src/__pycache__/NST_Inference.cpython-311.pyc +0 -0
  39. Src/__pycache__/Processing.cpython-311.pyc +0 -0
  40. Src/__pycache__/Processing.cpython-312.pyc +0 -0
  41. Src/__pycache__/Processing_img.cpython-311.pyc +0 -0
  42. Src/__pycache__/__init__.cpython-311.pyc +0 -0
  43. Src/__pycache__/__init__.cpython-312.pyc +0 -0
  44. Src/__pycache__/function.cpython-311.pyc +0 -0
  45. Src/__pycache__/net.cpython-311.pyc +0 -0
  46. Src/function.py +67 -0
  47. Src/net.py +152 -0
  48. app.py +279 -0
  49. models/decoder.pth +3 -0
  50. models/lstm_emotion_model_state.pth +3 -0
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
Datasets/task-emotion_psd_1.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:575550dab5e146dcd5d6b0d3ad3f349cf05c443e99932b494b34a299d485014d
3
+ size 1613182
Datasets/task-emotion_psd_2.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d143bcca4ddabfc9b56ab360dc844595bc230d5833a2bcf804fc6a10bc07fbb
3
+ size 1613182
Datasets/task-emotion_psd_3.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d22fe179a31082a87e27ff5846129e43eb93f5cf14ec1f191ce6a38c78f0c525
3
+ size 1613182
Emotions/disgust/dis_1.png ADDED

Git LFS Details

  • SHA256: 5ac7ab9a76de1bd20f2e97d1222abdd7acca7ec126d31dfe61830611441aeadb
  • Pointer size: 132 Bytes
  • Size of remote file: 2.47 MB
Emotions/disgust/dis_2.png ADDED

Git LFS Details

  • SHA256: ea0cb48344fb0cd177e31b98df3b770f74ed58c93e586fa83c957d6bde6d08d4
  • Pointer size: 132 Bytes
  • Size of remote file: 2.5 MB
Emotions/disgust/dis_3.png ADDED

Git LFS Details

  • SHA256: a36650f62f7094189db0235d922624c88139e670dff49b076cf6dcaaebfd5b2c
  • Pointer size: 132 Bytes
  • Size of remote file: 3.29 MB
Emotions/fear/fear_1.png ADDED

Git LFS Details

  • SHA256: d203031b3c11766f56bc6d87ae021d061e448fa0a1f863423f4d55a24d99c7db
  • Pointer size: 132 Bytes
  • Size of remote file: 2.51 MB
Emotions/fear/fear_2.png ADDED

Git LFS Details

  • SHA256: 1f6387173a0495b89d15adab188a9e5c3e68c2536f7ba09e5a39d495a63b3954
  • Pointer size: 132 Bytes
  • Size of remote file: 2.74 MB
Emotions/fear/fear_3.png ADDED

Git LFS Details

  • SHA256: a55d264eb54f9114413f85eed9d4227c01b2bffeb5ab067983592df253287491
  • Pointer size: 132 Bytes
  • Size of remote file: 3.16 MB
Emotions/joy/joy_1.png ADDED

Git LFS Details

  • SHA256: bdde85b1578148b43e157aae6e943cf332a79bf965e0614baaf9db2489e60fb1
  • Pointer size: 132 Bytes
  • Size of remote file: 2.07 MB
Emotions/joy/joy_2.png ADDED

Git LFS Details

  • SHA256: 50ba3d1a26c80fa27c0836b903e2e93397d7adb5398f0e35b0c8f46ad03bd3a2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.32 MB
Emotions/joy/joy_3.png ADDED

Git LFS Details

  • SHA256: abc5cc933536f5d5a0287b3fc48433947ab9147e3fb22b863ed2c35cf37f1255
  • Pointer size: 132 Bytes
  • Size of remote file: 3.24 MB
Emotions/sad/sad_1.png ADDED

Git LFS Details

  • SHA256: 2f8d8afd794483426c28bfdc6cbac0d503b42d4486e9230770c632fead5b6eff
  • Pointer size: 132 Bytes
  • Size of remote file: 4.72 MB
Emotions/sad/sad_2.png ADDED

Git LFS Details

  • SHA256: e1bc9975cfff5e5c6530dcd4f7c5e836a535081c7060dc4c5c3cbc79bb7fddbb
  • Pointer size: 132 Bytes
  • Size of remote file: 4.02 MB
Emotions/sad/sad_3.png ADDED

Git LFS Details

  • SHA256: 86f57852a4cf0c67e9bb608fe7984aa0f818a10dec245a28dc2d110f705df51f
  • Pointer size: 130 Bytes
  • Size of remote file: 19.5 kB
Models_Class/LSTMModel.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+ class LSTMModel(nn.Module):
5
+ ## constructor
6
+ def __init__(self, input_size, hidden_size, output_size, num_layers):
7
+ super(LSTMModel, self).__init__()
8
+ self.input_size = input_size
9
+ self.hidden_size = hidden_size
10
+ self.output_size = output_size
11
+ self.num_layers = num_layers
12
+ self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
13
+ self.fc = nn.Linear(self.hidden_size, self.output_size)
14
+
15
+ def forward(self,x, h0=None, c0=None):
16
+ # hidden and state vectors h0 and c0
17
+ if h0 is None or c0 is None:
18
+ h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
19
+ c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
20
+
21
+ out, (hn, cn) = self.lstm(x, (h0, c0))
22
+ out = self.fc(out)
23
+ return out, (hn, cn)
24
+
Models_Class/__pycache__/LSTMModel.cpython-311.pyc ADDED
Binary file (1.93 kB). View file
 
Models_Class/__pycache__/LSTMModel.cpython-312.pyc ADDED
Binary file (1.81 kB). View file
 
Models_Class/__pycache__/NST_class.cpython-311.pyc ADDED
Binary file (3.49 kB). View file
 
Painters/Pablo Picasso/Dora Maar with Cat (1941).png ADDED

Git LFS Details

  • SHA256: e3d9a1c358f10e2d5a078fd0b8c6e7360f4de921aaace30d6162524ac1189602
  • Pointer size: 130 Bytes
  • Size of remote file: 25.3 kB
Painters/Pablo Picasso/The Weeping Woman (1937).png ADDED

Git LFS Details

  • SHA256: b10f3b39125ef7c096dfd165b77faa1079774201abe1ec45b563744a4d4b8827
  • Pointer size: 130 Bytes
  • Size of remote file: 32.7 kB
Painters/Pablo Picasso/Three Musicians (1921).png ADDED

Git LFS Details

  • SHA256: f042345d98f8128fbfd8d84b3c84660fa7326418b63599efa7cfe429fd3bb16a
  • Pointer size: 131 Bytes
  • Size of remote file: 521 kB
Painters/Salvador Dalí/Sleep (1937).png ADDED

Git LFS Details

  • SHA256: 987d628b6c0ac7367b5c1b24ede65664ba32baba4f7a94264dc42cbd221c7139
  • Pointer size: 130 Bytes
  • Size of remote file: 60.5 kB
Painters/Salvador Dalí/Swans Reflecting Elephants (1937).png ADDED

Git LFS Details

  • SHA256: f71df5184f15adcdcaed509336b02edc37dd0b2daeb7049142dd61075e126148
  • Pointer size: 130 Bytes
  • Size of remote file: 27.2 kB
Painters/Salvador Dalí/The Persistence of Memory (1931).png ADDED

Git LFS Details

  • SHA256: f306a83ccc3d3f0b2e7894b92ceef114f6d5361750cb982d793acf613a0d107e
  • Pointer size: 130 Bytes
  • Size of remote file: 62 kB
Painters/Vincent van Gogh/Sunflowers (1888).png ADDED

Git LFS Details

  • SHA256: c8cd0e84094378a927dab0432c416e8174204bb98af394ceb757f369d8ef4a2c
  • Pointer size: 130 Bytes
  • Size of remote file: 94 kB
Painters/Vincent van Gogh/The Potato Eaters (1885).png ADDED

Git LFS Details

  • SHA256: d92cb55929ba1f20a552f828625e19277278161cff287f7f2f0fc6448c7eb2b5
  • Pointer size: 131 Bytes
  • Size of remote file: 135 kB
Painters/Vincent van Gogh/The Starry Night (1889).png ADDED

Git LFS Details

  • SHA256: 5f74a2599ced6f8437e2b29a786bdde2a451ff8f01d87418a14c6bbe2ce0ad6a
  • Pointer size: 131 Bytes
  • Size of remote file: 541 kB
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Brain Emotion Decoder
3
+ emoji: 🔥
4
+ colorFrom: purple
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 5.45.0
8
+ app_file: app.py
9
+ pinned: false
10
+ short_description: Decoding Emotions through Brain Signals
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
Src/Inference.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ ## Start here with the inference procedure
2
+ import torch
3
+ from Models_Class.LSTMModel import LSTMModel
4
+
5
+ def load_model(model_path, input_size, hidden_size, output_size, num_layers):
6
+ loaded_model = LSTMModel(input_size, hidden_size, output_size, num_layers)
7
+ loaded_model.load_state_dict(torch.load(model_path))
8
+ return loaded_model
Src/NST_Inference.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+ from torchvision.utils import save_image
9
+
10
+ from . import net
11
+ from .function import adaptive_instance_normalization
12
+
13
+
14
+ def test_transform(size, crop):
15
+ transform_list = []
16
+ if size != 0:
17
+ transform_list.append(transforms.Resize(size))
18
+ if crop:
19
+ transform_list.append(transforms.CenterCrop(size))
20
+ transform_list.append(transforms.ToTensor())
21
+ transform = transforms.Compose(transform_list)
22
+ return transform
23
+
24
+
25
+ def style_transfer(vgg, decoder, content, style, alpha=1.0,
26
+ interpolation_weights=None):
27
+ assert (0.0 <= alpha <= 1.0)
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ content_f = vgg(content)
30
+ style_f = vgg(style)
31
+ if interpolation_weights:
32
+ _, C, H, W = content_f.size()
33
+ feat = torch.FloatTensor(1, C, H, W).zero_().to(device)
34
+ base_feat = adaptive_instance_normalization(content_f, style_f)
35
+ for i, w in enumerate(interpolation_weights):
36
+ feat = feat + w * base_feat[i:i + 1]
37
+ content_f = content_f[0:1]
38
+ else:
39
+ feat = adaptive_instance_normalization(content_f, style_f)
40
+ feat = feat * alpha + content_f * (1 - alpha)
41
+ return decoder(feat)
42
+
43
+
44
+ def save_style(output_dir, content_path, style_path):
45
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
+ decoder_pth = Path("models/decoder.pth")
47
+ vgg_pth = Path("models/vgg_normalised.pth")
48
+ output_dir = Path("output")
49
+ output_dir.mkdir(exist_ok=True, parents=True)
50
+ content_path = Path(content_path)
51
+ style_paths = [Path(style_path)]
52
+
53
+ decoder = net.decoder
54
+ vgg = net.vgg
55
+
56
+ decoder.eval()
57
+ vgg.eval()
58
+
59
+ decoder.load_state_dict(torch.load(decoder_pth))
60
+ vgg.load_state_dict(torch.load(vgg_pth))
61
+ vgg = nn.Sequential(*list(vgg.children())[:31])
62
+
63
+ vgg.to(device)
64
+ decoder.to(device)
65
+
66
+ content_tf = test_transform(512, True)
67
+ style_tf = test_transform(512, True)
68
+ style = torch.stack([style_tf(Image.open(str(p))) for p in style_paths])
69
+ content = content_tf(Image.open(str(content_path))) \
70
+ .unsqueeze(0).expand_as(style)
71
+ style = style.to(device)
72
+ content = content.to(device)
73
+ with torch.no_grad():
74
+ output = style_transfer(vgg, decoder, content, style,
75
+ 1, '')
76
+ output = output.cpu()
77
+ output_name = output_dir / 'stylized_output.jpg'
78
+ save_image(output, str(output_name))
79
+ return output_name
80
+
81
+
Src/Processing.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ emotion_list = ["0", "1", "2", "3", "4", "5", "6"]
4
+
5
+
6
+ def load_data(psd_file_pth):
7
+ np_data = np.load(psd_file_pth, allow_pickle=True).item()["psd"]
8
+ return np_data
9
+
10
+
11
+ def process_data(np_data):
12
+ #Swap axes
13
+ swapped_data = np.swapaxes(np_data, 0, 1)
14
+ ## reshape data
15
+ reshape_data = swapped_data.reshape(630, 320)
16
+ return reshape_data
17
+
Src/Processing_img.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import torch.optim as optim
4
+ from torchvision import transforms
5
+ import torch.nn as nn
6
+ from Models_Class.NST_class import (
7
+ ContentLoss,
8
+ Normalization,
9
+ StyleLoss,
10
+ )
11
+
12
+ import copy
13
+
14
+ style_weight = 1e8
15
+ content_weight = 1e1
16
+ def image_loader(image_path, loader, device):
17
+ image = Image.open(image_path).convert('RGB')
18
+ image = loader(image).unsqueeze(0)
19
+ return image.to(device, torch.float)
20
+
21
+ def save_image(tensor, path="output.png"):
22
+ image = tensor.cpu().clone()
23
+ image = image.squeeze(0)
24
+ image = transforms.ToPILImage()(image)
25
+ image.save(path)
26
+
27
+ def gram_matrix(input):
28
+ a, b, c, d = input.size()
29
+ features = input.view(a * b, c * d)
30
+ G = torch.mm(features, features.t())
31
+ return G.div(a * b * c * d)
32
+
33
+
34
+ def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
35
+ style_img, content_img, content_layers, style_layers, device):
36
+ cnn = copy.deepcopy(cnn)
37
+ normalization = Normalization(normalization_mean, normalization_std).to(device)
38
+ content_losses = []
39
+ style_losses = []
40
+ model = nn.Sequential(normalization)
41
+
42
+ i = 0
43
+ for layer in cnn.children():
44
+ if isinstance(layer, nn.Conv2d):
45
+ i += 1
46
+ name = f'conv_{i}'
47
+ elif isinstance(layer, nn.ReLU):
48
+ name = f'relu_{i}'
49
+ layer = nn.ReLU(inplace=False)
50
+ elif isinstance(layer, nn.MaxPool2d):
51
+ name = f'pool_{i}'
52
+ elif isinstance(layer, nn.BatchNorm2d):
53
+ name = f'bn_{i}'
54
+ else:
55
+ raise RuntimeError(f'Unrecognized layer: {layer.__class__.__name__}')
56
+
57
+ model.add_module(name, layer)
58
+
59
+ if name in content_layers:
60
+ target = model(content_img).detach()
61
+ content_loss = ContentLoss(target)
62
+ model.add_module(f"content_loss_{i}", content_loss)
63
+ content_losses.append(content_loss)
64
+
65
+ if name in style_layers:
66
+ target_feature = model(style_img).detach()
67
+ style_loss = StyleLoss(target_feature)
68
+ model.add_module(f"style_loss_{i}", style_loss)
69
+ style_losses.append(style_loss)
70
+
71
+ for i in range(len(model) - 1, -1, -1):
72
+ if isinstance(model[i], (ContentLoss, StyleLoss)):
73
+ break
74
+ model = model[:i+1]
75
+ return model, style_losses, content_losses
76
+
77
+
78
+
79
+ def run_style_transfer(cnn, normalization_mean, normalization_std,
80
+ content_img, style_img, input_img,content_layers, style_layers, device, num_steps=300):
81
+ print("Building the style transfer model..")
82
+ model, style_losses, content_losses = get_style_model_and_losses(cnn, normalization_mean, normalization_std,
83
+ style_img, content_img,content_layers, style_layers, device )
84
+ optimizer = optim.LBFGS([input_img.requires_grad_()])
85
+
86
+ print("Optimizing..")
87
+ run = [0]
88
+ while run[0] <= num_steps:
89
+ def closure():
90
+ input_img.data.clamp_(0, 1)
91
+ optimizer.zero_grad()
92
+ model(input_img)
93
+ style_score = sum(sl.loss for sl in style_losses)
94
+ content_score = sum(cl.loss for cl in content_losses)
95
+ loss = style_weight * style_score + content_weight * content_score
96
+ loss.backward()
97
+
98
+ if run[0] % 50 == 0:
99
+ print(f"Step {run[0]}:")
100
+ print(f" Style Loss: {style_score.item():.4f}")
101
+ print(f" Content Loss: {content_score.item():.4f}")
102
+ print(f" Total Loss: {loss.item():.4f}\n")
103
+
104
+ run[0] += 1
105
+ return loss
106
+
107
+ optimizer.step(closure)
108
+
109
+ input_img.data.clamp_(0, 1)
110
+ return input_img
Src/__init__.py ADDED
File without changes
Src/__pycache__/Inference.cpython-311.pyc ADDED
Binary file (672 Bytes). View file
 
Src/__pycache__/Inference.cpython-312.pyc ADDED
Binary file (605 Bytes). View file
 
Src/__pycache__/NST_Inference.cpython-311.pyc ADDED
Binary file (5.49 kB). View file
 
Src/__pycache__/Processing.cpython-311.pyc ADDED
Binary file (901 Bytes). View file
 
Src/__pycache__/Processing.cpython-312.pyc ADDED
Binary file (825 Bytes). View file
 
Src/__pycache__/Processing_img.cpython-311.pyc ADDED
Binary file (7.34 kB). View file
 
Src/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (181 Bytes). View file
 
Src/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (169 Bytes). View file
 
Src/__pycache__/function.cpython-311.pyc ADDED
Binary file (4.68 kB). View file
 
Src/__pycache__/net.cpython-311.pyc ADDED
Binary file (9.5 kB). View file
 
Src/function.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def calc_mean_std(feat, eps=1e-5):
5
+ # eps is a small value added to the variance to avoid divide-by-zero.
6
+ size = feat.size()
7
+ assert (len(size) == 4)
8
+ N, C = size[:2]
9
+ feat_var = feat.view(N, C, -1).var(dim=2) + eps
10
+ feat_std = feat_var.sqrt().view(N, C, 1, 1)
11
+ feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
12
+ return feat_mean, feat_std
13
+
14
+
15
+ def adaptive_instance_normalization(content_feat, style_feat):
16
+ assert (content_feat.size()[:2] == style_feat.size()[:2])
17
+ size = content_feat.size()
18
+ style_mean, style_std = calc_mean_std(style_feat)
19
+ content_mean, content_std = calc_mean_std(content_feat)
20
+
21
+ normalized_feat = (content_feat - content_mean.expand(
22
+ size)) / content_std.expand(size)
23
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
24
+
25
+
26
+ def _calc_feat_flatten_mean_std(feat):
27
+ # takes 3D feat (C, H, W), return mean and std of array within channels
28
+ assert (feat.size()[0] == 3)
29
+ assert (isinstance(feat, torch.FloatTensor))
30
+ feat_flatten = feat.view(3, -1)
31
+ mean = feat_flatten.mean(dim=-1, keepdim=True)
32
+ std = feat_flatten.std(dim=-1, keepdim=True)
33
+ return feat_flatten, mean, std
34
+
35
+
36
+ def _mat_sqrt(x):
37
+ U, D, V = torch.svd(x)
38
+ return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t())
39
+
40
+
41
+ def coral(source, target):
42
+ # assume both source and target are 3D array (C, H, W)
43
+ # Note: flatten -> f
44
+
45
+ source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source)
46
+ source_f_norm = (source_f - source_f_mean.expand_as(
47
+ source_f)) / source_f_std.expand_as(source_f)
48
+ source_f_cov_eye = \
49
+ torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3)
50
+
51
+ target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target)
52
+ target_f_norm = (target_f - target_f_mean.expand_as(
53
+ target_f)) / target_f_std.expand_as(target_f)
54
+ target_f_cov_eye = \
55
+ torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3)
56
+
57
+ source_f_norm_transfer = torch.mm(
58
+ _mat_sqrt(target_f_cov_eye),
59
+ torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)),
60
+ source_f_norm)
61
+ )
62
+
63
+ source_f_transfer = source_f_norm_transfer * \
64
+ target_f_std.expand_as(source_f_norm) + \
65
+ target_f_mean.expand_as(source_f_norm)
66
+
67
+ return source_f_transfer.view(source.size())
Src/net.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from .function import adaptive_instance_normalization as adain
4
+ from .function import calc_mean_std
5
+
6
+ decoder = nn.Sequential(
7
+ nn.ReflectionPad2d((1, 1, 1, 1)),
8
+ nn.Conv2d(512, 256, (3, 3)),
9
+ nn.ReLU(),
10
+ nn.Upsample(scale_factor=2, mode='nearest'),
11
+ nn.ReflectionPad2d((1, 1, 1, 1)),
12
+ nn.Conv2d(256, 256, (3, 3)),
13
+ nn.ReLU(),
14
+ nn.ReflectionPad2d((1, 1, 1, 1)),
15
+ nn.Conv2d(256, 256, (3, 3)),
16
+ nn.ReLU(),
17
+ nn.ReflectionPad2d((1, 1, 1, 1)),
18
+ nn.Conv2d(256, 256, (3, 3)),
19
+ nn.ReLU(),
20
+ nn.ReflectionPad2d((1, 1, 1, 1)),
21
+ nn.Conv2d(256, 128, (3, 3)),
22
+ nn.ReLU(),
23
+ nn.Upsample(scale_factor=2, mode='nearest'),
24
+ nn.ReflectionPad2d((1, 1, 1, 1)),
25
+ nn.Conv2d(128, 128, (3, 3)),
26
+ nn.ReLU(),
27
+ nn.ReflectionPad2d((1, 1, 1, 1)),
28
+ nn.Conv2d(128, 64, (3, 3)),
29
+ nn.ReLU(),
30
+ nn.Upsample(scale_factor=2, mode='nearest'),
31
+ nn.ReflectionPad2d((1, 1, 1, 1)),
32
+ nn.Conv2d(64, 64, (3, 3)),
33
+ nn.ReLU(),
34
+ nn.ReflectionPad2d((1, 1, 1, 1)),
35
+ nn.Conv2d(64, 3, (3, 3)),
36
+ )
37
+
38
+ vgg = nn.Sequential(
39
+ nn.Conv2d(3, 3, (1, 1)),
40
+ nn.ReflectionPad2d((1, 1, 1, 1)),
41
+ nn.Conv2d(3, 64, (3, 3)),
42
+ nn.ReLU(), # relu1-1
43
+ nn.ReflectionPad2d((1, 1, 1, 1)),
44
+ nn.Conv2d(64, 64, (3, 3)),
45
+ nn.ReLU(), # relu1-2
46
+ nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
47
+ nn.ReflectionPad2d((1, 1, 1, 1)),
48
+ nn.Conv2d(64, 128, (3, 3)),
49
+ nn.ReLU(), # relu2-1
50
+ nn.ReflectionPad2d((1, 1, 1, 1)),
51
+ nn.Conv2d(128, 128, (3, 3)),
52
+ nn.ReLU(), # relu2-2
53
+ nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
54
+ nn.ReflectionPad2d((1, 1, 1, 1)),
55
+ nn.Conv2d(128, 256, (3, 3)),
56
+ nn.ReLU(), # relu3-1
57
+ nn.ReflectionPad2d((1, 1, 1, 1)),
58
+ nn.Conv2d(256, 256, (3, 3)),
59
+ nn.ReLU(), # relu3-2
60
+ nn.ReflectionPad2d((1, 1, 1, 1)),
61
+ nn.Conv2d(256, 256, (3, 3)),
62
+ nn.ReLU(), # relu3-3
63
+ nn.ReflectionPad2d((1, 1, 1, 1)),
64
+ nn.Conv2d(256, 256, (3, 3)),
65
+ nn.ReLU(), # relu3-4
66
+ nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
67
+ nn.ReflectionPad2d((1, 1, 1, 1)),
68
+ nn.Conv2d(256, 512, (3, 3)),
69
+ nn.ReLU(), # relu4-1, this is the last layer used
70
+ nn.ReflectionPad2d((1, 1, 1, 1)),
71
+ nn.Conv2d(512, 512, (3, 3)),
72
+ nn.ReLU(), # relu4-2
73
+ nn.ReflectionPad2d((1, 1, 1, 1)),
74
+ nn.Conv2d(512, 512, (3, 3)),
75
+ nn.ReLU(), # relu4-3
76
+ nn.ReflectionPad2d((1, 1, 1, 1)),
77
+ nn.Conv2d(512, 512, (3, 3)),
78
+ nn.ReLU(), # relu4-4
79
+ nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
80
+ nn.ReflectionPad2d((1, 1, 1, 1)),
81
+ nn.Conv2d(512, 512, (3, 3)),
82
+ nn.ReLU(), # relu5-1
83
+ nn.ReflectionPad2d((1, 1, 1, 1)),
84
+ nn.Conv2d(512, 512, (3, 3)),
85
+ nn.ReLU(), # relu5-2
86
+ nn.ReflectionPad2d((1, 1, 1, 1)),
87
+ nn.Conv2d(512, 512, (3, 3)),
88
+ nn.ReLU(), # relu5-3
89
+ nn.ReflectionPad2d((1, 1, 1, 1)),
90
+ nn.Conv2d(512, 512, (3, 3)),
91
+ nn.ReLU() # relu5-4
92
+ )
93
+
94
+
95
+ class Net(nn.Module):
96
+ def __init__(self, encoder, decoder):
97
+ super(Net, self).__init__()
98
+ enc_layers = list(encoder.children())
99
+ self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1
100
+ self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1
101
+ self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1
102
+ self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1
103
+ self.decoder = decoder
104
+ self.mse_loss = nn.MSELoss()
105
+
106
+ # fix the encoder
107
+ for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']:
108
+ for param in getattr(self, name).parameters():
109
+ param.requires_grad = False
110
+
111
+ # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image
112
+ def encode_with_intermediate(self, input):
113
+ results = [input]
114
+ for i in range(4):
115
+ func = getattr(self, 'enc_{:d}'.format(i + 1))
116
+ results.append(func(results[-1]))
117
+ return results[1:]
118
+
119
+ # extract relu4_1 from input image
120
+ def encode(self, input):
121
+ for i in range(4):
122
+ input = getattr(self, 'enc_{:d}'.format(i + 1))(input)
123
+ return input
124
+
125
+ def calc_content_loss(self, input, target):
126
+ assert (input.size() == target.size())
127
+ assert (target.requires_grad is False)
128
+ return self.mse_loss(input, target)
129
+
130
+ def calc_style_loss(self, input, target):
131
+ assert (input.size() == target.size())
132
+ assert (target.requires_grad is False)
133
+ input_mean, input_std = calc_mean_std(input)
134
+ target_mean, target_std = calc_mean_std(target)
135
+ return self.mse_loss(input_mean, target_mean) + \
136
+ self.mse_loss(input_std, target_std)
137
+
138
+ def forward(self, content, style, alpha=1.0):
139
+ assert 0 <= alpha <= 1
140
+ style_feats = self.encode_with_intermediate(style)
141
+ content_feat = self.encode(content)
142
+ t = adain(content_feat, style_feats[-1])
143
+ t = alpha * t + (1 - alpha) * content_feat
144
+
145
+ g_t = self.decoder(t)
146
+ g_t_feats = self.encode_with_intermediate(g_t)
147
+
148
+ loss_c = self.calc_content_loss(g_t_feats[-1], t)
149
+ loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0])
150
+ for i in range(1, 4):
151
+ loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i])
152
+ return loss_c, loss_s
app.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import gradio as gr
3
+ import pandas as pd
4
+ import numpy as np
5
+ from Src.Processing import load_data
6
+ from Src.Processing import process_data
7
+ from Src.Inference import load_model
8
+ from Src.NST_Inference import save_style
9
+ import torch
10
+ import time
11
+ import os
12
+ import mne
13
+ import matplotlib.pyplot as plt
14
+ import io
15
+ import matplotlib.cm as cm
16
+ import gradio as gr
17
+
18
+
19
+ dummy_emotion_data = pd.DataFrame({
20
+ 'Emotion': ['sad', 'dis', 'fear', 'neu', 'joy', 'ten', 'ins'],
21
+ 'Value': [0.8, 0.6, 0.1, 0.4, 0.7, 0.2, 0.3]
22
+ })
23
+
24
+ int_to_emotion = {
25
+ 0: 'sad',
26
+ 1: 'dis',
27
+ 2: 'fear',
28
+ 3: 'neu',
29
+ 4: 'joy',
30
+ 5: 'ten',
31
+ 6: 'ins'
32
+ }
33
+
34
+ abr_to_emotion = {
35
+ 'sad': "sadness",
36
+ 'dis': "disgust",
37
+ 'fear': "fear",
38
+ 'neu': "neutral",
39
+ 'joy': "joy",
40
+ 'ten': 'Tenderness',
41
+ 'ins': "inspiration"
42
+ }
43
+
44
+ PAINTERS_BASE_DIR = "Painters"
45
+ EMOTION_BASE_DIR = "Emotions"
46
+ output_dir = "outputs"
47
+ input_size = 320
48
+ hidden_size=50
49
+ output_size = 7
50
+ num_layers=1
51
+
52
+ painters = ["Pablo Picasso", "Vincent van Gogh", "Salvador Dalí"]
53
+ predefined_psd_files = ["task-emotion_psd_1.npy", "task-emotion_psd_2.npy", "task-emotion_psd_3.npy"]
54
+ Base_Dir = "Datasets"
55
+
56
+ PAINTER_PLACEHOLDER_DATA = {
57
+ "Pablo Picasso": [
58
+ ("Dora Maar with Cat (1941).png", "Dora Maar with Cat (1941)"),
59
+ ("The Weeping Woman (1937).png", "The Weeping Woman (1937)"),
60
+ ("Three Musicians (1921).png", "Three Musicians (1921)"),
61
+ ],
62
+ "Vincent van Gogh": [
63
+ ("Sunflowers (1888).png", "Sunflowers (1888)"),
64
+ ("The Starry Night (1889).png", "The Starry Night (1889)"),
65
+ ("The Potato Eaters (1885).png", "The Potato Eaters (1885)"),
66
+ ],
67
+ "Salvador Dalí": [
68
+ ("Persistence of Memory (1931).png", "Persistence of Memory (1931)"),
69
+ ("Swans Reflecting Elephants (1937).png", "Swans Reflecting Elephants (1937)"),
70
+ ("Sleep (1937).png", "Sleep (1937)"),
71
+ ],
72
+ }
73
+
74
+ def upload_psd_file(selected_file_name):
75
+ """
76
+ Processes a selected PSD file, performs inference, and prepares emotion distribution data.
77
+ """
78
+ if selected_file_name is None:
79
+ return gr.BarPlot(dummy_emotion_data, x="Emotion", y="Value", label="Emotion Distribution", visible=False), pd.DataFrame()
80
+
81
+ psd_file_path = os.path.join(Base_Dir, selected_file_name).replace(os.sep, '/')
82
+
83
+ try:
84
+ global np_data
85
+ np_data = load_data(psd_file_path)
86
+ print(f"np data orig {np_data.shape}")
87
+ except FileNotFoundError:
88
+ print(f"Error: PSD file not found at {psd_file_path}")
89
+ # Return a plot with error message or just hide it
90
+ return gr.BarPlot(dummy_emotion_data, x="Emotion", y="Value", label="Emotion Distribution (Error: File not found)", visible=False), pd.DataFrame()
91
+
92
+
93
+ final_data = process_data(np_data)
94
+ torch_data = torch.tensor(final_data, dtype=torch.float32).unsqueeze(0)
95
+ absolute_model_path = os.path.join("models", "lstm_emotion_model_state.pth")
96
+ loaded_model = load_model(absolute_model_path, input_size, hidden_size, output_size, num_layers)
97
+ loaded_model.eval()
98
+ with torch.no_grad():
99
+ predicted_logits, _ = loaded_model(torch_data)
100
+ final_output_indices = torch.argmax(predicted_logits, dim=2)
101
+ all_predicted_indices = final_output_indices.view(-1)
102
+
103
+ # Count occurrences of each predicted emotion index
104
+ values_count = torch.bincount(all_predicted_indices, minlength=output_size)
105
+ print(f"Raw bincount: {values_count}")
106
+ emotions_count = {int_to_emotion[i].strip(): 0 for i in range(output_size)}
107
+ for idx, count in enumerate(values_count):
108
+ if idx < output_size:
109
+ emotions_count[int_to_emotion[idx].strip()] = count.item()
110
+ dom_emotion = max(emotions_count, key=emotions_count.get)
111
+ emotion_data = pd.DataFrame({
112
+ "Emotion": list(emotions_count.keys()),
113
+ "Frequency": list(emotions_count.values())
114
+ })
115
+ emotion_data = emotion_data.sort_values(by="Emotion").reset_index(drop=True)
116
+ print(f"Final emotion_data DataFrame:\n{emotion_data}")
117
+
118
+ return gr.BarPlot(
119
+ emotion_data,
120
+ x="Emotion",
121
+ y="Frequency",
122
+ label="Emotion Distribution",
123
+ visible=True,
124
+ y_title="Frequency"
125
+ ), emotion_data, gr.Textbox(abr_to_emotion[dom_emotion], visible=True)
126
+
127
+
128
+ def update_paintings(painter_name):
129
+ """
130
+ Updates the gallery with paintings specific to the selected painter by
131
+ dynamically listing files in the painter's directory.
132
+ """
133
+ painter_dir = os.path.join(PAINTERS_BASE_DIR, painter_name).replace(os.sep, '/')
134
+ artist_paintings_for_gallery = []
135
+ if os.path.isdir(painter_dir):
136
+ for filename in sorted(os.listdir(painter_dir)):
137
+ if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.gif')):
138
+ file_path = os.path.join(painter_dir, filename).replace(os.sep, '/')
139
+ print(file_path)
140
+ title_with_ext = os.path.splitext(filename)[0]
141
+ artist_paintings_for_gallery.append((file_path, title_with_ext))
142
+ print(f"Loaded paintings for {painter_name}: {artist_paintings_for_gallery}")
143
+ return artist_paintings_for_gallery
144
+
145
+
146
+ def generate_my_art(painter, chosen_painting, dom_emotion):
147
+ if not painter or not chosen_painting:
148
+ return "Please select a painter and a painting.", None, None
149
+ img_style_pth = os.path.join(PAINTERS_BASE_DIR, painter, chosen_painting)
150
+ print(f"img_stype_path: {img_style_pth}")
151
+ time.sleep(3)
152
+ ##original image
153
+ emotion_pth = os.path.join(EMOTION_BASE_DIR, dom_emotion)
154
+ image_name = list(os.listdir(emotion_pth))[random.randint(0, len(os.listdir(emotion_pth)) -1)]
155
+ original_image_pth = os.path.join(emotion_pth, image_name)
156
+ print(f"original img _path: {original_image_pth}")
157
+ final_message = f"Art generated based on {painter}'s {chosen_painting} style!"
158
+ ## Neural Style Transfer
159
+ stylized_img_path = save_style(output_dir, original_image_pth, img_style_pth)
160
+ yield gr.Textbox(final_message), original_image_pth, stylized_img_path
161
+
162
+ # --- Gradio Interface Definition ---
163
+
164
+ with gr.Blocks(css=".gradio-container { max-width: 2000px; margin: auto; }") as demo:
165
+ current_emotion_df_state = gr.State(value=pd.DataFrame())
166
+ # Header Section
167
+ gr.Markdown(
168
+ """
169
+ <h1 style="text-align: center;font-size: 5em; padding: 20px; font-weight: bold;">Brain Emotion Decoder 🧠🎨</h1>
170
+ <p style="text-align: center; font-size: 1.5em; color: #555;font-weight: bold;">
171
+ Imagine seeing your deepest feelings transform into art. We decode the underlying emotions from your brain activity,
172
+ generating a personalized artwork. Discover the art of your inner self.
173
+ </p>
174
+ """
175
+ )
176
+
177
+ with gr.Row():
178
+ with gr.Column(scale=1):
179
+ gr.Markdown("<h2 font-size: 2em;>1. Choose a PSD file<h2>")
180
+ psd_file_selection = gr.Radio(
181
+ choices=predefined_psd_files,
182
+ label="Select a PSD file for analysis",
183
+ value=predefined_psd_files[0],
184
+ interactive=True
185
+ )
186
+
187
+ analyze_psd_button = gr.Button("Analyze PSD File", variant="secondary")
188
+
189
+ gr.Markdown("<h2 font-size: 2em;>2. Emotion Distribution<h2>")
190
+
191
+ emotion_distribution_plot = gr.BarPlot(
192
+ dummy_emotion_data,
193
+ x="Emotion",
194
+ y="Value",
195
+ label="Emotion Distribution",
196
+ height=300,
197
+ x_title="Emotion Type",
198
+ y_title="Frequency",
199
+ visible=False
200
+ )
201
+ dom_emotion = gr.Textbox(label = "dominant emotion", visible=False)
202
+
203
+ # Right Column: Art Museum and Generation
204
+ with gr.Column(scale=1):
205
+ gr.Markdown("<h3>Your Art Mesum</h3>") # Kept original heading
206
+ gr.Markdown("<h3>3. Choose your favourite painter</h3>")
207
+ painter_dropdown = gr.Dropdown(
208
+ choices=painters,
209
+ value="Pablo Picasso", # Default selection
210
+ label="Select a Painter"
211
+ )
212
+ gr.Markdown("<h3>4. Choose your favourite painting</h3>")
213
+ painting_gallery = gr.Gallery(
214
+ value=update_paintings("Pablo Picasso"), # Initial load for Picasso's paintings
215
+ label="Select a Painting",
216
+ height=300,
217
+ columns=3,
218
+ rows=1,
219
+ object_fit="contain",
220
+ preview=True,
221
+ interactive=True,
222
+ elem_id="painting_gallery",
223
+ visible=True,
224
+ )
225
+ selected_painting_name = gr.Textbox(visible=False)
226
+ generate_button = gr.Button("Generate My Art", variant="primary", scale=0)
227
+ status_message = gr.Textbox(
228
+ value="Click 'Generate My Art' to begin.",
229
+ label="Generation Status",
230
+ interactive=False,
231
+ show_label=False,
232
+ lines=1
233
+ )
234
+
235
+ gr.Markdown(
236
+ """
237
+ <h1 style="text-align: center;">Your Generated Artwork</h1>
238
+ <p style="text-align: center; color: #555;">
239
+ Once your brain's emotional data is processed, we pinpoint the <b>dominant emotion</b>. This single feeling inspires a <b>personalized artwork</b>. You can then <b>download</b> this unique visual representation of your inner self.
240
+ </p>
241
+ """
242
+ )
243
+
244
+ with gr.Row():
245
+ with gr.Column(scale=1):
246
+ gr.Markdown("<h3>Generated Image</h3>")
247
+ generated_image_output = gr.Image(label="Generated Image", show_label=False, height=300)
248
+ gr.Markdown("<h3>Blended Style Image</h3>")
249
+ blended_image_output = gr.Image(label="Blended Style Image", show_label=False, height=300)
250
+
251
+ # --- Event Listeners ---
252
+ analyze_psd_button.click(
253
+ upload_psd_file,
254
+ inputs=[psd_file_selection],
255
+ outputs=[emotion_distribution_plot, current_emotion_df_state, dom_emotion]
256
+ )
257
+
258
+ painter_dropdown.change(
259
+ update_paintings,
260
+ inputs=[painter_dropdown],
261
+ outputs=[painting_gallery]
262
+ )
263
+
264
+ def on_select(evt: gr.SelectData):
265
+ print("this function started")
266
+ print(f"Image index: {evt.index}\nImage value: {evt.value['image']['orig_name']}")
267
+ return evt.value['image']['orig_name']
268
+ painting_gallery.select(
269
+ on_select,
270
+ outputs=[selected_painting_name]
271
+ )
272
+
273
+ generate_button.click(
274
+ generate_my_art,
275
+ inputs=[painter_dropdown, selected_painting_name, dom_emotion],
276
+ outputs=[status_message, generated_image_output, blended_image_output]
277
+ )
278
+ if __name__ == "__main__":
279
+ demo.launch()
models/decoder.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:379ca41d59f3a37eed3599bbbc2560c19da5c458870a5ffd3a9dd41aa88f9472
3
+ size 14023458
models/lstm_emotion_model_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec28421b6bb8cec2fbe2d5059f02045eada4c0bcf6765a8d9be2fd69976b202a
3
+ size 301810