arcisvlm / scripts /experiment_visual_prefix.py
Hardik Sanghvi
feat: integrate Gemma 4 E2B backbone for production-quality VLM inference
7a564e3
Raw
History Blame Contribute Delete
16.1 kB
#!/usr/bin/env python3
"""
Quick validation experiment: Single visual token vs Multi-token visual prefix.
Trains two TINY models (2-layer decoder, shared ViT) on 1K real images for 10 epochs.
Compares generation quality to determine if the architecture bottleneck is real.
Result interpretation:
- If BOTH collapse → data/training issue, not architecture
- If multi-token works but single doesn't → architecture bottleneck confirmed
- If single-token works → data scale was the issue all along
Usage:
python3 scripts/experiment_visual_prefix.py --data_dir data/downloads/stage1 --device cuda
python3 scripts/experiment_visual_prefix.py --data_dir data/downloads/stage2_fullscale --device cuda
"""
import argparse
import json
import os
import sys
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# ============================================================================
# Tiny Model A: Single visual token (current architecture)
# ============================================================================
class TinyEncoderA(nn.Module):
"""Simplified ViT → single embedding."""
def __init__(self, img_size=224, patch_size=16, hidden_dim=256, embed_dim=512):
super().__init__()
n_patches = (img_size // patch_size) ** 2
self.patch_embed = nn.Conv2d(3, hidden_dim, patch_size, patch_size)
self.pos_embed = nn.Parameter(torch.randn(1, n_patches, hidden_dim) * 0.02)
self.blocks = nn.ModuleList([
nn.TransformerEncoderLayer(hidden_dim, nhead=4, dim_feedforward=hidden_dim*4, batch_first=True)
for _ in range(2)
])
self.norm = nn.LayerNorm(hidden_dim)
self.proj = nn.Linear(hidden_dim, embed_dim)
def forward(self, images):
x = self.patch_embed(images).flatten(2).transpose(1, 2) # [B, N, D]
x = x + self.pos_embed[:, :x.shape[1], :]
for block in self.blocks:
x = block(x)
x = self.norm(x)
x = x.mean(dim=1) # average pool → single vector
x = self.proj(x)
x = F.normalize(x, p=2, dim=-1) # L2-normalize like current arch
return x # [B, embed_dim]
class TinyDecoderA(nn.Module):
"""Single-token decoder (current architecture)."""
def __init__(self, embed_dim=512, hidden_dim=256, vocab_size=1000, max_seq=64):
super().__init__()
self.embed_proj = nn.Linear(embed_dim, hidden_dim)
self.token_embed = nn.Embedding(vocab_size, hidden_dim, padding_idx=0)
self.pos_embed = nn.Parameter(torch.randn(1, max_seq + 1, hidden_dim) * 0.02)
self.blocks = nn.ModuleList([
nn.TransformerDecoderLayer(hidden_dim, nhead=4, dim_feedforward=hidden_dim*4, batch_first=True)
for _ in range(2)
])
self.norm = nn.LayerNorm(hidden_dim)
self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False)
self.hidden_dim = hidden_dim
self.vocab_size = vocab_size
def forward(self, pred_embedding, target_ids):
embed_token = self.embed_proj(pred_embedding).unsqueeze(1) # [B, 1, D]
token_embeds = self.token_embed(target_ids) # [B, T, D]
x = torch.cat([embed_token, token_embeds], dim=1) # [B, 1+T, D]
T = x.shape[1]
x = x + self.pos_embed[:, :T, :]
# Causal mask
mask = nn.Transformer.generate_square_subsequent_mask(T, device=x.device)
for block in self.blocks:
x = block(x, x, tgt_mask=mask, memory_mask=mask)
x = self.norm(x)
logits = self.lm_head(x)
shift_logits = logits[:, :-1, :].contiguous()
shift_targets = target_ids.contiguous()
loss = F.cross_entropy(shift_logits.view(-1, self.vocab_size), shift_targets.view(-1), ignore_index=0)
return loss
@torch.no_grad()
def generate(self, pred_embedding, max_tokens=32, eos_id=2):
embed_token = self.embed_proj(pred_embedding).unsqueeze(1)
x = embed_token
generated = []
for _ in range(max_tokens):
T = x.shape[1]
pos_x = x + self.pos_embed[:, :T, :]
mask = nn.Transformer.generate_square_subsequent_mask(T, device=x.device)
h = pos_x
for block in self.blocks:
h = block(h, h, tgt_mask=mask, memory_mask=mask)
h = self.norm(h)
logits = self.lm_head(h[:, -1, :])
next_token = logits.argmax(dim=-1)
generated.append(next_token.item())
if next_token.item() == eos_id:
break
next_embed = self.token_embed(next_token.unsqueeze(0))
x = torch.cat([x, next_embed], dim=1)
return generated
# ============================================================================
# Tiny Model B: Multi-token visual prefix (proposed fix)
# ============================================================================
class TinyEncoderB(nn.Module):
"""ViT → multi-token output (N visual tokens)."""
def __init__(self, img_size=224, patch_size=16, hidden_dim=256, n_visual_tokens=16):
super().__init__()
n_patches = (img_size // patch_size) ** 2
self.patch_embed = nn.Conv2d(3, hidden_dim, patch_size, patch_size)
self.pos_embed = nn.Parameter(torch.randn(1, n_patches, hidden_dim) * 0.02)
self.blocks = nn.ModuleList([
nn.TransformerEncoderLayer(hidden_dim, nhead=4, dim_feedforward=hidden_dim*4, batch_first=True)
for _ in range(2)
])
self.norm = nn.LayerNorm(hidden_dim)
# Learned queries for downsampling
self.queries = nn.Parameter(torch.randn(1, n_visual_tokens, hidden_dim) * 0.02)
self.cross_attn = nn.MultiheadAttention(hidden_dim, num_heads=4, batch_first=True)
self.query_norm = nn.LayerNorm(hidden_dim)
def forward(self, images):
x = self.patch_embed(images).flatten(2).transpose(1, 2)
x = x + self.pos_embed[:, :x.shape[1], :]
for block in self.blocks:
x = block(x)
x = self.norm(x) # [B, N_patches, D]
# Downsample to N visual tokens via cross-attention
B = x.shape[0]
queries = self.queries.expand(B, -1, -1)
out, _ = self.cross_attn(queries, x, x)
return self.query_norm(out) # [B, N_visual, D]
class TinyDecoderB(nn.Module):
"""Multi-token visual prefix decoder."""
def __init__(self, hidden_dim=256, vocab_size=1000, max_seq=64):
super().__init__()
self.token_embed = nn.Embedding(vocab_size, hidden_dim, padding_idx=0)
self.pos_embed = nn.Parameter(torch.randn(1, max_seq + 32, hidden_dim) * 0.02)
self.blocks = nn.ModuleList([
nn.TransformerDecoderLayer(hidden_dim, nhead=4, dim_feedforward=hidden_dim*4, batch_first=True)
for _ in range(2)
])
self.norm = nn.LayerNorm(hidden_dim)
self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False)
self.hidden_dim = hidden_dim
self.vocab_size = vocab_size
def forward(self, visual_prefix, target_ids):
# visual_prefix: [B, N, D] — N visual tokens
N = visual_prefix.shape[1]
token_embeds = self.token_embed(target_ids) # [B, T, D]
x = torch.cat([visual_prefix, token_embeds], dim=1) # [B, N+T, D]
T_total = x.shape[1]
x = x + self.pos_embed[:, :T_total, :]
mask = nn.Transformer.generate_square_subsequent_mask(T_total, device=x.device)
for block in self.blocks:
x = block(x, x, tgt_mask=mask, memory_mask=mask)
x = self.norm(x)
logits = self.lm_head(x)
# Loss: only on answer positions (after visual prefix)
answer_logits = logits[:, N-1:-1, :].contiguous() # [B, T, vocab]
shift_targets = target_ids.contiguous()
loss = F.cross_entropy(answer_logits.view(-1, self.vocab_size), shift_targets.view(-1), ignore_index=0)
return loss
@torch.no_grad()
def generate(self, visual_prefix, max_tokens=32, eos_id=2):
x = visual_prefix # [1, N, D]
generated = []
for _ in range(max_tokens):
T = x.shape[1]
pos_x = x + self.pos_embed[:, :T, :]
mask = nn.Transformer.generate_square_subsequent_mask(T, device=x.device)
h = pos_x
for block in self.blocks:
h = block(h, h, tgt_mask=mask, memory_mask=mask)
h = self.norm(h)
logits = self.lm_head(h[:, -1, :])
next_token = logits.argmax(dim=-1)
generated.append(next_token.item())
if next_token.item() == eos_id:
break
next_embed = self.token_embed(next_token.unsqueeze(0))
x = torch.cat([x, next_embed], dim=1)
return generated
# ============================================================================
# Simple Dataset
# ============================================================================
class SimpleVQADataset(Dataset):
"""Load N real image-text pairs from JSONL."""
def __init__(self, jsonl_dir, max_samples=1000, img_size=224, vocab_size=1000):
from PIL import Image
from torchvision import transforms
self.transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
self.samples = []
self.vocab_size = vocab_size
# Build simple word→id vocab from data
all_words = set()
raw_samples = []
for fname in sorted(os.listdir(jsonl_dir)):
if fname.endswith('.jsonl'):
with open(os.path.join(jsonl_dir, fname)) as f:
for line in f:
try:
item = json.loads(line.strip())
img_path = item.get("image_path")
answer = item.get("answer", "")
if img_path and os.path.exists(img_path) and answer:
raw_samples.append((img_path, answer))
all_words.update(answer.lower().split()[:20])
if len(raw_samples) >= max_samples:
break
except json.JSONDecodeError:
continue
if len(raw_samples) >= max_samples:
break
# Simple vocab: PAD=0, BOS=1, EOS=2, then words
self.word2id = {"<pad>": 0, "<bos>": 1, "<eos>": 2}
for w in sorted(all_words):
if len(self.word2id) < vocab_size:
self.word2id[w] = len(self.word2id)
self.id2word = {v: k for k, v in self.word2id.items()}
# Process samples
for img_path, answer in raw_samples:
try:
img = self.transform(Image.open(img_path).convert("RGB"))
ids = [1] # BOS
for w in answer.lower().split()[:10]:
ids.append(self.word2id.get(w, 0))
ids.append(2) # EOS
# Pad to 16
ids = (ids + [0] * 16)[:16]
self.samples.append((img, torch.tensor(ids, dtype=torch.long)))
except Exception:
continue
print(f" Loaded {len(self.samples)} samples, vocab {len(self.word2id)} words")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.samples[idx]
def decode(self, ids):
return " ".join(self.id2word.get(i, "?") for i in ids if i > 2)
# ============================================================================
# Experiment
# ============================================================================
def run_experiment(data_dir, device, max_samples=1000, epochs=10):
print("=" * 70)
print("EXPERIMENT: Single Visual Token vs Multi-Token Visual Prefix")
print("=" * 70)
dataset = SimpleVQADataset(data_dir, max_samples=max_samples, img_size=224)
if len(dataset) < 10:
print("FATAL: Not enough data. Need at least 10 real image-text pairs.")
return
loader = DataLoader(dataset, batch_size=16, shuffle=True, drop_last=True)
vocab_size = len(dataset.word2id)
# ---- Model A: Single visual token ----
print("\n--- Model A: Single Visual Token (Current Architecture) ---")
enc_a = TinyEncoderA(img_size=224, patch_size=16, hidden_dim=256, embed_dim=512).to(device)
dec_a = TinyDecoderA(embed_dim=512, hidden_dim=256, vocab_size=vocab_size).to(device)
opt_a = torch.optim.AdamW(list(enc_a.parameters()) + list(dec_a.parameters()), lr=1e-3)
for epoch in range(epochs):
total_loss = 0
n = 0
for images, targets in loader:
images, targets = images.to(device), targets.to(device)
emb = enc_a(images) # [B, 512] single vector
loss = dec_a(emb, targets)
opt_a.zero_grad()
loss.backward()
opt_a.step()
total_loss += loss.item()
n += 1
print(f" Epoch {epoch+1}/{epochs}: loss={total_loss/n:.4f}")
# Test generation
enc_a.eval()
dec_a.eval()
print("\n Generation test (Model A):")
for i in range(5):
img, target = dataset[i]
emb = enc_a(img.unsqueeze(0).to(device))
gen_ids = dec_a.generate(emb, max_tokens=16)
gen_text = dataset.decode(gen_ids)
target_text = dataset.decode(target.tolist())
print(f" [{i}] Target: {target_text}")
print(f" Generated: {gen_text}")
print(f" Raw IDs: {gen_ids[:10]}")
# ---- Model B: Multi-token visual prefix ----
print("\n--- Model B: Multi-Token Visual Prefix (Proposed Fix) ---")
enc_b = TinyEncoderB(img_size=224, patch_size=16, hidden_dim=256, n_visual_tokens=16).to(device)
dec_b = TinyDecoderB(hidden_dim=256, vocab_size=vocab_size).to(device)
opt_b = torch.optim.AdamW(list(enc_b.parameters()) + list(dec_b.parameters()), lr=1e-3)
for epoch in range(epochs):
total_loss = 0
n = 0
for images, targets in loader:
images, targets = images.to(device), targets.to(device)
vis = enc_b(images) # [B, 16, 256] multi-token
loss = dec_b(vis, targets)
opt_b.zero_grad()
loss.backward()
opt_b.step()
total_loss += loss.item()
n += 1
print(f" Epoch {epoch+1}/{epochs}: loss={total_loss/n:.4f}")
# Test generation
enc_b.eval()
dec_b.eval()
print("\n Generation test (Model B):")
for i in range(5):
img, target = dataset[i]
vis = enc_b(img.unsqueeze(0).to(device))
gen_ids = dec_b.generate(vis, max_tokens=16)
gen_text = dataset.decode(gen_ids)
target_text = dataset.decode(target.tolist())
print(f" [{i}] Target: {target_text}")
print(f" Generated: {gen_text}")
print(f" Raw IDs: {gen_ids[:10]}")
print("\n" + "=" * 70)
print("EXPERIMENT COMPLETE")
print("=" * 70)
print()
print("Interpretation:")
print(" - If BOTH collapse → data or training issue")
print(" - If B works but A doesn't → architecture bottleneck confirmed")
print(" - If A works too → data scale was the issue, not architecture")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", default="data/downloads/stage1")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--max_samples", type=int, default=1000)
parser.add_argument("--epochs", type=int, default=10)
args = parser.parse_args()
run_experiment(args.data_dir, args.device, args.max_samples, args.epochs)