eeshaAI commited on
Commit
fe0fc64
·
verified ·
1 Parent(s): 9272f1f

Add Google Colab training notebook (T4 GPU + incremental HF push)

Browse files
Files changed (1) hide show
  1. Zeeb_Video_LLM_Training.ipynb +983 -0
Zeeb_Video_LLM_Training.ipynb ADDED
@@ -0,0 +1,983 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4",
8
+ "name": "Zeeb_Video_LLM_Training.ipynb"
9
+ },
10
+ "kernelspec": {
11
+ "name": "python3",
12
+ "display_name": "Python 3"
13
+ },
14
+ "language_info": {
15
+ "name": "python"
16
+ },
17
+ "accelerator": "GPU"
18
+ },
19
+ "cells": [
20
+ {
21
+ "cell_type": "markdown",
22
+ "metadata": {},
23
+ "source": [
24
+ "# 🎬 Zeeb — Video-LLM Training on T4 GPU\n",
25
+ "\n",
26
+ "**OLMo 2 1B + LoRA + VQ-VAE → Text-to-Video Generation**\n",
27
+ "\n",
28
+ "This notebook trains the full pipeline on a **Google Colab T4 GPU** and pushes checkpoints to HuggingFace incrementally.\n",
29
+ "\n",
30
+ "## Pipeline Overview\n",
31
+ "1. **Phase 1**: Train VQ-VAE on real images (COCO, streaming)\n",
32
+ "2. **Phase 2**: Tokenize image-text pairs through trained VQ-VAE\n",
33
+ "3. **Phase 3**: Fine-tune OLMo 2 1B + LoRA on tokenized data → push to EeshaAI/zeeb\n",
34
+ "\n",
35
+ "## Key Features\n",
36
+ "- ✅ **Incremental checkpoint pushing** to HuggingFace (survives Colab disconnects)\n",
37
+ "- ✅ **Resume from checkpoint** if training is interrupted\n",
38
+ "- ✅ **HuggingFace Trainer** with `push_to_hub=True` and `save_strategy=\"steps\"`\n",
39
+ "- ✅ **Real data** from COCO/imagenette (10K+ images)\n",
40
+ "- ✅ **GPU-accelerated** training (T4 = ~50x faster than CPU)\n",
41
+ "\n",
42
+ "**Make sure you select GPU runtime**: Runtime → Change runtime type → T4 GPU"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "markdown",
47
+ "metadata": {},
48
+ "source": [
49
+ "## ⚙️ Cell 1: Setup & Authentication"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": null,
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "# @title 1. Install Dependencies\n",
59
+ "!pip install -q torch torchvision transformers peft accelerate datasets huggingface_hub safetensors imageio Pillow\n",
60
+ "\n",
61
+ "import torch\n",
62
+ "print(f\"PyTorch: {torch.__version__}\")\n",
63
+ "print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
64
+ "if torch.cuda.is_available():\n",
65
+ " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
66
+ " print(f\"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB\")\n",
67
+ "else:\n",
68
+ " raise RuntimeError(\"No GPU detected! Go to Runtime → Change runtime type → T4 GPU\")"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": null,
74
+ "metadata": {},
75
+ "outputs": [],
76
+ "source": [
77
+ "# @title 2. HuggingFace Authentication\n",
78
+ "from huggingface_hub import HfApi, login\n",
79
+ "import os\n",
80
+ "\n",
81
+ "# 🔑 Paste your HuggingFace token here (must have write access to EeshaAI/zeeb)\n",
82
+ "HF_TOKEN = \"YOUR_HF_TOKEN_HERE\" # @param {type:\"string\"}\n",
83
+ "\n",
84
+ "login(token=HF_TOKEN)\n",
85
+ "\n",
86
+ "api = HfApi()\n",
87
+ "user_info = api.whoami()\n",
88
+ "print(f\"Logged in as: {user_info['name']}\")\n",
89
+ "\n",
90
+ "REPO_ID = \"EeshaAI/zeeb\"\n",
91
+ "api.create_repo(repo_id=REPO_ID, repo_type=\"model\", exist_ok=True)\n",
92
+ "print(f\"Model repo: https://huggingface.co/{REPO_ID}\")"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "markdown",
97
+ "metadata": {},
98
+ "source": [
99
+ "## 🧠 Cell 2: VQ-VAE Model Definition"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "execution_count": null,
105
+ "metadata": {},
106
+ "outputs": [],
107
+ "source": [
108
+ "# @title 3. VQ-VAE Architecture\n",
109
+ "import torch\n",
110
+ "import torch.nn as nn\n",
111
+ "import torch.nn.functional as F\n",
112
+ "\n",
113
+ "CODEBOOK_SIZE = 1024\n",
114
+ "CODEBOOK_DIM = 256\n",
115
+ "LATENT_DIM = 256\n",
116
+ "\n",
117
+ "class Encoder(nn.Module):\n",
118
+ " def __init__(self, in_channels=3, latent_dim=LATENT_DIM):\n",
119
+ " super().__init__()\n",
120
+ " self.net = nn.Sequential(\n",
121
+ " nn.Conv2d(in_channels, 64, 4, stride=2, padding=1), # -> 64x64\n",
122
+ " nn.ReLU(),\n",
123
+ " nn.Conv2d(64, 128, 4, stride=2, padding=1), # -> 32x32\n",
124
+ " nn.ReLU(),\n",
125
+ " nn.Conv2d(128, 256, 4, stride=2, padding=1), # -> 16x16\n",
126
+ " nn.ReLU(),\n",
127
+ " nn.Conv2d(256, latent_dim, 4, stride=2, padding=1), # -> 8x8\n",
128
+ " )\n",
129
+ "\n",
130
+ " def forward(self, x):\n",
131
+ " return self.net(x)\n",
132
+ "\n",
133
+ "\n",
134
+ "class VectorQuantizer(nn.Module):\n",
135
+ " def __init__(self, codebook_size=CODEBOOK_SIZE, codebook_dim=CODEBOOK_DIM, commitment_cost=0.25):\n",
136
+ " super().__init__()\n",
137
+ " self.codebook_size = codebook_size\n",
138
+ " self.codebook_dim = codebook_dim\n",
139
+ " self.commitment_cost = commitment_cost\n",
140
+ " self.codebook = nn.Embedding(codebook_size, codebook_dim)\n",
141
+ " self.codebook.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size)\n",
142
+ "\n",
143
+ " def forward(self, z):\n",
144
+ " B, H, W, C = z.shape\n",
145
+ " z_flat = z.reshape(-1, C)\n",
146
+ " dist = (z_flat.unsqueeze(1) - self.codebook.weight.unsqueeze(0)).pow(2).sum(-1)\n",
147
+ " indices = dist.argmin(dim=1)\n",
148
+ " z_q = self.codebook(indices).reshape(B, H, W, C)\n",
149
+ " commitment_loss = F.mse_loss(z_flat, z_q.reshape(-1, C).detach())\n",
150
+ " codebook_loss = F.mse_loss(z_q.reshape(-1, C), z_flat.detach())\n",
151
+ " loss = codebook_loss + self.commitment_cost * commitment_loss\n",
152
+ " z_q_st = z + (z_q - z).detach()\n",
153
+ " return z_q_st, loss, indices.reshape(B, H, W)\n",
154
+ "\n",
155
+ "\n",
156
+ "class Decoder(nn.Module):\n",
157
+ " def __init__(self, out_channels=3, latent_dim=LATENT_DIM):\n",
158
+ " super().__init__()\n",
159
+ " self.net = nn.Sequential(\n",
160
+ " nn.ConvTranspose2d(latent_dim, 256, 4, stride=2, padding=1), # -> 16x16\n",
161
+ " nn.ReLU(),\n",
162
+ " nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), # -> 32x32\n",
163
+ " nn.ReLU(),\n",
164
+ " nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), # -> 64x64\n",
165
+ " nn.ReLU(),\n",
166
+ " nn.ConvTranspose2d(64, out_channels, 4, stride=2, padding=1), # -> 128x128\n",
167
+ " nn.Sigmoid(),\n",
168
+ " )\n",
169
+ "\n",
170
+ " def forward(self, x):\n",
171
+ " return self.net(x)\n",
172
+ "\n",
173
+ "\n",
174
+ "class VQVAE(nn.Module):\n",
175
+ " def __init__(self):\n",
176
+ " super().__init__()\n",
177
+ " self.encoder = Encoder()\n",
178
+ " self.quantizer = VectorQuantizer()\n",
179
+ " self.proj_in = nn.Linear(LATENT_DIM, CODEBOOK_DIM)\n",
180
+ " self.proj_out = nn.Linear(CODEBOOK_DIM, LATENT_DIM)\n",
181
+ " self.decoder = Decoder()\n",
182
+ "\n",
183
+ " def forward(self, x):\n",
184
+ " z = self.encoder(x)\n",
185
+ " z = z.permute(0, 2, 3, 1)\n",
186
+ " z = self.proj_in(z)\n",
187
+ " z_q, vq_loss, indices = self.quantizer(z)\n",
188
+ " z_q = self.proj_out(z_q)\n",
189
+ " z_q = z_q.permute(0, 3, 1, 2)\n",
190
+ " recon = self.decoder(z_q)\n",
191
+ " return recon, vq_loss, indices\n",
192
+ "\n",
193
+ " def encode(self, x):\n",
194
+ " z = self.encoder(x)\n",
195
+ " z = z.permute(0, 2, 3, 1)\n",
196
+ " z = self.proj_in(z)\n",
197
+ " _, _, indices = self.quantizer(z)\n",
198
+ " return indices\n",
199
+ "\n",
200
+ " def decode_tokens(self, token_ids, grid_h=8, grid_w=8):\n",
201
+ " if isinstance(token_ids, list):\n",
202
+ " token_ids = torch.tensor(token_ids, dtype=torch.long)\n",
203
+ " token_ids = token_ids[:grid_h * grid_w]\n",
204
+ " if len(token_ids) < grid_h * grid_w:\n",
205
+ " token_ids = torch.cat([token_ids, torch.zeros(grid_h * grid_w - len(token_ids), dtype=torch.long)])\n",
206
+ " z_q = self.quantizer.codebook(token_ids)\n",
207
+ " z_q = self.proj_out(z_q)\n",
208
+ " z_q = z_q.reshape(1, grid_h, grid_w, -1).permute(0, 3, 1, 2)\n",
209
+ " return self.decoder(z_q)\n",
210
+ "\n",
211
+ "# Test\n",
212
+ "vq_vae = VQVAE().cuda()\n",
213
+ "test_input = torch.randn(2, 3, 128, 128).cuda()\n",
214
+ "recon, vq_loss, indices = vq_vae(test_input)\n",
215
+ "print(f\"VQ-VAE test: input {test_input.shape} -> recon {recon.shape}, indices {indices.shape}, loss {vq_loss.item():.4f}\")\n",
216
+ "n_params = sum(p.numel() for p in vq_vae.parameters()) / 1e6\n",
217
+ "print(f\"Parameters: {n_params:.1f}M\")\n",
218
+ "del vq_vae, test_input\n",
219
+ "torch.cuda.empty_cache()"
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "markdown",
224
+ "metadata": {},
225
+ "source": [
226
+ "## 🖼️ Phase 1: Train VQ-VAE on Real Images"
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "code",
231
+ "execution_count": null,
232
+ "metadata": {},
233
+ "outputs": [],
234
+ "source": [
235
+ "# @title 4. Phase 1: Train VQ-VAE\n",
236
+ "from datasets import load_dataset\n",
237
+ "from torchvision import transforms\n",
238
+ "from torch.utils.data import DataLoader, IterableDataset\n",
239
+ "import time\n",
240
+ "\n",
241
+ "# Check if trained VQ-VAE already exists on HF\n",
242
+ "VQ_VAE_ALREADY_TRAINED = False # @param {type:\"boolean\"}\n",
243
+ "VQ_VAE_EPOCHS = 5 # @param {type:\"integer\"}\n",
244
+ "VQ_VAE_LR = 3e-4 # @param {type:\"number\"}\n",
245
+ "VQ_VAE_BATCH = 32 # @param {type:\"integer\"}\n",
246
+ "VQ_VAE_MAX_IMAGES = 20000 # @param {type:\"integer\"}\n",
247
+ "VQ_VAE_IMG_SIZE = 128 # @param {type:\"integer\"}\n",
248
+ "\n",
249
+ "if VQ_VAE_ALREADY_TRAINED:\n",
250
+ " print(\"Skipping VQ-VAE training (already trained)\")\n",
251
+ " vq_vae = VQVAE()\n",
252
+ " # Download from HF if available\n",
253
+ " try:\n",
254
+ " from huggingface_hub import hf_hub_download\n",
255
+ " vq_path = hf_hub_download(REPO_ID, \"vq_vae_final.pt\", repo_type=\"model\")\n",
256
+ " vq_vae.load_state_dict(torch.load(vq_path, map_location=\"cuda\", weights_only=False))\n",
257
+ " print(f\"Loaded VQ-VAE from {REPO_ID}\")\n",
258
+ " except:\n",
259
+ " print(\"Could not download VQ-VAE, training from scratch\")\n",
260
+ " VQ_VAE_ALREADY_TRAINED = False\n",
261
+ "\n",
262
+ "if not VQ_VAE_ALREADY_TRAINED:\n",
263
+ " # Load dataset\n",
264
+ " print(\"Loading image dataset...\")\n",
265
+ " ds = None\n",
266
+ " image_key = \"image\"\n",
267
+ " cap_key = None\n",
268
+ " ds_name = \"\"\n",
269
+ "\n",
270
+ " for name, split, ik, ck in [\n",
271
+ " (\"detection-datasets/coco\", \"train\", \"image\", \"caption\"),\n",
272
+ " (\"frgfm/imagenette\", \"train\", \"image\", \"label\"),\n",
273
+ " (\"cifar10\", \"train\", \"img\", \"label\"),\n",
274
+ " ]:\n",
275
+ " try:\n",
276
+ " print(f\" Trying {name}...\")\n",
277
+ " ds = load_dataset(name, split=split, streaming=True, trust_remote_code=True)\n",
278
+ " test_item = next(iter(ds))\n",
279
+ " if ik in test_item:\n",
280
+ " image_key = ik\n",
281
+ " cap_key = ck if ck in test_item else None\n",
282
+ " ds_name = name\n",
283
+ " print(f\" Using {name}!\")\n",
284
+ " break\n",
285
+ " ds = None\n",
286
+ " except Exception as e:\n",
287
+ " print(f\" Failed: {str(e)[:80]}\")\n",
288
+ " ds = None\n",
289
+ "\n",
290
+ " if ds is None:\n",
291
+ " raise RuntimeError(\"No dataset available!\")\n",
292
+ "\n",
293
+ " # Transforms\n",
294
+ " transform = transforms.Compose([\n",
295
+ " transforms.Resize((VQ_VAE_IMG_SIZE, VQ_VAE_IMG_SIZE)),\n",
296
+ " transforms.ToTensor(),\n",
297
+ " ])\n",
298
+ "\n",
299
+ " class ImageStreamDataset(IterableDataset):\n",
300
+ " def __init__(self, hf_ds, transform, img_key, max_samples):\n",
301
+ " self.ds = hf_ds\n",
302
+ " self.transform = transform\n",
303
+ " self.img_key = img_key\n",
304
+ " self.max = max_samples\n",
305
+ "\n",
306
+ " def __iter__(self):\n",
307
+ " count = 0\n",
308
+ " for item in self.ds:\n",
309
+ " if count >= self.max:\n",
310
+ " break\n",
311
+ " try:\n",
312
+ " img = item[self.img_key]\n",
313
+ " if img.mode != \"RGB\":\n",
314
+ " img = img.convert(\"RGB\")\n",
315
+ " yield self.transform(img)\n",
316
+ " count += 1\n",
317
+ " except:\n",
318
+ " continue\n",
319
+ "\n",
320
+ " dataset = ImageStreamDataset(ds, transform, image_key, VQ_VAE_MAX_IMAGES)\n",
321
+ " dataloader = DataLoader(dataset, batch_size=VQ_VAE_BATCH, num_workers=2, pin_memory=True)\n",
322
+ "\n",
323
+ " # Initialize model\n",
324
+ " vq_vae = VQVAE().cuda()\n",
325
+ " optimizer = torch.optim.Adam(vq_vae.parameters(), lr=VQ_VAE_LR)\n",
326
+ " scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=VQ_VAE_EPOCHS)\n",
327
+ "\n",
328
+ " # Training loop\n",
329
+ " print(f\"\\nTraining VQ-VAE: {VQ_VAE_EPOCHS} epochs, {VQ_VAE_MAX_IMAGES} images, batch {VQ_VAE_BATCH}\")\n",
330
+ " vq_vae.train()\n",
331
+ " best_loss = float('inf')\n",
332
+ "\n",
333
+ " for epoch in range(VQ_VAE_EPOCHS):\n",
334
+ " epoch_loss = 0.0\n",
335
+ " epoch_recon = 0.0\n",
336
+ " epoch_vq = 0.0\n",
337
+ " n_batches = 0\n",
338
+ " start = time.time()\n",
339
+ "\n",
340
+ " for batch_idx, batch in enumerate(dataloader):\n",
341
+ " batch = batch.cuda()\n",
342
+ " recon, vq_loss, _ = vq_vae(batch)\n",
343
+ " recon_loss = F.mse_loss(recon, batch)\n",
344
+ " loss = recon_loss + vq_loss\n",
345
+ "\n",
346
+ " optimizer.zero_grad()\n",
347
+ " loss.backward()\n",
348
+ " torch.nn.utils.clip_grad_norm_(vq_vae.parameters(), 1.0)\n",
349
+ " optimizer.step()\n",
350
+ "\n",
351
+ " epoch_loss += loss.item()\n",
352
+ " epoch_recon += recon_loss.item()\n",
353
+ " epoch_vq += vq_loss.item()\n",
354
+ " n_batches += 1\n",
355
+ "\n",
356
+ " if batch_idx % 100 == 0 and batch_idx > 0:\n",
357
+ " avg = epoch_loss / n_batches\n",
358
+ " print(f\" Epoch {epoch+1}/{VQ_VAE_EPOCHS} | Batch {batch_idx} | Loss: {avg:.4f} (recon: {epoch_recon/n_batches:.4f}, vq: {epoch_vq/n_batches:.4f})\")\n",
359
+ "\n",
360
+ " scheduler.step()\n",
361
+ " elapsed = time.time() - start\n",
362
+ " avg_loss = epoch_loss / max(n_batches, 1)\n",
363
+ " print(f\"\\n Epoch {epoch+1} done. Loss: {avg_loss:.4f} | Batches: {n_batches} | Time: {elapsed:.0f}s\")\n",
364
+ "\n",
365
+ " # Save best model & push to HF\n",
366
+ " if avg_loss < best_loss:\n",
367
+ " best_loss = avg_loss\n",
368
+ " torch.save(vq_vae.state_dict(), \"vq_vae_best.pt\")\n",
369
+ " print(f\" New best model! Loss: {avg_loss:.4f}\")\n",
370
+ "\n",
371
+ " # Push VQ-VAE checkpoint to HF after each epoch\n",
372
+ " torch.save(vq_vae.state_dict(), \"vq_vae_final.pt\")\n",
373
+ " try:\n",
374
+ " api.upload_file(\n",
375
+ " path_or_fileobj=\"vq_vae_final.pt\",\n",
376
+ " path_in_repo=\"vq_vae_final.pt\",\n",
377
+ " repo_id=REPO_ID,\n",
378
+ " repo_type=\"model\",\n",
379
+ " commit_message=f\"VQ-VAE epoch {epoch+1}, loss {avg_loss:.4f}\"\n",
380
+ " )\n",
381
+ " print(f\" Pushed VQ-VAE checkpoint to HF!\")\n",
382
+ " except Exception as e:\n",
383
+ " print(f\" Push failed: {e}\")\n",
384
+ "\n",
385
+ " print(f\"\\nVQ-VAE training complete! Best loss: {best_loss:.4f}\")\n",
386
+ " vq_vae.eval()"
387
+ ]
388
+ },
389
+ {
390
+ "cell_type": "markdown",
391
+ "metadata": {},
392
+ "source": [
393
+ "## 🔢 Phase 2: Tokenize Dataset"
394
+ ]
395
+ },
396
+ {
397
+ "cell_type": "code",
398
+ "execution_count": null,
399
+ "metadata": {},
400
+ "outputs": [],
401
+ "source": [
402
+ "# @title 5. Phase 2: Tokenize Image-Text Pairs\n",
403
+ "import json\n",
404
+ "import numpy as np\n",
405
+ "from PIL import Image\n",
406
+ "\n",
407
+ "NUM_TOKENIZE = 50000 # @param {type:\"integer\"}\n",
408
+ "TOKENS_PER_SAMPLE = 64 # 8x8 grid\n",
409
+ "\n",
410
+ "# Caption helpers\n",
411
+ "IMAGENETTE_CLASSES = {\n",
412
+ " 0: \"a fish in water\", 1: \"a dog running in a field\", 2: \"a cassette player on a table\",\n",
413
+ " 3: \"a chainsaw cutting wood\", 4: \"a church with a tall steeple\", 5: \"a French horn on stage\",\n",
414
+ " 6: \"a garbage truck on the street\", 7: \"a gas station at night\", 8: \"a golf ball on a green\",\n",
415
+ " 9: \"a parachute in the sky\",\n",
416
+ "}\n",
417
+ "CIFAR10_CLASSES = [\"airplane flying\", \"automobile on road\", \"bird in tree\", \"cat sitting\",\n",
418
+ " \"deer in forest\", \"dog playing\", \"frog on lily pad\", \"horse running\",\n",
419
+ " \"ship on ocean\", \"truck driving\"]\n",
420
+ "\n",
421
+ "def get_caption(item, cap_key, ds_name, idx):\n",
422
+ " if cap_key and cap_key in item and item[cap_key] is not None:\n",
423
+ " cap = item[cap_key]\n",
424
+ " if isinstance(cap, list):\n",
425
+ " return cap[0] if cap else f\"image {idx}\"\n",
426
+ " elif isinstance(cap, str):\n",
427
+ " return cap\n",
428
+ " elif isinstance(cap, int):\n",
429
+ " if \"imagenette\" in ds_name.lower():\n",
430
+ " return IMAGENETTE_CLASSES.get(cap, f\"photo of object {cap}\")\n",
431
+ " elif \"cifar\" in ds_name.lower():\n",
432
+ " return CIFAR10_CLASSES[cap] if cap < len(CIFAR10_CLASSES) else f\"photo of class {cap}\"\n",
433
+ " return f\"photo of a {cap}\"\n",
434
+ " return f\"image {idx}\"\n",
435
+ "\n",
436
+ "# Load dataset for tokenization (re-load to get fresh stream)\n",
437
+ "print(\"Loading dataset for tokenization...\")\n",
438
+ "ds = None\n",
439
+ "image_key = \"image\"\n",
440
+ "cap_key = None\n",
441
+ "ds_name = \"\"\n",
442
+ "\n",
443
+ "for name, split, ik, ck in [\n",
444
+ " (\"detection-datasets/coco\", \"train\", \"image\", \"caption\"),\n",
445
+ " (\"frgfm/imagenette\", \"train\", \"image\", \"label\"),\n",
446
+ " (\"cifar10\", \"train\", \"img\", \"label\"),\n",
447
+ "]:\n",
448
+ " try:\n",
449
+ " ds = load_dataset(name, split=split, streaming=True, trust_remote_code=True)\n",
450
+ " test_item = next(iter(ds))\n",
451
+ " if ik in test_item:\n",
452
+ " image_key = ik\n",
453
+ " cap_key = ck if ck in test_item else None\n",
454
+ " ds_name = name\n",
455
+ " print(f\"Using {name}\")\n",
456
+ " break\n",
457
+ " ds = None\n",
458
+ " except:\n",
459
+ " ds = None\n",
460
+ "\n",
461
+ "if ds is None:\n",
462
+ " raise RuntimeError(\"No dataset!\")\n",
463
+ "\n",
464
+ "transform = transforms.Compose([\n",
465
+ " transforms.Resize((VQ_VAE_IMG_SIZE, VQ_VAE_IMG_SIZE)),\n",
466
+ " transforms.ToTensor(),\n",
467
+ "])\n",
468
+ "\n",
469
+ "vq_vae.eval()\n",
470
+ "tokenized_data = []\n",
471
+ "count = 0\n",
472
+ "errors = 0\n",
473
+ "\n",
474
+ "print(f\"Tokenizing {NUM_TOKENIZE} images...\")\n",
475
+ "for item in ds:\n",
476
+ " if count >= NUM_TOKENIZE:\n",
477
+ " break\n",
478
+ " try:\n",
479
+ " img = item[image_key]\n",
480
+ " if img.mode != \"RGB\":\n",
481
+ " img = img.convert(\"RGB\")\n",
482
+ " caption = get_caption(item, cap_key, ds_name, count)\n",
483
+ "\n",
484
+ " img_tensor = transform(img).unsqueeze(0).cuda()\n",
485
+ " with torch.no_grad():\n",
486
+ " tokens = vq_vae.encode(img_tensor)\n",
487
+ " flat_tokens = tokens.flatten().tolist()\n",
488
+ "\n",
489
+ " flat_tokens = flat_tokens[:TOKENS_PER_SAMPLE]\n",
490
+ " while len(flat_tokens) < TOKENS_PER_SAMPLE:\n",
491
+ " flat_tokens.append(0)\n",
492
+ "\n",
493
+ " tokenized_data.append({\n",
494
+ " \"text_prompt\": caption,\n",
495
+ " \"video_tokens\": flat_tokens,\n",
496
+ " })\n",
497
+ " count += 1\n",
498
+ "\n",
499
+ " if count % 2000 == 0:\n",
500
+ " print(f\" Tokenized {count}/{NUM_TOKENIZE} (errors: {errors})\")\n",
501
+ " # Save checkpoint\n",
502
+ " with open(\"tokenized_dataset.json\", \"w\") as f:\n",
503
+ " json.dump(tokenized_data, f)\n",
504
+ " # Push to HF\n",
505
+ " try:\n",
506
+ " api.upload_file(\n",
507
+ " path_or_fileobj=\"tokenized_dataset.json\",\n",
508
+ " path_in_repo=\"tokenized_dataset.json\",\n",
509
+ " repo_id=REPO_ID,\n",
510
+ " repo_type=\"model\",\n",
511
+ " commit_message=f\"Tokenized {count} samples\"\n",
512
+ " )\n",
513
+ " except:\n",
514
+ " pass\n",
515
+ "\n",
516
+ " del img_tensor\n",
517
+ " if count % 500 == 0:\n",
518
+ " torch.cuda.empty_cache()\n",
519
+ "\n",
520
+ " except Exception as e:\n",
521
+ " errors += 1\n",
522
+ " if errors <= 3:\n",
523
+ " print(f\" Error: {str(e)[:60]}\")\n",
524
+ " continue\n",
525
+ "\n",
526
+ "# Final save & push\n",
527
+ "with open(\"tokenized_dataset.json\", \"w\") as f:\n",
528
+ " json.dump(tokenized_data, f)\n",
529
+ "\n",
530
+ "api.upload_file(\n",
531
+ " path_or_fileobj=\"tokenized_dataset.json\",\n",
532
+ " path_in_repo=\"tokenized_dataset.json\",\n",
533
+ " repo_id=REPO_ID,\n",
534
+ " repo_type=\"model\",\n",
535
+ " commit_message=f\"Tokenized {len(tokenized_data)} samples (complete)\"\n",
536
+ ")\n",
537
+ "\n",
538
+ "print(f\"\\nTokenization complete: {len(tokenized_data)} samples ({errors} errors)\")\n",
539
+ "print(f\"Sample: '{tokenized_data[0]['text_prompt']}' -> {tokenized_data[0]['video_tokens'][:10]}\")\n",
540
+ "print(f\"Unique tokens in sample: {len(set(tokenized_data[0]['video_tokens']))}\")"
541
+ ]
542
+ },
543
+ {
544
+ "cell_type": "markdown",
545
+ "metadata": {},
546
+ "source": [
547
+ "## 🚀 Phase 3: Fine-tune LLM with LoRA (GPU + Incremental Push)"
548
+ ]
549
+ },
550
+ {
551
+ "cell_type": "code",
552
+ "execution_count": null,
553
+ "metadata": {},
554
+ "outputs": [],
555
+ "source": [
556
+ "# @title 6. Phase 3: Setup LLM + LoRA with HuggingFace Trainer\n",
557
+ "from transformers import (\n",
558
+ " AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer,\n",
559
+ " DataCollatorForLanguageModeling, TrainerCallback\n",
560
+ ")\n",
561
+ "from peft import LoraConfig, get_peft_model, TaskType\n",
562
+ "from torch.utils.data import Dataset\n",
563
+ "\n",
564
+ "# Hyperparameters\n",
565
+ "LORA_R = 8 # @param {type:\"integer\"}\n",
566
+ "LORA_ALPHA = 16 # @param {type:\"integer\"}\n",
567
+ "LORA_DROPOUT = 0.05 # @param {type:\"number\"}\n",
568
+ "LEARNING_RATE = 2e-4 # @param {type:\"number\"}\n",
569
+ "BATCH_SIZE = 2 # @param {type:\"integer\"}\n",
570
+ "GRADIENT_ACCUMULATION = 8 # @param {type:\"integer\"}\n",
571
+ "NUM_EPOCHS = 3 # @param {type:\"integer\"}\n",
572
+ "MAX_SEQ_LEN = 256 # @param {type:\"integer\"}\n",
573
+ "WARMUP_RATIO = 0.03 # @param {type:\"number\"}\n",
574
+ "WEIGHT_DECAY = 0.01 # @param {type:\"number\"}\n",
575
+ "SAVE_STEPS = 200 # @param {type:\"integer\"}\n",
576
+ "EVAL_STEPS = 200 # @param {type:\"integer\"}\n",
577
+ "FP16 = True # @param {type:\"boolean\"}\n",
578
+ "TRAIN_ON_ALL_DATA = False # @param {type:\"boolean\"}\n",
579
+ "LLM_TRAIN_SAMPLES = 10000 # @param {type:\"integer\"}\n",
580
+ "\n",
581
+ "MODEL_NAME = \"allenai/OLMo-2-0425-1B-Instruct\"\n",
582
+ "VIDEO_START = \"<video_start>\"\n",
583
+ "VIDEO_END = \"<video_end>\"\n",
584
+ "VIDEO_PAD = \"<video_pad>\"\n",
585
+ "\n",
586
+ "# Load tokenizer\n",
587
+ "print(\"Loading tokenizer...\")\n",
588
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)\n",
589
+ "if tokenizer.pad_token is None:\n",
590
+ " tokenizer.pad_token = tokenizer.eos_token\n",
591
+ "orig_vocab = len(tokenizer)\n",
592
+ "print(f\"Original vocab: {orig_vocab}\")\n",
593
+ "\n",
594
+ "# Expand vocab with visual tokens\n",
595
+ "visual_tokens = [VIDEO_START, VIDEO_END, VIDEO_PAD]\n",
596
+ "for i in range(CODEBOOK_SIZE):\n",
597
+ " visual_tokens.append(f\"<v_{i}>\")\n",
598
+ "tokenizer.add_tokens(visual_tokens)\n",
599
+ "print(f\"Expanded vocab: {len(tokenizer)} (+{len(tokenizer) - orig_vocab} visual tokens)\")\n",
600
+ "\n",
601
+ "# Load model\n",
602
+ "print(\"Loading model...\")\n",
603
+ "dtype = torch.float16 if FP16 else torch.float32\n",
604
+ "model = AutoModelForCausalLM.from_pretrained(\n",
605
+ " MODEL_NAME, trust_remote_code=True, torch_dtype=dtype\n",
606
+ ")\n",
607
+ "model.resize_token_embeddings(len(tokenizer))\n",
608
+ "print(f\"Model loaded: {MODEL_NAME}\")\n",
609
+ "\n",
610
+ "# Apply LoRA\n",
611
+ "print(f\"Applying LoRA (r={LORA_R})...\")\n",
612
+ "lora_config = LoraConfig(\n",
613
+ " r=LORA_R,\n",
614
+ " lora_alpha=LORA_ALPHA,\n",
615
+ " target_modules=[\"q_proj\", \"v_proj\", \"k_proj\", \"o_proj\"], # More modules than before!\n",
616
+ " lora_dropout=LORA_DROPOUT,\n",
617
+ " bias=\"none\",\n",
618
+ " task_type=TaskType.CAUSAL_LM,\n",
619
+ ")\n",
620
+ "model = get_peft_model(model, lora_config)\n",
621
+ "trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
622
+ "total = sum(p.numel() for p in model.parameters())\n",
623
+ "print(f\"LoRA: {trainable:,} / {total:,} trainable ({100*trainable/total:.2f}%)\")\n",
624
+ "model.print_trainable_parameters()"
625
+ ]
626
+ },
627
+ {
628
+ "cell_type": "code",
629
+ "execution_count": null,
630
+ "metadata": {},
631
+ "outputs": [],
632
+ "source": [
633
+ "# @title 7. Create Training Dataset\n",
634
+ "class VideoTokenDataset(Dataset):\n",
635
+ " def __init__(self, data, tokenizer, max_tokens=64, max_len=256):\n",
636
+ " self.data = data\n",
637
+ " self.tokenizer = tokenizer\n",
638
+ " self.max_tokens = max_tokens\n",
639
+ " self.max_len = max_len\n",
640
+ "\n",
641
+ " def __len__(self):\n",
642
+ " return len(self.data)\n",
643
+ "\n",
644
+ " def __getitem__(self, idx):\n",
645
+ " item = self.data[idx]\n",
646
+ " prompt = item[\"text_prompt\"]\n",
647
+ " tokens = item[\"video_tokens\"][:self.max_tokens]\n",
648
+ " while len(tokens) < self.max_tokens:\n",
649
+ " tokens.append(0)\n",
650
+ " token_str = \" \".join(f\"<v_{t}>\" for t in tokens)\n",
651
+ " text = f\"Create a video of: {prompt} {VIDEO_START} {token_str} {VIDEO_END}\"\n",
652
+ "\n",
653
+ " encoding = self.tokenizer(\n",
654
+ " text, return_tensors=\"pt\", truncation=True,\n",
655
+ " max_length=self.max_len, padding=\"max_length\"\n",
656
+ " )\n",
657
+ " input_ids = encoding[\"input_ids\"].squeeze()\n",
658
+ " attention_mask = encoding[\"attention_mask\"].squeeze()\n",
659
+ " labels = input_ids.clone()\n",
660
+ " # Don't compute loss on padding\n",
661
+ " labels[labels == self.tokenizer.pad_token_id] = -100\n",
662
+ "\n",
663
+ " return {\n",
664
+ " \"input_ids\": input_ids,\n",
665
+ " \"attention_mask\": attention_mask,\n",
666
+ " \"labels\": labels,\n",
667
+ " }\n",
668
+ "\n",
669
+ "# Load data\n",
670
+ "with open(\"tokenized_dataset.json\") as f:\n",
671
+ " all_data = json.load(f)\n",
672
+ "\n",
673
+ "if not TRAIN_ON_ALL_DATA:\n",
674
+ " all_data = all_data[:LLM_TRAIN_SAMPLES]\n",
675
+ "\n",
676
+ "print(f\"Training on {len(all_data)} samples\")\n",
677
+ "\n",
678
+ "# Split into train/eval\n",
679
+ "split_idx = int(len(all_data) * 0.95)\n",
680
+ "train_data = all_data[:split_idx]\n",
681
+ "eval_data = all_data[split_idx:]\n",
682
+ "\n",
683
+ "train_dataset = VideoTokenDataset(train_data, tokenizer)\n",
684
+ "eval_dataset = VideoTokenDataset(eval_data, tokenizer)\n",
685
+ "\n",
686
+ "print(f\"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}\")\n",
687
+ "\n",
688
+ "# Test one sample\n",
689
+ "sample = train_dataset[0]\n",
690
+ "decoded = tokenizer.decode(sample[\"input_ids\"][:80], skip_special_tokens=False)\n",
691
+ "print(f\"Sample: {decoded[:200]}...\")"
692
+ ]
693
+ },
694
+ {
695
+ "cell_type": "code",
696
+ "execution_count": null,
697
+ "metadata": {},
698
+ "outputs": [],
699
+ "source": [
700
+ "# @title 8. Configure HuggingFace Trainer with Incremental Push\n",
701
+ "\n",
702
+ "# Training arguments with push_to_hub for incremental checkpoint saves\n",
703
+ "training_args = TrainingArguments(\n",
704
+ " output_dir=\"./zeeb-checkpoints\",\n",
705
+ " \n",
706
+ " # Training params\n",
707
+ " num_train_epochs=NUM_EPOCHS,\n",
708
+ " per_device_train_batch_size=BATCH_SIZE,\n",
709
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
710
+ " gradient_accumulation_steps=GRADIENT_ACCUMULATION,\n",
711
+ " learning_rate=LEARNING_RATE,\n",
712
+ " weight_decay=WEIGHT_DECAY,\n",
713
+ " warmup_ratio=WARMUP_RATIO,\n",
714
+ " lr_scheduler_type=\"cosine\",\n",
715
+ " max_grad_norm=1.0,\n",
716
+ " \n",
717
+ " # Precision\n",
718
+ " fp16=FP16,\n",
719
+ " bf16=False,\n",
720
+ " \n",
721
+ " # Logging\n",
722
+ " logging_steps=10,\n",
723
+ " logging_first_step=True,\n",
724
+ " \n",
725
+ " # Saving - INCREMENTAL PUSH TO HF\n",
726
+ " save_strategy=\"steps\",\n",
727
+ " save_steps=SAVE_STEPS,\n",
728
+ " save_total_limit=3, # Keep only 3 checkpoints on disk\n",
729
+ " \n",
730
+ " # Evaluation\n",
731
+ " eval_strategy=\"steps\",\n",
732
+ " eval_steps=EVAL_STEPS,\n",
733
+ " \n",
734
+ " # INCREMENTAL PUSH TO HUGGINGFACE\n",
735
+ " push_to_hub=True,\n",
736
+ " hub_model_id=REPO_ID,\n",
737
+ " hub_token=HF_TOKEN,\n",
738
+ " hub_strategy=\"every_save\", # Push every time we save a checkpoint!\n",
739
+ " \n",
740
+ " # Resume from checkpoint\n",
741
+ " resume_from_checkpoint=True,\n",
742
+ " \n",
743
+ " # Performance\n",
744
+ " dataloader_num_workers=2,\n",
745
+ " dataloader_pin_memory=True,\n",
746
+ " gradient_checkpointing=True, # Save memory\n",
747
+ " optim=\"adamw_torch\",\n",
748
+ " \n",
749
+ " # Misc\n",
750
+ " remove_unused_columns=False,\n",
751
+ " report_to=\"none\", # Disable wandb/tensorboard\n",
752
+ " run_name=\"zeeb-video-llm\",\n",
753
+ ")\n",
754
+ "\n",
755
+ "print(\"Training Arguments:\")\n",
756
+ "print(f\" Epochs: {NUM_EPOCHS}\")\n",
757
+ "print(f\" Batch: {BATCH_SIZE} x {GRADIENT_ACCUMULATION} accumulation = effective {BATCH_SIZE * GRADIENT_ACCUMULATION}\")\n",
758
+ "print(f\" LR: {LEARNING_RATE}, Scheduler: cosine\")\n",
759
+ "print(f\" FP16: {FP16}\")\n",
760
+ "print(f\" Save every {SAVE_STEPS} steps → push to HF\")\n",
761
+ "print(f\" Push to: {REPO_ID}\")\n",
762
+ "print(f\" Hub strategy: every_save (incremental push)\")\n",
763
+ "print(f\" Gradient checkpointing: True\")\n",
764
+ "print(f\" Resume from checkpoint: True\")"
765
+ ]
766
+ },
767
+ {
768
+ "cell_type": "code",
769
+ "execution_count": null,
770
+ "metadata": {},
771
+ "outputs": [],
772
+ "source": [
773
+ "# @title 9. 🚀 START TRAINING! (with auto-resume)\n",
774
+ "import os\n",
775
+ "\n",
776
+ "# Check for existing checkpoints to resume from\n",
777
+ "checkpoint_dir = \"./zeeb-checkpoints\"\n",
778
+ "resume_ckpt = None\n",
779
+ "if os.path.exists(checkpoint_dir):\n",
780
+ " checkpoints = [d for d in os.listdir(checkpoint_dir) if d.startswith(\"checkpoint-\")]\n",
781
+ " if checkpoints:\n",
782
+ " latest = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))[-1]\n",
783
+ " resume_ckpt = os.path.join(checkpoint_dir, latest)\n",
784
+ " print(f\"Found checkpoint to resume from: {resume_ckpt}\")\n",
785
+ "\n",
786
+ "# Create trainer\n",
787
+ "trainer = Trainer(\n",
788
+ " model=model,\n",
789
+ " args=training_args,\n",
790
+ " train_dataset=train_dataset,\n",
791
+ " eval_dataset=eval_dataset,\n",
792
+ " data_collator=None, # Use default\n",
793
+ ")\n",
794
+ "\n",
795
+ "# Calculate total steps\n",
796
+ "total_steps = (len(train_dataset) // (BATCH_SIZE * GRADIENT_ACCUMULATION)) * NUM_EPOCHS\n",
797
+ "print(f\"\\nTotal training steps: ~{total_steps}\")\n",
798
+ "print(f\"Checkpoints will be pushed every {SAVE_STEPS} steps ({total_steps // SAVE_STEPS} pushes)\")\n",
799
+ "print(f\"\\nStarting training...\")\n",
800
+ "print(f\"If Colab disconnects, just re-run this cell — it will auto-resume!\\n\")\n",
801
+ "\n",
802
+ "# Train! (auto-resumes from checkpoint if available)\n",
803
+ "train_result = trainer.train(resume_from_checkpoint=resume_ckpt)\n",
804
+ "\n",
805
+ "print(f\"\\nTraining complete!\")\n",
806
+ "print(f\" Final loss: {train_result.training_loss:.4f}\")\n",
807
+ "print(f\" Total steps: {train_result.global_step}\")\n",
808
+ "print(f\" Training time: {train_result.metrics['train_runtime']:.0f}s ({train_result.metrics['train_runtime']/60:.1f} min)\")"
809
+ ]
810
+ },
811
+ {
812
+ "cell_type": "code",
813
+ "execution_count": null,
814
+ "metadata": {},
815
+ "outputs": [],
816
+ "source": [
817
+ "# @title 10. Merge LoRA & Push Final Model to HuggingFace\n",
818
+ "print(\"Merging LoRA weights into base model...\")\n",
819
+ "model = model.merge_and_unload()\n",
820
+ "\n",
821
+ "# Save locally\n",
822
+ "final_dir = \"./zeeb-final\"\n",
823
+ "model.save_pretrained(final_dir, safe_serialization=True)\n",
824
+ "tokenizer.save_pretrained(final_dir)\n",
825
+ "\n",
826
+ "# Copy VQ-VAE checkpoint\n",
827
+ "import shutil\n",
828
+ "if os.path.exists(\"vq_vae_final.pt\"):\n",
829
+ " shutil.copy(\"vq_vae_final.pt\", f\"{final_dir}/vq_vae_final.pt\")\n",
830
+ "if os.path.exists(\"tokenized_dataset.json\"):\n",
831
+ " shutil.copy(\"tokenized_dataset.json\", f\"{final_dir}/tokenized_dataset.json\")\n",
832
+ "\n",
833
+ "# Push final merged model to HuggingFace\n",
834
+ "print(f\"Pushing final model to {REPO_ID}...\")\n",
835
+ "model.push_to_hub(\n",
836
+ " REPO_ID,\n",
837
+ " token=HF_TOKEN,\n",
838
+ " commit_message=f\"Zeeb v2: OLMo 2 1B + LoRA (r={LORA_R}), {NUM_EPOCHS} epochs, {len(train_data)} samples, GPU-trained\"\n",
839
+ ")\n",
840
+ "tokenizer.push_to_hub(\n",
841
+ " REPO_ID,\n",
842
+ " token=HF_TOKEN,\n",
843
+ " commit_message=f\"Zeeb v2: tokenizer with visual tokens\"\n",
844
+ ")\n",
845
+ "\n",
846
+ "# Push additional files\n",
847
+ "for fname in [\"vq_vae_final.pt\", \"tokenized_dataset.json\"]:\n",
848
+ " if os.path.exists(fname):\n",
849
+ " api.upload_file(\n",
850
+ " path_or_fileobj=fname,\n",
851
+ " path_in_repo=fname,\n",
852
+ " repo_id=REPO_ID,\n",
853
+ " repo_type=\"model\",\n",
854
+ " commit_message=f\"Add {fname}\"\n",
855
+ " )\n",
856
+ "\n",
857
+ "print(f\"\\n✅ Final model pushed to https://huggingface.co/{REPO_ID}\")\n",
858
+ "print(\"This model can now be loaded in the HF Space for video generation!\")"
859
+ ]
860
+ },
861
+ {
862
+ "cell_type": "markdown",
863
+ "metadata": {},
864
+ "source": [
865
+ "## 🧪 Test: Generate a Video with the Trained Model"
866
+ ]
867
+ },
868
+ {
869
+ "cell_type": "code",
870
+ "execution_count": null,
871
+ "metadata": {},
872
+ "outputs": [],
873
+ "source": [
874
+ "# @title 11. Test Video Generation\n",
875
+ "import numpy as np\n",
876
+ "from PIL import Image\n",
877
+ "import imageio\n",
878
+ "\n",
879
+ "PROMPT = \"A cat jumping on a sofa\" # @param {type:\"string\"}\n",
880
+ "MAX_TOKENS = 64 # @param {type:\"integer\"}\n",
881
+ "TEMPERATURE = 0.9 # @param {type:\"number\"}\n",
882
+ "TOP_K = 50 # @param {type:\"integer\"}\n",
883
+ "\n",
884
+ "# Get visual token IDs\n",
885
+ "VIDEO_START_ID = tokenizer.convert_tokens_to_ids(\"<video_start>\")\n",
886
+ "VIDEO_END_ID = tokenizer.convert_tokens_to_ids(\"<video_end>\")\n",
887
+ "V_TOKEN_START_ID = tokenizer.convert_tokens_to_ids(\"<v_0>\")\n",
888
+ "V_TOKEN_END_ID = tokenizer.convert_tokens_to_ids(\"<v_1023>\")\n",
889
+ "\n",
890
+ "# Load VQ-VAE for decoding\n",
891
+ "vq_vae = VQVAE().cuda()\n",
892
+ "if os.path.exists(\"vq_vae_final.pt\"):\n",
893
+ " vq_vae.load_state_dict(torch.load(\"vq_vae_final.pt\", map_location=\"cuda\", weights_only=False))\n",
894
+ " print(\"Loaded trained VQ-VAE\")\n",
895
+ "vq_vae.eval()\n",
896
+ "\n",
897
+ "# Generate with constrained decoding\n",
898
+ "text = f\"Create a video of: {PROMPT} <video_start>\"\n",
899
+ "inputs = tokenizer(text, return_tensors=\"pt\", truncation=True, max_length=256)\n",
900
+ "current_ids = inputs[\"input_ids\"].cuda()\n",
901
+ "\n",
902
+ "vocab_size = len(tokenizer)\n",
903
+ "visual_mask = torch.zeros(vocab_size, dtype=torch.bool)\n",
904
+ "visual_mask[V_TOKEN_START_ID:V_TOKEN_END_ID + 1] = True\n",
905
+ "visual_mask[VIDEO_END_ID] = True\n",
906
+ "\n",
907
+ "visual_token_ids = []\n",
908
+ "model.eval()\n",
909
+ "\n",
910
+ "print(f\"Generating visual tokens for: '{PROMPT}'\")\n",
911
+ "with torch.no_grad():\n",
912
+ " for step in range(MAX_TOKENS):\n",
913
+ " outputs = model(input_ids=current_ids)\n",
914
+ " logits = outputs.logits[:, -1, :]\n",
915
+ " masked = logits.clone()\n",
916
+ " masked[0, ~visual_mask] = float('-inf')\n",
917
+ " masked = masked / max(TEMPERATURE, 0.01)\n",
918
+ " if TOP_K > 0:\n",
919
+ " top_k_values, _ = torch.topk(masked[0], min(TOP_K, masked.size(-1)))\n",
920
+ " threshold = top_k_values[-1]\n",
921
+ " masked[0, masked[0] < threshold] = float('-inf')\n",
922
+ " probs = F.softmax(masked, dim=-1)\n",
923
+ " next_token = torch.multinomial(probs, num_samples=1)\n",
924
+ " next_id = next_token.item()\n",
925
+ " if next_id == VIDEO_END_ID:\n",
926
+ " break\n",
927
+ " visual_idx = next_id - V_TOKEN_START_ID\n",
928
+ " visual_token_ids.append(visual_idx)\n",
929
+ " current_ids = torch.cat([current_ids, next_token], dim=-1)\n",
930
+ "\n",
931
+ "print(f\"Generated {len(visual_token_ids)} visual tokens ({len(set(visual_token_ids))} unique)\")\n",
932
+ "\n",
933
+ "# Decode through VQ-VAE\n",
934
+ "grid_h, grid_w = 8, 8\n",
935
+ "tokens_per_frame = grid_h * grid_w\n",
936
+ "num_frames = max(1, len(visual_token_ids) // tokens_per_frame)\n",
937
+ "\n",
938
+ "frames = []\n",
939
+ "for fi in range(num_frames):\n",
940
+ " ft = visual_token_ids[fi*tokens_per_frame:(fi+1)*tokens_per_frame]\n",
941
+ " frame_tensor = vq_vae.decode_tokens(ft, grid_h, grid_w)\n",
942
+ " frame_np = (frame_tensor[0].permute(1, 2, 0).detach().cpu().numpy() * 255).astype(np.uint8)\n",
943
+ " frames.append(frame_np)\n",
944
+ "\n",
945
+ "# Save video\n",
946
+ "if frames:\n",
947
+ " upscaled = [np.array(Image.fromarray(f).resize((256, 256), Image.BILINEAR)) for f in frames]\n",
948
+ " output_path = \"/content/generated_video.mp4\"\n",
949
+ " imageio.mimsave(output_path, upscaled, fps=2)\n",
950
+ " print(f\"Video saved: {output_path} ({len(upscaled)} frames, 256x256)\")\n",
951
+ " \n",
952
+ " # Display first frame\n",
953
+ " from IPython.display import display\n",
954
+ " display(Image.fromarray(upscaled[0]))\n",
955
+ "else:\n",
956
+ " print(\"No frames generated\")"
957
+ ]
958
+ },
959
+ {
960
+ "cell_type": "markdown",
961
+ "metadata": {},
962
+ "source": [
963
+ "## 📊 Summary & Next Steps\n",
964
+ "\n",
965
+ "### What was trained:\n",
966
+ "- **VQ-VAE**: 3.8M params, trained on real COCO images, maps images ↔ discrete tokens\n",
967
+ "- **OLMo 2 1B + LoRA**: 1B params (only ~1M trainable), fine-tuned to predict visual tokens from text\n",
968
+ "\n",
969
+ "### How to improve further:\n",
970
+ "1. **More data**: Use 50K+ samples instead of 10K\n",
971
+ "2. **Bigger LoRA**: Increase r from 8 to 16-32\n",
972
+ "3. **More target modules**: Add \"gate_proj\", \"up_proj\", \"down_proj\" to LoRA targets\n",
973
+ "4. **Video data**: Use OpenVid-1M with actual video frames (multiple frames per clip)\n",
974
+ "5. **Larger codebook**: 4096 or 8192 entries instead of 1024\n",
975
+ "6. **Higher resolution**: 256x256 VQ-VAE instead of 128x128\n",
976
+ "7. **Multi-frame**: Encode 4-8 frames per video, not just 1\n",
977
+ "\n",
978
+ "### Resume after Colab disconnect:\n",
979
+ "Just re-run cells 1, 2, 3, 6, 7, 8, and 9 — the Trainer will auto-resume from the last checkpoint pushed to HF!"
980
+ ]
981
+ }
982
+ ]
983
+ }