diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..f6b1f326ca4ab7cf0c8798856f8fe0020ff82d58 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,36 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text diff --git a/Datasets/task-emotion_psd_1.npy b/Datasets/task-emotion_psd_1.npy new file mode 100644 index 0000000000000000000000000000000000000000..c2438ea68c0686320ae30c8d881bde539c50093f --- /dev/null +++ b/Datasets/task-emotion_psd_1.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:575550dab5e146dcd5d6b0d3ad3f349cf05c443e99932b494b34a299d485014d +size 1613182 diff --git a/Datasets/task-emotion_psd_2.npy b/Datasets/task-emotion_psd_2.npy new file mode 100644 index 0000000000000000000000000000000000000000..a18941e92925e2c9aaaa0c31a88fb72232d0f06d --- /dev/null +++ b/Datasets/task-emotion_psd_2.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9d143bcca4ddabfc9b56ab360dc844595bc230d5833a2bcf804fc6a10bc07fbb +size 1613182 diff --git a/Datasets/task-emotion_psd_3.npy b/Datasets/task-emotion_psd_3.npy new file mode 100644 index 0000000000000000000000000000000000000000..7bf77d28174f012e0b07d4357f7633938b856d65 --- /dev/null +++ b/Datasets/task-emotion_psd_3.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d22fe179a31082a87e27ff5846129e43eb93f5cf14ec1f191ce6a38c78f0c525 +size 1613182 diff --git a/Emotions/disgust/dis_1.png b/Emotions/disgust/dis_1.png new file mode 100644 index 0000000000000000000000000000000000000000..c2df1ec130a88aef63fd25d78f52c9b34e694153 --- /dev/null +++ b/Emotions/disgust/dis_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5ac7ab9a76de1bd20f2e97d1222abdd7acca7ec126d31dfe61830611441aeadb +size 2468922 diff --git a/Emotions/disgust/dis_2.png b/Emotions/disgust/dis_2.png new file mode 100644 index 0000000000000000000000000000000000000000..8d174b5cdacd0929aead1990d206d4506bb9344f --- /dev/null +++ b/Emotions/disgust/dis_2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ea0cb48344fb0cd177e31b98df3b770f74ed58c93e586fa83c957d6bde6d08d4 +size 2503002 diff --git a/Emotions/disgust/dis_3.png b/Emotions/disgust/dis_3.png new file mode 100644 index 0000000000000000000000000000000000000000..49cbc56c1ec109c5d377a6ff52874b74a6124118 --- /dev/null +++ b/Emotions/disgust/dis_3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a36650f62f7094189db0235d922624c88139e670dff49b076cf6dcaaebfd5b2c +size 3293453 diff --git a/Emotions/fear/fear_1.png b/Emotions/fear/fear_1.png new file mode 100644 index 0000000000000000000000000000000000000000..3aab98201edd2df2728c84716c6d3cc7564af692 --- /dev/null +++ b/Emotions/fear/fear_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d203031b3c11766f56bc6d87ae021d061e448fa0a1f863423f4d55a24d99c7db +size 2510004 diff --git a/Emotions/fear/fear_2.png b/Emotions/fear/fear_2.png new file mode 100644 index 0000000000000000000000000000000000000000..d783e2a6db89381dc3253911b86d451493609545 --- /dev/null +++ b/Emotions/fear/fear_2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1f6387173a0495b89d15adab188a9e5c3e68c2536f7ba09e5a39d495a63b3954 +size 2741651 diff --git a/Emotions/fear/fear_3.png b/Emotions/fear/fear_3.png new file mode 100644 index 0000000000000000000000000000000000000000..09b0205cf2f98e88ea42ea97e1d3e61b4ab46fd5 --- /dev/null +++ b/Emotions/fear/fear_3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a55d264eb54f9114413f85eed9d4227c01b2bffeb5ab067983592df253287491 +size 3162079 diff --git a/Emotions/joy/joy_1.png b/Emotions/joy/joy_1.png new file mode 100644 index 0000000000000000000000000000000000000000..78e2855d2d6e13731cacffc87cc133b8a0988cbb --- /dev/null +++ b/Emotions/joy/joy_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bdde85b1578148b43e157aae6e943cf332a79bf965e0614baaf9db2489e60fb1 +size 2067834 diff --git a/Emotions/joy/joy_2.png b/Emotions/joy/joy_2.png new file mode 100644 index 0000000000000000000000000000000000000000..7b9116882e055b63ad2da24ab32539b417e81729 --- /dev/null +++ b/Emotions/joy/joy_2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:50ba3d1a26c80fa27c0836b903e2e93397d7adb5398f0e35b0c8f46ad03bd3a2 +size 1319517 diff --git a/Emotions/joy/joy_3.png b/Emotions/joy/joy_3.png new file mode 100644 index 0000000000000000000000000000000000000000..d6ab6c893eba24e46ad137a51dd04d7fa7ba2409 --- /dev/null +++ b/Emotions/joy/joy_3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:abc5cc933536f5d5a0287b3fc48433947ab9147e3fb22b863ed2c35cf37f1255 +size 3240746 diff --git a/Emotions/sad/sad_1.png b/Emotions/sad/sad_1.png new file mode 100644 index 0000000000000000000000000000000000000000..62a89a951d385dfd38fe66cf9506445e0eccc5fc --- /dev/null +++ b/Emotions/sad/sad_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2f8d8afd794483426c28bfdc6cbac0d503b42d4486e9230770c632fead5b6eff +size 4719413 diff --git a/Emotions/sad/sad_2.png b/Emotions/sad/sad_2.png new file mode 100644 index 0000000000000000000000000000000000000000..4736a1d60de66d418aeb042f3715c57483a38262 --- /dev/null +++ b/Emotions/sad/sad_2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e1bc9975cfff5e5c6530dcd4f7c5e836a535081c7060dc4c5c3cbc79bb7fddbb +size 4022274 diff --git a/Emotions/sad/sad_3.png b/Emotions/sad/sad_3.png new file mode 100644 index 0000000000000000000000000000000000000000..6bdd36ee46fdc5345590c8466e508894c78e0dec --- /dev/null +++ b/Emotions/sad/sad_3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:86f57852a4cf0c67e9bb608fe7984aa0f818a10dec245a28dc2d110f705df51f +size 19520 diff --git a/Models_Class/LSTMModel.py b/Models_Class/LSTMModel.py new file mode 100644 index 0000000000000000000000000000000000000000..337aae0e0310e914014e5aa387791703f55ddcd0 --- /dev/null +++ b/Models_Class/LSTMModel.py @@ -0,0 +1,24 @@ +import torch.nn as nn +import torch + +class LSTMModel(nn.Module): + ## constructor + def __init__(self, input_size, hidden_size, output_size, num_layers): + super(LSTMModel, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.output_size = output_size + self.num_layers = num_layers + self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True) + self.fc = nn.Linear(self.hidden_size, self.output_size) + + def forward(self,x, h0=None, c0=None): + # hidden and state vectors h0 and c0 + if h0 is None or c0 is None: + h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size) + c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size) + + out, (hn, cn) = self.lstm(x, (h0, c0)) + out = self.fc(out) + return out, (hn, cn) + \ No newline at end of file diff --git a/Models_Class/__pycache__/LSTMModel.cpython-311.pyc b/Models_Class/__pycache__/LSTMModel.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..baacc9ab931ca371e0651025d608e3369f197c2b Binary files /dev/null and b/Models_Class/__pycache__/LSTMModel.cpython-311.pyc differ diff --git a/Models_Class/__pycache__/LSTMModel.cpython-312.pyc b/Models_Class/__pycache__/LSTMModel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40b25584934091d3ab9560e05a6c6d81f78bdd41 Binary files /dev/null and b/Models_Class/__pycache__/LSTMModel.cpython-312.pyc differ diff --git a/Models_Class/__pycache__/NST_class.cpython-311.pyc b/Models_Class/__pycache__/NST_class.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9689a149d7e9cb671220bb5a68e23475944dd52 Binary files /dev/null and b/Models_Class/__pycache__/NST_class.cpython-311.pyc differ diff --git a/Painters/Pablo Picasso/Dora Maar with Cat (1941).png b/Painters/Pablo Picasso/Dora Maar with Cat (1941).png new file mode 100644 index 0000000000000000000000000000000000000000..b541cbb933fe60f8147155fc431d4bda9426b377 --- /dev/null +++ b/Painters/Pablo Picasso/Dora Maar with Cat (1941).png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e3d9a1c358f10e2d5a078fd0b8c6e7360f4de921aaace30d6162524ac1189602 +size 25292 diff --git a/Painters/Pablo Picasso/The Weeping Woman (1937).png b/Painters/Pablo Picasso/The Weeping Woman (1937).png new file mode 100644 index 0000000000000000000000000000000000000000..7c5bf2839d5dc15e89b795a2d4c876af1cc782d1 --- /dev/null +++ b/Painters/Pablo Picasso/The Weeping Woman (1937).png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b10f3b39125ef7c096dfd165b77faa1079774201abe1ec45b563744a4d4b8827 +size 32746 diff --git a/Painters/Pablo Picasso/Three Musicians (1921).png b/Painters/Pablo Picasso/Three Musicians (1921).png new file mode 100644 index 0000000000000000000000000000000000000000..89a143a058523a852fcfb6faf94d85e4dcc9dc79 --- /dev/null +++ b/Painters/Pablo Picasso/Three Musicians (1921).png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f042345d98f8128fbfd8d84b3c84660fa7326418b63599efa7cfe429fd3bb16a +size 521103 diff --git "a/Painters/Salvador Dal\303\255/Sleep (1937).png" "b/Painters/Salvador Dal\303\255/Sleep (1937).png" new file mode 100644 index 0000000000000000000000000000000000000000..7db6250078ab5171e016f2962bab768c8b1e073b --- /dev/null +++ "b/Painters/Salvador Dal\303\255/Sleep (1937).png" @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:987d628b6c0ac7367b5c1b24ede65664ba32baba4f7a94264dc42cbd221c7139 +size 60452 diff --git "a/Painters/Salvador Dal\303\255/Swans Reflecting Elephants (1937).png" "b/Painters/Salvador Dal\303\255/Swans Reflecting Elephants (1937).png" new file mode 100644 index 0000000000000000000000000000000000000000..f799552c3219194f7d6d4c49f6aaf1d35e351f2a --- /dev/null +++ "b/Painters/Salvador Dal\303\255/Swans Reflecting Elephants (1937).png" @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f71df5184f15adcdcaed509336b02edc37dd0b2daeb7049142dd61075e126148 +size 27230 diff --git "a/Painters/Salvador Dal\303\255/The Persistence of Memory (1931).png" "b/Painters/Salvador Dal\303\255/The Persistence of Memory (1931).png" new file mode 100644 index 0000000000000000000000000000000000000000..2353c5eae9abf4464425c3a151382d323d8da2d9 --- /dev/null +++ "b/Painters/Salvador Dal\303\255/The Persistence of Memory (1931).png" @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f306a83ccc3d3f0b2e7894b92ceef114f6d5361750cb982d793acf613a0d107e +size 62011 diff --git a/Painters/Vincent van Gogh/Sunflowers (1888).png b/Painters/Vincent van Gogh/Sunflowers (1888).png new file mode 100644 index 0000000000000000000000000000000000000000..89eb7e48334342845b26cb2af1fb069bee585a01 --- /dev/null +++ b/Painters/Vincent van Gogh/Sunflowers (1888).png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c8cd0e84094378a927dab0432c416e8174204bb98af394ceb757f369d8ef4a2c +size 93950 diff --git a/Painters/Vincent van Gogh/The Potato Eaters (1885).png b/Painters/Vincent van Gogh/The Potato Eaters (1885).png new file mode 100644 index 0000000000000000000000000000000000000000..5a12df25b5e4b8bc82f1c51c40ab4190d581a0f0 --- /dev/null +++ b/Painters/Vincent van Gogh/The Potato Eaters (1885).png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d92cb55929ba1f20a552f828625e19277278161cff287f7f2f0fc6448c7eb2b5 +size 135374 diff --git a/Painters/Vincent van Gogh/The Starry Night (1889).png b/Painters/Vincent van Gogh/The Starry Night (1889).png new file mode 100644 index 0000000000000000000000000000000000000000..d8cb2e1e320ff2c79bb9c22a167fbffa31688f23 --- /dev/null +++ b/Painters/Vincent van Gogh/The Starry Night (1889).png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5f74a2599ced6f8437e2b29a786bdde2a451ff8f01d87418a14c6bbe2ce0ad6a +size 541477 diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3af9e36959c25360ee379655e25e88147d02e31b --- /dev/null +++ b/README.md @@ -0,0 +1,13 @@ +--- +title: Brain Emotion Decoder +emoji: 🔥 +colorFrom: purple +colorTo: red +sdk: gradio +sdk_version: 5.45.0 +app_file: app.py +pinned: false +short_description: Decoding Emotions through Brain Signals +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/Src/Inference.py b/Src/Inference.py new file mode 100644 index 0000000000000000000000000000000000000000..a04bab5bbfedd51c6a3eb4577d3cb75a92d94f76 --- /dev/null +++ b/Src/Inference.py @@ -0,0 +1,8 @@ +## Start here with the inference procedure +import torch +from Models_Class.LSTMModel import LSTMModel + +def load_model(model_path, input_size, hidden_size, output_size, num_layers): + loaded_model = LSTMModel(input_size, hidden_size, output_size, num_layers) + loaded_model.load_state_dict(torch.load(model_path)) + return loaded_model \ No newline at end of file diff --git a/Src/NST_Inference.py b/Src/NST_Inference.py new file mode 100644 index 0000000000000000000000000000000000000000..694a3dfb378afe8df61246ba906f3aff7ba0caac --- /dev/null +++ b/Src/NST_Inference.py @@ -0,0 +1,81 @@ +import argparse +from pathlib import Path + +import torch +import torch.nn as nn +from PIL import Image +from torchvision import transforms +from torchvision.utils import save_image + +from . import net +from .function import adaptive_instance_normalization + + +def test_transform(size, crop): + transform_list = [] + if size != 0: + transform_list.append(transforms.Resize(size)) + if crop: + transform_list.append(transforms.CenterCrop(size)) + transform_list.append(transforms.ToTensor()) + transform = transforms.Compose(transform_list) + return transform + + +def style_transfer(vgg, decoder, content, style, alpha=1.0, + interpolation_weights=None): + assert (0.0 <= alpha <= 1.0) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + content_f = vgg(content) + style_f = vgg(style) + if interpolation_weights: + _, C, H, W = content_f.size() + feat = torch.FloatTensor(1, C, H, W).zero_().to(device) + base_feat = adaptive_instance_normalization(content_f, style_f) + for i, w in enumerate(interpolation_weights): + feat = feat + w * base_feat[i:i + 1] + content_f = content_f[0:1] + else: + feat = adaptive_instance_normalization(content_f, style_f) + feat = feat * alpha + content_f * (1 - alpha) + return decoder(feat) + + +def save_style(output_dir, content_path, style_path): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + decoder_pth = Path("models/decoder.pth") + vgg_pth = Path("models/vgg_normalised.pth") + output_dir = Path("output") + output_dir.mkdir(exist_ok=True, parents=True) + content_path = Path(content_path) + style_paths = [Path(style_path)] + + decoder = net.decoder + vgg = net.vgg + + decoder.eval() + vgg.eval() + + decoder.load_state_dict(torch.load(decoder_pth)) + vgg.load_state_dict(torch.load(vgg_pth)) + vgg = nn.Sequential(*list(vgg.children())[:31]) + + vgg.to(device) + decoder.to(device) + + content_tf = test_transform(512, True) + style_tf = test_transform(512, True) + style = torch.stack([style_tf(Image.open(str(p))) for p in style_paths]) + content = content_tf(Image.open(str(content_path))) \ + .unsqueeze(0).expand_as(style) + style = style.to(device) + content = content.to(device) + with torch.no_grad(): + output = style_transfer(vgg, decoder, content, style, + 1, '') + output = output.cpu() + output_name = output_dir / 'stylized_output.jpg' + save_image(output, str(output_name)) + return output_name + + \ No newline at end of file diff --git a/Src/Processing.py b/Src/Processing.py new file mode 100644 index 0000000000000000000000000000000000000000..e816955f0833d0dd5c5019d7f0943ff205a5d533 --- /dev/null +++ b/Src/Processing.py @@ -0,0 +1,17 @@ +import numpy as np + +emotion_list = ["0", "1", "2", "3", "4", "5", "6"] + + +def load_data(psd_file_pth): + np_data = np.load(psd_file_pth, allow_pickle=True).item()["psd"] + return np_data + + +def process_data(np_data): + #Swap axes + swapped_data = np.swapaxes(np_data, 0, 1) + ## reshape data + reshape_data = swapped_data.reshape(630, 320) + return reshape_data + \ No newline at end of file diff --git a/Src/Processing_img.py b/Src/Processing_img.py new file mode 100644 index 0000000000000000000000000000000000000000..4293ad71742e80850e6e301759bf11c6a014a29a --- /dev/null +++ b/Src/Processing_img.py @@ -0,0 +1,110 @@ +from PIL import Image +import torch +import torch.optim as optim +from torchvision import transforms +import torch.nn as nn +from Models_Class.NST_class import ( + ContentLoss, + Normalization, + StyleLoss, +) + +import copy + +style_weight = 1e8 +content_weight = 1e1 +def image_loader(image_path, loader, device): + image = Image.open(image_path).convert('RGB') + image = loader(image).unsqueeze(0) + return image.to(device, torch.float) + +def save_image(tensor, path="output.png"): + image = tensor.cpu().clone() + image = image.squeeze(0) + image = transforms.ToPILImage()(image) + image.save(path) + +def gram_matrix(input): + a, b, c, d = input.size() + features = input.view(a * b, c * d) + G = torch.mm(features, features.t()) + return G.div(a * b * c * d) + + +def get_style_model_and_losses(cnn, normalization_mean, normalization_std, + style_img, content_img, content_layers, style_layers, device): + cnn = copy.deepcopy(cnn) + normalization = Normalization(normalization_mean, normalization_std).to(device) + content_losses = [] + style_losses = [] + model = nn.Sequential(normalization) + + i = 0 + for layer in cnn.children(): + if isinstance(layer, nn.Conv2d): + i += 1 + name = f'conv_{i}' + elif isinstance(layer, nn.ReLU): + name = f'relu_{i}' + layer = nn.ReLU(inplace=False) + elif isinstance(layer, nn.MaxPool2d): + name = f'pool_{i}' + elif isinstance(layer, nn.BatchNorm2d): + name = f'bn_{i}' + else: + raise RuntimeError(f'Unrecognized layer: {layer.__class__.__name__}') + + model.add_module(name, layer) + + if name in content_layers: + target = model(content_img).detach() + content_loss = ContentLoss(target) + model.add_module(f"content_loss_{i}", content_loss) + content_losses.append(content_loss) + + if name in style_layers: + target_feature = model(style_img).detach() + style_loss = StyleLoss(target_feature) + model.add_module(f"style_loss_{i}", style_loss) + style_losses.append(style_loss) + + for i in range(len(model) - 1, -1, -1): + if isinstance(model[i], (ContentLoss, StyleLoss)): + break + model = model[:i+1] + return model, style_losses, content_losses + + + +def run_style_transfer(cnn, normalization_mean, normalization_std, + content_img, style_img, input_img,content_layers, style_layers, device, num_steps=300): + print("Building the style transfer model..") + model, style_losses, content_losses = get_style_model_and_losses(cnn, normalization_mean, normalization_std, + style_img, content_img,content_layers, style_layers, device ) + optimizer = optim.LBFGS([input_img.requires_grad_()]) + + print("Optimizing..") + run = [0] + while run[0] <= num_steps: + def closure(): + input_img.data.clamp_(0, 1) + optimizer.zero_grad() + model(input_img) + style_score = sum(sl.loss for sl in style_losses) + content_score = sum(cl.loss for cl in content_losses) + loss = style_weight * style_score + content_weight * content_score + loss.backward() + + if run[0] % 50 == 0: + print(f"Step {run[0]}:") + print(f" Style Loss: {style_score.item():.4f}") + print(f" Content Loss: {content_score.item():.4f}") + print(f" Total Loss: {loss.item():.4f}\n") + + run[0] += 1 + return loss + + optimizer.step(closure) + + input_img.data.clamp_(0, 1) + return input_img \ No newline at end of file diff --git a/Src/__init__.py b/Src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Src/__pycache__/Inference.cpython-311.pyc b/Src/__pycache__/Inference.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0904aab2a95bca30d87a9491f8a0ebc3f18aeb4 Binary files /dev/null and b/Src/__pycache__/Inference.cpython-311.pyc differ diff --git a/Src/__pycache__/Inference.cpython-312.pyc b/Src/__pycache__/Inference.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f4f916e2c577e112a1f21189a8cb0ca7ccc5fa0 Binary files /dev/null and b/Src/__pycache__/Inference.cpython-312.pyc differ diff --git a/Src/__pycache__/NST_Inference.cpython-311.pyc b/Src/__pycache__/NST_Inference.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65eb4784e653bedf210c19f9d6aa9a47c1c7fcfc Binary files /dev/null and b/Src/__pycache__/NST_Inference.cpython-311.pyc differ diff --git a/Src/__pycache__/Processing.cpython-311.pyc b/Src/__pycache__/Processing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfbeaf1328d37f20257eb9947e2e30a0d30c9a6f Binary files /dev/null and b/Src/__pycache__/Processing.cpython-311.pyc differ diff --git a/Src/__pycache__/Processing.cpython-312.pyc b/Src/__pycache__/Processing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..277023357d1c62f6d1289c93da5460e5f60c0b36 Binary files /dev/null and b/Src/__pycache__/Processing.cpython-312.pyc differ diff --git a/Src/__pycache__/Processing_img.cpython-311.pyc b/Src/__pycache__/Processing_img.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f60f360af02748d8fb159a7bee6d1d223312cb0 Binary files /dev/null and b/Src/__pycache__/Processing_img.cpython-311.pyc differ diff --git a/Src/__pycache__/__init__.cpython-311.pyc b/Src/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b513225f7bb75478b91038913045ae140e6600d4 Binary files /dev/null and b/Src/__pycache__/__init__.cpython-311.pyc differ diff --git a/Src/__pycache__/__init__.cpython-312.pyc b/Src/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eac7c0487c5d38892430c2337d40cfdc13f354b6 Binary files /dev/null and b/Src/__pycache__/__init__.cpython-312.pyc differ diff --git a/Src/__pycache__/function.cpython-311.pyc b/Src/__pycache__/function.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..901a4fb7129b52489a65e43bb1d1a5c823c57bc9 Binary files /dev/null and b/Src/__pycache__/function.cpython-311.pyc differ diff --git a/Src/__pycache__/net.cpython-311.pyc b/Src/__pycache__/net.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af5c47909d3e4b9a95949d3b787a3664eec5d4e6 Binary files /dev/null and b/Src/__pycache__/net.cpython-311.pyc differ diff --git a/Src/function.py b/Src/function.py new file mode 100644 index 0000000000000000000000000000000000000000..090f7417ce67fd08932f9c11fc6a6c060f1c5ce1 --- /dev/null +++ b/Src/function.py @@ -0,0 +1,67 @@ +import torch + + +def calc_mean_std(feat, eps=1e-5): + # eps is a small value added to the variance to avoid divide-by-zero. + size = feat.size() + assert (len(size) == 4) + N, C = size[:2] + feat_var = feat.view(N, C, -1).var(dim=2) + eps + feat_std = feat_var.sqrt().view(N, C, 1, 1) + feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) + return feat_mean, feat_std + + +def adaptive_instance_normalization(content_feat, style_feat): + assert (content_feat.size()[:2] == style_feat.size()[:2]) + size = content_feat.size() + style_mean, style_std = calc_mean_std(style_feat) + content_mean, content_std = calc_mean_std(content_feat) + + normalized_feat = (content_feat - content_mean.expand( + size)) / content_std.expand(size) + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + + +def _calc_feat_flatten_mean_std(feat): + # takes 3D feat (C, H, W), return mean and std of array within channels + assert (feat.size()[0] == 3) + assert (isinstance(feat, torch.FloatTensor)) + feat_flatten = feat.view(3, -1) + mean = feat_flatten.mean(dim=-1, keepdim=True) + std = feat_flatten.std(dim=-1, keepdim=True) + return feat_flatten, mean, std + + +def _mat_sqrt(x): + U, D, V = torch.svd(x) + return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t()) + + +def coral(source, target): + # assume both source and target are 3D array (C, H, W) + # Note: flatten -> f + + source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source) + source_f_norm = (source_f - source_f_mean.expand_as( + source_f)) / source_f_std.expand_as(source_f) + source_f_cov_eye = \ + torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3) + + target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target) + target_f_norm = (target_f - target_f_mean.expand_as( + target_f)) / target_f_std.expand_as(target_f) + target_f_cov_eye = \ + torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3) + + source_f_norm_transfer = torch.mm( + _mat_sqrt(target_f_cov_eye), + torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)), + source_f_norm) + ) + + source_f_transfer = source_f_norm_transfer * \ + target_f_std.expand_as(source_f_norm) + \ + target_f_mean.expand_as(source_f_norm) + + return source_f_transfer.view(source.size()) diff --git a/Src/net.py b/Src/net.py new file mode 100644 index 0000000000000000000000000000000000000000..4576334f52f1ff332933362d477dc52122541fc8 --- /dev/null +++ b/Src/net.py @@ -0,0 +1,152 @@ +import torch.nn as nn + +from .function import adaptive_instance_normalization as adain +from .function import calc_mean_std + +decoder = nn.Sequential( + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(512, 256, (3, 3)), + nn.ReLU(), + nn.Upsample(scale_factor=2, mode='nearest'), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(256, 256, (3, 3)), + nn.ReLU(), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(256, 256, (3, 3)), + nn.ReLU(), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(256, 256, (3, 3)), + nn.ReLU(), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(256, 128, (3, 3)), + nn.ReLU(), + nn.Upsample(scale_factor=2, mode='nearest'), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(128, 128, (3, 3)), + nn.ReLU(), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(128, 64, (3, 3)), + nn.ReLU(), + nn.Upsample(scale_factor=2, mode='nearest'), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(64, 64, (3, 3)), + nn.ReLU(), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(64, 3, (3, 3)), +) + +vgg = nn.Sequential( + nn.Conv2d(3, 3, (1, 1)), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(3, 64, (3, 3)), + nn.ReLU(), # relu1-1 + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(64, 64, (3, 3)), + nn.ReLU(), # relu1-2 + nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(64, 128, (3, 3)), + nn.ReLU(), # relu2-1 + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(128, 128, (3, 3)), + nn.ReLU(), # relu2-2 + nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(128, 256, (3, 3)), + nn.ReLU(), # relu3-1 + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(256, 256, (3, 3)), + nn.ReLU(), # relu3-2 + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(256, 256, (3, 3)), + nn.ReLU(), # relu3-3 + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(256, 256, (3, 3)), + nn.ReLU(), # relu3-4 + nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(256, 512, (3, 3)), + nn.ReLU(), # relu4-1, this is the last layer used + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(512, 512, (3, 3)), + nn.ReLU(), # relu4-2 + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(512, 512, (3, 3)), + nn.ReLU(), # relu4-3 + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(512, 512, (3, 3)), + nn.ReLU(), # relu4-4 + nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(512, 512, (3, 3)), + nn.ReLU(), # relu5-1 + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(512, 512, (3, 3)), + nn.ReLU(), # relu5-2 + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(512, 512, (3, 3)), + nn.ReLU(), # relu5-3 + nn.ReflectionPad2d((1, 1, 1, 1)), + nn.Conv2d(512, 512, (3, 3)), + nn.ReLU() # relu5-4 +) + + +class Net(nn.Module): + def __init__(self, encoder, decoder): + super(Net, self).__init__() + enc_layers = list(encoder.children()) + self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1 + self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1 + self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1 + self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1 + self.decoder = decoder + self.mse_loss = nn.MSELoss() + + # fix the encoder + for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']: + for param in getattr(self, name).parameters(): + param.requires_grad = False + + # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image + def encode_with_intermediate(self, input): + results = [input] + for i in range(4): + func = getattr(self, 'enc_{:d}'.format(i + 1)) + results.append(func(results[-1])) + return results[1:] + + # extract relu4_1 from input image + def encode(self, input): + for i in range(4): + input = getattr(self, 'enc_{:d}'.format(i + 1))(input) + return input + + def calc_content_loss(self, input, target): + assert (input.size() == target.size()) + assert (target.requires_grad is False) + return self.mse_loss(input, target) + + def calc_style_loss(self, input, target): + assert (input.size() == target.size()) + assert (target.requires_grad is False) + input_mean, input_std = calc_mean_std(input) + target_mean, target_std = calc_mean_std(target) + return self.mse_loss(input_mean, target_mean) + \ + self.mse_loss(input_std, target_std) + + def forward(self, content, style, alpha=1.0): + assert 0 <= alpha <= 1 + style_feats = self.encode_with_intermediate(style) + content_feat = self.encode(content) + t = adain(content_feat, style_feats[-1]) + t = alpha * t + (1 - alpha) * content_feat + + g_t = self.decoder(t) + g_t_feats = self.encode_with_intermediate(g_t) + + loss_c = self.calc_content_loss(g_t_feats[-1], t) + loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0]) + for i in range(1, 4): + loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i]) + return loss_c, loss_s diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..8f7a67589bc1c7e02f0586db4d21ed969ed8e330 --- /dev/null +++ b/app.py @@ -0,0 +1,279 @@ +import random +import gradio as gr +import pandas as pd +import numpy as np +from Src.Processing import load_data +from Src.Processing import process_data +from Src.Inference import load_model +from Src.NST_Inference import save_style +import torch +import time +import os +import mne +import matplotlib.pyplot as plt +import io +import matplotlib.cm as cm +import gradio as gr + + +dummy_emotion_data = pd.DataFrame({ + 'Emotion': ['sad', 'dis', 'fear', 'neu', 'joy', 'ten', 'ins'], + 'Value': [0.8, 0.6, 0.1, 0.4, 0.7, 0.2, 0.3] +}) + +int_to_emotion = { + 0: 'sad', + 1: 'dis', + 2: 'fear', + 3: 'neu', + 4: 'joy', + 5: 'ten', + 6: 'ins' +} + +abr_to_emotion = { + 'sad': "sadness", + 'dis': "disgust", + 'fear': "fear", + 'neu': "neutral", + 'joy': "joy", + 'ten': 'Tenderness', + 'ins': "inspiration" +} + +PAINTERS_BASE_DIR = "Painters" +EMOTION_BASE_DIR = "Emotions" +output_dir = "outputs" +input_size = 320 +hidden_size=50 +output_size = 7 +num_layers=1 + +painters = ["Pablo Picasso", "Vincent van Gogh", "Salvador DalÃ"] +predefined_psd_files = ["task-emotion_psd_1.npy", "task-emotion_psd_2.npy", "task-emotion_psd_3.npy"] +Base_Dir = "Datasets" + +PAINTER_PLACEHOLDER_DATA = { + "Pablo Picasso": [ + ("Dora Maar with Cat (1941).png", "Dora Maar with Cat (1941)"), + ("The Weeping Woman (1937).png", "The Weeping Woman (1937)"), + ("Three Musicians (1921).png", "Three Musicians (1921)"), + ], + "Vincent van Gogh": [ + ("Sunflowers (1888).png", "Sunflowers (1888)"), + ("The Starry Night (1889).png", "The Starry Night (1889)"), + ("The Potato Eaters (1885).png", "The Potato Eaters (1885)"), + ], + "Salvador DalÃ": [ + ("Persistence of Memory (1931).png", "Persistence of Memory (1931)"), + ("Swans Reflecting Elephants (1937).png", "Swans Reflecting Elephants (1937)"), + ("Sleep (1937).png", "Sleep (1937)"), + ], +} + +def upload_psd_file(selected_file_name): + """ + Processes a selected PSD file, performs inference, and prepares emotion distribution data. + """ + if selected_file_name is None: + return gr.BarPlot(dummy_emotion_data, x="Emotion", y="Value", label="Emotion Distribution", visible=False), pd.DataFrame() + + psd_file_path = os.path.join(Base_Dir, selected_file_name).replace(os.sep, '/') + + try: + global np_data + np_data = load_data(psd_file_path) + print(f"np data orig {np_data.shape}") + except FileNotFoundError: + print(f"Error: PSD file not found at {psd_file_path}") + # Return a plot with error message or just hide it + return gr.BarPlot(dummy_emotion_data, x="Emotion", y="Value", label="Emotion Distribution (Error: File not found)", visible=False), pd.DataFrame() + + + final_data = process_data(np_data) + torch_data = torch.tensor(final_data, dtype=torch.float32).unsqueeze(0) + absolute_model_path = os.path.join("models", "lstm_emotion_model_state.pth") + loaded_model = load_model(absolute_model_path, input_size, hidden_size, output_size, num_layers) + loaded_model.eval() + with torch.no_grad(): + predicted_logits, _ = loaded_model(torch_data) + final_output_indices = torch.argmax(predicted_logits, dim=2) + all_predicted_indices = final_output_indices.view(-1) + + # Count occurrences of each predicted emotion index + values_count = torch.bincount(all_predicted_indices, minlength=output_size) + print(f"Raw bincount: {values_count}") + emotions_count = {int_to_emotion[i].strip(): 0 for i in range(output_size)} + for idx, count in enumerate(values_count): + if idx < output_size: + emotions_count[int_to_emotion[idx].strip()] = count.item() + dom_emotion = max(emotions_count, key=emotions_count.get) + emotion_data = pd.DataFrame({ + "Emotion": list(emotions_count.keys()), + "Frequency": list(emotions_count.values()) + }) + emotion_data = emotion_data.sort_values(by="Emotion").reset_index(drop=True) + print(f"Final emotion_data DataFrame:\n{emotion_data}") + + return gr.BarPlot( + emotion_data, + x="Emotion", + y="Frequency", + label="Emotion Distribution", + visible=True, + y_title="Frequency" + ), emotion_data, gr.Textbox(abr_to_emotion[dom_emotion], visible=True) + + +def update_paintings(painter_name): + """ + Updates the gallery with paintings specific to the selected painter by + dynamically listing files in the painter's directory. + """ + painter_dir = os.path.join(PAINTERS_BASE_DIR, painter_name).replace(os.sep, '/') + artist_paintings_for_gallery = [] + if os.path.isdir(painter_dir): + for filename in sorted(os.listdir(painter_dir)): + if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.gif')): + file_path = os.path.join(painter_dir, filename).replace(os.sep, '/') + print(file_path) + title_with_ext = os.path.splitext(filename)[0] + artist_paintings_for_gallery.append((file_path, title_with_ext)) + print(f"Loaded paintings for {painter_name}: {artist_paintings_for_gallery}") + return artist_paintings_for_gallery + + +def generate_my_art(painter, chosen_painting, dom_emotion): + if not painter or not chosen_painting: + return "Please select a painter and a painting.", None, None + img_style_pth = os.path.join(PAINTERS_BASE_DIR, painter, chosen_painting) + print(f"img_stype_path: {img_style_pth}") + time.sleep(3) + ##original image + emotion_pth = os.path.join(EMOTION_BASE_DIR, dom_emotion) + image_name = list(os.listdir(emotion_pth))[random.randint(0, len(os.listdir(emotion_pth)) -1)] + original_image_pth = os.path.join(emotion_pth, image_name) + print(f"original img _path: {original_image_pth}") + final_message = f"Art generated based on {painter}'s {chosen_painting} style!" + ## Neural Style Transfer + stylized_img_path = save_style(output_dir, original_image_pth, img_style_pth) + yield gr.Textbox(final_message), original_image_pth, stylized_img_path + +# --- Gradio Interface Definition --- + +with gr.Blocks(css=".gradio-container { max-width: 2000px; margin: auto; }") as demo: + current_emotion_df_state = gr.State(value=pd.DataFrame()) + # Header Section + gr.Markdown( + """ +
+ Imagine seeing your deepest feelings transform into art. We decode the underlying emotions from your brain activity, + generating a personalized artwork. Discover the art of your inner self. +
+ """ + ) + + with gr.Row(): + with gr.Column(scale=1): + gr.Markdown("+ Once your brain's emotional data is processed, we pinpoint the dominant emotion. This single feeling inspires a personalized artwork. You can then download this unique visual representation of your inner self. +
+ """ + ) + + with gr.Row(): + with gr.Column(scale=1): + gr.Markdown("