eeshaAI commited on
Commit
e53a305
Β·
verified Β·
1 Parent(s): 6e8dde1

Update train_full_pipeline.py: full training pipeline with real datasets

Browse files
Files changed (1) hide show
  1. train_full_pipeline.py +690 -0
train_full_pipeline.py ADDED
@@ -0,0 +1,690 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Full Pipeline: Train VQ-VAE β†’ Tokenize OpenVid β†’ Train LLM β†’ Push to EeshaAI/zeeb
4
+ =================================================================================
5
+ Runs on HuggingFace Spaces (free CPU tier, 16GB RAM).
6
+
7
+ Phase 1: Train VQ-VAE on COCO 2017 images (118K real images, streaming)
8
+ Phase 2: Stream 10K clips from OpenVid-1M β†’ tokenize via trained VQ-VAE β†’ save integers
9
+ Phase 3: Fine-tune OLMo 2 1B with LoRA on 10K tokenized samples
10
+ Phase 4: Push trained model to EeshaAI/zeeb
11
+ """
12
+
13
+ import os
14
+ import sys
15
+ import json
16
+ import time
17
+ import gc
18
+ import threading
19
+ import traceback
20
+ import numpy as np
21
+ from typing import Optional
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+ from torch.utils.data import DataLoader, Dataset, IterableDataset
27
+
28
+ # ============================================================================
29
+ # CONFIGURATION
30
+ # ============================================================================
31
+ HF_TOKEN = os.environ.get("HF_TOKEN", "")
32
+ REPO_ID = "eeshaAI/zeeb"
33
+ MODEL_NAME = "allenai/OLMo-2-0425-1B-Instruct"
34
+ CODEBOOK_SIZE = 1024
35
+ CODEBOOK_DIM = 256
36
+ LATENT_DIM = 256
37
+ VIDEO_START = "<video_start>"
38
+ VIDEO_END = "<video_end>"
39
+ VIDEO_PAD = "<video_pad>"
40
+
41
+ # VQ-VAE training
42
+ VQ_VAE_EPOCHS = 5
43
+ VQ_VAE_LR = 1e-3
44
+ VQ_VAE_BATCH = 32
45
+ VQ_VAE_IMG_SIZE = 128 # resize images to 128x128
46
+
47
+ # Dataset preparation
48
+ NUM_OPENVID_CLIPS = 10000
49
+ TOKENS_PER_CLIP = 128 # number of visual tokens per video clip
50
+
51
+ # LLM training
52
+ NUM_EPOCHS = 3
53
+ LORA_R = 4
54
+ LORA_ALPHA = 8
55
+ LORA_DROPOUT = 0.05
56
+ LEARNING_RATE = 1e-4
57
+ BATCH_SIZE = 1
58
+ MAX_SEQ_LEN = 384
59
+ GRADIENT_ACCUMULATION = 4
60
+
61
+ LOG_FILE = "/tmp/pipeline_log.txt"
62
+
63
+
64
+ # ============================================================================
65
+ # LOGGING
66
+ # ============================================================================
67
+ class Logger:
68
+ def __init__(self, path):
69
+ self.path = path
70
+ self.lock = threading.Lock()
71
+ with open(path, "w") as f:
72
+ f.write("πŸš€ Zeeb Full Pipeline Starting...\n\n")
73
+
74
+ def log(self, msg):
75
+ with self.lock:
76
+ with open(self.path, "a") as f:
77
+ f.write(msg)
78
+ f.flush()
79
+ print(msg, end="", flush=True)
80
+
81
+
82
+ # ============================================================================
83
+ # VQ-VAE MODEL
84
+ # ============================================================================
85
+ class Encoder(nn.Module):
86
+ def __init__(self, in_channels=3, latent_dim=LATENT_DIM):
87
+ super().__init__()
88
+ self.net = nn.Sequential(
89
+ nn.Conv2d(in_channels, 64, 4, stride=2, padding=1), # β†’ 64x64
90
+ nn.ReLU(),
91
+ nn.Conv2d(64, 128, 4, stride=2, padding=1), # β†’ 32x32
92
+ nn.ReLU(),
93
+ nn.Conv2d(128, 256, 4, stride=2, padding=1), # β†’ 16x16
94
+ nn.ReLU(),
95
+ nn.Conv2d(256, latent_dim, 4, stride=2, padding=1), # β†’ 8x8
96
+ )
97
+
98
+ def forward(self, x):
99
+ return self.net(x)
100
+
101
+
102
+ class VectorQuantizer(nn.Module):
103
+ def __init__(self, codebook_size=CODEBOOK_SIZE, codebook_dim=CODEBOOK_DIM, commitment_cost=0.25):
104
+ super().__init__()
105
+ self.codebook_size = codebook_size
106
+ self.codebook_dim = codebook_dim
107
+ self.commitment_cost = commitment_cost
108
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
109
+ self.codebook.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size)
110
+
111
+ def forward(self, z):
112
+ # z: [B, H, W, C] (channels last)
113
+ B, H, W, C = z.shape
114
+ z_flat = z.reshape(-1, C)
115
+
116
+ # Find nearest codebook entry
117
+ dist = (z_flat.unsqueeze(1) - self.codebook.weight.unsqueeze(0)).pow(2).sum(-1)
118
+ indices = dist.argmin(dim=1)
119
+
120
+ z_q = self.codebook(indices).reshape(B, H, W, C)
121
+
122
+ # Losses
123
+ commitment_loss = F.mse_loss(z_flat, z_q.reshape(-1, C).detach())
124
+ codebook_loss = F.mse_loss(z_q.reshape(-1, C), z_flat.detach())
125
+ loss = codebook_loss + self.commitment_cost * commitment_loss
126
+
127
+ # Straight-through estimator
128
+ z_q_st = z + (z_q - z).detach()
129
+
130
+ return z_q_st, loss, indices.reshape(B, H, W)
131
+
132
+
133
+ class Decoder(nn.Module):
134
+ def __init__(self, out_channels=3, latent_dim=LATENT_DIM):
135
+ super().__init__()
136
+ self.net = nn.Sequential(
137
+ nn.ConvTranspose2d(latent_dim, 256, 4, stride=2, padding=1), # β†’ 16x16
138
+ nn.ReLU(),
139
+ nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), # β†’ 32x32
140
+ nn.ReLU(),
141
+ nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), # β†’ 64x64
142
+ nn.ReLU(),
143
+ nn.ConvTranspose2d(64, out_channels, 4, stride=2, padding=1), # β†’ 128x128
144
+ nn.Sigmoid(),
145
+ )
146
+
147
+ def forward(self, x):
148
+ return self.net(x)
149
+
150
+
151
+ class VQVAE(nn.Module):
152
+ def __init__(self):
153
+ super().__init__()
154
+ self.encoder = Encoder()
155
+ self.quantizer = VectorQuantizer()
156
+ self.proj_in = nn.Linear(LATENT_DIM, CODEBOOK_DIM)
157
+ self.proj_out = nn.Linear(CODEBOOK_DIM, LATENT_DIM)
158
+ self.decoder = Decoder()
159
+
160
+ def forward(self, x):
161
+ z = self.encoder(x) # [B, C, H, W]
162
+ z = z.permute(0, 2, 3, 1) # [B, H, W, C]
163
+ z = self.proj_in(z) # [B, H, W, codebook_dim]
164
+ z_q, vq_loss, indices = self.quantizer(z)
165
+ z_q = self.proj_out(z_q) # [B, H, W, latent_dim]
166
+ z_q = z_q.permute(0, 3, 1, 2) # [B, C, H, W]
167
+ recon = self.decoder(z_q)
168
+ return recon, vq_loss, indices
169
+
170
+ def encode(self, x):
171
+ """Encode image to token indices."""
172
+ z = self.encoder(x)
173
+ z = z.permute(0, 2, 3, 1)
174
+ z = self.proj_in(z)
175
+ _, _, indices = self.quantizer(z)
176
+ return indices # [B, H, W]
177
+
178
+ def decode_tokens(self, token_ids, grid_h=8, grid_w=8):
179
+ """Decode token IDs back to image."""
180
+ if isinstance(token_ids, list):
181
+ token_ids = torch.tensor(token_ids, dtype=torch.long)
182
+ token_ids = token_ids[:grid_h * grid_w]
183
+ if len(token_ids) < grid_h * grid_w:
184
+ token_ids = torch.cat([token_ids, torch.zeros(grid_h * grid_w - len(token_ids), dtype=torch.long)])
185
+
186
+ z_q = self.quantizer.codebook(token_ids) # [H*W, D]
187
+ z_q = self.proj_out(z_q) # [H*W, latent_dim]
188
+ z_q = z_q.reshape(1, grid_h, grid_w, -1).permute(0, 3, 1, 2)
189
+ return self.decoder(z_q)
190
+
191
+
192
+ # ============================================================================
193
+ # PHASE 1: TRAIN VQ-VAE ON COCO IMAGES
194
+ # ============================================================================
195
+ def train_vq_vae(logger: Logger) -> VQVAE:
196
+ """Train VQ-VAE on COCO 2017 images (streaming, so no massive download)."""
197
+ logger.log("=" * 60 + "\n")
198
+ logger.log("PHASE 1: Training VQ-VAE on COCO 2017 images\n")
199
+ logger.log("=" * 60 + "\n\n")
200
+
201
+ from datasets import load_dataset
202
+ from torchvision import transforms
203
+
204
+ # Load COCO in streaming mode
205
+ logger.log("πŸ“¦ Loading COCO 2017 dataset (streaming)...\n")
206
+ coco = load_dataset("HuggingFaceM4/COCO", split="train", streaming=True, trust_remote_code=True)
207
+
208
+ # Image transforms
209
+ transform = transforms.Compose([
210
+ transforms.Resize((VQ_VAE_IMG_SIZE, VQ_VAE_IMG_SIZE)),
211
+ transforms.ToTensor(), # [0, 1]
212
+ ])
213
+
214
+ class COCOStreamDataset(IterableDataset):
215
+ def __init__(self, hf_dataset, transform, max_samples=50000):
216
+ self.dataset = hf_dataset
217
+ self.transform = transform
218
+ self.max_samples = max_samples
219
+
220
+ def __iter__(self):
221
+ count = 0
222
+ for item in self.dataset:
223
+ if count >= self.max_samples:
224
+ break
225
+ try:
226
+ img = item["image"]
227
+ if img.mode != "RGB":
228
+ img = img.convert("RGB")
229
+ tensor = self.transform(img)
230
+ yield tensor
231
+ count += 1
232
+ except Exception:
233
+ continue
234
+
235
+ dataset = COCOStreamDataset(coco, transform, max_samples=50000)
236
+ dataloader = DataLoader(dataset, batch_size=VQ_VAE_BATCH, num_workers=0)
237
+
238
+ # Initialize model
239
+ model = VQVAE()
240
+ n_params = sum(p.numel() for p in model.parameters()) / 1e6
241
+ logger.log(f"βœ… VQ-VAE initialized: {n_params:.1f}M parameters\n")
242
+
243
+ optimizer = torch.optim.Adam(model.parameters(), lr=VQ_VAE_LR)
244
+ model.train()
245
+
246
+ for epoch in range(VQ_VAE_EPOCHS):
247
+ epoch_loss = 0.0
248
+ epoch_recon = 0.0
249
+ epoch_vq = 0.0
250
+ num_batches = 0
251
+ start_time = time.time()
252
+
253
+ for batch_idx, batch in enumerate(dataloader):
254
+ recon, vq_loss, _ = model(batch)
255
+ recon_loss = F.mse_loss(recon, batch)
256
+ loss = recon_loss + vq_loss
257
+
258
+ optimizer.zero_grad()
259
+ loss.backward()
260
+ optimizer.step()
261
+
262
+ epoch_loss += loss.item()
263
+ epoch_recon += recon_loss.item()
264
+ epoch_vq += vq_loss.item()
265
+ num_batches += 1
266
+
267
+ if batch_idx % 50 == 0 and batch_idx > 0:
268
+ avg = epoch_loss / num_batches
269
+ avg_r = epoch_recon / num_batches
270
+ avg_v = epoch_vq / num_batches
271
+ logger.log(f" Epoch {epoch+1}/{VQ_VAE_EPOCHS} | Batch {batch_idx} | "
272
+ f"Loss: {avg:.4f} (recon: {avg_r:.4f}, vq: {avg_v:.4f})\n")
273
+
274
+ del recon, vq_loss, loss
275
+ if batch_idx % 200 == 0:
276
+ gc.collect()
277
+
278
+ elapsed = time.time() - start_time
279
+ avg_loss = epoch_loss / max(num_batches, 1)
280
+ logger.log(f"\nπŸ“ˆ Epoch {epoch+1} done. Avg Loss: {avg_loss:.4f} | "
281
+ f"Batches: {num_batches} | Time: {elapsed:.0f}s\n\n")
282
+
283
+ # Save
284
+ torch.save(model.state_dict(), "vq_vae_real.pt")
285
+ logger.log("βœ… VQ-VAE saved to vq_vae_real.pt\n\n")
286
+ return model
287
+
288
+
289
+ # ============================================================================
290
+ # PHASE 2: TOKENIZE OPENVID-1M DATASET
291
+ # ============================================================================
292
+ def tokenize_openvid(logger: Logger, vq_vae: Optional[VQVAE] = None):
293
+ """Stream OpenVid-1M, tokenize videos with VQ-VAE, save tokenized data."""
294
+ logger.log("=" * 60 + "\n")
295
+ logger.log("PHASE 2: Tokenizing OpenVid-1M dataset (10K clips)\n")
296
+ logger.log("=" * 60 + "\n\n")
297
+
298
+ # Load VQ-VAE if not provided
299
+ if vq_vae is None:
300
+ if os.path.exists("vq_vae_real.pt"):
301
+ vq_vae = VQVAE()
302
+ vq_vae.load_state_dict(torch.load("vq_vae_real.pt", map_location="cpu", weights_only=False))
303
+ logger.log("βœ… Loaded trained VQ-VAE from vq_vae_real.pt\n")
304
+ else:
305
+ logger.log("❌ No trained VQ-VAE found! Run Phase 1 first.\n")
306
+ return None
307
+
308
+ vq_vae.eval()
309
+
310
+ from datasets import load_dataset
311
+
312
+ logger.log("πŸ“¦ Loading OpenVid-1M dataset (streaming)...\n")
313
+ try:
314
+ dataset = load_dataset("NJU-PCALab/OpenVid-1M", split="train", streaming=True, trust_remote_code=True)
315
+ except Exception as e:
316
+ logger.log(f"⚠️ OpenVid-1M load error: {e}\n")
317
+ logger.log("πŸ”„ Trying alternative: WebVid-2M...\n")
318
+ try:
319
+ dataset = load_dataset("tmpdump/webvid10m", split="train", streaming=True, trust_remote_code=True)
320
+ except Exception as e2:
321
+ logger.log(f"⚠️ WebVid load error: {e2}\n")
322
+ logger.log("πŸ”„ Falling back to COCO captions (image-only, but much more data)...\n")
323
+ return _tokenize_coco_fallback(logger, vq_vae)
324
+
325
+ # Tokenize clips
326
+ tokenized_data = []
327
+ count = 0
328
+ errors = 0
329
+
330
+ for item in dataset:
331
+ if count >= NUM_OPENVID_CLIPS:
332
+ break
333
+
334
+ try:
335
+ # Get text caption
336
+ caption = ""
337
+ for key in ["caption", "text", "description", "title"]:
338
+ if key in item and item[key]:
339
+ caption = item[key]
340
+ break
341
+
342
+ if not caption:
343
+ caption = f"video clip {count}"
344
+
345
+ # Get video frames
346
+ video = item.get("video", None)
347
+ if video is None:
348
+ errors += 1
349
+ continue
350
+
351
+ # Process video frames
352
+ import io
353
+ from PIL import Image
354
+
355
+ frames = []
356
+ if hasattr(video, 'read'):
357
+ # It's bytes
358
+ pass
359
+
360
+ # Try to extract frames
361
+ if isinstance(video, dict) and "bytes" in video:
362
+ video_bytes = video["bytes"]
363
+ elif isinstance(video, bytes):
364
+ video_bytes = video
365
+ else:
366
+ errors += 1
367
+ continue
368
+
369
+ # Use imageio or decord to extract frames
370
+ try:
371
+ import imageio
372
+ reader = imageio.get_reader(io.BytesIO(video_bytes), format='mp4')
373
+ for i, frame in enumerate(reader):
374
+ if i >= 4: # Take first 4 frames
375
+ break
376
+ img = Image.fromarray(frame).convert("RGB").resize((128, 128))
377
+ frames.append(np.array(img))
378
+ reader.close()
379
+ except Exception:
380
+ errors += 1
381
+ continue
382
+
383
+ if not frames:
384
+ errors += 1
385
+ continue
386
+
387
+ # Tokenize frames through VQ-VAE
388
+ from torchvision import transforms
389
+ transform = transforms.ToTensor()
390
+ all_tokens = []
391
+
392
+ for frame in frames:
393
+ img_tensor = transform(Image.fromarray(frame)).unsqueeze(0)
394
+ with torch.no_grad():
395
+ tokens = vq_vae.encode(img_tensor)
396
+ all_tokens.extend(tokens.flatten().tolist())
397
+
398
+ # Truncate/pad to fixed length
399
+ all_tokens = all_tokens[:TOKENS_PER_CLIP]
400
+ while len(all_tokens) < TOKENS_PER_CLIP:
401
+ all_tokens.append(0)
402
+
403
+ tokenized_data.append({
404
+ "text_prompt": caption,
405
+ "video_tokens": all_tokens,
406
+ })
407
+
408
+ count += 1
409
+ if count % 100 == 0:
410
+ logger.log(f" Tokenized {count}/{NUM_OPENVID_CLIPS} clips (errors: {errors})\n")
411
+
412
+ except Exception as e:
413
+ errors += 1
414
+ if errors <= 3:
415
+ logger.log(f" ⚠️ Error on item: {e}\n")
416
+ continue
417
+
418
+ if not tokenized_data:
419
+ logger.log("❌ No clips tokenized from OpenVid-1M! Falling back to COCO captions.\n")
420
+ return _tokenize_coco_fallback(logger, vq_vae)
421
+
422
+ # Save
423
+ with open("tokenized_dataset.json", "w") as f:
424
+ json.dump(tokenized_data, f)
425
+
426
+ logger.log(f"\nβœ… Tokenized {len(tokenized_data)} clips saved to tokenized_dataset.json\n")
427
+ logger.log(f" Errors: {errors}\n\n")
428
+ return tokenized_data
429
+
430
+
431
+ def _tokenize_coco_fallback(logger: Logger, vq_vae: VQVAE):
432
+ """Fallback: tokenize COCO captions as image-text pairs."""
433
+ logger.log("πŸ“¦ Using COCO captions as image-text pairs (50K samples)...\n")
434
+
435
+ from datasets import load_dataset
436
+ from torchvision import transforms
437
+ from PIL import Image
438
+
439
+ coco = load_dataset("HuggingFaceM4/COCO", split="train", streaming=True, trust_remote_code=True)
440
+ transform = transforms.Compose([
441
+ transforms.Resize((VQ_VAE_IMG_SIZE, VQ_VAE_IMG_SIZE)),
442
+ transforms.ToTensor(),
443
+ ])
444
+
445
+ vq_vae.eval()
446
+ tokenized_data = []
447
+ count = 0
448
+
449
+ for item in coco:
450
+ if count >= 50000:
451
+ break
452
+
453
+ try:
454
+ img = item["image"]
455
+ if img.mode != "RGB":
456
+ img = img.convert("RGB")
457
+
458
+ caption = ""
459
+ if "caption" in item:
460
+ caption = item["caption"] if isinstance(item["caption"], str) else item["caption"][0]
461
+ elif "text" in item:
462
+ caption = item["text"]
463
+ if not caption:
464
+ caption = f"image {count}"
465
+
466
+ img_tensor = transform(img).unsqueeze(0)
467
+ with torch.no_grad():
468
+ tokens = vq_vae.encode(img_tensor)
469
+ flat_tokens = tokens.flatten().tolist()
470
+
471
+ # Truncate/pad
472
+ flat_tokens = flat_tokens[:TOKENS_PER_CLIP]
473
+ while len(flat_tokens) < TOKENS_PER_CLIP:
474
+ flat_tokens.append(0)
475
+
476
+ tokenized_data.append({
477
+ "text_prompt": caption,
478
+ "video_tokens": flat_tokens,
479
+ })
480
+
481
+ count += 1
482
+ if count % 1000 == 0:
483
+ logger.log(f" Tokenized {count}/50000 images\n")
484
+ # Save checkpoint periodically
485
+ if count % 10000 == 0:
486
+ with open("tokenized_dataset.json", "w") as f:
487
+ json.dump(tokenized_data, f)
488
+ logger.log(f" πŸ’Ύ Checkpoint saved ({len(tokenized_data)} samples)\n")
489
+
490
+ except Exception:
491
+ continue
492
+
493
+ # Final save
494
+ with open("tokenized_dataset.json", "w") as f:
495
+ json.dump(tokenized_data, f)
496
+
497
+ logger.log(f"\nβœ… Tokenized {len(tokenized_data)} images saved to tokenized_dataset.json\n\n")
498
+ return tokenized_data
499
+
500
+
501
+ # ============================================================================
502
+ # PHASE 3: TRAIN LLM WITH LORA
503
+ # ============================================================================
504
+ def train_llm(logger: Logger):
505
+ """Fine-tune OLMo 2 1B with LoRA on tokenized data."""
506
+ logger.log("=" * 60 + "\n")
507
+ logger.log("PHASE 3: Fine-tuning OLMo 2 1B + LoRA\n")
508
+ logger.log("=" * 60 + "\n\n")
509
+
510
+ from transformers import AutoModelForCausalLM, AutoTokenizer
511
+ from peft import LoraConfig, get_peft_model, TaskType
512
+
513
+ # Load data
514
+ data_path = "tokenized_dataset.json"
515
+ if not os.path.exists(data_path):
516
+ logger.log("❌ No tokenized dataset found! Run Phase 2 first.\n")
517
+ return
518
+
519
+ with open(data_path) as f:
520
+ data = json.load(f)
521
+ logger.log(f"πŸ“Š Loaded {len(data)} training samples\n")
522
+
523
+ # Tokenizer
524
+ logger.log("πŸ“¦ Loading OLMo 2 1B tokenizer...\n")
525
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
526
+ if tokenizer.pad_token is None:
527
+ tokenizer.pad_token = tokenizer.eos_token
528
+
529
+ # Model
530
+ logger.log("πŸ“¦ Loading model (fp32, CPU)...\n")
531
+ model = AutoModelForCausalLM.from_pretrained(
532
+ MODEL_NAME, trust_remote_code=True, torch_dtype=torch.float32
533
+ )
534
+ logger.log(f"βœ… Model loaded. Original vocab: {len(tokenizer)}\n")
535
+
536
+ # Expand vocab
537
+ logger.log(f"πŸ”€ Adding {CODEBOOK_SIZE} visual tokens...\n")
538
+ visual_tokens = [VIDEO_START, VIDEO_END, VIDEO_PAD]
539
+ for i in range(CODEBOOK_SIZE):
540
+ visual_tokens.append(f"<v_{i}>")
541
+ tokenizer.add_tokens(visual_tokens)
542
+ model.resize_token_embeddings(len(tokenizer))
543
+ logger.log(f"βœ… New vocab: {len(tokenizer)}\n")
544
+
545
+ # LoRA
546
+ logger.log(f"πŸ”§ Applying LoRA (r={LORA_R})...\n")
547
+ lora_config = LoraConfig(
548
+ r=LORA_R, lora_alpha=LORA_ALPHA,
549
+ target_modules=["q_proj", "v_proj"],
550
+ lora_dropout=LORA_DROPOUT, bias="none",
551
+ task_type=TaskType.CAUSAL_LM,
552
+ )
553
+ model = get_peft_model(model, lora_config)
554
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
555
+ total = sum(p.numel() for p in model.parameters())
556
+ logger.log(f"βœ… LoRA: {trainable:,} / {total:,} trainable ({100*trainable/total:.2f}%)\n")
557
+
558
+ # Dataset
559
+ class VideoTokenDataset(Dataset):
560
+ def __init__(self, data, max_tokens=TOKENS_PER_CLIP):
561
+ self.data = data
562
+ self.max_tokens = max_tokens
563
+
564
+ def __len__(self):
565
+ return len(self.data)
566
+
567
+ def __getitem__(self, idx):
568
+ item = self.data[idx]
569
+ prompt = item["text_prompt"]
570
+ tokens = item["video_tokens"][:self.max_tokens]
571
+ while len(tokens) < self.max_tokens:
572
+ tokens.append(0)
573
+ return {"prompt": prompt, "video_tokens": torch.tensor(tokens, dtype=torch.long)}
574
+
575
+ dataset = VideoTokenDataset(data)
576
+ dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
577
+ total_steps = NUM_EPOCHS * len(dataloader)
578
+ logger.log(f"πŸ“Š {len(dataset)} samples Γ— {NUM_EPOCHS} epochs = {total_steps} steps\n\n")
579
+
580
+ # Train
581
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
582
+ model.train()
583
+ global_step = 0
584
+ running_loss = 0.0
585
+ start_time = time.time()
586
+
587
+ for epoch in range(NUM_EPOCHS):
588
+ epoch_loss = 0.0
589
+ num_batches = 0
590
+
591
+ for batch_idx, batch in enumerate(dataloader):
592
+ prompt = batch["prompt"][0]
593
+ video_tokens = batch["video_tokens"][0]
594
+
595
+ # Format: use 64 visual tokens per sample for memory
596
+ token_str = " ".join(f"<v_{t.item()}>" for t in video_tokens[:64])
597
+ text = f"Create a video of: {prompt} {VIDEO_START} {token_str} {VIDEO_END}"
598
+
599
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LEN, padding="max_length")
600
+ outputs = model(**inputs, labels=inputs["input_ids"])
601
+ loss = outputs.loss / GRADIENT_ACCUMULATION
602
+ loss.backward()
603
+
604
+ if (batch_idx + 1) % GRADIENT_ACCUMULATION == 0 or (batch_idx + 1) == len(dataloader):
605
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
606
+ optimizer.step()
607
+ optimizer.zero_grad()
608
+
609
+ global_step += 1
610
+ batch_loss = loss.item() * GRADIENT_ACCUMULATION
611
+ epoch_loss += batch_loss
612
+ running_loss += batch_loss
613
+ num_batches += 1
614
+
615
+ if batch_idx % 100 == 0:
616
+ elapsed = time.time() - start_time
617
+ speed = global_step / elapsed if elapsed > 0 else 0
618
+ logger.log(f" Epoch {epoch+1}/{NUM_EPOCHS} | Step {batch_idx+1}/{len(dataloader)} | "
619
+ f"Loss: {batch_loss:.4f} | Avg: {epoch_loss/num_batches:.4f} | "
620
+ f"Speed: {speed:.2f} steps/s\n")
621
+
622
+ del outputs, loss
623
+ gc.collect()
624
+
625
+ logger.log(f"\nπŸ“ˆ Epoch {epoch+1} done. Avg Loss: {epoch_loss/num_batches:.4f}\n\n")
626
+
627
+ total_time = time.time() - start_time
628
+ logger.log(f"βœ… Training complete in {total_time:.0f}s ({total_time/60:.1f} min)\n")
629
+ logger.log(f" Final avg loss: {running_loss/global_step:.4f}\n\n")
630
+
631
+ # Merge & save
632
+ logger.log("πŸ”€ Merging LoRA β†’ base model...\n")
633
+ model = model.merge_and_unload()
634
+
635
+ save_dir = "./trained_model"
636
+ model.save_pretrained(save_dir, safe_serialization=True)
637
+ tokenizer.save_pretrained(save_dir)
638
+
639
+ # Also save VQ-VAE
640
+ if os.path.exists("vq_vae_real.pt"):
641
+ import shutil
642
+ shutil.copy("vq_vae_real.pt", f"{save_dir}/vq_vae_final.pt")
643
+
644
+ # Copy tokenized dataset
645
+ if os.path.exists("tokenized_dataset.json"):
646
+ import shutil
647
+ shutil.copy("tokenized_dataset.json", f"{save_dir}/tokenized_dataset.json")
648
+
649
+ logger.log("βœ… Model saved locally.\n")
650
+
651
+ # Push
652
+ logger.log(f"πŸš€ Pushing to {REPO_ID}...\n")
653
+ from huggingface_hub import HfApi
654
+ api = HfApi(token=HF_TOKEN)
655
+ try:
656
+ api.create_repo(repo_id=REPO_ID, repo_type="model", exist_ok=True)
657
+ except:
658
+ pass
659
+ api.upload_folder(folder_path=save_dir, repo_id=REPO_ID, repo_type="model",
660
+ commit_message=f"LoRA OLMo 2 1B (r={LORA_R}, {NUM_EPOCHS} epochs, {len(data)} samples)")
661
+ logger.log(f"βœ… Pushed to https://huggingface.co/{REPO_ID}\n\n")
662
+
663
+
664
+ # ============================================================================
665
+ # MAIN PIPELINE
666
+ # ============================================================================
667
+ def run_pipeline(log_path: str = LOG_FILE):
668
+ logger = Logger(log_path)
669
+
670
+ try:
671
+ # Phase 1: Train VQ-VAE
672
+ vq_vae = train_vq_vae(logger)
673
+ gc.collect()
674
+
675
+ # Phase 2: Tokenize dataset
676
+ tokenize_openvid(logger, vq_vae)
677
+ gc.collect()
678
+
679
+ # Phase 3: Train LLM
680
+ train_llm(logger)
681
+
682
+ logger.log("\nπŸŽ‰ FULL PIPELINE COMPLETE!\n")
683
+ except Exception as e:
684
+ logger.log(f"\n❌ PIPELINE ERROR: {e}\n")
685
+ logger.log(traceback.format_exc())
686
+
687
+
688
+ # CLI
689
+ if __name__ == "__main__":
690
+ run_pipeline()