{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4", "name": "Zeeb_Video_LLM_Training.ipynb" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# ๐ŸŽฌ Zeeb โ€” Video-LLM Training on T4 GPU\n", "\n", "**OLMo 2 1B + LoRA + VQ-VAE โ†’ Text-to-Video Generation**\n", "\n", "This notebook trains the full pipeline on a **Google Colab T4 GPU** and pushes checkpoints to HuggingFace incrementally.\n", "\n", "## Pipeline Overview\n", "1. **Phase 1**: Train VQ-VAE on real images (COCO, streaming)\n", "2. **Phase 2**: Tokenize image-text pairs through trained VQ-VAE\n", "3. **Phase 3**: Fine-tune OLMo 2 1B + LoRA on tokenized data โ†’ push to EeshaAI/zeeb\n", "\n", "## Key Features\n", "- โœ… **Incremental checkpoint pushing** to HuggingFace (survives Colab disconnects)\n", "- โœ… **Resume from checkpoint** if training is interrupted\n", "- โœ… **HuggingFace Trainer** with `push_to_hub=True` and `save_strategy=\"steps\"`\n", "- โœ… **Real data** from COCO/imagenette (10K+ images)\n", "- โœ… **GPU-accelerated** training (T4 = ~50x faster than CPU)\n", "\n", "**Make sure you select GPU runtime**: Runtime โ†’ Change runtime type โ†’ T4 GPU" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## โš™๏ธ Cell 1: Setup & Authentication" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# @title 1. Install Dependencies\n", "!pip install -q torch torchvision transformers peft accelerate datasets huggingface_hub safetensors imageio Pillow\n", "\n", "import torch\n", "print(f\"PyTorch: {torch.__version__}\")\n", "print(f\"CUDA available: {torch.cuda.is_available()}\")\n", "if torch.cuda.is_available():\n", " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n", " print(f\"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB\")\n", "else:\n", " raise RuntimeError(\"No GPU detected! Go to Runtime โ†’ Change runtime type โ†’ T4 GPU\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# @title 2. HuggingFace Authentication\n", "from huggingface_hub import HfApi, login\n", "import os\n", "\n", "# ๐Ÿ”‘ Paste your HuggingFace token here (must have write access to EeshaAI/zeeb)\n", "HF_TOKEN = \"YOUR_HF_TOKEN_HERE\" # @param {type:\"string\"}\n", "\n", "login(token=HF_TOKEN)\n", "\n", "api = HfApi()\n", "user_info = api.whoami()\n", "print(f\"Logged in as: {user_info['name']}\")\n", "\n", "REPO_ID = \"EeshaAI/zeeb\"\n", "api.create_repo(repo_id=REPO_ID, repo_type=\"model\", exist_ok=True)\n", "print(f\"Model repo: https://huggingface.co/{REPO_ID}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ๐Ÿง  Cell 2: VQ-VAE Model Definition" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# @title 3. VQ-VAE Architecture\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "CODEBOOK_SIZE = 1024\n", "CODEBOOK_DIM = 256\n", "LATENT_DIM = 256\n", "\n", "class Encoder(nn.Module):\n", " def __init__(self, in_channels=3, latent_dim=LATENT_DIM):\n", " super().__init__()\n", " self.net = nn.Sequential(\n", " nn.Conv2d(in_channels, 64, 4, stride=2, padding=1), # -> 64x64\n", " nn.ReLU(),\n", " nn.Conv2d(64, 128, 4, stride=2, padding=1), # -> 32x32\n", " nn.ReLU(),\n", " nn.Conv2d(128, 256, 4, stride=2, padding=1), # -> 16x16\n", " nn.ReLU(),\n", " nn.Conv2d(256, latent_dim, 4, stride=2, padding=1), # -> 8x8\n", " )\n", "\n", " def forward(self, x):\n", " return self.net(x)\n", "\n", "\n", "class VectorQuantizer(nn.Module):\n", " def __init__(self, codebook_size=CODEBOOK_SIZE, codebook_dim=CODEBOOK_DIM, commitment_cost=0.25):\n", " super().__init__()\n", " self.codebook_size = codebook_size\n", " self.codebook_dim = codebook_dim\n", " self.commitment_cost = commitment_cost\n", " self.codebook = nn.Embedding(codebook_size, codebook_dim)\n", " self.codebook.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size)\n", "\n", " def forward(self, z):\n", " B, H, W, C = z.shape\n", " z_flat = z.reshape(-1, C)\n", " dist = (z_flat.unsqueeze(1) - self.codebook.weight.unsqueeze(0)).pow(2).sum(-1)\n", " indices = dist.argmin(dim=1)\n", " z_q = self.codebook(indices).reshape(B, H, W, C)\n", " commitment_loss = F.mse_loss(z_flat, z_q.reshape(-1, C).detach())\n", " codebook_loss = F.mse_loss(z_q.reshape(-1, C), z_flat.detach())\n", " loss = codebook_loss + self.commitment_cost * commitment_loss\n", " z_q_st = z + (z_q - z).detach()\n", " return z_q_st, loss, indices.reshape(B, H, W)\n", "\n", "\n", "class Decoder(nn.Module):\n", " def __init__(self, out_channels=3, latent_dim=LATENT_DIM):\n", " super().__init__()\n", " self.net = nn.Sequential(\n", " nn.ConvTranspose2d(latent_dim, 256, 4, stride=2, padding=1), # -> 16x16\n", " nn.ReLU(),\n", " nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), # -> 32x32\n", " nn.ReLU(),\n", " nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), # -> 64x64\n", " nn.ReLU(),\n", " nn.ConvTranspose2d(64, out_channels, 4, stride=2, padding=1), # -> 128x128\n", " nn.Sigmoid(),\n", " )\n", "\n", " def forward(self, x):\n", " return self.net(x)\n", "\n", "\n", "class VQVAE(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.encoder = Encoder()\n", " self.quantizer = VectorQuantizer()\n", " self.proj_in = nn.Linear(LATENT_DIM, CODEBOOK_DIM)\n", " self.proj_out = nn.Linear(CODEBOOK_DIM, LATENT_DIM)\n", " self.decoder = Decoder()\n", "\n", " def forward(self, x):\n", " z = self.encoder(x)\n", " z = z.permute(0, 2, 3, 1)\n", " z = self.proj_in(z)\n", " z_q, vq_loss, indices = self.quantizer(z)\n", " z_q = self.proj_out(z_q)\n", " z_q = z_q.permute(0, 3, 1, 2)\n", " recon = self.decoder(z_q)\n", " return recon, vq_loss, indices\n", "\n", " def encode(self, x):\n", " z = self.encoder(x)\n", " z = z.permute(0, 2, 3, 1)\n", " z = self.proj_in(z)\n", " _, _, indices = self.quantizer(z)\n", " return indices\n", "\n", " def decode_tokens(self, token_ids, grid_h=8, grid_w=8):\n", " if isinstance(token_ids, list):\n", " token_ids = torch.tensor(token_ids, dtype=torch.long)\n", " token_ids = token_ids[:grid_h * grid_w]\n", " if len(token_ids) < grid_h * grid_w:\n", " token_ids = torch.cat([token_ids, torch.zeros(grid_h * grid_w - len(token_ids), dtype=torch.long)])\n", " z_q = self.quantizer.codebook(token_ids)\n", " z_q = self.proj_out(z_q)\n", " z_q = z_q.reshape(1, grid_h, grid_w, -1).permute(0, 3, 1, 2)\n", " return self.decoder(z_q)\n", "\n", "# Test\n", "vq_vae = VQVAE().cuda()\n", "test_input = torch.randn(2, 3, 128, 128).cuda()\n", "recon, vq_loss, indices = vq_vae(test_input)\n", "print(f\"VQ-VAE test: input {test_input.shape} -> recon {recon.shape}, indices {indices.shape}, loss {vq_loss.item():.4f}\")\n", "n_params = sum(p.numel() for p in vq_vae.parameters()) / 1e6\n", "print(f\"Parameters: {n_params:.1f}M\")\n", "del vq_vae, test_input\n", "torch.cuda.empty_cache()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ๐Ÿ–ผ๏ธ Phase 1: Train VQ-VAE on Real Images" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# @title 4. Phase 1: Train VQ-VAE\n", "from datasets import load_dataset\n", "from torchvision import transforms\n", "from torch.utils.data import DataLoader, IterableDataset\n", "import time\n", "\n", "# Check if trained VQ-VAE already exists on HF\n", "VQ_VAE_ALREADY_TRAINED = False # @param {type:\"boolean\"}\n", "VQ_VAE_EPOCHS = 5 # @param {type:\"integer\"}\n", "VQ_VAE_LR = 3e-4 # @param {type:\"number\"}\n", "VQ_VAE_BATCH = 32 # @param {type:\"integer\"}\n", "VQ_VAE_MAX_IMAGES = 20000 # @param {type:\"integer\"}\n", "VQ_VAE_IMG_SIZE = 128 # @param {type:\"integer\"}\n", "\n", "if VQ_VAE_ALREADY_TRAINED:\n", " print(\"Skipping VQ-VAE training (already trained)\")\n", " vq_vae = VQVAE()\n", " # Download from HF if available\n", " try:\n", " from huggingface_hub import hf_hub_download\n", " vq_path = hf_hub_download(REPO_ID, \"vq_vae_final.pt\", repo_type=\"model\")\n", " vq_vae.load_state_dict(torch.load(vq_path, map_location=\"cuda\", weights_only=False))\n", " print(f\"Loaded VQ-VAE from {REPO_ID}\")\n", " except:\n", " print(\"Could not download VQ-VAE, training from scratch\")\n", " VQ_VAE_ALREADY_TRAINED = False\n", "\n", "if not VQ_VAE_ALREADY_TRAINED:\n", " # Load dataset\n", " print(\"Loading image dataset...\")\n", " ds = None\n", " image_key = \"image\"\n", " cap_key = None\n", " ds_name = \"\"\n", "\n", " for name, split, ik, ck in [\n", " (\"detection-datasets/coco\", \"train\", \"image\", \"caption\"),\n", " (\"frgfm/imagenette\", \"train\", \"image\", \"label\"),\n", " (\"cifar10\", \"train\", \"img\", \"label\"),\n", " ]:\n", " try:\n", " print(f\" Trying {name}...\")\n", " ds = load_dataset(name, split=split, streaming=True, trust_remote_code=True)\n", " test_item = next(iter(ds))\n", " if ik in test_item:\n", " image_key = ik\n", " cap_key = ck if ck in test_item else None\n", " ds_name = name\n", " print(f\" Using {name}!\")\n", " break\n", " ds = None\n", " except Exception as e:\n", " print(f\" Failed: {str(e)[:80]}\")\n", " ds = None\n", "\n", " if ds is None:\n", " raise RuntimeError(\"No dataset available!\")\n", "\n", " # Transforms\n", " transform = transforms.Compose([\n", " transforms.Resize((VQ_VAE_IMG_SIZE, VQ_VAE_IMG_SIZE)),\n", " transforms.ToTensor(),\n", " ])\n", "\n", " class ImageStreamDataset(IterableDataset):\n", " def __init__(self, hf_ds, transform, img_key, max_samples):\n", " self.ds = hf_ds\n", " self.transform = transform\n", " self.img_key = img_key\n", " self.max = max_samples\n", "\n", " def __iter__(self):\n", " count = 0\n", " for item in self.ds:\n", " if count >= self.max:\n", " break\n", " try:\n", " img = item[self.img_key]\n", " if img.mode != \"RGB\":\n", " img = img.convert(\"RGB\")\n", " yield self.transform(img)\n", " count += 1\n", " except:\n", " continue\n", "\n", " dataset = ImageStreamDataset(ds, transform, image_key, VQ_VAE_MAX_IMAGES)\n", " dataloader = DataLoader(dataset, batch_size=VQ_VAE_BATCH, num_workers=2, pin_memory=True)\n", "\n", " # Initialize model\n", " vq_vae = VQVAE().cuda()\n", " optimizer = torch.optim.Adam(vq_vae.parameters(), lr=VQ_VAE_LR)\n", " scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=VQ_VAE_EPOCHS)\n", "\n", " # Training loop\n", " print(f\"\\nTraining VQ-VAE: {VQ_VAE_EPOCHS} epochs, {VQ_VAE_MAX_IMAGES} images, batch {VQ_VAE_BATCH}\")\n", " vq_vae.train()\n", " best_loss = float('inf')\n", "\n", " for epoch in range(VQ_VAE_EPOCHS):\n", " epoch_loss = 0.0\n", " epoch_recon = 0.0\n", " epoch_vq = 0.0\n", " n_batches = 0\n", " start = time.time()\n", "\n", " for batch_idx, batch in enumerate(dataloader):\n", " batch = batch.cuda()\n", " recon, vq_loss, _ = vq_vae(batch)\n", " recon_loss = F.mse_loss(recon, batch)\n", " loss = recon_loss + vq_loss\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " torch.nn.utils.clip_grad_norm_(vq_vae.parameters(), 1.0)\n", " optimizer.step()\n", "\n", " epoch_loss += loss.item()\n", " epoch_recon += recon_loss.item()\n", " epoch_vq += vq_loss.item()\n", " n_batches += 1\n", "\n", " if batch_idx % 100 == 0 and batch_idx > 0:\n", " avg = epoch_loss / n_batches\n", " print(f\" Epoch {epoch+1}/{VQ_VAE_EPOCHS} | Batch {batch_idx} | Loss: {avg:.4f} (recon: {epoch_recon/n_batches:.4f}, vq: {epoch_vq/n_batches:.4f})\")\n", "\n", " scheduler.step()\n", " elapsed = time.time() - start\n", " avg_loss = epoch_loss / max(n_batches, 1)\n", " print(f\"\\n Epoch {epoch+1} done. Loss: {avg_loss:.4f} | Batches: {n_batches} | Time: {elapsed:.0f}s\")\n", "\n", " # Save best model & push to HF\n", " if avg_loss < best_loss:\n", " best_loss = avg_loss\n", " torch.save(vq_vae.state_dict(), \"vq_vae_best.pt\")\n", " print(f\" New best model! Loss: {avg_loss:.4f}\")\n", "\n", " # Push VQ-VAE checkpoint to HF after each epoch\n", " torch.save(vq_vae.state_dict(), \"vq_vae_final.pt\")\n", " try:\n", " api.upload_file(\n", " path_or_fileobj=\"vq_vae_final.pt\",\n", " path_in_repo=\"vq_vae_final.pt\",\n", " repo_id=REPO_ID,\n", " repo_type=\"model\",\n", " commit_message=f\"VQ-VAE epoch {epoch+1}, loss {avg_loss:.4f}\"\n", " )\n", " print(f\" Pushed VQ-VAE checkpoint to HF!\")\n", " except Exception as e:\n", " print(f\" Push failed: {e}\")\n", "\n", " print(f\"\\nVQ-VAE training complete! Best loss: {best_loss:.4f}\")\n", " vq_vae.eval()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ๐Ÿ”ข Phase 2: Tokenize Dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# @title 5. Phase 2: Tokenize Image-Text Pairs\n", "import json\n", "import numpy as np\n", "from PIL import Image\n", "\n", "NUM_TOKENIZE = 50000 # @param {type:\"integer\"}\n", "TOKENS_PER_SAMPLE = 64 # 8x8 grid\n", "\n", "# Caption helpers\n", "IMAGENETTE_CLASSES = {\n", " 0: \"a fish in water\", 1: \"a dog running in a field\", 2: \"a cassette player on a table\",\n", " 3: \"a chainsaw cutting wood\", 4: \"a church with a tall steeple\", 5: \"a French horn on stage\",\n", " 6: \"a garbage truck on the street\", 7: \"a gas station at night\", 8: \"a golf ball on a green\",\n", " 9: \"a parachute in the sky\",\n", "}\n", "CIFAR10_CLASSES = [\"airplane flying\", \"automobile on road\", \"bird in tree\", \"cat sitting\",\n", " \"deer in forest\", \"dog playing\", \"frog on lily pad\", \"horse running\",\n", " \"ship on ocean\", \"truck driving\"]\n", "\n", "def get_caption(item, cap_key, ds_name, idx):\n", " if cap_key and cap_key in item and item[cap_key] is not None:\n", " cap = item[cap_key]\n", " if isinstance(cap, list):\n", " return cap[0] if cap else f\"image {idx}\"\n", " elif isinstance(cap, str):\n", " return cap\n", " elif isinstance(cap, int):\n", " if \"imagenette\" in ds_name.lower():\n", " return IMAGENETTE_CLASSES.get(cap, f\"photo of object {cap}\")\n", " elif \"cifar\" in ds_name.lower():\n", " return CIFAR10_CLASSES[cap] if cap < len(CIFAR10_CLASSES) else f\"photo of class {cap}\"\n", " return f\"photo of a {cap}\"\n", " return f\"image {idx}\"\n", "\n", "# Load dataset for tokenization (re-load to get fresh stream)\n", "print(\"Loading dataset for tokenization...\")\n", "ds = None\n", "image_key = \"image\"\n", "cap_key = None\n", "ds_name = \"\"\n", "\n", "for name, split, ik, ck in [\n", " (\"detection-datasets/coco\", \"train\", \"image\", \"caption\"),\n", " (\"frgfm/imagenette\", \"train\", \"image\", \"label\"),\n", " (\"cifar10\", \"train\", \"img\", \"label\"),\n", "]:\n", " try:\n", " ds = load_dataset(name, split=split, streaming=True, trust_remote_code=True)\n", " test_item = next(iter(ds))\n", " if ik in test_item:\n", " image_key = ik\n", " cap_key = ck if ck in test_item else None\n", " ds_name = name\n", " print(f\"Using {name}\")\n", " break\n", " ds = None\n", " except:\n", " ds = None\n", "\n", "if ds is None:\n", " raise RuntimeError(\"No dataset!\")\n", "\n", "transform = transforms.Compose([\n", " transforms.Resize((VQ_VAE_IMG_SIZE, VQ_VAE_IMG_SIZE)),\n", " transforms.ToTensor(),\n", "])\n", "\n", "vq_vae.eval()\n", "tokenized_data = []\n", "count = 0\n", "errors = 0\n", "\n", "print(f\"Tokenizing {NUM_TOKENIZE} images...\")\n", "for item in ds:\n", " if count >= NUM_TOKENIZE:\n", " break\n", " try:\n", " img = item[image_key]\n", " if img.mode != \"RGB\":\n", " img = img.convert(\"RGB\")\n", " caption = get_caption(item, cap_key, ds_name, count)\n", "\n", " img_tensor = transform(img).unsqueeze(0).cuda()\n", " with torch.no_grad():\n", " tokens = vq_vae.encode(img_tensor)\n", " flat_tokens = tokens.flatten().tolist()\n", "\n", " flat_tokens = flat_tokens[:TOKENS_PER_SAMPLE]\n", " while len(flat_tokens) < TOKENS_PER_SAMPLE:\n", " flat_tokens.append(0)\n", "\n", " tokenized_data.append({\n", " \"text_prompt\": caption,\n", " \"video_tokens\": flat_tokens,\n", " })\n", " count += 1\n", "\n", " if count % 2000 == 0:\n", " print(f\" Tokenized {count}/{NUM_TOKENIZE} (errors: {errors})\")\n", " # Save checkpoint\n", " with open(\"tokenized_dataset.json\", \"w\") as f:\n", " json.dump(tokenized_data, f)\n", " # Push to HF\n", " try:\n", " api.upload_file(\n", " path_or_fileobj=\"tokenized_dataset.json\",\n", " path_in_repo=\"tokenized_dataset.json\",\n", " repo_id=REPO_ID,\n", " repo_type=\"model\",\n", " commit_message=f\"Tokenized {count} samples\"\n", " )\n", " except:\n", " pass\n", "\n", " del img_tensor\n", " if count % 500 == 0:\n", " torch.cuda.empty_cache()\n", "\n", " except Exception as e:\n", " errors += 1\n", " if errors <= 3:\n", " print(f\" Error: {str(e)[:60]}\")\n", " continue\n", "\n", "# Final save & push\n", "with open(\"tokenized_dataset.json\", \"w\") as f:\n", " json.dump(tokenized_data, f)\n", "\n", "api.upload_file(\n", " path_or_fileobj=\"tokenized_dataset.json\",\n", " path_in_repo=\"tokenized_dataset.json\",\n", " repo_id=REPO_ID,\n", " repo_type=\"model\",\n", " commit_message=f\"Tokenized {len(tokenized_data)} samples (complete)\"\n", ")\n", "\n", "print(f\"\\nTokenization complete: {len(tokenized_data)} samples ({errors} errors)\")\n", "print(f\"Sample: '{tokenized_data[0]['text_prompt']}' -> {tokenized_data[0]['video_tokens'][:10]}\")\n", "print(f\"Unique tokens in sample: {len(set(tokenized_data[0]['video_tokens']))}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ๐Ÿš€ Phase 3: Fine-tune LLM with LoRA (GPU + Incremental Push)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# @title 6. Phase 3: Setup LLM + LoRA with HuggingFace Trainer\n", "from transformers import (\n", " AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer,\n", " DataCollatorForLanguageModeling, TrainerCallback\n", ")\n", "from peft import LoraConfig, get_peft_model, TaskType\n", "from torch.utils.data import Dataset\n", "\n", "# Hyperparameters\n", "LORA_R = 8 # @param {type:\"integer\"}\n", "LORA_ALPHA = 16 # @param {type:\"integer\"}\n", "LORA_DROPOUT = 0.05 # @param {type:\"number\"}\n", "LEARNING_RATE = 2e-4 # @param {type:\"number\"}\n", "BATCH_SIZE = 2 # @param {type:\"integer\"}\n", "GRADIENT_ACCUMULATION = 8 # @param {type:\"integer\"}\n", "NUM_EPOCHS = 3 # @param {type:\"integer\"}\n", "MAX_SEQ_LEN = 256 # @param {type:\"integer\"}\n", "WARMUP_RATIO = 0.03 # @param {type:\"number\"}\n", "WEIGHT_DECAY = 0.01 # @param {type:\"number\"}\n", "SAVE_STEPS = 200 # @param {type:\"integer\"}\n", "EVAL_STEPS = 200 # @param {type:\"integer\"}\n", "FP16 = True # @param {type:\"boolean\"}\n", "TRAIN_ON_ALL_DATA = False # @param {type:\"boolean\"}\n", "LLM_TRAIN_SAMPLES = 10000 # @param {type:\"integer\"}\n", "\n", "MODEL_NAME = \"allenai/OLMo-2-0425-1B-Instruct\"\n", "VIDEO_START = \"\"\n", "VIDEO_END = \"\"\n", "VIDEO_PAD = \"\"\n", "\n", "# Load tokenizer\n", "print(\"Loading tokenizer...\")\n", "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)\n", "if tokenizer.pad_token is None:\n", " tokenizer.pad_token = tokenizer.eos_token\n", "orig_vocab = len(tokenizer)\n", "print(f\"Original vocab: {orig_vocab}\")\n", "\n", "# Expand vocab with visual tokens\n", "visual_tokens = [VIDEO_START, VIDEO_END, VIDEO_PAD]\n", "for i in range(CODEBOOK_SIZE):\n", " visual_tokens.append(f\"\")\n", "tokenizer.add_tokens(visual_tokens)\n", "print(f\"Expanded vocab: {len(tokenizer)} (+{len(tokenizer) - orig_vocab} visual tokens)\")\n", "\n", "# Load model\n", "print(\"Loading model...\")\n", "dtype = torch.float16 if FP16 else torch.float32\n", "model = AutoModelForCausalLM.from_pretrained(\n", " MODEL_NAME, trust_remote_code=True, torch_dtype=dtype\n", ")\n", "model.resize_token_embeddings(len(tokenizer))\n", "print(f\"Model loaded: {MODEL_NAME}\")\n", "\n", "# Apply LoRA\n", "print(f\"Applying LoRA (r={LORA_R})...\")\n", "lora_config = LoraConfig(\n", " r=LORA_R,\n", " lora_alpha=LORA_ALPHA,\n", " target_modules=[\"q_proj\", \"v_proj\", \"k_proj\", \"o_proj\"], # More modules than before!\n", " lora_dropout=LORA_DROPOUT,\n", " bias=\"none\",\n", " task_type=TaskType.CAUSAL_LM,\n", ")\n", "model = get_peft_model(model, lora_config)\n", "trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", "total = sum(p.numel() for p in model.parameters())\n", "print(f\"LoRA: {trainable:,} / {total:,} trainable ({100*trainable/total:.2f}%)\")\n", "model.print_trainable_parameters()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# @title 7. Create Training Dataset\n", "class VideoTokenDataset(Dataset):\n", " def __init__(self, data, tokenizer, max_tokens=64, max_len=256):\n", " self.data = data\n", " self.tokenizer = tokenizer\n", " self.max_tokens = max_tokens\n", " self.max_len = max_len\n", "\n", " def __len__(self):\n", " return len(self.data)\n", "\n", " def __getitem__(self, idx):\n", " item = self.data[idx]\n", " prompt = item[\"text_prompt\"]\n", " tokens = item[\"video_tokens\"][:self.max_tokens]\n", " while len(tokens) < self.max_tokens:\n", " tokens.append(0)\n", " token_str = \" \".join(f\"\" for t in tokens)\n", " text = f\"Create a video of: {prompt} {VIDEO_START} {token_str} {VIDEO_END}\"\n", "\n", " encoding = self.tokenizer(\n", " text, return_tensors=\"pt\", truncation=True,\n", " max_length=self.max_len, padding=\"max_length\"\n", " )\n", " input_ids = encoding[\"input_ids\"].squeeze()\n", " attention_mask = encoding[\"attention_mask\"].squeeze()\n", " labels = input_ids.clone()\n", " # Don't compute loss on padding\n", " labels[labels == self.tokenizer.pad_token_id] = -100\n", "\n", " return {\n", " \"input_ids\": input_ids,\n", " \"attention_mask\": attention_mask,\n", " \"labels\": labels,\n", " }\n", "\n", "# Load data\n", "with open(\"tokenized_dataset.json\") as f:\n", " all_data = json.load(f)\n", "\n", "if not TRAIN_ON_ALL_DATA:\n", " all_data = all_data[:LLM_TRAIN_SAMPLES]\n", "\n", "print(f\"Training on {len(all_data)} samples\")\n", "\n", "# Split into train/eval\n", "split_idx = int(len(all_data) * 0.95)\n", "train_data = all_data[:split_idx]\n", "eval_data = all_data[split_idx:]\n", "\n", "train_dataset = VideoTokenDataset(train_data, tokenizer)\n", "eval_dataset = VideoTokenDataset(eval_data, tokenizer)\n", "\n", "print(f\"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}\")\n", "\n", "# Test one sample\n", "sample = train_dataset[0]\n", "decoded = tokenizer.decode(sample[\"input_ids\"][:80], skip_special_tokens=False)\n", "print(f\"Sample: {decoded[:200]}...\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# @title 8. Configure HuggingFace Trainer with Incremental Push\n", "\n", "# Training arguments with push_to_hub for incremental checkpoint saves\n", "training_args = TrainingArguments(\n", " output_dir=\"./zeeb-checkpoints\",\n", " \n", " # Training params\n", " num_train_epochs=NUM_EPOCHS,\n", " per_device_train_batch_size=BATCH_SIZE,\n", " per_device_eval_batch_size=BATCH_SIZE,\n", " gradient_accumulation_steps=GRADIENT_ACCUMULATION,\n", " learning_rate=LEARNING_RATE,\n", " weight_decay=WEIGHT_DECAY,\n", " warmup_ratio=WARMUP_RATIO,\n", " lr_scheduler_type=\"cosine\",\n", " max_grad_norm=1.0,\n", " \n", " # Precision\n", " fp16=FP16,\n", " bf16=False,\n", " \n", " # Logging\n", " logging_steps=10,\n", " logging_first_step=True,\n", " \n", " # Saving - INCREMENTAL PUSH TO HF\n", " save_strategy=\"steps\",\n", " save_steps=SAVE_STEPS,\n", " save_total_limit=3, # Keep only 3 checkpoints on disk\n", " \n", " # Evaluation\n", " eval_strategy=\"steps\",\n", " eval_steps=EVAL_STEPS,\n", " \n", " # INCREMENTAL PUSH TO HUGGINGFACE\n", " push_to_hub=True,\n", " hub_model_id=REPO_ID,\n", " hub_token=HF_TOKEN,\n", " hub_strategy=\"every_save\", # Push every time we save a checkpoint!\n", " \n", " # Resume from checkpoint\n", " resume_from_checkpoint=True,\n", " \n", " # Performance\n", " dataloader_num_workers=2,\n", " dataloader_pin_memory=True,\n", " gradient_checkpointing=True, # Save memory\n", " optim=\"adamw_torch\",\n", " \n", " # Misc\n", " remove_unused_columns=False,\n", " report_to=\"none\", # Disable wandb/tensorboard\n", " run_name=\"zeeb-video-llm\",\n", ")\n", "\n", "print(\"Training Arguments:\")\n", "print(f\" Epochs: {NUM_EPOCHS}\")\n", "print(f\" Batch: {BATCH_SIZE} x {GRADIENT_ACCUMULATION} accumulation = effective {BATCH_SIZE * GRADIENT_ACCUMULATION}\")\n", "print(f\" LR: {LEARNING_RATE}, Scheduler: cosine\")\n", "print(f\" FP16: {FP16}\")\n", "print(f\" Save every {SAVE_STEPS} steps โ†’ push to HF\")\n", "print(f\" Push to: {REPO_ID}\")\n", "print(f\" Hub strategy: every_save (incremental push)\")\n", "print(f\" Gradient checkpointing: True\")\n", "print(f\" Resume from checkpoint: True\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# @title 9. ๐Ÿš€ START TRAINING! (with auto-resume)\n", "import os\n", "\n", "# Check for existing checkpoints to resume from\n", "checkpoint_dir = \"./zeeb-checkpoints\"\n", "resume_ckpt = None\n", "if os.path.exists(checkpoint_dir):\n", " checkpoints = [d for d in os.listdir(checkpoint_dir) if d.startswith(\"checkpoint-\")]\n", " if checkpoints:\n", " latest = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))[-1]\n", " resume_ckpt = os.path.join(checkpoint_dir, latest)\n", " print(f\"Found checkpoint to resume from: {resume_ckpt}\")\n", "\n", "# Create trainer\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_dataset,\n", " eval_dataset=eval_dataset,\n", " data_collator=None, # Use default\n", ")\n", "\n", "# Calculate total steps\n", "total_steps = (len(train_dataset) // (BATCH_SIZE * GRADIENT_ACCUMULATION)) * NUM_EPOCHS\n", "print(f\"\\nTotal training steps: ~{total_steps}\")\n", "print(f\"Checkpoints will be pushed every {SAVE_STEPS} steps ({total_steps // SAVE_STEPS} pushes)\")\n", "print(f\"\\nStarting training...\")\n", "print(f\"If Colab disconnects, just re-run this cell โ€” it will auto-resume!\\n\")\n", "\n", "# Train! (auto-resumes from checkpoint if available)\n", "train_result = trainer.train(resume_from_checkpoint=resume_ckpt)\n", "\n", "print(f\"\\nTraining complete!\")\n", "print(f\" Final loss: {train_result.training_loss:.4f}\")\n", "print(f\" Total steps: {train_result.global_step}\")\n", "print(f\" Training time: {train_result.metrics['train_runtime']:.0f}s ({train_result.metrics['train_runtime']/60:.1f} min)\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# @title 10. Merge LoRA & Push Final Model to HuggingFace\n", "print(\"Merging LoRA weights into base model...\")\n", "model = model.merge_and_unload()\n", "\n", "# Save locally\n", "final_dir = \"./zeeb-final\"\n", "model.save_pretrained(final_dir, safe_serialization=True)\n", "tokenizer.save_pretrained(final_dir)\n", "\n", "# Copy VQ-VAE checkpoint\n", "import shutil\n", "if os.path.exists(\"vq_vae_final.pt\"):\n", " shutil.copy(\"vq_vae_final.pt\", f\"{final_dir}/vq_vae_final.pt\")\n", "if os.path.exists(\"tokenized_dataset.json\"):\n", " shutil.copy(\"tokenized_dataset.json\", f\"{final_dir}/tokenized_dataset.json\")\n", "\n", "# Push final merged model to HuggingFace\n", "print(f\"Pushing final model to {REPO_ID}...\")\n", "model.push_to_hub(\n", " REPO_ID,\n", " token=HF_TOKEN,\n", " commit_message=f\"Zeeb v2: OLMo 2 1B + LoRA (r={LORA_R}), {NUM_EPOCHS} epochs, {len(train_data)} samples, GPU-trained\"\n", ")\n", "tokenizer.push_to_hub(\n", " REPO_ID,\n", " token=HF_TOKEN,\n", " commit_message=f\"Zeeb v2: tokenizer with visual tokens\"\n", ")\n", "\n", "# Push additional files\n", "for fname in [\"vq_vae_final.pt\", \"tokenized_dataset.json\"]:\n", " if os.path.exists(fname):\n", " api.upload_file(\n", " path_or_fileobj=fname,\n", " path_in_repo=fname,\n", " repo_id=REPO_ID,\n", " repo_type=\"model\",\n", " commit_message=f\"Add {fname}\"\n", " )\n", "\n", "print(f\"\\nโœ… Final model pushed to https://huggingface.co/{REPO_ID}\")\n", "print(\"This model can now be loaded in the HF Space for video generation!\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ๐Ÿงช Test: Generate a Video with the Trained Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# @title 11. Test Video Generation\n", "import numpy as np\n", "from PIL import Image\n", "import imageio\n", "\n", "PROMPT = \"A cat jumping on a sofa\" # @param {type:\"string\"}\n", "MAX_TOKENS = 64 # @param {type:\"integer\"}\n", "TEMPERATURE = 0.9 # @param {type:\"number\"}\n", "TOP_K = 50 # @param {type:\"integer\"}\n", "\n", "# Get visual token IDs\n", "VIDEO_START_ID = tokenizer.convert_tokens_to_ids(\"\")\n", "VIDEO_END_ID = tokenizer.convert_tokens_to_ids(\"\")\n", "V_TOKEN_START_ID = tokenizer.convert_tokens_to_ids(\"\")\n", "V_TOKEN_END_ID = tokenizer.convert_tokens_to_ids(\"\")\n", "\n", "# Load VQ-VAE for decoding\n", "vq_vae = VQVAE().cuda()\n", "if os.path.exists(\"vq_vae_final.pt\"):\n", " vq_vae.load_state_dict(torch.load(\"vq_vae_final.pt\", map_location=\"cuda\", weights_only=False))\n", " print(\"Loaded trained VQ-VAE\")\n", "vq_vae.eval()\n", "\n", "# Generate with constrained decoding\n", "text = f\"Create a video of: {PROMPT} \"\n", "inputs = tokenizer(text, return_tensors=\"pt\", truncation=True, max_length=256)\n", "current_ids = inputs[\"input_ids\"].cuda()\n", "\n", "vocab_size = len(tokenizer)\n", "visual_mask = torch.zeros(vocab_size, dtype=torch.bool)\n", "visual_mask[V_TOKEN_START_ID:V_TOKEN_END_ID + 1] = True\n", "visual_mask[VIDEO_END_ID] = True\n", "\n", "visual_token_ids = []\n", "model.eval()\n", "\n", "print(f\"Generating visual tokens for: '{PROMPT}'\")\n", "with torch.no_grad():\n", " for step in range(MAX_TOKENS):\n", " outputs = model(input_ids=current_ids)\n", " logits = outputs.logits[:, -1, :]\n", " masked = logits.clone()\n", " masked[0, ~visual_mask] = float('-inf')\n", " masked = masked / max(TEMPERATURE, 0.01)\n", " if TOP_K > 0:\n", " top_k_values, _ = torch.topk(masked[0], min(TOP_K, masked.size(-1)))\n", " threshold = top_k_values[-1]\n", " masked[0, masked[0] < threshold] = float('-inf')\n", " probs = F.softmax(masked, dim=-1)\n", " next_token = torch.multinomial(probs, num_samples=1)\n", " next_id = next_token.item()\n", " if next_id == VIDEO_END_ID:\n", " break\n", " visual_idx = next_id - V_TOKEN_START_ID\n", " visual_token_ids.append(visual_idx)\n", " current_ids = torch.cat([current_ids, next_token], dim=-1)\n", "\n", "print(f\"Generated {len(visual_token_ids)} visual tokens ({len(set(visual_token_ids))} unique)\")\n", "\n", "# Decode through VQ-VAE\n", "grid_h, grid_w = 8, 8\n", "tokens_per_frame = grid_h * grid_w\n", "num_frames = max(1, len(visual_token_ids) // tokens_per_frame)\n", "\n", "frames = []\n", "for fi in range(num_frames):\n", " ft = visual_token_ids[fi*tokens_per_frame:(fi+1)*tokens_per_frame]\n", " frame_tensor = vq_vae.decode_tokens(ft, grid_h, grid_w)\n", " frame_np = (frame_tensor[0].permute(1, 2, 0).detach().cpu().numpy() * 255).astype(np.uint8)\n", " frames.append(frame_np)\n", "\n", "# Save video\n", "if frames:\n", " upscaled = [np.array(Image.fromarray(f).resize((256, 256), Image.BILINEAR)) for f in frames]\n", " output_path = \"/content/generated_video.mp4\"\n", " imageio.mimsave(output_path, upscaled, fps=2)\n", " print(f\"Video saved: {output_path} ({len(upscaled)} frames, 256x256)\")\n", " \n", " # Display first frame\n", " from IPython.display import display\n", " display(Image.fromarray(upscaled[0]))\n", "else:\n", " print(\"No frames generated\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ๐Ÿ“Š Summary & Next Steps\n", "\n", "### What was trained:\n", "- **VQ-VAE**: 3.8M params, trained on real COCO images, maps images โ†” discrete tokens\n", "- **OLMo 2 1B + LoRA**: 1B params (only ~1M trainable), fine-tuned to predict visual tokens from text\n", "\n", "### How to improve further:\n", "1. **More data**: Use 50K+ samples instead of 10K\n", "2. **Bigger LoRA**: Increase r from 8 to 16-32\n", "3. **More target modules**: Add \"gate_proj\", \"up_proj\", \"down_proj\" to LoRA targets\n", "4. **Video data**: Use OpenVid-1M with actual video frames (multiple frames per clip)\n", "5. **Larger codebook**: 4096 or 8192 entries instead of 1024\n", "6. **Higher resolution**: 256x256 VQ-VAE instead of 128x128\n", "7. **Multi-frame**: Encode 4-8 frames per video, not just 1\n", "\n", "### Resume after Colab disconnect:\n", "Just re-run cells 1, 2, 3, 6, 7, 8, and 9 โ€” the Trainer will auto-resume from the last checkpoint pushed to HF!" ] } ] }