| | |
| | |
| |
|
| | |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| | import os |
| | import sys |
| | import json |
| | import argparse |
| | import numpy as np |
| | import math |
| | from einops import rearrange |
| | import time |
| | import random |
| | import string |
| | import h5py |
| | from tqdm import tqdm |
| |
|
| | import webdataset as wds |
| | import gc |
| |
|
| | import matplotlib.pyplot as plt |
| | import torch |
| | import torch.nn as nn |
| | from torchvision import transforms |
| |
|
| | from accelerate import Accelerator, DeepSpeedPlugin |
| |
|
| | |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| |
|
| | |
| | import utils |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| | local_rank = os.getenv('RANK') |
| | if local_rank is None: |
| | local_rank = 0 |
| | else: |
| | local_rank = int(local_rank) |
| | print("LOCAL RANK ", local_rank) |
| |
|
| | num_devices = torch.cuda.device_count() |
| | if num_devices==0: num_devices = 1 |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | if num_devices <= 1 and utils.is_interactive(): |
| | global_batch_size = batch_size = 128 |
| | print(f"Setting batch_size to {batch_size}") |
| | |
| | os.environ["MASTER_ADDR"] = "localhost" |
| | os.environ["MASTER_PORT"] = str(np.random.randint(10000)+9000) |
| | os.environ["RANK"] = "0" |
| | os.environ["LOCAL_RANK"] = "0" |
| | os.environ["WORLD_SIZE"] = "1" |
| | os.environ["GLOBAL_BATCH_SIZE"] = str(global_batch_size) |
| | else: |
| | global_batch_size = os.environ["GLOBAL_BATCH_SIZE"] |
| | batch_size = int(os.environ["GLOBAL_BATCH_SIZE"]) // num_devices |
| |
|
| | |
| | if local_rank == 0: |
| | with open('deepspeed_config_stage2.json', 'r') as file: |
| | config = json.load(file) |
| | config['train_batch_size'] = int(os.environ["GLOBAL_BATCH_SIZE"]) |
| | config['train_micro_batch_size_per_gpu'] = batch_size |
| | config['bf16'] = {'enabled': False} |
| | config['fp16'] = {'enabled': True} |
| | with open('deepspeed_config_stage2.json', 'w') as file: |
| | json.dump(config, file) |
| | else: |
| | |
| | time.sleep(10) |
| | deepspeed_plugin = DeepSpeedPlugin("deepspeed_config_stage2_cpuoffload.json") |
| | accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin) |
| |
|
| |
|
| | |
| |
|
| |
|
| | print("PID of this process =",os.getpid()) |
| | device = accelerator.device |
| | print("device:",device) |
| | num_workers = num_devices |
| | print(accelerator.state) |
| | world_size = accelerator.state.num_processes |
| | distributed = not accelerator.state.distributed_type == 'NO' |
| |
|
| | |
| | if accelerator.mixed_precision == "bf16": |
| | data_type = torch.bfloat16 |
| | elif accelerator.mixed_precision == "fp16": |
| | data_type = torch.float16 |
| | else: |
| | data_type = torch.float32 |
| |
|
| | print("distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size, "data_type =", data_type) |
| | print = accelerator.print |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| | |
| | if utils.is_interactive(): |
| | |
| | model_name = ''.join(random.choices(string.ascii_letters + string.digits, k=10)) |
| | model_name = model_name + "_interactive" |
| | print("model_name:", model_name) |
| |
|
| | |
| | |
| | jupyter_args = f"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \ |
| | --model_name={model_name} \ |
| | --subj=1 --batch_size={batch_size} --no-blurry_recon --no-depth_recon --hidden_dim=4096 \ |
| | --clip_scale=1. --blur_scale=100. --depth_scale=100. \ |
| | --max_lr=3e-4 --mixup_pct=.66 --num_epochs=12 --ckpt_interval=999 --no-use_image_aug --no-ckpt_saving" |
| |
|
| | jupyter_args = jupyter_args.split() |
| | print(jupyter_args) |
| | |
| | from IPython.display import clear_output |
| | get_ipython().run_line_magic('load_ext', 'autoreload') |
| | |
| | get_ipython().run_line_magic('autoreload', '2') |
| |
|
| |
|
| | |
| |
|
| |
|
| | parser = argparse.ArgumentParser(description="Model Training Configuration") |
| | parser.add_argument( |
| | "--model_name", type=str, default="testing", |
| | help="name of model, used for ckpt saving and wandb logging (if enabled)", |
| | ) |
| | parser.add_argument( |
| | "--data_path", type=str, default="/fsx/proj-fmri/shared/natural-scenes-dataset", |
| | help="Path to where NSD data is stored / where to download it to", |
| | ) |
| | parser.add_argument( |
| | "--subj",type=int, default=1, choices=[1,2,5,7], |
| | ) |
| | parser.add_argument( |
| | "--batch_size", type=int, default=32, |
| | help="Batch size can be increased by 10x if only training v2c and not diffusion diffuser", |
| | ) |
| | parser.add_argument( |
| | "--wandb_log",action=argparse.BooleanOptionalAction,default=False, |
| | help="whether to log to wandb", |
| | ) |
| | parser.add_argument( |
| | "--resume_from_ckpt",action=argparse.BooleanOptionalAction,default=False, |
| | help="if not using wandb and want to resume from a ckpt", |
| | ) |
| | parser.add_argument( |
| | "--wandb_project",type=str,default="stability", |
| | help="wandb project name", |
| | ) |
| | parser.add_argument( |
| | "--mixup_pct",type=float,default=.33, |
| | help="proportion of way through training when to switch from BiMixCo to SoftCLIP", |
| | ) |
| | parser.add_argument( |
| | "--blurry_recon",action=argparse.BooleanOptionalAction,default=True, |
| | help="whether to output blurry reconstructions", |
| | ) |
| | parser.add_argument( |
| | "--depth_recon",action=argparse.BooleanOptionalAction,default=True, |
| | help="whether to output depth reconstructions", |
| | ) |
| | parser.add_argument( |
| | "--blur_scale",type=float,default=100., |
| | help="multiply loss from blurry recons by this number", |
| | ) |
| | parser.add_argument( |
| | "--depth_scale",type=float,default=100., |
| | help="multiply loss from depth recons by this number", |
| | ) |
| | parser.add_argument( |
| | "--clip_scale",type=float,default=1., |
| | help="multiply contrastive loss by this number", |
| | ) |
| | parser.add_argument( |
| | "--use_image_aug",action=argparse.BooleanOptionalAction,default=True, |
| | help="whether to use image augmentation", |
| | ) |
| | parser.add_argument( |
| | "--num_epochs",type=int,default=120, |
| | help="number of epochs of training", |
| | ) |
| | parser.add_argument( |
| | "--hidden_dim",type=int,default=4096, |
| | ) |
| | parser.add_argument( |
| | "--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'], |
| | ) |
| | parser.add_argument( |
| | "--ckpt_saving",action=argparse.BooleanOptionalAction,default=True, |
| | ) |
| | parser.add_argument( |
| | "--ckpt_interval",type=int,default=5, |
| | help="save backup ckpt and reconstruct every x epochs", |
| | ) |
| | parser.add_argument( |
| | "--seed",type=int,default=42, |
| | ) |
| | parser.add_argument( |
| | "--max_lr",type=float,default=3e-4, |
| | ) |
| |
|
| | if utils.is_interactive(): |
| | args = parser.parse_args(jupyter_args) |
| | else: |
| | args = parser.parse_args() |
| |
|
| | |
| | for attribute_name in vars(args).keys(): |
| | globals()[attribute_name] = getattr(args, attribute_name) |
| |
|
| |
|
| | |
| |
|
| |
|
| | outdir = os.path.abspath(f'../train_logs/{model_name}') |
| | if not os.path.exists(outdir) and ckpt_saving: |
| | os.makedirs(outdir,exist_ok=True) |
| | if use_image_aug: |
| | import kornia |
| | from kornia.augmentation.container import AugmentationSequential |
| | img_augment = AugmentationSequential( |
| | kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3), |
| | kornia.augmentation.Resize((224, 224)), |
| | kornia.augmentation.RandomHorizontalFlip(p=0.3), |
| | kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3), |
| | kornia.augmentation.RandomGrayscale(p=0.3), |
| | same_on_batch=False, |
| | data_keys=["input"], |
| | ) |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| | if subj==1: |
| | num_train = 24958 |
| | num_test = 2770 |
| | test_batch_size = num_test |
| |
|
| | def my_split_by_node(urls): return urls |
| | |
| | train_url = f"{data_path}/wds/subj0{subj}/train/" + "{0..36}.tar" |
| | |
| | print(train_url) |
| |
|
| | train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\ |
| | .shuffle(750, initial=1500, rng=random.Random(42))\ |
| | .decode("torch")\ |
| | .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\ |
| | .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"]) |
| | train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True) |
| |
|
| | test_url = f"{data_path}/wds/subj0{subj}/test/" + "0.tar" |
| | print(test_url) |
| |
|
| | test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\ |
| | .shuffle(750, initial=1500, rng=random.Random(42))\ |
| | .decode("torch")\ |
| | .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\ |
| | .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"]) |
| | test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=True, pin_memory=True) |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| | test_vox_indices = [] |
| | test_73k_images = [] |
| | for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl): |
| | test_vox_indices = np.append(test_vox_indices, behav[:,0,5].cpu().numpy()) |
| | test_73k_images = np.append(test_73k_images, behav[:,0,0].cpu().numpy()) |
| | test_vox_indices = test_vox_indices.astype(np.int16) |
| | print(test_i, (test_i+1) * test_batch_size, len(test_vox_indices)) |
| | print("---\n") |
| |
|
| | train_vox_indices = [] |
| | train_73k_images = [] |
| | for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl): |
| | train_vox_indices = np.append(train_vox_indices, behav[:,0,5].long().cpu().numpy()) |
| | train_73k_images = np.append(train_73k_images, behav[:,0,0].cpu().numpy()) |
| | train_vox_indices = train_vox_indices.astype(np.int16) |
| | print(train_i, (train_i+1) * batch_size, len(train_vox_indices)) |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| | |
| | f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r') |
| | |
| |
|
| | voxels = f['betas'][:] |
| | print(f"subj0{subj} betas loaded into memory") |
| | voxels = torch.Tensor(voxels).to("cpu").to(data_type) |
| | print("voxels", voxels.shape) |
| | num_voxels = voxels.shape[-1] |
| |
|
| | |
| | f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r') |
| | images = f['images'][:] |
| | images = torch.Tensor(images).to("cpu").to(data_type) |
| | print("images", images.shape) |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| | from models import Clipper |
| | clip_model = Clipper("ViT-L/14", device=torch.device(f"cuda:{local_rank}"), hidden_state=True, norm_embs=True) |
| | clip_seq_dim = 257 |
| | clip_emb_dim = 768 |
| | |
| | seq_len = 1 |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | if blurry_recon: |
| | from diffusers import VQModel |
| | autoenc = VQModel.from_pretrained("/fsx/proj-fmri/shared/cache/models--microsoft--vq-diffusion-ithq/snapshots/3f796fb49ee559370dc638dea1d8116af131d993/vqvae", torch_dtype=data_type) |
| | autoenc.eval() |
| | autoenc.requires_grad_(False) |
| | autoenc.to(device) |
| | utils.count_params(autoenc) |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| | if blurry_recon: |
| | if utils.is_interactive(): display(utils.torch_to_Image(images[[30]])) |
| |
|
| | input_batch = images[[30]].to(device) |
| | print(input_batch.shape) |
| |
|
| | downsampled_image = nn.functional.interpolate(input_batch, size=(8, 8), mode='bilinear', align_corners=False) |
| | re_upsampled_image = nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest') |
| | re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215 |
| | print(re_upsampled_enc.shape) |
| | |
| | if utils.is_interactive(): display(utils.torch_to_Image((autoenc.decode(re_upsampled_enc/0.18215).sample / 2 + 0.5).clamp(0,1))) |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| | if depth_recon: |
| | from controlnet_aux.midas import MidasDetector |
| | |
| | midas_depth = MidasDetector.from_pretrained( |
| | "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large", cache_dir="/fsx/proj-fmri/shared/cache").to(device) |
| | midas_depth.model.eval() |
| | midas_depth.model.requires_grad_(False) |
| | midas_depth.model.to(device) |
| | pass |
| |
|
| |
|
| | |
| |
|
| |
|
| | if depth_recon: |
| | if utils.is_interactive(): display(utils.torch_to_Image(images[[30]])) |
| |
|
| | input_batch = images[[30,31]].float().to(device) |
| | print(input_batch.shape) |
| | |
| | midas_emb = midas_depth.model(input_batch).unsqueeze(1) |
| | print(midas_emb.shape) |
| |
|
| | prediction = utils.resize(midas_emb, 32) |
| | print(prediction.shape) |
| | |
| | prediction = (prediction / prediction.view(prediction.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(prediction)).half() |
| | midas_emb_size = prediction.flatten(1).shape[1] |
| | print("midas_emb", prediction.shape, prediction.min(), prediction.max()) |
| | print("midas_emb_size", midas_emb_size) |
| | |
| | if utils.is_interactive(): display(utils.torch_to_Image(utils.resize(prediction, 224))) |
| |
|
| | if blurry_recon: |
| | prediction = utils.resize(midas_emb, 128).half().repeat(1,3,1,1) |
| | prediction = (prediction / prediction.view(prediction.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(prediction)).half() |
| | prediction_enc = autoenc.encode(2*prediction-1).latents * 0.18215 |
| | print("vae midas_emb", prediction_enc.shape, prediction_enc.min(), prediction_enc.max()) |
| | |
| | if utils.is_interactive(): display(utils.torch_to_Image((autoenc.decode(prediction_enc/0.18215).sample / 2 + 0.5).clamp(0,1))) |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| | class MindEyeModule(nn.Module): |
| | def __init__(self): |
| | super(MindEyeModule, self).__init__() |
| | def forward(self, x): |
| | return x |
| | |
| | model = MindEyeModule() |
| | model |
| |
|
| |
|
| | |
| |
|
| |
|
| | time_embedding_dim = 512 |
| |
|
| | class RidgeRegression(torch.nn.Module): |
| | |
| | def __init__(self, input_size, out_features): |
| | super(RidgeRegression, self).__init__() |
| | self.out_features = out_features |
| | self.linear = torch.nn.Linear(input_size, out_features) |
| | def forward(self, x): |
| | return self.linear(x) |
| | |
| | model.ridge = RidgeRegression(voxels.shape[1] + time_embedding_dim, out_features=hidden_dim) |
| | utils.count_params(model.ridge) |
| | utils.count_params(model) |
| |
|
| | b = torch.randn((2,1,voxels.shape[1])) |
| | time_emb_test = torch.randn((2,1,time_embedding_dim)) |
| | print(b.shape, model.ridge(torch.cat((b,time_emb_test),dim=-1)).shape) |
| |
|
| |
|
| | |
| |
|
| |
|
| | num_past_voxels = 15 |
| | seq_len = 1 + 1 |
| |
|
| |
|
| | |
| |
|
| |
|
| | class BrainNetwork(nn.Module): |
| | def __init__(self, out_dim=768, in_dim=15724, seq_len=2, h=4096, n_blocks=4, drop=.15, clip_size=768): |
| | super().__init__() |
| | self.seq_len = seq_len |
| | self.h = h |
| | self.clip_size = clip_size |
| | |
| | |
| | |
| | |
| | |
| | self.mixer_blocks1 = nn.ModuleList([ |
| | self.mixer_block1(h, drop) for _ in range(n_blocks) |
| | ]) |
| | self.mixer_blocks2 = nn.ModuleList([ |
| | self.mixer_block2(seq_len, drop) for _ in range(n_blocks) |
| | ]) |
| | |
| | |
| | self.clin1 = nn.Linear(h * seq_len, out_dim, bias=True) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | self.clip_proj = nn.Sequential( |
| | nn.LayerNorm(clip_size), |
| | nn.GELU(), |
| | nn.Linear(clip_size, 2048), |
| | nn.LayerNorm(2048), |
| | nn.GELU(), |
| | nn.Linear(2048, 2048), |
| | nn.LayerNorm(2048), |
| | nn.GELU(), |
| | nn.Linear(2048, clip_size) |
| | ) |
| |
|
| | if blurry_recon: |
| | |
| | |
| | |
| | |
| | |
| | self.blin1 = nn.Linear(h*seq_len, 4096) |
| | self.bgroupnorm = nn.GroupNorm(1, 256) |
| | self.bupsampler = Decoder( |
| | in_channels=256, |
| | out_channels=128, |
| | up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"], |
| | block_out_channels=[32, 64, 128], |
| | layers_per_block=1, |
| | ) |
| |
|
| | if depth_recon: |
| | |
| | |
| | |
| | |
| | self.dlin1 = nn.Linear(h*seq_len, 4096) |
| | self.dgroupnorm = nn.GroupNorm(1, 256) |
| | self.dupsampler = Decoder( |
| | in_channels=256, |
| | out_channels=1, |
| | up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"], |
| | block_out_channels=[32, 64, 128, 256], |
| | layers_per_block=1, |
| | ) |
| | |
| | def mixer_block1(self, h, drop): |
| | return nn.Sequential( |
| | nn.LayerNorm(h), |
| | self.mlp(h, h, drop), |
| | ) |
| |
|
| | def mixer_block2(self, seq_len, drop): |
| | return nn.Sequential( |
| | nn.LayerNorm(seq_len), |
| | self.mlp(seq_len, seq_len, drop) |
| | ) |
| | |
| | def mlp(self, in_dim, out_dim, drop): |
| | return nn.Sequential( |
| | nn.Linear(in_dim, out_dim), |
| | nn.GELU(), |
| | nn.Dropout(drop), |
| | nn.Linear(out_dim, out_dim), |
| | ) |
| | |
| | def forward(self, x): |
| | |
| | b,d = torch.Tensor([0.]), torch.Tensor([0.]) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | residual1 = x |
| | residual2 = x.permute(0,2,1) |
| | for block1, block2 in zip(self.mixer_blocks1,self.mixer_blocks2): |
| | x = block1(x) + residual1 |
| | residual1 = x |
| | x = x.permute(0,2,1) |
| | |
| | x = block2(x) + residual2 |
| | residual2 = x |
| | x = x.permute(0,2,1) |
| | |
| | |
| | x = x.reshape(x.size(0), -1) |
| | |
| | c = self.clin1(x) |
| |
|
| | |
| | |
| | |
| | c = self.clip_proj(c.reshape(len(c), -1, self.clip_size)) |
| |
|
| | if blurry_recon: |
| | b = self.blin1(x) |
| | b = b.reshape(len(b), 256, 4, 4) |
| | b = self.bgroupnorm(b) |
| | b = self.bupsampler(b) |
| | |
| | if depth_recon: |
| | d = self.dlin1(x) |
| | d = d.reshape(len(d), 256, 4, 4) |
| | d = self.dgroupnorm(d) |
| | d = self.dupsampler(d) |
| | |
| | return c, b, d |
| |
|
| |
|
| | class TimeEmbedding(nn.Module): |
| | def __init__(self, embedding_time_dim=512, num_past_voxels=15): |
| | super().__init__() |
| | self.embedding_time = nn.Embedding(num_past_voxels, embedding_time_dim) |
| | self.num_past_voxels = num_past_voxels |
| | self.embedding_time_dim = embedding_time_dim |
| |
|
| | def forward(self, time): |
| | |
| | time = time.long() |
| | time = self.embedding_time(time) |
| | return time |
| | |
| |
|
| | |
| | model.time_embedding = TimeEmbedding(embedding_time_dim=512, num_past_voxels=15) |
| |
|
| | model.backbone = BrainNetwork(h=1024, in_dim=1024, seq_len=4, clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim) |
| | utils.count_params(model.backbone) |
| | utils.count_params(model) |
| |
|
| |
|
| | |
| | b = torch.randn((256,4,1024)) |
| | print("b.shape",b.shape) |
| | with torch.no_grad(): |
| | clip_, blur_, depth_ = model.backbone(b) |
| | print(clip_.shape, blur_.shape, depth_.shape) |
| |
|
| |
|
| | |
| |
|
| |
|
| | voxel_ridge = torch.randn(512,4096) |
| | voxel_ridge = voxel_ridge.view(int(voxel_ridge.shape[0]/seq_len), seq_len, hidden_dim) |
| | print("b.shape",voxel_ridge.shape) |
| | with torch.no_grad(): |
| | clip_, blur_, depth_ = model.backbone(voxel_ridge) |
| | print(clip_.shape, blur_.shape, depth_.shape) |
| |
|
| |
|
| | |
| |
|
| |
|
| | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] |
| | opt_grouped_parameters = [ |
| | {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2}, |
| | {'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2}, |
| | {'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}, |
| | ] |
| |
|
| | optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr) |
| |
|
| | if lr_scheduler_type == 'linear': |
| | lr_scheduler = torch.optim.lr_scheduler.LinearLR( |
| | optimizer, |
| | total_iters=int(np.floor(num_epochs*(num_train/num_devices/batch_size))), |
| | last_epoch=-1 |
| | ) |
| | elif lr_scheduler_type == 'cycle': |
| | total_steps=int(np.floor(num_epochs*(num_train/num_devices/batch_size))) |
| | print("total_steps", total_steps) |
| | lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( |
| | optimizer, |
| | max_lr=max_lr, |
| | total_steps=total_steps, |
| | final_div_factor=1000, |
| | last_epoch=-1, pct_start=2/num_epochs |
| | ) |
| | |
| | def save_ckpt(tag): |
| | ckpt_path = outdir+f'/{tag}.pth' |
| | print(f'saving {ckpt_path}',flush=True) |
| | unwrapped_model = accelerator.unwrap_model(model) |
| | try: |
| | torch.save({ |
| | 'epoch': epoch, |
| | 'model_state_dict': unwrapped_model.state_dict(), |
| | 'optimizer_state_dict': optimizer.state_dict(), |
| | 'lr_scheduler': lr_scheduler.state_dict(), |
| | 'train_losses': losses, |
| | 'test_losses': test_losses, |
| | 'lrs': lrs, |
| | }, ckpt_path) |
| | except: |
| | print("Couldn't save... moving on to prevent crashing.") |
| | del unwrapped_model |
| | |
| | print("\nDone with model preparations!") |
| | utils.count_params(model) |
| |
|
| |
|
| | |
| |
|
| |
|
| | seq_len = 4 |
| |
|
| |
|
| | |
| |
|
| |
|
| | voxel_ridge = torch.randn(512,4096) |
| | voxel_ridge = voxel_ridge.view(int(voxel_ridge.shape[0]/seq_len), seq_len, hidden_dim) |
| |
|
| |
|
| | |
| |
|
| |
|
| | voxel_ridge.shape |
| |
|
| |
|
| | |
| |
|
| |
|
| | pp = None |
| | for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl): |
| | with torch.cuda.amp.autocast(dtype=data_type): |
| | |
| |
|
| | voxel = voxels[behav[:,0,5].cpu().long()] |
| | image = images[behav[:,0,0].cpu().long()].float() |
| |
|
| | past_15_voxels = voxels[past_behav[:,:seq_len-1,5].cpu().long()] |
| | past_15_times = torch.Tensor([i for i in range(seq_len)]) |
| | |
| | print(past_15_times) |
| | |
| | |
| | """ |
| | #if blurry_recon: |
| | # blurry_image_enc = autoenc.encode(2*utils.resize(image,128)-1).latent_dist.mode() * 0.18215 |
| | blurry_image_enc = autoenc.encode(2*utils.resize(add_saturation(image),128)-1).latents * 0.18215 |
| | |
| | if depth_recon: |
| | # depth_images = utils.resize(midas_depth.model(image).unsqueeze(1).repeat(1,3,1,1), 128) |
| | depth_images = utils.resize(midas_depth.model(image).unsqueeze(1), 32) |
| | depth_images = (depth_images / depth_images.view(depth_images.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(depth_images)).half() |
| | depth_image_enc = depth_images # autoenc.encode(2*depth_images-1).latents * 0.18215 |
| | |
| | if use_image_aug: |
| | image = img_augment(image) |
| | |
| | clip_target = clip_model.embed_image(image) |
| | assert not torch.any(torch.isnan(clip_target)) |
| | |
| | if epoch < int(mixup_pct * num_epochs): |
| | voxel, perm, betas, select = utils.mixco(voxel) |
| | past_voxel, _, _, _ = utils.mixco(voxel, perm=perm, betas=betas, select=select) |
| | """ |
| | for p in range(seq_len-1): |
| | print(past_behav.shape) |
| | print(past_behav[:,p,-1]) |
| | print(past_15_voxels.shape) |
| | mask = past_behav[:,p,-1] == torch.ones_like(past_behav[:,p,-1]) |
| | print(mask) |
| | past_15_voxels[mask, p, :] = torch.zeros_like(past_15_voxels[0, p, :]) |
| | print(past_15_voxels) |
| | pp = past_15_voxels |
| | |
| | break |
| |
|
| |
|
| | |
| |
|
| |
|
| | pp[20, 0, :] |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| | if local_rank==0 and wandb_log: |
| | import wandb |
| | wandb_project = 'mindeyev2' |
| | wandb_run = model_name |
| | wandb_notes = '' |
| | |
| | print(f"wandb {wandb_project} run {wandb_run}") |
| | wandb.login(host='https://stability.wandb.io') |
| | wandb_config = { |
| | "model_name": model_name, |
| | "global_batch_size": global_batch_size, |
| | "batch_size": batch_size, |
| | "num_epochs": num_epochs, |
| | "clip_scale": clip_scale, |
| | "blur_scale": blur_scale, |
| | "use_image_aug": use_image_aug, |
| | "max_lr": max_lr, |
| | "mixup_pct": mixup_pct, |
| | "num_train": num_train, |
| | "num_test": num_test, |
| | "ckpt_interval": ckpt_interval, |
| | "ckpt_saving": ckpt_saving, |
| | "seed": seed, |
| | "distributed": distributed, |
| | "num_devices": num_devices, |
| | "world_size": world_size, |
| | "train_url": train_url, |
| | "test_url": test_url, |
| | } |
| | print("wandb_config:\n",wandb_config) |
| | if True: |
| | print("wandb_id:",model_name) |
| | wandb.init( |
| | id = model_name, |
| | project=wandb_project, |
| | name=wandb_run, |
| | config=wandb_config, |
| | notes=wandb_notes, |
| | resume="allow", |
| | ) |
| | else: |
| | wandb.init( |
| | project=wandb_project, |
| | name=wandb_run, |
| | config=wandb_config, |
| | notes=wandb_notes, |
| | ) |
| | else: |
| | wandb_log = False |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| | epoch = 0 |
| | losses, test_losses, lrs = [], [], [] |
| | best_test_loss = 1e9 |
| | soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs)) |
| |
|
| | |
| | if resume_from_ckpt: |
| | print("\n---resuming from last.pth ckpt---\n") |
| | try: |
| | checkpoint = torch.load(outdir+'/last.pth', map_location='cpu') |
| | except: |
| | print('last.pth failed... trying last_backup.pth') |
| | checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu') |
| | epoch = checkpoint['epoch'] |
| | print("Epoch",epoch) |
| | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | del checkpoint |
| | elif wandb_log: |
| | if wandb.run.resumed: |
| | print("\n---resuming from last.pth ckpt---\n") |
| | try: |
| | checkpoint = torch.load(outdir+'/last.pth', map_location='cpu') |
| | except: |
| | print('last.pth failed... trying last_backup.pth') |
| | checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu') |
| | epoch = checkpoint['epoch'] |
| | print("Epoch",epoch) |
| | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | del checkpoint |
| | torch.cuda.empty_cache() |
| |
|
| |
|
| | |
| |
|
| |
|
| | model, optimizer, train_dl, lr_scheduler = accelerator.prepare( |
| | model, optimizer, train_dl, lr_scheduler |
| | ) |
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | def add_saturation(image, alpha=2): |
| | gray_image = 0.2989 * image[:, 0, :, :] + 0.5870 * image[:, 1, :, :] + 0.1140 * image[:, 2, :, :] |
| | gray_image = gray_image.unsqueeze(1).expand_as(image) |
| | saturated_image = alpha * image + (1 - alpha) * gray_image |
| | return torch.clamp(saturated_image, 0, 1) |
| |
|
| |
|
| | |
| |
|
| |
|
| | print(f"{model_name} starting with epoch {epoch} / {num_epochs}") |
| | progress_bar = tqdm(range(epoch,num_epochs), ncols=1200, disable=(local_rank!=0)) |
| | test_image, test_voxel = None, None |
| | mse = nn.MSELoss() |
| | l1 = nn.L1Loss() |
| |
|
| | for epoch in progress_bar: |
| | model.train() |
| | |
| | fwd_percent_correct = 0. |
| | bwd_percent_correct = 0. |
| | test_fwd_percent_correct = 0. |
| | test_bwd_percent_correct = 0. |
| |
|
| | loss_clip_total = 0. |
| | loss_blurry_total = 0. |
| | loss_depth_total = 0. |
| | test_loss_clip_total = 0. |
| | test_loss_blurry_total = 0. |
| | test_loss_depth_total = 0. |
| |
|
| | blurry_pixcorr = 0. |
| | test_blurry_pixcorr = 0. |
| | |
| | for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl): |
| | with torch.cuda.amp.autocast(dtype=data_type): |
| | optimizer.zero_grad() |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | voxel = voxels[behav[:,0,5].cpu().long()].to(device) |
| | image = images[behav[:,0,0].cpu().long()].to(device).float() |
| |
|
| | past_15_voxels = voxels[past_behav[:,:seq_len-1,5].cpu().long()].to(device) |
| | past_15_times = torch.Tensor([i for i in range(seq_len-1)]).to(device) |
| | |
| | |
| | |
| | if blurry_recon: |
| | |
| | blurry_image_enc = autoenc.encode(2*utils.resize(add_saturation(image),128)-1).latents * 0.18215 |
| |
|
| | if depth_recon: |
| | |
| | depth_images = utils.resize(midas_depth.model(image).unsqueeze(1), 32) |
| | depth_images = (depth_images / depth_images.view(depth_images.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(depth_images)).half() |
| | depth_image_enc = depth_images |
| | |
| | if use_image_aug: |
| | image = img_augment(image) |
| | |
| | clip_target = clip_model.embed_image(image) |
| | assert not torch.any(torch.isnan(clip_target)) |
| | |
| | if epoch < int(mixup_pct * num_epochs): |
| | voxel, perm, betas, select = utils.mixco(voxel) |
| | past_voxel, _, _, _ = utils.mixco(voxel, perm=perm, betas=betas, select=select) |
| | |
| | for p in range(seq_len-1): |
| | |
| | |
| | |
| | mask = past_behav[:,p,-1] == torch.ones_like(past_behav[:,p,-1]) |
| | |
| | past_15_voxels[mask, p, :] = torch.zeros_like(past_15_voxels[0, p, :]) |
| | |
| | |
| | past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1]) |
| | past_15_times = past_15_times.repeat(voxel.shape[0], 1) |
| | past_15_times = past_15_times.reshape(-1) |
| | time_embeddings = model.time_embedding(past_15_times) |
| | past_info_full = torch.cat((past_15_voxels, time_embeddings), dim=-1) |
| | |
| | positional_current_voxel = torch.zeros((voxel.shape[0], time_embeddings.shape[-1])).to(voxel.device) |
| | voxel = torch.cat((voxel, positional_current_voxel), dim=-1) |
| | voxel_ridge = model.ridge(torch.cat((voxel, past_info_full), dim=-2)) |
| | voxel_ridge = voxel_ridge.view(int(voxel_ridge.shape[0]/seq_len), seq_len, hidden_dim) |
| | |
| | |
| | |
| | |
| | print(voxel_ridge.shape) |
| | |
| | clip_voxels, blurry_image_enc_, depth_image_enc_ = model.backbone(voxel_ridge) |
| | |
| | clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1) |
| | clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1) |
| | |
| | if epoch < int(mixup_pct * num_epochs): |
| | loss_clip = utils.mixco_nce( |
| | clip_voxels_norm, |
| | clip_target_norm, |
| | temp=.006, |
| | perm=perm, betas=betas, select=select) |
| | else: |
| | epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)] |
| | loss_clip = utils.soft_clip_loss( |
| | clip_voxels_norm, |
| | clip_target_norm, |
| | temp=epoch_temp) |
| |
|
| | loss_clip_total += loss_clip.item() |
| | loss_clip *= clip_scale |
| | loss = loss_clip |
| | |
| | if blurry_recon: |
| | downsampled_image = nn.functional.interpolate(image, size=(8, 8), mode='bilinear', align_corners=False) |
| | re_upsampled_image = add_saturation(nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest')) |
| | re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215 |
| | |
| | loss_blurry = (l1(blurry_image_enc_, blurry_image_enc) + l1(blurry_image_enc_, re_upsampled_enc)) |
| | loss_blurry += l1(torch.var(blurry_image_enc), torch.var(blurry_image_enc_)) |
| | loss_blurry_total += loss_blurry.item() |
| | loss_blurry *= blur_scale |
| | loss += loss_blurry |
| |
|
| | if depth_recon: |
| | loss_depth = l1(depth_image_enc_, depth_image_enc) |
| | |
| | loss_depth_total += loss_depth.item() |
| | loss_depth *= depth_scale |
| | loss += loss_depth |
| | |
| | |
| | labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) |
| | fwd_percent_correct += utils.topk(torch.abs(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm)), labels, k=1).item() |
| | bwd_percent_correct += utils.topk(torch.abs(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm)), labels, k=1).item() |
| | |
| | if blurry_recon: |
| | with torch.no_grad(): |
| | |
| | random_samps = np.random.choice(np.arange(len(voxel)), size=batch_size//5, replace=False) |
| | |
| | blurry_recon_images = (autoenc.decode(blurry_image_enc_[random_samps]/0.18215).sample/ 2 + 0.5).clamp(0,1) |
| | |
| | pixcorr = utils.pixcorr_origsize_nanmean(image[random_samps], blurry_recon_images) |
| | |
| | |
| | blurry_pixcorr += pixcorr.item() |
| | |
| |
|
| | utils.check_loss(loss) |
| | accelerator.backward(loss) |
| | optimizer.step() |
| | |
| | losses.append(loss.item()) |
| | lrs.append(optimizer.param_groups[0]['lr']) |
| | |
| | if lr_scheduler_type is not None: |
| | lr_scheduler.step() |
| |
|
| | model.eval() |
| | if local_rank==0: |
| | with torch.no_grad(), torch.cuda.amp.autocast(dtype=data_type): |
| | for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl): |
| | |
| | assert len(behav) == num_test |
| | |
| | |
| | if test_image is None: |
| | voxel = voxels[behav[:,0,5].cpu().long()] |
| | image = behav[:,0,0].cpu().long() |
| | |
| | unique_image, sort_indices = torch.unique(image, return_inverse=True) |
| | for im in unique_image: |
| | locs = torch.where(im == image)[0] |
| | if test_image is None: |
| | test_image = images[im][None] |
| | test_voxel = torch.mean(voxel[locs],axis=0)[None] |
| | else: |
| | test_image = torch.vstack((test_image, images[im][None])) |
| | test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None])) |
| | |
| | |
| | random_indices = torch.arange(len(test_voxel))[:300] |
| | voxel = test_voxel[random_indices].to(device) |
| | image = test_image[random_indices].to(device) |
| | assert len(image) == 300 |
| | |
| | current_past_behav = past_behav[random_indices] |
| |
|
| | past_15_voxels = voxels[current_past_behav[:,:seq_len-1,5].cpu().long()].to(device) |
| | past_15_times = torch.Tensor([i for i in range(seq_len-1)]).to(device) |
| |
|
| | if blurry_recon: |
| | |
| | blurry_image_enc = autoenc.encode(2*utils.resize(add_saturation(image),128)-1).latents * 0.18215 |
| |
|
| | if depth_recon: |
| | |
| | depth_images = utils.resize(midas_depth.model(image).unsqueeze(1), 32) |
| | depth_images = (depth_images / depth_images.view(depth_images.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(depth_images)).half() |
| | depth_image_enc = depth_images |
| | |
| | clip_target = clip_model.embed_image(image.float()) |
| | |
| |
|
| | past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1]) |
| | past_15_times = past_15_times.repeat(voxel.shape[0], 1) |
| | past_15_times = past_15_times.reshape(-1) |
| | time_embeddings = model.time_embedding(past_15_times) |
| | past_info_full = torch.cat((past_15_voxels, time_embeddings), dim=-1) |
| |
|
| | positional_current_voxel = torch.zeros((voxel.shape[0], time_embeddings.shape[-1])).to(voxel.device) |
| | voxel = torch.cat((voxel, positional_current_voxel), dim=-1) |
| | voxel_ridge = model.ridge(torch.cat((voxel, past_info_full), dim=-2)).unsqueeze(1) |
| | |
| | |
| |
|
| | |
| | |
| | clip_voxels, blurry_image_enc_, depth_image_enc_ = model.backbone(voxel_ridge) |
| | |
| | clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1) |
| | clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1) |
| | |
| | loss_clip = utils.soft_clip_loss( |
| | clip_voxels_norm, |
| | clip_target_norm, |
| | temp=.006) |
| | test_loss_clip_total += loss_clip.item() |
| | loss_clip = loss_clip * clip_scale |
| | loss = loss_clip |
| |
|
| | if blurry_recon: |
| | downsampled_image = nn.functional.interpolate(image, size=(8, 8), mode='bilinear', align_corners=False) |
| | re_upsampled_image = add_saturation(nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest')) |
| | re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215 |
| | |
| | loss_blurry = (l1(blurry_image_enc_, blurry_image_enc) + l1(blurry_image_enc_, re_upsampled_enc)) |
| | loss_blurry += l1(torch.var(blurry_image_enc), torch.var(blurry_image_enc_)) |
| | test_loss_blurry_total += loss_blurry.item() |
| | loss_blurry *= blur_scale |
| | loss += loss_blurry |
| | |
| | |
| | blurry_recon_images = (autoenc.decode(blurry_image_enc_[:len(voxel)//2]/0.18215).sample / 2 + 0.5).clamp(0,1) |
| | blurry_recon_images = torch.vstack((blurry_recon_images, (autoenc.decode(blurry_image_enc_[len(voxel)//2:]/0.18215).sample / 2 + 0.5).clamp(0,1))) |
| | pixcorr = utils.pixcorr(image, blurry_recon_images) |
| | loss += (1 - pixcorr) |
| | test_blurry_pixcorr += pixcorr.item() |
| |
|
| | if depth_recon: |
| | loss_depth = l1(depth_image_enc_, depth_image_enc) |
| | |
| | test_loss_depth_total += loss_depth.item() |
| | loss_depth *= depth_scale |
| | loss += loss_depth |
| | |
| | |
| | labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) |
| | test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1).item() |
| | test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1).item() |
| |
|
| | utils.check_loss(loss) |
| | test_losses.append(loss.item()) |
| |
|
| | |
| | print("---") |
| | |
| | assert (test_i+1) == 1 |
| | logs = {"train/loss": np.mean(losses[-(train_i+1):]), |
| | "test/loss": np.mean(test_losses[-(test_i+1):]), |
| | "train/lr": lrs[-1], |
| | "train/num_steps": len(losses), |
| | "test/num_steps": len(test_losses), |
| | "train/fwd_pct_correct": fwd_percent_correct / (train_i + 1), |
| | "train/bwd_pct_correct": bwd_percent_correct / (train_i + 1), |
| | "test/test_fwd_pct_correct": test_fwd_percent_correct / (test_i + 1), |
| | "test/test_bwd_pct_correct": test_bwd_percent_correct / (test_i + 1), |
| | "train/loss_clip_total": loss_clip_total / (train_i + 1), |
| | "train/loss_blurry_total": loss_blurry_total / (train_i + 1), |
| | "test/loss_clip_total": test_loss_clip_total / (test_i + 1), |
| | "test/loss_blurry_total": test_loss_blurry_total / (test_i + 1), |
| | "train/blurry_pixcorr": blurry_pixcorr / (train_i + 1), |
| | "test/blurry_pixcorr": test_blurry_pixcorr / (test_i + 1), |
| | "train/loss_depth_total": loss_depth_total / (train_i + 1), |
| | "test/loss_depth_total": test_loss_depth_total / (test_i + 1), |
| | } |
| | |
| | if blurry_recon: |
| | |
| | fig, axes = plt.subplots(1, 8, figsize=(10, 4)) |
| | jj=-1 |
| | for j in [0,1,2,3]: |
| | jj+=1 |
| | axes[jj].imshow(utils.torch_to_Image((autoenc.decode(blurry_image_enc[[j]]/0.18215).sample / 2 + 0.5).clamp(0,1))) |
| | axes[jj].axis('off') |
| | jj+=1 |
| | axes[jj].imshow(utils.torch_to_Image((autoenc.decode(blurry_image_enc_[[j]]/0.18215).sample / 2 + 0.5).clamp(0,1))) |
| | axes[jj].axis('off') |
| | |
| | if wandb_log: |
| | logs[f"test/recons"] = wandb.Image(fig, caption=f"epoch{epoch:03d}") |
| | plt.close() |
| | else: |
| | plt.show() |
| |
|
| | if depth_recon: |
| | |
| | fig, axes = plt.subplots(1, 8, figsize=(10, 4)) |
| | |
| | |
| | jj=-1 |
| | for j in [0,1,2,3]: |
| | jj+=1 |
| | axes[jj].imshow(utils.torch_to_Image(utils.resize(depth_image_enc[[j]].view(1,1,32,32).clamp(0,1), 224))) |
| | axes[jj].axis('off') |
| | jj+=1 |
| | axes[jj].imshow(utils.torch_to_Image(utils.resize(depth_image_enc_[[j]].view(1,1,32,32).clamp(0,1), 224))) |
| | axes[jj].axis('off') |
| | if wandb_log: |
| | logs[f"test/depth_recons"] = wandb.Image(fig, caption=f"epoch{epoch:03d}") |
| | plt.close() |
| | else: |
| | plt.show() |
| | |
| | progress_bar.set_postfix(**logs) |
| | |
| | |
| | if epoch % ckpt_interval == 0: |
| | if not utils.is_interactive(): |
| | save_ckpt(f'last') |
| | |
| | if wandb_log: wandb.log(logs) |
| |
|
| | |
| | accelerator.wait_for_everyone() |
| | torch.cuda.empty_cache() |
| | gc.collect() |
| |
|
| | print("\n===Finished!===\n") |
| | if ckpt_saving: |
| | save_ckpt(f'last') |
| | if not utils.is_interactive(): |
| | sys.exit(0) |
| |
|
| |
|
| | |
| |
|
| |
|
| | plt.plot(losses) |
| | plt.show() |
| | plt.plot(test_losses) |
| | plt.show() |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| | annots = np.load("/fsx/proj-fmri/shared/mindeyev2_dataset/COCO_73k_annots_curated.npy") |
| |
|
| |
|
| | |
| |
|
| |
|
| | ii=2 |
| | all_indices = np.unique(train_73k_images) |
| | with torch.no_grad(), torch.cuda.amp.autocast(): |
| | for batch in tqdm(range(0,len(all_indices),512)): |
| | if batch==0: |
| | clip_target = clip_model.embed_image(images[all_indices[batch:batch+512]]).cpu() |
| | else: |
| | target = clip_model.embed_image(images[all_indices[batch:batch+512]]).cpu() |
| | clip_target = torch.vstack((clip_target,target)) |
| | clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1) |
| |
|
| | voxel = test_voxel[[ii]].to(device) |
| | image = test_image[[ii]].to(device) |
| |
|
| | print("Original Image (test set)") |
| | display(utils.torch_to_Image(image)) |
| | |
| | clip_target = clip_model.embed_image(image).cpu() |
| | |
| | |
| | voxel_ridge = model.ridge(voxel).unsqueeze(1) |
| | clip_voxels, _, _ = model.backbone(voxel_ridge) |
| | clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1) |
| | clip_voxels_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1) |
| |
|
| | print("clip_voxels_norm", clip_voxels_norm.shape) |
| | print("clip_target_norm", clip_target_norm.shape) |
| | |
| | sortt = torch.argsort(utils.batchwise_cosine_similarity(clip_voxels_norm.cpu(), |
| | clip_target_norm).flatten()).flip(0) |
| | picks = all_indices[sortt[:5]] |
| |
|
| | print("\nNearest neighbors in training set") |
| | for ip,p in enumerate(picks): |
| | display(utils.torch_to_Image(images[[p]])) |
| | |
| | if ip==0: predicted_caption = utils.select_annotations([annots[int(p)]])[0] |
| |
|
| | print("\n=====\npredicted_caption:\n", predicted_caption) |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| | from diffusers import StableDiffusionXLPipeline |
| | pipe = StableDiffusionXLPipeline.from_pretrained( |
| | "/fsx/proj-fmri/shared/cache/models--stabilityai--stable-diffusion-xl-base-1.0/snapshots/f898a3e026e802f68796b95e9702464bac78d76f", torch_dtype=torch.float16, variant="fp16", use_safetensors=True |
| | ) |
| | pipe.to("cuda") |
| | pass |
| |
|
| |
|
| | |
| |
|
| |
|
| | prompt = predicted_caption |
| | recon = pipe(prompt=prompt).images[0] |
| |
|
| |
|
| | |
| |
|
| |
|
| | print("Seen image") |
| | display(utils.torch_to_Image(image)) |
| |
|
| | print("Reconstruction") |
| | utils.torch_to_Image(utils.resize(transforms.ToTensor()(recon),224)) |
| |
|
| |
|