eeshaAI commited on
Commit
e8cf613
·
verified ·
1 Parent(s): 395c0d2

Updated: scaled pipeline with real data (10K images, 5K LLM samples, checkpoint/resume support)

Browse files
Files changed (1) hide show
  1. train_full_pipeline.py +493 -322
train_full_pipeline.py CHANGED
@@ -1,12 +1,13 @@
1
  #!/usr/bin/env python3
2
  """
3
- Full Pipeline: Train VQ-VAE → Tokenize OpenVid → Train LLM → Push to EeshaAI/zeeb
4
- =================================================================================
5
  Runs on HuggingFace Spaces (free CPU tier, 16GB RAM).
 
6
 
7
- Phase 1: Train VQ-VAE on COCO 2017 images (118K real images, streaming)
8
- Phase 2: Stream 10K clips from OpenVid-1M tokenize via trained VQ-VAE → save integers
9
- Phase 3: Fine-tune OLMo 2 1B with LoRA on 10K tokenized samples
10
  Phase 4: Push trained model to EeshaAI/zeeb
11
  """
12
 
@@ -17,8 +18,9 @@ import time
17
  import gc
18
  import threading
19
  import traceback
 
20
  import numpy as np
21
- from typing import Optional
22
 
23
  import torch
24
  import torch.nn as nn
@@ -38,27 +40,36 @@ VIDEO_START = "<video_start>"
38
  VIDEO_END = "<video_end>"
39
  VIDEO_PAD = "<video_pad>"
40
 
 
 
 
 
 
41
  # VQ-VAE training
42
  VQ_VAE_EPOCHS = 5
43
- VQ_VAE_LR = 1e-3
44
- VQ_VAE_BATCH = 32
45
- VQ_VAE_IMG_SIZE = 128 # resize images to 128x128
 
46
 
47
- # Dataset preparation
48
- NUM_OPENVID_CLIPS = 10000
49
- TOKENS_PER_CLIP = 128 # number of visual tokens per video clip
50
 
51
  # LLM training
52
- NUM_EPOCHS = 3
53
  LORA_R = 4
54
  LORA_ALPHA = 8
55
  LORA_DROPOUT = 0.05
56
- LEARNING_RATE = 1e-4
57
  BATCH_SIZE = 1
58
- MAX_SEQ_LEN = 384
59
- GRADIENT_ACCUMULATION = 4
 
 
60
 
61
- LOG_FILE = "/tmp/pipeline_log.txt"
 
62
 
63
 
64
  # ============================================================================
@@ -69,30 +80,84 @@ class Logger:
69
  self.path = path
70
  self.lock = threading.Lock()
71
  with open(path, "w") as f:
72
- f.write("🚀 Zeeb Full Pipeline Starting...\n\n")
73
 
74
  def log(self, msg):
 
 
75
  with self.lock:
76
- with open(self.path, "a") as f:
77
- f.write(msg)
78
- f.flush()
79
- print(msg, end="", flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
 
82
  # ============================================================================
83
- # VQ-VAE MODEL
84
  # ============================================================================
85
  class Encoder(nn.Module):
86
  def __init__(self, in_channels=3, latent_dim=LATENT_DIM):
87
  super().__init__()
88
  self.net = nn.Sequential(
89
- nn.Conv2d(in_channels, 64, 4, stride=2, padding=1), # 64x64
90
  nn.ReLU(),
91
- nn.Conv2d(64, 128, 4, stride=2, padding=1), # 32x32
92
  nn.ReLU(),
93
- nn.Conv2d(128, 256, 4, stride=2, padding=1), # 16x16
94
  nn.ReLU(),
95
- nn.Conv2d(256, latent_dim, 4, stride=2, padding=1), # 8x8
96
  )
97
 
98
  def forward(self, x):
@@ -134,13 +199,13 @@ class Decoder(nn.Module):
134
  def __init__(self, out_channels=3, latent_dim=LATENT_DIM):
135
  super().__init__()
136
  self.net = nn.Sequential(
137
- nn.ConvTranspose2d(latent_dim, 256, 4, stride=2, padding=1), # 16x16
138
  nn.ReLU(),
139
- nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), # 32x32
140
  nn.ReLU(),
141
- nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), # 64x64
142
  nn.ReLU(),
143
- nn.ConvTranspose2d(64, out_channels, 4, stride=2, padding=1), # 128x128
144
  nn.Sigmoid(),
145
  )
146
 
@@ -190,49 +255,120 @@ class VQVAE(nn.Module):
190
 
191
 
192
  # ============================================================================
193
- # PHASE 1: TRAIN VQ-VAE ON COCO IMAGES
194
  # ============================================================================
195
- def train_vq_vae(logger: Logger) -> VQVAE:
196
- """Train VQ-VAE on COCO 2017 images (streaming, so no massive download)."""
197
- logger.log("=" * 60 + "\n")
198
- logger.log("PHASE 1: Training VQ-VAE on COCO 2017 images\n")
199
- logger.log("=" * 60 + "\n\n")
200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  from datasets import load_dataset
202
- from torchvision import transforms
203
-
204
- # Try multiple COCO/image datasets (some have compatibility issues)
205
- logger.log("📦 Loading image dataset (trying multiple sources)...\n")
206
- coco = None
207
- image_key = "image"
208
-
209
  dataset_sources = [
210
- ("detection-datasets/coco", "train", "image"),
211
- ("rafaelpadilla/coco2017", "train", "image"),
212
- ("frgfm/imagenette", "train", "image"),
213
- ("zh-plus/tiny-imagenet", "train", "image"),
214
- ("cifar10", "train", "img"),
215
  ]
216
-
217
- for ds_name, ds_split, ds_img_key in dataset_sources:
218
  try:
219
- logger.log(f" Trying {ds_name}...\n")
220
- coco = load_dataset(ds_name, split=ds_split, streaming=True, trust_remote_code=True)
221
- # Test first item
222
- test_item = next(iter(coco))
223
- if ds_img_key in test_item:
224
- image_key = ds_img_key
225
- logger.log(f" ✅ Using {ds_name} (image key: '{image_key}')\n")
226
- break
227
- else:
228
- logger.log(f" ⚠️ No '{ds_img_key}' key in {ds_name}, keys: {list(test_item.keys())}\n")
229
- coco = None
 
 
 
 
 
 
 
 
 
 
 
230
  except Exception as e:
231
- logger.log(f" ❌ {ds_name} failed: {str(e)[:100]}\n")
232
- coco = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
- if coco is None:
235
- logger.log("❌ No dataset could be loaded! Cannot train VQ-VAE.\n")
 
 
236
  return None
237
 
238
  # Image transforms
@@ -242,7 +378,7 @@ def train_vq_vae(logger: Logger) -> VQVAE:
242
  ])
243
 
244
  class ImageStreamDataset(IterableDataset):
245
- def __init__(self, hf_dataset, transform, img_key, max_samples=50000):
246
  self.dataset = hf_dataset
247
  self.transform = transform
248
  self.img_key = img_key
@@ -258,23 +394,38 @@ def train_vq_vae(logger: Logger) -> VQVAE:
258
  if img.mode != "RGB":
259
  img = img.convert("RGB")
260
  tensor = self.transform(img)
261
- yield tensor
262
  count += 1
 
263
  except Exception:
264
  continue
265
 
266
- dataset = ImageStreamDataset(coco, transform, image_key, max_samples=50000)
267
  dataloader = DataLoader(dataset, batch_size=VQ_VAE_BATCH, num_workers=0)
268
 
269
- # Initialize model
270
  model = VQVAE()
271
  n_params = sum(p.numel() for p in model.parameters()) / 1e6
272
- logger.log(f"VQ-VAE initialized: {n_params:.1f}M parameters\n")
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
  optimizer = torch.optim.Adam(model.parameters(), lr=VQ_VAE_LR)
 
275
  model.train()
276
 
277
- for epoch in range(VQ_VAE_EPOCHS):
 
 
278
  epoch_loss = 0.0
279
  epoch_recon = 0.0
280
  epoch_vq = 0.0
@@ -288,6 +439,7 @@ def train_vq_vae(logger: Logger) -> VQVAE:
288
 
289
  optimizer.zero_grad()
290
  loss.backward()
 
291
  optimizer.step()
292
 
293
  epoch_loss += loss.item()
@@ -295,202 +447,103 @@ def train_vq_vae(logger: Logger) -> VQVAE:
295
  epoch_vq += vq_loss.item()
296
  num_batches += 1
297
 
298
- if batch_idx % 50 == 0 and batch_idx > 0:
299
  avg = epoch_loss / num_batches
300
  avg_r = epoch_recon / num_batches
301
  avg_v = epoch_vq / num_batches
302
  logger.log(f" Epoch {epoch+1}/{VQ_VAE_EPOCHS} | Batch {batch_idx} | "
303
  f"Loss: {avg:.4f} (recon: {avg_r:.4f}, vq: {avg_v:.4f})\n")
304
 
305
- del recon, vq_loss, loss
306
- if batch_idx % 200 == 0:
307
  gc.collect()
308
 
 
 
309
  elapsed = time.time() - start_time
310
  avg_loss = epoch_loss / max(num_batches, 1)
311
- logger.log(f"\n📈 Epoch {epoch+1} done. Avg Loss: {avg_loss:.4f} | "
 
312
  f"Batches: {num_batches} | Time: {elapsed:.0f}s\n\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
- # Save
 
 
 
 
 
315
  torch.save(model.state_dict(), "vq_vae_real.pt")
316
- logger.log("✅ VQ-VAE saved to vq_vae_real.pt\n\n")
 
 
317
  return model
318
 
319
 
320
  # ============================================================================
321
- # PHASE 2: TOKENIZE OPENVID-1M DATASET
322
  # ============================================================================
323
- def tokenize_openvid(logger: Logger, vq_vae: Optional[VQVAE] = None):
324
- """Stream OpenVid-1M, tokenize videos with VQ-VAE, save tokenized data."""
325
  logger.log("=" * 60 + "\n")
326
- logger.log("PHASE 2: Tokenizing OpenVid-1M dataset (10K clips)\n")
327
  logger.log("=" * 60 + "\n\n")
328
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  # Load VQ-VAE if not provided
330
  if vq_vae is None:
331
- if os.path.exists("vq_vae_real.pt"):
 
 
 
 
 
332
  vq_vae = VQVAE()
333
  vq_vae.load_state_dict(torch.load("vq_vae_real.pt", map_location="cpu", weights_only=False))
334
- logger.log("Loaded trained VQ-VAE from vq_vae_real.pt\n")
335
  else:
336
- logger.log("No trained VQ-VAE found! Run Phase 1 first.\n")
337
  return None
338
 
339
  vq_vae.eval()
340
-
341
- from datasets import load_dataset
342
-
343
- logger.log("📦 Loading OpenVid-1M dataset (streaming)...\n")
344
- try:
345
- dataset = load_dataset("NJU-PCALab/OpenVid-1M", split="train", streaming=True, trust_remote_code=True)
346
- except Exception as e:
347
- logger.log(f"⚠️ OpenVid-1M load error: {e}\n")
348
- logger.log("🔄 Trying alternative: WebVid-2M...\n")
349
- try:
350
- dataset = load_dataset("tmpdump/webvid10m", split="train", streaming=True, trust_remote_code=True)
351
- except Exception as e2:
352
- logger.log(f"⚠️ WebVid load error: {e2}\n")
353
- logger.log("🔄 Falling back to COCO captions (image-only, but much more data)...\n")
354
- return _tokenize_coco_fallback(logger, vq_vae)
355
-
356
- # Tokenize clips
357
- tokenized_data = []
358
- count = 0
359
- errors = 0
360
-
361
- for item in dataset:
362
- if count >= NUM_OPENVID_CLIPS:
363
- break
364
-
365
- try:
366
- # Get text caption
367
- caption = ""
368
- for key in ["caption", "text", "description", "title"]:
369
- if key in item and item[key]:
370
- caption = item[key]
371
- break
372
-
373
- if not caption:
374
- caption = f"video clip {count}"
375
-
376
- # Get video frames
377
- video = item.get("video", None)
378
- if video is None:
379
- errors += 1
380
- continue
381
-
382
- # Process video frames
383
- import io
384
- from PIL import Image
385
-
386
- frames = []
387
- if hasattr(video, 'read'):
388
- # It's bytes
389
- pass
390
-
391
- # Try to extract frames
392
- if isinstance(video, dict) and "bytes" in video:
393
- video_bytes = video["bytes"]
394
- elif isinstance(video, bytes):
395
- video_bytes = video
396
- else:
397
- errors += 1
398
- continue
399
-
400
- # Use imageio or decord to extract frames
401
- try:
402
- import imageio
403
- reader = imageio.get_reader(io.BytesIO(video_bytes), format='mp4')
404
- for i, frame in enumerate(reader):
405
- if i >= 4: # Take first 4 frames
406
- break
407
- img = Image.fromarray(frame).convert("RGB").resize((128, 128))
408
- frames.append(np.array(img))
409
- reader.close()
410
- except Exception:
411
- errors += 1
412
- continue
413
-
414
- if not frames:
415
- errors += 1
416
- continue
417
-
418
- # Tokenize frames through VQ-VAE
419
- from torchvision import transforms
420
- transform = transforms.ToTensor()
421
- all_tokens = []
422
-
423
- for frame in frames:
424
- img_tensor = transform(Image.fromarray(frame)).unsqueeze(0)
425
- with torch.no_grad():
426
- tokens = vq_vae.encode(img_tensor)
427
- all_tokens.extend(tokens.flatten().tolist())
428
-
429
- # Truncate/pad to fixed length
430
- all_tokens = all_tokens[:TOKENS_PER_CLIP]
431
- while len(all_tokens) < TOKENS_PER_CLIP:
432
- all_tokens.append(0)
433
-
434
- tokenized_data.append({
435
- "text_prompt": caption,
436
- "video_tokens": all_tokens,
437
- })
438
-
439
- count += 1
440
- if count % 100 == 0:
441
- logger.log(f" Tokenized {count}/{NUM_OPENVID_CLIPS} clips (errors: {errors})\n")
442
-
443
- except Exception as e:
444
- errors += 1
445
- if errors <= 3:
446
- logger.log(f" ⚠️ Error on item: {e}\n")
447
- continue
448
-
449
- if not tokenized_data:
450
- logger.log("❌ No clips tokenized from OpenVid-1M! Falling back to COCO captions.\n")
451
- return _tokenize_coco_fallback(logger, vq_vae)
452
-
453
- # Save
454
- with open("tokenized_dataset.json", "w") as f:
455
- json.dump(tokenized_data, f)
456
-
457
- logger.log(f"\n✅ Tokenized {len(tokenized_data)} clips saved to tokenized_dataset.json\n")
458
- logger.log(f" Errors: {errors}\n\n")
459
- return tokenized_data
460
-
461
-
462
- def _tokenize_coco_fallback(logger: Logger, vq_vae: VQVAE):
463
- """Fallback: tokenize image-text pairs from available datasets."""
464
- logger.log("📦 Using image-text pairs as fallback (50K samples)...\n")
465
-
466
  from datasets import load_dataset
467
  from torchvision import transforms
468
  from PIL import Image
469
 
470
- # Try multiple datasets
471
- ds = None
472
- image_key = "image"
473
- caption_key = "text"
474
-
475
- for ds_name, ds_split, img_k, cap_k in [
476
- ("detection-datasets/coco", "train", "image", "caption"),
477
- ("frgfm/imagenette", "train", "image", "label"),
478
- ("cifar10", "train", "img", "label"),
479
- ]:
480
- try:
481
- logger.log(f" Trying {ds_name}...\n")
482
- ds = load_dataset(ds_name, split=ds_split, streaming=True, trust_remote_code=True)
483
- test = next(iter(ds))
484
- image_key = img_k if img_k in test else "image"
485
- caption_key = cap_k if cap_k in test else "text"
486
- logger.log(f" ✅ Using {ds_name} (img='{image_key}', cap='{caption_key}')\n")
487
- break
488
- except Exception as e:
489
- logger.log(f" ❌ {ds_name}: {str(e)[:100]}\n")
490
- ds = None
491
-
492
  if ds is None:
493
- logger.log("No dataset available for tokenization!\n")
494
  return None
495
 
496
  transform = transforms.Compose([
@@ -498,50 +551,43 @@ def _tokenize_coco_fallback(logger: Logger, vq_vae: VQVAE):
498
  transforms.ToTensor(),
499
  ])
500
 
501
- vq_vae.eval()
502
  tokenized_data = []
503
  count = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
504
 
505
- # Label mapping for datasets that only have class labels
506
- label_names = {
507
- "cifar10": ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"],
508
- }
509
 
510
  for item in ds:
511
- if count >= 50000:
512
  break
513
 
514
  try:
515
- img = item[image_key]
516
  if img.mode != "RGB":
517
  img = img.convert("RGB")
518
 
519
- # Get caption
520
- caption = ""
521
- if caption_key in item and item[caption_key] is not None:
522
- cap = item[caption_key]
523
- if isinstance(cap, list):
524
- caption = cap[0] if cap else ""
525
- elif isinstance(cap, int):
526
- # It's a class label - convert to text
527
- ds_name_short = ds_name.split("/")[0] if "/" in ds_name else ds_name
528
- if ds_name_short in label_names and cap < len(label_names[ds_name_short]):
529
- caption = f"a photo of a {label_names[ds_name_short][cap]}"
530
- else:
531
- caption = f"image class {cap}"
532
- else:
533
- caption = str(cap)
534
- if not caption:
535
- caption = f"image {count}"
536
 
537
  img_tensor = transform(img).unsqueeze(0)
538
  with torch.no_grad():
539
  tokens = vq_vae.encode(img_tensor)
540
  flat_tokens = tokens.flatten().tolist()
541
 
542
- # Truncate/pad
543
- flat_tokens = flat_tokens[:TOKENS_PER_CLIP]
544
- while len(flat_tokens) < TOKENS_PER_CLIP:
545
  flat_tokens.append(0)
546
 
547
  tokenized_data.append({
@@ -550,71 +596,110 @@ def _tokenize_coco_fallback(logger: Logger, vq_vae: VQVAE):
550
  })
551
 
552
  count += 1
553
- if count % 1000 == 0:
554
- logger.log(f" Tokenized {count}/50000 images\n")
555
- # Save checkpoint periodically
556
- if count % 10000 == 0:
557
- with open("tokenized_dataset.json", "w") as f:
558
- json.dump(tokenized_data, f)
559
- logger.log(f" 💾 Checkpoint saved ({len(tokenized_data)} samples)\n")
560
-
561
- except Exception:
 
 
 
 
 
 
562
  continue
563
 
564
- # Final save
 
 
 
 
 
 
 
 
 
565
  with open("tokenized_dataset.json", "w") as f:
566
  json.dump(tokenized_data, f)
567
 
568
- logger.log(f"\n✅ Tokenized {len(tokenized_data)} images saved to tokenized_dataset.json\n\n")
 
 
 
 
 
569
  return tokenized_data
570
 
571
 
572
  # ============================================================================
573
  # PHASE 3: TRAIN LLM WITH LORA
574
  # ============================================================================
575
- def train_llm(logger: Logger):
576
  """Fine-tune OLMo 2 1B with LoRA on tokenized data."""
577
  logger.log("=" * 60 + "\n")
578
- logger.log("PHASE 3: Fine-tuning OLMo 2 1B + LoRA\n")
579
  logger.log("=" * 60 + "\n\n")
580
 
 
 
 
 
581
  from transformers import AutoModelForCausalLM, AutoTokenizer
582
  from peft import LoraConfig, get_peft_model, TaskType
583
 
584
  # Load data
585
- data_path = "tokenized_dataset.json"
586
  if not os.path.exists(data_path):
587
- logger.log("❌ No tokenized dataset found! Run Phase 2 first.\n")
 
 
 
588
  return
589
 
590
  with open(data_path) as f:
591
- data = json.load(f)
592
- logger.log(f"📊 Loaded {len(data)} training samples\n")
 
 
 
 
 
 
 
 
 
 
 
593
 
594
  # Tokenizer
595
- logger.log("📦 Loading OLMo 2 1B tokenizer...\n")
596
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
597
  if tokenizer.pad_token is None:
598
  tokenizer.pad_token = tokenizer.eos_token
599
 
600
  # Model
601
- logger.log("📦 Loading model (fp32, CPU)...\n")
602
  model = AutoModelForCausalLM.from_pretrained(
603
  MODEL_NAME, trust_remote_code=True, torch_dtype=torch.float32
604
  )
605
- logger.log(f"✅ Model loaded. Original vocab: {len(tokenizer)}\n")
 
606
 
607
  # Expand vocab
608
- logger.log(f"🔤 Adding {CODEBOOK_SIZE} visual tokens...\n")
609
  visual_tokens = [VIDEO_START, VIDEO_END, VIDEO_PAD]
610
  for i in range(CODEBOOK_SIZE):
611
  visual_tokens.append(f"<v_{i}>")
612
  tokenizer.add_tokens(visual_tokens)
613
  model.resize_token_embeddings(len(tokenizer))
614
- logger.log(f"New vocab: {len(tokenizer)}\n")
615
 
616
  # LoRA
617
- logger.log(f"🔧 Applying LoRA (r={LORA_R})...\n")
618
  lora_config = LoraConfig(
619
  r=LORA_R, lora_alpha=LORA_ALPHA,
620
  target_modules=["q_proj", "v_proj"],
@@ -624,11 +709,11 @@ def train_llm(logger: Logger):
624
  model = get_peft_model(model, lora_config)
625
  trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
626
  total = sum(p.numel() for p in model.parameters())
627
- logger.log(f"LoRA: {trainable:,} / {total:,} trainable ({100*trainable/total:.2f}%)\n")
628
 
629
  # Dataset
630
  class VideoTokenDataset(Dataset):
631
- def __init__(self, data, max_tokens=TOKENS_PER_CLIP):
632
  self.data = data
633
  self.max_tokens = max_tokens
634
 
@@ -646,13 +731,32 @@ def train_llm(logger: Logger):
646
  dataset = VideoTokenDataset(data)
647
  dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
648
  total_steps = NUM_EPOCHS * len(dataloader)
649
- logger.log(f"📊 {len(dataset)} samples × {NUM_EPOCHS} epochs = {total_steps} steps\n\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
650
 
651
- # Train
652
- optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
653
  model.train()
654
  global_step = 0
655
  running_loss = 0.0
 
656
  start_time = time.time()
657
 
658
  for epoch in range(NUM_EPOCHS):
@@ -663,8 +767,8 @@ def train_llm(logger: Logger):
663
  prompt = batch["prompt"][0]
664
  video_tokens = batch["video_tokens"][0]
665
 
666
- # Format: use 64 visual tokens per sample for memory
667
- token_str = " ".join(f"<v_{t.item()}>" for t in video_tokens[:64])
668
  text = f"Create a video of: {prompt} {VIDEO_START} {token_str} {VIDEO_END}"
669
 
670
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LEN, padding="max_length")
@@ -686,73 +790,140 @@ def train_llm(logger: Logger):
686
  if batch_idx % 100 == 0:
687
  elapsed = time.time() - start_time
688
  speed = global_step / elapsed if elapsed > 0 else 0
 
689
  logger.log(f" Epoch {epoch+1}/{NUM_EPOCHS} | Step {batch_idx+1}/{len(dataloader)} | "
690
  f"Loss: {batch_loss:.4f} | Avg: {epoch_loss/num_batches:.4f} | "
691
- f"Speed: {speed:.2f} steps/s\n")
692
 
693
- del outputs, loss
694
- gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
695
 
696
- logger.log(f"\n📈 Epoch {epoch+1} done. Avg Loss: {epoch_loss/num_batches:.4f}\n\n")
 
 
 
 
 
 
 
697
 
698
  total_time = time.time() - start_time
699
- logger.log(f"✅ Training complete in {total_time:.0f}s ({total_time/60:.1f} min)\n")
700
- logger.log(f" Final avg loss: {running_loss/global_step:.4f}\n\n")
 
701
 
702
  # Merge & save
703
- logger.log("🔀 Merging LoRA base model...\n")
704
  model = model.merge_and_unload()
705
 
706
- save_dir = "./trained_model"
 
707
  model.save_pretrained(save_dir, safe_serialization=True)
708
  tokenizer.save_pretrained(save_dir)
709
 
710
- # Also save VQ-VAE
711
- if os.path.exists("vq_vae_real.pt"):
 
712
  import shutil
713
- shutil.copy("vq_vae_real.pt", f"{save_dir}/vq_vae_final.pt")
 
 
 
714
 
715
  # Copy tokenized dataset
716
- if os.path.exists("tokenized_dataset.json"):
717
  import shutil
718
- shutil.copy("tokenized_dataset.json", f"{save_dir}/tokenized_dataset.json")
 
719
 
720
- logger.log("Model saved locally.\n")
721
 
722
- # Push
723
- logger.log(f"🚀 Pushing to {REPO_ID}...\n")
724
- from huggingface_hub import HfApi
725
- api = HfApi(token=HF_TOKEN)
726
- try:
727
- api.create_repo(repo_id=REPO_ID, repo_type="model", exist_ok=True)
728
- except:
729
- pass
730
- api.upload_folder(folder_path=save_dir, repo_id=REPO_ID, repo_type="model",
731
- commit_message=f"LoRA OLMo 2 1B (r={LORA_R}, {NUM_EPOCHS} epochs, {len(data)} samples)")
732
- logger.log(f"✅ Pushed to https://huggingface.co/{REPO_ID}\n\n")
 
 
 
 
 
 
 
 
 
 
 
 
733
 
734
 
735
  # ============================================================================
736
  # MAIN PIPELINE
737
  # ============================================================================
738
- def run_pipeline(log_path: str = LOG_FILE):
 
 
 
739
  logger = Logger(log_path)
 
 
 
 
 
740
 
741
  try:
742
  # Phase 1: Train VQ-VAE
743
- vq_vae = train_vq_vae(logger)
 
 
 
 
 
 
744
  gc.collect()
745
 
746
  # Phase 2: Tokenize dataset
747
- tokenize_openvid(logger, vq_vae)
 
 
 
 
 
748
  gc.collect()
749
 
750
  # Phase 3: Train LLM
751
- train_llm(logger)
 
 
 
 
752
 
753
- logger.log("\n🎉 FULL PIPELINE COMPLETE!\n")
 
 
 
 
754
  except Exception as e:
755
- logger.log(f"\n❌ PIPELINE ERROR: {e}\n")
756
  logger.log(traceback.format_exc())
757
 
758
 
 
1
  #!/usr/bin/env python3
2
  """
3
+ Full Pipeline: Train VQ-VAE → Tokenize Data → Train LLM → Push to EeshaAI/zeeb
4
+ ================================================================================
5
  Runs on HuggingFace Spaces (free CPU tier, 16GB RAM).
6
+ Uses /data/ persistent volume for checkpoints (survives Space restarts).
7
 
8
+ Phase 1: Train VQ-VAE on real images (COCO/imagenette, streaming)
9
+ Phase 2: Tokenize image-text pairs through trained VQ-VAE
10
+ Phase 3: Fine-tune OLMo 2 1B with LoRA on tokenized data
11
  Phase 4: Push trained model to EeshaAI/zeeb
12
  """
13
 
 
18
  import gc
19
  import threading
20
  import traceback
21
+ import hashlib
22
  import numpy as np
23
+ from typing import Optional, List, Dict, Any
24
 
25
  import torch
26
  import torch.nn as nn
 
40
  VIDEO_END = "<video_end>"
41
  VIDEO_PAD = "<video_pad>"
42
 
43
+ # Persistent storage
44
+ DATA_DIR = os.environ.get("DATA_DIR", "/data")
45
+ PERSIST_DIR = os.path.join(DATA_DIR, "zeeb_checkpoints")
46
+ os.makedirs(PERSIST_DIR, exist_ok=True)
47
+
48
  # VQ-VAE training
49
  VQ_VAE_EPOCHS = 5
50
+ VQ_VAE_LR = 3e-4
51
+ VQ_VAE_BATCH = 8
52
+ VQ_VAE_IMG_SIZE = 128
53
+ VQ_VAE_MAX_IMAGES = 10000 # Train on 10K real images
54
 
55
+ # Tokenization
56
+ TOKENS_PER_SAMPLE = 64 # 8x8 grid
57
+ NUM_TOKENIZE_SAMPLES = 10000 # Tokenize 10K image-text pairs
58
 
59
  # LLM training
60
+ NUM_EPOCHS = 2
61
  LORA_R = 4
62
  LORA_ALPHA = 8
63
  LORA_DROPOUT = 0.05
64
+ LEARNING_RATE = 5e-5
65
  BATCH_SIZE = 1
66
+ MAX_SEQ_LEN = 256
67
+ GRADIENT_ACCUMULATION = 8
68
+ LLM_TRAIN_SAMPLES = 5000 # Train on 5K samples (feasible on CPU)
69
+ SAVE_EVERY = 500 # Save checkpoint every N steps
70
 
71
+ LOG_FILE = os.path.join(DATA_DIR, "pipeline_log.txt")
72
+ STATE_FILE = os.path.join(PERSIST_DIR, "pipeline_state.json")
73
 
74
 
75
  # ============================================================================
 
80
  self.path = path
81
  self.lock = threading.Lock()
82
  with open(path, "w") as f:
83
+ f.write("Zeeb Full Pipeline Starting...\n\n")
84
 
85
  def log(self, msg):
86
+ timestamp = time.strftime("%H:%M:%S")
87
+ line = f"[{timestamp}] {msg}"
88
  with self.lock:
89
+ try:
90
+ with open(self.path, "a") as f:
91
+ f.write(line)
92
+ f.flush()
93
+ except:
94
+ pass
95
+ print(line, end="", flush=True)
96
+
97
+
98
+ # ============================================================================
99
+ # PIPELINE STATE (for resume after restart)
100
+ # ============================================================================
101
+ class PipelineState:
102
+ """Track pipeline progress so we can resume after Space restarts."""
103
+
104
+ def __init__(self):
105
+ self.state = {
106
+ "phase": 0, # 0=not started, 1=vq_vae, 2=tokenize, 3=llm, 4=done
107
+ "vq_vae_done": False,
108
+ "vq_vae_epoch": 0,
109
+ "vq_vae_batch": 0,
110
+ "tokenize_done": False,
111
+ "tokenize_count": 0,
112
+ "llm_done": False,
113
+ "llm_step": 0,
114
+ "llm_epoch": 0,
115
+ "pushed": False,
116
+ }
117
+ self.load()
118
+
119
+ def load(self):
120
+ if os.path.exists(STATE_FILE):
121
+ try:
122
+ with open(STATE_FILE) as f:
123
+ saved = json.load(f)
124
+ self.state.update(saved)
125
+ except:
126
+ pass
127
+
128
+ def save(self):
129
+ try:
130
+ with open(STATE_FILE, "w") as f:
131
+ json.dump(self.state, f, indent=2)
132
+ except:
133
+ pass
134
+
135
+ def update(self, **kwargs):
136
+ self.state.update(kwargs)
137
+ self.save()
138
+
139
+ @property
140
+ def phase(self):
141
+ return self.state.get("phase", 0)
142
+
143
+ def is_done(self, phase_name):
144
+ return self.state.get(f"{phase_name}_done", False)
145
 
146
 
147
  # ============================================================================
148
+ # VQ-VAE MODEL (same architecture as in generation code)
149
  # ============================================================================
150
  class Encoder(nn.Module):
151
  def __init__(self, in_channels=3, latent_dim=LATENT_DIM):
152
  super().__init__()
153
  self.net = nn.Sequential(
154
+ nn.Conv2d(in_channels, 64, 4, stride=2, padding=1), # -> 64x64
155
  nn.ReLU(),
156
+ nn.Conv2d(64, 128, 4, stride=2, padding=1), # -> 32x32
157
  nn.ReLU(),
158
+ nn.Conv2d(128, 256, 4, stride=2, padding=1), # -> 16x16
159
  nn.ReLU(),
160
+ nn.Conv2d(256, latent_dim, 4, stride=2, padding=1), # -> 8x8
161
  )
162
 
163
  def forward(self, x):
 
199
  def __init__(self, out_channels=3, latent_dim=LATENT_DIM):
200
  super().__init__()
201
  self.net = nn.Sequential(
202
+ nn.ConvTranspose2d(latent_dim, 256, 4, stride=2, padding=1), # -> 16x16
203
  nn.ReLU(),
204
+ nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), # -> 32x32
205
  nn.ReLU(),
206
+ nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), # -> 64x64
207
  nn.ReLU(),
208
+ nn.ConvTranspose2d(64, out_channels, 4, stride=2, padding=1), # -> 128x128
209
  nn.Sigmoid(),
210
  )
211
 
 
255
 
256
 
257
  # ============================================================================
258
+ # DATASET HELPERS
259
  # ============================================================================
 
 
 
 
 
260
 
261
+ # Imagenette class names for generating captions
262
+ IMAGENETTE_CLASSES = {
263
+ 0: "a fish in water",
264
+ 1: "a dog running in a field",
265
+ 2: "a cassette player on a table",
266
+ 3: "a chainsaw cutting wood",
267
+ 4: "a church with a tall steeple",
268
+ 5: "a French horn on stage",
269
+ 6: "a garbage truck on the street",
270
+ 7: "a gas station at night",
271
+ 8: "a golf ball on a green",
272
+ 9: "a parachute in the sky",
273
+ }
274
+
275
+ CIFAR10_CLASSES = ["airplane flying", "automobile on road", "bird in tree",
276
+ "cat sitting", "deer in forest", "dog playing", "frog on lily pad",
277
+ "horse running", "ship on ocean", "truck driving"]
278
+
279
+
280
+ def load_image_dataset(logger: Logger):
281
+ """Load an image dataset for VQ-VAE training. Returns (stream, image_key, caption_key, name)."""
282
  from datasets import load_dataset
283
+
284
+ # Try datasets with both images and good captions
 
 
 
 
 
285
  dataset_sources = [
286
+ # (dataset_name, split, image_key, caption_key, description)
287
+ ("detection-datasets/coco", "train", "image", "caption", "COCO 2017 (detection)"),
288
+ ("frgfm/imagenette", "train", "image", "label", "Imagenette (10 classes)"),
289
+ ("cifar10", "train", "img", "label", "CIFAR-10"),
 
290
  ]
291
+
292
+ for ds_name, ds_split, img_key, cap_key, desc in dataset_sources:
293
  try:
294
+ logger.log(f" Trying {ds_name} ({desc})...\n")
295
+ ds = load_dataset(ds_name, split=ds_split, streaming=True, trust_remote_code=True)
296
+ test_item = next(iter(ds))
297
+
298
+ # Verify keys exist
299
+ actual_img_key = img_key if img_key in test_item else None
300
+ actual_cap_key = cap_key if cap_key in test_item else None
301
+
302
+ if actual_img_key is None:
303
+ # Try common alternatives
304
+ for k in ["image", "img", "png", "jpg"]:
305
+ if k in test_item:
306
+ actual_img_key = k
307
+ break
308
+
309
+ if actual_img_key is None:
310
+ logger.log(f" No image key found in {ds_name}. Keys: {list(test_item.keys())}\n")
311
+ continue
312
+
313
+ logger.log(f" Using {ds_name}! img_key='{actual_img_key}', cap_key='{actual_cap_key}'\n")
314
+ return ds, actual_img_key, actual_cap_key, ds_name
315
+
316
  except Exception as e:
317
+ logger.log(f" Failed: {str(e)[:100]}\n")
318
+ continue
319
+
320
+ return None, None, None, None
321
+
322
+
323
+ def get_caption(item, cap_key, ds_name, index):
324
+ """Extract or generate a caption for a dataset item."""
325
+ if cap_key and cap_key in item and item[cap_key] is not None:
326
+ cap = item[cap_key]
327
+ if isinstance(cap, list):
328
+ return cap[0] if cap else f"image {index}"
329
+ elif isinstance(cap, str):
330
+ return cap
331
+ elif isinstance(cap, int):
332
+ # Class label - convert to descriptive caption
333
+ if "imagenette" in ds_name.lower():
334
+ return IMAGENETTE_CLASSES.get(cap, f"photo of object {cap}")
335
+ elif "cifar" in ds_name.lower():
336
+ return CIFAR10_CLASSES[cap] if cap < len(CIFAR10_CLASSES) else f"photo of class {cap}"
337
+ else:
338
+ return f"photo of a {cap}"
339
+ return f"image {index}"
340
+
341
+
342
+ # ============================================================================
343
+ # PHASE 1: TRAIN VQ-VAE ON REAL IMAGES
344
+ # ============================================================================
345
+ def train_vq_vae(logger: Logger, state: PipelineState) -> VQVAE:
346
+ """Train VQ-VAE on real images with checkpoint/resume support."""
347
+ logger.log("=" * 60 + "\n")
348
+ logger.log("PHASE 1: Training VQ-VAE on real images\n")
349
+ logger.log("=" * 60 + "\n\n")
350
+
351
+ from datasets import load_dataset
352
+ from torchvision import transforms
353
+ from PIL import Image
354
+
355
+ # Check if already done
356
+ if state.is_done("vq_vae"):
357
+ logger.log("VQ-VAE already trained! Loading checkpoint...\n")
358
+ ckpt_path = os.path.join(PERSIST_DIR, "vq_vae_best.pt")
359
+ if os.path.exists(ckpt_path):
360
+ model = VQVAE()
361
+ model.load_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=False))
362
+ logger.log("Loaded trained VQ-VAE from checkpoint.\n")
363
+ return model
364
+ else:
365
+ logger.log("Checkpoint not found, retraining...\n")
366
+ state.update(vq_vae_done=False)
367
 
368
+ # Load dataset
369
+ ds, img_key, cap_key, ds_name = load_image_dataset(logger)
370
+ if ds is None:
371
+ logger.log("No dataset available! Cannot train VQ-VAE.\n")
372
  return None
373
 
374
  # Image transforms
 
378
  ])
379
 
380
  class ImageStreamDataset(IterableDataset):
381
+ def __init__(self, hf_dataset, transform, img_key, max_samples):
382
  self.dataset = hf_dataset
383
  self.transform = transform
384
  self.img_key = img_key
 
394
  if img.mode != "RGB":
395
  img = img.convert("RGB")
396
  tensor = self.transform(img)
 
397
  count += 1
398
+ yield tensor
399
  except Exception:
400
  continue
401
 
402
+ dataset = ImageStreamDataset(ds, transform, img_key, VQ_VAE_MAX_IMAGES)
403
  dataloader = DataLoader(dataset, batch_size=VQ_VAE_BATCH, num_workers=0)
404
 
405
+ # Initialize or resume model
406
  model = VQVAE()
407
  n_params = sum(p.numel() for p in model.parameters()) / 1e6
408
+ logger.log(f"VQ-VAE initialized: {n_params:.1f}M parameters\n")
409
+
410
+ # Resume from checkpoint if available
411
+ resume_ckpt = os.path.join(PERSIST_DIR, "vq_vae_latest.pt")
412
+ start_epoch = 0
413
+ if os.path.exists(resume_ckpt):
414
+ try:
415
+ ckpt = torch.load(resume_ckpt, map_location="cpu", weights_only=False)
416
+ model.load_state_dict(ckpt["model_state_dict"])
417
+ start_epoch = ckpt.get("epoch", 0)
418
+ logger.log(f"Resumed VQ-VAE from epoch {start_epoch}\n")
419
+ except:
420
+ logger.log("Could not resume checkpoint, starting fresh.\n")
421
 
422
  optimizer = torch.optim.Adam(model.parameters(), lr=VQ_VAE_LR)
423
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=VQ_VAE_EPOCHS)
424
  model.train()
425
 
426
+ best_loss = float('inf')
427
+
428
+ for epoch in range(start_epoch, VQ_VAE_EPOCHS):
429
  epoch_loss = 0.0
430
  epoch_recon = 0.0
431
  epoch_vq = 0.0
 
439
 
440
  optimizer.zero_grad()
441
  loss.backward()
442
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
443
  optimizer.step()
444
 
445
  epoch_loss += loss.item()
 
447
  epoch_vq += vq_loss.item()
448
  num_batches += 1
449
 
450
+ if batch_idx % 100 == 0 and batch_idx > 0:
451
  avg = epoch_loss / num_batches
452
  avg_r = epoch_recon / num_batches
453
  avg_v = epoch_vq / num_batches
454
  logger.log(f" Epoch {epoch+1}/{VQ_VAE_EPOCHS} | Batch {batch_idx} | "
455
  f"Loss: {avg:.4f} (recon: {avg_r:.4f}, vq: {avg_v:.4f})\n")
456
 
457
+ del recon, vq_loss, loss, batch
458
+ if batch_idx % 100 == 0:
459
  gc.collect()
460
 
461
+ # End of epoch
462
+ scheduler.step()
463
  elapsed = time.time() - start_time
464
  avg_loss = epoch_loss / max(num_batches, 1)
465
+ avg_recon = epoch_recon / max(num_batches, 1)
466
+ logger.log(f"\nEpoch {epoch+1} done. Loss: {avg_loss:.4f} (recon: {avg_recon:.4f}) | "
467
  f"Batches: {num_batches} | Time: {elapsed:.0f}s\n\n")
468
+
469
+ # Save checkpoint
470
+ ckpt_path = os.path.join(PERSIST_DIR, "vq_vae_latest.pt")
471
+ torch.save({
472
+ "epoch": epoch + 1,
473
+ "model_state_dict": model.state_dict(),
474
+ "optimizer_state_dict": optimizer.state_dict(),
475
+ "loss": avg_loss,
476
+ }, ckpt_path)
477
+
478
+ # Save best model
479
+ if avg_loss < best_loss:
480
+ best_loss = avg_loss
481
+ best_path = os.path.join(PERSIST_DIR, "vq_vae_best.pt")
482
+ torch.save(model.state_dict(), best_path)
483
+ logger.log(f" New best model! Loss: {avg_loss:.4f}\n")
484
+
485
+ state.update(vq_vae_epoch=epoch + 1)
486
+ gc.collect()
487
 
488
+ # Save final
489
+ final_path = os.path.join(PERSIST_DIR, "vq_vae_best.pt")
490
+ if not os.path.exists(final_path):
491
+ torch.save(model.state_dict(), final_path)
492
+
493
+ # Also save to root for easy access
494
  torch.save(model.state_dict(), "vq_vae_real.pt")
495
+
496
+ state.update(vq_vae_done=True, phase=2)
497
+ logger.log(f"VQ-VAE training complete! Best loss: {best_loss:.4f}\n\n")
498
  return model
499
 
500
 
501
  # ============================================================================
502
+ # PHASE 2: TOKENIZE IMAGE-TEXT PAIRS
503
  # ============================================================================
504
+ def tokenize_dataset(logger: Logger, state: PipelineState, vq_vae: Optional[VQVAE] = None):
505
+ """Tokenize image-text pairs through trained VQ-VAE."""
506
  logger.log("=" * 60 + "\n")
507
+ logger.log("PHASE 2: Tokenizing image-text pairs\n")
508
  logger.log("=" * 60 + "\n\n")
509
 
510
+ if state.is_done("tokenize"):
511
+ logger.log("Tokenization already done! Loading cached data...\n")
512
+ data_path = os.path.join(PERSIST_DIR, "tokenized_dataset.json")
513
+ if os.path.exists(data_path):
514
+ with open(data_path) as f:
515
+ data = json.load(f)
516
+ logger.log(f"Loaded {len(data)} tokenized samples.\n")
517
+ return data
518
+ else:
519
+ logger.log("Cached data not found, re-tokenizing...\n")
520
+ state.update(tokenize_done=False)
521
+
522
  # Load VQ-VAE if not provided
523
  if vq_vae is None:
524
+ ckpt_path = os.path.join(PERSIST_DIR, "vq_vae_best.pt")
525
+ if os.path.exists(ckpt_path):
526
+ vq_vae = VQVAE()
527
+ vq_vae.load_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=False))
528
+ logger.log("Loaded trained VQ-VAE for tokenization.\n")
529
+ elif os.path.exists("vq_vae_real.pt"):
530
  vq_vae = VQVAE()
531
  vq_vae.load_state_dict(torch.load("vq_vae_real.pt", map_location="cpu", weights_only=False))
532
+ logger.log("Loaded VQ-VAE from vq_vae_real.pt.\n")
533
  else:
534
+ logger.log("No trained VQ-VAE found! Run Phase 1 first.\n")
535
  return None
536
 
537
  vq_vae.eval()
538
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539
  from datasets import load_dataset
540
  from torchvision import transforms
541
  from PIL import Image
542
 
543
+ # Load dataset with captions
544
+ ds, img_key, cap_key, ds_name = load_image_dataset(logger)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
  if ds is None:
546
+ logger.log("No dataset available for tokenization!\n")
547
  return None
548
 
549
  transform = transforms.Compose([
 
551
  transforms.ToTensor(),
552
  ])
553
 
 
554
  tokenized_data = []
555
  count = 0
556
+ errors = 0
557
+
558
+ # Check for partial tokenization (resume support)
559
+ partial_path = os.path.join(PERSIST_DIR, "tokenized_partial.json")
560
+ if os.path.exists(partial_path):
561
+ try:
562
+ with open(partial_path) as f:
563
+ tokenized_data = json.load(f)
564
+ count = len(tokenized_data)
565
+ logger.log(f"Resuming tokenization from {count} samples.\n")
566
+ except:
567
+ tokenized_data = []
568
+ count = 0
569
 
570
+ logger.log(f"Tokenizing up to {NUM_TOKENIZE_SAMPLES} images...\n")
 
 
 
571
 
572
  for item in ds:
573
+ if count >= NUM_TOKENIZE_SAMPLES:
574
  break
575
 
576
  try:
577
+ img = item[img_key]
578
  if img.mode != "RGB":
579
  img = img.convert("RGB")
580
 
581
+ caption = get_caption(item, cap_key, ds_name, count)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
582
 
583
  img_tensor = transform(img).unsqueeze(0)
584
  with torch.no_grad():
585
  tokens = vq_vae.encode(img_tensor)
586
  flat_tokens = tokens.flatten().tolist()
587
 
588
+ # Truncate/pad to fixed length
589
+ flat_tokens = flat_tokens[:TOKENS_PER_SAMPLE]
590
+ while len(flat_tokens) < TOKENS_PER_SAMPLE:
591
  flat_tokens.append(0)
592
 
593
  tokenized_data.append({
 
596
  })
597
 
598
  count += 1
599
+
600
+ if count % 500 == 0:
601
+ logger.log(f" Tokenized {count}/{NUM_TOKENIZE_SAMPLES} images (errors: {errors})\n")
602
+ # Save partial progress
603
+ with open(partial_path, "w") as f:
604
+ json.dump(tokenized_data, f)
605
+
606
+ del img_tensor, tokens
607
+ if count % 200 == 0:
608
+ gc.collect()
609
+
610
+ except Exception as e:
611
+ errors += 1
612
+ if errors <= 5:
613
+ logger.log(f" Error on item {count}: {str(e)[:80]}\n")
614
  continue
615
 
616
+ if not tokenized_data:
617
+ logger.log("No images tokenized!\n")
618
+ return None
619
+
620
+ # Save final
621
+ data_path = os.path.join(PERSIST_DIR, "tokenized_dataset.json")
622
+ with open(data_path, "w") as f:
623
+ json.dump(tokenized_data, f)
624
+
625
+ # Also save to root
626
  with open("tokenized_dataset.json", "w") as f:
627
  json.dump(tokenized_data, f)
628
 
629
+ # Clean up partial
630
+ if os.path.exists(partial_path):
631
+ os.remove(partial_path)
632
+
633
+ state.update(tokenize_done=True, tokenize_count=len(tokenized_data), phase=3)
634
+ logger.log(f"\nTokenized {len(tokenized_data)} images saved (errors: {errors})\n\n")
635
  return tokenized_data
636
 
637
 
638
  # ============================================================================
639
  # PHASE 3: TRAIN LLM WITH LORA
640
  # ============================================================================
641
+ def train_llm(logger: Logger, state: PipelineState):
642
  """Fine-tune OLMo 2 1B with LoRA on tokenized data."""
643
  logger.log("=" * 60 + "\n")
644
+ logger.log("PHASE 3: Fine-tuning OLMo 2 1B + LoRA on real data\n")
645
  logger.log("=" * 60 + "\n\n")
646
 
647
+ if state.is_done("llm"):
648
+ logger.log("LLM already trained! Skipping.\n")
649
+ return
650
+
651
  from transformers import AutoModelForCausalLM, AutoTokenizer
652
  from peft import LoraConfig, get_peft_model, TaskType
653
 
654
  # Load data
655
+ data_path = os.path.join(PERSIST_DIR, "tokenized_dataset.json")
656
  if not os.path.exists(data_path):
657
+ data_path = "tokenized_dataset.json"
658
+
659
+ if not os.path.exists(data_path):
660
+ logger.log("No tokenized dataset found! Run Phase 2 first.\n")
661
  return
662
 
663
  with open(data_path) as f:
664
+ all_data = json.load(f)
665
+
666
+ # Limit to training samples
667
+ data = all_data[:LLM_TRAIN_SAMPLES]
668
+ logger.log(f"Loaded {len(all_data)} total samples, using {len(data)} for training\n")
669
+
670
+ # Quick data quality check
671
+ if data:
672
+ sample = data[0]
673
+ logger.log(f"Sample prompt: '{sample['text_prompt']}'\n")
674
+ logger.log(f"Sample tokens (first 10): {sample['video_tokens'][:10]}\n")
675
+ unique_tokens = len(set(sample['video_tokens']))
676
+ logger.log(f"Unique tokens in sample: {unique_tokens}\n\n")
677
 
678
  # Tokenizer
679
+ logger.log("Loading OLMo 2 1B tokenizer...\n")
680
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
681
  if tokenizer.pad_token is None:
682
  tokenizer.pad_token = tokenizer.eos_token
683
 
684
  # Model
685
+ logger.log("Loading model (fp32, CPU)...\n")
686
  model = AutoModelForCausalLM.from_pretrained(
687
  MODEL_NAME, trust_remote_code=True, torch_dtype=torch.float32
688
  )
689
+ orig_vocab = len(tokenizer)
690
+ logger.log(f"Model loaded. Original vocab: {orig_vocab}\n")
691
 
692
  # Expand vocab
693
+ logger.log(f"Adding {CODEBOOK_SIZE} visual tokens...\n")
694
  visual_tokens = [VIDEO_START, VIDEO_END, VIDEO_PAD]
695
  for i in range(CODEBOOK_SIZE):
696
  visual_tokens.append(f"<v_{i}>")
697
  tokenizer.add_tokens(visual_tokens)
698
  model.resize_token_embeddings(len(tokenizer))
699
+ logger.log(f"New vocab: {len(tokenizer)}\n")
700
 
701
  # LoRA
702
+ logger.log(f"Applying LoRA (r={LORA_R})...\n")
703
  lora_config = LoraConfig(
704
  r=LORA_R, lora_alpha=LORA_ALPHA,
705
  target_modules=["q_proj", "v_proj"],
 
709
  model = get_peft_model(model, lora_config)
710
  trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
711
  total = sum(p.numel() for p in model.parameters())
712
+ logger.log(f"LoRA: {trainable:,} / {total:,} trainable ({100*trainable/total:.2f}%)\n")
713
 
714
  # Dataset
715
  class VideoTokenDataset(Dataset):
716
+ def __init__(self, data, max_tokens=TOKENS_PER_SAMPLE):
717
  self.data = data
718
  self.max_tokens = max_tokens
719
 
 
731
  dataset = VideoTokenDataset(data)
732
  dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
733
  total_steps = NUM_EPOCHS * len(dataloader)
734
+ logger.log(f"{len(dataset)} samples x {NUM_EPOCHS} epochs = {total_steps} steps\n\n")
735
+
736
+ # Optimizer - Adafactor is more memory-efficient for CPU
737
+ from transformers import Adafactor
738
+ optimizer = Adafactor(
739
+ model.parameters(), lr=LEARNING_RATE,
740
+ relative_step=False, scale_parameter=False, warmup_init=False
741
+ )
742
+
743
+ # Resume from checkpoint if available
744
+ start_step = state.state.get("llm_step", 0)
745
+ start_epoch = state.state.get("llm_epoch", 0)
746
+
747
+ llm_ckpt_dir = os.path.join(PERSIST_DIR, "llm_checkpoint")
748
+ if start_step > 0 and os.path.exists(llm_ckpt_dir):
749
+ try:
750
+ logger.log(f"Resuming LLM training from step {start_step}, epoch {start_epoch}\n")
751
+ # We'd need to skip dataloader steps - for simplicity, restart epoch
752
+ start_step = 0
753
+ except:
754
+ pass
755
 
 
 
756
  model.train()
757
  global_step = 0
758
  running_loss = 0.0
759
+ best_loss = float('inf')
760
  start_time = time.time()
761
 
762
  for epoch in range(NUM_EPOCHS):
 
767
  prompt = batch["prompt"][0]
768
  video_tokens = batch["video_tokens"][0]
769
 
770
+ # Format training text
771
+ token_str = " ".join(f"<v_{t.item()}>" for t in video_tokens)
772
  text = f"Create a video of: {prompt} {VIDEO_START} {token_str} {VIDEO_END}"
773
 
774
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LEN, padding="max_length")
 
790
  if batch_idx % 100 == 0:
791
  elapsed = time.time() - start_time
792
  speed = global_step / elapsed if elapsed > 0 else 0
793
+ eta = (total_steps - global_step) / speed if speed > 0 else 0
794
  logger.log(f" Epoch {epoch+1}/{NUM_EPOCHS} | Step {batch_idx+1}/{len(dataloader)} | "
795
  f"Loss: {batch_loss:.4f} | Avg: {epoch_loss/num_batches:.4f} | "
796
+ f"Speed: {speed:.2f} steps/s | ETA: {eta/60:.0f}m\n")
797
 
798
+ # Save checkpoint periodically
799
+ if global_step % SAVE_EVERY == 0 and global_step > 0:
800
+ ckpt_loss = running_loss / global_step
801
+ logger.log(f" Saving checkpoint at step {global_step} (loss: {ckpt_loss:.4f})...\n")
802
+ try:
803
+ os.makedirs(llm_ckpt_dir, exist_ok=True)
804
+ model.save_pretrained(llm_ckpt_dir)
805
+ tokenizer.save_pretrained(llm_ckpt_dir)
806
+ state.update(llm_step=global_step, llm_epoch=epoch)
807
+ except Exception as e:
808
+ logger.log(f" Checkpoint save failed: {str(e)[:80]}\n")
809
+
810
+ del outputs, loss, inputs
811
+ if batch_idx % 50 == 0:
812
+ gc.collect()
813
 
814
+ avg_epoch_loss = epoch_loss / max(num_batches, 1)
815
+ logger.log(f"\nEpoch {epoch+1} done. Avg Loss: {avg_epoch_loss:.4f}\n\n")
816
+
817
+ # Save best model
818
+ if avg_epoch_loss < best_loss:
819
+ best_loss = avg_epoch_loss
820
+
821
+ state.update(llm_epoch=epoch + 1)
822
 
823
  total_time = time.time() - start_time
824
+ final_loss = running_loss / max(global_step, 1)
825
+ logger.log(f"Training complete in {total_time:.0f}s ({total_time/60:.1f} min)\n")
826
+ logger.log(f"Final avg loss: {final_loss:.4f}\n\n")
827
 
828
  # Merge & save
829
+ logger.log("Merging LoRA into base model...\n")
830
  model = model.merge_and_unload()
831
 
832
+ save_dir = os.path.join(PERSIST_DIR, "trained_model")
833
+ os.makedirs(save_dir, exist_ok=True)
834
  model.save_pretrained(save_dir, safe_serialization=True)
835
  tokenizer.save_pretrained(save_dir)
836
 
837
+ # Also save VQ-VAE checkpoint
838
+ vq_path = os.path.join(PERSIST_DIR, "vq_vae_best.pt")
839
+ if os.path.exists(vq_path):
840
  import shutil
841
+ shutil.copy(vq_path, os.path.join(save_dir, "vq_vae_final.pt"))
842
+ elif os.path.exists("vq_vae_real.pt"):
843
+ import shutil
844
+ shutil.copy("vq_vae_real.pt", os.path.join(save_dir, "vq_vae_final.pt"))
845
 
846
  # Copy tokenized dataset
847
+ if os.path.exists(os.path.join(PERSIST_DIR, "tokenized_dataset.json")):
848
  import shutil
849
+ shutil.copy(os.path.join(PERSIST_DIR, "tokenized_dataset.json"),
850
+ os.path.join(save_dir, "tokenized_dataset.json"))
851
 
852
+ logger.log("Model saved locally.\n")
853
 
854
+ # Push to Hub
855
+ if HF_TOKEN:
856
+ logger.log(f"Pushing to {REPO_ID}...\n")
857
+ try:
858
+ from huggingface_hub import HfApi
859
+ api = HfApi(token=HF_TOKEN)
860
+ try:
861
+ api.create_repo(repo_id=REPO_ID, repo_type="model", exist_ok=True)
862
+ except:
863
+ pass
864
+ api.upload_folder(
865
+ folder_path=save_dir, repo_id=REPO_ID, repo_type="model",
866
+ commit_message=f"LoRA OLMo 2 1B (r={LORA_R}, {NUM_EPOCHS} epochs, {len(data)} real samples, loss={final_loss:.4f})"
867
+ )
868
+ logger.log(f"Pushed to https://huggingface.co/{REPO_ID}\n\n")
869
+ state.update(pushed=True)
870
+ except Exception as e:
871
+ logger.log(f"Push failed: {str(e)[:200]}\n")
872
+ logger.log("Model is saved locally and can be pushed manually.\n\n")
873
+ else:
874
+ logger.log("No HF_TOKEN set, skipping push.\n")
875
+
876
+ state.update(llm_done=True, phase=4)
877
 
878
 
879
  # ============================================================================
880
  # MAIN PIPELINE
881
  # ============================================================================
882
+ def run_pipeline(log_path: str = None):
883
+ if log_path is None:
884
+ log_path = LOG_FILE
885
+
886
  logger = Logger(log_path)
887
+ state = PipelineState()
888
+
889
+ logger.log(f"Pipeline state: Phase {state.phase}\n")
890
+ logger.log(f"Persistent dir: {PERSIST_DIR}\n")
891
+ logger.log(f"Data dir contents: {os.listdir(PERSIST_DIR) if os.path.exists(PERSIST_DIR) else 'empty'}\n\n")
892
 
893
  try:
894
  # Phase 1: Train VQ-VAE
895
+ if not state.is_done("vq_vae"):
896
+ state.update(phase=1)
897
+ vq_vae = train_vq_vae(logger, state)
898
+ else:
899
+ logger.log("Skipping Phase 1 (already done)\n")
900
+ vq_vae = None
901
+
902
  gc.collect()
903
 
904
  # Phase 2: Tokenize dataset
905
+ if not state.is_done("tokenize"):
906
+ state.update(phase=2)
907
+ tokenize_dataset(logger, state, vq_vae)
908
+ else:
909
+ logger.log("Skipping Phase 2 (already done)\n")
910
+
911
  gc.collect()
912
 
913
  # Phase 3: Train LLM
914
+ if not state.is_done("llm"):
915
+ state.update(phase=3)
916
+ train_llm(logger, state)
917
+ else:
918
+ logger.log("Skipping Phase 3 (already done)\n")
919
 
920
+ logger.log("\n" + "=" * 60 + "\n")
921
+ logger.log("FULL PIPELINE COMPLETE!\n")
922
+ logger.log("=" * 60 + "\n")
923
+ state.update(phase=4)
924
+
925
  except Exception as e:
926
+ logger.log(f"\nPIPELINE ERROR: {e}\n")
927
  logger.log(traceback.format_exc())
928
 
929