{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "d004931d-1ff4-4d85-9d23-1bb1eab2111e", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fri Sep 8 05:45:08 2023 \n", "+-----------------------------------------------------------------------------+\n", "| NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |\n", "|-------------------------------+----------------------+----------------------+\n", "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", "| | | MIG M. |\n", "|===============================+======================+======================|\n", "| 0 NVIDIA A100-SXM... On | 00000000:10:1C.0 Off | 0 |\n", "| N/A 30C P0 51W / 400W | 3MiB / 40960MiB | 0% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 1 NVIDIA A100-SXM... On | 00000000:10:1D.0 Off | 0 |\n", "| N/A 38C P0 157W / 400W | 40431MiB / 40960MiB | 100% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 2 NVIDIA A100-SXM... On | 00000000:20:1C.0 Off | 0 |\n", "| N/A 43C P0 176W / 400W | 28503MiB / 40960MiB | 100% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 3 NVIDIA A100-SXM... On | 00000000:20:1D.0 Off | 0 |\n", "| N/A 38C P0 166W / 400W | 38049MiB / 40960MiB | 100% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 4 NVIDIA A100-SXM... On | 00000000:90:1C.0 Off | 0 |\n", "| N/A 40C P0 178W / 400W | 37159MiB / 40960MiB | 100% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 5 NVIDIA A100-SXM... On | 00000000:90:1D.0 Off | 0 |\n", "| N/A 39C P0 158W / 400W | 37641MiB / 40960MiB | 100% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 6 NVIDIA A100-SXM... On | 00000000:A0:1C.0 Off | 0 |\n", "| N/A 39C P0 167W / 400W | 38395MiB / 40960MiB | 98% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 7 NVIDIA A100-SXM... On | 00000000:A0:1D.0 Off | 0 |\n", "| N/A 48C P0 381W / 400W | 35381MiB / 40960MiB | 100% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", " \n", "+-----------------------------------------------------------------------------+\n", "| Processes: |\n", "| GPU GI CI PID Type Process name GPU Memory |\n", "| ID ID Usage |\n", "|=============================================================================|\n", "| 1 N/A N/A 919504 C ...ari/llama_env/bin/python3 40428MiB |\n", "| 2 N/A N/A 919505 C ...ari/llama_env/bin/python3 28500MiB |\n", "| 3 N/A N/A 919506 C ...ari/llama_env/bin/python3 38046MiB |\n", "| 4 N/A N/A 919507 C ...ari/llama_env/bin/python3 37156MiB |\n", "| 5 N/A N/A 919508 C ...ari/llama_env/bin/python3 37638MiB |\n", "| 6 N/A N/A 919509 C ...ari/llama_env/bin/python3 38392MiB |\n", "| 7 N/A N/A 919510 C ...ari/llama_env/bin/python3 35378MiB |\n", "+-----------------------------------------------------------------------------+\n" ] } ], "source": [ "!nvidia-smi" ] }, { "cell_type": "code", "execution_count": 1, "id": "ec021849-d426-4450-a140-f8647a764d2e", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[2023-09-08 05:54:30,058] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" ] } ], "source": [ "from functools import partial\n", "from einops import rearrange\n", "from transformers import MusicgenForConditionalGeneration\n", "import pytorch_lightning as pl\n", "import os, torch, torch.nn as nn, torch.utils.data as data, torchvision as tv\n", "import lightning as L\n", "import numpy as np, pandas as pd, matplotlib.pyplot as plt\n", "from pytorch_lightning.loggers import WandbLogger\n", "import wandb" ] }, { "cell_type": "code", "execution_count": 2, "id": "5f1916fe-163f-4cac-ac0c-c16a74bcae89", "metadata": { "tags": [] }, "outputs": [], "source": [ "# create the datasets and dataloaders\n", "train_voxels_path = '/fsx/proj-fmri/ckadirt/b2m/data/sub-001_Resp_Training.npy' # path to training voxels 65000 * 4800 \n", "test_voxels_path = '/fsx/proj-fmri/ckadirt/b2m/data/sub-001_Resp_Test_Mean.npy' # path to test voxels 65000 * 600\n", "\n", "train_embeddings_path = '/fsx/proj-fmri/ckadirt/b2m/data/encodec32khz_training_embeds_sorted.npy' # path to training embeddings 480 * 2 * 1125\n", "test_embeddings_path = '/fsx/proj-fmri/ckadirt/b2m/data/encodec32khz_testing_embeds_sorted.npy' # path to test embeddings 600 * 2 * 1125\n", "\n", "class VoxelsDataset(data.Dataset):\n", " def __init__(self, voxels_path, embeddings_path):\n", " # transpose the two dimensions of the voxels data to match the embeddings data\n", " self.voxels = torch.from_numpy(np.load(voxels_path)).float().transpose(0, 1)\n", " self.embeddings = torch.from_numpy(np.load(embeddings_path))\n", " # as each stimulus has been exposed for 15 seconds and the fMRI data is sampled every 1.5 seconds, we take 10 samples per stimulus\n", " self.len = len(self.voxels) // 10\n", " print(\"The len is \", self.len )\n", "\n", " def __getitem__(self, index):\n", " # as each stimulus has been exposed for 15 seconds and the fMRI data is sampled every 1.5 seconds, we take 10 samples per stimulus\n", " voxels = self.voxels[index*10:(index+1)*10]\n", " embeddings = self.embeddings[index]\n", " return voxels, embeddings\n", "\n", " def __len__(self):\n", " return self.len\n", " \n", "class VoxelsEmbeddinsEncodecDataModule(pl.LightningDataModule):\n", " def __init__(self, train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=8):\n", " super().__init__()\n", " self.train_voxels_path = train_voxels_path\n", " self.train_embeddings_path = train_embeddings_path\n", " self.test_voxels_path = test_voxels_path\n", " self.test_embeddings_path = test_embeddings_path\n", " self.batch_size = batch_size\n", "\n", " def setup(self, stage=None):\n", " self.train_dataset = VoxelsDataset(self.train_voxels_path, self.train_embeddings_path)\n", " self.test_dataset = VoxelsDataset(self.test_voxels_path, self.test_embeddings_path)\n", "\n", " def train_dataloader(self):\n", " return data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)\n", "\n", " def val_dataloader(self):\n", " return data.DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "39fa231d-7e28-4813-8f5d-edf8fbce1774", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "'bn = BrainNetwork(in_dim=4096)\\nrr = RidgeRegression(60784, 4096)\\n\\ntest_input = torch.randn(3, 60784)\\nout1 = rr(test_input)\\nout2 = bn(out1)\\nprint(out2.shape, out1.shape)\\n '" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class RidgeRegression(torch.nn.Module):\n", " # make sure to add weight_decay when initializing optimizer\n", " def __init__(self, input_size, out_features): \n", " super(RidgeRegression, self).__init__()\n", " self.linear = torch.nn.Linear(input_size, out_features)\n", " def forward(self, x):\n", " return self.linear(x)\n", " \n", "class BrainNetwork(nn.Module):\n", " def __init__(self, out_dim=768*128, in_dim=60784, clip_size=128, h=4096, n_blocks=4, norm_type='ln', act_first=False, use_projector=True, drop1=.5, drop2=.15):\n", " super().__init__()\n", " norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)\n", " act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU\n", " act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)\n", " self.mlp = nn.ModuleList([\n", " nn.Sequential(\n", " nn.Linear(h, h),\n", " *[item() for item in act_and_norm],\n", " nn.Dropout(drop2)\n", " ) for _ in range(n_blocks)\n", " ])\n", " self.lin1 = nn.Linear(h, out_dim, bias=True)\n", " self.n_blocks = n_blocks\n", " self.clip_size = clip_size\n", " self.use_projector = use_projector\n", " if use_projector:\n", " self.projector = nn.Sequential(\n", " nn.LayerNorm(clip_size),\n", " nn.GELU(),\n", " nn.Linear(clip_size, 2048),\n", " nn.LayerNorm(2048),\n", " nn.GELU(),\n", " nn.Linear(2048, 2048),\n", " nn.LayerNorm(2048),\n", " nn.GELU(),\n", " nn.Linear(2048, clip_size)\n", " )\n", " \n", " def forward(self, x):\n", " residual = x\n", " for res_block in range(self.n_blocks):\n", " x = self.mlp[res_block](x)\n", " x += residual\n", " residual = x\n", " x = x.reshape(len(x), -1)\n", " x = self.lin1(x)\n", " if self.use_projector:\n", " x = self.projector(x.reshape(len(x), -1, self.clip_size))\n", " x = rearrange(x, 'b e t -> b t e')\n", " return x\n", " return x\n", " \n", "\"\"\"bn = BrainNetwork(in_dim=4096)\n", "rr = RidgeRegression(60784, 4096)\n", "\n", "test_input = torch.randn(3, 60784)\n", "out1 = rr(test_input)\n", "out2 = bn(out1)\n", "print(out2.shape, out1.shape)\n", " \"\"\"" ] }, { "cell_type": "code", "execution_count": 4, "id": "495e8e78-c44f-4410-b3c2-61faa6beec77", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "\"b2m_test = B2M().to('cuda')\\ntest_b2m_input = torch.randn(4, 60784).to('cuda')\\naudio_codes = np.load('/fsx/proj-fmri/ckadirt/b2m/data/encodec32khz_training_embeds_sorted.npy')\\naudio_codes = torch.from_numpy(audio_codes)\\naudio_codes = audio_codes[:4, :, :]\\naudio_codes = rearrange(audio_codes, 'b c t -> (b c) t').to('cuda').long()\\ntest_b2m_output = b2m_test(test_b2m_input, audio_codes)\\nprint(test_b2m_output.shape)\"" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class B2M(pl.LightningModule):\n", " def __init__(self, input_size = 60784, mapping_size = 4096, num_codebooks = 4):\n", " # sizes is a list of the sizes of the layers ej: [4800, 1000, 1000, 1000, 1000, 1000, 1000, 600]\n", " # residual_conections is a list with the same length as sizes, each element is a list of the indexes of the layers that will recieve the output of the layer as input, 0 means that the layer will recieve the x inputs ej. [[0], [1], [2,1], [3], [4,3], [5], [6,5], [7]]\n", " # dropout is a list with the same length as sizes, each element is the dropout probability of the layer ej. [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]\n", " super().__init__()\n", " self.brain_network = BrainNetwork(h = mapping_size)\n", " self.ridge_regression = RidgeRegression(input_size=input_size, out_features=mapping_size)\n", " self.loss = nn.CrossEntropyLoss()\n", " self.pseudo_text_encoder = nn.Sequential(\n", " self.ridge_regression,\n", " self.brain_network\n", " )\n", " self.musicgen_decoder = MusicgenForConditionalGeneration.from_pretrained(\"facebook/musicgen-small\")\n", " self.pad_token_id = self.musicgen_decoder.generation_config.pad_token_id\n", " self.num_codebooks = num_codebooks\n", " self.test_outptus = []\n", " self.train_outptus = []\n", "\n", " for param in self.musicgen_decoder.parameters():\n", " param.requires_grad = False\n", "\n", " def forward(self, x, decoder_input_ids=None):\n", " # x is [batch_size, 60784]\n", " # decoder input ids is [batch_size * num_codebooks, 750] 750 is the length of the audiocodes for 15 seconds of audio\n", " # first we pass the voxels through the pseudo text encoder\n", " pseudo_encoded_fmri = self.pseudo_text_encoder(x)\n", " # x is [batch_size, 128, 768]\n", " # now we pass the output through the musicgen projector to get [batch_size, 128, 1024]\n", " projected_pseudo_encoded_fmri = self.musicgen_decoder.enc_to_dec_proj(pseudo_encoded_fmri)\n", "\n", " if decoder_input_ids is None:\n", " # if no decoder input ids are given, we create a tensor of the size [batch_size * num_codebooks, 1] filled with the pad token id\n", " decoder_input_ids = (\n", " torch.ones((x.shape[0] * self.num_codebooks, 1), dtype=torch.long)\n", " * self.pad_token_id\n", " )\n", " \n", " # now we pass the projected pseudo encoded fmri through the musicgen decoder\n", " logits = self.musicgen_decoder.decoder(\n", " input_ids = decoder_input_ids,\n", " encoder_hidden_states = projected_pseudo_encoded_fmri,\n", " ).logits\n", "\n", " return logits\n", "\n", " \n", " def training_step(self, batch, batch_idx):\n", " voxels, embeddings = batch # the sizes are [batch_size, 10, 65000] and [batch_size, 4, 750]\n", " # take the last scan from the voxels\n", " voxels = voxels[:, -1, :]\n", " # convert the embeddings to long and combine the batch and codebook dimensions\n", " embeddings = rearrange(embeddings, 'b c t -> (b c) t').long()\n", "\n", "\n", " #take just the first 200 embeddings\n", " #embeddings = embeddings[:, :200]\n", " # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus\n", " #voxels = voxels[:, 0:2, :]\n", " #voxels = voxels.mean(dim=1)\n", " #voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n", "\n", "\n", " # use the decoder input ids to get the logits\n", " decoder_input_ids = embeddings[:, :-1]\n", " logits = self(voxels, decoder_input_ids)\n", "\n", " # get the loss\n", " loss = self.loss(rearrange(logits, '(b c) t d -> (b c t) d', c=self.num_codebooks), rearrange(embeddings[:, 1:], '(b c) t -> (b c t)', c=self.num_codebooks))\n", "\n", "\n", " acuracy = self.tokens_accuracy(logits, embeddings[:,1:])\n", " self.log('train_loss', loss, sync_dist=True)\n", " self.log('train_accuracy', acuracy, sync_dist=True)\n", " discrete_outputs = logits.argmax(dim=2)\n", " self.train_outptus.append(discrete_outputs)\n", " return loss\n", " \n", " def tokens_accuracy(self, outputs, embeddings):\n", " # outputs is [batch_size, 750, 2048]\n", " # embeddings is [batch_size, 750]\n", " # we need to get the index of the maximum value of each token\n", " outputs = outputs.argmax(dim=2)\n", " # now we need to compare the outputs with the embeddings\n", " return (outputs == embeddings).float().mean()\n", " \n", " def on_train_epoch_end(self):\n", " self.train_outptus = torch.cat(self.train_outptus)\n", " # save the outputs with the current epoch name\n", " torch.save(self.train_outptus, 'outputs_train'+str(self.current_epoch)+'.pt')\n", " self.train_outptus = []\n", " \n", " def on_validation_epoch_end(self):\n", " self.test_outptus = torch.cat(self.test_outptus)\n", " # save the outputs with the current epoch name\n", " torch.save(self.test_outptus, 'outputs_validation'+str(self.current_epoch)+'.pt')\n", " self.test_outptus = []\n", "\n", " \n", " def validation_step(self, batch, batch_idx):\n", " voxels, embeddings = batch\n", " # take the last scan from the voxels\n", " voxels = voxels[:, -1, :]\n", " # convert the embeddings to long and combine the batch and codebook dimensions\n", " embeddings = rearrange(embeddings, 'b c t -> (b c) t').long()\n", "\n", " # take just the first 200 embeddings\n", " #embeddings = embeddings[:, :200]\n", " # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus\n", " #voxels = voxels[:, 0:2, :]\n", " #voxels = voxels.mean(dim=1)\n", " #voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n", "\n", " # use the decoder input ids to get the logits\n", " decoder_input_ids = embeddings[:, :-1]\n", " logits = self(voxels, decoder_input_ids)\n", "\n", " # get the loss\n", " loss = self.loss(rearrange(logits, '(b c) t d -> (b c t) d', c=self.num_codebooks), rearrange(embeddings[:, 1:], '(b c) t -> (b c t)', c=self.num_codebooks))\n", "\n", " acuracy = self.tokens_accuracy(logits, embeddings[:,1:])\n", " self.log('val_loss', loss, sync_dist=True)\n", " self.log('val_accuracy', acuracy, sync_dist=True)\n", " discrete_outputs = logits.argmax(dim=2)\n", " self.test_outptus.append(discrete_outputs)\n", " return loss\n", " \n", " \n", "\n", " def configure_optimizers(self):\n", " # we just want to train the pseudo text encoder, but we need to zero the gradients of the musicgen decoder\n", " optimizer = torch.optim.AdamW(\n", " [\n", " {'params': self.pseudo_text_encoder.parameters(), 'lr': 3e-6, 'weight_decay': 1e-4},\n", " {'params': self.musicgen_decoder.parameters(), 'lr': 0},\n", " ],\n", " )\n", " return optimizer\n", "\n", "\n", "# create the model\n", "\"\"\"b2m_test = B2M().to('cuda')\n", "test_b2m_input = torch.randn(4, 60784).to('cuda')\n", "audio_codes = np.load('/fsx/proj-fmri/ckadirt/b2m/data/encodec32khz_training_embeds_sorted.npy')\n", "audio_codes = torch.from_numpy(audio_codes)\n", "audio_codes = audio_codes[:4, :, :]\n", "audio_codes = rearrange(audio_codes, 'b c t -> (b c) t').to('cuda').long()\n", "test_b2m_output = b2m_test(test_b2m_input, audio_codes)\n", "print(test_b2m_output.shape)\"\"\"" ] }, { "cell_type": "code", "execution_count": null, "id": "0646c340-c3df-42c0-ad22-2b677fc85da6", "metadata": { "tags": [] }, "outputs": [], "source": [ "b2m = B2M()\n", "data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=4)\n", "\n", "wandb.finish()\n", "\n", "wandb_logger = WandbLogger(project='brain2music', entity='ckadirt')\n", "\n", "# define the trainer\n", "trainer = pl.Trainer(devices=1, accelerator=\"gpu\", max_epochs=400, logger=wandb_logger, precision='16-mixed', log_every_n_steps=10)\n", "\n", "# train the model\n", "trainer.fit(b2m, datamodule=data_module)" ] }, { "cell_type": "code", "execution_count": 21, "id": "4f458adf-1e89-4b8b-9514-08bcc9e8ef56", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fri Sep 8 05:08:00 2023 \n", "+-----------------------------------------------------------------------------+\n", "| NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |\n", "|-------------------------------+----------------------+----------------------+\n", "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", "| | | MIG M. |\n", "|===============================+======================+======================|\n", "| 0 NVIDIA A100-SXM... On | 00000000:10:1C.0 Off | 0 |\n", "| N/A 31C P0 71W / 400W | 29323MiB / 40960MiB | 0% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 1 NVIDIA A100-SXM... On | 00000000:10:1D.0 Off | 0 |\n", "| N/A 40C P0 204W / 400W | 40217MiB / 40960MiB | 100% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 2 NVIDIA A100-SXM... On | 00000000:20:1C.0 Off | 0 |\n", "| N/A 47C P0 243W / 400W | 40421MiB / 40960MiB | 99% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 3 NVIDIA A100-SXM... On | 00000000:20:1D.0 Off | 0 |\n", "| N/A 42C P0 194W / 400W | 37547MiB / 40960MiB | 100% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 4 NVIDIA A100-SXM... On | 00000000:90:1C.0 Off | 0 |\n", "| N/A 45C P0 282W / 400W | 29223MiB / 40960MiB | 100% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 5 NVIDIA A100-SXM... On | 00000000:90:1D.0 Off | 0 |\n", "| N/A 39C P0 179W / 400W | 24387MiB / 40960MiB | 100% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 6 NVIDIA A100-SXM... On | 00000000:A0:1C.0 Off | 0 |\n", "| N/A 43C P0 232W / 400W | 29819MiB / 40960MiB | 88% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", "| 7 NVIDIA A100-SXM... On | 00000000:A0:1D.0 Off | 0 |\n", "| N/A 38C P0 166W / 400W | 28583MiB / 40960MiB | 100% Default |\n", "| | | Disabled |\n", "+-------------------------------+----------------------+----------------------+\n", " \n", "+-----------------------------------------------------------------------------+\n", "| Processes: |\n", "| GPU GI CI PID Type Process name GPU Memory |\n", "| ID ID Usage |\n", "|=============================================================================|\n", "| 0 N/A N/A 895711 C ...3/envs/mindeye/bin/python 29320MiB |\n", "| 1 N/A N/A 899054 C ...ari/llama_env/bin/python3 40214MiB |\n", "| 2 N/A N/A 899055 C ...ari/llama_env/bin/python3 40418MiB |\n", "| 3 N/A N/A 899056 C ...ari/llama_env/bin/python3 37544MiB |\n", "| 4 N/A N/A 899057 C ...ari/llama_env/bin/python3 29220MiB |\n", "| 5 N/A N/A 899058 C ...ari/llama_env/bin/python3 24384MiB |\n", "| 6 N/A N/A 899059 C ...ari/llama_env/bin/python3 29816MiB |\n", "| 7 N/A N/A 899060 C ...ari/llama_env/bin/python3 28580MiB |\n", "+-----------------------------------------------------------------------------+\n" ] } ], "source": [ "!nvidia-smi" ] }, { "cell_type": "code", "execution_count": 49, "id": "e4c9f9f5-a984-4fcf-8938-861f5bfb38d6", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([240, 749])\n" ] } ], "source": [ "# read this file reb2m/src/outputs_train39.pt\n", "train_outputs = torch.load('/fsx/proj-fmri/ckadirt/b2m/src/outputs_validation39.pt')\n", "print(train_outputs.shape)\n", "example1 = train_outputs[0:4].unsqueeze(0).unsqueeze(0)\n", "\n" ] }, { "cell_type": "code", "execution_count": 55, "id": "80e45d62-27d4-4ebe-a86c-56a9cc4ac0de", "metadata": { "tags": [] }, "outputs": [], "source": [ "data = np.load('/fsx/proj-fmri/ckadirt/b2m/data/encodec32khz_training_embeds_sorted.npy')" ] }, { "cell_type": "code", "execution_count": 58, "id": "ce72b901-e670-48ca-8c92-2221b46f6145", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "array([ 753., 609., 248., 248., 248., 1271., 248., 1350., 1350.,\n", " 1350., 1350., 367., 297., 1251., 702., 1103., 1106., 1600.,\n", " 1106., 457., 1638., 903., 1103., 1673., 1103., 902., 149.,\n", " 1544., 1458., 1544., 773., 1470., 1470., 890., 1457., 1038.,\n", " 2008., 1506., 457., 1126., 1047., 1103., 1933., 3., 560.,\n", " 714., 271., 1442., 1710., 949., 1508., 957., 685., 399.,\n", " 1103., 1667., 1555., 1529., 494., 1436., 1883., 29., 225.,\n", " 846., 773., 569., 677., 71., 888., 1693., 1401., 888.,\n", " 1016., 792., 569., 1590., 71., 1600., 314., 272., 1756.,\n", " 1917., 1264., 917., 1021., 178., 1205., 974., 457., 457.,\n", " 1106., 569., 1562., 271., 1977., 367., 345., 893., 1842.,\n", " 1401., 1152., 1152., 1152., 1152., 50., 1418., 1748., 1188.,\n", " 2043., 1666., 1796., 1512., 457., 812., 1600., 1764., 879.,\n", " 1194., 1457., 1842., 1883., 104., 666., 352., 612., 1710.,\n", " 1458., 315., 1990., 741., 1047., 675., 514., 1051., 1132.,\n", " 1115., 315., 1347., 1670., 1875., 1194., 71., 1786., 196.,\n", " 2043., 1818., 1477., 996., 1083., 967., 128., 1629., 1562.,\n", " 1875., 237., 712., 1279., 29., 675., 1207., 1303., 1622.,\n", " 1622., 1622., 1622., 1557., 675., 1303., 261., 675., 1245.,\n", " 1245., 675., 675., 714., 378., 1673., 1145., 1673., 1673.,\n", " 2013., 1990., 974., 457., 1124., 1562., 3., 1721., 846.,\n", " 378., 271., 271., 675., 1782., 876., 1918., 483., 1419.,\n", " 1693., 1562., 50., 1198., 1786., 1492., 1670., 1292., 104.,\n", " 1286., 620., 776., 828., 1629., 762., 71., 1095., 367.,\n", " 2043., 1501., 1152., 320., 1271., 801., 1671., 1418., 1213.,\n", " 71., 40., 1419., 1179., 178., 1106., 1016., 1652., 902.,\n", " 1074., 366., 1484., 42., 1457., 1453., 1241., 849., 1888.,\n", " 1775., 1888., 82., 1671., 836., 82., 82., 472., 247.,\n", " 1883., 29., 398., 642., 1904., 1436., 2002., 71., 71.,\n", " 71., 71., 938., 1122., 1106., 736., 367., 1531., 778.,\n", " 1388., 1949., 207., 1418., 721., 1130., 989., 1303., 1402.,\n", " 1047., 569., 569., 569., 272., 320., 548., 976., 314.,\n", " 1095., 1559., 1378., 828., 1742., 1378., 620., 1271., 776.,\n", " 801., 1213., 1358., 1883., 1157., 104., 1744., 648., 1271.,\n", " 1373., 1956., 685., 290., 938., 965., 208., 823., 1287.,\n", " 893., 1470., 775., 1158., 775., 1990., 986., 1152., 1358.,\n", " 312., 1111., 1564., 50., 237., 1646., 937., 1052., 1917.,\n", " 1742., 1702., 297., 259., 1734., 54., 1933., 71., 1875.,\n", " 1875., 1702., 1875., 237., 237., 237., 164., 1602., 220.,\n", " 457., 1798., 457., 1798., 1798., 1514., 548., 523., 949.,\n", " 1599., 714., 1673., 893., 714., 1016., 1016., 1638., 1562.,\n", " 714., 1016., 1103., 871., 569., 1047., 1047., 612., 1646.,\n", " 1875., 747., 714., 718., 1562., 1673., 1555., 1763., 1127.,\n", " 793., 1817., 657., 1106., 457., 1106., 458., 1599., 1106.,\n", " 29., 1863., 1103., 1599., 714., 812., 1194., 1508., 1194.,\n", " 1562., 986., 714., 1308., 1097., 1207., 1095., 1763., 1127.,\n", " 642., 1419., 290., 42., 1246., 400., 1462., 1194., 773.,\n", " 741., 832., 1368., 2042., 937., 71., 1541., 42., 2008.,\n", " 920., 872., 1473., 890., 234., 234., 1781., 1742., 1742.,\n", " 71., 508., 237., 1790., 1428., 347., 1907., 1642., 457.,\n", " 1106., 987., 1106., 1907., 1562., 1378., 1638., 1106., 1106.,\n", " 399., 1106., 1106., 569., 1194., 695., 1286., 237., 2043.,\n", " 1891., 1933., 1194., 1445., 1670., 1933., 1426., 1198., 937.,\n", " 1977., 677., 1907., 1907., 1492., 1514., 1798., 1562., 1823.,\n", " 644., 1629., 1842., 4., 712., 949., 686., 607., 1343.,\n", " 1907., 1907., 1106., 1973., 1051., 1974., 974., 1869., 913.,\n", " 1499., 272., 1322., 789., 457., 1869., 354., 45., 1869.,\n", " 846., 1103., 1393., 1798., 1424., 1766., 457., 1126., 1642.,\n", " 1775., 1775., 1144., 84., 84., 1142., 1549., 1992., 1989.,\n", " 1047., 500., 637., 1499., 1451., 1732., 297., 261., 1124.,\n", " 29., 866., 457., 149., 1608., 1047., 1343., 1433., 1047.,\n", " 1044., 1837., 1984., 1807., 869., 1598., 1188., 50., 569.,\n", " 272., 1401., 1670., 367., 744., 1817., 178., 1904., 1562.,\n", " 352., 71., 1670., 980., 1629., 736., 1130., 328., 302.,\n", " 893., 1378., 1652., 1469., 1786., 1742., 328., 677., 1652.,\n", " 1652., 1188., 2008., 1733., 1002., 1106., 1103., 464., 1052.,\n", " 1782., 1508., 1106., 677., 1481., 1666., 1921., 548., 1106.,\n", " 902., 1106., 893., 1106., 893., 1933., 893., 237., 1051.,\n", " 626., 919., 1106., 1484., 1288., 1103., 902., 986., 162.,\n", " 421., 1733., 281., 2031., 366., 494., 1594., 825., 730.,\n", " 1702., 642., 1478., 1366., 869., 1106., 1106., 1047., 320.,\n", " 1380., 569., 569., 569., 1380., 104., 1324., 1600., 1132.,\n", " 1337., 1047., 612., 1555., 2001., 2027., 2027., 71., 800.,\n", " 30., 2003., 1817., 1047., 1764., 714., 1457., 840., 1246.,\n", " 1978., 297., 1798., 1127., 1106., 948., 1047., 457., 1158.,\n", " 2001., 1308., 1316., 634., 1481., 1047., 762., 964., 1913.,\n", " 1638., 1562., 1582., 1095., 936., 658., 1978., 1512., 1753.,\n", " 1974., 1564., 2007., 890., 1562., 890., 1801., 1801., 1378.,\n", " 71., 1481., 1863., 921., 1600., 121., 648., 1132., 1132.,\n", " 1512., 1529., 263., 1529., 1564., 1484., 642., 642., 1629.,\n", " 40., 1444., 1078., 1078., 1436., 118., 118., 1522., 1904.,\n", " 237., 658., 658., 1711., 1095., 778., 121., 248., 999.,\n", " 648., 648., 247., 648., 648., 8., 8., 1453., 609.,\n", " 609., 83., 166.], dtype=float32)" ] }, "execution_count": 58, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data[0][0]" ] }, { "cell_type": "code", "execution_count": 54, "id": "cf8382ae-ed92-43bb-a65c-7537161956d5", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "tensor([609, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237, 237,\n", " 237, 237, 237, 237, 237, 237, 237], device='cuda:0')" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "example1[0][0][0]" ] }, { "cell_type": "code", "execution_count": 23, "id": "f5842cd4-17b7-47c6-8939-31634f3db3cc", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n", "│ in <module>:1 │\n", "│ │\n", "│ ❱ 1 sampled.unsqueeze(0).unsqueeze(0).detach().clone() │\n", "│ 2 │\n", "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n", "RuntimeError: CUDA error: device-side assert triggered\n", "CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be \n", "incorrect.\n", "For debugging consider passing CUDA_LAUNCH_BLOCKING=1.\n", "Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n", "\n", "\n" ], "text/plain": [ "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", "\u001b[31m│\u001b[0m in \u001b[92m
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n", "│ in <module>:1 │\n", "│ │\n", "│ ❱ 1 decoded = b2m.musicgen_decoder.audio_encoder.decode(audio_codes = sampled.unsqueeze(0).u │\n", "│ 2 │\n", "│ │\n", "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc │\n", "│ odec/modeling_encodec.py:742 in decode │\n", "│ │\n", "│ 739 │ │ if chunk_length is None: │\n", "│ 740 │ │ │ if len(audio_codes) != 1: │\n", "│ 741 │ │ │ │ raise ValueError(f\"Expected one frame, got {len(audio_codes)}\") │\n", "│ ❱ 742 │ │ │ audio_values = self._decode_frame(audio_codes[0], audio_scales[0]) │\n", "│ 743 │ │ else: │\n", "│ 744 │ │ │ decoded_frames = [] │\n", "│ 745 │\n", "│ │\n", "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc │\n", "│ odec/modeling_encodec.py:707 in _decode_frame │\n", "│ │\n", "│ 704 │ def _decode_frame(self, codes: torch.Tensor, scale: Optional[torch.Tensor] = None) - │\n", "│ 705 │ │ codes = codes.transpose(0, 1) │\n", "│ 706 │ │ embeddings = self.quantizer.decode(codes) │\n", "│ ❱ 707 │ │ outputs = self.decoder(embeddings) │\n", "│ 708 │ │ if scale is not None: │\n", "│ 709 │ │ │ outputs = outputs * scale.view(-1, 1, 1) │\n", "│ 710 │ │ return outputs │\n", "│ │\n", "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/module │\n", "│ .py:1501 in _call_impl │\n", "│ │\n", "│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │\n", "│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │\n", "│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │\n", "│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │\n", "│ 1502 │ │ # Do not call functions when jit is used │\n", "│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │\n", "│ 1504 │ │ backward_pre_hooks = [] │\n", "│ │\n", "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc │\n", "│ odec/modeling_encodec.py:336 in forward │\n", "│ │\n", "│ 333 │ │\n", "│ 334 │ def forward(self, hidden_states): │\n", "│ 335 │ │ for layer in self.layers: │\n", "│ ❱ 336 │ │ │ hidden_states = layer(hidden_states) │\n", "│ 337 │ │ return hidden_states │\n", "│ 338 │\n", "│ 339 │\n", "│ │\n", "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/module │\n", "│ .py:1501 in _call_impl │\n", "│ │\n", "│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │\n", "│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │\n", "│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │\n", "│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │\n", "│ 1502 │ │ # Do not call functions when jit is used │\n", "│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │\n", "│ 1504 │ │ backward_pre_hooks = [] │\n", "│ │\n", "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc │\n", "│ odec/modeling_encodec.py:162 in forward │\n", "│ │\n", "│ 159 │ │ │ # Asymmetric padding required for odd strides │\n", "│ 160 │ │ │ padding_right = padding_total // 2 │\n", "│ 161 │ │ │ padding_left = padding_total - padding_right │\n", "│ ❱ 162 │ │ │ hidden_states = self._pad1d( │\n", "│ 163 │ │ │ │ hidden_states, (padding_left, padding_right + extra_padding), mode=self. │\n", "│ 164 │ │ │ ) │\n", "│ 165 │\n", "│ │\n", "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc │\n", "│ odec/modeling_encodec.py:143 in _pad1d │\n", "│ │\n", "│ 140 │ │ if length <= max_pad: │\n", "│ 141 │ │ │ extra_pad = max_pad - length + 1 │\n", "│ 142 │ │ │ hidden_states = nn.functional.pad(hidden_states, (0, extra_pad)) │\n", "│ ❱ 143 │ │ padded = nn.functional.pad(hidden_states, paddings, mode, value) │\n", "│ 144 │ │ end = padded.shape[-1] - extra_pad │\n", "│ 145 │ │ return padded[..., :end] │\n", "│ 146 │\n", "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n", "RuntimeError: CUDA error: device-side assert triggered\n", "CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be \n", "incorrect.\n", "For debugging consider passing CUDA_LAUNCH_BLOCKING=1.\n", "Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n", "\n", "\n" ], "text/plain": [ "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", "\u001b[31m│\u001b[0m in \u001b[92m
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n", "│ in <module>:2 │\n", "│ │\n", "│ 1 projected_pseudo_encoded_fmri = torch.rand((1,15,1024)) │\n", "│ ❱ 2 gepe = b2m.musicgen_decoder.generate( │\n", "│ 3 │ │ │ encoder_hidden_states = projected_pseudo_encoded_fmri, │\n", "│ 4 │ │ │ max_length = 752 │\n", "│ 5 │ │ ) │\n", "│ │\n", "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/utils/_contextlib │\n", "│ .py:115 in decorate_context │\n", "│ │\n", "│ 112 │ @functools.wraps(func) │\n", "│ 113 │ def decorate_context(*args, **kwargs): │\n", "│ 114 │ │ with ctx_factory(): │\n", "│ ❱ 115 │ │ │ return func(*args, **kwargs) │\n", "│ 116 │ │\n", "│ 117 │ return decorate_context │\n", "│ 118 │\n", "│ │\n", "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/mus │\n", "│ icgen/modeling_musicgen.py:2261 in generate │\n", "│ │\n", "│ 2258 │ │ generation_config = copy.deepcopy(generation_config) │\n", "│ 2259 │ │ model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be m │\n", "│ 2260 │ │ generation_config.validate() │\n", "│ ❱ 2261 │ │ self._validate_model_kwargs(model_kwargs.copy()) │\n", "│ 2262 │ │ │\n", "│ 2263 │ │ if model_kwargs.get(\"encoder_outputs\") is not None and type(model_kwargs[\"encode │\n", "│ 2264 │ │ │ # wrap the unconditional outputs as a BaseModelOutput for compatibility with │\n", "│ │\n", "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/generation │\n", "│ /utils.py:1249 in _validate_model_kwargs │\n", "│ │\n", "│ 1246 │ │ │ │ unused_model_args.append(key) │\n", "│ 1247 │ │ │\n", "│ 1248 │ │ if unused_model_args: │\n", "│ ❱ 1249 │ │ │ raise ValueError( │\n", "│ 1250 │ │ │ │ f\"The following `model_kwargs` are not used by the model: {unused_model_ │\n", "│ 1251 │ │ │ │ \" generate arguments will also show up in this list)\" │\n", "│ 1252 │ │ │ ) │\n", "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n", "ValueError: The following `model_kwargs` are not used by the model: ['encoder_hidden_states'] (note: typos in the \n", "generate arguments will also show up in this list)\n", "\n" ], "text/plain": [ "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", "\u001b[31m│\u001b[0m in \u001b[92m
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n", "│ in <module>:1 │\n", "│ │\n", "│ ❱ 1 sampled.to('cpu') │\n", "│ 2 │\n", "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n", "RuntimeError: CUDA error: device-side assert triggered\n", "CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be \n", "incorrect.\n", "For debugging consider passing CUDA_LAUNCH_BLOCKING=1.\n", "Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n", "\n", "\n" ], "text/plain": [ "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", "\u001b[31m│\u001b[0m in \u001b[92m
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n", "│ in <module>:4 │\n", "│ │\n", "│ 1 b2m = b2m.to('cpu') │\n", "│ 2 ss = torch.load('./samplet.pt').to('cpu') │\n", "│ 3 with torch.no_grad(): │\n", "│ ❱ 4 │ decoded = b2m.musicgen_decoder.audio_encoder.decode(audio_codes = ss.unsqueeze(0).un │\n", "│ 5 │\n", "│ │\n", "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc │\n", "│ odec/modeling_encodec.py:742 in decode │\n", "│ │\n", "│ 739 │ │ if chunk_length is None: │\n", "│ 740 │ │ │ if len(audio_codes) != 1: │\n", "│ 741 │ │ │ │ raise ValueError(f\"Expected one frame, got {len(audio_codes)}\") │\n", "│ ❱ 742 │ │ │ audio_values = self._decode_frame(audio_codes[0], audio_scales[0]) │\n", "│ 743 │ │ else: │\n", "│ 744 │ │ │ decoded_frames = [] │\n", "│ 745 │\n", "│ │\n", "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc │\n", "│ odec/modeling_encodec.py:706 in _decode_frame │\n", "│ │\n", "│ 703 │ │\n", "│ 704 │ def _decode_frame(self, codes: torch.Tensor, scale: Optional[torch.Tensor] = None) - │\n", "│ 705 │ │ codes = codes.transpose(0, 1) │\n", "│ ❱ 706 │ │ embeddings = self.quantizer.decode(codes) │\n", "│ 707 │ │ outputs = self.decoder(embeddings) │\n", "│ 708 │ │ if scale is not None: │\n", "│ 709 │ │ │ outputs = outputs * scale.view(-1, 1, 1) │\n", "│ │\n", "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc │\n", "│ odec/modeling_encodec.py:435 in decode │\n", "│ │\n", "│ 432 │ │ quantized_out = torch.tensor(0.0, device=codes.device) │\n", "│ 433 │ │ for i, indices in enumerate(codes): │\n", "│ 434 │ │ │ layer = self.layers[i] │\n", "│ ❱ 435 │ │ │ quantized = layer.decode(indices) │\n", "│ 436 │ │ │ quantized_out = quantized_out + quantized │\n", "│ 437 │ │ return quantized_out │\n", "│ 438 │\n", "│ │\n", "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc │\n", "│ odec/modeling_encodec.py:391 in decode │\n", "│ │\n", "│ 388 │ │ return embed_in │\n", "│ 389 │ │\n", "│ 390 │ def decode(self, embed_ind): │\n", "│ ❱ 391 │ │ quantize = self.codebook.decode(embed_ind) │\n", "│ 392 │ │ quantize = quantize.permute(0, 2, 1) │\n", "│ 393 │ │ return quantize │\n", "│ 394 │\n", "│ │\n", "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/enc │\n", "│ odec/modeling_encodec.py:372 in decode │\n", "│ │\n", "│ 369 │ │ return embed_ind │\n", "│ 370 │ │\n", "│ 371 │ def decode(self, embed_ind): │\n", "│ ❱ 372 │ │ quantize = nn.functional.embedding(embed_ind, self.embed) │\n", "│ 373 │ │ return quantize │\n", "│ 374 │\n", "│ 375 │\n", "│ │\n", "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/functional.py: │\n", "│ 2210 in embedding │\n", "│ │\n", "│ 2207 │ │ # torch.embedding_renorm_ │\n", "│ 2208 │ │ # remove once script supports set_grad_enabled │\n", "│ 2209 │ │ _no_grad_embedding_renorm_(weight, input, max_norm, norm_type) │\n", "│ ❱ 2210 │ return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) │\n", "│ 2211 │\n", "│ 2212 │\n", "│ 2213 def embedding_bag( │\n", "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n", "IndexError: index out of range in self\n", "\n" ], "text/plain": [ "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", "\u001b[31m│\u001b[0m in \u001b[92m