arcisvlm / scripts /quick_eval.py
Hardik Sanghvi
feat: integrate Gemma 4 E2B backbone for production-quality VLM inference
7a564e3
Raw
History Blame Contribute Delete
14.4 kB
#!/usr/bin/env python3
"""
Quick evaluation for autoresearch loop.
Loads a checkpoint, trains for N minutes on a subsample, evaluates on held-out data.
Outputs: val_loss: X.XXXX (parseable by autoresearch skill)
Usage:
python3 scripts/quick_eval.py --config configs/default.yaml --device cpu --train-minutes 0.1 --eval-samples 10
python3 scripts/quick_eval.py --ckpt checkpoints/stage2_epoch1.pt --config configs/scale_1.3b.yaml --device cuda --train-minutes 5
"""
import argparse
import math
import os
import sys
import time
import torch
import torch.nn.functional as F
import yaml
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from model.vlm import VLJEPAModel
from model.tokenizer import BPETokenizer
# ---------------------------------------------------------------------------
# Dataset helpers
# ---------------------------------------------------------------------------
class SubsampledDataset(torch.utils.data.Dataset):
"""Wraps a list of samples dicts with image/question_ids/question_mask/answer_ids."""
def __init__(self, samples: list[dict]):
self.samples = samples
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.samples[idx]
def load_jsonl_samples(jsonl_dir: str, tokenizer, img_size: int,
max_q: int = 64, max_a: int = 128,
max_lines: int | None = None,
skip_lines: int = 0) -> list[dict]:
"""Load samples from JSONL files under jsonl_dir.
Args:
jsonl_dir: Directory containing .jsonl files.
tokenizer: BPETokenizer instance.
img_size: Image size for dummy images.
max_q: Max query token length.
max_a: Max answer token length.
max_lines: Stop after this many lines (None = load all).
skip_lines: Skip this many lines from the start.
"""
import json
samples = []
lines_seen = 0
for fname in sorted(os.listdir(jsonl_dir)):
if not fname.endswith(".jsonl"):
continue
fpath = os.path.join(jsonl_dir, fname)
with open(fpath) as f:
for line in f:
lines_seen += 1
if lines_seen <= skip_lines:
continue
try:
item = json.loads(line.strip())
except (json.JSONDecodeError, ValueError):
continue
question, answer = _extract_qa(item)
if not answer:
answer = "unknown"
sample = _tokenize_sample(
question, answer, tokenizer, img_size, max_q, max_a
)
samples.append(sample)
if max_lines is not None and len(samples) >= max_lines:
return samples
return samples
def _extract_qa(item: dict) -> tuple[str, str]:
"""Extract question and answer from various JSONL formats."""
question = ""
answer = ""
# LLaVA-Instruct format
if "conversations" in item:
convos = item["conversations"]
if isinstance(convos, list) and len(convos) >= 2:
question = convos[0].get("value", "") if isinstance(convos[0], dict) else str(convos[0])
answer = convos[1].get("value", "") if isinstance(convos[1], dict) else str(convos[1])
# VQAv2/GQA format
if not question:
question = item.get("question", item.get("text", "What do you see?"))
if not answer:
answer = item.get("answer", item.get("multiple_choice_answer", ""))
if not answer and "answers" in item:
answers = item["answers"]
if isinstance(answers, list) and answers:
answer = answers[0].get("answer", str(answers[0])) if isinstance(answers[0], dict) else str(answers[0])
return str(question), str(answer)
def _tokenize_sample(question: str, answer: str, tokenizer, img_size: int,
max_q: int, max_a: int) -> dict:
"""Tokenize a single QA pair into a training sample dict."""
q_ids = tokenizer.encode(question)
a_ids = tokenizer.encode(answer)
# Pad/truncate
q_ids = (q_ids[:max_q] + [tokenizer.pad_id] * max_q)[:max_q]
a_ids = (a_ids[:max_a] + [tokenizer.pad_id] * max_a)[:max_a]
q_tensor = torch.tensor(q_ids, dtype=torch.long)
a_tensor = torch.tensor(a_ids, dtype=torch.long)
# NOTE: quick_eval uses text-only JSONL data; image is a required tensor shape
# but visual content is not used for decode loss measurement.
return {
"image": torch.zeros(3, img_size, img_size),
"question_ids": q_tensor,
"question_mask": (q_tensor != tokenizer.pad_id).long(),
"answer_ids": a_tensor,
}
# ---------------------------------------------------------------------------
# Cosine warmup scheduler (standalone, no DDP dependency)
# ---------------------------------------------------------------------------
class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):
"""Linear warmup then cosine decay."""
def __init__(self, optimizer, warmup_steps: int, total_steps: int,
min_lr: float = 1e-7, last_epoch: int = -1):
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.min_lr = min_lr
super().__init__(optimizer, last_epoch)
def get_lr(self):
step = self.last_epoch
if step < self.warmup_steps:
scale = step / max(1, self.warmup_steps)
return [base_lr * scale for base_lr in self.base_lrs]
progress = (step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps)
cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
return [self.min_lr + (base_lr - self.min_lr) * cosine for base_lr in self.base_lrs]
# ---------------------------------------------------------------------------
# Core routines
# ---------------------------------------------------------------------------
def load_model_and_config(config_path: str, ckpt_path: str | None, device: str):
"""Load config, build model, optionally load checkpoint weights."""
with open(config_path) as f:
config = yaml.safe_load(f)
model = VLJEPAModel(config)
if ckpt_path and os.path.exists(ckpt_path):
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
state = ckpt.get("model_state_dict", ckpt)
model.load_state_dict(state, strict=False)
print(f"[quick_eval] Loaded checkpoint: {ckpt_path}", file=sys.stderr)
else:
print("[quick_eval] No checkpoint — initializing from scratch", file=sys.stderr)
model = model.to(device)
return model, config
def load_tokenizer(config: dict) -> BPETokenizer:
"""Load tokenizer — NO dummy fallback."""
from model.tokenizer_utils import load_tokenizer as _load
return _load(config)
def build_train_eval_data(config: dict, tokenizer, eval_samples: int):
"""Return (train_samples, eval_samples) lists of dicts."""
img_size = config["vision"]["img_size"]
vocab_size = config["decoder"]["vocab_size"]
jsonl_dir = "data/downloads/stage2"
if os.path.isdir(jsonl_dir):
# Count total lines to figure out split
import json
total = 0
for fname in sorted(os.listdir(jsonl_dir)):
if fname.endswith(".jsonl"):
with open(os.path.join(jsonl_dir, fname)) as f:
for _ in f:
total += 1
if total > eval_samples + 100:
# Eval = last eval_samples lines; Train = everything before
train_count = total - eval_samples
train_data = load_jsonl_samples(
jsonl_dir, tokenizer, img_size, max_lines=train_count
)
eval_data = load_jsonl_samples(
jsonl_dir, tokenizer, img_size,
skip_lines=train_count, max_lines=eval_samples
)
print(f"[quick_eval] Real data: {len(train_data)} train, {len(eval_data)} eval", file=sys.stderr)
return train_data, eval_data
elif total > 0:
# Too few lines — use first 80% train, last 20% eval
split = max(1, int(total * 0.8))
train_data = load_jsonl_samples(
jsonl_dir, tokenizer, img_size, max_lines=split
)
eval_data = load_jsonl_samples(
jsonl_dir, tokenizer, img_size,
skip_lines=split, max_lines=eval_samples
)
print(f"[quick_eval] Real data (small): {len(train_data)} train, {len(eval_data)} eval", file=sys.stderr)
return train_data, eval_data
raise RuntimeError(
"FATAL: No training data found for quick_eval.\n"
"Download real data first: python3 scripts/download_all_data.py --stage 2\n"
"Required: data/downloads/stage2/ with JSONL files"
)
def train_loop(model, train_data: list[dict], config: dict, device: str,
train_minutes: float) -> float:
"""Train for exactly train_minutes. Returns average training loss."""
if train_minutes <= 0 or len(train_data) == 0:
return float("nan")
stage_cfg = config.get("train_stage2", {})
lr = stage_cfg.get("learning_rate", 1e-4)
grad_clip = stage_cfg.get("gradient_clip", 1.0)
lb_weight = stage_cfg.get("load_balance_weight", 0.01)
batch_size = min(stage_cfg.get("batch_size", 4), len(train_data))
# Freeze x_encoder for stage2-style training
model.freeze_x_encoder()
model.train()
trainable = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(trainable, lr=lr, weight_decay=0.01)
dataset = SubsampledDataset(train_data)
loader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=True, drop_last=True
)
# Simple scheduler
estimated_steps = max(1, int((train_minutes * 60) / 0.5)) # rough estimate
warmup = min(stage_cfg.get("warmup_steps", 50), estimated_steps // 5)
scheduler = CosineWarmupScheduler(optimizer, warmup, estimated_steps)
# Mixed precision on CUDA
use_amp = device.startswith("cuda")
autocast_dtype = torch.bfloat16 if use_amp and torch.cuda.is_bf16_supported() else torch.float16
total_loss = 0.0
steps = 0
deadline = time.time() + train_minutes * 60
while time.time() < deadline:
for batch in loader:
if time.time() >= deadline:
break
images = batch["image"].to(device)
q_ids = batch["question_ids"].to(device)
q_mask = batch["question_mask"].to(device)
a_ids = batch["answer_ids"].to(device)
optimizer.zero_grad(set_to_none=True)
with torch.amp.autocast(device_type=device.split(":")[0], dtype=autocast_dtype, enabled=use_amp):
output = model.forward_stage2(
images=images,
query_ids=q_ids,
query_padding_mask=q_mask,
answer_ids=a_ids,
load_balance_weight=lb_weight,
)
loss = output["loss"]
loss.backward()
torch.nn.utils.clip_grad_norm_(trainable, grad_clip)
optimizer.step()
scheduler.step()
total_loss += loss.item()
steps += 1
if steps % 20 == 0:
print(f"[quick_eval] step {steps} train_loss={loss.item():.4f}", file=sys.stderr)
if time.time() >= deadline:
break
avg = total_loss / max(steps, 1)
print(f"[quick_eval] Training done: {steps} steps in {train_minutes:.1f} min, avg_loss={avg:.4f}", file=sys.stderr)
return avg
@torch.no_grad()
def evaluate(model, eval_data: list[dict], device: str,
max_samples: int | None = None) -> float:
"""Compute average decode loss on eval samples. Returns avg loss."""
model.eval()
n = len(eval_data) if max_samples is None else min(max_samples, len(eval_data))
if n == 0:
return float("nan")
total_loss = 0.0
count = 0
# Process one-by-one to avoid OOM on large batches
for i in range(n):
sample = eval_data[i]
images = sample["image"].unsqueeze(0).to(device)
q_ids = sample["question_ids"].unsqueeze(0).to(device)
q_mask = sample["question_mask"].unsqueeze(0).to(device)
a_ids = sample["answer_ids"].unsqueeze(0).to(device)
output = model.forward_stage2(
images=images,
query_ids=q_ids,
query_padding_mask=q_mask,
answer_ids=a_ids,
load_balance_weight=0.0, # no LB loss for eval
)
total_loss += output["decode_loss"].item()
count += 1
return total_loss / count
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="Quick eval for autoresearch loop")
parser.add_argument("--ckpt", type=str, default=None, help="Checkpoint path (optional)")
parser.add_argument("--config", type=str, required=True, help="YAML config path")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--train-minutes", type=float, default=5.0, help="Minutes to train (0 = eval only)")
parser.add_argument("--eval-samples", type=int, default=1000, help="Number of eval samples")
args = parser.parse_args()
t0 = time.time()
# Load model
model, config = load_model_and_config(args.config, args.ckpt, args.device)
tokenizer = load_tokenizer(config)
# Build data
train_data, eval_data = build_train_eval_data(config, tokenizer, args.eval_samples)
# Train
train_loss = train_loop(model, train_data, config, args.device, args.train_minutes)
# Evaluate
val_loss = evaluate(model, eval_data, args.device, max_samples=args.eval_samples)
elapsed = time.time() - t0
# --- Parseable output (autoresearch reads these lines) ---
print(f"train_loss: {train_loss:.4f}")
print(f"val_loss: {val_loss:.4f}")
print(f"time_elapsed: {elapsed:.1f}s")
if __name__ == "__main__":
main()