diff --git a/.gitattributes b/.gitattributes index 6ecdafb819e5f2473034facb62862b3c25234ba4..fda315d816776633c29e95601c04cff35f665aad 100644 --- a/.gitattributes +++ b/.gitattributes @@ -135,3 +135,4 @@ Vaani/output_image2.png filter=lfs diff=lfs merge=lfs -text Vaani/sampleJSON.csv filter=lfs diff=lfs merge=lfs -text Vaani/sampleJSON.json filter=lfs diff=lfs merge=lfs -text tools/__pycache__/pynvml.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text +Vaani/VaaniLDM/samplesH/x0_0.png filter=lfs diff=lfs merge=lfs -text diff --git a/Vaani/SDFT/_2.ipynb b/Vaani/SDFT/_2.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..84d7154a385ff2bdab676a8b3c8a688c1b0aa39c --- /dev/null +++ b/Vaani/SDFT/_2.ipynb @@ -0,0 +1,675 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "aab59bea", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'cuda'" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "import torch.optim as optim\n", + "from torch.utils.data import Dataset, DataLoader\n", + "from torchvision import transforms\n", + "from torchvision.transforms import v2\n", + "from PIL import Image\n", + "from diffusers import StableDiffusionPipeline\n", + "from diffusers.optimization import get_scheduler\n", + "from torch import nn\n", + "import torch.nn.functional as F\n", + "import os\n", + "import pandas as pd\n", + "from tqdm import trange, tqdm\n", + "\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "device" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8f13b66f", + "metadata": {}, + "outputs": [], + "source": [ + "# import torch\n", + "# import torch.nn as nn\n", + "# import torch.nn.functional as F\n", + "\n", + "# audio_embed_dim = 1280\n", + "# output_dim = 768\n", + "# device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "\n", + "# context_projector = nn.Sequential(\n", + "# nn.Linear(audio_embed_dim, 320),\n", + "# nn.SiLU(),\n", + "# nn.Linear(320, output_dim)\n", + "# ).to(device).half()\n", + "\n", + "# # Dummy input\n", + "# audio_embedding = dummy_audio = torch.zeros(10, 1500, 1280, device=device, dtype=torch.float16)\n", + "# print(audio_embedding.shape) # [10, 1500, 1280]\n", + "\n", + "# # Project audio to [10, 1500, 768]\n", + "# projected = context_projector(audio_embedding)\n", + "# print(projected.shape) # [10, 1500, 768]\n", + "\n", + "# # Compute attention scores: reduce feature dim to scalar per time step\n", + "# attn_scores = projected.mean(dim=2) # [10, 1500]\n", + "# attn_weights = F.softmax(attn_scores, dim=1) # [10, 1500]\n", + "# attn_weights = attn_weights.unsqueeze(2) # [10, 1500, 1]\n", + "\n", + "# # Weighted average\n", + "# pooled = (projected * attn_weights).sum(dim=1, keepdim=True) # [10, 1, 768]\n", + "# print(pooled.shape) # Final shape: [10, 1, 768]\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d32b7d9d", + "metadata": {}, + "outputs": [], + "source": [ + "# === Helpers ===\n", + "def walkDIR(folder_path, include=None):\n", + " file_list = []\n", + " for root, _, files in os.walk(folder_path):\n", + " for file in files:\n", + " if include is None or any(file.endswith(ext) for ext in include):\n", + " file_list.append(os.path.join(root, file))\n", + " print(\"Files found:\", len(file_list))\n", + " return file_list\n", + "\n", + "# === Dataset Class ===\n", + "class VaaniDataset(torch.utils.data.Dataset):\n", + " def __init__(self, files_paths, im_size):\n", + " self.files_paths = files_paths\n", + " self.im_size = im_size\n", + "\n", + " def __len__(self):\n", + " return len(self.files_paths)\n", + "\n", + " def __getitem__(self, idx):\n", + " # image = tv.io.read_image(self.files_paths[idx], mode=tv.io.ImageReadMode.RGB)\n", + " image = Image.open(self.files_paths[idx]).convert(\"RGB\")\n", + " image = v2.ToImage()(image)\n", + " # image = tv.io.decode_image(self.files_paths[idx], mode=tv.io.ImageReadMode.RGB)\n", + " image = v2.Resize((self.im_size, self.im_size))(image)\n", + " image = v2.ToDtype(torch.float32, scale=True)(image)\n", + " # image = 2*image - 1\n", + " return image\n", + "\n", + "\n", + "def create_dataloader(dataset, batch_size, debug=False, val_split=0.1, num_workers=4):\n", + " if debug:\n", + " s = 0.001\n", + " dataset, _ = torch.utils.data.random_split(dataset, [s, 1-s], torch.manual_seed(42))\n", + " print(\"Length of Train dataset:\", len(dataset))\n", + "\n", + " train_dataloader = DataLoader(\n", + " dataset, \n", + " batch_size=batch_size, \n", + " shuffle=True, \n", + " num_workers=num_workers,\n", + " pin_memory=True,\n", + " drop_last=True,\n", + " persistent_workers=True\n", + " )\n", + " \n", + " images = next(iter(train_dataloader))\n", + " print('Total Batches:', len(train_dataloader))\n", + " print('BATCH SHAPE:', images.shape)\n", + " return train_dataloader\n", + "\n", + "# === Audio Context Projector ===\n", + "# class AudioContextProjector(nn.Module):\n", + "# def __init__(self, audio_embed_dim):\n", + "# super().__init__()\n", + "# self.audio_embed_dim = audio_embed_dim\n", + "# self.context_projector = nn.Sequential(\n", + "# nn.Linear(audio_embed_dim, 320),\n", + "# nn.SiLU(),\n", + "# nn.Linear(320, 1)\n", + "# )\n", + "\n", + "# def forward(self, audio_embedding):\n", + "# if audio_embedding.size(-1) != self.audio_embed_dim:\n", + "# raise ValueError(f\"Expected audio embedding dim {self.audio_embed_dim}, got {audio_embedding.size(-1)}\")\n", + "# weights = self.context_projector(audio_embedding) # [B, T, 1]\n", + "# weights = torch.softmax(weights, dim=1) # [B, T, 1]\n", + "# pooled = (audio_embedding * weights).sum(dim=1) # [B, 1280]\n", + "# return pooled.unsqueeze(1) # [B, 1, 1280]\n", + "# class AudioContextProjector(nn.Module):\n", + "# def __init__(self, audio_embed_dim=1280, output_dim=768): # Add output_dim for flexibility\n", + "# super().__init__()\n", + "# self.audio_embed_dim = audio_embed_dim\n", + "# self.context_projector = nn.Sequential(\n", + "# nn.Linear(audio_embed_dim, 320),\n", + "# nn.SiLU(),\n", + "# nn.Linear(320, output_dim) # Output 768 to match UNet's expectation\n", + "# )\n", + "\n", + "# def forward(self, audio_embedding):\n", + "# if audio_embedding.size(-1) != self.audio_embed_dim:\n", + "# raise ValueError(f\"Expected audio embedding dim {self.audio_embed_dim}, got {audio_embedding.size(-1)}\")\n", + "# weights = self.context_projector(audio_embedding) # [B, T, 768]\n", + "# weights = torch.softmax(pooled, dim=1) # [B, T, 768]\n", + "# pooled = (audio_embedding * weights).sum(dim=1) # [B, 768]\n", + "# return pooled.unsqueeze(1) # [B, 1, 768]\n", + "class AudioContextProjector(nn.Module):\n", + " def __init__(self, audio_embed_dim=1280, output_dim=768):\n", + " super().__init__()\n", + " self.audio_embed_dim = audio_embed_dim\n", + " self.output_dim = output_dim\n", + " self.context_projector = nn.Sequential(\n", + " nn.Linear(audio_embed_dim, 320),\n", + " nn.SiLU(),\n", + " nn.Linear(320, output_dim) # Output 768 to match UNet's expectation\n", + " )\n", + "\n", + " def forward(self, audio_embedding):\n", + " if audio_embedding.size(-1) != self.audio_embed_dim:\n", + " raise ValueError(f\"Expected audio embedding dim {self.audio_embed_dim}, got {audio_embedding.size(-1)}\")\n", + "\n", + " # Project to [B, T, 768]\n", + " projected = self.context_projector(audio_embedding) # [B, T, 768]\n", + "\n", + " # Compute scalar attention scores per timestep\n", + " attn_scores = projected.mean(dim=2) # [B, T]\n", + " attn_weights = F.softmax(attn_scores, dim=1) # [B, T]\n", + " attn_weights = attn_weights.unsqueeze(2) # [B, T, 1]\n", + "\n", + " # Apply attention to the projected embeddings\n", + " pooled = (projected * attn_weights).sum(dim=1, keepdim=True) # [B, 1, 768]\n", + " return pooled\n", + "\n", + "\n", + "\n", + "# === Inference Function ===\n", + "def run_inference(pipe, unet, vae, device, context_hidden_states, save_path=\"inference_output.png\"):\n", + " pipe.unet = unet\n", + " pipe.vae = vae\n", + " pipe.to(device)\n", + "\n", + " batch_size = 1\n", + " latents = torch.randn((batch_size, pipe.unet.in_channels, 64, 64), device=device, dtype=torch.float16)\n", + " # latents = torch.randn((batch_size, pipe.unet.config.in_channels, 64, 64), device=device, dtype=torch.float16)\n", + " pipe.scheduler.set_timesteps(50)\n", + " latents = latents * pipe.scheduler.init_noise_sigma\n", + " \n", + " expected_shape = (batch_size, 1, 768) # Adjust based on model\n", + " if context_hidden_states.shape != expected_shape:\n", + " raise ValueError(f\"Expected context_hidden_states shape {expected_shape}, got {context_hidden_states.shape}\")\n", + " \n", + " for t in pipe.scheduler.timesteps:\n", + " with torch.no_grad():\n", + " noise_pred = pipe.unet(latents, t, encoder_hidden_states=context_hidden_states).sample\n", + " latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample\n", + "\n", + " # latents = 1 / 0.18215 * latents\n", + " latents = 1 / pipe.vae.config.scaling_factor * latents\n", + " with torch.no_grad():\n", + " image = pipe.vae.decode(latents).sample\n", + "\n", + " image = (image / 2 + 0.5).clamp(0, 1)\n", + " image = image.cpu().permute(0, 2, 3, 1).numpy()[0]\n", + " image = Image.fromarray((image * 255).astype(\"uint8\"))\n", + " image.save(save_path)\n", + " print(f\"Inference image saved to {save_path}\")\n", + "\n", + "\n", + "# === Load Pipeline ===\n", + "def load_pipeline(model_id, device):\n", + " pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)\n", + " unet = pipe.unet\n", + " vae = pipe.vae\n", + " return pipe, unet, vae\n", + "\n", + "# === Freeze Layers Function ===\n", + "def freeze_vae_layers(vae):\n", + " vae.encoder.requires_grad_(False)\n", + " vae.quant_conv.requires_grad_(False)\n", + " vae.decoder.requires_grad_(True)\n", + " vae.post_quant_conv.requires_grad_(True)\n", + "\n", + "def freeze_unet_layers(unet):\n", + " for name, param in unet.named_parameters():\n", + " if \"attn2\" in name or \"conv2\" in name:\n", + " param.requires_grad = True\n", + " else:\n", + " param.requires_grad = False\n", + "\n", + "# === Optimizer Setup ===\n", + "def setup_optimizer(vae, unet, projector, lr):\n", + " params_to_optimize = list(filter(lambda p: p.requires_grad, vae.parameters())) + \\\n", + " list(filter(lambda p: p.requires_grad, unet.parameters())) + \\\n", + " list(filter(lambda p: p.requires_grad, projector.parameters()))\n", + " optimizer = optim.AdamW(params_to_optimize, lr=lr)\n", + " return optimizer\n", + "\n", + "\n", + "# === Gradient Accumulation Function ===\n", + "def accumulate_gradients(optimizer, loss, gradient_accumulation_steps, step, dataloader):\n", + " loss = loss / gradient_accumulation_steps\n", + " loss.backward()\n", + " if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == len(dataloader):\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + "\n", + "# === Save Checkpoint Function ===\n", + "def save_checkpoint(epoch, unet, vae, projector, optimizer, checkpoint_path):\n", + " # checkpoint_path = f\"{save_dir}/checkpoint.pth\"\n", + " torch.save({\n", + " 'epoch': epoch,\n", + " 'unet': unet.state_dict(),\n", + " 'vae': vae.state_dict(),\n", + " 'projector': projector.state_dict(),\n", + " 'optimizer': optimizer.state_dict(),\n", + " }, checkpoint_path)\n", + " print(f\"Checkpoint saved to {checkpoint_path}\")\n", + "\n", + "# === Resume from Checkpoint Function ===\n", + "def resume_from_checkpoint(checkpoint_path, unet, vae, projector, optimizer):\n", + " if os.path.exists(checkpoint_path):\n", + " checkpoint = torch.load(checkpoint_path, map_location='cpu')\n", + " unet.load_state_dict(checkpoint['unet'])\n", + " vae.load_state_dict(checkpoint['vae'])\n", + " projector.load_state_dict(checkpoint['projector'])\n", + " optimizer.load_state_dict(checkpoint['optimizer'])\n", + " start_epoch = checkpoint['epoch'] + 1\n", + " print(f\"Resuming training from epoch {start_epoch}...\")\n", + " return start_epoch\n", + " else:\n", + " print(\"No checkpoint found, starting from scratch.\")\n", + " return 0\n", + "\n", + "\n", + "# === Training Loop Function ===\n", + "def train_loop(dataloader, unet, vae, optimizer, gradient_accumulation_steps, device, num_epochs, samples_path, checkpoint_path, pipe, projector):\n", + " start_epoch = resume_from_checkpoint(checkpoint_path, unet, vae, projector, optimizer)\n", + "\n", + " for epoch in trange(start_epoch, num_epochs, colour='red', desc=f'{device}-training', ncols=100):\n", + " unet.train()\n", + " vae.train()\n", + " projector.train()\n", + " total_loss = 0\n", + " step = 0\n", + " \n", + " for image in tqdm(dataloader, colour='green', desc=f'{device}-batch', ncols=100):\n", + " # print(\"step:\", step)\n", + " image = image.to(device, dtype=torch.float16)\n", + "\n", + " latents = vae.encode(image).latent_dist.sample() * 0.18215\n", + " noise = torch.randn_like(latents)\n", + " # timesteps = torch.randint(0, 1000, (latents.shape[0],), device=device).long()\n", + " timesteps = torch.randint(0, pipe.scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long()\n", + "\n", + " # === Use dummy audio embedding ===\n", + " dummy_audio = torch.zeros(image.size(0), 1500, 1280, device=device, dtype=torch.float16)\n", + " context_hidden_states = projector(dummy_audio)\n", + "\n", + " # print(\"Model IP\")\n", + " noise_pred = unet(latents + noise, timesteps, encoder_hidden_states=context_hidden_states).sample\n", + " # print(\"Model OP\")\n", + "\n", + " loss = nn.MSELoss()(noise_pred, noise)\n", + " total_loss += loss.item()\n", + "\n", + " step += 1\n", + " accumulate_gradients(optimizer, loss, gradient_accumulation_steps, step, dataloader)\n", + "\n", + " avg_loss = total_loss / len(dataloader)\n", + " print(f\"Epoch {epoch + 1} | Avg Loss: {avg_loss:.6f}\")\n", + "\n", + " save_checkpoint(epoch, unet, vae, projector, optimizer, checkpoint_path)\n", + " run_inference(pipe, unet, vae, device, context_hidden_states, save_path=f\"{samples_path}/inference_epoch{epoch + 1}.png\")\n", + "\n", + " print(\"\\n✅ Fine-tuning complete.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9ad5f6a3", + "metadata": {}, + "outputs": [], + "source": [ + "# === Main Function ===\n", + "def main():\n", + " model_id = \"runwayml/stable-diffusion-v1-5\"\n", + " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + " lr = 1e-5\n", + " num_epochs = 10\n", + " batch_size = 16\n", + " debug = False\n", + " gradient_accumulation_steps = 1\n", + " \n", + " os.makedirs(f\"./checkpoints\", exist_ok=True)\n", + " os.makedirs(f\"./samples\", exist_ok=True)\n", + " checkpoint_path = f\"./checkpoints/checkpoint.pth\"\n", + " samples_path = f\"./samples\"\n", + " image_dir = \"/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Images\"\n", + "\n", + " pipe, unet, vae = load_pipeline(model_id, device)\n", + " freeze_vae_layers(vae)\n", + " freeze_unet_layers(unet)\n", + " projector = AudioContextProjector(audio_embed_dim=1280, output_dim=768).to(device).half()\n", + " optimizer = setup_optimizer(vae, unet, projector, lr)\n", + "\n", + " # === Dataset & Dataloader ===\n", + " files = walkDIR(image_dir, include=['.png', '.jpeg', '.jpg'])\n", + " dataset = VaaniDataset(files_paths=files, im_size=256)\n", + " image = dataset[2]\n", + " print('IMAGE SHAPE:', image.shape, \"Dataset len:\", len(dataset))\n", + " dataloader = create_dataloader(dataset, batch_size, debug=debug)\n", + "\n", + " train_loop(dataloader, unet, vae, optimizer, gradient_accumulation_steps, device, num_epochs, samples_path, checkpoint_path, pipe, projector)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "e71b4ba9", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Couldn't connect to the Hub: (MaxRetryError('HTTPSConnectionPool(host=\\'huggingface.co\\', port=443): Max retries exceeded with url: /api/models/runwayml/stable-diffusion-v1-5 (Caused by NameResolutionError(\": Failed to resolve \\'huggingface.co\\' ([Errno -2] Name or service not known)\"))'), '(Request ID: bcd4fcc3-8634-4bfe-8454-3b4dbdcc1222)').\n", + "Will try to load from local cache.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1014662fa9c44a00b0e9e6b3d1e9747d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading pipeline components...: 0%| | 0/7 [00:00 \u001b[39m\u001b[32m2\u001b[39m \u001b[43mmain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 30\u001b[39m, in \u001b[36mmain\u001b[39m\u001b[34m()\u001b[39m\n\u001b[32m 27\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m'\u001b[39m\u001b[33mIMAGE SHAPE:\u001b[39m\u001b[33m'\u001b[39m, image.shape, \u001b[33m\"\u001b[39m\u001b[33mDataset len:\u001b[39m\u001b[33m\"\u001b[39m, \u001b[38;5;28mlen\u001b[39m(dataset))\n\u001b[32m 28\u001b[39m dataloader = create_dataloader(dataset, batch_size, debug=debug)\n\u001b[32m---> \u001b[39m\u001b[32m30\u001b[39m \u001b[43mtrain_loop\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataloader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43munet\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvae\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient_accumulation_steps\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_epochs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msamples_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcheckpoint_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpipe\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprojector\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 228\u001b[39m, in \u001b[36mtrain_loop\u001b[39m\u001b[34m(dataloader, unet, vae, optimizer, gradient_accumulation_steps, device, num_epochs, samples_path, checkpoint_path, pipe, projector)\u001b[39m\n\u001b[32m 226\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m image \u001b[38;5;129;01min\u001b[39;00m tqdm(dataloader, colour=\u001b[33m'\u001b[39m\u001b[33mgreen\u001b[39m\u001b[33m'\u001b[39m, desc=\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdevice\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m-batch\u001b[39m\u001b[33m'\u001b[39m, ncols=\u001b[32m100\u001b[39m):\n\u001b[32m 227\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[33mstep:\u001b[39m\u001b[33m\"\u001b[39m, step)\n\u001b[32m--> \u001b[39m\u001b[32m228\u001b[39m image = \u001b[43mimage\u001b[49m\u001b[43m.\u001b[49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtorch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mfloat16\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 230\u001b[39m latents = vae.encode(image).latent_dist.sample() * \u001b[32m0.18215\u001b[39m\n\u001b[32m 231\u001b[39m noise = torch.randn_like(latents)\n", + "\u001b[31mKeyboardInterrupt\u001b[39m: " + ] + } + ], + "source": [ + "if __name__ == \"__main__\":\n", + " main()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/Vaani/SDFT/_2_.py b/Vaani/SDFT/_2_.py new file mode 100644 index 0000000000000000000000000000000000000000..cf32f91cdd212495994c78095ed0caff5f4373f3 --- /dev/null +++ b/Vaani/SDFT/_2_.py @@ -0,0 +1,345 @@ +import torch +import torch.optim as optim +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +from torchvision.transforms import v2 +from PIL import Image +from diffusers import StableDiffusionPipeline +from diffusers.optimization import get_scheduler +from torch import nn +import torch.nn.functional as F +import os +import pandas as pd +from tqdm import trange, tqdm + +os.environ["CUDA_VISIBLE_DEVICES"] = "1" +device = "cuda" if torch.cuda.is_available() else "cpu" +device + + +# import torch +# import torch.nn as nn +# import torch.nn.functional as F + +# audio_embed_dim = 1280 +# output_dim = 768 +# device = 'cuda' if torch.cuda.is_available() else 'cpu' + +# context_projector = nn.Sequential( +# nn.Linear(audio_embed_dim, 320), +# nn.SiLU(), +# nn.Linear(320, output_dim) +# ).to(device).half() + +# # Dummy input +# audio_embedding = dummy_audio = torch.zeros(10, 1500, 1280, device=device, dtype=torch.float16) +# print(audio_embedding.shape) # [10, 1500, 1280] + +# # Project audio to [10, 1500, 768] +# projected = context_projector(audio_embedding) +# print(projected.shape) # [10, 1500, 768] + +# # Compute attention scores: reduce feature dim to scalar per time step +# attn_scores = projected.mean(dim=2) # [10, 1500] +# attn_weights = F.softmax(attn_scores, dim=1) # [10, 1500] +# attn_weights = attn_weights.unsqueeze(2) # [10, 1500, 1] + +# # Weighted average +# pooled = (projected * attn_weights).sum(dim=1, keepdim=True) # [10, 1, 768] +# print(pooled.shape) # Final shape: [10, 1, 768] + + + +# === Helpers === +def walkDIR(folder_path, include=None): + file_list = [] + for root, _, files in os.walk(folder_path): + for file in files: + if include is None or any(file.endswith(ext) for ext in include): + file_list.append(os.path.join(root, file)) + print("Files found:", len(file_list)) + return file_list + +# === Dataset Class === +class VaaniDataset(torch.utils.data.Dataset): + def __init__(self, files_paths, im_size): + self.files_paths = files_paths + self.im_size = im_size + + def __len__(self): + return len(self.files_paths) + + def __getitem__(self, idx): + # image = tv.io.read_image(self.files_paths[idx], mode=tv.io.ImageReadMode.RGB) + image = Image.open(self.files_paths[idx]).convert("RGB") + image = v2.ToImage()(image) + # image = tv.io.decode_image(self.files_paths[idx], mode=tv.io.ImageReadMode.RGB) + image = v2.Resize((self.im_size, self.im_size))(image) + image = v2.ToDtype(torch.float32, scale=True)(image) + # image = 2*image - 1 + return image + + +def create_dataloader(dataset, batch_size, debug=False, val_split=0.1, num_workers=4): + if debug: + s = 0.001 + dataset, _ = torch.utils.data.random_split(dataset, [s, 1-s], torch.manual_seed(42)) + print("Length of Train dataset:", len(dataset)) + + train_dataloader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=True, + drop_last=True, + persistent_workers=True + ) + + images = next(iter(train_dataloader)) + print('Total Batches:', len(train_dataloader)) + print('BATCH SHAPE:', images.shape) + return train_dataloader + +# === Audio Context Projector === +# class AudioContextProjector(nn.Module): +# def __init__(self, audio_embed_dim): +# super().__init__() +# self.audio_embed_dim = audio_embed_dim +# self.context_projector = nn.Sequential( +# nn.Linear(audio_embed_dim, 320), +# nn.SiLU(), +# nn.Linear(320, 1) +# ) + +# def forward(self, audio_embedding): +# if audio_embedding.size(-1) != self.audio_embed_dim: +# raise ValueError(f"Expected audio embedding dim {self.audio_embed_dim}, got {audio_embedding.size(-1)}") +# weights = self.context_projector(audio_embedding) # [B, T, 1] +# weights = torch.softmax(weights, dim=1) # [B, T, 1] +# pooled = (audio_embedding * weights).sum(dim=1) # [B, 1280] +# return pooled.unsqueeze(1) # [B, 1, 1280] +# class AudioContextProjector(nn.Module): +# def __init__(self, audio_embed_dim=1280, output_dim=768): # Add output_dim for flexibility +# super().__init__() +# self.audio_embed_dim = audio_embed_dim +# self.context_projector = nn.Sequential( +# nn.Linear(audio_embed_dim, 320), +# nn.SiLU(), +# nn.Linear(320, output_dim) # Output 768 to match UNet's expectation +# ) + +# def forward(self, audio_embedding): +# if audio_embedding.size(-1) != self.audio_embed_dim: +# raise ValueError(f"Expected audio embedding dim {self.audio_embed_dim}, got {audio_embedding.size(-1)}") +# weights = self.context_projector(audio_embedding) # [B, T, 768] +# weights = torch.softmax(pooled, dim=1) # [B, T, 768] +# pooled = (audio_embedding * weights).sum(dim=1) # [B, 768] +# return pooled.unsqueeze(1) # [B, 1, 768] +class AudioContextProjector(nn.Module): + def __init__(self, audio_embed_dim=1280, output_dim=768): + super().__init__() + self.audio_embed_dim = audio_embed_dim + self.output_dim = output_dim + self.context_projector = nn.Sequential( + nn.Linear(audio_embed_dim, 320), + nn.SiLU(), + nn.Linear(320, output_dim) # Output 768 to match UNet's expectation + ) + + def forward(self, audio_embedding): + if audio_embedding.size(-1) != self.audio_embed_dim: + raise ValueError(f"Expected audio embedding dim {self.audio_embed_dim}, got {audio_embedding.size(-1)}") + + # Project to [B, T, 768] + projected = self.context_projector(audio_embedding) # [B, T, 768] + + # Compute scalar attention scores per timestep + attn_scores = projected.mean(dim=2) # [B, T] + attn_weights = F.softmax(attn_scores, dim=1) # [B, T] + attn_weights = attn_weights.unsqueeze(2) # [B, T, 1] + + # Apply attention to the projected embeddings + pooled = (projected * attn_weights).sum(dim=1, keepdim=True) # [B, 1, 768] + return pooled + + + +# === Inference Function === +def run_inference(pipe, unet, vae, device, context_hidden_states, save_path="inference_output.png"): + pipe.unet = unet + pipe.vae = vae + pipe.to(device) + + batch_size = 1 + latents = torch.randn((batch_size, pipe.unet.in_channels, 64, 64), device=device, dtype=torch.float16) + # latents = torch.randn((batch_size, pipe.unet.config.in_channels, 64, 64), device=device, dtype=torch.float16) + pipe.scheduler.set_timesteps(50) + latents = latents * pipe.scheduler.init_noise_sigma + + expected_shape = (batch_size, 1, 768) # Adjust based on model + if context_hidden_states.shape != expected_shape: + raise ValueError(f"Expected context_hidden_states shape {expected_shape}, got {context_hidden_states.shape}") + + for t in pipe.scheduler.timesteps: + with torch.no_grad(): + noise_pred = pipe.unet(latents, t, encoder_hidden_states=context_hidden_states).sample + latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample + + # latents = 1 / 0.18215 * latents + latents = 1 / pipe.vae.config.scaling_factor * latents + with torch.no_grad(): + image = pipe.vae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy()[0] + image = Image.fromarray((image * 255).astype("uint8")) + image.save(save_path) + print(f"Inference image saved to {save_path}") + + +# === Load Pipeline === +def load_pipeline(model_id, device): + pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device) + unet = pipe.unet + vae = pipe.vae + return pipe, unet, vae + +# === Freeze Layers Function === +def freeze_vae_layers(vae): + vae.encoder.requires_grad_(False) + vae.quant_conv.requires_grad_(False) + vae.decoder.requires_grad_(True) + vae.post_quant_conv.requires_grad_(True) + +def freeze_unet_layers(unet): + for name, param in unet.named_parameters(): + if "attn2" in name or "conv2" in name: + param.requires_grad = True + else: + param.requires_grad = False + +# === Optimizer Setup === +def setup_optimizer(vae, unet, projector, lr): + params_to_optimize = list(filter(lambda p: p.requires_grad, vae.parameters())) + \ + list(filter(lambda p: p.requires_grad, unet.parameters())) + \ + list(filter(lambda p: p.requires_grad, projector.parameters())) + optimizer = optim.AdamW(params_to_optimize, lr=lr) + return optimizer + + +# === Gradient Accumulation Function === +def accumulate_gradients(optimizer, loss, gradient_accumulation_steps, step, dataloader): + loss = loss / gradient_accumulation_steps + loss.backward() + if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == len(dataloader): + optimizer.step() + optimizer.zero_grad() + +# === Save Checkpoint Function === +def save_checkpoint(epoch, unet, vae, projector, optimizer, checkpoint_path): + # checkpoint_path = f"{save_dir}/checkpoint.pth" + torch.save({ + 'epoch': epoch, + 'unet': unet.state_dict(), + 'vae': vae.state_dict(), + 'projector': projector.state_dict(), + 'optimizer': optimizer.state_dict(), + }, checkpoint_path) + print(f"Checkpoint saved to {checkpoint_path}") + +# === Resume from Checkpoint Function === +def resume_from_checkpoint(checkpoint_path, unet, vae, projector, optimizer): + if os.path.exists(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + unet.load_state_dict(checkpoint['unet']) + vae.load_state_dict(checkpoint['vae']) + projector.load_state_dict(checkpoint['projector']) + optimizer.load_state_dict(checkpoint['optimizer']) + start_epoch = checkpoint['epoch'] + 1 + print(f"Resuming training from epoch {start_epoch}...") + return start_epoch + else: + print("No checkpoint found, starting from scratch.") + return 0 + + +# === Training Loop Function === +def train_loop(dataloader, unet, vae, optimizer, gradient_accumulation_steps, device, num_epochs, samples_path, checkpoint_path, pipe, projector): + start_epoch = resume_from_checkpoint(checkpoint_path, unet, vae, projector, optimizer) + + for epoch in trange(start_epoch, num_epochs, colour='red', desc=f'{device}-training', dynamic_ncols=True): + unet.train() + vae.train() + projector.train() + total_loss = 0 + step = 0 + + for image in tqdm(dataloader, colour='green', desc=f'{device}-batch', dynamic_ncols=True): + # print("step:", step) + image = image.to(device, dtype=torch.float16) + + latents = vae.encode(image).latent_dist.sample() * 0.18215 + noise = torch.randn_like(latents) + # timesteps = torch.randint(0, 1000, (latents.shape[0],), device=device).long() + timesteps = torch.randint(0, pipe.scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long() + + # === Use dummy audio embedding === + dummy_audio = torch.zeros(image.size(0), 1500, 1280, device=device, dtype=torch.float16) + context_hidden_states = projector(dummy_audio) + + # print("Model IP") + noise_pred = unet(latents + noise, timesteps, encoder_hidden_states=context_hidden_states).sample + # print("Model OP") + + loss = nn.MSELoss()(noise_pred, noise) + total_loss += loss.item() + + step += 1 + accumulate_gradients(optimizer, loss, gradient_accumulation_steps, step, dataloader) + + avg_loss = total_loss / len(dataloader) + print(f"Epoch {epoch + 1} | Avg Loss: {avg_loss:.6f}") + + save_checkpoint(epoch, unet, vae, projector, optimizer, checkpoint_path) + run_inference(pipe, unet, vae, device, context_hidden_states, save_path=f"{samples_path}/inference_epoch{epoch + 1}.png") + + print("\n✅ Fine-tuning complete.") + + +# === Main Function === +def main(): + model_id = "runwayml/stable-diffusion-v1-5" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + lr = 1e-5 + num_epochs = 10 + batch_size = 16 + debug = False + gradient_accumulation_steps = 1 + + os.makedirs(f"./checkpoints", exist_ok=True) + os.makedirs(f"./samples", exist_ok=True) + checkpoint_path = f"./checkpoints/checkpoint.pth" + samples_path = f"./samples" + image_dir = "/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Images" + + pipe, unet, vae = load_pipeline(model_id, device) + freeze_vae_layers(vae) + freeze_unet_layers(unet) + projector = AudioContextProjector(audio_embed_dim=1280, output_dim=768).to(device).half() + optimizer = setup_optimizer(vae, unet, projector, lr) + + # === Dataset & Dataloader === + files = walkDIR(image_dir, include=['.png', '.jpeg', '.jpg']) + dataset = VaaniDataset(files_paths=files, im_size=256) + image = dataset[2] + print('IMAGE SHAPE:', image.shape, "Dataset len:", len(dataset)) + dataloader = create_dataloader(dataset, batch_size, debug=debug) + + train_loop(dataloader, unet, vae, optimizer, gradient_accumulation_steps, device, num_epochs, samples_path, checkpoint_path, pipe, projector) + + +if __name__ == "__main__": + main() + + diff --git a/Vaani/SDFT/_2_DDP.py b/Vaani/SDFT/_2_DDP.py new file mode 100644 index 0000000000000000000000000000000000000000..cef1b872bef45db8f849d95abdeb171cb4bc1990 --- /dev/null +++ b/Vaani/SDFT/_2_DDP.py @@ -0,0 +1,316 @@ +import torch +import torch.optim as optim +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +from torchvision.transforms import v2 +from PIL import Image +from diffusers import StableDiffusionPipeline +from diffusers.optimization import get_scheduler +from torch import nn +import torch.nn.functional as F +import os +import pandas as pd +from tqdm import trange, tqdm +# DDP Imports +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data.distributed import DistributedSampler +import torch.multiprocessing as mp + +# Set CUDA_VISIBLE_DEVICES +# os.environ["CUDA_VISIBLE_DEVICES"] = "1" +device = "cuda" if torch.cuda.is_available() else "cpu" + +# === Helpers === +def walkDIR(folder_path, include=None): + file_list = [] + for root, _, files in os.walk(folder_path): + for file in files: + if include is None or any(file.endswith(ext) for ext in include): + file_list.append(os.path.join(root, file)) + print("Files found:", len(file_list)) + return file_list + +# === Dataset Class === +class VaaniDataset(torch.utils.data.Dataset): + def __init__(self, files_paths, im_size): + self.files_paths = files_paths + self.im_size = im_size + + def __len__(self): + return len(self.files_paths) + + def __getitem__(self, idx): + image = Image.open(self.files_paths[idx]).convert("RGB") + image = v2.ToImage()(image) + image = v2.Resize((self.im_size, self.im_size))(image) + image = v2.ToDtype(torch.float32, scale=True)(image) + return image + +# === Modified create_dataloader for DDP and single GPU === +def create_dataloader(dataset, batch_size, debug=False, val_split=0.1, num_workers=4, rank=None, is_distributed=False): + if debug: + s = 0.001 + dataset, _ = torch.utils.data.random_split(dataset, [s, 1-s], torch.manual_seed(42)) + print(f"{'Rank ' + str(rank) + ': ' if rank is not None else ''}Length of Train dataset: {len(dataset)}") + + # Use DistributedSampler only if DDP is active + sampler = DistributedSampler(dataset, shuffle=True) if is_distributed else None + train_dataloader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=(sampler is None), + sampler=sampler, + num_workers=num_workers, + pin_memory=True, + drop_last=True, + persistent_workers=True + ) + + images = next(iter(train_dataloader)) + if rank is not None: + print(f"Rank {rank}: Total Batches: {len(train_dataloader)}") + print(f"Rank {rank}: BATCH SHAPE: {images.shape}") + else: + print(f"Total Batches: {len(train_dataloader)}") + print(f"BATCH SHAPE: {images.shape}") + return train_dataloader + +# === Audio Context Projector === +class AudioContextProjector(nn.Module): + def __init__(self, audio_embed_dim=1280, output_dim=768): + super().__init__() + self.audio_embed_dim = audio_embed_dim + self.output_dim = output_dim + self.context_projector = nn.Sequential( + nn.Linear(audio_embed_dim, 320), + nn.SiLU(), + nn.Linear(320, output_dim) + ) + + def forward(self, audio_embedding): + if audio_embedding.size(-1) != self.audio_embed_dim: + raise ValueError(f"Expected audio embedding dim {self.audio_embed_dim}, got {audio_embedding.size(-1)}") + projected = self.context_projector(audio_embedding) + attn_scores = projected.mean(dim=2) + attn_weights = F.softmax(attn_scores, dim=1) + attn_weights = attn_weights.unsqueeze(2) + pooled = (projected * attn_weights).sum(dim=1, keepdim=True) + return pooled + +# === Inference Function === +def run_inference(pipe, unet, vae, device, context_hidden_states, save_path="inference_output.png", rank=0): + if rank != 0: # Only rank-0 or single-GPU process runs inference + return + pipe.unet = unet.module if isinstance(unet, DDP) else unet + pipe.vae = vae.module if isinstance(vae, DDP) else vae + pipe.to(device) + + batch_size = 1 + latents = torch.randn((batch_size, pipe.unet.in_channels, 64, 64), device=device, dtype=torch.float16) + pipe.scheduler.set_timesteps(50) + latents = latents * pipe.scheduler.init_noise_sigma + + expected_shape = (batch_size, 1, 768) + if context_hidden_states.shape != expected_shape: + raise ValueError(f"Expected context_hidden_states shape {expected_shape}, got {context_hidden_states.shape}") + + for t in pipe.scheduler.timesteps: + with torch.no_grad(): + noise_pred = pipe.unet(latents, t, encoder_hidden_states=context_hidden_states).sample + latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample + + latents = 1 / pipe.vae.config.scaling_factor * latents + with torch.no_grad(): + image = pipe.vae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy()[0] + image = Image.fromarray((image * 255).astype("uint8")) + image.save(save_path) + print(f"{'Rank ' + str(rank) + ': ' if rank != 0 else ''}Inference image saved to {save_path}") + +# === Load Pipeline === +def load_pipeline(model_id, device): + pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device) + unet = pipe.unet + vae = pipe.vae + return pipe, unet, vae + +# === Freeze Layers Function === +def freeze_vae_layers(vae): + vae.encoder.requires_grad_(False) + vae.quant_conv.requires_grad_(False) + vae.decoder.requires_grad_(True) + vae.post_quant_conv.requires_grad_(True) + +def freeze_unet_layers(unet): + for name, param in unet.named_parameters(): + if "attn2" in name or "conv2" in name: + param.requires_grad = True + else: + param.requires_grad = False + +# === Optimizer Setup === +def setup_optimizer(vae, unet, projector, lr): + params_to_optimize = list(filter(lambda p: p.requires_grad, vae.parameters())) + \ + list(filter(lambda p: p.requires_grad, unet.parameters())) + \ + list(filter(lambda p: p.requires_grad, projector.parameters())) + optimizer = optim.AdamW(params_to_optimize, lr=lr) + return optimizer + +# === Gradient Accumulation Function === +def accumulate_gradients(optimizer, loss, gradient_accumulation_steps, step, dataloader): + loss = loss / gradient_accumulation_steps + loss.backward() + if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == len(dataloader): + optimizer.step() + optimizer.zero_grad() + +# === Save Checkpoint Function === +def save_checkpoint(epoch, unet, vae, projector, optimizer, checkpoint_path, rank=0): + if rank != 0: # Only rank-0 or single-GPU process saves checkpoint + return + torch.save({ + 'epoch': epoch, + 'unet': unet.module.state_dict() if isinstance(unet, DDP) else unet.state_dict(), + 'vae': vae.module.state_dict() if isinstance(vae, DDP) else vae.state_dict(), + 'projector': projector.module.state_dict() if isinstance(projector, DDP) else projector.state_dict(), + 'optimizer': optimizer.state_dict(), + }, checkpoint_path) + print(f"{'Rank ' + str(rank) + ': ' if rank != 0 else ''}Checkpoint saved to {checkpoint_path}") + +# === Resume from Checkpoint Function === +def resume_from_checkpoint(checkpoint_path, unet, vae, projector, optimizer, rank=0): + if rank != 0: # Only rank-0 or single-GPU process loads checkpoint + return 0 + if os.path.exists(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + unet.load_state_dict(checkpoint['unet']) + vae.load_state_dict(checkpoint['vae']) + projector.load_state_dict(checkpoint['projector']) + optimizer.load_state_dict(checkpoint['optimizer']) + start_epoch = checkpoint['epoch'] + 1 + print(f"{'Rank ' + str(rank) + ': ' if rank != 0 else ''}Resuming training from epoch {start_epoch}...") + return start_epoch + else: + print(f"{'Rank ' + str(rank) + ': ' if rank != 0 else ''}No checkpoint found, starting from scratch.") + return 0 + +# === Training Loop Function === +def train_loop(dataloader, unet, vae, optimizer, gradient_accumulation_steps, device, num_epochs, samples_path, checkpoint_path, pipe, projector, rank=0, is_distributed=False): + start_epoch = resume_from_checkpoint(checkpoint_path, unet, vae, projector, optimizer, rank) + + for epoch in trange(start_epoch, num_epochs, colour='red', desc=f"{'Rank ' + str(rank) + ' ' if rank != 0 else ''}{device}-training", dynamic_ncols=True): + unet.train() + vae.train() + projector.train() + total_loss = 0 + step = 0 + + # Reset sampler for each epoch if using DistributedSampler + if is_distributed and isinstance(dataloader.sampler, DistributedSampler): + dataloader.sampler.set_epoch(epoch) + + for image in tqdm(dataloader, colour='green', desc=f"{'Rank ' + str(rank) + ' ' if rank != 0 else ''}{device}-batch", dynamic_ncols=True): + image = image.to(device, dtype=torch.float16) + + latents = vae.encode(image).latent_dist.sample() * 0.18215 + noise = torch.randn_like(latents) + timesteps = torch.randint(0, pipe.scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long() + + dummy_audio = torch.zeros(image.size(0), 1500, 1280, device=device, dtype=torch.float16) + context_hidden_states = projector(dummy_audio) + + noise_pred = unet(latents + noise, timesteps, encoder_hidden_states=context_hidden_states).sample + loss = nn.MSELoss()(noise_pred, noise) + total_loss += loss.item() + + step += 1 + accumulate_gradients(optimizer, loss, gradient_accumulation_steps, step, dataloader) + + # Aggregate loss for DDP + if is_distributed: + total_loss_tensor = torch.tensor(total_loss, device=device) + dist.all_reduce(total_loss_tensor, op=dist.ReduceOp.SUM) + avg_loss = total_loss_tensor.item() / (len(dataloader) * dist.get_world_size()) + else: + avg_loss = total_loss / len(dataloader) + + if rank == 0: + print(f"{'Rank ' + str(rank) + ': ' if rank != 0 else ''}Epoch {epoch + 1} | Avg Loss: {avg_loss:.6f}") + + save_checkpoint(epoch, unet, vae, projector, optimizer, checkpoint_path, rank) + run_inference(pipe, unet, vae, device, context_hidden_states, + save_path=f"{samples_path}/inference_epoch{epoch + 1}{'_rank' + str(rank) if rank != 0 else ''}.png", + rank=rank) + + if rank == 0: + print(f"{'Rank ' + str(rank) + ': ' if rank != 0 else ''}✅ Fine-tuning complete.") + +# === DDP Setup Function === +def setup_ddp(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + dist.init_process_group("nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + +# === Main Function === +def main(rank=0, world_size=1, is_distributed=False): + if is_distributed: + setup_ddp(rank, world_size) + device = torch.device(f"cuda:{rank}") + else: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + model_id = "runwayml/stable-diffusion-v1-5" + lr = 1e-5 + num_epochs = 10 + batch_size = 16 + debug = False + gradient_accumulation_steps = 1 + + if rank == 0: + os.makedirs(f"./checkpoints", exist_ok=True) + os.makedirs(f"./samples", exist_ok=True) + if is_distributed: + dist.barrier() + + checkpoint_path = f"./checkpoints/checkpoint{'_rank' + str(rank) if is_distributed else ''}.pth" + samples_path = f"./samples" + image_dir = "/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Images" + + pipe, unet, vae = load_pipeline(model_id, device) + freeze_vae_layers(vae) + freeze_unet_layers(unet) + projector = AudioContextProjector(audio_embed_dim=1280, output_dim=768).to(device).half() + + if is_distributed: + unet = DDP(unet, device_ids=[rank]) + vae = DDP(vae, device_ids=[rank]) + projector = DDP(projector, device_ids=[rank]) + + optimizer = setup_optimizer(vae, unet, projector, lr) + + files = walkDIR(image_dir, include=['.png', '.jpeg', '.jpg']) + dataset = VaaniDataset(files_paths=files, im_size=256) + if rank == 0: + image = dataset[2] + print(f"{'Rank ' + str(rank) + ': ' if rank != 0 else ''}IMAGE SHAPE: {image.shape}, Dataset len: {len(dataset)}") + + dataloader = create_dataloader(dataset, batch_size, debug=debug, rank=rank, is_distributed=is_distributed) + + train_loop(dataloader, unet, vae, optimizer, gradient_accumulation_steps, device, num_epochs, + samples_path, checkpoint_path, pipe, projector, rank, is_distributed) + + if is_distributed: + dist.destroy_process_group() + +# === Entry Point === +if __name__ == "__main__": + world_size = torch.cuda.device_count() + print(f"Detected {world_size} GPU(s)") + if world_size > 1: + mp.spawn(main, args=(world_size, True), nprocs=world_size, join=True) + else: + main(rank=0, world_size=1, is_distributed=False) \ No newline at end of file diff --git a/Vaani/SDFT/checkpoints/checkpoint.pth b/Vaani/SDFT/checkpoints/checkpoint.pth new file mode 100644 index 0000000000000000000000000000000000000000..f7c3f77f5cce17d3129e22bf8a482b77f2ded7d5 --- /dev/null +++ b/Vaani/SDFT/checkpoints/checkpoint.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:79469e5ae61b7894df2b96cdb09b873f9d0e2282f8b85d4195c5dbd16e182891 +size 2866661866 diff --git a/Vaani/SDFT/download_model.py b/Vaani/SDFT/download_model.py new file mode 100644 index 0000000000000000000000000000000000000000..15195c4c0f7205d681fe696b773f8e2be2d610f2 --- /dev/null +++ b/Vaani/SDFT/download_model.py @@ -0,0 +1,13 @@ +import torch +from diffusers import StableDiffusionPipeline, UNet2DConditionModel, StableDiffusion3Pipeline + +device = "cuda" if torch.cuda.is_available() else "cpu" +print("device:", device) + +# pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) +# pipe +# del pipe + +pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3.5-medium", torch_dtype=torch.bfloat16) +pipe +# del pipe \ No newline at end of file diff --git a/Vaani/SDFT/vaani-stablediffusion-finetune-kaggle.ipynb b/Vaani/SDFT/vaani-stablediffusion-finetune-kaggle.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..74464eba5adb681c979ddc76ae429d04c932bb7a --- /dev/null +++ b/Vaani/SDFT/vaani-stablediffusion-finetune-kaggle.ipynb @@ -0,0 +1,650 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "trusted": true + }, + "outputs": [], + "source": [ + "import torch\n", + "from torch import nn\n", + "from torch.utils.data import Dataset, DataLoader\n", + "from transformers import CLIPTextModel, CLIPTokenizer\n", + "from diffusers import StableDiffusionPipeline, UNet2DConditionModel\n", + "from diffusers.optimization import get_scheduler\n", + "from accelerate import Accelerator\n", + "import torchaudio" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "trusted": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Couldn't connect to the Hub: (MaxRetryError('HTTPSConnectionPool(host=\\'huggingface.co\\', port=443): Max retries exceeded with url: /api/models/runwayml/stable-diffusion-v1-5 (Caused by NameResolutionError(\": Failed to resolve \\'huggingface.co\\' ([Errno -2] Name or service not known)\"))'), '(Request ID: 85a7f948-b1d1-4bb4-be97-0eaea2bfd0f8)').\n", + "Will try to load from local cache.\n", + "Loading pipeline components...: 100%|██████████| 7/7 [00:43<00:00, 6.22s/it]\n" + ] + } + ], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "pipe = StableDiffusionPipeline.from_pretrained(\"runwayml/stable-diffusion-v1-5\", torch_dtype=torch.float16).to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "trusted": true + }, + "outputs": [], + "source": [ + "unet = pipe.unet\n", + "vae = pipe.vae\n", + "tokenizer = pipe.tokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "trusted": true + }, + "outputs": [], + "source": [ + "# Your text prompt\n", + "prompt = \"a photo of an astronaut riding a horse on mars\"\n", + "\n", + "# Generate image\n", + "with torch.autocast(\"cuda\"):\n", + " image = pipe(prompt).images[0]\n", + "\n", + "# Show or save the result\n", + "image.show() # Opens in default image viewer\n", + "image.save(\"astronaut_horse_mars.png\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "execution": { + "iopub.execute_input": "2025-05-14T14:30:58.653987Z", + "iopub.status.busy": "2025-05-14T14:30:58.653745Z", + "iopub.status.idle": "2025-05-14T14:30:58.658276Z", + "shell.execute_reply": "2025-05-14T14:30:58.657649Z", + "shell.execute_reply.started": "2025-05-14T14:30:58.653970Z" + }, + "trusted": true + }, + "outputs": [], + "source": [ + "import torch\n", + "from torch.utils.data import Dataset, DataLoader\n", + "from torchvision import transforms\n", + "from torchvision.transforms import v2\n", + "from PIL import Image\n", + "from diffusers import StableDiffusionPipeline\n", + "from diffusers.optimization import get_scheduler\n", + "from accelerate import Accelerator\n", + "from torch import nn\n", + "import os\n", + "import pandas as pd\n", + "from tqdm import trange, tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "execution": { + "iopub.execute_input": "2025-05-14T14:30:58.659588Z", + "iopub.status.busy": "2025-05-14T14:30:58.658976Z", + "iopub.status.idle": "2025-05-14T14:31:23.063776Z", + "shell.execute_reply": "2025-05-14T14:31:23.063145Z", + "shell.execute_reply.started": "2025-05-14T14:30:58.659571Z" + }, + "trusted": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Files found: 128807\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
image_path
0/kaggle/input/vaani-images-tar/Images/IISc_Vaa...
1/kaggle/input/vaani-images-tar/Images/IISc_Vaa...
2/kaggle/input/vaani-images-tar/Images/IISc_Vaa...
3/kaggle/input/vaani-images-tar/Images/IISc_Vaa...
4/kaggle/input/vaani-images-tar/Images/IISc_Vaa...
......
128802/kaggle/input/vaani-images-tar/Images/IISc_Vaa...
128803/kaggle/input/vaani-images-tar/Images/IISc_Vaa...
128804/kaggle/input/vaani-images-tar/Images/IISc_Vaa...
128805/kaggle/input/vaani-images-tar/Images/IISc_Vaa...
128806/kaggle/input/vaani-images-tar/Images/IISc_Vaa...
\n", + "

128807 rows × 1 columns

\n", + "
" + ], + "text/plain": [ + " image_path\n", + "0 /kaggle/input/vaani-images-tar/Images/IISc_Vaa...\n", + "1 /kaggle/input/vaani-images-tar/Images/IISc_Vaa...\n", + "2 /kaggle/input/vaani-images-tar/Images/IISc_Vaa...\n", + "3 /kaggle/input/vaani-images-tar/Images/IISc_Vaa...\n", + "4 /kaggle/input/vaani-images-tar/Images/IISc_Vaa...\n", + "... ...\n", + "128802 /kaggle/input/vaani-images-tar/Images/IISc_Vaa...\n", + "128803 /kaggle/input/vaani-images-tar/Images/IISc_Vaa...\n", + "128804 /kaggle/input/vaani-images-tar/Images/IISc_Vaa...\n", + "128805 /kaggle/input/vaani-images-tar/Images/IISc_Vaa...\n", + "128806 /kaggle/input/vaani-images-tar/Images/IISc_Vaa...\n", + "\n", + "[128807 rows x 1 columns]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "IMAGES_PATH = r\"/kaggle/input/vaani-images-tar/Images\"\n", + "\n", + "def walkDIR(folder_path, include=None):\n", + " file_list = []\n", + " for root, _, files in os.walk(folder_path):\n", + " for file in files:\n", + " if include is None or any(file.endswith(ext) for ext in include):\n", + " file_list.append(os.path.join(root, file))\n", + " print(\"Files found:\", len(file_list))\n", + " return file_list\n", + "\n", + "files = walkDIR(IMAGES_PATH, include=['.png', '.jpeg', '.jpg'])\n", + "df = pd.DataFrame(files, columns=['image_path'])\n", + "df" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "execution": { + "iopub.execute_input": "2025-05-14T14:31:23.065017Z", + "iopub.status.busy": "2025-05-14T14:31:23.064553Z", + "iopub.status.idle": "2025-05-14T14:31:23.086417Z", + "shell.execute_reply": "2025-05-14T14:31:23.085628Z", + "shell.execute_reply.started": "2025-05-14T14:31:23.064991Z" + }, + "trusted": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "IMAGE SHAPE: torch.Size([3, 256, 256])\n" + ] + }, + { + "data": { + "text/plain": [ + "128807" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class VaaniDataset(torch.utils.data.Dataset):\n", + " def __init__(self, files_paths, im_size):\n", + " self.files_paths = files_paths\n", + " self.im_size = im_size\n", + "\n", + " def __len__(self):\n", + " return len(self.files_paths)\n", + "\n", + " def __getitem__(self, idx):\n", + " # image = tv.io.read_image(self.files_paths[idx], mode=tv.io.ImageReadMode.RGB)\n", + " image = Image.open(self.files_paths[idx]).convert(\"RGB\")\n", + " image = v2.ToImage()(image)\n", + " # image = tv.io.decode_image(self.files_paths[idx], mode=tv.io.ImageReadMode.RGB)\n", + " image = v2.Resize((self.im_size, self.im_size))(image)\n", + " image = v2.ToDtype(torch.float32, scale=True)(image)\n", + " # image = 2*image - 1\n", + " return image\n", + "\n", + "dataset = VaaniDataset(files_paths=files, im_size=256)\n", + "image = dataset[2]\n", + "print('IMAGE SHAPE:', image.shape)\n", + "len(dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "execution": { + "iopub.execute_input": "2025-05-14T14:31:23.087483Z", + "iopub.status.busy": "2025-05-14T14:31:23.087211Z", + "iopub.status.idle": "2025-05-14T14:31:23.468810Z", + "shell.execute_reply": "2025-05-14T14:31:23.465992Z", + "shell.execute_reply.started": "2025-05-14T14:31:23.087458Z" + }, + "trusted": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Length of Train dataset: 129\n", + "BATCH SHAPE: torch.Size([2, 3, 256, 256])\n" + ] + } + ], + "source": [ + "debug = True\n", + "\n", + "if debug:\n", + " s = 0.001\n", + " dataset, _ = torch.utils.data.random_split(dataset, [s, 1-s], torch.manual_seed(42))\n", + " print(\"Length of Train dataset:\", len(dataset))\n", + "\n", + "BATCH_SIZE = 2\n", + "\n", + "dataloader = torch.utils.data.DataLoader(\n", + " dataset, \n", + " batch_size=BATCH_SIZE, \n", + " shuffle=True, \n", + " num_workers=4,\n", + " pin_memory=True,\n", + " drop_last=True,\n", + " persistent_workers=True\n", + ")\n", + "\n", + "images = next(iter(dataloader))\n", + "print('BATCH SHAPE:', images.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "execution": { + "iopub.execute_input": "2025-05-14T14:31:59.796334Z", + "iopub.status.busy": "2025-05-14T14:31:59.795660Z", + "iopub.status.idle": "2025-05-14T14:31:59.800889Z", + "shell.execute_reply": "2025-05-14T14:31:59.800295Z", + "shell.execute_reply.started": "2025-05-14T14:31:59.796311Z" + }, + "trusted": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "64" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(dataloader)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "execution": { + "iopub.execute_input": "2025-05-14T14:31:23.470858Z", + "iopub.status.busy": "2025-05-14T14:31:23.470503Z", + "iopub.status.idle": "2025-05-14T14:31:28.213003Z", + "shell.execute_reply": "2025-05-14T14:31:28.212168Z", + "shell.execute_reply.started": "2025-05-14T14:31:23.470801Z" + }, + "scrolled": true, + "trusted": true + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "28c0c220b2cf45968b4abdecf3936bc9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading pipeline components...: 0%| | 0/7 [00:00 None: + self._temp = temp + self._key = key + self._parent = parent + + def __eq__(self, other): + if not isinstance(other, DotDict): + return NotImplemented + return vars(self) == vars(other) + + def __getattr__(self, __name: str) -> Any: + if __name not in self.__dict__ and not self._temp: + self.__dict__[__name] = DotDict(temp=True, key=__name, parent=self) + else: + del self._parent.__dict__[self._key] + raise AttributeError("No attribute '%s'" % __name) + return self.__dict__[__name] + + def __repr__(self) -> str: + item_keys = [k for k in self.__dict__ if not k.startswith("_")] + + if len(item_keys) == 0: + return "DotDict()" + elif len(item_keys) == 1: + key = item_keys[0] + val = self.__dict__[key] + return "DotDict(%s=%s)" % (key, repr(val)) + else: + return "DotDict(%s)" % ", ".join( + "%s=%s" % (key, repr(val)) for key, val in self.__dict__.items() + ) + + @classmethod + def from_dict(cls, original: typing.Mapping[str, any]) -> "DotDict": + """Create a DotDict from a (possibly nested) dict `original`. + Warning: this method should not be used on very deeply nested inputs, + since it's recursively traversing the nested dictionary values. + """ + dd = DotDict() + for key, value in original.items(): + if isinstance(value, typing.Mapping): + value = cls.from_dict(value) + setattr(dd, key, value) + return dd + + +# ================================================================== +# L P I P S +# ================================================================== +class vgg16(nn.Module): + def __init__(self): + super(vgg16, self).__init__() + vgg_pretrained_features = tv.models.vgg16( + weights=tv.models.VGG16_Weights.IMAGENET1K_V1 + ).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + + self.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h1 = self.slice1(X) + h2 = self.slice2(h1) + h3 = self.slice3(h2) + h4 = self.slice4(h3) + h5 = self.slice5(h4) + vgg_outputs = namedtuple("VggOutputs", ['h1', 'h2', 'h3', 'h4', 'h5']) + out = vgg_outputs(h1, h2, h3, h4, h5) + return out + + +def _spatial_average(in_tens, keepdim=True): + return in_tens.mean([2, 3], keepdim=keepdim) + + +def _normalize_tensor(in_feat, eps= 1e-8): + norm_factor = torch.sqrt(eps + torch.sum(in_feat**2, dim=1, keepdim=True)) + return in_feat / norm_factor + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + # Imagnet normalization for (0-1) + # mean = [0.485, 0.456, 0.406] + # std = [0.229, 0.224, 0.225] + + self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + ''' A single linear layer which does a 1x1 conv ''' + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = [nn.Dropout(), ] if (use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) + + +class LPIPS(nn.Module): + def __init__(self, net='vgg', version='0.1', use_dropout=True): + super(LPIPS, self).__init__() + self.version = version + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] + self.L = len(self.chns) + self.net = vgg16() + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.lins = nn.ModuleList([self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]) + + # --- Orignal url -------------------- + # weights_url = f"https://github.com/richzhang/PerceptualSimilarity/raw/master/lpips/weights/v{version}/{net}.pth" + + # --- Orignal Forked url ------------- + weights_url = f"https://github.com/akuresonite/PerceptualSimilarity-Forked/raw/master/lpips/weights/v{version}/{net}.pth" + + # --- Orignal torchmetric url -------- + # weights_url = "https://github.com/Lightning-AI/torchmetrics/raw/master/src/torchmetrics/functional/image/lpips_models/vgg.pth" + + state_dict = torch.hub.load_state_dict_from_url(weights_url, map_location='cpu') + self.load_state_dict(state_dict, strict=False) + + self.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, in0, in1, normalize=False): + # Scale the inputs to -1 to +1 range if input in [0,1] + if normalize: + in0 = 2 * in0 - 1 + in1 = 2 * in1 - 1 + + in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1) + # in0_input, in1_input = in0, in1 + + outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) + + diffs = {} + for kk in range(self.L): + feats0 = _normalize_tensor(outs0[kk]) + feats1 = _normalize_tensor(outs1[kk]) + diffs[kk] = (feats0 - feats1) ** 2 + + res = [_spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] + val = sum(res) + return val.reshape(-1) + + +# ================================================================== +# P A T C H - G A N - D I S C R I M I N A T O R +# ================================================================== +class Discriminator(nn.Module): + r""" + PatchGAN Discriminator. + Rather than taking IMG_CHANNELSxIMG_HxIMG_W all the way to + 1 scalar value , we instead predict grid of values. + Where each grid is prediction of how likely + the discriminator thinks that the image patch corresponding + to the grid cell is real + """ + + def __init__( + self, + im_channels=3, + conv_channels=[64, 128, 256], + kernels=[4, 4, 4, 4], + strides=[2, 2, 2, 1], + paddings=[1, 1, 1, 1], + ): + super().__init__() + self.im_channels = im_channels + activation = nn.LeakyReLU(0.2) + layers_dim = [self.im_channels] + conv_channels + [1] + self.layers = nn.ModuleList( + [ + nn.Sequential( + nn.Conv2d( + layers_dim[i], + layers_dim[i + 1], + kernel_size=kernels[i], + stride=strides[i], + padding=paddings[i], + bias=False if i != 0 else True, + ), + ( + nn.BatchNorm2d(layers_dim[i + 1]) + if i != len(layers_dim) - 2 and i != 0 + else nn.Identity() + ), + activation if i != len(layers_dim) - 2 else nn.Identity(), + ) + for i in range(len(layers_dim) - 1) + ] + ) + + def forward(self, x): + out = x + for layer in self.layers: + out = layer(out) + return out + + + +# ================================================================== +# D O W E - B L O C K +# ================================================================== +class DownBlock(nn.Module): + r""" + Down conv block with attention. + Sequence of following block + 1. Resnet block with time embedding + 2. Attention block + 3. Downsample + """ + + def __init__( + self, + in_channels, + out_channels, + t_emb_dim, + down_sample, + num_heads, + num_layers, + attn, + norm_channels, + cross_attn=False, + context_dim=None, + ): + super().__init__() + self.num_layers = num_layers + self.down_sample = down_sample + self.attn = attn + self.context_dim = context_dim + self.cross_attn = cross_attn + self.t_emb_dim = t_emb_dim + self.resnet_conv_first = nn.ModuleList( + [ + nn.Sequential( + nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), + nn.SiLU(), + nn.Conv2d( + in_channels if i == 0 else out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ), + ) + for i in range(num_layers) + ] + ) + if self.t_emb_dim is not None: + self.t_emb_layers = nn.ModuleList( + [ + nn.Sequential(nn.SiLU(), nn.Linear(self.t_emb_dim, out_channels)) + for _ in range(num_layers) + ] + ) + self.resnet_conv_second = nn.ModuleList( + [ + nn.Sequential( + nn.GroupNorm(norm_channels, out_channels), + nn.SiLU(), + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), + ) + for _ in range(num_layers) + ] + ) + + if self.attn: + self.attention_norms = nn.ModuleList( + [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] + ) + + self.attentions = nn.ModuleList( + [ + nn.MultiheadAttention(out_channels, num_heads, batch_first=True) + for _ in range(num_layers) + ] + ) + if self.cross_attn: + assert context_dim is not None, "Context Dimension must be passed for cross attention" + self.cross_attention_norms = nn.ModuleList( + [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] + ) + self.cross_attentions = nn.ModuleList( + [ + nn.MultiheadAttention(out_channels, num_heads, batch_first=True) + for _ in range(num_layers) + ] + ) + self.context_proj = nn.ModuleList( + [nn.Linear(context_dim, out_channels) for _ in range(num_layers)] + ) + self.residual_input_conv = nn.ModuleList( + [ + nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) + for i in range(num_layers) + ] + ) + self.down_sample_conv = ( + nn.Conv2d(out_channels, out_channels, 4, 2, 1) if self.down_sample else nn.Identity() + ) + + def forward(self, x, t_emb=None, context=None): + out = x + for i in range(self.num_layers): + # Resnet block of Unet + + resnet_input = out + out = self.resnet_conv_first[i](out) + if self.t_emb_dim is not None: + out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] + out = self.resnet_conv_second[i](out) + out = out + self.residual_input_conv[i](resnet_input) + + if self.attn: + # Attention block of Unet + + batch_size, channels, h, w = out.shape + in_attn = out.reshape(batch_size, channels, h * w) + in_attn = self.attention_norms[i](in_attn) + in_attn = in_attn.transpose(1, 2) + out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) + out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) + out = out + out_attn + if self.cross_attn: + assert ( + context is not None + ), "context cannot be None if cross attention layers are used" + batch_size, channels, h, w = out.shape + in_attn = out.reshape(batch_size, channels, h * w) + in_attn = self.cross_attention_norms[i](in_attn) + in_attn = in_attn.transpose(1, 2) + assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim + context_proj = self.context_proj[i](context) + out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) + out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) + out = out + out_attn + # Downsample + + out = self.down_sample_conv(out) + return out + + + +# ================================================================== +# M I D - B L O C K +# ================================================================== +class MidBlock(nn.Module): + r""" + Mid conv block with attention. + Sequence of following blocks + 1. Resnet block with time embedding + 2. Attention block + 3. Resnet block with time embedding + """ + + def __init__( + self, + in_channels, + out_channels, + t_emb_dim, + num_heads, + num_layers, + norm_channels, + cross_attn=None, + context_dim=None, + ): + super().__init__() + self.num_layers = num_layers + self.t_emb_dim = t_emb_dim + self.context_dim = context_dim + self.cross_attn = cross_attn + self.resnet_conv_first = nn.ModuleList( + [ + nn.Sequential( + nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), + nn.SiLU(), + nn.Conv2d( + in_channels if i == 0 else out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ), + ) + for i in range(num_layers + 1) + ] + ) + + if self.t_emb_dim is not None: + self.t_emb_layers = nn.ModuleList( + [ + nn.Sequential(nn.SiLU(), nn.Linear(t_emb_dim, out_channels)) + for _ in range(num_layers + 1) + ] + ) + self.resnet_conv_second = nn.ModuleList( + [ + nn.Sequential( + nn.GroupNorm(norm_channels, out_channels), + nn.SiLU(), + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), + ) + for _ in range(num_layers + 1) + ] + ) + + self.attention_norms = nn.ModuleList( + [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] + ) + + self.attentions = nn.ModuleList( + [ + nn.MultiheadAttention(out_channels, num_heads, batch_first=True) + for _ in range(num_layers) + ] + ) + if self.cross_attn: + assert context_dim is not None, "Context Dimension must be passed for cross attention" + self.cross_attention_norms = nn.ModuleList( + [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] + ) + self.cross_attentions = nn.ModuleList( + [ + nn.MultiheadAttention(out_channels, num_heads, batch_first=True) + for _ in range(num_layers) + ] + ) + self.context_proj = nn.ModuleList( + [nn.Linear(context_dim, out_channels) for _ in range(num_layers)] + ) + self.residual_input_conv = nn.ModuleList( + [ + nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) + for i in range(num_layers + 1) + ] + ) + + def forward(self, x, t_emb=None, context=None): + out = x + + # First resnet block + resnet_input = out + out = self.resnet_conv_first[0](out) + if self.t_emb_dim is not None: + out = out + self.t_emb_layers[0](t_emb)[:, :, None, None] + out = self.resnet_conv_second[0](out) + out = out + self.residual_input_conv[0](resnet_input) + + for i in range(self.num_layers): + # Attention Block + batch_size, channels, h, w = out.shape + in_attn = out.reshape(batch_size, channels, h * w) + in_attn = self.attention_norms[i](in_attn) + in_attn = in_attn.transpose(1, 2) + out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) + out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) + out = out + out_attn + + if self.cross_attn: + assert ( + context is not None + ), "context cannot be None if cross attention layers are used" + batch_size, channels, h, w = out.shape + in_attn = out.reshape(batch_size, channels, h * w) + in_attn = self.cross_attention_norms[i](in_attn) + in_attn = in_attn.transpose(1, 2) + assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim + context_proj = self.context_proj[i](context) + out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) + out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) + out = out + out_attn + + # Resnet Block + resnet_input = out + out = self.resnet_conv_first[i + 1](out) + if self.t_emb_dim is not None: + out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None] + out = self.resnet_conv_second[i + 1](out) + out = out + self.residual_input_conv[i + 1](resnet_input) + return out + + +# ================================================================== +# U P - B L O C K +# ================================================================== +class UpBlock(nn.Module): + r""" + Up conv block with attention. + Sequence of following blocks + 1. Upsample + 1. Concatenate Down block output + 2. Resnet block with time embedding + 3. Attention Block + """ + + def __init__( + self, + in_channels, + out_channels, + t_emb_dim, + up_sample, + num_heads, + num_layers, + attn, + norm_channels, + ): + super().__init__() + self.num_layers = num_layers + self.up_sample = up_sample + self.t_emb_dim = t_emb_dim + self.attn = attn + self.resnet_conv_first = nn.ModuleList( + [ + nn.Sequential( + nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), + nn.SiLU(), + nn.Conv2d( + in_channels if i == 0 else out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ), + ) + for i in range(num_layers) + ] + ) + + if self.t_emb_dim is not None: + self.t_emb_layers = nn.ModuleList( + [ + nn.Sequential(nn.SiLU(), nn.Linear(t_emb_dim, out_channels)) + for _ in range(num_layers) + ] + ) + self.resnet_conv_second = nn.ModuleList( + [ + nn.Sequential( + nn.GroupNorm(norm_channels, out_channels), + nn.SiLU(), + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), + ) + for _ in range(num_layers) + ] + ) + if self.attn: + self.attention_norms = nn.ModuleList( + [nn.GroupNorm(norm_channels, out_channels) for _ in range(num_layers)] + ) + + self.attentions = nn.ModuleList( + [ + nn.MultiheadAttention(out_channels, num_heads, batch_first=True) + for _ in range(num_layers) + ] + ) + self.residual_input_conv = nn.ModuleList( + [ + nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) + for i in range(num_layers) + ] + ) + self.up_sample_conv = ( + nn.ConvTranspose2d(in_channels, in_channels, 4, 2, 1) + if self.up_sample + else nn.Identity() + ) + + def forward(self, x, out_down=None, t_emb=None): + # Upsample + + x = self.up_sample_conv(x) + + # Concat with Downblock output + + if out_down is not None: + x = torch.cat([x, out_down], dim=1) + out = x + for i in range(self.num_layers): + # Resnet Block + + resnet_input = out + out = self.resnet_conv_first[i](out) + if self.t_emb_dim is not None: + out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] + out = self.resnet_conv_second[i](out) + out = out + self.residual_input_conv[i](resnet_input) + + # Self Attention + + if self.attn: + batch_size, channels, h, w = out.shape + in_attn = out.reshape(batch_size, channels, h * w) + in_attn = self.attention_norms[i](in_attn) + in_attn = in_attn.transpose(1, 2) + out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) + out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) + out = out + out_attn + return out + + +# ================================================================== +# V Q - V A E +# ================================================================== +class VQVAE(nn.Module): + def __init__(self, im_channels, model_config): + super().__init__() + self.down_channels = model_config.down_channels + self.mid_channels = model_config.mid_channels + self.down_sample = model_config.down_sample + self.num_down_layers = model_config.num_down_layers + self.num_mid_layers = model_config.num_mid_layers + self.num_up_layers = model_config.num_up_layers + + # To disable attention in Downblock of Encoder and Upblock of Decoder + self.attns = model_config.attn_down + + # Latent Dimension + self.z_channels = model_config.z_channels + self.codebook_size = model_config.codebook_size + self.norm_channels = model_config.norm_channels + self.num_heads = model_config.num_heads + + # Assertion to validate the channel information + assert self.mid_channels[0] == self.down_channels[-1] + assert self.mid_channels[-1] == self.down_channels[-1] + assert len(self.down_sample) == len(self.down_channels) - 1 + assert len(self.attns) == len(self.down_channels) - 1 + + # Wherever we use downsampling in encoder correspondingly use + # upsampling in decoder + self.up_sample = list(reversed(self.down_sample)) + + ##################### Encoder ###################### + self.encoder_conv_in = nn.Conv2d( + im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1) + ) + + # Downblock + Midblock + self.encoder_layers = nn.ModuleList([]) + for i in range(len(self.down_channels) - 1): + self.encoder_layers.append( + DownBlock( + self.down_channels[i], + self.down_channels[i + 1], + t_emb_dim=None, + down_sample=self.down_sample[i], + num_heads=self.num_heads, + num_layers=self.num_down_layers, + attn=self.attns[i], + norm_channels=self.norm_channels, + ) + ) + self.encoder_mids = nn.ModuleList([]) + for i in range(len(self.mid_channels) - 1): + self.encoder_mids.append( + MidBlock( + self.mid_channels[i], + self.mid_channels[i + 1], + t_emb_dim=None, + num_heads=self.num_heads, + num_layers=self.num_mid_layers, + norm_channels=self.norm_channels, + ) + ) + self.encoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[-1]) + self.encoder_conv_out = nn.Conv2d( + self.down_channels[-1], self.z_channels, kernel_size=3, padding=1 + ) + + # Pre Quantization Convolution + self.pre_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1) + + # Codebook + self.embedding = nn.Embedding(self.codebook_size, self.z_channels) + #################################################### + + ##################### Decoder ###################### + + # Post Quantization Convolution + self.post_quant_conv = nn.Conv2d(self.z_channels, self.z_channels, kernel_size=1) + self.decoder_conv_in = nn.Conv2d( + self.z_channels, self.mid_channels[-1], kernel_size=3, padding=(1, 1) + ) + + # Midblock + Upblock + self.decoder_mids = nn.ModuleList([]) + for i in reversed(range(1, len(self.mid_channels))): + self.decoder_mids.append( + MidBlock( + self.mid_channels[i], + self.mid_channels[i - 1], + t_emb_dim=None, + num_heads=self.num_heads, + num_layers=self.num_mid_layers, + norm_channels=self.norm_channels, + ) + ) + self.decoder_layers = nn.ModuleList([]) + for i in reversed(range(1, len(self.down_channels))): + self.decoder_layers.append( + UpBlock( + self.down_channels[i], + self.down_channels[i - 1], + t_emb_dim=None, + up_sample=self.down_sample[i - 1], + num_heads=self.num_heads, + num_layers=self.num_up_layers, + attn=self.attns[i - 1], + norm_channels=self.norm_channels, + ) + ) + self.decoder_norm_out = nn.GroupNorm(self.norm_channels, self.down_channels[0]) + self.decoder_conv_out = nn.Conv2d( + self.down_channels[0], im_channels, kernel_size=3, padding=1 + ) + + def quantize(self, x): + B, C, H, W = x.shape + + # B, C, H, W -> B, H, W, C + x = x.permute(0, 2, 3, 1) + + # B, H, W, C -> B, H*W, C + x = x.reshape(x.size(0), -1, x.size(-1)) + + # Find nearest embedding/codebook vector + # dist between (B, H*W, C) and (B, K, C) -> (B, H*W, K) + dist = torch.cdist(x, self.embedding.weight[None, :].repeat((x.size(0), 1, 1))) + # (B, H*W) + min_encoding_indices = torch.argmin(dist, dim=-1) + + # Replace encoder output with nearest codebook + # quant_out -> B*H*W, C + quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1)) + + # x -> B*H*W, C + x = x.reshape((-1, x.size(-1))) + commmitment_loss = torch.mean((quant_out.detach() - x) ** 2) + codebook_loss = torch.mean((quant_out - x.detach()) ** 2) + quantize_losses = {"codebook_loss": codebook_loss, "commitment_loss": commmitment_loss} + # Straight through estimation + quant_out = x + (quant_out - x).detach() + + # quant_out -> B, C, H, W + quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2) + min_encoding_indices = min_encoding_indices.reshape( + (-1, quant_out.size(-2), quant_out.size(-1)) + ) + return quant_out, quantize_losses, min_encoding_indices + + def encode(self, x): + out = self.encoder_conv_in(x) + for idx, down in enumerate(self.encoder_layers): + out = down(out) + for mid in self.encoder_mids: + out = mid(out) + out = self.encoder_norm_out(out) + out = nn.SiLU()(out) + out = self.encoder_conv_out(out) + out = self.pre_quant_conv(out) + out, quant_losses, _ = self.quantize(out) + return out, quant_losses + + def decode(self, z): + out = z + out = self.post_quant_conv(out) + out = self.decoder_conv_in(out) + for mid in self.decoder_mids: + out = mid(out) + for idx, up in enumerate(self.decoder_layers): + out = up(out) + out = self.decoder_norm_out(out) + out = nn.SiLU()(out) + out = self.decoder_conv_out(out) + return out + + def forward(self, x): + '''out: [B, 3, 256, 256] + z: [B, 3, 64, 64] + quant_losses: { + codebook_loss: 0.0681, + commitment_loss: 0.0681 + } + ''' + z, quant_losses = self.encode(x) + out = self.decode(z) + return out, z, quant_losses + + +# ================================================================== +# C O N F I G U R A T I O N +# ================================================================== +import pprint +config_path = "/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/config-LDM-High-Pre.yaml" +# config_path = sys.argv[1] +with open(config_path, 'r') as file: + Config = yaml.safe_load(file) + pprint.pprint(Config, width=120) + +Config = DotDict.from_dict(Config) +dataset_config = Config.dataset_params +diffusion_config = Config.diffusion_params +model_config = Config.model_params +train_config = Config.train_params +paths = Config.paths + + +# ================================================================== +# V A A N I - D A T A S E T +# ================================================================== +IMAGES_PATH = paths.images_dir + +def walkDIR(folder_path, include=None): + file_list = [] + for root, _, files in os.walk(folder_path): + for file in files: + if include is None or any(file.endswith(ext) for ext in include): + file_list.append(os.path.join(root, file)) + print("Files found:", len(file_list)) + return file_list + +files = walkDIR(IMAGES_PATH, include=['.png', '.jpeg', '.jpg']) +df = pd.DataFrame(files, columns=['image_path']) + +class VaaniDataset(torch.utils.data.Dataset): + def __init__(self, files_paths, im_size): + self.files_paths = files_paths + self.im_size = im_size + + def __len__(self): + return len(self.files_paths) + + def __getitem__(self, idx): + image = tv.io.read_image(self.files_paths[idx], mode=tv.io.ImageReadMode.RGB) + # image = tv.io.decode_image(self.files_paths[idx], mode=tv.io.ImageReadMode.RGB) + image = v2.Resize((self.im_size,self.im_size))(image) + image = v2.ToDtype(torch.float32, scale=True)(image) + # image = 2*image - 1 + return image + +dataset = VaaniDataset(files_paths=files, im_size=dataset_config.im_size) +image = dataset[2] +print('IMAGE SHAPE:', image.shape) +if train_config.debug: + s = 0.001 + dataset, _ = torch.utils.data.random_split(dataset, [s, 1-s], torch.manual_seed(42)) + print("Length of Train dataset:", len(dataset)) + +if sys.argv[1] == "train_vae": + BATCH_SIZE = train_config.autoencoder_batch_size +elif sys.argv[1] == "train_ldm": + BATCH_SIZE = train_config.ldm_batch_size + +dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + num_workers=48, + pin_memory=True, + drop_last=True, + persistent_workers=True +) + +images = next(iter(dataloader)) +print('BATCH SHAPE:', images.shape) + + +# ================================================================== +# M O D E L - I N I T I L I Z A T I O N +# ================================================================== +dataset_config = Config.dataset_params +autoencoder_config = Config.autoencoder_params +train_config = Config.train_params + +# model = VQVAE(im_channels=dataset_config.im_channels, +# model_config=autoencoder_config).to(device) + +# model_output = model(images.to(device)) +# print('MODEL OUTPUT:') +# print(model_output[0].shape, model_output[1].shape, model_output[2]) + +# from torchinfo import summary +# summary(model=model, +# input_data=images.to(device), +# # input_size = (1, 3, config.IMAGE_SIZE, config.IMAGE_SIZE), +# col_names = ["input_size", "output_size", "num_params", "trainable", "params_percent"], +# col_width=20, +# row_settings=["var_names"], +# depth = 6, +# # device=device +# ) +# exit() + + +# ================================================================== +# V Q - V A E - T R A I N I N G +# ================================================================== +# python your_script.py 2>&1 > training.log +import time + +def format_time(t1, t2): + elapsed_time = t2 - t1 + if elapsed_time < 60: + return f"{elapsed_time:.2f} seconds" + elif elapsed_time < 3600: + minutes = elapsed_time // 60 + seconds = elapsed_time % 60 + return f"{minutes:.0f} minutes {seconds:.2f} seconds" + elif elapsed_time < 86400: + hours = elapsed_time // 3600 + remainder = elapsed_time % 3600 + minutes = remainder // 60 + seconds = remainder % 60 + return f"{hours:.0f} hours {minutes:.0f} minutes {seconds:.2f} seconds" + else: + days = elapsed_time // 86400 + remainder = elapsed_time % 86400 + hours = remainder // 3600 + remainder = remainder % 3600 + minutes = remainder // 60 + seconds = remainder % 60 + return f"{days:.0f} days {hours:.0f} hours {minutes:.0f} minutes {seconds:.2f} seconds" + + +def find_checkpoints(checkpoint_path): + directory = os.path.dirname(checkpoint_path) + prefix = os.path.basename(checkpoint_path) + pattern = re.compile(rf"{re.escape(prefix)}_epoch(\d+)\.pt$") + + try: + files = os.listdir(directory) + except FileNotFoundError: + return [] + + return [ + os.path.join(directory, f) + for f in files if pattern.match(f) + ] + + +def save_vae_checkpoint( + total_steps, epoch, model, discriminator, optimizer_d, + optimizer_g, metrics, checkpoint_path, logs, total_training_time +): + checkpoint = { + "total_steps": total_steps, + "epoch": epoch, + "model_state_dict": model.state_dict(), + "discriminator_state_dict": discriminator.state_dict(), + "optimizer_d_state_dict": optimizer_d.state_dict(), + "optimizer_g_state_dict": optimizer_g.state_dict(), + "metrics": metrics, + "logs": logs, + "total_training_time": total_training_time + } + checkpoint_file = f"{checkpoint_path}_epoch{epoch}.pt" + torch.save(checkpoint, checkpoint_file) + print(f"VQVAE Checkpoint saved at {checkpoint_file}") + all_ckpts = glob.glob(f"{checkpoint_path}_epoch*.pt") + # all_ckpts = find_checkpoints(checkpoint_path) + def extract_epoch(filename): + match = re.search(r"_epoch(\d+)\.pt", filename) + return int(match.group(1)) if match else -1 + all_ckpts = sorted(all_ckpts, key=extract_epoch) + for old_ckpt in all_ckpts[:-2]: + os.remove(old_ckpt) + print(f"Removed old VQVAE checkpoint: {old_ckpt}") + + + +def load_vae_checkpoint(checkpoint_path, model, discriminator, optimizer_d, optimizer_g, device=device): + all_ckpts = glob.glob(f"{checkpoint_path}_epoch*.pt") + # all_ckpts = find_checkpoints(checkpoint_path) + if not all_ckpts: + print("No VQVAE checkpoint found. Starting from scratch.") + return 0, 0, None, [], 0 + def extract_epoch(filename): + match = re.search(r"_epoch(\d+)\.pt", filename) + return int(match.group(1)) if match else -1 + all_ckpts = sorted(all_ckpts, key=extract_epoch) + latest_ckpt = all_ckpts[-1] + if os.path.exists(latest_ckpt): + checkpoint = torch.load(latest_ckpt, map_location=device) + model.load_state_dict(checkpoint["model_state_dict"]) + discriminator.load_state_dict(checkpoint["discriminator_state_dict"]) + optimizer_d.load_state_dict(checkpoint["optimizer_d_state_dict"]) + optimizer_g.load_state_dict(checkpoint["optimizer_g_state_dict"]) + total_steps = checkpoint["total_steps"] + epoch = checkpoint["epoch"] + metrics = checkpoint["metrics"] + logs = checkpoint.get("logs", []) + total_training_time = checkpoint.get("total_training_time", 0) + print(f"VQVAE Checkpoint loaded from {latest_ckpt}. Resuming from epoch {epoch + 1}, step {total_steps}") + return total_steps, epoch + 1, metrics, logs, total_training_time + else: + print("No VQVAE checkpoint found. Starting from scratch.") + return 0, 0, None, [], 0 + +from PIL import Image +def inference(model, dataset, save_path, epoch, device="cuda", sample_size=8): + if not os.path.exists(save_path): + os.makedirs(save_path) + + image_tensors = [] + for i in range(sample_size): + image_tensors.append(dataset[i].unsqueeze(0)) + + image_tensors = torch.cat(image_tensors, dim=0).to(device) + with torch.no_grad(): + outputs, _, _ = model(image_tensors) + + save_input = image_tensors.detach().cpu() + save_output = outputs.detach().cpu() + + grid = make_grid(torch.cat([save_input, save_output], dim=0), nrow=sample_size) + + np_img = (grid * 255).byte().numpy().transpose(1, 2, 0) + combined_image = Image.fromarray(np_img) + combined_image.save("output_image.png") + + # combined_image = tv.transforms.ToPILImage()(grid) + combined_image.save(os.path.join(save_path, f"reconstructed_images_EP-{epoch}_{sample_size}.png")) + + print(f"Reconstructed images saved at: {save_path}") + + +def trainVAE(Config, dataloader): + dataset_config = Config.dataset_params + autoencoder_config = Config.autoencoder_params + train_config = Config.train_params + paths = Config.paths + + seed = train_config.seed + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + if device == "cuda": + torch.cuda.manual_seed_all(seed) + + model = VQVAE(im_channels=dataset_config.im_channels, model_config=autoencoder_config).to(device) + discriminator = Discriminator(im_channels=dataset_config.im_channels).to(device) + + optimizer_d = torch.optim.AdamW(discriminator.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999)) + optimizer_g = torch.optim.AdamW(model.parameters(), lr=train_config.autoencoder_lr, betas=(0.5, 0.999)) + + checkpoint_path = os.path.join(train_config.task_name, "vqvae_ckpt.pth") + (total_steps, start_epoch, + metrics, logs, total_training_time) = load_vae_checkpoint(checkpoint_path, + model, discriminator, + optimizer_d, optimizer_g) + + if not os.path.exists(train_config.task_name): + os.mkdir(train_config.task_name) + + num_epochs = train_config.autoencoder_epochs + recon_criterion = torch.nn.MSELoss() + disc_criterion = torch.nn.MSELoss() + lpips_model = LPIPS().eval().to(device) + + acc_steps = train_config.autoencoder_acc_steps + disc_step_start = train_config.disc_start + + start_time_total = time.time() - total_training_time + + for epoch_idx in trange(start_epoch, num_epochs, colour='red', dynamic_ncols=True): + start_time_epoch = time.time() + epoch_log = [] + + for images in tqdm(dataloader, colour='green', dynamic_ncols=True): + batch_start_time = time.time() + total_steps += 1 + + images = images.to(device) + model_output = model(images) + output, z, quantize_losses = model_output + + recon_loss = recon_criterion(output, images) / acc_steps + + g_loss = ( + recon_loss + + (train_config.codebook_weight * quantize_losses["codebook_loss"] / acc_steps) + + (train_config.commitment_beta * quantize_losses["commitment_loss"] / acc_steps) + ) + + if total_steps > disc_step_start: + disc_fake_pred = discriminator(output) + disc_fake_loss = disc_criterion(disc_fake_pred, torch.ones_like(disc_fake_pred)) + g_loss += train_config.disc_weight * disc_fake_loss / acc_steps + + lpips_loss = torch.mean(lpips_model(output, images)) / acc_steps + g_loss += train_config.perceptual_weight * lpips_loss + + g_loss.backward() + + if total_steps % acc_steps == 0: + optimizer_g.step() + optimizer_g.zero_grad() + + if total_steps > disc_step_start: + disc_fake_pred = discriminator(output.detach()) + disc_real_pred = discriminator(images) + + # disc_loss = (disc_criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred)) + + # disc_criterion(disc_real_pred, torch.ones_like(disc_real_pred))) / 2 / acc_steps + + disc_fake_loss = disc_criterion(disc_fake_pred, torch.zeros(disc_fake_pred.shape, device=disc_fake_pred.device)) + disc_real_loss = disc_criterion(disc_real_pred, torch.ones(disc_real_pred.shape, device=disc_real_pred.device)) + + disc_loss = train_config.disc_weight * (disc_fake_loss + disc_real_loss) / 2 / acc_steps + + disc_loss.backward() + + if total_steps % acc_steps == 0: + optimizer_d.step() + optimizer_d.zero_grad() + + if total_steps % acc_steps == 0: + optimizer_g.step() + optimizer_g.zero_grad() + + batch_time = time.time() - batch_start_time + epoch_log.append(format_time(0, batch_time)) + + optimizer_d.step() + optimizer_d.zero_grad() + optimizer_g.step() + optimizer_g.zero_grad() + + epoch_time = time.time() - start_time_epoch + logs.append({"epoch": epoch_idx + 1, "epoch_time": format_time(0, epoch_time), "batch_times": epoch_log}) + + total_training_time = time.time() - start_time_total + + save_vae_checkpoint(total_steps, epoch_idx + 1, model, discriminator, optimizer_d, optimizer_g, metrics, checkpoint_path, logs, total_training_time) + recon_save_path = os.path.join(train_config.task_name, 'vqvae_recon') + inference(model, dataset, recon_save_path, epoch=epoch_idx, device=device, sample_size=16) + + print("Training completed.") + + + + +# ================================================================== +# S T A R T I N G - V Q - V A E - T R A I N I N G +# ================================================================== +# trainVAE(Config, dataloader) + +# python Vaani-VQVAE-Main.py | tee AE-training.log +# python Vaani-VQVAE-Main.py > AE-training.log 2>&1 + + + + + +# ================================================================== +# L I N E A R - N O I S E - S C H E D U L E R +# ================================================================== +class LinearNoiseScheduler: + r""" + Class for the linear noise scheduler that is used in DDPM. + """ + + def __init__(self, num_timesteps, beta_start, beta_end): + + self.num_timesteps = num_timesteps + self.beta_start = beta_start + self.beta_end = beta_end + # Mimicking how compvis repo creates schedule + self.betas = ( + torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_timesteps) ** 2 + ) + self.alphas = 1. - self.betas + self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0) + self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod) + self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod) + + def add_noise(self, original, noise, t): + r""" + Forward method for diffusion + :param original: Image on which noise is to be applied + :param noise: Random Noise Tensor (from normal dist) + :param t: timestep of the forward process of shape -> (B,) + :return: + """ + original_shape = original.shape + batch_size = original_shape[0] + + sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(batch_size) + sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(original.device)[t].reshape(batch_size) + + # Reshape till (B,) becomes (B,1,1,1) if image is (B,C,H,W) + for _ in range(len(original_shape) - 1): + sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1) + for _ in range(len(original_shape) - 1): + sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1) + + # Apply and Return Forward process equation + return (sqrt_alpha_cum_prod.to(original.device) * original + + sqrt_one_minus_alpha_cum_prod.to(original.device) * noise) + + def sample_prev_timestep(self, xt, noise_pred, t): + r""" + Use the noise prediction by model to get + xt-1 using xt and the nosie predicted + :param xt: current timestep sample + :param noise_pred: model noise prediction + :param t: current timestep we are at + :return: + """ + x0 = ((xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)) / + torch.sqrt(self.alpha_cum_prod.to(xt.device)[t])) + x0 = torch.clamp(x0, -1., 1.) + + mean = xt - ((self.betas.to(xt.device)[t]) * noise_pred) / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t]) + mean = mean / torch.sqrt(self.alphas.to(xt.device)[t]) + + if t == 0: + return mean, x0 + else: + variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t]) + variance = variance * self.betas.to(xt.device)[t] + sigma = variance ** 0.5 + z = torch.randn(xt.shape).to(xt.device) + + # OR + # variance = self.betas[t] + # sigma = variance ** 0.5 + # z = torch.randn(xt.shape).to(xt.device) + return mean + sigma * z, x0 + + + +# ================================================================== +# T I M E - E M B E D D I N G +# ================================================================== +def get_time_embedding(time_steps, temb_dim): + r""" + Convert time steps tensor into an embedding using the + sinusoidal time embedding formula + :param time_steps: 1D tensor of length batch size + :param temb_dim: Dimension of the embedding + :return: BxD embedding representation of B time steps + """ + assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2" + + # factor = 10000^(2i/d_model) + factor = 10000 ** ((torch.arange( + start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2)) + ) + + # pos / factor + # timesteps B -> B, 1 -> B, temb_dim + t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor + t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1) + return t_emb + + + + + +# ================================================================== +# L D M - U N E T - U P - B L O C K +# ================================================================== +class UpBlockUnet(nn.Module): + r""" + Up conv block with attention. + Sequence of following blocks + 1. Upsample + 1. Concatenate Down block output + 2. Resnet block with time embedding + 3. Attention Block + """ + + def __init__(self, in_channels, out_channels, t_emb_dim, up_sample, + num_heads, num_layers, norm_channels, cross_attn=False, context_dim=None): + super().__init__() + self.num_layers = num_layers + self.up_sample = up_sample + self.t_emb_dim = t_emb_dim + self.cross_attn = cross_attn + self.context_dim = context_dim + self.resnet_conv_first = nn.ModuleList( + [ + nn.Sequential( + nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels), + nn.SiLU(), + nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1, + padding=1), + ) + for i in range(num_layers) + ] + ) + + if self.t_emb_dim is not None: + self.t_emb_layers = nn.ModuleList([ + nn.Sequential( + nn.SiLU(), + nn.Linear(t_emb_dim, out_channels) + ) + for _ in range(num_layers) + ]) + + self.resnet_conv_second = nn.ModuleList( + [ + nn.Sequential( + nn.GroupNorm(norm_channels, out_channels), + nn.SiLU(), + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), + ) + for _ in range(num_layers) + ] + ) + + self.attention_norms = nn.ModuleList( + [ + nn.GroupNorm(norm_channels, out_channels) + for _ in range(num_layers) + ] + ) + + self.attentions = nn.ModuleList( + [ + nn.MultiheadAttention(out_channels, num_heads, batch_first=True) + for _ in range(num_layers) + ] + ) + + if self.cross_attn: + assert context_dim is not None, "Context Dimension must be passed for cross attention" + self.cross_attention_norms = nn.ModuleList( + [nn.GroupNorm(norm_channels, out_channels) + for _ in range(num_layers)] + ) + self.cross_attentions = nn.ModuleList( + [nn.MultiheadAttention(out_channels, num_heads, batch_first=True) + for _ in range(num_layers)] + ) + self.context_proj = nn.ModuleList( + [nn.Linear(context_dim, out_channels) + for _ in range(num_layers)] + ) + self.residual_input_conv = nn.ModuleList( + [ + nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1) + for i in range(num_layers) + ] + ) + self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, + 4, 2, 1) \ + if self.up_sample else nn.Identity() + + def forward(self, x, out_down=None, t_emb=None, context=None): + x = self.up_sample_conv(x) + if out_down is not None: + x = torch.cat([x, out_down], dim=1) + + out = x + for i in range(self.num_layers): + # --- Resnet -------------------- + resnet_input = out + out = self.resnet_conv_first[i](out) + if self.t_emb_dim is not None: + out = out + self.t_emb_layers[i](t_emb)[:, :, None, None] + out = self.resnet_conv_second[i](out) + out = out + self.residual_input_conv[i](resnet_input) + + # --- Self Attention ------------ + batch_size, channels, h, w = out.shape + in_attn = out.reshape(batch_size, channels, h * w) + in_attn = self.attention_norms[i](in_attn) + in_attn = in_attn.transpose(1, 2) + out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn) + out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) + out = out + out_attn + + # --- Cross Attention ----------- + if self.cross_attn: + assert context is not None, "context cannot be None if cross attention layers are used" + batch_size, channels, h, w = out.shape + in_attn = out.reshape(batch_size, channels, h * w) + in_attn = self.cross_attention_norms[i](in_attn) + in_attn = in_attn.transpose(1, 2) + assert len(context.shape) == 3, \ + "Context shape does not match B,_,CONTEXT_DIM" + assert context.shape[0] == x.shape[0] and context.shape[-1] == self.context_dim,\ + "Context shape does not match B,_,CONTEXT_DIM" + context_proj = self.context_proj[i](context) + out_attn, _ = self.cross_attentions[i](in_attn, context_proj, context_proj) + out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w) + out = out + out_attn + + return out + + + +# ================================================================== +# L D M - U N E T +# ================================================================== +class Unet(nn.Module): + r""" + Unet model comprising + Down blocks, Midblocks and Uplocks + """ + + def __init__(self, im_channels, model_config): + super().__init__() + self.down_channels = model_config.down_channels + self.mid_channels = model_config.mid_channels + self.t_emb_dim = model_config.time_emb_dim + self.down_sample = model_config.down_sample + self.num_down_layers = model_config.num_down_layers + self.num_mid_layers = model_config.num_mid_layers + self.num_up_layers = model_config.num_up_layers + self.attns = model_config.attn_down + self.norm_channels = model_config.norm_channels + self.num_heads = model_config.num_heads + self.conv_out_channels = model_config.conv_out_channels + + assert self.mid_channels[0] == self.down_channels[-1] + assert self.mid_channels[-1] == self.down_channels[-2] + assert len(self.down_sample) == len(self.down_channels) - 1 + assert len(self.attns) == len(self.down_channels) - 1 + + self.condition_config = model_config.condition_config + self.cond = condition_types = self.condition_config.condition_types + if 'audio' in condition_types: + self.audio_cond = True + self.audio_embed_dim = self.condition_config.audio_condition_config.audio_embed_dim + + # Initial projection from sinusoidal time embedding + self.t_proj = nn.Sequential( + nn.Linear(self.t_emb_dim, self.t_emb_dim), + nn.SiLU(), + nn.Linear(self.t_emb_dim, self.t_emb_dim), + ) + + # Context projection for whisper Encoder last hidden state + # [B, 1500, 1280] -> [B, 1280] + self.context_projector = nn.Sequential( + nn.Linear(self.audio_embed_dim, 320), + nn.SiLU(), + nn.Linear(320, 1) + ) + + self.up_sample = list(reversed(self.down_sample)) + self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=1) + + # --::----- D O W N - B l O C K S ----------------::--------------::---------------- + self.downs = nn.ModuleList([]) + for i in range(len(self.down_channels) - 1): + # Cross Attention and Context Dim only needed if text or audio condition is present + self.downs.append( + DownBlock( + self.down_channels[i], + self.down_channels[i + 1], + self.t_emb_dim, + down_sample=self.down_sample[i], + num_heads=self.num_heads, + num_layers=self.num_down_layers, + attn=self.attns[i], + norm_channels=self.norm_channels, + cross_attn=self.audio_cond, + context_dim=self.audio_embed_dim + ) + ) + + # --::----- M I D - B l O C K S ----------------::--------------::---------------- + self.mids = nn.ModuleList([]) + for i in range(len(self.mid_channels) - 1): + self.mids.append( + MidBlock( + self.mid_channels[i], + self.mid_channels[i + 1], + self.t_emb_dim, + num_heads=self.num_heads, + num_layers=self.num_mid_layers, + norm_channels=self.norm_channels, + cross_attn=self.audio_cond, + context_dim=self.audio_embed_dim + ) + ) + + # --::----- U P - B l O C K S ----------------::--------------::---------------- + self.ups = nn.ModuleList([]) + for i in reversed(range(len(self.down_channels) - 1)): + self.ups.append( + UpBlockUnet( + self.down_channels[i] * 2, + self.down_channels[i - 1] if i != 0 else self.conv_out_channels, + self.t_emb_dim, + up_sample=self.down_sample[i], + num_heads=self.num_heads, + num_layers=self.num_up_layers, + norm_channels=self.norm_channels, + cross_attn=self.audio_cond, + context_dim=self.audio_embed_dim + ) + ) + + self.norm_out = nn.GroupNorm(self.norm_channels, self.conv_out_channels) + self.conv_out = nn.Conv2d(self.conv_out_channels, im_channels, kernel_size=3, padding=1) + + def forward(self, x, t, cond_input=None): + # Shapes assuming downblocks are [C1, C2, C3, C4] + # Shapes assuming midblocks are [C4, C4, C3] + # Shapes assuming downsamples are [True, True, False] + # B x C x H x W + out = self.conv_in(x) + # B x C1 x H x W + + # t_emb -> B x t_emb_dim + t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim) + t_emb = self.t_proj(t_emb) + + # --- Conditioning --------------- + if self.audio_cond: + # context_hidden_states = cond_input + # print(self.audio_cond, cond_input.shape) + + last_hidden_state = cond_input + weights = self.context_projector(last_hidden_state) + weights = torch.softmax(weights, dim=1) # Normalize across time + pooled_embedding = (last_hidden_state * weights).sum(dim=1) # [1, 512] + context_hidden_states = pooled_embedding.unsqueeze(1) + + # print(context_hidden_states.shape) + # exit() + + # --- Down Pass ------------------ + down_outs = [] + for idx, down in enumerate(self.downs): + down_outs.append(out) + out = down(out, t_emb, context_hidden_states) + # down_outs [B x C1 x H x W, B x C2 x H/2 x W/2, B x C3 x H/4 x W/4] + # out B x C4 x H/4 x W/4 + + # --- Mid Pass ------------------ + for mid in self.mids: + out = mid(out, t_emb, context_hidden_states) + # out B x C3 x H/4 x W/4 + + # --- Up Pass ------------------ + for up in self.ups: + down_out = down_outs.pop() + out = up(out, down_out, t_emb, context_hidden_states) + # out [B x C2 x H/4 x W/4, B x C1 x H/2 x W/2, B x 16 x H x W] + + out = self.norm_out(out) + out = nn.SiLU()(out) + out = self.conv_out(out) + # out B x C x H x W + return out + + + +# ================================================================== +# L D M - T R A I N I N G +# ================================================================== +def find_checkpoints(checkpoint_path): + directory = os.path.dirname(checkpoint_path) + prefix = os.path.basename(checkpoint_path) + pattern = re.compile(rf"{re.escape(prefix)}_epoch(\d+)\.pt$") + + try: + files = os.listdir(directory) + except FileNotFoundError: + return [] + + return [ + os.path.join(directory, f) + for f in files if pattern.match(f) + ] + +def save_ldm_checkpoint(checkpoint_path, + total_steps, epoch, model, optimizer, + metrics, logs, total_training_time + ): + checkpoint = { + "total_steps": total_steps, + "epoch": epoch, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "metrics": metrics, + "logs": logs, + "total_training_time": total_training_time + } + checkpoint_file = f"{checkpoint_path}_epoch{epoch}.pt" + torch.save(checkpoint, checkpoint_file) + print(f"LDM Checkpoint saved at {checkpoint_file}") + all_ckpts = glob.glob(f"{checkpoint_path}_epoch*.pt") + # all_ckpts = find_checkpoints(checkpoint_path) + def extract_epoch(filename): + match = re.search(r"_epoch(\d+)\.pt", filename) + return int(match.group(1)) if match else -1 + all_ckpts = sorted(all_ckpts, key=extract_epoch) + for old_ckpt in all_ckpts[:-2]: + os.remove(old_ckpt) + print(f"Removed old LDM checkpoint: {old_ckpt}") + +def load_ldm_checkpoint(checkpoint_path, model, optimizer): + all_ckpts = glob.glob(f"{checkpoint_path}_epoch*.pt") + # all_ckpts = find_checkpoints(checkpoint_path) + if not all_ckpts: + print("No LDM checkpoint found. Starting from scratch.") + return 0, 0, None, [], 0 + def extract_epoch(filename): + match = re.search(r"_epoch(\d+)\.pt", filename) + return int(match.group(1)) if match else -1 + all_ckpts = sorted(all_ckpts, key=extract_epoch) + latest_ckpt = all_ckpts[-1] + if os.path.exists(latest_ckpt): + checkpoint = torch.load(latest_ckpt, map_location=device) + model.load_state_dict(checkpoint["model_state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + total_steps = checkpoint["total_steps"] + epoch = checkpoint["epoch"] + metrics = checkpoint["metrics"] + logs = checkpoint.get("logs", []) + total_training_time = checkpoint.get("total_training_time", 0) + print(f"LDM Checkpoint loaded from {latest_ckpt}. Resuming from epoch {epoch + 1}, step {total_steps}") + return total_steps, epoch + 1, metrics, logs, total_training_time + else: + print("No LDM checkpoint found. Starting from scratch.") + return 0, 0, None, [], 0 + +def load_ldm_vae_checkpoint(checkpoint_path, vae, device=device): + # all_ckpts = glob.glob(f"{checkpoint_path}_epoch*.pt") + all_ckpts = find_checkpoints(checkpoint_path) + if not all_ckpts: + print("No VQVAE checkpoint found.") + return 0, 0, None, [], 0 + def extract_epoch(filename): + match = re.search(r"_epoch(\d+)\.pt", filename) + return int(match.group(1)) if match else -1 + all_ckpts = sorted(all_ckpts, key=extract_epoch) + latest_ckpt = all_ckpts[-1] + if os.path.exists(latest_ckpt): + checkpoint = torch.load(latest_ckpt, map_location=device) + vae.load_state_dict(checkpoint["model_state_dict"]) + total_steps = checkpoint["total_steps"] + epoch = checkpoint["epoch"] + print(f"VQVAE Checkpoint loaded from {latest_ckpt} at epoch {epoch + 1} & step {total_steps}") + +def trainLDM(Config, dataloader): + diffusion_config = Config.diffusion_params + dataset_config = Config.dataset_params + diffusion_model_config = Config.ldm_params + autoencoder_model_config = Config.autoencoder_params + train_config = Config.train_params + condition_config = diffusion_model_config.condition_config + + seed = train_config.seed + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + if device == "cuda": + torch.cuda.manual_seed_all(seed) + + vqvae_device = "cuda:1" + ldm_device = "cuda:0" + # ldm_device = vqvae_device = device + + # scheduler = CosineNoiseScheduler(diffusion_config.num_timesteps) + scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config.num_timesteps, + beta_start=diffusion_config.beta_start, + beta_end=diffusion_config.beta_end) + + if not train_config.ldm_pretraining: + if condition_config is not None: + condition_types = condition_config.condition_types + if 'audio' in condition_types: + from msclap import CLAP # type: ignore + audio_model = CLAP(version = '2023', use_cuda=(True if "cuda" in device else False)) + + model = Unet(im_channels=autoencoder_model_config.z_channels, model_config=diffusion_model_config).to(device) + optimizer = torch.optim.AdamW(model.parameters(), lr=train_config.ldm_lr) + criterion = torch.nn.MSELoss() + num_epochs = train_config.ldm_epochs + + checkpoint_path = os.path.join(os.getcwd(), train_config.task_name, "ldmH_ckpt") + (total_steps, start_epoch, metrics, logs, + total_training_time) = load_ldm_checkpoint(checkpoint_path, model, optimizer) + + vae = VQVAE(im_channels=dataset_config.im_channels, model_config=autoencoder_model_config).eval().to(vqvae_device) + vae_checkpoint_path = os.path.join(os.getcwd(), train_config.task_name, "vqvae_ckpt") + load_ldm_vae_checkpoint(vae_checkpoint_path, vae, vqvae_device) + for param in vae.parameters(): + param.requires_grad = False + vae.eval() + + if not os.path.exists(train_config.task_name): + os.makedirs(train_config.task_name, exist_ok=True) + + acc_steps = train_config.ldm_acc_steps + disc_step_start = train_config.disc_start + start_time_total = time.time() - total_training_time + + model.train() + optimizer.zero_grad() + for epoch_idx in trange(start_epoch, num_epochs, desc=f"{device}-LDM Epoch", colour='red', dynamic_ncols=True): + start_time_epoch = time.time() + losses = [] + epoch_log = [] + + # Load latest vqvae checkpoints + vae_checkpoint_path = os.path.join(os.getcwd(), train_config.task_name, "vqvae_ckpt") + load_ldm_vae_checkpoint(vae_checkpoint_path, vae, vqvae_device) + for param in vae.parameters(): + param.requires_grad = False + vae.eval() + + # for images, cond_input in tqdm(dataloader, colour='green', dynamic_ncols=True): + for images in tqdm(dataloader, colour='green', dynamic_ncols=True): + cond_input = None + batch_start_time = time.time() + total_steps += 1 + batch_size = images.shape[0] + + # images = images.to(device) + with torch.no_grad(): + images, _ = vae.encode(images.to(vqvae_device)) + images = images.to(ldm_device) + + # Conditional Input + audio_embed_dim = condition_config.audio_condition_config.audio_embed_dim + # empty_audio_embedding = torch.zeros(audio_embed_dim, device=device).float().unsqueeze(0).repeat(batch_size, 1).unsqueeze(1) + # empty_audio_embedding = torch.zeros((1500,1280), device=device).float().unsqueeze(0).repeat(batch_size, 1).unsqueeze(1) + empty_audio_embedding = torch.zeros((batch_size, 1500, 1280), device=device).float() + if not train_config.ldm_pretraining: + if 'audio' in condition_types: + with torch.no_grad(): + audio_embeddings = audio_model.get_audio_embeddings(cond_input) + text_drop_prob = condition_config.audio_condition_config.cond_drop_prob + text_drop_mask = torch.zeros((images.shape[0]), device=images.device).float().uniform_(0, 1) < text_drop_prob + audio_embeddings[text_drop_mask, :, :] = empty_audio_embedding[0] + else: + audio_embeddings = empty_audio_embedding + + # Sample random noise + noise = torch.randn_like(images).to(device) + + # Sample timestep + t = torch.randint(0, diffusion_config.num_timesteps, (images.shape[0],)).to(device) + + # Add noise to images according to timestep + noisy_images = scheduler.add_noise(images, noise, t) + noise_pred = model(noisy_images, t, cond_input=audio_embeddings) + + loss = criterion(noise_pred, noise) + losses.append(loss.item()) + loss = loss / acc_steps + loss.backward() + + if total_steps % acc_steps == 0: + optimizer.step() + optimizer.zero_grad() + + if total_steps % acc_steps == 0: + optimizer.step() + optimizer.zero_grad() + + print(f'Finished epoch:{epoch_idx + 1}/{num_epochs} | Loss : {np.mean(losses):.4f}') + + epoch_time = time.time() - start_time_epoch + logs.append({"epoch": epoch_idx + 1, "epoch_time": format_time(0, epoch_time), "batch_times": epoch_log}) + + total_training_time = time.time() - start_time_total + save_ldm_checkpoint(checkpoint_path, total_steps, epoch_idx + 1, model, optimizer, metrics, logs, total_training_time) + + + infer(Config) + + # Checking to conntinue training + train_continue = DotDict.from_dict(yaml.safe_load(open(config_path, 'r'))) + if train_continue.training.continue_ldm == False: + print('LDM Training Stoped ...') + break + + print('Done Training ...') + + + +# ================================================================== +# L D M - S A M P L I N G +# ================================================================== +def sample(model, scheduler, train_config, diffusion_model_config, + autoencoder_model_config, diffusion_config, dataset_config, + vae, audio_model + ): + r""" + Sample stepwise by going backward one timestep at a time. + We save the x0 predictions + """ + im_size = dataset_config.im_size // 2**sum(autoencoder_model_config.down_sample) + xt = torch.randn((train_config.num_samples, + autoencoder_model_config.z_channels, + im_size, + im_size)).to(device) + + audio_embed_dim = diffusion_model_config.condition_config.audio_condition_config.audio_embed_dim + # empty_audio_embedding = torch.zeros(audio_embed_dim, device=device).float() + # empty_audio_embedding = torch.zeros(audio_embed_dim, device=device).float().unsqueeze(0) + # empty_audio_embedding = empty_audio_embedding.repeat(train_config.num_samples, 1).unsqueeze(1) + empty_audio_embedding = torch.zeros((train_config.num_samples, 1500, 1280), device=device).float() + if not train_config.ldm_pretraining: + # Create Conditional input + pass + else: + audio_embeddings = empty_audio_embedding + + uncond_input = empty_audio_embedding + cond_input = audio_embeddings + + save_count = 0 + for i in tqdm(reversed(range(diffusion_config.num_timesteps)), + total=diffusion_config.num_timesteps, + colour='blue', desc="Sampling", dynamic_ncols=True): + # Get prediction of noise + t = (torch.ones((xt.shape[0],)) * i).long().to(device) + # t = torch.as_tensor(i).unsqueeze(0).to(device) + noise_pred_cond = model(xt, t, cond_input) + + cf_guidance_scale = train_config.cf_guidance_scale + if cf_guidance_scale > 1: + noise_pred_uncond = model(xt, t, uncond_input) + noise_pred = noise_pred_uncond + cf_guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: + noise_pred = noise_pred_cond + + # Use scheduler to get x0 and xt-1 + xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device)) + + # Save x0 + #ims = torch.clamp(xt, -1., 1.).detach().cpu() + if i == 0: + # Decode ONLY the final iamge to save time + ims = vae.decode(xt) + else: + # ims = xt + ims = x0_pred + + ims = torch.clamp(ims, -1., 1.).detach().cpu() + ims = (ims + 1) / 2 + grid = make_grid(ims, nrow=train_config.num_grid_rows) + # img = tv.transforms.ToPILImage()(grid) + np_img = (grid * 255).byte().numpy().transpose(1, 2, 0) + img = Image.fromarray(np_img) + + if not os.path.exists(os.path.join(train_config.task_name, 'samplesH')): + os.makedirs(os.path.join(train_config.task_name, 'samplesH'), exist_ok=True) + + img.save(os.path.join(train_config.task_name, 'samplesH', 'x0_{}.png'.format(i))) + img.close() + + +def infer(Config): + diffusion_config = Config.diffusion_params + dataset_config = Config.dataset_params + diffusion_model_config = Config.ldm_params + autoencoder_model_config = Config.autoencoder_params + train_config = Config.train_params + + # scheduler = CosineNoiseScheduler(diffusion_config.num_timesteps) + scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config.num_timesteps, + beta_start=diffusion_config.beta_start, + beta_end=diffusion_config.beta_end) + + model = Unet(im_channels=autoencoder_model_config.z_channels, + model_config=diffusion_model_config).eval().to(device) + vae = VQVAE(im_channels=dataset_config.im_channels, + model_config=autoencoder_model_config).eval().to(device) + + if os.path.exists(os.path.join(train_config.task_name, train_config.ldm_ckpt_name)): + checkpoint_path = os.path.join(train_config.task_name, train_config.ldm_ckpt_name) + checkpoint = torch.load(checkpoint_path, map_location=device) + model.load_state_dict(checkpoint["model_state_dict"]) + vae.load_state_dict(checkpoint["vae_state_dict"]) + print('Loaded unet & vae checkpoint') + + # Create output directories + if not os.path.exists(train_config.task_name): + os.makedirs(train_config.task_name, exist_ok=True) + + with torch.no_grad(): + sample(model, scheduler, train_config, diffusion_model_config, + autoencoder_model_config, diffusion_config, dataset_config, vae, None) + + + + +# ================================================================== +# S T A R T I N G - L D M - T R A I N I N G +# ================================================================== +# trainLDM(Config, dataloader) + + + +if sys.argv[1] == 'train_vae': + trainVAE(Config, dataloader) +elif sys.argv[1] == 'train_ldm': + trainLDM(Config, dataloader) +else: + infer(Config) + + + +# git add . && git commit -m "LDM" && git push -u origin master +# huggingface-cli upload alpha31476/Vaani-Audio2Img-LDM . --commit-message "SDFT" \ No newline at end of file diff --git a/Vaani/_7_Vaani-LDM-Pre.py b/Vaani/_7_Vaani-LDM-Pre.py index d21606c6b09a85dd7e11aecd71e71117ad123283..df74121f3c2068678d2906b951690d90113ed5da 100644 --- a/Vaani/_7_Vaani-LDM-Pre.py +++ b/Vaani/_7_Vaani-LDM-Pre.py @@ -42,7 +42,7 @@ from torch.utils.data import Dataset, DataLoader from torchvision.utils import make_grid print("TIME:", datetime.datetime.now()) -# os.environ["CUDA_VISIBLE_DEVICES"] = f"{sys.argv[2]}" +os.environ["CUDA_VISIBLE_DEVICES"] = f"{sys.argv[2]}" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print("DEVICE:", device) @@ -1730,9 +1730,9 @@ def trainLDM(Config, dataloader): if device == "cuda": torch.cuda.manual_seed_all(seed) - vqvae_device = "cuda:1" - ldm_device = "cuda:0" - # ldm_device = vqvae_device = device + vqvae_device = "cuda:0" + ldm_device = "cuda:1" + ldm_device = vqvae_device = device # scheduler = CosineNoiseScheduler(diffusion_config.num_timesteps) scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config.num_timesteps, diff --git a/Vaani/config-LDM-High-Pre.yaml b/Vaani/config-LDM-High-Pre.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f1b6668c606e6b1eb02080f5586e39284e610944 --- /dev/null +++ b/Vaani/config-LDM-High-Pre.yaml @@ -0,0 +1,77 @@ +dataset_params: + im_channels: 3 + im_size: 128 + +paths: + images_dir: "/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Images" + +diffusion_params: + num_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + +ldm_params: + down_channels: [ 128, 384, 512, 768 ] + mid_channels: [ 768, 512 ] + down_sample: [ True, True, True ] + attn_down: [ True, True, True ] + time_emb_dim: 512 + norm_channels: 32 + num_heads: 16 + conv_out_channels: 128 + num_down_layers: 4 + num_mid_layers: 4 + num_up_layers: 4 + condition_config: + condition_types: [ 'audio' ] + audio_condition_config: + text_embed_model: 'whisper' + train_audio_embed_model: False + audio_embed_dim: 1280 + cond_drop_prob: 0.1 + + +autoencoder_params: + z_channels: 3 + codebook_size: 20 + down_channels: [ 32, 64, 128 ] + mid_channels: [ 128, 128 ] + down_sample: [ True, True ] + attn_down: [ False, False ] + norm_channels: 32 + num_heads: 16 + num_down_layers: 4 + num_mid_layers: 4 + num_up_layers: 4 + +train_params: + seed: 4422 + task_name: 'VaaniLDM' + ldm_batch_size: 32 + ldm_pretraining: True + debug: False + autoencoder_batch_size: 80 + disc_start: 1000 + disc_weight: 0.5 + codebook_weight: 1 + commitment_beta: 0.2 + perceptual_weight: 1 + kl_weight: 0.000005 + ldm_epochs: 50 + autoencoder_epochs: 30 + num_samples: 9 + num_grid_rows: 3 + ldm_lr: 0.00001 + autoencoder_lr: 0.0001 + ldm_acc_steps: 1 + autoencoder_acc_steps: 1 + autoencoder_img_save_steps: 8 + save_latents: True + cf_guidance_scale : 1.0 + vqvae_latent_dir_name: 'vqvae_latents' + ldm_ckpt_name: 'ddpm_ckpt.pth' + vqvae_ckpt_name: 'vqvae_ckpt.pth' + +training: + continue_vqvae: True + continue_ldm: True diff --git a/Vaani/config-LDM-Pre.yaml b/Vaani/config-LDM-Pre.yaml index 90f684f7a553adccb8c4e6e449b4b888a6767a45..d845e6b18ba89146ba37996d99fdb44ae2a39a6c 100644 --- a/Vaani/config-LDM-Pre.yaml +++ b/Vaani/config-LDM-Pre.yaml @@ -47,7 +47,7 @@ autoencoder_params: train_params: seed: 4422 task_name: 'VaaniLDM' - ldm_batch_size: 32 + ldm_batch_size: 16 ldm_pretraining: True debug: False autoencoder_batch_size: 80 @@ -57,7 +57,7 @@ train_params: commitment_beta: 0.2 perceptual_weight: 1 kl_weight: 0.000005 - ldm_epochs: 50 + ldm_epochs: 60 autoencoder_epochs: 30 num_samples: 9 num_grid_rows: 3