alpha31476 commited on
Commit
cb656a6
·
verified ·
1 Parent(s): 87ef7b5
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. Vaani/SDFT/_2.ipynb +675 -0
  3. Vaani/SDFT/_2_.py +345 -0
  4. Vaani/SDFT/_2_DDP.py +316 -0
  5. Vaani/SDFT/checkpoints/checkpoint.pth +3 -0
  6. Vaani/SDFT/download_model.py +13 -0
  7. Vaani/SDFT/vaani-stablediffusion-finetune-kaggle.ipynb +650 -0
  8. Vaani/VaaniLDM/ddpm_ckpt_epoch31.pt +3 -0
  9. Vaani/VaaniLDM/ddpm_ckpt_epoch32.pt +3 -0
  10. Vaani/VaaniLDM/ldmH_ckpt_epoch24.pt +3 -0
  11. Vaani/VaaniLDM/ldmH_ckpt_epoch25.pt +3 -0
  12. Vaani/VaaniLDM/samples/x0_0.png +2 -2
  13. Vaani/VaaniLDM/samples/x0_1.png +0 -0
  14. Vaani/VaaniLDM/samples/x0_10.png +0 -0
  15. Vaani/VaaniLDM/samples/x0_100.png +0 -0
  16. Vaani/VaaniLDM/samples/x0_101.png +0 -0
  17. Vaani/VaaniLDM/samples/x0_102.png +0 -0
  18. Vaani/VaaniLDM/samples/x0_103.png +0 -0
  19. Vaani/VaaniLDM/samples/x0_104.png +0 -0
  20. Vaani/VaaniLDM/samples/x0_105.png +0 -0
  21. Vaani/VaaniLDM/samples/x0_106.png +0 -0
  22. Vaani/VaaniLDM/samples/x0_107.png +0 -0
  23. Vaani/VaaniLDM/samples/x0_108.png +0 -0
  24. Vaani/VaaniLDM/samples/x0_109.png +0 -0
  25. Vaani/VaaniLDM/samples/x0_11.png +0 -0
  26. Vaani/VaaniLDM/samples/x0_110.png +0 -0
  27. Vaani/VaaniLDM/samples/x0_111.png +0 -0
  28. Vaani/VaaniLDM/samples/x0_112.png +0 -0
  29. Vaani/VaaniLDM/samples/x0_113.png +0 -0
  30. Vaani/VaaniLDM/samples/x0_114.png +0 -0
  31. Vaani/VaaniLDM/samples/x0_115.png +0 -0
  32. Vaani/VaaniLDM/samples/x0_116.png +0 -0
  33. Vaani/VaaniLDM/samples/x0_117.png +0 -0
  34. Vaani/VaaniLDM/samples/x0_118.png +0 -0
  35. Vaani/VaaniLDM/samples/x0_119.png +0 -0
  36. Vaani/VaaniLDM/samples/x0_12.png +0 -0
  37. Vaani/VaaniLDM/samples/x0_120.png +0 -0
  38. Vaani/VaaniLDM/samples/x0_121.png +0 -0
  39. Vaani/VaaniLDM/samples/x0_122.png +0 -0
  40. Vaani/VaaniLDM/samples/x0_123.png +0 -0
  41. Vaani/VaaniLDM/samples/x0_124.png +0 -0
  42. Vaani/VaaniLDM/samples/x0_125.png +0 -0
  43. Vaani/VaaniLDM/samples/x0_126.png +0 -0
  44. Vaani/VaaniLDM/samples/x0_127.png +0 -0
  45. Vaani/VaaniLDM/samples/x0_128.png +0 -0
  46. Vaani/VaaniLDM/samples/x0_129.png +0 -0
  47. Vaani/VaaniLDM/samples/x0_13.png +0 -0
  48. Vaani/VaaniLDM/samples/x0_130.png +0 -0
  49. Vaani/VaaniLDM/samples/x0_131.png +0 -0
  50. Vaani/VaaniLDM/samples/x0_132.png +0 -0
.gitattributes CHANGED
@@ -135,3 +135,4 @@ Vaani/output_image2.png filter=lfs diff=lfs merge=lfs -text
135
  Vaani/sampleJSON.csv filter=lfs diff=lfs merge=lfs -text
136
  Vaani/sampleJSON.json filter=lfs diff=lfs merge=lfs -text
137
  tools/__pycache__/pynvml.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
 
 
135
  Vaani/sampleJSON.csv filter=lfs diff=lfs merge=lfs -text
136
  Vaani/sampleJSON.json filter=lfs diff=lfs merge=lfs -text
137
  tools/__pycache__/pynvml.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
138
+ Vaani/VaaniLDM/samplesH/x0_0.png filter=lfs diff=lfs merge=lfs -text
Vaani/SDFT/_2.ipynb ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "aab59bea",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "data": {
11
+ "text/plain": [
12
+ "'cuda'"
13
+ ]
14
+ },
15
+ "execution_count": 1,
16
+ "metadata": {},
17
+ "output_type": "execute_result"
18
+ }
19
+ ],
20
+ "source": [
21
+ "import torch\n",
22
+ "import torch.optim as optim\n",
23
+ "from torch.utils.data import Dataset, DataLoader\n",
24
+ "from torchvision import transforms\n",
25
+ "from torchvision.transforms import v2\n",
26
+ "from PIL import Image\n",
27
+ "from diffusers import StableDiffusionPipeline\n",
28
+ "from diffusers.optimization import get_scheduler\n",
29
+ "from torch import nn\n",
30
+ "import torch.nn.functional as F\n",
31
+ "import os\n",
32
+ "import pandas as pd\n",
33
+ "from tqdm import trange, tqdm\n",
34
+ "\n",
35
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
36
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
37
+ "device"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": 2,
43
+ "id": "8f13b66f",
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "# import torch\n",
48
+ "# import torch.nn as nn\n",
49
+ "# import torch.nn.functional as F\n",
50
+ "\n",
51
+ "# audio_embed_dim = 1280\n",
52
+ "# output_dim = 768\n",
53
+ "# device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
54
+ "\n",
55
+ "# context_projector = nn.Sequential(\n",
56
+ "# nn.Linear(audio_embed_dim, 320),\n",
57
+ "# nn.SiLU(),\n",
58
+ "# nn.Linear(320, output_dim)\n",
59
+ "# ).to(device).half()\n",
60
+ "\n",
61
+ "# # Dummy input\n",
62
+ "# audio_embedding = dummy_audio = torch.zeros(10, 1500, 1280, device=device, dtype=torch.float16)\n",
63
+ "# print(audio_embedding.shape) # [10, 1500, 1280]\n",
64
+ "\n",
65
+ "# # Project audio to [10, 1500, 768]\n",
66
+ "# projected = context_projector(audio_embedding)\n",
67
+ "# print(projected.shape) # [10, 1500, 768]\n",
68
+ "\n",
69
+ "# # Compute attention scores: reduce feature dim to scalar per time step\n",
70
+ "# attn_scores = projected.mean(dim=2) # [10, 1500]\n",
71
+ "# attn_weights = F.softmax(attn_scores, dim=1) # [10, 1500]\n",
72
+ "# attn_weights = attn_weights.unsqueeze(2) # [10, 1500, 1]\n",
73
+ "\n",
74
+ "# # Weighted average\n",
75
+ "# pooled = (projected * attn_weights).sum(dim=1, keepdim=True) # [10, 1, 768]\n",
76
+ "# print(pooled.shape) # Final shape: [10, 1, 768]\n"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": null,
82
+ "id": "d32b7d9d",
83
+ "metadata": {},
84
+ "outputs": [],
85
+ "source": [
86
+ "# === Helpers ===\n",
87
+ "def walkDIR(folder_path, include=None):\n",
88
+ " file_list = []\n",
89
+ " for root, _, files in os.walk(folder_path):\n",
90
+ " for file in files:\n",
91
+ " if include is None or any(file.endswith(ext) for ext in include):\n",
92
+ " file_list.append(os.path.join(root, file))\n",
93
+ " print(\"Files found:\", len(file_list))\n",
94
+ " return file_list\n",
95
+ "\n",
96
+ "# === Dataset Class ===\n",
97
+ "class VaaniDataset(torch.utils.data.Dataset):\n",
98
+ " def __init__(self, files_paths, im_size):\n",
99
+ " self.files_paths = files_paths\n",
100
+ " self.im_size = im_size\n",
101
+ "\n",
102
+ " def __len__(self):\n",
103
+ " return len(self.files_paths)\n",
104
+ "\n",
105
+ " def __getitem__(self, idx):\n",
106
+ " # image = tv.io.read_image(self.files_paths[idx], mode=tv.io.ImageReadMode.RGB)\n",
107
+ " image = Image.open(self.files_paths[idx]).convert(\"RGB\")\n",
108
+ " image = v2.ToImage()(image)\n",
109
+ " # image = tv.io.decode_image(self.files_paths[idx], mode=tv.io.ImageReadMode.RGB)\n",
110
+ " image = v2.Resize((self.im_size, self.im_size))(image)\n",
111
+ " image = v2.ToDtype(torch.float32, scale=True)(image)\n",
112
+ " # image = 2*image - 1\n",
113
+ " return image\n",
114
+ "\n",
115
+ "\n",
116
+ "def create_dataloader(dataset, batch_size, debug=False, val_split=0.1, num_workers=4):\n",
117
+ " if debug:\n",
118
+ " s = 0.001\n",
119
+ " dataset, _ = torch.utils.data.random_split(dataset, [s, 1-s], torch.manual_seed(42))\n",
120
+ " print(\"Length of Train dataset:\", len(dataset))\n",
121
+ "\n",
122
+ " train_dataloader = DataLoader(\n",
123
+ " dataset, \n",
124
+ " batch_size=batch_size, \n",
125
+ " shuffle=True, \n",
126
+ " num_workers=num_workers,\n",
127
+ " pin_memory=True,\n",
128
+ " drop_last=True,\n",
129
+ " persistent_workers=True\n",
130
+ " )\n",
131
+ " \n",
132
+ " images = next(iter(train_dataloader))\n",
133
+ " print('Total Batches:', len(train_dataloader))\n",
134
+ " print('BATCH SHAPE:', images.shape)\n",
135
+ " return train_dataloader\n",
136
+ "\n",
137
+ "# === Audio Context Projector ===\n",
138
+ "# class AudioContextProjector(nn.Module):\n",
139
+ "# def __init__(self, audio_embed_dim):\n",
140
+ "# super().__init__()\n",
141
+ "# self.audio_embed_dim = audio_embed_dim\n",
142
+ "# self.context_projector = nn.Sequential(\n",
143
+ "# nn.Linear(audio_embed_dim, 320),\n",
144
+ "# nn.SiLU(),\n",
145
+ "# nn.Linear(320, 1)\n",
146
+ "# )\n",
147
+ "\n",
148
+ "# def forward(self, audio_embedding):\n",
149
+ "# if audio_embedding.size(-1) != self.audio_embed_dim:\n",
150
+ "# raise ValueError(f\"Expected audio embedding dim {self.audio_embed_dim}, got {audio_embedding.size(-1)}\")\n",
151
+ "# weights = self.context_projector(audio_embedding) # [B, T, 1]\n",
152
+ "# weights = torch.softmax(weights, dim=1) # [B, T, 1]\n",
153
+ "# pooled = (audio_embedding * weights).sum(dim=1) # [B, 1280]\n",
154
+ "# return pooled.unsqueeze(1) # [B, 1, 1280]\n",
155
+ "# class AudioContextProjector(nn.Module):\n",
156
+ "# def __init__(self, audio_embed_dim=1280, output_dim=768): # Add output_dim for flexibility\n",
157
+ "# super().__init__()\n",
158
+ "# self.audio_embed_dim = audio_embed_dim\n",
159
+ "# self.context_projector = nn.Sequential(\n",
160
+ "# nn.Linear(audio_embed_dim, 320),\n",
161
+ "# nn.SiLU(),\n",
162
+ "# nn.Linear(320, output_dim) # Output 768 to match UNet's expectation\n",
163
+ "# )\n",
164
+ "\n",
165
+ "# def forward(self, audio_embedding):\n",
166
+ "# if audio_embedding.size(-1) != self.audio_embed_dim:\n",
167
+ "# raise ValueError(f\"Expected audio embedding dim {self.audio_embed_dim}, got {audio_embedding.size(-1)}\")\n",
168
+ "# weights = self.context_projector(audio_embedding) # [B, T, 768]\n",
169
+ "# weights = torch.softmax(pooled, dim=1) # [B, T, 768]\n",
170
+ "# pooled = (audio_embedding * weights).sum(dim=1) # [B, 768]\n",
171
+ "# return pooled.unsqueeze(1) # [B, 1, 768]\n",
172
+ "class AudioContextProjector(nn.Module):\n",
173
+ " def __init__(self, audio_embed_dim=1280, output_dim=768):\n",
174
+ " super().__init__()\n",
175
+ " self.audio_embed_dim = audio_embed_dim\n",
176
+ " self.output_dim = output_dim\n",
177
+ " self.context_projector = nn.Sequential(\n",
178
+ " nn.Linear(audio_embed_dim, 320),\n",
179
+ " nn.SiLU(),\n",
180
+ " nn.Linear(320, output_dim) # Output 768 to match UNet's expectation\n",
181
+ " )\n",
182
+ "\n",
183
+ " def forward(self, audio_embedding):\n",
184
+ " if audio_embedding.size(-1) != self.audio_embed_dim:\n",
185
+ " raise ValueError(f\"Expected audio embedding dim {self.audio_embed_dim}, got {audio_embedding.size(-1)}\")\n",
186
+ "\n",
187
+ " # Project to [B, T, 768]\n",
188
+ " projected = self.context_projector(audio_embedding) # [B, T, 768]\n",
189
+ "\n",
190
+ " # Compute scalar attention scores per timestep\n",
191
+ " attn_scores = projected.mean(dim=2) # [B, T]\n",
192
+ " attn_weights = F.softmax(attn_scores, dim=1) # [B, T]\n",
193
+ " attn_weights = attn_weights.unsqueeze(2) # [B, T, 1]\n",
194
+ "\n",
195
+ " # Apply attention to the projected embeddings\n",
196
+ " pooled = (projected * attn_weights).sum(dim=1, keepdim=True) # [B, 1, 768]\n",
197
+ " return pooled\n",
198
+ "\n",
199
+ "\n",
200
+ "\n",
201
+ "# === Inference Function ===\n",
202
+ "def run_inference(pipe, unet, vae, device, context_hidden_states, save_path=\"inference_output.png\"):\n",
203
+ " pipe.unet = unet\n",
204
+ " pipe.vae = vae\n",
205
+ " pipe.to(device)\n",
206
+ "\n",
207
+ " batch_size = 1\n",
208
+ " latents = torch.randn((batch_size, pipe.unet.in_channels, 64, 64), device=device, dtype=torch.float16)\n",
209
+ " # latents = torch.randn((batch_size, pipe.unet.config.in_channels, 64, 64), device=device, dtype=torch.float16)\n",
210
+ " pipe.scheduler.set_timesteps(50)\n",
211
+ " latents = latents * pipe.scheduler.init_noise_sigma\n",
212
+ " \n",
213
+ " expected_shape = (batch_size, 1, 768) # Adjust based on model\n",
214
+ " if context_hidden_states.shape != expected_shape:\n",
215
+ " raise ValueError(f\"Expected context_hidden_states shape {expected_shape}, got {context_hidden_states.shape}\")\n",
216
+ " \n",
217
+ " for t in pipe.scheduler.timesteps:\n",
218
+ " with torch.no_grad():\n",
219
+ " noise_pred = pipe.unet(latents, t, encoder_hidden_states=context_hidden_states).sample\n",
220
+ " latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample\n",
221
+ "\n",
222
+ " # latents = 1 / 0.18215 * latents\n",
223
+ " latents = 1 / pipe.vae.config.scaling_factor * latents\n",
224
+ " with torch.no_grad():\n",
225
+ " image = pipe.vae.decode(latents).sample\n",
226
+ "\n",
227
+ " image = (image / 2 + 0.5).clamp(0, 1)\n",
228
+ " image = image.cpu().permute(0, 2, 3, 1).numpy()[0]\n",
229
+ " image = Image.fromarray((image * 255).astype(\"uint8\"))\n",
230
+ " image.save(save_path)\n",
231
+ " print(f\"Inference image saved to {save_path}\")\n",
232
+ "\n",
233
+ "\n",
234
+ "# === Load Pipeline ===\n",
235
+ "def load_pipeline(model_id, device):\n",
236
+ " pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)\n",
237
+ " unet = pipe.unet\n",
238
+ " vae = pipe.vae\n",
239
+ " return pipe, unet, vae\n",
240
+ "\n",
241
+ "# === Freeze Layers Function ===\n",
242
+ "def freeze_vae_layers(vae):\n",
243
+ " vae.encoder.requires_grad_(False)\n",
244
+ " vae.quant_conv.requires_grad_(False)\n",
245
+ " vae.decoder.requires_grad_(True)\n",
246
+ " vae.post_quant_conv.requires_grad_(True)\n",
247
+ "\n",
248
+ "def freeze_unet_layers(unet):\n",
249
+ " for name, param in unet.named_parameters():\n",
250
+ " if \"attn2\" in name or \"conv2\" in name:\n",
251
+ " param.requires_grad = True\n",
252
+ " else:\n",
253
+ " param.requires_grad = False\n",
254
+ "\n",
255
+ "# === Optimizer Setup ===\n",
256
+ "def setup_optimizer(vae, unet, projector, lr):\n",
257
+ " params_to_optimize = list(filter(lambda p: p.requires_grad, vae.parameters())) + \\\n",
258
+ " list(filter(lambda p: p.requires_grad, unet.parameters())) + \\\n",
259
+ " list(filter(lambda p: p.requires_grad, projector.parameters()))\n",
260
+ " optimizer = optim.AdamW(params_to_optimize, lr=lr)\n",
261
+ " return optimizer\n",
262
+ "\n",
263
+ "\n",
264
+ "# === Gradient Accumulation Function ===\n",
265
+ "def accumulate_gradients(optimizer, loss, gradient_accumulation_steps, step, dataloader):\n",
266
+ " loss = loss / gradient_accumulation_steps\n",
267
+ " loss.backward()\n",
268
+ " if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == len(dataloader):\n",
269
+ " optimizer.step()\n",
270
+ " optimizer.zero_grad()\n",
271
+ "\n",
272
+ "# === Save Checkpoint Function ===\n",
273
+ "def save_checkpoint(epoch, unet, vae, projector, optimizer, checkpoint_path):\n",
274
+ " # checkpoint_path = f\"{save_dir}/checkpoint.pth\"\n",
275
+ " torch.save({\n",
276
+ " 'epoch': epoch,\n",
277
+ " 'unet': unet.state_dict(),\n",
278
+ " 'vae': vae.state_dict(),\n",
279
+ " 'projector': projector.state_dict(),\n",
280
+ " 'optimizer': optimizer.state_dict(),\n",
281
+ " }, checkpoint_path)\n",
282
+ " print(f\"Checkpoint saved to {checkpoint_path}\")\n",
283
+ "\n",
284
+ "# === Resume from Checkpoint Function ===\n",
285
+ "def resume_from_checkpoint(checkpoint_path, unet, vae, projector, optimizer):\n",
286
+ " if os.path.exists(checkpoint_path):\n",
287
+ " checkpoint = torch.load(checkpoint_path, map_location='cpu')\n",
288
+ " unet.load_state_dict(checkpoint['unet'])\n",
289
+ " vae.load_state_dict(checkpoint['vae'])\n",
290
+ " projector.load_state_dict(checkpoint['projector'])\n",
291
+ " optimizer.load_state_dict(checkpoint['optimizer'])\n",
292
+ " start_epoch = checkpoint['epoch'] + 1\n",
293
+ " print(f\"Resuming training from epoch {start_epoch}...\")\n",
294
+ " return start_epoch\n",
295
+ " else:\n",
296
+ " print(\"No checkpoint found, starting from scratch.\")\n",
297
+ " return 0\n",
298
+ "\n",
299
+ "\n",
300
+ "# === Training Loop Function ===\n",
301
+ "def train_loop(dataloader, unet, vae, optimizer, gradient_accumulation_steps, device, num_epochs, samples_path, checkpoint_path, pipe, projector):\n",
302
+ " start_epoch = resume_from_checkpoint(checkpoint_path, unet, vae, projector, optimizer)\n",
303
+ "\n",
304
+ " for epoch in trange(start_epoch, num_epochs, colour='red', desc=f'{device}-training', ncols=100):\n",
305
+ " unet.train()\n",
306
+ " vae.train()\n",
307
+ " projector.train()\n",
308
+ " total_loss = 0\n",
309
+ " step = 0\n",
310
+ " \n",
311
+ " for image in tqdm(dataloader, colour='green', desc=f'{device}-batch', ncols=100):\n",
312
+ " # print(\"step:\", step)\n",
313
+ " image = image.to(device, dtype=torch.float16)\n",
314
+ "\n",
315
+ " latents = vae.encode(image).latent_dist.sample() * 0.18215\n",
316
+ " noise = torch.randn_like(latents)\n",
317
+ " # timesteps = torch.randint(0, 1000, (latents.shape[0],), device=device).long()\n",
318
+ " timesteps = torch.randint(0, pipe.scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long()\n",
319
+ "\n",
320
+ " # === Use dummy audio embedding ===\n",
321
+ " dummy_audio = torch.zeros(image.size(0), 1500, 1280, device=device, dtype=torch.float16)\n",
322
+ " context_hidden_states = projector(dummy_audio)\n",
323
+ "\n",
324
+ " # print(\"Model IP\")\n",
325
+ " noise_pred = unet(latents + noise, timesteps, encoder_hidden_states=context_hidden_states).sample\n",
326
+ " # print(\"Model OP\")\n",
327
+ "\n",
328
+ " loss = nn.MSELoss()(noise_pred, noise)\n",
329
+ " total_loss += loss.item()\n",
330
+ "\n",
331
+ " step += 1\n",
332
+ " accumulate_gradients(optimizer, loss, gradient_accumulation_steps, step, dataloader)\n",
333
+ "\n",
334
+ " avg_loss = total_loss / len(dataloader)\n",
335
+ " print(f\"Epoch {epoch + 1} | Avg Loss: {avg_loss:.6f}\")\n",
336
+ "\n",
337
+ " save_checkpoint(epoch, unet, vae, projector, optimizer, checkpoint_path)\n",
338
+ " run_inference(pipe, unet, vae, device, context_hidden_states, save_path=f\"{samples_path}/inference_epoch{epoch + 1}.png\")\n",
339
+ "\n",
340
+ " print(\"\\n✅ Fine-tuning complete.\")"
341
+ ]
342
+ },
343
+ {
344
+ "cell_type": "code",
345
+ "execution_count": 4,
346
+ "id": "9ad5f6a3",
347
+ "metadata": {},
348
+ "outputs": [],
349
+ "source": [
350
+ "# === Main Function ===\n",
351
+ "def main():\n",
352
+ " model_id = \"runwayml/stable-diffusion-v1-5\"\n",
353
+ " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
354
+ " lr = 1e-5\n",
355
+ " num_epochs = 10\n",
356
+ " batch_size = 16\n",
357
+ " debug = False\n",
358
+ " gradient_accumulation_steps = 1\n",
359
+ " \n",
360
+ " os.makedirs(f\"./checkpoints\", exist_ok=True)\n",
361
+ " os.makedirs(f\"./samples\", exist_ok=True)\n",
362
+ " checkpoint_path = f\"./checkpoints/checkpoint.pth\"\n",
363
+ " samples_path = f\"./samples\"\n",
364
+ " image_dir = \"/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Images\"\n",
365
+ "\n",
366
+ " pipe, unet, vae = load_pipeline(model_id, device)\n",
367
+ " freeze_vae_layers(vae)\n",
368
+ " freeze_unet_layers(unet)\n",
369
+ " projector = AudioContextProjector(audio_embed_dim=1280, output_dim=768).to(device).half()\n",
370
+ " optimizer = setup_optimizer(vae, unet, projector, lr)\n",
371
+ "\n",
372
+ " # === Dataset & Dataloader ===\n",
373
+ " files = walkDIR(image_dir, include=['.png', '.jpeg', '.jpg'])\n",
374
+ " dataset = VaaniDataset(files_paths=files, im_size=256)\n",
375
+ " image = dataset[2]\n",
376
+ " print('IMAGE SHAPE:', image.shape, \"Dataset len:\", len(dataset))\n",
377
+ " dataloader = create_dataloader(dataset, batch_size, debug=debug)\n",
378
+ "\n",
379
+ " train_loop(dataloader, unet, vae, optimizer, gradient_accumulation_steps, device, num_epochs, samples_path, checkpoint_path, pipe, projector)"
380
+ ]
381
+ },
382
+ {
383
+ "cell_type": "code",
384
+ "execution_count": 5,
385
+ "id": "e71b4ba9",
386
+ "metadata": {},
387
+ "outputs": [
388
+ {
389
+ "name": "stderr",
390
+ "output_type": "stream",
391
+ "text": [
392
+ "Couldn't connect to the Hub: (MaxRetryError('HTTPSConnectionPool(host=\\'huggingface.co\\', port=443): Max retries exceeded with url: /api/models/runwayml/stable-diffusion-v1-5 (Caused by NameResolutionError(\"<urllib3.connection.HTTPSConnection object at 0x7fd9a9445c40>: Failed to resolve \\'huggingface.co\\' ([Errno -2] Name or service not known)\"))'), '(Request ID: bcd4fcc3-8634-4bfe-8454-3b4dbdcc1222)').\n",
393
+ "Will try to load from local cache.\n"
394
+ ]
395
+ },
396
+ {
397
+ "data": {
398
+ "application/vnd.jupyter.widget-view+json": {
399
+ "model_id": "1014662fa9c44a00b0e9e6b3d1e9747d",
400
+ "version_major": 2,
401
+ "version_minor": 0
402
+ },
403
+ "text/plain": [
404
+ "Loading pipeline components...: 0%| | 0/7 [00:00<?, ?it/s]"
405
+ ]
406
+ },
407
+ "metadata": {},
408
+ "output_type": "display_data"
409
+ },
410
+ {
411
+ "name": "stdout",
412
+ "output_type": "stream",
413
+ "text": [
414
+ "Files found: 128807\n",
415
+ "IMAGE SHAPE: torch.Size([3, 256, 256]) Dataset len: 128807\n",
416
+ "Total Batches: 8050\n",
417
+ "BATCH SHAPE: torch.Size([16, 3, 256, 256])\n",
418
+ "No checkpoint found, starting from scratch.\n"
419
+ ]
420
+ },
421
+ {
422
+ "name": "stderr",
423
+ "output_type": "stream",
424
+ "text": [
425
+ "cuda-training: 0%|\u001b[31m \u001b[0m| 0/10 [00:00<?, ?it/s]\u001b[0m"
426
+ ]
427
+ },
428
+ {
429
+ "name": "stdout",
430
+ "output_type": "stream",
431
+ "text": [
432
+ "step: 0\n",
433
+ "Model IP\n",
434
+ "Model OP\n"
435
+ ]
436
+ },
437
+ {
438
+ "name": "stderr",
439
+ "output_type": "stream",
440
+ "text": []
441
+ },
442
+ {
443
+ "name": "stdout",
444
+ "output_type": "stream",
445
+ "text": [
446
+ "step: 1\n",
447
+ "Model IP\n",
448
+ "Model OP\n"
449
+ ]
450
+ },
451
+ {
452
+ "name": "stderr",
453
+ "output_type": "stream",
454
+ "text": []
455
+ },
456
+ {
457
+ "name": "stdout",
458
+ "output_type": "stream",
459
+ "text": [
460
+ "step: 2\n",
461
+ "Model IP\n"
462
+ ]
463
+ },
464
+ {
465
+ "name": "stderr",
466
+ "output_type": "stream",
467
+ "text": []
468
+ },
469
+ {
470
+ "name": "stdout",
471
+ "output_type": "stream",
472
+ "text": [
473
+ "Model OP\n",
474
+ "step: 3\n",
475
+ "Model IP\n",
476
+ "Model OP\n"
477
+ ]
478
+ },
479
+ {
480
+ "name": "stderr",
481
+ "output_type": "stream",
482
+ "text": []
483
+ },
484
+ {
485
+ "name": "stdout",
486
+ "output_type": "stream",
487
+ "text": [
488
+ "step: 4\n",
489
+ "Model IP\n",
490
+ "Model OP\n"
491
+ ]
492
+ },
493
+ {
494
+ "name": "stderr",
495
+ "output_type": "stream",
496
+ "text": []
497
+ },
498
+ {
499
+ "name": "stdout",
500
+ "output_type": "stream",
501
+ "text": [
502
+ "step: 5\n",
503
+ "Model IP\n"
504
+ ]
505
+ },
506
+ {
507
+ "name": "stderr",
508
+ "output_type": "stream",
509
+ "text": []
510
+ },
511
+ {
512
+ "name": "stdout",
513
+ "output_type": "stream",
514
+ "text": [
515
+ "Model OP\n",
516
+ "step: 6\n",
517
+ "Model IP\n"
518
+ ]
519
+ },
520
+ {
521
+ "name": "stderr",
522
+ "output_type": "stream",
523
+ "text": []
524
+ },
525
+ {
526
+ "name": "stdout",
527
+ "output_type": "stream",
528
+ "text": [
529
+ "Model OP\n",
530
+ "step: 7\n",
531
+ "Model IP\n"
532
+ ]
533
+ },
534
+ {
535
+ "name": "stderr",
536
+ "output_type": "stream",
537
+ "text": []
538
+ },
539
+ {
540
+ "name": "stdout",
541
+ "output_type": "stream",
542
+ "text": [
543
+ "Model OP\n",
544
+ "step: 8\n",
545
+ "Model IP\n",
546
+ "Model OP\n"
547
+ ]
548
+ },
549
+ {
550
+ "name": "stderr",
551
+ "output_type": "stream",
552
+ "text": []
553
+ },
554
+ {
555
+ "name": "stdout",
556
+ "output_type": "stream",
557
+ "text": [
558
+ "step: 9\n",
559
+ "Model IP\n",
560
+ "Model OP\n"
561
+ ]
562
+ },
563
+ {
564
+ "name": "stderr",
565
+ "output_type": "stream",
566
+ "text": []
567
+ },
568
+ {
569
+ "name": "stdout",
570
+ "output_type": "stream",
571
+ "text": [
572
+ "step: 10\n",
573
+ "Model IP\n",
574
+ "Model OP\n"
575
+ ]
576
+ },
577
+ {
578
+ "name": "stderr",
579
+ "output_type": "stream",
580
+ "text": []
581
+ },
582
+ {
583
+ "name": "stdout",
584
+ "output_type": "stream",
585
+ "text": [
586
+ "step: 11\n",
587
+ "Model IP\n"
588
+ ]
589
+ },
590
+ {
591
+ "name": "stderr",
592
+ "output_type": "stream",
593
+ "text": []
594
+ },
595
+ {
596
+ "name": "stdout",
597
+ "output_type": "stream",
598
+ "text": [
599
+ "Model OP\n",
600
+ "step: 12\n",
601
+ "Model IP\n"
602
+ ]
603
+ },
604
+ {
605
+ "name": "stderr",
606
+ "output_type": "stream",
607
+ "text": []
608
+ },
609
+ {
610
+ "name": "stdout",
611
+ "output_type": "stream",
612
+ "text": [
613
+ "Model OP\n",
614
+ "step: 13\n",
615
+ "Model IP\n",
616
+ "Model OP\n"
617
+ ]
618
+ },
619
+ {
620
+ "name": "stderr",
621
+ "output_type": "stream",
622
+ "text": [
623
+ "cuda-batch: 0%|\u001b[32m \u001b[0m| 14/8050 [00:06<1:02:22, 2.15it/s]\u001b[0m\n",
624
+ "cuda-training: 0%|\u001b[31m \u001b[0m| 0/10 [00:06<?, ?it/s]\u001b[0m\n"
625
+ ]
626
+ },
627
+ {
628
+ "name": "stdout",
629
+ "output_type": "stream",
630
+ "text": [
631
+ "step: 14\n"
632
+ ]
633
+ },
634
+ {
635
+ "ename": "KeyboardInterrupt",
636
+ "evalue": "",
637
+ "output_type": "error",
638
+ "traceback": [
639
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
640
+ "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)",
641
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[5]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[34m__name__\u001b[39m == \u001b[33m\"\u001b[39m\u001b[33m__main__\u001b[39m\u001b[33m\"\u001b[39m:\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m \u001b[43mmain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
642
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 30\u001b[39m, in \u001b[36mmain\u001b[39m\u001b[34m()\u001b[39m\n\u001b[32m 27\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m'\u001b[39m\u001b[33mIMAGE SHAPE:\u001b[39m\u001b[33m'\u001b[39m, image.shape, \u001b[33m\"\u001b[39m\u001b[33mDataset len:\u001b[39m\u001b[33m\"\u001b[39m, \u001b[38;5;28mlen\u001b[39m(dataset))\n\u001b[32m 28\u001b[39m dataloader = create_dataloader(dataset, batch_size, debug=debug)\n\u001b[32m---> \u001b[39m\u001b[32m30\u001b[39m \u001b[43mtrain_loop\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataloader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43munet\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvae\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient_accumulation_steps\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_epochs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msamples_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcheckpoint_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpipe\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprojector\u001b[49m\u001b[43m)\u001b[49m\n",
643
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 228\u001b[39m, in \u001b[36mtrain_loop\u001b[39m\u001b[34m(dataloader, unet, vae, optimizer, gradient_accumulation_steps, device, num_epochs, samples_path, checkpoint_path, pipe, projector)\u001b[39m\n\u001b[32m 226\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m image \u001b[38;5;129;01min\u001b[39;00m tqdm(dataloader, colour=\u001b[33m'\u001b[39m\u001b[33mgreen\u001b[39m\u001b[33m'\u001b[39m, desc=\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdevice\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m-batch\u001b[39m\u001b[33m'\u001b[39m, ncols=\u001b[32m100\u001b[39m):\n\u001b[32m 227\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[33mstep:\u001b[39m\u001b[33m\"\u001b[39m, step)\n\u001b[32m--> \u001b[39m\u001b[32m228\u001b[39m image = \u001b[43mimage\u001b[49m\u001b[43m.\u001b[49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtorch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mfloat16\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 230\u001b[39m latents = vae.encode(image).latent_dist.sample() * \u001b[32m0.18215\u001b[39m\n\u001b[32m 231\u001b[39m noise = torch.randn_like(latents)\n",
644
+ "\u001b[31mKeyboardInterrupt\u001b[39m: "
645
+ ]
646
+ }
647
+ ],
648
+ "source": [
649
+ "if __name__ == \"__main__\":\n",
650
+ " main()"
651
+ ]
652
+ }
653
+ ],
654
+ "metadata": {
655
+ "kernelspec": {
656
+ "display_name": "Python 3",
657
+ "language": "python",
658
+ "name": "python3"
659
+ },
660
+ "language_info": {
661
+ "codemirror_mode": {
662
+ "name": "ipython",
663
+ "version": 3
664
+ },
665
+ "file_extension": ".py",
666
+ "mimetype": "text/x-python",
667
+ "name": "python",
668
+ "nbconvert_exporter": "python",
669
+ "pygments_lexer": "ipython3",
670
+ "version": "3.12.2"
671
+ }
672
+ },
673
+ "nbformat": 4,
674
+ "nbformat_minor": 5
675
+ }
Vaani/SDFT/_2_.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.optim as optim
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from torchvision import transforms
5
+ from torchvision.transforms import v2
6
+ from PIL import Image
7
+ from diffusers import StableDiffusionPipeline
8
+ from diffusers.optimization import get_scheduler
9
+ from torch import nn
10
+ import torch.nn.functional as F
11
+ import os
12
+ import pandas as pd
13
+ from tqdm import trange, tqdm
14
+
15
+ os.environ["CUDA_VISIBLE_DEVICES"] = "1"
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ device
18
+
19
+
20
+ # import torch
21
+ # import torch.nn as nn
22
+ # import torch.nn.functional as F
23
+
24
+ # audio_embed_dim = 1280
25
+ # output_dim = 768
26
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
27
+
28
+ # context_projector = nn.Sequential(
29
+ # nn.Linear(audio_embed_dim, 320),
30
+ # nn.SiLU(),
31
+ # nn.Linear(320, output_dim)
32
+ # ).to(device).half()
33
+
34
+ # # Dummy input
35
+ # audio_embedding = dummy_audio = torch.zeros(10, 1500, 1280, device=device, dtype=torch.float16)
36
+ # print(audio_embedding.shape) # [10, 1500, 1280]
37
+
38
+ # # Project audio to [10, 1500, 768]
39
+ # projected = context_projector(audio_embedding)
40
+ # print(projected.shape) # [10, 1500, 768]
41
+
42
+ # # Compute attention scores: reduce feature dim to scalar per time step
43
+ # attn_scores = projected.mean(dim=2) # [10, 1500]
44
+ # attn_weights = F.softmax(attn_scores, dim=1) # [10, 1500]
45
+ # attn_weights = attn_weights.unsqueeze(2) # [10, 1500, 1]
46
+
47
+ # # Weighted average
48
+ # pooled = (projected * attn_weights).sum(dim=1, keepdim=True) # [10, 1, 768]
49
+ # print(pooled.shape) # Final shape: [10, 1, 768]
50
+
51
+
52
+
53
+ # === Helpers ===
54
+ def walkDIR(folder_path, include=None):
55
+ file_list = []
56
+ for root, _, files in os.walk(folder_path):
57
+ for file in files:
58
+ if include is None or any(file.endswith(ext) for ext in include):
59
+ file_list.append(os.path.join(root, file))
60
+ print("Files found:", len(file_list))
61
+ return file_list
62
+
63
+ # === Dataset Class ===
64
+ class VaaniDataset(torch.utils.data.Dataset):
65
+ def __init__(self, files_paths, im_size):
66
+ self.files_paths = files_paths
67
+ self.im_size = im_size
68
+
69
+ def __len__(self):
70
+ return len(self.files_paths)
71
+
72
+ def __getitem__(self, idx):
73
+ # image = tv.io.read_image(self.files_paths[idx], mode=tv.io.ImageReadMode.RGB)
74
+ image = Image.open(self.files_paths[idx]).convert("RGB")
75
+ image = v2.ToImage()(image)
76
+ # image = tv.io.decode_image(self.files_paths[idx], mode=tv.io.ImageReadMode.RGB)
77
+ image = v2.Resize((self.im_size, self.im_size))(image)
78
+ image = v2.ToDtype(torch.float32, scale=True)(image)
79
+ # image = 2*image - 1
80
+ return image
81
+
82
+
83
+ def create_dataloader(dataset, batch_size, debug=False, val_split=0.1, num_workers=4):
84
+ if debug:
85
+ s = 0.001
86
+ dataset, _ = torch.utils.data.random_split(dataset, [s, 1-s], torch.manual_seed(42))
87
+ print("Length of Train dataset:", len(dataset))
88
+
89
+ train_dataloader = DataLoader(
90
+ dataset,
91
+ batch_size=batch_size,
92
+ shuffle=True,
93
+ num_workers=num_workers,
94
+ pin_memory=True,
95
+ drop_last=True,
96
+ persistent_workers=True
97
+ )
98
+
99
+ images = next(iter(train_dataloader))
100
+ print('Total Batches:', len(train_dataloader))
101
+ print('BATCH SHAPE:', images.shape)
102
+ return train_dataloader
103
+
104
+ # === Audio Context Projector ===
105
+ # class AudioContextProjector(nn.Module):
106
+ # def __init__(self, audio_embed_dim):
107
+ # super().__init__()
108
+ # self.audio_embed_dim = audio_embed_dim
109
+ # self.context_projector = nn.Sequential(
110
+ # nn.Linear(audio_embed_dim, 320),
111
+ # nn.SiLU(),
112
+ # nn.Linear(320, 1)
113
+ # )
114
+
115
+ # def forward(self, audio_embedding):
116
+ # if audio_embedding.size(-1) != self.audio_embed_dim:
117
+ # raise ValueError(f"Expected audio embedding dim {self.audio_embed_dim}, got {audio_embedding.size(-1)}")
118
+ # weights = self.context_projector(audio_embedding) # [B, T, 1]
119
+ # weights = torch.softmax(weights, dim=1) # [B, T, 1]
120
+ # pooled = (audio_embedding * weights).sum(dim=1) # [B, 1280]
121
+ # return pooled.unsqueeze(1) # [B, 1, 1280]
122
+ # class AudioContextProjector(nn.Module):
123
+ # def __init__(self, audio_embed_dim=1280, output_dim=768): # Add output_dim for flexibility
124
+ # super().__init__()
125
+ # self.audio_embed_dim = audio_embed_dim
126
+ # self.context_projector = nn.Sequential(
127
+ # nn.Linear(audio_embed_dim, 320),
128
+ # nn.SiLU(),
129
+ # nn.Linear(320, output_dim) # Output 768 to match UNet's expectation
130
+ # )
131
+
132
+ # def forward(self, audio_embedding):
133
+ # if audio_embedding.size(-1) != self.audio_embed_dim:
134
+ # raise ValueError(f"Expected audio embedding dim {self.audio_embed_dim}, got {audio_embedding.size(-1)}")
135
+ # weights = self.context_projector(audio_embedding) # [B, T, 768]
136
+ # weights = torch.softmax(pooled, dim=1) # [B, T, 768]
137
+ # pooled = (audio_embedding * weights).sum(dim=1) # [B, 768]
138
+ # return pooled.unsqueeze(1) # [B, 1, 768]
139
+ class AudioContextProjector(nn.Module):
140
+ def __init__(self, audio_embed_dim=1280, output_dim=768):
141
+ super().__init__()
142
+ self.audio_embed_dim = audio_embed_dim
143
+ self.output_dim = output_dim
144
+ self.context_projector = nn.Sequential(
145
+ nn.Linear(audio_embed_dim, 320),
146
+ nn.SiLU(),
147
+ nn.Linear(320, output_dim) # Output 768 to match UNet's expectation
148
+ )
149
+
150
+ def forward(self, audio_embedding):
151
+ if audio_embedding.size(-1) != self.audio_embed_dim:
152
+ raise ValueError(f"Expected audio embedding dim {self.audio_embed_dim}, got {audio_embedding.size(-1)}")
153
+
154
+ # Project to [B, T, 768]
155
+ projected = self.context_projector(audio_embedding) # [B, T, 768]
156
+
157
+ # Compute scalar attention scores per timestep
158
+ attn_scores = projected.mean(dim=2) # [B, T]
159
+ attn_weights = F.softmax(attn_scores, dim=1) # [B, T]
160
+ attn_weights = attn_weights.unsqueeze(2) # [B, T, 1]
161
+
162
+ # Apply attention to the projected embeddings
163
+ pooled = (projected * attn_weights).sum(dim=1, keepdim=True) # [B, 1, 768]
164
+ return pooled
165
+
166
+
167
+
168
+ # === Inference Function ===
169
+ def run_inference(pipe, unet, vae, device, context_hidden_states, save_path="inference_output.png"):
170
+ pipe.unet = unet
171
+ pipe.vae = vae
172
+ pipe.to(device)
173
+
174
+ batch_size = 1
175
+ latents = torch.randn((batch_size, pipe.unet.in_channels, 64, 64), device=device, dtype=torch.float16)
176
+ # latents = torch.randn((batch_size, pipe.unet.config.in_channels, 64, 64), device=device, dtype=torch.float16)
177
+ pipe.scheduler.set_timesteps(50)
178
+ latents = latents * pipe.scheduler.init_noise_sigma
179
+
180
+ expected_shape = (batch_size, 1, 768) # Adjust based on model
181
+ if context_hidden_states.shape != expected_shape:
182
+ raise ValueError(f"Expected context_hidden_states shape {expected_shape}, got {context_hidden_states.shape}")
183
+
184
+ for t in pipe.scheduler.timesteps:
185
+ with torch.no_grad():
186
+ noise_pred = pipe.unet(latents, t, encoder_hidden_states=context_hidden_states).sample
187
+ latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample
188
+
189
+ # latents = 1 / 0.18215 * latents
190
+ latents = 1 / pipe.vae.config.scaling_factor * latents
191
+ with torch.no_grad():
192
+ image = pipe.vae.decode(latents).sample
193
+
194
+ image = (image / 2 + 0.5).clamp(0, 1)
195
+ image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
196
+ image = Image.fromarray((image * 255).astype("uint8"))
197
+ image.save(save_path)
198
+ print(f"Inference image saved to {save_path}")
199
+
200
+
201
+ # === Load Pipeline ===
202
+ def load_pipeline(model_id, device):
203
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)
204
+ unet = pipe.unet
205
+ vae = pipe.vae
206
+ return pipe, unet, vae
207
+
208
+ # === Freeze Layers Function ===
209
+ def freeze_vae_layers(vae):
210
+ vae.encoder.requires_grad_(False)
211
+ vae.quant_conv.requires_grad_(False)
212
+ vae.decoder.requires_grad_(True)
213
+ vae.post_quant_conv.requires_grad_(True)
214
+
215
+ def freeze_unet_layers(unet):
216
+ for name, param in unet.named_parameters():
217
+ if "attn2" in name or "conv2" in name:
218
+ param.requires_grad = True
219
+ else:
220
+ param.requires_grad = False
221
+
222
+ # === Optimizer Setup ===
223
+ def setup_optimizer(vae, unet, projector, lr):
224
+ params_to_optimize = list(filter(lambda p: p.requires_grad, vae.parameters())) + \
225
+ list(filter(lambda p: p.requires_grad, unet.parameters())) + \
226
+ list(filter(lambda p: p.requires_grad, projector.parameters()))
227
+ optimizer = optim.AdamW(params_to_optimize, lr=lr)
228
+ return optimizer
229
+
230
+
231
+ # === Gradient Accumulation Function ===
232
+ def accumulate_gradients(optimizer, loss, gradient_accumulation_steps, step, dataloader):
233
+ loss = loss / gradient_accumulation_steps
234
+ loss.backward()
235
+ if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == len(dataloader):
236
+ optimizer.step()
237
+ optimizer.zero_grad()
238
+
239
+ # === Save Checkpoint Function ===
240
+ def save_checkpoint(epoch, unet, vae, projector, optimizer, checkpoint_path):
241
+ # checkpoint_path = f"{save_dir}/checkpoint.pth"
242
+ torch.save({
243
+ 'epoch': epoch,
244
+ 'unet': unet.state_dict(),
245
+ 'vae': vae.state_dict(),
246
+ 'projector': projector.state_dict(),
247
+ 'optimizer': optimizer.state_dict(),
248
+ }, checkpoint_path)
249
+ print(f"Checkpoint saved to {checkpoint_path}")
250
+
251
+ # === Resume from Checkpoint Function ===
252
+ def resume_from_checkpoint(checkpoint_path, unet, vae, projector, optimizer):
253
+ if os.path.exists(checkpoint_path):
254
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
255
+ unet.load_state_dict(checkpoint['unet'])
256
+ vae.load_state_dict(checkpoint['vae'])
257
+ projector.load_state_dict(checkpoint['projector'])
258
+ optimizer.load_state_dict(checkpoint['optimizer'])
259
+ start_epoch = checkpoint['epoch'] + 1
260
+ print(f"Resuming training from epoch {start_epoch}...")
261
+ return start_epoch
262
+ else:
263
+ print("No checkpoint found, starting from scratch.")
264
+ return 0
265
+
266
+
267
+ # === Training Loop Function ===
268
+ def train_loop(dataloader, unet, vae, optimizer, gradient_accumulation_steps, device, num_epochs, samples_path, checkpoint_path, pipe, projector):
269
+ start_epoch = resume_from_checkpoint(checkpoint_path, unet, vae, projector, optimizer)
270
+
271
+ for epoch in trange(start_epoch, num_epochs, colour='red', desc=f'{device}-training', dynamic_ncols=True):
272
+ unet.train()
273
+ vae.train()
274
+ projector.train()
275
+ total_loss = 0
276
+ step = 0
277
+
278
+ for image in tqdm(dataloader, colour='green', desc=f'{device}-batch', dynamic_ncols=True):
279
+ # print("step:", step)
280
+ image = image.to(device, dtype=torch.float16)
281
+
282
+ latents = vae.encode(image).latent_dist.sample() * 0.18215
283
+ noise = torch.randn_like(latents)
284
+ # timesteps = torch.randint(0, 1000, (latents.shape[0],), device=device).long()
285
+ timesteps = torch.randint(0, pipe.scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long()
286
+
287
+ # === Use dummy audio embedding ===
288
+ dummy_audio = torch.zeros(image.size(0), 1500, 1280, device=device, dtype=torch.float16)
289
+ context_hidden_states = projector(dummy_audio)
290
+
291
+ # print("Model IP")
292
+ noise_pred = unet(latents + noise, timesteps, encoder_hidden_states=context_hidden_states).sample
293
+ # print("Model OP")
294
+
295
+ loss = nn.MSELoss()(noise_pred, noise)
296
+ total_loss += loss.item()
297
+
298
+ step += 1
299
+ accumulate_gradients(optimizer, loss, gradient_accumulation_steps, step, dataloader)
300
+
301
+ avg_loss = total_loss / len(dataloader)
302
+ print(f"Epoch {epoch + 1} | Avg Loss: {avg_loss:.6f}")
303
+
304
+ save_checkpoint(epoch, unet, vae, projector, optimizer, checkpoint_path)
305
+ run_inference(pipe, unet, vae, device, context_hidden_states, save_path=f"{samples_path}/inference_epoch{epoch + 1}.png")
306
+
307
+ print("\n✅ Fine-tuning complete.")
308
+
309
+
310
+ # === Main Function ===
311
+ def main():
312
+ model_id = "runwayml/stable-diffusion-v1-5"
313
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
314
+ lr = 1e-5
315
+ num_epochs = 10
316
+ batch_size = 16
317
+ debug = False
318
+ gradient_accumulation_steps = 1
319
+
320
+ os.makedirs(f"./checkpoints", exist_ok=True)
321
+ os.makedirs(f"./samples", exist_ok=True)
322
+ checkpoint_path = f"./checkpoints/checkpoint.pth"
323
+ samples_path = f"./samples"
324
+ image_dir = "/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Images"
325
+
326
+ pipe, unet, vae = load_pipeline(model_id, device)
327
+ freeze_vae_layers(vae)
328
+ freeze_unet_layers(unet)
329
+ projector = AudioContextProjector(audio_embed_dim=1280, output_dim=768).to(device).half()
330
+ optimizer = setup_optimizer(vae, unet, projector, lr)
331
+
332
+ # === Dataset & Dataloader ===
333
+ files = walkDIR(image_dir, include=['.png', '.jpeg', '.jpg'])
334
+ dataset = VaaniDataset(files_paths=files, im_size=256)
335
+ image = dataset[2]
336
+ print('IMAGE SHAPE:', image.shape, "Dataset len:", len(dataset))
337
+ dataloader = create_dataloader(dataset, batch_size, debug=debug)
338
+
339
+ train_loop(dataloader, unet, vae, optimizer, gradient_accumulation_steps, device, num_epochs, samples_path, checkpoint_path, pipe, projector)
340
+
341
+
342
+ if __name__ == "__main__":
343
+ main()
344
+
345
+
Vaani/SDFT/_2_DDP.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.optim as optim
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from torchvision import transforms
5
+ from torchvision.transforms import v2
6
+ from PIL import Image
7
+ from diffusers import StableDiffusionPipeline
8
+ from diffusers.optimization import get_scheduler
9
+ from torch import nn
10
+ import torch.nn.functional as F
11
+ import os
12
+ import pandas as pd
13
+ from tqdm import trange, tqdm
14
+ # DDP Imports
15
+ import torch.distributed as dist
16
+ from torch.nn.parallel import DistributedDataParallel as DDP
17
+ from torch.utils.data.distributed import DistributedSampler
18
+ import torch.multiprocessing as mp
19
+
20
+ # Set CUDA_VISIBLE_DEVICES
21
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "1"
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+
24
+ # === Helpers ===
25
+ def walkDIR(folder_path, include=None):
26
+ file_list = []
27
+ for root, _, files in os.walk(folder_path):
28
+ for file in files:
29
+ if include is None or any(file.endswith(ext) for ext in include):
30
+ file_list.append(os.path.join(root, file))
31
+ print("Files found:", len(file_list))
32
+ return file_list
33
+
34
+ # === Dataset Class ===
35
+ class VaaniDataset(torch.utils.data.Dataset):
36
+ def __init__(self, files_paths, im_size):
37
+ self.files_paths = files_paths
38
+ self.im_size = im_size
39
+
40
+ def __len__(self):
41
+ return len(self.files_paths)
42
+
43
+ def __getitem__(self, idx):
44
+ image = Image.open(self.files_paths[idx]).convert("RGB")
45
+ image = v2.ToImage()(image)
46
+ image = v2.Resize((self.im_size, self.im_size))(image)
47
+ image = v2.ToDtype(torch.float32, scale=True)(image)
48
+ return image
49
+
50
+ # === Modified create_dataloader for DDP and single GPU ===
51
+ def create_dataloader(dataset, batch_size, debug=False, val_split=0.1, num_workers=4, rank=None, is_distributed=False):
52
+ if debug:
53
+ s = 0.001
54
+ dataset, _ = torch.utils.data.random_split(dataset, [s, 1-s], torch.manual_seed(42))
55
+ print(f"{'Rank ' + str(rank) + ': ' if rank is not None else ''}Length of Train dataset: {len(dataset)}")
56
+
57
+ # Use DistributedSampler only if DDP is active
58
+ sampler = DistributedSampler(dataset, shuffle=True) if is_distributed else None
59
+ train_dataloader = DataLoader(
60
+ dataset,
61
+ batch_size=batch_size,
62
+ shuffle=(sampler is None),
63
+ sampler=sampler,
64
+ num_workers=num_workers,
65
+ pin_memory=True,
66
+ drop_last=True,
67
+ persistent_workers=True
68
+ )
69
+
70
+ images = next(iter(train_dataloader))
71
+ if rank is not None:
72
+ print(f"Rank {rank}: Total Batches: {len(train_dataloader)}")
73
+ print(f"Rank {rank}: BATCH SHAPE: {images.shape}")
74
+ else:
75
+ print(f"Total Batches: {len(train_dataloader)}")
76
+ print(f"BATCH SHAPE: {images.shape}")
77
+ return train_dataloader
78
+
79
+ # === Audio Context Projector ===
80
+ class AudioContextProjector(nn.Module):
81
+ def __init__(self, audio_embed_dim=1280, output_dim=768):
82
+ super().__init__()
83
+ self.audio_embed_dim = audio_embed_dim
84
+ self.output_dim = output_dim
85
+ self.context_projector = nn.Sequential(
86
+ nn.Linear(audio_embed_dim, 320),
87
+ nn.SiLU(),
88
+ nn.Linear(320, output_dim)
89
+ )
90
+
91
+ def forward(self, audio_embedding):
92
+ if audio_embedding.size(-1) != self.audio_embed_dim:
93
+ raise ValueError(f"Expected audio embedding dim {self.audio_embed_dim}, got {audio_embedding.size(-1)}")
94
+ projected = self.context_projector(audio_embedding)
95
+ attn_scores = projected.mean(dim=2)
96
+ attn_weights = F.softmax(attn_scores, dim=1)
97
+ attn_weights = attn_weights.unsqueeze(2)
98
+ pooled = (projected * attn_weights).sum(dim=1, keepdim=True)
99
+ return pooled
100
+
101
+ # === Inference Function ===
102
+ def run_inference(pipe, unet, vae, device, context_hidden_states, save_path="inference_output.png", rank=0):
103
+ if rank != 0: # Only rank-0 or single-GPU process runs inference
104
+ return
105
+ pipe.unet = unet.module if isinstance(unet, DDP) else unet
106
+ pipe.vae = vae.module if isinstance(vae, DDP) else vae
107
+ pipe.to(device)
108
+
109
+ batch_size = 1
110
+ latents = torch.randn((batch_size, pipe.unet.in_channels, 64, 64), device=device, dtype=torch.float16)
111
+ pipe.scheduler.set_timesteps(50)
112
+ latents = latents * pipe.scheduler.init_noise_sigma
113
+
114
+ expected_shape = (batch_size, 1, 768)
115
+ if context_hidden_states.shape != expected_shape:
116
+ raise ValueError(f"Expected context_hidden_states shape {expected_shape}, got {context_hidden_states.shape}")
117
+
118
+ for t in pipe.scheduler.timesteps:
119
+ with torch.no_grad():
120
+ noise_pred = pipe.unet(latents, t, encoder_hidden_states=context_hidden_states).sample
121
+ latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample
122
+
123
+ latents = 1 / pipe.vae.config.scaling_factor * latents
124
+ with torch.no_grad():
125
+ image = pipe.vae.decode(latents).sample
126
+
127
+ image = (image / 2 + 0.5).clamp(0, 1)
128
+ image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
129
+ image = Image.fromarray((image * 255).astype("uint8"))
130
+ image.save(save_path)
131
+ print(f"{'Rank ' + str(rank) + ': ' if rank != 0 else ''}Inference image saved to {save_path}")
132
+
133
+ # === Load Pipeline ===
134
+ def load_pipeline(model_id, device):
135
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)
136
+ unet = pipe.unet
137
+ vae = pipe.vae
138
+ return pipe, unet, vae
139
+
140
+ # === Freeze Layers Function ===
141
+ def freeze_vae_layers(vae):
142
+ vae.encoder.requires_grad_(False)
143
+ vae.quant_conv.requires_grad_(False)
144
+ vae.decoder.requires_grad_(True)
145
+ vae.post_quant_conv.requires_grad_(True)
146
+
147
+ def freeze_unet_layers(unet):
148
+ for name, param in unet.named_parameters():
149
+ if "attn2" in name or "conv2" in name:
150
+ param.requires_grad = True
151
+ else:
152
+ param.requires_grad = False
153
+
154
+ # === Optimizer Setup ===
155
+ def setup_optimizer(vae, unet, projector, lr):
156
+ params_to_optimize = list(filter(lambda p: p.requires_grad, vae.parameters())) + \
157
+ list(filter(lambda p: p.requires_grad, unet.parameters())) + \
158
+ list(filter(lambda p: p.requires_grad, projector.parameters()))
159
+ optimizer = optim.AdamW(params_to_optimize, lr=lr)
160
+ return optimizer
161
+
162
+ # === Gradient Accumulation Function ===
163
+ def accumulate_gradients(optimizer, loss, gradient_accumulation_steps, step, dataloader):
164
+ loss = loss / gradient_accumulation_steps
165
+ loss.backward()
166
+ if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == len(dataloader):
167
+ optimizer.step()
168
+ optimizer.zero_grad()
169
+
170
+ # === Save Checkpoint Function ===
171
+ def save_checkpoint(epoch, unet, vae, projector, optimizer, checkpoint_path, rank=0):
172
+ if rank != 0: # Only rank-0 or single-GPU process saves checkpoint
173
+ return
174
+ torch.save({
175
+ 'epoch': epoch,
176
+ 'unet': unet.module.state_dict() if isinstance(unet, DDP) else unet.state_dict(),
177
+ 'vae': vae.module.state_dict() if isinstance(vae, DDP) else vae.state_dict(),
178
+ 'projector': projector.module.state_dict() if isinstance(projector, DDP) else projector.state_dict(),
179
+ 'optimizer': optimizer.state_dict(),
180
+ }, checkpoint_path)
181
+ print(f"{'Rank ' + str(rank) + ': ' if rank != 0 else ''}Checkpoint saved to {checkpoint_path}")
182
+
183
+ # === Resume from Checkpoint Function ===
184
+ def resume_from_checkpoint(checkpoint_path, unet, vae, projector, optimizer, rank=0):
185
+ if rank != 0: # Only rank-0 or single-GPU process loads checkpoint
186
+ return 0
187
+ if os.path.exists(checkpoint_path):
188
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
189
+ unet.load_state_dict(checkpoint['unet'])
190
+ vae.load_state_dict(checkpoint['vae'])
191
+ projector.load_state_dict(checkpoint['projector'])
192
+ optimizer.load_state_dict(checkpoint['optimizer'])
193
+ start_epoch = checkpoint['epoch'] + 1
194
+ print(f"{'Rank ' + str(rank) + ': ' if rank != 0 else ''}Resuming training from epoch {start_epoch}...")
195
+ return start_epoch
196
+ else:
197
+ print(f"{'Rank ' + str(rank) + ': ' if rank != 0 else ''}No checkpoint found, starting from scratch.")
198
+ return 0
199
+
200
+ # === Training Loop Function ===
201
+ def train_loop(dataloader, unet, vae, optimizer, gradient_accumulation_steps, device, num_epochs, samples_path, checkpoint_path, pipe, projector, rank=0, is_distributed=False):
202
+ start_epoch = resume_from_checkpoint(checkpoint_path, unet, vae, projector, optimizer, rank)
203
+
204
+ for epoch in trange(start_epoch, num_epochs, colour='red', desc=f"{'Rank ' + str(rank) + ' ' if rank != 0 else ''}{device}-training", dynamic_ncols=True):
205
+ unet.train()
206
+ vae.train()
207
+ projector.train()
208
+ total_loss = 0
209
+ step = 0
210
+
211
+ # Reset sampler for each epoch if using DistributedSampler
212
+ if is_distributed and isinstance(dataloader.sampler, DistributedSampler):
213
+ dataloader.sampler.set_epoch(epoch)
214
+
215
+ for image in tqdm(dataloader, colour='green', desc=f"{'Rank ' + str(rank) + ' ' if rank != 0 else ''}{device}-batch", dynamic_ncols=True):
216
+ image = image.to(device, dtype=torch.float16)
217
+
218
+ latents = vae.encode(image).latent_dist.sample() * 0.18215
219
+ noise = torch.randn_like(latents)
220
+ timesteps = torch.randint(0, pipe.scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long()
221
+
222
+ dummy_audio = torch.zeros(image.size(0), 1500, 1280, device=device, dtype=torch.float16)
223
+ context_hidden_states = projector(dummy_audio)
224
+
225
+ noise_pred = unet(latents + noise, timesteps, encoder_hidden_states=context_hidden_states).sample
226
+ loss = nn.MSELoss()(noise_pred, noise)
227
+ total_loss += loss.item()
228
+
229
+ step += 1
230
+ accumulate_gradients(optimizer, loss, gradient_accumulation_steps, step, dataloader)
231
+
232
+ # Aggregate loss for DDP
233
+ if is_distributed:
234
+ total_loss_tensor = torch.tensor(total_loss, device=device)
235
+ dist.all_reduce(total_loss_tensor, op=dist.ReduceOp.SUM)
236
+ avg_loss = total_loss_tensor.item() / (len(dataloader) * dist.get_world_size())
237
+ else:
238
+ avg_loss = total_loss / len(dataloader)
239
+
240
+ if rank == 0:
241
+ print(f"{'Rank ' + str(rank) + ': ' if rank != 0 else ''}Epoch {epoch + 1} | Avg Loss: {avg_loss:.6f}")
242
+
243
+ save_checkpoint(epoch, unet, vae, projector, optimizer, checkpoint_path, rank)
244
+ run_inference(pipe, unet, vae, device, context_hidden_states,
245
+ save_path=f"{samples_path}/inference_epoch{epoch + 1}{'_rank' + str(rank) if rank != 0 else ''}.png",
246
+ rank=rank)
247
+
248
+ if rank == 0:
249
+ print(f"{'Rank ' + str(rank) + ': ' if rank != 0 else ''}✅ Fine-tuning complete.")
250
+
251
+ # === DDP Setup Function ===
252
+ def setup_ddp(rank, world_size):
253
+ os.environ['MASTER_ADDR'] = 'localhost'
254
+ os.environ['MASTER_PORT'] = '12355'
255
+ dist.init_process_group("nccl", rank=rank, world_size=world_size)
256
+ torch.cuda.set_device(rank)
257
+
258
+ # === Main Function ===
259
+ def main(rank=0, world_size=1, is_distributed=False):
260
+ if is_distributed:
261
+ setup_ddp(rank, world_size)
262
+ device = torch.device(f"cuda:{rank}")
263
+ else:
264
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
265
+
266
+ model_id = "runwayml/stable-diffusion-v1-5"
267
+ lr = 1e-5
268
+ num_epochs = 10
269
+ batch_size = 16
270
+ debug = False
271
+ gradient_accumulation_steps = 1
272
+
273
+ if rank == 0:
274
+ os.makedirs(f"./checkpoints", exist_ok=True)
275
+ os.makedirs(f"./samples", exist_ok=True)
276
+ if is_distributed:
277
+ dist.barrier()
278
+
279
+ checkpoint_path = f"./checkpoints/checkpoint{'_rank' + str(rank) if is_distributed else ''}.pth"
280
+ samples_path = f"./samples"
281
+ image_dir = "/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Images"
282
+
283
+ pipe, unet, vae = load_pipeline(model_id, device)
284
+ freeze_vae_layers(vae)
285
+ freeze_unet_layers(unet)
286
+ projector = AudioContextProjector(audio_embed_dim=1280, output_dim=768).to(device).half()
287
+
288
+ if is_distributed:
289
+ unet = DDP(unet, device_ids=[rank])
290
+ vae = DDP(vae, device_ids=[rank])
291
+ projector = DDP(projector, device_ids=[rank])
292
+
293
+ optimizer = setup_optimizer(vae, unet, projector, lr)
294
+
295
+ files = walkDIR(image_dir, include=['.png', '.jpeg', '.jpg'])
296
+ dataset = VaaniDataset(files_paths=files, im_size=256)
297
+ if rank == 0:
298
+ image = dataset[2]
299
+ print(f"{'Rank ' + str(rank) + ': ' if rank != 0 else ''}IMAGE SHAPE: {image.shape}, Dataset len: {len(dataset)}")
300
+
301
+ dataloader = create_dataloader(dataset, batch_size, debug=debug, rank=rank, is_distributed=is_distributed)
302
+
303
+ train_loop(dataloader, unet, vae, optimizer, gradient_accumulation_steps, device, num_epochs,
304
+ samples_path, checkpoint_path, pipe, projector, rank, is_distributed)
305
+
306
+ if is_distributed:
307
+ dist.destroy_process_group()
308
+
309
+ # === Entry Point ===
310
+ if __name__ == "__main__":
311
+ world_size = torch.cuda.device_count()
312
+ print(f"Detected {world_size} GPU(s)")
313
+ if world_size > 1:
314
+ mp.spawn(main, args=(world_size, True), nprocs=world_size, join=True)
315
+ else:
316
+ main(rank=0, world_size=1, is_distributed=False)
Vaani/SDFT/checkpoints/checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79469e5ae61b7894df2b96cdb09b873f9d0e2282f8b85d4195c5dbd16e182891
3
+ size 2866661866
Vaani/SDFT/download_model.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import StableDiffusionPipeline, UNet2DConditionModel, StableDiffusion3Pipeline
3
+
4
+ device = "cuda" if torch.cuda.is_available() else "cpu"
5
+ print("device:", device)
6
+
7
+ # pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
8
+ # pipe
9
+ # del pipe
10
+
11
+ pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3.5-medium", torch_dtype=torch.bfloat16)
12
+ pipe
13
+ # del pipe
Vaani/SDFT/vaani-stablediffusion-finetune-kaggle.ipynb ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 4,
6
+ "metadata": {
7
+ "trusted": true
8
+ },
9
+ "outputs": [],
10
+ "source": [
11
+ "import torch\n",
12
+ "from torch import nn\n",
13
+ "from torch.utils.data import Dataset, DataLoader\n",
14
+ "from transformers import CLIPTextModel, CLIPTokenizer\n",
15
+ "from diffusers import StableDiffusionPipeline, UNet2DConditionModel\n",
16
+ "from diffusers.optimization import get_scheduler\n",
17
+ "from accelerate import Accelerator\n",
18
+ "import torchaudio"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 5,
24
+ "metadata": {
25
+ "trusted": true
26
+ },
27
+ "outputs": [
28
+ {
29
+ "name": "stderr",
30
+ "output_type": "stream",
31
+ "text": [
32
+ "Couldn't connect to the Hub: (MaxRetryError('HTTPSConnectionPool(host=\\'huggingface.co\\', port=443): Max retries exceeded with url: /api/models/runwayml/stable-diffusion-v1-5 (Caused by NameResolutionError(\"<urllib3.connection.HTTPSConnection object at 0x7f99cc2c77d0>: Failed to resolve \\'huggingface.co\\' ([Errno -2] Name or service not known)\"))'), '(Request ID: 85a7f948-b1d1-4bb4-be97-0eaea2bfd0f8)').\n",
33
+ "Will try to load from local cache.\n",
34
+ "Loading pipeline components...: 100%|██████████| 7/7 [00:43<00:00, 6.22s/it]\n"
35
+ ]
36
+ }
37
+ ],
38
+ "source": [
39
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
40
+ "pipe = StableDiffusionPipeline.from_pretrained(\"runwayml/stable-diffusion-v1-5\", torch_dtype=torch.float16).to(device)"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": 6,
46
+ "metadata": {
47
+ "trusted": true
48
+ },
49
+ "outputs": [],
50
+ "source": [
51
+ "unet = pipe.unet\n",
52
+ "vae = pipe.vae\n",
53
+ "tokenizer = pipe.tokenizer"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": null,
59
+ "metadata": {
60
+ "trusted": true
61
+ },
62
+ "outputs": [],
63
+ "source": [
64
+ "# Your text prompt\n",
65
+ "prompt = \"a photo of an astronaut riding a horse on mars\"\n",
66
+ "\n",
67
+ "# Generate image\n",
68
+ "with torch.autocast(\"cuda\"):\n",
69
+ " image = pipe(prompt).images[0]\n",
70
+ "\n",
71
+ "# Show or save the result\n",
72
+ "image.show() # Opens in default image viewer\n",
73
+ "image.save(\"astronaut_horse_mars.png\")"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "markdown",
78
+ "metadata": {},
79
+ "source": [
80
+ "<hr style=\"height:4px;border:none;color:#ff0000;background-color:#ff0000;\">"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": 8,
86
+ "metadata": {
87
+ "execution": {
88
+ "iopub.execute_input": "2025-05-14T14:30:58.653987Z",
89
+ "iopub.status.busy": "2025-05-14T14:30:58.653745Z",
90
+ "iopub.status.idle": "2025-05-14T14:30:58.658276Z",
91
+ "shell.execute_reply": "2025-05-14T14:30:58.657649Z",
92
+ "shell.execute_reply.started": "2025-05-14T14:30:58.653970Z"
93
+ },
94
+ "trusted": true
95
+ },
96
+ "outputs": [],
97
+ "source": [
98
+ "import torch\n",
99
+ "from torch.utils.data import Dataset, DataLoader\n",
100
+ "from torchvision import transforms\n",
101
+ "from torchvision.transforms import v2\n",
102
+ "from PIL import Image\n",
103
+ "from diffusers import StableDiffusionPipeline\n",
104
+ "from diffusers.optimization import get_scheduler\n",
105
+ "from accelerate import Accelerator\n",
106
+ "from torch import nn\n",
107
+ "import os\n",
108
+ "import pandas as pd\n",
109
+ "from tqdm import trange, tqdm"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "code",
114
+ "execution_count": 9,
115
+ "metadata": {
116
+ "execution": {
117
+ "iopub.execute_input": "2025-05-14T14:30:58.659588Z",
118
+ "iopub.status.busy": "2025-05-14T14:30:58.658976Z",
119
+ "iopub.status.idle": "2025-05-14T14:31:23.063776Z",
120
+ "shell.execute_reply": "2025-05-14T14:31:23.063145Z",
121
+ "shell.execute_reply.started": "2025-05-14T14:30:58.659571Z"
122
+ },
123
+ "trusted": true
124
+ },
125
+ "outputs": [
126
+ {
127
+ "name": "stdout",
128
+ "output_type": "stream",
129
+ "text": [
130
+ "Files found: 128807\n"
131
+ ]
132
+ },
133
+ {
134
+ "data": {
135
+ "text/html": [
136
+ "<div>\n",
137
+ "<style scoped>\n",
138
+ " .dataframe tbody tr th:only-of-type {\n",
139
+ " vertical-align: middle;\n",
140
+ " }\n",
141
+ "\n",
142
+ " .dataframe tbody tr th {\n",
143
+ " vertical-align: top;\n",
144
+ " }\n",
145
+ "\n",
146
+ " .dataframe thead th {\n",
147
+ " text-align: right;\n",
148
+ " }\n",
149
+ "</style>\n",
150
+ "<table border=\"1\" class=\"dataframe\">\n",
151
+ " <thead>\n",
152
+ " <tr style=\"text-align: right;\">\n",
153
+ " <th></th>\n",
154
+ " <th>image_path</th>\n",
155
+ " </tr>\n",
156
+ " </thead>\n",
157
+ " <tbody>\n",
158
+ " <tr>\n",
159
+ " <th>0</th>\n",
160
+ " <td>/kaggle/input/vaani-images-tar/Images/IISc_Vaa...</td>\n",
161
+ " </tr>\n",
162
+ " <tr>\n",
163
+ " <th>1</th>\n",
164
+ " <td>/kaggle/input/vaani-images-tar/Images/IISc_Vaa...</td>\n",
165
+ " </tr>\n",
166
+ " <tr>\n",
167
+ " <th>2</th>\n",
168
+ " <td>/kaggle/input/vaani-images-tar/Images/IISc_Vaa...</td>\n",
169
+ " </tr>\n",
170
+ " <tr>\n",
171
+ " <th>3</th>\n",
172
+ " <td>/kaggle/input/vaani-images-tar/Images/IISc_Vaa...</td>\n",
173
+ " </tr>\n",
174
+ " <tr>\n",
175
+ " <th>4</th>\n",
176
+ " <td>/kaggle/input/vaani-images-tar/Images/IISc_Vaa...</td>\n",
177
+ " </tr>\n",
178
+ " <tr>\n",
179
+ " <th>...</th>\n",
180
+ " <td>...</td>\n",
181
+ " </tr>\n",
182
+ " <tr>\n",
183
+ " <th>128802</th>\n",
184
+ " <td>/kaggle/input/vaani-images-tar/Images/IISc_Vaa...</td>\n",
185
+ " </tr>\n",
186
+ " <tr>\n",
187
+ " <th>128803</th>\n",
188
+ " <td>/kaggle/input/vaani-images-tar/Images/IISc_Vaa...</td>\n",
189
+ " </tr>\n",
190
+ " <tr>\n",
191
+ " <th>128804</th>\n",
192
+ " <td>/kaggle/input/vaani-images-tar/Images/IISc_Vaa...</td>\n",
193
+ " </tr>\n",
194
+ " <tr>\n",
195
+ " <th>128805</th>\n",
196
+ " <td>/kaggle/input/vaani-images-tar/Images/IISc_Vaa...</td>\n",
197
+ " </tr>\n",
198
+ " <tr>\n",
199
+ " <th>128806</th>\n",
200
+ " <td>/kaggle/input/vaani-images-tar/Images/IISc_Vaa...</td>\n",
201
+ " </tr>\n",
202
+ " </tbody>\n",
203
+ "</table>\n",
204
+ "<p>128807 rows × 1 columns</p>\n",
205
+ "</div>"
206
+ ],
207
+ "text/plain": [
208
+ " image_path\n",
209
+ "0 /kaggle/input/vaani-images-tar/Images/IISc_Vaa...\n",
210
+ "1 /kaggle/input/vaani-images-tar/Images/IISc_Vaa...\n",
211
+ "2 /kaggle/input/vaani-images-tar/Images/IISc_Vaa...\n",
212
+ "3 /kaggle/input/vaani-images-tar/Images/IISc_Vaa...\n",
213
+ "4 /kaggle/input/vaani-images-tar/Images/IISc_Vaa...\n",
214
+ "... ...\n",
215
+ "128802 /kaggle/input/vaani-images-tar/Images/IISc_Vaa...\n",
216
+ "128803 /kaggle/input/vaani-images-tar/Images/IISc_Vaa...\n",
217
+ "128804 /kaggle/input/vaani-images-tar/Images/IISc_Vaa...\n",
218
+ "128805 /kaggle/input/vaani-images-tar/Images/IISc_Vaa...\n",
219
+ "128806 /kaggle/input/vaani-images-tar/Images/IISc_Vaa...\n",
220
+ "\n",
221
+ "[128807 rows x 1 columns]"
222
+ ]
223
+ },
224
+ "execution_count": 9,
225
+ "metadata": {},
226
+ "output_type": "execute_result"
227
+ }
228
+ ],
229
+ "source": [
230
+ "IMAGES_PATH = r\"/kaggle/input/vaani-images-tar/Images\"\n",
231
+ "\n",
232
+ "def walkDIR(folder_path, include=None):\n",
233
+ " file_list = []\n",
234
+ " for root, _, files in os.walk(folder_path):\n",
235
+ " for file in files:\n",
236
+ " if include is None or any(file.endswith(ext) for ext in include):\n",
237
+ " file_list.append(os.path.join(root, file))\n",
238
+ " print(\"Files found:\", len(file_list))\n",
239
+ " return file_list\n",
240
+ "\n",
241
+ "files = walkDIR(IMAGES_PATH, include=['.png', '.jpeg', '.jpg'])\n",
242
+ "df = pd.DataFrame(files, columns=['image_path'])\n",
243
+ "df"
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "code",
248
+ "execution_count": 10,
249
+ "metadata": {
250
+ "execution": {
251
+ "iopub.execute_input": "2025-05-14T14:31:23.065017Z",
252
+ "iopub.status.busy": "2025-05-14T14:31:23.064553Z",
253
+ "iopub.status.idle": "2025-05-14T14:31:23.086417Z",
254
+ "shell.execute_reply": "2025-05-14T14:31:23.085628Z",
255
+ "shell.execute_reply.started": "2025-05-14T14:31:23.064991Z"
256
+ },
257
+ "trusted": true
258
+ },
259
+ "outputs": [
260
+ {
261
+ "name": "stdout",
262
+ "output_type": "stream",
263
+ "text": [
264
+ "IMAGE SHAPE: torch.Size([3, 256, 256])\n"
265
+ ]
266
+ },
267
+ {
268
+ "data": {
269
+ "text/plain": [
270
+ "128807"
271
+ ]
272
+ },
273
+ "execution_count": 10,
274
+ "metadata": {},
275
+ "output_type": "execute_result"
276
+ }
277
+ ],
278
+ "source": [
279
+ "class VaaniDataset(torch.utils.data.Dataset):\n",
280
+ " def __init__(self, files_paths, im_size):\n",
281
+ " self.files_paths = files_paths\n",
282
+ " self.im_size = im_size\n",
283
+ "\n",
284
+ " def __len__(self):\n",
285
+ " return len(self.files_paths)\n",
286
+ "\n",
287
+ " def __getitem__(self, idx):\n",
288
+ " # image = tv.io.read_image(self.files_paths[idx], mode=tv.io.ImageReadMode.RGB)\n",
289
+ " image = Image.open(self.files_paths[idx]).convert(\"RGB\")\n",
290
+ " image = v2.ToImage()(image)\n",
291
+ " # image = tv.io.decode_image(self.files_paths[idx], mode=tv.io.ImageReadMode.RGB)\n",
292
+ " image = v2.Resize((self.im_size, self.im_size))(image)\n",
293
+ " image = v2.ToDtype(torch.float32, scale=True)(image)\n",
294
+ " # image = 2*image - 1\n",
295
+ " return image\n",
296
+ "\n",
297
+ "dataset = VaaniDataset(files_paths=files, im_size=256)\n",
298
+ "image = dataset[2]\n",
299
+ "print('IMAGE SHAPE:', image.shape)\n",
300
+ "len(dataset)"
301
+ ]
302
+ },
303
+ {
304
+ "cell_type": "code",
305
+ "execution_count": 11,
306
+ "metadata": {
307
+ "execution": {
308
+ "iopub.execute_input": "2025-05-14T14:31:23.087483Z",
309
+ "iopub.status.busy": "2025-05-14T14:31:23.087211Z",
310
+ "iopub.status.idle": "2025-05-14T14:31:23.468810Z",
311
+ "shell.execute_reply": "2025-05-14T14:31:23.465992Z",
312
+ "shell.execute_reply.started": "2025-05-14T14:31:23.087458Z"
313
+ },
314
+ "trusted": true
315
+ },
316
+ "outputs": [
317
+ {
318
+ "name": "stdout",
319
+ "output_type": "stream",
320
+ "text": [
321
+ "Length of Train dataset: 129\n",
322
+ "BATCH SHAPE: torch.Size([2, 3, 256, 256])\n"
323
+ ]
324
+ }
325
+ ],
326
+ "source": [
327
+ "debug = True\n",
328
+ "\n",
329
+ "if debug:\n",
330
+ " s = 0.001\n",
331
+ " dataset, _ = torch.utils.data.random_split(dataset, [s, 1-s], torch.manual_seed(42))\n",
332
+ " print(\"Length of Train dataset:\", len(dataset))\n",
333
+ "\n",
334
+ "BATCH_SIZE = 2\n",
335
+ "\n",
336
+ "dataloader = torch.utils.data.DataLoader(\n",
337
+ " dataset, \n",
338
+ " batch_size=BATCH_SIZE, \n",
339
+ " shuffle=True, \n",
340
+ " num_workers=4,\n",
341
+ " pin_memory=True,\n",
342
+ " drop_last=True,\n",
343
+ " persistent_workers=True\n",
344
+ ")\n",
345
+ "\n",
346
+ "images = next(iter(dataloader))\n",
347
+ "print('BATCH SHAPE:', images.shape)"
348
+ ]
349
+ },
350
+ {
351
+ "cell_type": "code",
352
+ "execution_count": 17,
353
+ "metadata": {
354
+ "execution": {
355
+ "iopub.execute_input": "2025-05-14T14:31:59.796334Z",
356
+ "iopub.status.busy": "2025-05-14T14:31:59.795660Z",
357
+ "iopub.status.idle": "2025-05-14T14:31:59.800889Z",
358
+ "shell.execute_reply": "2025-05-14T14:31:59.800295Z",
359
+ "shell.execute_reply.started": "2025-05-14T14:31:59.796311Z"
360
+ },
361
+ "trusted": true
362
+ },
363
+ "outputs": [
364
+ {
365
+ "data": {
366
+ "text/plain": [
367
+ "64"
368
+ ]
369
+ },
370
+ "execution_count": 17,
371
+ "metadata": {},
372
+ "output_type": "execute_result"
373
+ }
374
+ ],
375
+ "source": [
376
+ "len(dataloader)"
377
+ ]
378
+ },
379
+ {
380
+ "cell_type": "code",
381
+ "execution_count": 12,
382
+ "metadata": {
383
+ "execution": {
384
+ "iopub.execute_input": "2025-05-14T14:31:23.470858Z",
385
+ "iopub.status.busy": "2025-05-14T14:31:23.470503Z",
386
+ "iopub.status.idle": "2025-05-14T14:31:28.213003Z",
387
+ "shell.execute_reply": "2025-05-14T14:31:28.212168Z",
388
+ "shell.execute_reply.started": "2025-05-14T14:31:23.470801Z"
389
+ },
390
+ "scrolled": true,
391
+ "trusted": true
392
+ },
393
+ "outputs": [
394
+ {
395
+ "data": {
396
+ "application/vnd.jupyter.widget-view+json": {
397
+ "model_id": "28c0c220b2cf45968b4abdecf3936bc9",
398
+ "version_major": 2,
399
+ "version_minor": 0
400
+ },
401
+ "text/plain": [
402
+ "Loading pipeline components...: 0%| | 0/7 [00:00<?, ?it/s]"
403
+ ]
404
+ },
405
+ "metadata": {},
406
+ "output_type": "display_data"
407
+ }
408
+ ],
409
+ "source": [
410
+ "# Load pretrained Stable Diffusion\n",
411
+ "pipe = StableDiffusionPipeline.from_pretrained(\n",
412
+ " \"runwayml/stable-diffusion-v1-5\", \n",
413
+ " torch_dtype=torch.float16\n",
414
+ ").to(\"cuda\")\n",
415
+ "\n",
416
+ "unet = pipe.unet\n",
417
+ "vae = pipe.vae\n",
418
+ "scheduler = pipe.scheduler"
419
+ ]
420
+ },
421
+ {
422
+ "cell_type": "code",
423
+ "execution_count": 15,
424
+ "metadata": {
425
+ "execution": {
426
+ "iopub.execute_input": "2025-05-14T14:31:39.847880Z",
427
+ "iopub.status.busy": "2025-05-14T14:31:39.847601Z",
428
+ "iopub.status.idle": "2025-05-14T14:31:39.868068Z",
429
+ "shell.execute_reply": "2025-05-14T14:31:39.867331Z",
430
+ "shell.execute_reply.started": "2025-05-14T14:31:39.847863Z"
431
+ },
432
+ "trusted": true
433
+ },
434
+ "outputs": [],
435
+ "source": [
436
+ "# Optimizer and scheduler\n",
437
+ "optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-5)\n",
438
+ "lr_scheduler = get_scheduler(\"linear\", optimizer=optimizer, num_warmup_steps=100, num_training_steps=1000)\n",
439
+ "\n",
440
+ "\n",
441
+ "accelerator = Accelerator()\n",
442
+ "unet, optimizer, dataloader = accelerator.prepare(unet, optimizer, dataloader)"
443
+ ]
444
+ },
445
+ {
446
+ "cell_type": "code",
447
+ "execution_count": 16,
448
+ "metadata": {
449
+ "execution": {
450
+ "iopub.execute_input": "2025-05-14T14:31:42.759171Z",
451
+ "iopub.status.busy": "2025-05-14T14:31:42.758886Z",
452
+ "iopub.status.idle": "2025-05-14T14:31:42.763012Z",
453
+ "shell.execute_reply": "2025-05-14T14:31:42.762387Z",
454
+ "shell.execute_reply.started": "2025-05-14T14:31:42.759152Z"
455
+ },
456
+ "trusted": true
457
+ },
458
+ "outputs": [],
459
+ "source": [
460
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
461
+ ]
462
+ },
463
+ {
464
+ "cell_type": "code",
465
+ "execution_count": 19,
466
+ "metadata": {
467
+ "execution": {
468
+ "iopub.execute_input": "2025-05-14T14:40:23.644302Z",
469
+ "iopub.status.busy": "2025-05-14T14:40:23.643598Z",
470
+ "iopub.status.idle": "2025-05-14T14:40:23.648831Z",
471
+ "shell.execute_reply": "2025-05-14T14:40:23.648151Z",
472
+ "shell.execute_reply.started": "2025-05-14T14:40:23.644244Z"
473
+ },
474
+ "trusted": true
475
+ },
476
+ "outputs": [],
477
+ "source": [
478
+ "EPOCHS = 100"
479
+ ]
480
+ },
481
+ {
482
+ "cell_type": "code",
483
+ "execution_count": null,
484
+ "metadata": {
485
+ "execution": {
486
+ "iopub.execute_input": "2025-05-14T14:40:35.244187Z",
487
+ "iopub.status.busy": "2025-05-14T14:40:35.243686Z"
488
+ },
489
+ "scrolled": true,
490
+ "trusted": true
491
+ },
492
+ "outputs": [
493
+ {
494
+ "name": "stderr",
495
+ "output_type": "stream",
496
+ "text": [
497
+ " 84%|\u001b[32m████████████████████████████████████████████████████ \u001b[0m| 84/100 [39:16<07:28, 28.02s/it]\u001b[0m"
498
+ ]
499
+ }
500
+ ],
501
+ "source": [
502
+ "# Start training loop\n",
503
+ "for epoch in trange(EPOCHS, ncols=100, colour='green'):\n",
504
+ " for step, images in enumerate(dataloader):\n",
505
+ " images = images.to(device, dtype=torch.float16)\n",
506
+ "\n",
507
+ " # Encode images to latents\n",
508
+ " latents = vae.encode(images).latent_dist.sample()\n",
509
+ " latents = latents * 0.18215\n",
510
+ "\n",
511
+ " # Sample noise and timesteps\n",
512
+ " noise = torch.randn_like(latents)\n",
513
+ " timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long()\n",
514
+ " noisy_latents = scheduler.add_noise(latents, noise, timesteps)\n",
515
+ "\n",
516
+ " # Use zeroed audio embedding (like a null conditioning vector)\n",
517
+ " batch_size = images.shape[0]\n",
518
+ " cond_dim = pipe.text_encoder.config.hidden_size # 768 for SD 1.5\n",
519
+ " null_emb = torch.zeros((batch_size, 77, cond_dim), device=device, dtype=torch.float16)\n",
520
+ "\n",
521
+ " # Predict noise\n",
522
+ " noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=null_emb).sample\n",
523
+ "\n",
524
+ " # Loss and backward\n",
525
+ " loss = nn.MSELoss()(noise_pred, noise)\n",
526
+ " accelerator.backward(loss)\n",
527
+ " optimizer.step()\n",
528
+ " lr_scheduler.step()\n",
529
+ " optimizer.zero_grad()\n",
530
+ "\n",
531
+ " # if step % 10 == 0:\n",
532
+ " # print(f\"Epoch {epoch}, Step {step}, Loss: {loss.item():.4f}\")"
533
+ ]
534
+ },
535
+ {
536
+ "cell_type": "markdown",
537
+ "metadata": {},
538
+ "source": [
539
+ "# Sampling"
540
+ ]
541
+ },
542
+ {
543
+ "cell_type": "code",
544
+ "execution_count": null,
545
+ "metadata": {
546
+ "trusted": true
547
+ },
548
+ "outputs": [],
549
+ "source": [
550
+ "import torch\n",
551
+ "from diffusers import StableDiffusionPipeline, DDIMScheduler\n",
552
+ "from PIL import Image\n",
553
+ "\n",
554
+ "# Load pretrained (or fine-tuned) Stable Diffusion\n",
555
+ "pipe = StableDiffusionPipeline.from_pretrained(\n",
556
+ " \"runwayml/stable-diffusion-v1-5\",\n",
557
+ " torch_dtype=torch.float16,\n",
558
+ ")\n",
559
+ "pipe.to(\"cuda\")\n",
560
+ "\n",
561
+ "# Optionally load fine-tuned weights\n",
562
+ "pipe.unet.load_state_dict(torch.load(\"path/to/fine_tuned_unet.pth\"))"
563
+ ]
564
+ },
565
+ {
566
+ "cell_type": "code",
567
+ "execution_count": null,
568
+ "metadata": {
569
+ "trusted": true
570
+ },
571
+ "outputs": [],
572
+ "source": [
573
+ "# Prepare dummy zero embedding\n",
574
+ "batch_size = 1\n",
575
+ "seq_len = 77 # number of tokens (CLIP text length)\n",
576
+ "embed_dim = pipe.text_encoder.config.hidden_size # 768 for CLIP\n",
577
+ "null_emb = torch.zeros((batch_size, seq_len, embed_dim), device=\"cuda\", dtype=torch.float16)\n",
578
+ "\n",
579
+ "# Sample initial noise\n",
580
+ "latents = torch.randn((batch_size, pipe.unet.in_channels, 64, 64), device=\"cuda\", dtype=torch.float16)\n",
581
+ "\n",
582
+ "# Use DDIM or default scheduler\n",
583
+ "pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)\n",
584
+ "\n",
585
+ "# Denoising loop\n",
586
+ "num_inference_steps = 50\n",
587
+ "pipe.scheduler.set_timesteps(num_inference_steps)\n",
588
+ "latents = latents * pipe.scheduler.init_noise_sigma\n",
589
+ "\n",
590
+ "for t in pipe.scheduler.timesteps:\n",
591
+ " # Predict noise using zero embedding\n",
592
+ " with torch.no_grad():\n",
593
+ " noise_pred = pipe.unet(latents, t, encoder_hidden_states=null_emb).sample\n",
594
+ "\n",
595
+ " # Compute the previous noisy sample\n",
596
+ " latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample\n",
597
+ "\n",
598
+ "# Decode latents to image\n",
599
+ "latents = 1 / 0.18215 * latents\n",
600
+ "with torch.no_grad():\n",
601
+ " image = pipe.vae.decode(latents).sample\n",
602
+ "\n",
603
+ "# Convert to PIL\n",
604
+ "image = (image / 2 + 0.5).clamp(0, 1)\n",
605
+ "image = image.cpu().permute(0, 2, 3, 1).numpy()[0]\n",
606
+ "image = Image.fromarray((image * 255).astype(\"uint8\"))\n",
607
+ "\n",
608
+ "# Save or show image\n",
609
+ "image.save(\"zero_condition_output.png\")\n",
610
+ "image.show()"
611
+ ]
612
+ }
613
+ ],
614
+ "metadata": {
615
+ "kaggle": {
616
+ "accelerator": "nvidiaTeslaT4",
617
+ "dataSources": [
618
+ {
619
+ "datasetId": 6964433,
620
+ "sourceId": 11161218,
621
+ "sourceType": "datasetVersion"
622
+ }
623
+ ],
624
+ "dockerImageVersionId": 31041,
625
+ "isGpuEnabled": true,
626
+ "isInternetEnabled": true,
627
+ "language": "python",
628
+ "sourceType": "notebook"
629
+ },
630
+ "kernelspec": {
631
+ "display_name": "Python 3",
632
+ "language": "python",
633
+ "name": "python3"
634
+ },
635
+ "language_info": {
636
+ "codemirror_mode": {
637
+ "name": "ipython",
638
+ "version": 3
639
+ },
640
+ "file_extension": ".py",
641
+ "mimetype": "text/x-python",
642
+ "name": "python",
643
+ "nbconvert_exporter": "python",
644
+ "pygments_lexer": "ipython3",
645
+ "version": "3.12.2"
646
+ }
647
+ },
648
+ "nbformat": 4,
649
+ "nbformat_minor": 4
650
+ }
Vaani/VaaniLDM/ddpm_ckpt_epoch31.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:339fa2458d3ead55689b3e219be0223ddc515874a4c03bb67bce527527076073
3
+ size 593243562
Vaani/VaaniLDM/ddpm_ckpt_epoch32.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32a6ea3e6f6558014f9eb11da6263abf02f130fdd77643889cf088f6d7077359
3
+ size 593243626
Vaani/VaaniLDM/ldmH_ckpt_epoch24.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0abb3e97bfae10d4689aeeed14bd6fcd48472be5729870a4d179a74ff67982c7
3
+ size 2476368170
Vaani/VaaniLDM/ldmH_ckpt_epoch25.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:efd5c35e1948cb47ae74cf4a09e43cf6428d623bb3bb3bec0594057a195b7953
3
+ size 2476368234
Vaani/VaaniLDM/samples/x0_0.png CHANGED

Git LFS Details

  • SHA256: 5051b0c57b98915bbd30f8f413daa87c96e3bc117dd72adf55fee55c33d75516
  • Pointer size: 131 Bytes
  • Size of remote file: 421 kB

Git LFS Details

  • SHA256: a40423339bf5053f537333df664a4945fb160181e05ae139856e2b67cd77cc16
  • Pointer size: 131 Bytes
  • Size of remote file: 426 kB
Vaani/VaaniLDM/samples/x0_1.png CHANGED
Vaani/VaaniLDM/samples/x0_10.png CHANGED
Vaani/VaaniLDM/samples/x0_100.png CHANGED
Vaani/VaaniLDM/samples/x0_101.png CHANGED
Vaani/VaaniLDM/samples/x0_102.png CHANGED
Vaani/VaaniLDM/samples/x0_103.png CHANGED
Vaani/VaaniLDM/samples/x0_104.png CHANGED
Vaani/VaaniLDM/samples/x0_105.png CHANGED
Vaani/VaaniLDM/samples/x0_106.png CHANGED
Vaani/VaaniLDM/samples/x0_107.png CHANGED
Vaani/VaaniLDM/samples/x0_108.png CHANGED
Vaani/VaaniLDM/samples/x0_109.png CHANGED
Vaani/VaaniLDM/samples/x0_11.png CHANGED
Vaani/VaaniLDM/samples/x0_110.png CHANGED
Vaani/VaaniLDM/samples/x0_111.png CHANGED
Vaani/VaaniLDM/samples/x0_112.png CHANGED
Vaani/VaaniLDM/samples/x0_113.png CHANGED
Vaani/VaaniLDM/samples/x0_114.png CHANGED
Vaani/VaaniLDM/samples/x0_115.png CHANGED
Vaani/VaaniLDM/samples/x0_116.png CHANGED
Vaani/VaaniLDM/samples/x0_117.png CHANGED
Vaani/VaaniLDM/samples/x0_118.png CHANGED
Vaani/VaaniLDM/samples/x0_119.png CHANGED
Vaani/VaaniLDM/samples/x0_12.png CHANGED
Vaani/VaaniLDM/samples/x0_120.png CHANGED
Vaani/VaaniLDM/samples/x0_121.png CHANGED
Vaani/VaaniLDM/samples/x0_122.png CHANGED
Vaani/VaaniLDM/samples/x0_123.png CHANGED
Vaani/VaaniLDM/samples/x0_124.png CHANGED
Vaani/VaaniLDM/samples/x0_125.png CHANGED
Vaani/VaaniLDM/samples/x0_126.png CHANGED
Vaani/VaaniLDM/samples/x0_127.png CHANGED
Vaani/VaaniLDM/samples/x0_128.png CHANGED
Vaani/VaaniLDM/samples/x0_129.png CHANGED
Vaani/VaaniLDM/samples/x0_13.png CHANGED
Vaani/VaaniLDM/samples/x0_130.png CHANGED
Vaani/VaaniLDM/samples/x0_131.png CHANGED
Vaani/VaaniLDM/samples/x0_132.png CHANGED