{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "f4d95fac-ac1d-473c-ab96-650f76e6aaf5", "metadata": { "tags": [] }, "outputs": [], "source": [ "# # Code to convert this notebook to .py if you want to run it via command line or with Slurm\n", "# from subprocess import call\n", "# command = \"jupyter nbconvert Train.ipynb --to python\"\n", "# call(command,shell=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "0e012513-880e-4f88-9680-013397af1c8f", "metadata": { "tags": [] }, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "b0f0f4f3", "metadata": { "tags": [] }, "source": [ "# Import packages & functions" ] }, { "cell_type": "code", "execution_count": 2, "id": "5bad764b-45c1-45ce-a716-8d055e09821a", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[2023-09-05 13:05:25,854] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" ] } ], "source": [ "import os\n", "import sys\n", "import json\n", "import argparse\n", "import numpy as np\n", "import time\n", "import random\n", "import h5py\n", "from tqdm import tqdm\n", "\n", "import webdataset as wds\n", "import gc\n", "\n", "import matplotlib.pyplot as plt\n", "import torch\n", "import torch.nn as nn\n", "from torchvision import transforms\n", "from flash_attn import flash_attn_qkvpacked_func, flash_attn_func\n", "\n", "# tf32 data type is faster than standard float32\n", "torch.backends.cuda.matmul.allow_tf32 = True\n", "\n", "# custom functions #\n", "import utils" ] }, { "cell_type": "code", "execution_count": 3, "id": "cc5d2e32-6027-4a19-bef4-5ca068db35bb", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LOCAL RANK 0\n", "[2023-09-05 13:05:34,712] [WARNING] [comm.py:152:init_deepspeed_backend] NCCL backend in DeepSpeed not yet implemented\n", "[2023-09-05 13:05:34,712] [INFO] [comm.py:594:init_distributed] cdb=None\n", "[2023-09-05 13:05:34,713] [INFO] [comm.py:625:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl\n" ] } ], "source": [ "local_rank = os.getenv('RANK')\n", "if local_rank is None: \n", " local_rank = 0\n", "else:\n", " local_rank = int(local_rank)\n", "print(\"LOCAL RANK \", local_rank)\n", "\n", "### Single-GPU config ###\n", "## Feel free to uncomment the below 4 lines and comment out all the multi-gpu config code to simplify things for single-gpu\n", "# from accelerate import Accelerator\n", "# num_devices = torch.cuda.device_count()\n", "# if num_devices==0: num_devices = 1\n", "# accelerator = Accelerator(split_batches=False)\n", "# global_batch_size = 128\n", " \n", "### Multi-GPU config ###\n", "from accelerate import Accelerator, DeepSpeedPlugin\n", "num_devices = torch.cuda.device_count()\n", "if num_devices==0: num_devices = 1\n", "if num_devices <= 1 and utils.is_interactive():\n", " # can emulate a distributed environment for deepspeed to work in jupyter notebook\n", " os.environ[\"MASTER_ADDR\"] = \"localhost\"\n", " os.environ[\"MASTER_PORT\"] = str(np.random.randint(10000)+9000)\n", " os.environ[\"RANK\"] = \"0\"\n", " os.environ[\"LOCAL_RANK\"] = \"0\"\n", " os.environ[\"WORLD_SIZE\"] = \"1\"\n", " os.environ[\"GLOBAL_BATCH_SIZE\"] = \"128\" # set this to your batch size!\n", " global_batch_size = os.environ[\"GLOBAL_BATCH_SIZE\"]\n", "\n", "# alter the deepspeed config according to your global and local batch size\n", "if local_rank == 0:\n", " with open('deepspeed_config_stage2.json', 'r') as file:\n", " config = json.load(file)\n", " config['train_batch_size'] = int(os.environ[\"GLOBAL_BATCH_SIZE\"])\n", " config['train_micro_batch_size_per_gpu'] = int(os.environ[\"GLOBAL_BATCH_SIZE\"]) // num_devices\n", " with open('deepspeed_config_stage2.json', 'w') as file:\n", " json.dump(config, file)\n", "else:\n", " # give some time for the local_rank=0 gpu to prep new deepspeed config file\n", " time.sleep(10)\n", "deepspeed_plugin = DeepSpeedPlugin(\"deepspeed_config_stage2.json\")\n", "accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin)" ] }, { "cell_type": "code", "execution_count": 4, "id": "b767ab6f-d4a9-47a5-b3bf-f56bf6760c0c", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PID of this process = 804611\n", "device: cuda:0\n", "Distributed environment: DEEPSPEED Backend: nccl\n", "Num processes: 1\n", "Process index: 0\n", "Local process index: 0\n", "Device: cuda:0\n", "\n", "Mixed precision type: fp16\n", "ds_config: {'bf16': {'enabled': False}, 'fp16': {'enabled': True}, 'zero_optimization': {'stage': 2, 'contiguous_gradients': True, 'stage3_gather_16bit_weights_on_model_save': True, 'stage3_max_live_parameters': 1000000000.0, 'stage3_max_reuse_distance': 1000000000.0, 'stage3_prefetch_bucket_size': 10000000.0, 'stage3_param_persistence_threshold': 100000.0, 'reduce_bucket_size': 10000000.0, 'sub_group_size': 1000000000.0, 'offload_optimizer': {'device': 'none', 'nvme_path': '/scratch', 'pin_memory': True}, 'offload_param': {'device': 'none', 'nvme_path': '/scratch', 'buffer_size': 4000000000.0, 'pin_memory': True}}, 'aio': {'block_size': 26214400, 'queue_depth': 32, 'thread_count': 1, 'single_submit': False, 'overlap_events': True}, 'gradient_accumulation_steps': 1, 'gradient_clipping': 1.0, 'steps_per_print': inf, 'train_batch_size': 128, 'train_micro_batch_size_per_gpu': 128, 'wall_clock_breakdown': False, 'zero_allow_untested_optimizer': True}\n", "\n", "distributed = True num_devices = 1 local rank = 0 world size = 1\n" ] } ], "source": [ "print(\"PID of this process =\",os.getpid())\n", "device = accelerator.device\n", "print(\"device:\",device)\n", "num_workers = num_devices\n", "print(accelerator.state)\n", "world_size = accelerator.state.num_processes\n", "distributed = not accelerator.state.distributed_type == 'NO'\n", "print(\"distributed =\",distributed, \"num_devices =\", num_devices, \"local rank =\", local_rank, \"world size =\", world_size)\n", "print = accelerator.print # only print if local_rank=0" ] }, { "cell_type": "markdown", "id": "9018b82b-c054-4463-9527-4b0c2a75bda6", "metadata": { "tags": [] }, "source": [ "# Configurations" ] }, { "cell_type": "code", "execution_count": 5, "id": "2b61fec7-72a0-4b67-86da-1375f1d9fbd3", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset', '--model_name=test', '--subj=1', '--batch_size=128', '--n_samples_save=0', '--max_lr=5e-3', '--mixup_pct=.66', '--num_epochs=12', '--ckpt_interval=999', '--no-use_image_aug']\n" ] } ], "source": [ "# if running this interactively, can specify jupyter_args here for argparser to use\n", "if utils.is_interactive():\n", " # Example use\n", " jupyter_args = f\"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \\\n", " --model_name=test \\\n", " --subj=1 --batch_size={global_batch_size} --n_samples_save=0 \\\n", " --max_lr=5e-3 --mixup_pct=.66 --num_epochs=12 --ckpt_interval=999 --no-use_image_aug\"\n", "\n", " jupyter_args = jupyter_args.split()\n", " print(jupyter_args)\n", " \n", " from IPython.display import clear_output # function to clear print outputs in cell\n", " %load_ext autoreload \n", " # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions\n", " %autoreload 2 " ] }, { "cell_type": "code", "execution_count": 6, "id": "2028bdf0-2f41-46d9-b6e7-86b870dbf16c", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "global batch_size 128\n", "batch_size 128\n" ] } ], "source": [ "parser = argparse.ArgumentParser(description=\"Model Training Configuration\")\n", "parser.add_argument(\n", " \"--model_name\", type=str, default=\"testing\",\n", " help=\"name of model, used for ckpt saving and wandb logging (if enabled)\",\n", ")\n", "parser.add_argument(\n", " \"--data_path\", type=str, default=\"/fsx/proj-fmri/shared/natural-scenes-dataset\",\n", " help=\"Path to where NSD data is stored / where to download it to\",\n", ")\n", "parser.add_argument(\n", " \"--subj\",type=int, default=1, choices=[1,2,5,7],\n", ")\n", "parser.add_argument(\n", " \"--batch_size\", type=int, default=32,\n", " help=\"Batch size can be increased by 10x if only training v2c and not diffusion prior\",\n", ")\n", "parser.add_argument(\n", " \"--wandb_log\",action=argparse.BooleanOptionalAction,default=False,\n", " help=\"whether to log to wandb\",\n", ")\n", "parser.add_argument(\n", " \"--resume_from_ckpt\",action=argparse.BooleanOptionalAction,default=False,\n", " help=\"if not using wandb and want to resume from a ckpt\",\n", ")\n", "parser.add_argument(\n", " \"--wandb_project\",type=str,default=\"stability\",\n", " help=\"wandb project name\",\n", ")\n", "parser.add_argument(\n", " \"--mixup_pct\",type=float,default=.33,\n", " help=\"proportion of way through training when to switch from BiMixCo to SoftCLIP\",\n", ")\n", "parser.add_argument(\n", " \"--use_image_aug\",action=argparse.BooleanOptionalAction,default=True,\n", " help=\"whether to use image augmentation\",\n", ")\n", "parser.add_argument(\n", " \"--num_epochs\",type=int,default=240,\n", " help=\"number of epochs of training\",\n", ")\n", "parser.add_argument(\n", " \"--lr_scheduler_type\",type=str,default='cycle',choices=['cycle','linear'],\n", ")\n", "parser.add_argument(\n", " \"--ckpt_saving\",action=argparse.BooleanOptionalAction,default=True,\n", ")\n", "parser.add_argument(\n", " \"--ckpt_interval\",type=int,default=5,\n", " help=\"save backup ckpt and reconstruct every x epochs\",\n", ")\n", "parser.add_argument(\n", " \"--seed\",type=int,default=42,\n", ")\n", "parser.add_argument(\n", " \"--max_lr\",type=float,default=3e-4,\n", ")\n", "parser.add_argument(\n", " \"--n_samples_save\",type=int,default=0,choices=[0,1],\n", " help=\"Number of reconstructions for monitoring progress, 0 will speed up training\",\n", ")\n", "\n", "if utils.is_interactive():\n", " args = parser.parse_args(jupyter_args)\n", "else:\n", " args = parser.parse_args()\n", "\n", "# create global variables without the args prefix\n", "for attribute_name in vars(args).keys():\n", " globals()[attribute_name] = getattr(args, attribute_name)\n", "\n", "print(\"global batch_size\", batch_size)\n", "batch_size = int(batch_size / num_devices)\n", "print(\"batch_size\", batch_size)" ] }, { "cell_type": "code", "execution_count": 7, "id": "60cd7f2c-37fd-426b-a0c6-633e51bc4c4d", "metadata": { "tags": [] }, "outputs": [], "source": [ "outdir = os.path.abspath(f'../train_logs/{model_name}')\n", "if not os.path.exists(outdir):\n", " os.makedirs(outdir,exist_ok=True)\n", "if use_image_aug:\n", " import kornia\n", " from kornia.augmentation.container import AugmentationSequential\n", " img_augment = AugmentationSequential(\n", " kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),\n", " kornia.augmentation.Resize((224, 224)),\n", " kornia.augmentation.RandomHorizontalFlip(p=0.3),\n", " kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),\n", " kornia.augmentation.RandomGrayscale(p=0.3),\n", " same_on_batch=False,\n", " data_keys=[\"input\"],\n", " )" ] }, { "cell_type": "markdown", "id": "42d13c25-1369-4c49-81d4-83d713586096", "metadata": { "tags": [] }, "source": [ "# Prep data, models, and dataloaders" ] }, { "cell_type": "markdown", "id": "1c023f24-5233-4a15-a2f5-78487b3a8546", "metadata": {}, "source": [ "## Dataloader" ] }, { "cell_type": "code", "execution_count": 8, "id": "81084834-035f-4465-ad59-59e6b806a2f5", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/{0..36}.tar\n", "/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/test/0.tar\n" ] } ], "source": [ "if subj==1:\n", " num_train = 24958\n", " num_test = 2770\n", "test_batch_size = num_test\n", "\n", "def my_split_by_node(urls): return urls\n", " \n", "train_url = f\"{data_path}/wds/subj0{subj}/train/\" + \"{0..36}.tar\"\n", "print(train_url)\n", "\n", "train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\\\n", " .shuffle(750, initial=1500, rng=random.Random(42))\\\n", " .decode(\"torch\")\\\n", " .rename(behav=\"behav.npy\", past_behav=\"past_behav.npy\", future_behav=\"future_behav.npy\", olds_behav=\"olds_behav.npy\")\\\n", " .to_tuple(*[\"behav\", \"past_behav\", \"future_behav\", \"olds_behav\"])\n", "train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True)\n", "\n", "test_url = f\"{data_path}/wds/subj0{subj}/test/\" + \"0.tar\"\n", "print(test_url)\n", "\n", "test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\\\n", " .shuffle(750, initial=1500, rng=random.Random(42))\\\n", " .decode(\"torch\")\\\n", " .rename(behav=\"behav.npy\", past_behav=\"past_behav.npy\", future_behav=\"future_behav.npy\", olds_behav=\"olds_behav.npy\")\\\n", " .to_tuple(*[\"behav\", \"past_behav\", \"future_behav\", \"olds_behav\"])\n", "test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=False, pin_memory=True)" ] }, { "cell_type": "markdown", "id": "203b060a-2dd2-4c35-929b-c576be82eb52", "metadata": {}, "source": [ "### check dataloaders are working" ] }, { "cell_type": "code", "execution_count": 9, "id": "e7a9c68c-c3c9-4080-bd99-067c4486dc37", "metadata": { "tags": [] }, "outputs": [], "source": [ "# test_indices = []\n", "# test_images = []\n", "# for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):\n", "# test_indices = np.append(test_indices, behav[:,0,5].numpy())\n", "# test_images = np.append(test_images, behav[:,0,0].numpy())\n", "# test_indices = test_indices.astype(np.int16)\n", "# print(test_i, (test_i+1) * test_batch_size, len(test_indices))\n", "# print(\"---\\n\")\n", "\n", "# train_indices = []\n", "# train_images = []\n", "# for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):\n", "# train_indices = np.append(train_indices, behav[:,0,5].long().numpy())\n", "# train_images = np.append(train_images, behav[:,0,0].numpy())\n", "# train_indices = train_indices.astype(np.int16)\n", "# print(train_i, (train_i+1) * batch_size, len(train_indices))" ] }, { "cell_type": "markdown", "id": "45fad12c-f9fb-4408-8fd4-9bca324ad634", "metadata": {}, "source": [ "## Load voxel betas, K-means clustering model, and images" ] }, { "cell_type": "code", "execution_count": 10, "id": "039dd330-7339-4f88-8f00-45f95e47baa0", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "subj01 betas loaded into memory\n", "voxels torch.Size([27750, 15729])\n", "images torch.Size([73000, 3, 224, 224])\n" ] } ], "source": [ "# load betas\n", "f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r')\n", "voxels = f['betas'][:]\n", "print(f\"subj0{subj} betas loaded into memory\")\n", "voxels = torch.Tensor(voxels).to(\"cpu\").half()\n", "if subj==1:\n", " voxels = torch.hstack((voxels, torch.zeros((len(voxels), 5))))\n", "print(\"voxels\", voxels.shape)\n", "num_voxels = voxels.shape[-1]\n", "\n", "# load orig images\n", "f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')\n", "images = f['images'][:]\n", "images = torch.Tensor(images).to(\"cpu\").half()\n", "print(\"images\", images.shape)" ] }, { "cell_type": "code", "execution_count": 11, "id": "b0420dc0-199e-4c1a-857d-b1747058b467", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ViT-L/14 cuda:0\n" ] } ], "source": [ "from models import Clipper\n", "eva02_model = Clipper(\"ViT-L/14\", device=torch.device(f\"cuda:{local_rank}\"), hidden_state=True, norm_embs=True)\n", "\n", "clip_seq_dim = 257\n", "clip_emb_dim = 768\n", "hidden_dim = 4096" ] }, { "cell_type": "code", "execution_count": 12, "id": "c44c271b-173f-472e-b059-a2eda0f4c4c5", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "MindEyeModule()" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class MindEyeModule(nn.Module):\n", " def __init__(self):\n", " super(MindEyeModule, self).__init__()\n", " def forward(self, x):\n", " return x\n", " \n", "model = MindEyeModule()\n", "model" ] }, { "cell_type": "code", "execution_count": 13, "id": "038a5d61-4769-40b9-a004-f4e7b5b38bb0", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "param counts:\n", "64,430,080 total\n", "64,430,080 trainable\n", "param counts:\n", "64,430,080 total\n", "64,430,080 trainable\n", "torch.Size([2, 15729]) torch.Size([2, 4096])\n" ] } ], "source": [ "class RidgeRegression(torch.nn.Module):\n", " # make sure to add weight_decay when initializing optimizer\n", " def __init__(self, input_size, out_features): \n", " super(RidgeRegression, self).__init__()\n", " self.linear = torch.nn.Linear(input_size, out_features)\n", " def forward(self, x):\n", " return self.linear(x)\n", " \n", "model.ridge = RidgeRegression(voxels.shape[1], out_features=hidden_dim)\n", "utils.count_params(model.ridge)\n", "utils.count_params(model)\n", "\n", "b = torch.randn((2,voxels.shape[1]))\n", "print(b.shape, model.ridge(b).shape)" ] }, { "cell_type": "code", "execution_count": 14, "id": "8c7f2d47-08a4-40d9-ba63-a6b11c559d42", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "\"from functools import partial\\nclass BrainNetwork(nn.Module):\\n def __init__(self, out_dim=768, in_dim=15724, clip_size=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, use_projector=True, drop1=.5, drop2=.15):\\n super().__init__()\\n norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)\\n act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU\\n act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)\\n self.mlp = nn.ModuleList([\\n nn.Sequential(\\n nn.Linear(h, h),\\n *[item() for item in act_and_norm],\\n nn.Dropout(drop2)\\n ) for _ in range(n_blocks)\\n ])\\n self.lin1 = nn.Linear(h, out_dim, bias=True)\\n self.n_blocks = n_blocks\\n self.clip_size = clip_size\\n self.use_projector = use_projector\\n if use_projector:\\n self.projector = nn.Sequential(\\n nn.LayerNorm(clip_size),\\n nn.GELU(),\\n nn.Linear(clip_size, 2048),\\n nn.LayerNorm(2048),\\n nn.GELU(),\\n nn.Linear(2048, 2048),\\n nn.LayerNorm(2048),\\n nn.GELU(),\\n nn.Linear(2048, clip_size)\\n )\\n \\n def forward(self, x):\\n residual = x\\n for res_block in range(self.n_blocks):\\n x = self.mlp[res_block](x)\\n x += residual\\n residual = x\\n print(x.shape)\\n x = x.reshape(len(x), -1)\\n print(x.shape)\\n x = self.lin1(x)\\n print(x.shape)\\n if self.use_projector:\\n return self.projector(x.reshape(len(x), -1, self.clip_size))\\n return x\\n\\nmodel.backbone = BrainNetwork(in_dim=hidden_dim, clip_size=clip_emb_dim, out_dim=clip_seq_dim*clip_emb_dim, use_projector=True)\\nutils.count_params(model.backbone)\\nutils.count_params(model)\\n\\nb = torch.randn((2,hidden_dim))\\nprint(b.shape, model.backbone(b).shape)\"" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\"\"\"from functools import partial\n", "class BrainNetwork(nn.Module):\n", " def __init__(self, out_dim=768, in_dim=15724, clip_size=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, use_projector=True, drop1=.5, drop2=.15):\n", " super().__init__()\n", " norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)\n", " act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU\n", " act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)\n", " self.mlp = nn.ModuleList([\n", " nn.Sequential(\n", " nn.Linear(h, h),\n", " *[item() for item in act_and_norm],\n", " nn.Dropout(drop2)\n", " ) for _ in range(n_blocks)\n", " ])\n", " self.lin1 = nn.Linear(h, out_dim, bias=True)\n", " self.n_blocks = n_blocks\n", " self.clip_size = clip_size\n", " self.use_projector = use_projector\n", " if use_projector:\n", " self.projector = nn.Sequential(\n", " nn.LayerNorm(clip_size),\n", " nn.GELU(),\n", " nn.Linear(clip_size, 2048),\n", " nn.LayerNorm(2048),\n", " nn.GELU(),\n", " nn.Linear(2048, 2048),\n", " nn.LayerNorm(2048),\n", " nn.GELU(),\n", " nn.Linear(2048, clip_size)\n", " )\n", " \n", " def forward(self, x):\n", " residual = x\n", " for res_block in range(self.n_blocks):\n", " x = self.mlp[res_block](x)\n", " x += residual\n", " residual = x\n", " print(x.shape)\n", " x = x.reshape(len(x), -1)\n", " print(x.shape)\n", " x = self.lin1(x)\n", " print(x.shape)\n", " if self.use_projector:\n", " return self.projector(x.reshape(len(x), -1, self.clip_size))\n", " return x\n", "\n", "model.backbone = BrainNetwork(in_dim=hidden_dim, clip_size=clip_emb_dim, out_dim=clip_seq_dim*clip_emb_dim, use_projector=True)\n", "utils.count_params(model.backbone)\n", "utils.count_params(model)\n", "\n", "b = torch.randn((2,hidden_dim))\n", "print(b.shape, model.backbone(b).shape)\"\"\"" ] }, { "cell_type": "code", "execution_count": 15, "id": "863fcb22-f588-480f-ad1c-14bcda9130ef", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "param counts:\n", "883,253,376 total\n", "883,253,376 trainable\n", "param counts:\n", "947,683,456 total\n", "947,683,456 trainable\n" ] } ], "source": [ "from functools import partial\n", "class BrainNetwork(nn.Module):\n", " def __init__(self, out_dim=768, in_dim=15724, clip_size=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, use_projector=True, drop1=.5, drop2=.15):\n", " super().__init__()\n", " norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)\n", " act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU\n", " act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)\n", " self.mlp = nn.ModuleList([\n", " nn.Sequential(\n", " nn.Linear(h, h),\n", " *[item() for item in act_and_norm],\n", " nn.Dropout(drop2)\n", " ) for _ in range(n_blocks)\n", " ])\n", " self.lin1 = nn.Linear(h, out_dim, bias=True)\n", " self.n_blocks = n_blocks\n", " self.clip_size = clip_size\n", " self.use_projector = use_projector\n", " if use_projector:\n", " self.projector = nn.Sequential(\n", " nn.LayerNorm(clip_size),\n", " nn.GELU(),\n", " nn.Linear(clip_size, 2048),\n", " nn.LayerNorm(2048),\n", " nn.GELU(),\n", " nn.Linear(2048, 2048),\n", " nn.LayerNorm(2048),\n", " nn.GELU(),\n", " nn.Linear(2048, clip_size)\n", " )\n", " \n", " def forward(self, x):\n", " residual = x\n", " for res_block in range(self.n_blocks):\n", " x = self.mlp[res_block](x)\n", " x += residual\n", " residual = x\n", " x = x.reshape(len(x), -1)\n", " x = self.lin1(x)\n", " if self.use_projector:\n", " return self.projector(x.reshape(len(x), -1, self.clip_size))\n", " return x\n", "\n", "from flash_attn import flash_attn_qkvpacked_func, flash_attn_func\n", "from einops import rearrange\n", "\n", "\n", "class FeedForward(nn.Module):\n", " def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.):\n", " super().__init__()\n", " self.net = nn.Sequential(\n", " nn.Linear(input_dim, hidden_dim),\n", " nn.GELU(),\n", " nn.Dropout(dropout),\n", " nn.Linear(hidden_dim, output_dim),\n", " nn.Dropout(dropout)\n", " )\n", " def forward(self, x):\n", " return self.net(x)\n", "\n", "class TransformerBlock(nn.Module):\n", " def __init__(self, input_embedding_dim = 16, sequence_length = 4096, num_heads = 8, dropout_attention = 0.3, dropout_residual = 0.3, dropout_pre_norm = 0.3, use_pre_norm = True, voxels_embeddings = False, embedding_in = False):\n", " super().__init__()\n", " self.attention = flash_attn_qkvpacked_func\n", " self.feed_forward = FeedForward(input_embedding_dim, input_embedding_dim*4, input_embedding_dim)\n", " self.norm1 = nn.LayerNorm(input_embedding_dim)\n", " self.norm2 = nn.LayerNorm(input_embedding_dim)\n", "\n", " self.dropout_attention = nn.Dropout(dropout_attention)\n", " self.dropout_residual = nn.Dropout(dropout_residual)\n", " self.dropout_pre_norm = nn.Dropout(dropout_pre_norm)\n", "\n", " self.use_pre_norm = use_pre_norm\n", " self.voxels_embeddings = voxels_embeddings\n", "\n", " if self.voxels_embeddings:\n", " self.voxels_embeddings_projection = nn.Linear(1, input_embedding_dim)\n", "\n", " self.num_heads = num_heads\n", " self.input_embedding_dim = input_embedding_dim\n", " self.sequence_length = sequence_length\n", " self.embedding_in = embedding_in\n", "\n", " #query, key, value projection for each head\n", " self.qkv_projection = nn.ModuleList([\n", " nn.Linear(input_embedding_dim, input_embedding_dim*3//self.num_heads, bias=False) for _ in range(num_heads)\n", " ])\n", "\n", " def forward(self, voxels):\n", " # x: (batch_size, voxels_shape)\n", " if not self.embedding_in:\n", " if self.voxels_embeddings:\n", " voxels = self.voxels_embeddings_projection(voxels.unsqueeze(-1)) # (batch_size, voxels_shape, input_embedding_dim)\n", " else:\n", " # reshape voxels to (batch_size, voxels_shape//input_embedding_dim, input_embedding_dim)\n", " voxels = rearrange(voxels, 'b (s i) -> b s i', s=self.sequence_length//self.input_embedding_dim, i=self.input_embedding_dim)\n", " # voxels: (batch_size, sequence_length, input_embedding_dim)\n", " voxels = self.dropout_pre_norm(voxels)\n", " voxels = self.norm1(voxels)\n", " voxels = self.dropout_attention(voxels)\n", " qkv = torch.stack([proj(voxels) for proj in self.qkv_projection], dim=3) # (batch_size, sequence_length, 3, num_heads, input_embedding_dim)\n", " qkv = rearrange(qkv, 'b t (hd kqv) h -> b t kqv h hd', kqv = 3).type(torch.float16).to('cuda')\n", " qkv = self.attention(qkv, self.dropout_attention.p)\n", " qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], -1) # (batch_size, sequence_length, input_embedding_dim*3*num_heads)\n", " qkv = self.dropout_residual(qkv)\n", " voxels = voxels + qkv\n", " voxels = self.norm2(voxels)\n", " voxels = self.dropout_residual(voxels)\n", " voxels = self.feed_forward(voxels)\n", " return voxels\n", " \n", "\n", " \n", "\n", "# create model\n", "# TransformerBlockTest = TransformerBlock(input_embedding_dim=16, sequence_length=4096, num_heads=8, dropout_attention=0.3, dropout_residual=0.3, dropout_pre_norm=0.3, use_pre_norm=True, voxels_embeddings=True).to('cuda')\n", "# inputTest = torch.randn((2,4096)).to('cuda')\n", "# print(inputTest.shape, TransformerBlockTest(inputTest).shape)\n", "\n", "\n", "\n", "class BrainTransformerNetwork(nn.Module):\n", " def __init__(self, out_dim=768, in_dim=15724, embed_dim=16, n_blocks=4, num_heads=8, dropout_attention=0.3, dropout_residual=0.3, dropout_pre_norm=0.3, use_pre_norm=True, voxels_embeddings=True):\n", " super().__init__()\n", " self.out_dim = out_dim\n", " self.in_dim = in_dim\n", " self.embed_dim = embed_dim\n", " self.n_blocks = n_blocks\n", " self.num_heads = num_heads\n", " self.dropout_attention = dropout_attention\n", " self.dropout_residual = dropout_residual\n", " self.dropout_pre_norm = dropout_pre_norm\n", " self.use_pre_norm = use_pre_norm\n", " self.voxels_embeddings = voxels_embeddings\n", " self.transformer = nn.Sequential()\n", " for i in range(n_blocks):\n", " if voxels_embeddings and i == 0:\n", " self.transformer.add_module(f\"transformer_block_{i}\", TransformerBlock(input_embedding_dim=embed_dim, sequence_length=in_dim, num_heads=num_heads, dropout_attention=dropout_attention, dropout_residual=dropout_residual, dropout_pre_norm=dropout_pre_norm, use_pre_norm=use_pre_norm, voxels_embeddings=False, embedding_in = False))\n", " else:\n", " self.transformer.add_module(f\"transformer_block_{i}\", TransformerBlock(input_embedding_dim=embed_dim, sequence_length=in_dim//embed_dim, num_heads=num_heads, dropout_attention=dropout_attention, dropout_residual=dropout_residual, dropout_pre_norm=dropout_pre_norm, use_pre_norm=use_pre_norm, voxels_embeddings=False, embedding_in = True))\n", " \n", " #self.pre_head_lin = nn.Linear(in_dim//embed_dim * embed_dim, in_dim, bias = False)\n", " #self.head_lin = nn.Linear(in_dim, out_dim, bias = False)\n", " self.gelu = nn.GELU()\n", " self.BrainNetwork = BrainNetwork(in_dim=hidden_dim, clip_size=clip_emb_dim, out_dim=clip_seq_dim*clip_emb_dim, use_projector=True)\n", "\n", " def forward(self, x):\n", " x = self.transformer(x)\n", " x = rearrange(x, 'b s i -> b (s i)')\n", " x = self.gelu(x)\n", " x = self.BrainNetwork(x)\n", " return x\n", " \n", "import math\n", "# create model\n", "# TestBrainTransformerNetwork = BrainTransformerNetwork(out_dim=768, in_dim=4096, embed_dim=16, n_blocks=4, num_heads=8, dropout_attention=0.3, dropout_residual=0.3, dropout_pre_norm=0.3, use_pre_norm=True, voxels_embeddings=True).to('cuda')\n", "# inputTest = torch.randn((2,4096)).to('cuda')\n", "# print(inputTest.shape, TestBrainTransformerNetwork(inputTest).shape)\n", "\n", "#model.backbone = BrainNetwork(in_dim=hidden_dim, clip_size=clip_emb_dim, out_dim=clip_seq_dim*clip_emb_dim, use_projector=True)\n", "#utils.count_params(model.backbone)\n", "#utils.count_params(model)\n", "\n", "#b = torch.randn((2,hidden_dim))\n", "#print(b.shape, model.backbone(b).shape)\n", "\n", "model.backbone = BrainTransformerNetwork(out_dim=clip_seq_dim*clip_emb_dim, in_dim=4096, embed_dim=64, n_blocks=2, num_heads=16, dropout_attention=0.5, dropout_residual=0.4, dropout_pre_norm=0.4, use_pre_norm=True, voxels_embeddings=True).to('cuda')\n", "utils.count_params(model.backbone)\n", "utils.count_params(model)\n", "\n", "#b = torch.randn((2,hidden_dim)).to('cuda')\n", "#print(b.shape, model.backbone(b).shape)\n", "b = None" ] }, { "cell_type": "code", "execution_count": 16, "id": "1ce49f9e-2e43-42fb-8072-c991abfcce79", "metadata": { "tags": [] }, "outputs": [], "source": [ "#b = torch.randn((2,hidden_dim)).to('cuda')\n", "#print(b.shape, model.backbone(b).shape)" ] }, { "cell_type": "code", "execution_count": 17, "id": "e14d0482-dc42-43b9-9ce1-953c32f2c9c1", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Done with model preparations!\n" ] } ], "source": [ "no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']\n", "opt_grouped_parameters = [\n", " {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2},\n", " {'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 2e-2},\n", " {'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},\n", "]\n", "\n", "optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr)\n", "\n", "if lr_scheduler_type == 'linear':\n", " lr_scheduler = torch.optim.lr_scheduler.LinearLR(\n", " optimizer,\n", " total_iters=int(num_epochs*(num_train*num_devices//batch_size)),\n", " last_epoch=-1\n", " )\n", "elif lr_scheduler_type == 'cycle':\n", " total_steps=int(num_epochs*(num_train*num_devices//batch_size))\n", " lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(\n", " optimizer, \n", " max_lr=max_lr,\n", " total_steps=total_steps,\n", " final_div_factor=10,\n", " last_epoch=-1, pct_start=2/num_epochs\n", " )\n", " \n", "def save_ckpt(tag): \n", " ckpt_path = outdir+f'/{tag}.pth'\n", " print(f'saving {ckpt_path}',flush=True)\n", " unwrapped_model = accelerator.unwrap_model(model)\n", " try:\n", " torch.save({\n", " 'epoch': epoch,\n", " 'model_state_dict': unwrapped_model.state_dict(),\n", " 'optimizer_state_dict': optimizer.state_dict(),\n", " 'lr_scheduler': lr_scheduler.state_dict(),\n", " 'train_losses': losses,\n", " 'test_losses': test_losses,\n", " 'lrs': lrs,\n", " }, ckpt_path)\n", " except:\n", " print(\"Couldn't save... moving on to prevent crashing.\")\n", " del unwrapped_model\n", " \n", "print(\"\\nDone with model preparations!\")" ] }, { "cell_type": "code", "execution_count": 18, "id": "62a7f9f3-aedb-4c9e-925b-64a3642b8c43", "metadata": { "tags": [] }, "outputs": [], "source": [ "wandb_log = True" ] }, { "cell_type": "markdown", "id": "983f458b-35b8-49f2-b6db-80296cece730", "metadata": {}, "source": [ "# Weights and Biases" ] }, { "cell_type": "code", "execution_count": 19, "id": "0a25a662-daa8-4de9-9233-8364800fcb6b", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "wandb mindeye2 run transformer test run\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mckadirt\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "wandb_config:\n", " {'model_name': 'test', 'batch_size': 128, 'num_epochs': 12, 'use_image_aug': False, 'max_lr': 0.005, 'lr_scheduler_type': 'cycle', 'mixup_pct': 0.66, 'num_train': 24958, 'num_test': 2770, 'seed': 42, 'distributed': True, 'num_devices': 1, 'world_size': 1}\n" ] }, { "data": { "text/html": [ "wandb version 0.15.9 is available! To upgrade, please run:\n", " $ pip install wandb --upgrade" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Tracking run with wandb version 0.15.5" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in /fsx/proj-fmri/ckadirt/MindEyeV2/src/wandb/run-20230905_130625-6wdt860v" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run transformer test run to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://stability.wandb.io/ckadirt/mindeye2" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://stability.wandb.io/ckadirt/mindeye2/runs/6wdt860v" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "\n", "# params for wandb\n", "if local_rank==0 and wandb_log: # only use main process for wandb logging\n", " import wandb\n", " \n", " wandb_project = 'mindeye2'\n", " wandb_run = 'transformer test run'\n", " wandb_notes = ''\n", " \n", " print(f\"wandb {wandb_project} run {wandb_run}\")\n", " wandb.login(host='https://stability.wandb.io')#, relogin=True)\n", " wandb_config = {\n", " \"model_name\": model_name,\n", " \"batch_size\": batch_size,\n", " \"num_epochs\": num_epochs,\n", " \"use_image_aug\": use_image_aug,\n", " \"max_lr\": max_lr,\n", " \"lr_scheduler_type\": lr_scheduler_type,\n", " \"mixup_pct\": mixup_pct,\n", " \"num_train\": num_train,\n", " \"num_test\": num_test,\n", " \"seed\": seed,\n", " \"distributed\": distributed,\n", " \"num_devices\": num_devices,\n", " \"world_size\": world_size,\n", " }\n", " print(\"wandb_config:\\n\",wandb_config)\n", " if False: # wandb_auto_resume\n", " print(\"wandb_id:\",model_name)\n", " wandb.init(\n", " id = model_name,\n", " project=wandb_project,\n", " name=wandb_run,\n", " config=wandb_config,\n", " notes=wandb_notes,\n", " resume=\"allow\",\n", " )\n", " else:\n", " wandb.init(\n", " project=wandb_project,\n", " name=wandb_run,\n", " config=wandb_config,\n", " notes=wandb_notes,\n", " )\n", "else:\n", " wandb_log = False" ] }, { "cell_type": "markdown", "id": "d5690151-2131-4918-b750-e869cbd1a8a8", "metadata": {}, "source": [ "# Main" ] }, { "cell_type": "code", "execution_count": 20, "id": "12de6387-6e18-4e4b-b5ce-a847d625330a", "metadata": { "tags": [] }, "outputs": [], "source": [ "epoch = 0\n", "losses, test_losses, lrs = [], [], []\n", "best_test_loss = 1e9\n", "soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))\n", "\n", "resume_from_ckpt = False\n", "\n", "# Optionally resume from checkpoint #\n", "if resume_from_ckpt:\n", " print(\"\\n---resuming from last.pth ckpt---\\n\")\n", " try:\n", " checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')\n", " except:\n", " print('last.pth failed... trying last_backup.pth')\n", " checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')\n", " epoch = checkpoint['epoch']\n", " print(\"Epoch\",epoch)\n", " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", " lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n", " diffusion_prior.load_state_dict(checkpoint['model_state_dict'])\n", " del checkpoint\n", "elif False: #wandb_log:\n", " if wandb.run.resumed:\n", " print(\"\\n---resuming from last.pth ckpt---\\n\")\n", " try:\n", " checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')\n", " except:\n", " print('last.pth failed... trying last_backup.pth')\n", " checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')\n", " epoch = checkpoint['epoch']\n", " print(\"Epoch\",epoch)\n", " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", " lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n", " diffusion_prior.load_state_dict(checkpoint['model_state_dict'])\n", " del checkpoint\n", "torch.cuda.empty_cache()" ] }, { "cell_type": "code", "execution_count": 21, "id": "99f09f76-4481-4133-b09a-a22b10dbc0c4", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2023-09-05 13:06:35,335] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed info: version=0.9.5, git-hash=unknown, git-branch=unknown\n", "[2023-09-05 13:06:35,733] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed Flops Profiler Enabled: False\n", "[2023-09-05 13:06:35,735] [INFO] [logging.py:96:log_dist] [Rank 0] Removing param_group that has no 'params' in the client Optimizer\n", "[2023-09-05 13:06:35,735] [INFO] [logging.py:96:log_dist] [Rank 0] Using client Optimizer as basic optimizer\n", "[2023-09-05 13:06:35,737] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed Basic Optimizer = AdamW\n", "[2023-09-05 13:06:35,737] [INFO] [utils.py:54:is_zero_supported_optimizer] Checking ZeRO support for optimizer=AdamW type=\n", "[2023-09-05 13:06:35,738] [INFO] [logging.py:96:log_dist] [Rank 0] Creating torch.float16 ZeRO stage 2 optimizer\n", "[2023-09-05 13:06:35,738] [INFO] [stage_1_and_2.py:133:__init__] Reduce bucket size 10000000\n", "[2023-09-05 13:06:35,738] [INFO] [stage_1_and_2.py:134:__init__] Allgather bucket size 500,000,000\n", "[2023-09-05 13:06:35,739] [INFO] [stage_1_and_2.py:135:__init__] CPU Offload: False\n", "[2023-09-05 13:06:35,739] [INFO] [stage_1_and_2.py:136:__init__] Round robin gradient partitioning: False\n", "Rank: 0 partition count [1, 1, 1] and sizes[(64430080, False), (883012608, False), (240768, False)] \n", "[2023-09-05 13:06:37,941] [INFO] [utils.py:785:see_memory_usage] Before initializing optimizer states\n", "[2023-09-05 13:06:37,942] [INFO] [utils.py:786:see_memory_usage] MA 7.31 GB Max_MA 7.31 GB CA 7.33 GB Max_CA 7 GB \n", "[2023-09-05 13:06:37,943] [INFO] [utils.py:793:see_memory_usage] CPU Virtual Memory: used = 96.0 GB, percent = 8.6%\n", "[2023-09-05 13:06:38,117] [INFO] [utils.py:785:see_memory_usage] After initializing optimizer states\n", "[2023-09-05 13:06:38,118] [INFO] [utils.py:786:see_memory_usage] MA 14.37 GB Max_MA 24.48 GB CA 24.99 GB Max_CA 25 GB \n", "[2023-09-05 13:06:38,119] [INFO] [utils.py:793:see_memory_usage] CPU Virtual Memory: used = 96.0 GB, percent = 8.6%\n", "[2023-09-05 13:06:38,120] [INFO] [stage_1_and_2.py:488:__init__] optimizer state initialized\n", "[2023-09-05 13:06:38,267] [INFO] [utils.py:785:see_memory_usage] After initializing ZeRO optimizer\n", "[2023-09-05 13:06:38,268] [INFO] [utils.py:786:see_memory_usage] MA 14.37 GB Max_MA 14.37 GB CA 24.99 GB Max_CA 25 GB \n", "[2023-09-05 13:06:38,269] [INFO] [utils.py:793:see_memory_usage] CPU Virtual Memory: used = 96.0 GB, percent = 8.6%\n", "[2023-09-05 13:06:38,272] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed Final Optimizer = AdamW\n", "[2023-09-05 13:06:38,273] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed using client LR scheduler\n", "[2023-09-05 13:06:38,273] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed LR Scheduler = None\n", "[2023-09-05 13:06:38,274] [INFO] [logging.py:96:log_dist] [Rank 0] step=0, skipped=0, lr=[0.00019999999999999966, 0.00019999999999999966, 0.00019999999999999966], mom=[(0.95, 0.999), (0.95, 0.999), (0.95, 0.999)]\n", "[2023-09-05 13:06:38,275] [INFO] [config.py:960:print] DeepSpeedEngine configuration:\n", "[2023-09-05 13:06:38,276] [INFO] [config.py:964:print] activation_checkpointing_config {\n", " \"partition_activations\": false, \n", " \"contiguous_memory_optimization\": false, \n", " \"cpu_checkpointing\": false, \n", " \"number_checkpoints\": null, \n", " \"synchronize_checkpoint_boundary\": false, \n", " \"profile\": false\n", "}\n", "[2023-09-05 13:06:38,276] [INFO] [config.py:964:print] aio_config ................... {'block_size': 26214400, 'queue_depth': 32, 'thread_count': 1, 'single_submit': False, 'overlap_events': True}\n", "[2023-09-05 13:06:38,277] [INFO] [config.py:964:print] amp_enabled .................. False\n", "[2023-09-05 13:06:38,277] [INFO] [config.py:964:print] amp_params ................... False\n", "[2023-09-05 13:06:38,278] [INFO] [config.py:964:print] autotuning_config ............ {\n", " \"enabled\": false, \n", " \"start_step\": null, \n", " \"end_step\": null, \n", " \"metric_path\": null, \n", " \"arg_mappings\": null, \n", " \"metric\": \"throughput\", \n", " \"model_info\": null, \n", " \"results_dir\": \"autotuning_results\", \n", " \"exps_dir\": \"autotuning_exps\", \n", " \"overwrite\": true, \n", " \"fast\": true, \n", " \"start_profile_step\": 3, \n", " \"end_profile_step\": 5, \n", " \"tuner_type\": \"gridsearch\", \n", " \"tuner_early_stopping\": 5, \n", " \"tuner_num_trials\": 50, \n", " \"model_info_path\": null, \n", " \"mp_size\": 1, \n", " \"max_train_batch_size\": null, \n", " \"min_train_batch_size\": 1, \n", " \"max_train_micro_batch_size_per_gpu\": 1.024000e+03, \n", " \"min_train_micro_batch_size_per_gpu\": 1, \n", " \"num_tuning_micro_batch_sizes\": 3\n", "}\n", "[2023-09-05 13:06:38,279] [INFO] [config.py:964:print] bfloat16_enabled ............. False\n", "[2023-09-05 13:06:38,279] [INFO] [config.py:964:print] checkpoint_parallel_write_pipeline False\n", "[2023-09-05 13:06:38,280] [INFO] [config.py:964:print] checkpoint_tag_validation_enabled True\n", "[2023-09-05 13:06:38,280] [INFO] [config.py:964:print] checkpoint_tag_validation_fail False\n", "[2023-09-05 13:06:38,281] [INFO] [config.py:964:print] comms_config ................. \n", "[2023-09-05 13:06:38,282] [INFO] [config.py:964:print] communication_data_type ...... None\n", "[2023-09-05 13:06:38,282] [INFO] [config.py:964:print] compression_config ........... {'weight_quantization': {'shared_parameters': {'enabled': False, 'quantizer_kernel': False, 'schedule_offset': 0, 'quantize_groups': 1, 'quantize_verbose': False, 'quantization_type': 'symmetric', 'quantize_weight_in_forward': False, 'rounding': 'nearest', 'fp16_mixed_quantize': False, 'quantize_change_ratio': 0.001}, 'different_groups': {}}, 'activation_quantization': {'shared_parameters': {'enabled': False, 'quantization_type': 'symmetric', 'range_calibration': 'dynamic', 'schedule_offset': 1000}, 'different_groups': {}}, 'sparse_pruning': {'shared_parameters': {'enabled': False, 'method': 'l1', 'schedule_offset': 1000}, 'different_groups': {}}, 'row_pruning': {'shared_parameters': {'enabled': False, 'method': 'l1', 'schedule_offset': 1000}, 'different_groups': {}}, 'head_pruning': {'shared_parameters': {'enabled': False, 'method': 'topk', 'schedule_offset': 1000}, 'different_groups': {}}, 'channel_pruning': {'shared_parameters': {'enabled': False, 'method': 'l1', 'schedule_offset': 1000}, 'different_groups': {}}, 'layer_reduction': {'enabled': False}}\n", "[2023-09-05 13:06:38,283] [INFO] [config.py:964:print] curriculum_enabled_legacy .... False\n", "[2023-09-05 13:06:38,283] [INFO] [config.py:964:print] curriculum_params_legacy ..... False\n", "[2023-09-05 13:06:38,284] [INFO] [config.py:964:print] data_efficiency_config ....... {'enabled': False, 'seed': 1234, 'data_sampling': {'enabled': False, 'num_epochs': 1000, 'num_workers': 0, 'curriculum_learning': {'enabled': False}}, 'data_routing': {'enabled': False, 'random_ltd': {'enabled': False, 'layer_token_lr_schedule': {'enabled': False}}}}\n", "[2023-09-05 13:06:38,284] [INFO] [config.py:964:print] data_efficiency_enabled ...... False\n", "[2023-09-05 13:06:38,285] [INFO] [config.py:964:print] dataloader_drop_last ......... False\n", "[2023-09-05 13:06:38,285] [INFO] [config.py:964:print] disable_allgather ............ False\n", "[2023-09-05 13:06:38,286] [INFO] [config.py:964:print] dump_state ................... False\n", "[2023-09-05 13:06:38,287] [INFO] [config.py:964:print] dynamic_loss_scale_args ...... None\n", "[2023-09-05 13:06:38,287] [INFO] [config.py:964:print] eigenvalue_enabled ........... False\n", "[2023-09-05 13:06:38,288] [INFO] [config.py:964:print] eigenvalue_gas_boundary_resolution 1\n", "[2023-09-05 13:06:38,288] [INFO] [config.py:964:print] eigenvalue_layer_name ........ bert.encoder.layer\n", "[2023-09-05 13:06:38,289] [INFO] [config.py:964:print] eigenvalue_layer_num ......... 0\n", "[2023-09-05 13:06:38,289] [INFO] [config.py:964:print] eigenvalue_max_iter .......... 100\n", "[2023-09-05 13:06:38,290] [INFO] [config.py:964:print] eigenvalue_stability ......... 1e-06\n", "[2023-09-05 13:06:38,290] [INFO] [config.py:964:print] eigenvalue_tol ............... 0.01\n", "[2023-09-05 13:06:38,291] [INFO] [config.py:964:print] eigenvalue_verbose ........... False\n", "[2023-09-05 13:06:38,291] [INFO] [config.py:964:print] elasticity_enabled ........... False\n", "[2023-09-05 13:06:38,292] [INFO] [config.py:964:print] flops_profiler_config ........ {\n", " \"enabled\": false, \n", " \"recompute_fwd_factor\": 0.0, \n", " \"profile_step\": 1, \n", " \"module_depth\": -1, \n", " \"top_modules\": 1, \n", " \"detailed\": true, \n", " \"output_file\": null\n", "}\n", "[2023-09-05 13:06:38,293] [INFO] [config.py:964:print] fp16_auto_cast ............... False\n", "[2023-09-05 13:06:38,293] [INFO] [config.py:964:print] fp16_enabled ................. True\n", "[2023-09-05 13:06:38,294] [INFO] [config.py:964:print] fp16_master_weights_and_gradients False\n", "[2023-09-05 13:06:38,294] [INFO] [config.py:964:print] global_rank .................. 0\n", "[2023-09-05 13:06:38,295] [INFO] [config.py:964:print] grad_accum_dtype ............. None\n", "[2023-09-05 13:06:38,295] [INFO] [config.py:964:print] gradient_accumulation_steps .. 1\n", "[2023-09-05 13:06:38,296] [INFO] [config.py:964:print] gradient_clipping ............ 1.0\n", "[2023-09-05 13:06:38,296] [INFO] [config.py:964:print] gradient_predivide_factor .... 1.0\n", "[2023-09-05 13:06:38,297] [INFO] [config.py:964:print] hybrid_engine ................ enabled=False max_out_tokens=512 inference_tp_size=1 release_inference_cache=False pin_parameters=True tp_gather_partition_size=8\n", "[2023-09-05 13:06:38,297] [INFO] [config.py:964:print] initial_dynamic_scale ........ 65536\n", "[2023-09-05 13:06:38,298] [INFO] [config.py:964:print] load_universal_checkpoint .... False\n", "[2023-09-05 13:06:38,299] [INFO] [config.py:964:print] loss_scale ................... 0\n", "[2023-09-05 13:06:38,299] [INFO] [config.py:964:print] memory_breakdown ............. False\n", "[2023-09-05 13:06:38,300] [INFO] [config.py:964:print] mics_hierarchial_params_gather False\n", "[2023-09-05 13:06:38,300] [INFO] [config.py:964:print] mics_shard_size .............. -1\n", "[2023-09-05 13:06:38,301] [INFO] [config.py:964:print] monitor_config ............... tensorboard=TensorBoardConfig(enabled=False, output_path='', job_name='DeepSpeedJobName') wandb=WandbConfig(enabled=False, group=None, team=None, project='deepspeed') csv_monitor=CSVConfig(enabled=False, output_path='', job_name='DeepSpeedJobName') enabled=False\n", "[2023-09-05 13:06:38,302] [INFO] [config.py:964:print] nebula_config ................ {\n", " \"enabled\": false, \n", " \"persistent_storage_path\": null, \n", " \"persistent_time_interval\": 100, \n", " \"num_of_version_in_retention\": 2, \n", " \"enable_nebula_load\": true, \n", " \"load_path\": null\n", "}\n", "[2023-09-05 13:06:38,302] [INFO] [config.py:964:print] optimizer_legacy_fusion ...... False\n", "[2023-09-05 13:06:38,303] [INFO] [config.py:964:print] optimizer_name ............... None\n", "[2023-09-05 13:06:38,303] [INFO] [config.py:964:print] optimizer_params ............. None\n", "[2023-09-05 13:06:38,304] [INFO] [config.py:964:print] pipeline ..................... {'stages': 'auto', 'partition': 'best', 'seed_layers': False, 'activation_checkpoint_interval': 0}\n", "[2023-09-05 13:06:38,304] [INFO] [config.py:964:print] pld_enabled .................. False\n", "[2023-09-05 13:06:38,305] [INFO] [config.py:964:print] pld_params ................... False\n", "[2023-09-05 13:06:38,305] [INFO] [config.py:964:print] prescale_gradients ........... False\n", "[2023-09-05 13:06:38,306] [INFO] [config.py:964:print] scheduler_name ............... None\n", "[2023-09-05 13:06:38,307] [INFO] [config.py:964:print] scheduler_params ............. None\n", "[2023-09-05 13:06:38,307] [INFO] [config.py:964:print] sparse_attention ............. None\n", "[2023-09-05 13:06:38,308] [INFO] [config.py:964:print] sparse_gradients_enabled ..... False\n", "[2023-09-05 13:06:38,308] [INFO] [config.py:964:print] steps_per_print .............. inf\n", "[2023-09-05 13:06:38,309] [INFO] [config.py:964:print] train_batch_size ............. 128\n", "[2023-09-05 13:06:38,309] [INFO] [config.py:964:print] train_micro_batch_size_per_gpu 128\n", "[2023-09-05 13:06:38,310] [INFO] [config.py:964:print] use_node_local_storage ....... False\n", "[2023-09-05 13:06:38,310] [INFO] [config.py:964:print] wall_clock_breakdown ......... False\n", "[2023-09-05 13:06:38,311] [INFO] [config.py:964:print] world_size ................... 1\n", "[2023-09-05 13:06:38,312] [INFO] [config.py:964:print] zero_allow_untested_optimizer True\n", "[2023-09-05 13:06:38,312] [INFO] [config.py:964:print] zero_config .................. stage=2 contiguous_gradients=True reduce_scatter=True reduce_bucket_size=10000000 allgather_partitions=True allgather_bucket_size=500,000,000 overlap_comm=False load_from_fp32_weights=True elastic_checkpoint=False offload_param=DeepSpeedZeroOffloadParamConfig(device='none', nvme_path=PosixPath('/scratch'), buffer_count=5, buffer_size=4000000000, max_in_cpu=1,000,000,000, pin_memory=True) offload_optimizer=DeepSpeedZeroOffloadOptimizerConfig(device='none', nvme_path=PosixPath('/scratch'), buffer_count=4, pin_memory=True, pipeline=False, pipeline_read=False, pipeline_write=False, fast_init=False) sub_group_size=1000000000 cpu_offload_param=None cpu_offload_use_pin_memory=None cpu_offload=None prefetch_bucket_size=10000000 param_persistence_threshold=100000 model_persistence_threshold=sys.maxsize max_live_parameters=1000000000 max_reuse_distance=1000000000 gather_16bit_weights_on_model_save=True stage3_gather_fp16_weights_on_model_save=False ignore_unused_parameters=True legacy_stage1=False round_robin_gradients=False mics_shard_size=-1 mics_hierarchical_params_gather=False memory_efficient_linear=True\n", "[2023-09-05 13:06:38,313] [INFO] [config.py:964:print] zero_enabled ................. True\n", "[2023-09-05 13:06:38,313] [INFO] [config.py:964:print] zero_force_ds_cpu_optimizer .. True\n", "[2023-09-05 13:06:38,314] [INFO] [config.py:964:print] zero_optimization_stage ...... 2\n", "[2023-09-05 13:06:38,314] [INFO] [config.py:950:print_user_config] json = {\n", " \"bf16\": {\n", " \"enabled\": false\n", " }, \n", " \"fp16\": {\n", " \"enabled\": true\n", " }, \n", " \"zero_optimization\": {\n", " \"stage\": 2, \n", " \"contiguous_gradients\": true, \n", " \"stage3_gather_16bit_weights_on_model_save\": true, \n", " \"stage3_max_live_parameters\": 1.000000e+09, \n", " \"stage3_max_reuse_distance\": 1.000000e+09, \n", " \"stage3_prefetch_bucket_size\": 1.000000e+07, \n", " \"stage3_param_persistence_threshold\": 1.000000e+05, \n", " \"reduce_bucket_size\": 1.000000e+07, \n", " \"sub_group_size\": 1.000000e+09, \n", " \"offload_optimizer\": {\n", " \"device\": \"none\", \n", " \"nvme_path\": \"/scratch\", \n", " \"pin_memory\": true\n", " }, \n", " \"offload_param\": {\n", " \"device\": \"none\", \n", " \"nvme_path\": \"/scratch\", \n", " \"buffer_size\": 4.000000e+09, \n", " \"pin_memory\": true\n", " }\n", " }, \n", " \"aio\": {\n", " \"block_size\": 2.621440e+07, \n", " \"queue_depth\": 32, \n", " \"thread_count\": 1, \n", " \"single_submit\": false, \n", " \"overlap_events\": true\n", " }, \n", " \"gradient_accumulation_steps\": 1, \n", " \"gradient_clipping\": 1.0, \n", " \"steps_per_print\": inf, \n", " \"train_batch_size\": 128, \n", " \"train_micro_batch_size_per_gpu\": 128, \n", " \"wall_clock_breakdown\": false, \n", " \"zero_allow_untested_optimizer\": true\n", "}\n" ] } ], "source": [ "model, optimizer, train_dl, test_dl, lr_scheduler = accelerator.prepare(\n", "model, optimizer, train_dl, test_dl, lr_scheduler\n", ")" ] }, { "cell_type": "code", "execution_count": 22, "id": "b4a3368c-e6ce-49cc-b970-ee3dba12dfcd", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "test starting with epoch 0 / 12\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/12 [00:00