#!/usr/bin/env python3 """ Quick validation experiment: Single visual token vs Multi-token visual prefix. Trains two TINY models (2-layer decoder, shared ViT) on 1K real images for 10 epochs. Compares generation quality to determine if the architecture bottleneck is real. Result interpretation: - If BOTH collapse → data/training issue, not architecture - If multi-token works but single doesn't → architecture bottleneck confirmed - If single-token works → data scale was the issue all along Usage: python3 scripts/experiment_visual_prefix.py --data_dir data/downloads/stage1 --device cuda python3 scripts/experiment_visual_prefix.py --data_dir data/downloads/stage2_fullscale --device cuda """ import argparse import json import os import sys import time import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # ============================================================================ # Tiny Model A: Single visual token (current architecture) # ============================================================================ class TinyEncoderA(nn.Module): """Simplified ViT → single embedding.""" def __init__(self, img_size=224, patch_size=16, hidden_dim=256, embed_dim=512): super().__init__() n_patches = (img_size // patch_size) ** 2 self.patch_embed = nn.Conv2d(3, hidden_dim, patch_size, patch_size) self.pos_embed = nn.Parameter(torch.randn(1, n_patches, hidden_dim) * 0.02) self.blocks = nn.ModuleList([ nn.TransformerEncoderLayer(hidden_dim, nhead=4, dim_feedforward=hidden_dim*4, batch_first=True) for _ in range(2) ]) self.norm = nn.LayerNorm(hidden_dim) self.proj = nn.Linear(hidden_dim, embed_dim) def forward(self, images): x = self.patch_embed(images).flatten(2).transpose(1, 2) # [B, N, D] x = x + self.pos_embed[:, :x.shape[1], :] for block in self.blocks: x = block(x) x = self.norm(x) x = x.mean(dim=1) # average pool → single vector x = self.proj(x) x = F.normalize(x, p=2, dim=-1) # L2-normalize like current arch return x # [B, embed_dim] class TinyDecoderA(nn.Module): """Single-token decoder (current architecture).""" def __init__(self, embed_dim=512, hidden_dim=256, vocab_size=1000, max_seq=64): super().__init__() self.embed_proj = nn.Linear(embed_dim, hidden_dim) self.token_embed = nn.Embedding(vocab_size, hidden_dim, padding_idx=0) self.pos_embed = nn.Parameter(torch.randn(1, max_seq + 1, hidden_dim) * 0.02) self.blocks = nn.ModuleList([ nn.TransformerDecoderLayer(hidden_dim, nhead=4, dim_feedforward=hidden_dim*4, batch_first=True) for _ in range(2) ]) self.norm = nn.LayerNorm(hidden_dim) self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False) self.hidden_dim = hidden_dim self.vocab_size = vocab_size def forward(self, pred_embedding, target_ids): embed_token = self.embed_proj(pred_embedding).unsqueeze(1) # [B, 1, D] token_embeds = self.token_embed(target_ids) # [B, T, D] x = torch.cat([embed_token, token_embeds], dim=1) # [B, 1+T, D] T = x.shape[1] x = x + self.pos_embed[:, :T, :] # Causal mask mask = nn.Transformer.generate_square_subsequent_mask(T, device=x.device) for block in self.blocks: x = block(x, x, tgt_mask=mask, memory_mask=mask) x = self.norm(x) logits = self.lm_head(x) shift_logits = logits[:, :-1, :].contiguous() shift_targets = target_ids.contiguous() loss = F.cross_entropy(shift_logits.view(-1, self.vocab_size), shift_targets.view(-1), ignore_index=0) return loss @torch.no_grad() def generate(self, pred_embedding, max_tokens=32, eos_id=2): embed_token = self.embed_proj(pred_embedding).unsqueeze(1) x = embed_token generated = [] for _ in range(max_tokens): T = x.shape[1] pos_x = x + self.pos_embed[:, :T, :] mask = nn.Transformer.generate_square_subsequent_mask(T, device=x.device) h = pos_x for block in self.blocks: h = block(h, h, tgt_mask=mask, memory_mask=mask) h = self.norm(h) logits = self.lm_head(h[:, -1, :]) next_token = logits.argmax(dim=-1) generated.append(next_token.item()) if next_token.item() == eos_id: break next_embed = self.token_embed(next_token.unsqueeze(0)) x = torch.cat([x, next_embed], dim=1) return generated # ============================================================================ # Tiny Model B: Multi-token visual prefix (proposed fix) # ============================================================================ class TinyEncoderB(nn.Module): """ViT → multi-token output (N visual tokens).""" def __init__(self, img_size=224, patch_size=16, hidden_dim=256, n_visual_tokens=16): super().__init__() n_patches = (img_size // patch_size) ** 2 self.patch_embed = nn.Conv2d(3, hidden_dim, patch_size, patch_size) self.pos_embed = nn.Parameter(torch.randn(1, n_patches, hidden_dim) * 0.02) self.blocks = nn.ModuleList([ nn.TransformerEncoderLayer(hidden_dim, nhead=4, dim_feedforward=hidden_dim*4, batch_first=True) for _ in range(2) ]) self.norm = nn.LayerNorm(hidden_dim) # Learned queries for downsampling self.queries = nn.Parameter(torch.randn(1, n_visual_tokens, hidden_dim) * 0.02) self.cross_attn = nn.MultiheadAttention(hidden_dim, num_heads=4, batch_first=True) self.query_norm = nn.LayerNorm(hidden_dim) def forward(self, images): x = self.patch_embed(images).flatten(2).transpose(1, 2) x = x + self.pos_embed[:, :x.shape[1], :] for block in self.blocks: x = block(x) x = self.norm(x) # [B, N_patches, D] # Downsample to N visual tokens via cross-attention B = x.shape[0] queries = self.queries.expand(B, -1, -1) out, _ = self.cross_attn(queries, x, x) return self.query_norm(out) # [B, N_visual, D] class TinyDecoderB(nn.Module): """Multi-token visual prefix decoder.""" def __init__(self, hidden_dim=256, vocab_size=1000, max_seq=64): super().__init__() self.token_embed = nn.Embedding(vocab_size, hidden_dim, padding_idx=0) self.pos_embed = nn.Parameter(torch.randn(1, max_seq + 32, hidden_dim) * 0.02) self.blocks = nn.ModuleList([ nn.TransformerDecoderLayer(hidden_dim, nhead=4, dim_feedforward=hidden_dim*4, batch_first=True) for _ in range(2) ]) self.norm = nn.LayerNorm(hidden_dim) self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False) self.hidden_dim = hidden_dim self.vocab_size = vocab_size def forward(self, visual_prefix, target_ids): # visual_prefix: [B, N, D] — N visual tokens N = visual_prefix.shape[1] token_embeds = self.token_embed(target_ids) # [B, T, D] x = torch.cat([visual_prefix, token_embeds], dim=1) # [B, N+T, D] T_total = x.shape[1] x = x + self.pos_embed[:, :T_total, :] mask = nn.Transformer.generate_square_subsequent_mask(T_total, device=x.device) for block in self.blocks: x = block(x, x, tgt_mask=mask, memory_mask=mask) x = self.norm(x) logits = self.lm_head(x) # Loss: only on answer positions (after visual prefix) answer_logits = logits[:, N-1:-1, :].contiguous() # [B, T, vocab] shift_targets = target_ids.contiguous() loss = F.cross_entropy(answer_logits.view(-1, self.vocab_size), shift_targets.view(-1), ignore_index=0) return loss @torch.no_grad() def generate(self, visual_prefix, max_tokens=32, eos_id=2): x = visual_prefix # [1, N, D] generated = [] for _ in range(max_tokens): T = x.shape[1] pos_x = x + self.pos_embed[:, :T, :] mask = nn.Transformer.generate_square_subsequent_mask(T, device=x.device) h = pos_x for block in self.blocks: h = block(h, h, tgt_mask=mask, memory_mask=mask) h = self.norm(h) logits = self.lm_head(h[:, -1, :]) next_token = logits.argmax(dim=-1) generated.append(next_token.item()) if next_token.item() == eos_id: break next_embed = self.token_embed(next_token.unsqueeze(0)) x = torch.cat([x, next_embed], dim=1) return generated # ============================================================================ # Simple Dataset # ============================================================================ class SimpleVQADataset(Dataset): """Load N real image-text pairs from JSONL.""" def __init__(self, jsonl_dir, max_samples=1000, img_size=224, vocab_size=1000): from PIL import Image from torchvision import transforms self.transform = transforms.Compose([ transforms.Resize((img_size, img_size)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) self.samples = [] self.vocab_size = vocab_size # Build simple word→id vocab from data all_words = set() raw_samples = [] for fname in sorted(os.listdir(jsonl_dir)): if fname.endswith('.jsonl'): with open(os.path.join(jsonl_dir, fname)) as f: for line in f: try: item = json.loads(line.strip()) img_path = item.get("image_path") answer = item.get("answer", "") if img_path and os.path.exists(img_path) and answer: raw_samples.append((img_path, answer)) all_words.update(answer.lower().split()[:20]) if len(raw_samples) >= max_samples: break except json.JSONDecodeError: continue if len(raw_samples) >= max_samples: break # Simple vocab: PAD=0, BOS=1, EOS=2, then words self.word2id = {"": 0, "": 1, "": 2} for w in sorted(all_words): if len(self.word2id) < vocab_size: self.word2id[w] = len(self.word2id) self.id2word = {v: k for k, v in self.word2id.items()} # Process samples for img_path, answer in raw_samples: try: img = self.transform(Image.open(img_path).convert("RGB")) ids = [1] # BOS for w in answer.lower().split()[:10]: ids.append(self.word2id.get(w, 0)) ids.append(2) # EOS # Pad to 16 ids = (ids + [0] * 16)[:16] self.samples.append((img, torch.tensor(ids, dtype=torch.long))) except Exception: continue print(f" Loaded {len(self.samples)} samples, vocab {len(self.word2id)} words") def __len__(self): return len(self.samples) def __getitem__(self, idx): return self.samples[idx] def decode(self, ids): return " ".join(self.id2word.get(i, "?") for i in ids if i > 2) # ============================================================================ # Experiment # ============================================================================ def run_experiment(data_dir, device, max_samples=1000, epochs=10): print("=" * 70) print("EXPERIMENT: Single Visual Token vs Multi-Token Visual Prefix") print("=" * 70) dataset = SimpleVQADataset(data_dir, max_samples=max_samples, img_size=224) if len(dataset) < 10: print("FATAL: Not enough data. Need at least 10 real image-text pairs.") return loader = DataLoader(dataset, batch_size=16, shuffle=True, drop_last=True) vocab_size = len(dataset.word2id) # ---- Model A: Single visual token ---- print("\n--- Model A: Single Visual Token (Current Architecture) ---") enc_a = TinyEncoderA(img_size=224, patch_size=16, hidden_dim=256, embed_dim=512).to(device) dec_a = TinyDecoderA(embed_dim=512, hidden_dim=256, vocab_size=vocab_size).to(device) opt_a = torch.optim.AdamW(list(enc_a.parameters()) + list(dec_a.parameters()), lr=1e-3) for epoch in range(epochs): total_loss = 0 n = 0 for images, targets in loader: images, targets = images.to(device), targets.to(device) emb = enc_a(images) # [B, 512] single vector loss = dec_a(emb, targets) opt_a.zero_grad() loss.backward() opt_a.step() total_loss += loss.item() n += 1 print(f" Epoch {epoch+1}/{epochs}: loss={total_loss/n:.4f}") # Test generation enc_a.eval() dec_a.eval() print("\n Generation test (Model A):") for i in range(5): img, target = dataset[i] emb = enc_a(img.unsqueeze(0).to(device)) gen_ids = dec_a.generate(emb, max_tokens=16) gen_text = dataset.decode(gen_ids) target_text = dataset.decode(target.tolist()) print(f" [{i}] Target: {target_text}") print(f" Generated: {gen_text}") print(f" Raw IDs: {gen_ids[:10]}") # ---- Model B: Multi-token visual prefix ---- print("\n--- Model B: Multi-Token Visual Prefix (Proposed Fix) ---") enc_b = TinyEncoderB(img_size=224, patch_size=16, hidden_dim=256, n_visual_tokens=16).to(device) dec_b = TinyDecoderB(hidden_dim=256, vocab_size=vocab_size).to(device) opt_b = torch.optim.AdamW(list(enc_b.parameters()) + list(dec_b.parameters()), lr=1e-3) for epoch in range(epochs): total_loss = 0 n = 0 for images, targets in loader: images, targets = images.to(device), targets.to(device) vis = enc_b(images) # [B, 16, 256] multi-token loss = dec_b(vis, targets) opt_b.zero_grad() loss.backward() opt_b.step() total_loss += loss.item() n += 1 print(f" Epoch {epoch+1}/{epochs}: loss={total_loss/n:.4f}") # Test generation enc_b.eval() dec_b.eval() print("\n Generation test (Model B):") for i in range(5): img, target = dataset[i] vis = enc_b(img.unsqueeze(0).to(device)) gen_ids = dec_b.generate(vis, max_tokens=16) gen_text = dataset.decode(gen_ids) target_text = dataset.decode(target.tolist()) print(f" [{i}] Target: {target_text}") print(f" Generated: {gen_text}") print(f" Raw IDs: {gen_ids[:10]}") print("\n" + "=" * 70) print("EXPERIMENT COMPLETE") print("=" * 70) print() print("Interpretation:") print(" - If BOTH collapse → data or training issue") print(" - If B works but A doesn't → architecture bottleneck confirmed") print(" - If A works too → data scale was the issue, not architecture") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--data_dir", default="data/downloads/stage1") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") parser.add_argument("--max_samples", type=int, default=1000) parser.add_argument("--epochs", type=int, default=10) args = parser.parse_args() run_experiment(args.data_dir, args.device, args.max_samples, args.epochs)