{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ckadirt/miniconda3/envs/b2m/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import os, torch, torch.nn as nn, torch.utils.data as data, torchvision as tv\n", "import lightning as L\n", "import numpy as np, pandas as pd, matplotlib.pyplot as plt\n", "from pytorch_lightning.loggers import WandbLogger" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# create the datasets and dataloaders\n", "train_voxels_path = '/home/ckadirt/brain2music/dataset/preproc/sub-001_Resp_Training.npy' # path to training voxels 65000 * 4800 \n", "test_voxels_path = '/home/ckadirt/brain2music/dataset/preproc/sub-001_Resp_Test_Mean.npy' # path to test voxels 65000 * 600\n", "\n", "train_embeddings_path = '/home/ckadirt/brain2music/encodec_training_embeds_150.npy' # path to training embeddings 480 * 2 * 1125\n", "test_embeddings_path = '/home/ckadirt/brain2music/encodec_test_embeds_150.npy' # path to test embeddings 600 * 2 * 1125\n", "\n", "class VoxelsDataset(data.Dataset):\n", " def __init__(self, voxels_path, embeddings_path):\n", " # transpose the two dimensions of the voxels data to match the embeddings data\n", " self.voxels = torch.from_numpy(np.load(voxels_path)).float().transpose(0, 1)\n", " self.embeddings = torch.from_numpy(np.load(embeddings_path))\n", " # as each stimulus has been exposed for 15 seconds and the fMRI data is sampled every 1.5 seconds, we take 10 samples per stimulus\n", " self.len = len(self.voxels) // 10\n", " print(\"The len is \", self.len )\n", "\n", " def __getitem__(self, index):\n", " # as each stimulus has been exposed for 15 seconds and the fMRI data is sampled every 1.5 seconds, we take 10 samples per stimulus\n", " voxels = self.voxels[index*10:(index+1)*10]\n", " embeddings = self.embeddings[index]\n", " return voxels, embeddings\n", "\n", " def __len__(self):\n", " return self.len\n", " \n", "class VoxelsEmbeddinsEncodecDataModule(L.LightningDataModule):\n", " def __init__(self, train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=4):\n", " super().__init__()\n", " self.train_voxels_path = train_voxels_path\n", " self.train_embeddings_path = train_embeddings_path\n", " self.test_voxels_path = test_voxels_path\n", " self.test_embeddings_path = test_embeddings_path\n", " self.batch_size = batch_size\n", "\n", " def setup(self, stage=None):\n", " self.train_dataset = VoxelsDataset(self.train_voxels_path, self.train_embeddings_path)\n", " self.test_dataset = VoxelsDataset(self.test_voxels_path, self.test_embeddings_path)\n", "\n", " def train_dataloader(self):\n", " return data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)\n", "\n", " def val_dataloader(self):\n", " return data.DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "data_module_example = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=4)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "data_module_example.setup()\n", "train_dataloader = data_module_example.train_dataloader()\n", "val_dataset = data_module_example.val_dataloader()" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([], size=(0, 60784)),\n", " tensor([[ 302., 244., 660., 854., 660., 480., 854., 618., 618., 854.,\n", " 790., 750., 659., 59., 891., 891., 536., 167., 343., 536.,\n", " 715., 758., 758., 758., 480., 498., 854., 4., 4., 308.,\n", " 270., 342., 342., 660., 342., 854., 342., 435., 549., 150.,\n", " 631., 485., 844., 366., 266., 35., 847., 667., 862., 109.,\n", " 573., 379., 226., 573., 603., 513., 178., 302., 715., 631.,\n", " 342., 258., 244., 302., 715., 854., 854., 294., 366., 660.,\n", " 361., 302., 729., 962., 790., 711., 660., 243., 294., 802.,\n", " 329., 513., 962., 342., 711., 244., 243., 549., 802., 854.,\n", " 750., 81., 342., 381., 854., 603., 790., 109., 294., 513.,\n", " 419., 485., 504., 660., 361., 790., 790., 167., 802., 246.,\n", " 485., 246., 81., 1023., 149., 81., 943., 504., 755., 414.,\n", " 246., 972., 715., 1023., 790., 692., 790., 572., 504., 302.,\n", " 308., 853., 631., 657., 790., 361., 660., 715., 686., 213.,\n", " 226., 187., 586., 361., 485., 790., 729., 951., 962., 485.],\n", " [ 963., 645., 645., 326., 138., 1013., 680., 525., 411., 102.,\n", " 462., 466., 698., 409., 289., 923., 878., 415., 386., 604.,\n", " 975., 162., 603., 284., 233., 75., 244., 1016., 1016., 242.,\n", " 67., 194., 122., 492., 856., 997., 997., 221., 243., 814.,\n", " 386., 598., 317., 166., 583., 439., 654., 430., 201., 160.,\n", " 813., 716., 312., 664., 204., 462., 375., 451., 67., 535.,\n", " 854., 209., 548., 812., 657., 827., 408., 411., 422., 352.,\n", " 99., 711., 664., 239., 890., 529., 617., 186., 536., 178.,\n", " 29., 930., 187., 973., 354., 450., 468., 273., 995., 653.,\n", " 935., 335., 973., 812., 348., 664., 575., 184., 299., 782.,\n", " 36., 29., 641., 653., 105., 958., 653., 828., 981., 218.,\n", " 1021., 381., 356., 35., 416., 675., 45., 839., 690., 331.,\n", " 634., 610., 317., 745., 673., 331., 575., 57., 100., 564.,\n", " 590., 492., 902., 53., 73., 332., 1005., 395., 679., 781.,\n", " 174., 74., 121., 667., 265., 479., 583., 655., 163., 81.]]))" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val_dataset.dataset[239]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class MLP(L.LightningModule):\n", " def __init__(self, sizes, residual_conections, dropout):\n", " # sizes is a list of the sizes of the layers ej: [4800, 1000, 1000, 1000, 1000, 1000, 1000, 600]\n", " # residual_conections is a list with the same length as sizes, each element is a list of the indexes of the layers that will recieve the output of the layer as input, 0 means that the layer will recieve the x inputs ej. [[0], [1], [2,1], [3], [4,3], [5], [6,5], [7]]\n", " # dropout is a list with the same length as sizes, each element is the dropout probability of the layer ej. [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]\n", " super().__init__()\n", " self.sizes = sizes\n", " self.residual_conections = residual_conections\n", " self.dropout = dropout\n", " self.layers = nn.Sequential()\n", " for i in range(len(sizes)-1):\n", " self.layers.add_module('linear'+str(i), nn.Linear(sizes[i], sizes[i+1]))\n", " self.layers.add_module('relu'+str(i), nn.ReLU())\n", " self.layers.add_module('dropout'+str(i), nn.Dropout(dropout[i]))\n", "\n", " self.loss = nn.CrossEntropyLoss()\n", " self.test_outptus = []\n", " self.train_outptus = []\n", "\n", " def forward(self, x):\n", " return self.layers(x)\n", " \n", " def training_step(self, batch, batch_idx):\n", " voxels, embeddings = batch # the sizes are [batch_size, 10, 65000] and [batch_size, 2, 1125]\n", " # flatten the voxels to [batch_size, rest of the dimensions]\n", " embeddings = embeddings.flatten(start_dim=1).long() # the size is [batch_size, 2250] \n", " #take just the first 200 embeddings\n", " embeddings = embeddings[:, :200]\n", " # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus\n", " voxels = voxels[:, 0:2, :]\n", " voxels = voxels.mean(dim=1)\n", " voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n", " outputs = self(voxels)\n", " # the outputs are [batch_size, 200*1024], we need to reshape them to [batch_size, 200, 1024]\n", " outputs = outputs.reshape(-1, 1024, 200)\n", " loss = self.loss(outputs, embeddings)\n", " acuracy = self.tokens_accuracy(outputs, embeddings)\n", " self.log('train_loss', loss, sync_dist=True)\n", " self.log('train_accuracy', acuracy, sync_dist=True)\n", " discrete_outputs = outputs.argmax(dim=1)\n", " self.train_outptus.append(discrete_outputs)\n", " return loss\n", " \n", " def tokens_accuracy(self, outputs, embeddings):\n", " # outputs is [batch_size, 1024, 200]\n", " # embeddings is [batch_size, 200]\n", " # we need to get the index of the maximum value of each token\n", " outputs = outputs.argmax(dim=1)\n", " # now we need to compare the outputs with the embeddings\n", " return (outputs == embeddings).float().mean()\n", " \n", " def on_train_epoch_end(self):\n", " self.train_outptus = torch.cat(self.train_outptus)\n", " # save the outputs with the current epoch name\n", " torch.save(self.train_outptus, 'outputs_train'+str(self.current_epoch)+'.pt')\n", " self.train_outptus = []\n", " \n", " def on_validation_epoch_end(self):\n", " self.test_outptus = torch.cat(self.test_outptus)\n", " # save the outputs with the current epoch name\n", " torch.save(self.test_outptus, 'outputs_validation'+str(self.current_epoch)+'.pt')\n", " self.test_outptus = []\n", "\n", " \n", " def validation_step(self, batch, batch_idx):\n", " voxels, embeddings = batch\n", " embeddings = embeddings.flatten(start_dim=1).long()\n", " embeddings = embeddings[:, :200]\n", " voxels = voxels[:, 0:2, :]\n", " voxels = voxels.mean(dim=1)\n", " voxels = voxels.flatten(start_dim=1)\n", " outputs = self(voxels)\n", " outputs = outputs.reshape(-1, 1024, 200)\n", " loss = self.loss(outputs, embeddings)\n", " accuracy = self.tokens_accuracy(outputs, embeddings)\n", " self.log('val_loss', loss, sync_dist=True)\n", " self.log('val_accuracy', accuracy, sync_dist=True)\n", " discrete_outputs = outputs.argmax(dim=1)\n", " self.test_outptus.append(discrete_outputs)\n", " return loss\n", " \n", " \n", " def configure_optimizers(self):\n", " return torch.optim.Adam(self.parameters(), lr=1e-6)\n", " \n", "\n", "# create the model\n", "sizes = [60784, 500, 500, 150*1024]\n", "residual_conections = [[0], [1], [2], [3]]\n", "dropout = [0.3, 0.3, 0.3, 0.3]\n", "model = MLP(sizes, residual_conections, dropout)\n", "\n", "# create the data module\n", "data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=4)\n", "\n", "wandb.finish()\n", "\n", "wandb_logger = WandbLogger(project='brain2music', entity='ckadirt')\n", "\n", "# define the trainer\n", "trainer = L.Trainer(devices=2, accelerator=\"gpu\", max_epochs=400, logger=wandb_logger, precision='16-mixed', log_every_n_steps=10)\n", "\n", "# train the model\n", "trainer.fit(model, datamodule=data_module)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class MLP(L.LightningModule):\n", " def __init__(self, sizes, residual_conections, dropout):\n", " # sizes is a list of the sizes of the layers ej: [4800, 1000, 1000, 1000, 1000, 1000, 1000, 600]\n", " # residual_conections is a list with the same length as sizes, each element is a list of the indexes of the layers that will recieve the output of the layer as input, 0 means that the layer will recieve the x inputs ej. [[0], [1], [2,1], [3], [4,3], [5], [6,5], [7]]\n", " # dropout is a list with the same length as sizes, each element is the dropout probability of the layer ej. [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]\n", " super().__init__()\n", " self.sizes = sizes\n", " self.residual_conections = residual_conections\n", " self.dropout = dropout\n", " self.layers = nn.Sequential()\n", " for i in range(len(sizes)-1):\n", " self.layers.add_module('linear'+str(i), nn.Linear(sizes[i], sizes[i+1]))\n", " self.layers.add_module('relu'+str(i), nn.ReLU())\n", " self.layers.add_module('dropout'+str(i), nn.Dropout(dropout[i]))\n", "\n", " self.loss = nn.MSELoss()\n", " self.test_outptus = []\n", " self.train_outptus = []\n", "\n", " def forward(self, x):\n", " return self.layers(x)\n", " \n", " def training_step(self, batch, batch_idx):\n", " voxels, embeddings = batch # the sizes are [batch_size, 10, 65000] and [batch_size, 2, 1125]\n", " # flatten the voxels to [batch_size, rest of the dimensions]\n", " embeddings = embeddings.flatten(start_dim=1) # the size is [batch_size, 2250]\n", " # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus\n", " voxels = voxels.mean(dim=1)\n", " voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n", " outputs = self(voxels)\n", " loss = self.loss(outputs, embeddings)\n", " self.log('train_loss', loss)\n", " discrete_outputs = outputs.argmax(dim=1)\n", " self.train_outptus.append(discrete_outputs)\n", " return loss\n", " \n", " def on_train_epoch_end(self):\n", " self.train_outptus = torch.cat(self.train_outptus)\n", " # save the outputs with the current epoch name\n", " torch.save(self.train_outptus, 'outputs_train'+str(self.current_epoch)+'.pt')\n", " self.train_outptus = []\n", " \n", " def on_validation_epoch_end(self):\n", " self.test_outptus = torch.cat(self.test_outptus)\n", " # save the outputs with the current epoch name\n", " torch.save(self.test_outptus, 'outputs_validation'+str(self.current_epoch)+'.pt')\n", " self.test_outptus = []\n", "\n", " def validation_step(self, batch, batch_idx):\n", " voxels, embeddings = batch\n", " embeddings = embeddings.flatten(start_dim=1)\n", " voxels = voxels.mean(dim=1)\n", " voxels = voxels.flatten(start_dim=1)\n", " outputs = self(voxels)\n", " loss = self.loss(outputs, embeddings)\n", " self.log('val_loss', loss)\n", " discrete_outputs = outputs.argmax(dim=1)\n", " self.test_outptus.append(discrete_outputs)\n", " return loss\n", " \n", " \n", " def configure_optimizers(self):\n", " return torch.optim.Adam(self.parameters(), lr=1e-5)\n", " \n", "\n", "# create the model\n", "sizes = [60784, 1000, 1000, 150*2*1024]\n", "residual_conections = [[0], [1], [2], [3]]\n", "dropout = [0.5, 0.5, 0.5, 0.5]\n", "model = MLP(sizes, residual_conections, dropout)\n", "\n", "# create the data module\n", "data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=32)\n", "\n", "\n", "wandb_logger = WandbLogger(project='brain2music', entity='ckadirt')\n", "\n", "# define the trainer\n", "trainer = L.Trainer(devices=2, accelerator=\"gpu\", max_epochs=400, logger=wandb_logger, precision='16-mixed', log_every_n_steps=10)\n", "\n", "# train the model\n", "trainer.fit(model, datamodule=data_module)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class MLP(L.LightningModule):\n", " def __init__(self, sizes, residual_conections, dropout):\n", " # sizes is a list of the sizes of the layers ej: [4800, 1000, 1000, 1000, 1000, 1000, 1000, 600]\n", " # residual_conections is a list with the same length as sizes, each element is a list of the indexes of the layers that will recieve the output of the layer as input, 0 means that the layer will recieve the x inputs ej. [[0], [1], [2,1], [3], [4,3], [5], [6,5], [7]]\n", " # dropout is a list with the same length as sizes, each element is the dropout probability of the layer ej. [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]\n", " super().__init__()\n", " self.sizes = sizes\n", " self.residual_conections = residual_conections\n", " self.dropout = dropout\n", " self.layers = nn.Sequential()\n", " for i in range(len(sizes)-1):\n", " self.layers.add_module('linear'+str(i), nn.Linear(sizes[i], sizes[i+1]))\n", " self.layers.add_module('relu'+str(i), nn.ReLU())\n", " self.layers.add_module('dropout'+str(i), nn.Dropout(dropout[i]))\n", "\n", " self.loss = nn.CrossEntropyLoss()\n", "\n", " def forward(self, x):\n", " return self.layers(x)\n", " \n", " def training_step(self, batch, batch_idx):\n", " voxels, embeddings = batch # the sizes are [batch_size, 10, 65000] and [batch_size, 2, 1125]\n", " # flatten the voxels to [batch_size, rest of the dimensions]\n", " embeddings = embeddings.flatten(start_dim=1).long() # the size is [batch_size, 2250] \n", " #take just the first 200 embeddings\n", " embeddings = embeddings[:, :200]\n", " # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus\n", " voxels = voxels.mean(dim=1)\n", " voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n", " outputs = self(voxels)\n", " # the outputs are [batch_size, 200*1024], we need to reshape them to [batch_size, 200, 1024]\n", " outputs = outputs.reshape(-1, 1024, 200)\n", " loss = self.loss(outputs, embeddings)\n", " acuracy = self.tokens_accuracy(outputs, embeddings)\n", " self.log('train_loss', loss)\n", " self.log('train_accuracy', acuracy)\n", " return loss\n", " \n", " def tokens_accuracy(self, outputs, embeddings):\n", " # outputs is [batch_size, 1024, 200]\n", " # embeddings is [batch_size, 200]\n", " # we need to get the index of the maximum value of each token\n", " outputs = outputs.argmax(dim=1)\n", " # now we need to compare the outputs with the embeddings\n", " return (outputs == embeddings).float().mean()\n", "\n", " \n", " def validation_step(self, batch, batch_idx):\n", " voxels, embeddings = batch\n", " embeddings = embeddings.flatten(start_dim=1).long()\n", " embeddings = embeddings[:, :200]\n", " voxels = voxels.mean(dim=1)\n", " voxels = voxels.flatten(start_dim=1)\n", " outputs = self(voxels)\n", " outputs = outputs.reshape(-1, 1024, 200)\n", " loss = self.loss(outputs, embeddings)\n", " accuracy = self.tokens_accuracy(outputs, embeddings)\n", " self.log('val_loss', loss)\n", " self.log('val_accuracy', accuracy)\n", " return loss\n", " \n", " \n", " def configure_optimizers(self):\n", " return torch.optim.Adam(self.parameters(), lr=1e-5)\n", " \n", "\n", "# create the model\n", "sizes = [60784, 1000, 1000, 200*1024]\n", "residual_conections = [[0], [1], [2], [3]]\n", "dropout = [0.5, 0.5, 0.5, 0.5]\n", "model = MLP(sizes, residual_conections, dropout)\n", "\n", "# create the data module\n", "data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=2)\n", "\n", "wandb.finish()\n", "\n", "wandb_logger = WandbLogger(project='brain2music', entity='ckadirt')\n", "\n", "# define the trainer\n", "trainer = L.Trainer(devices=2, accelerator=\"gpu\", max_epochs=400, logger=wandb_logger, precision='16-mixed', log_every_n_steps=10)\n", "\n", "# train the model\n", "trainer.fit(model, datamodule=data_module)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class MLP(L.LightningModule):\n", " def __init__(self, sizes, residual_conections, dropout):\n", " # sizes is a list of the sizes of the layers ej: [4800, 1000, 1000, 1000, 1000, 1000, 1000, 600]\n", " # residual_conections is a list with the same length as sizes, each element is a list of the indexes of the layers that will recieve the output of the layer as input, 0 means that the layer will recieve the x inputs ej. [[0], [1], [2,1], [3], [4,3], [5], [6,5], [7]]\n", " # dropout is a list with the same length as sizes, each element is the dropout probability of the layer ej. [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]\n", " super().__init__()\n", " self.sizes = sizes\n", " self.residual_conections = residual_conections\n", " self.dropout = dropout\n", " self.layers = nn.Sequential()\n", " for i in range(len(sizes)-1):\n", " self.layers.add_module('linear'+str(i), nn.Linear(sizes[i], sizes[i+1]))\n", " self.layers.add_module('relu'+str(i), nn.ReLU())\n", " self.layers.add_module('dropout'+str(i), nn.Dropout(dropout[i]))\n", "\n", " self.loss = nn.CrossEntropyLoss()\n", "\n", " def forward(self, x):\n", " return self.layers(x)\n", " \n", " def training_step(self, batch, batch_idx):\n", " voxels, embeddings = batch # the sizes are [batch_size, 10, 65000] and [batch_size, 2, 1125]\n", " # flatten the voxels to [batch_size, rest of the dimensions]\n", " embeddings = embeddings.flatten(start_dim=1).long() # the size is [batch_size, 2250] \n", " #take just the first 200 embeddings\n", " embeddings = embeddings[:, :200]\n", " # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus\n", " voxels = voxels[:, 1, :]\n", " voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n", " outputs = self(voxels)\n", " # the outputs are [batch_size, 200*1024], we need to reshape them to [batch_size, 200, 1024]\n", " outputs = outputs.reshape(-1, 1024, 200)\n", " loss = self.loss(outputs, embeddings)\n", " acuracy = self.tokens_accuracy(outputs, embeddings)\n", " self.log('train_loss', loss)\n", " self.log('train_accuracy', acuracy)\n", " return loss\n", " \n", " def tokens_accuracy(self, outputs, embeddings):\n", " # outputs is [batch_size, 1024, 200]\n", " # embeddings is [batch_size, 200]\n", " # we need to get the index of the maximum value of each token\n", " outputs = outputs.argmax(dim=1)\n", " # now we need to compare the outputs with the embeddings\n", " return (outputs == embeddings).float().mean()\n", "\n", " \n", " def validation_step(self, batch, batch_idx):\n", " voxels, embeddings = batch\n", " embeddings = embeddings.flatten(start_dim=1).long()\n", " embeddings = embeddings[:, :200]\n", " voxels = voxels[:, 1, :]\n", " voxels = voxels.flatten(start_dim=1)\n", " outputs = self(voxels)\n", " outputs = outputs.reshape(-1, 1024, 200)\n", " loss = self.loss(outputs, embeddings)\n", " accuracy = self.tokens_accuracy(outputs, embeddings)\n", " self.log('val_loss', loss)\n", " self.log('val_accuracy', accuracy)\n", " return loss\n", " \n", " \n", " def configure_optimizers(self):\n", " return torch.optim.Adam(self.parameters(), lr=1e-6)\n", " \n", "\n", "# create the model\n", "sizes = [60784, 1000, 1000, 200*1024]\n", "residual_conections = [[0], [1], [2], [3]]\n", "dropout = [0.2, 0.2, 0.2, 0.2]\n", "model = MLP(sizes, residual_conections, dropout)\n", "\n", "# create the data module\n", "data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=4)\n", "\n", "wandb.finish()\n", "\n", "wandb_logger = WandbLogger(project='brain2music', entity='ckadirt')\n", "\n", "# define the trainer\n", "trainer = L.Trainer(devices=2, accelerator=\"gpu\", max_epochs=400, logger=wandb_logger, precision='16-mixed', log_every_n_steps=10)\n", "\n", "# train the model\n", "trainer.fit(model, datamodule=data_module)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model3.eval()\n", "outputs = torch.Tensor((480,200))\n", "with torch.no_grad():\n", " test_dataset = VoxelsDataset(test_voxels_path, test_embeddings_path)\n", " dataloader = data.DataLoader(test_dataset, batch_size = 2)\n", " for i, (voxels, embeddings) in enumerate(dataloader):\n", " voxels = voxels[:, 1, :]\n", " voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n", " bout = model3(voxels)\n", " bout = bout.reshape(-1, 1024, 200)\n", " # the 1024 dimension is the number of tokens, we need to get the index of the maximum value of each token\n", " bout = bout.argmax(dim=1)\n", " # now we need to add the outputs to the outputs tensor\n", " outputs[i*2:(i+1)*2] = bout\n", " \n", " \n", "# save the predicted outputs on the current directory\n", "torch.save(outputs, 'outputs.pt')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" } }, "nbformat": 4, "nbformat_minor": 4 }