arcisvlm / scripts /quick_eval_guard.py
Hardik Sanghvi
feat: integrate Gemma 4 E2B backbone for production-quality VLM inference
7a564e3
Raw
History Blame Contribute Delete
8.65 kB
#!/usr/bin/env python3
"""
Guard evaluation for autoresearch — runs on a SEPARATE test set.
Ensures changes that improve val_loss don't regress on unseen data.
Uses stage3 data (detection/caption) as a domain-shifted guard set.
Exits 0 if guard_loss is acceptable, exits 1 if regression detected.
Usage:
python3 scripts/quick_eval_guard.py --ckpt checkpoints/stage2_epoch1.pt --config configs/default.yaml --device cpu
python3 scripts/quick_eval_guard.py --ckpt checkpoints/stage2_epoch1.pt --config configs/scale_1.3b.yaml --device cuda
"""
import argparse
import json
import os
import sys
import torch
import yaml
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from model.vlm import VLJEPAModel
from model.tokenizer import BPETokenizer
# ---------------------------------------------------------------------------
# Default baseline — 2x this value is the regression threshold.
# On an untrained model with random weights, decode_loss is typically ~10-11.
# After a few epochs of stage2 on real data it should be ~2-4.
# We use a generous baseline so the guard only fires on clear regressions.
# ---------------------------------------------------------------------------
DEFAULT_BASELINE_LOSS = 12.0
def load_model_and_config(config_path: str, ckpt_path: str | None, device: str):
"""Load config, build model, optionally load checkpoint."""
with open(config_path) as f:
config = yaml.safe_load(f)
model = VLJEPAModel(config)
if ckpt_path and os.path.exists(ckpt_path):
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
state = ckpt.get("model_state_dict", ckpt)
model.load_state_dict(state, strict=False)
print(f"[guard] Loaded checkpoint: {ckpt_path}", file=sys.stderr)
else:
print("[guard] No checkpoint — evaluating random init", file=sys.stderr)
model = model.to(device)
model.eval()
return model, config
def load_tokenizer(config: dict) -> BPETokenizer:
"""Load or create a BPE tokenizer."""
vocab_size = config.get("decoder", {}).get("vocab_size", 8192)
tokenizer = BPETokenizer(vocab_size=vocab_size)
for path in ["checkpoints/tokenizer_32k.json", "checkpoints/tokenizer.json"]:
if os.path.exists(path):
tokenizer.load(path)
return tokenizer
tokenizer.train(["dummy guard evaluation text"] * 50)
return tokenizer
def _tokenize_sample(question: str, answer: str, tokenizer, img_size: int,
max_q: int = 64, max_a: int = 128) -> dict:
"""Tokenize a QA pair into a sample dict."""
q_ids = tokenizer.encode(question)
a_ids = tokenizer.encode(answer)
q_ids = (q_ids[:max_q] + [tokenizer.pad_id] * max_q)[:max_q]
a_ids = (a_ids[:max_a] + [tokenizer.pad_id] * max_a)[:max_a]
q_tensor = torch.tensor(q_ids, dtype=torch.long)
a_tensor = torch.tensor(a_ids, dtype=torch.long)
# NOTE: Guard eval uses text-only tokenized data with placeholder images.
# This is acceptable because guard eval measures decode loss regression,
# not visual understanding. The image tensor is a required input shape
# but visual content is irrelevant for loss comparison.
return {
"image": torch.zeros(3, img_size, img_size),
"question_ids": q_tensor,
"question_mask": (q_tensor != tokenizer.pad_id).long(),
"answer_ids": a_tensor,
}
def build_guard_data(config: dict, tokenizer, max_samples: int = 500) -> list[dict]:
"""Build guard evaluation set from stage3 data (detection/caption domain).
Falls back to dummy detection-style data if stage3 JSONL is unavailable.
"""
img_size = config["vision"]["img_size"]
vocab_size = config["decoder"]["vocab_size"]
# Try stage3 JSONL data
stage3_dir = "data/downloads/stage3"
if os.path.isdir(stage3_dir):
samples = []
for fname in sorted(os.listdir(stage3_dir)):
if not fname.endswith(".jsonl"):
continue
fpath = os.path.join(stage3_dir, fname)
with open(fpath) as f:
for line in f:
if len(samples) >= max_samples:
break
try:
item = json.loads(line.strip())
except (json.JSONDecodeError, ValueError):
continue
question = item.get("question", item.get("text", ""))
answer = item.get("answer", item.get("caption", ""))
# LLaVA format
if not question and "conversations" in item:
convos = item["conversations"]
if isinstance(convos, list) and len(convos) >= 2:
question = convos[0].get("value", "") if isinstance(convos[0], dict) else str(convos[0])
answer = convos[1].get("value", "") if isinstance(convos[1], dict) else str(convos[1])
if not question:
question = "Describe what you see in this image."
if not answer:
answer = "unknown"
samples.append(
_tokenize_sample(question, answer, tokenizer, img_size)
)
if len(samples) >= max_samples:
break
if samples:
print(f"[guard] Loaded {len(samples)} stage3 guard samples", file=sys.stderr)
return samples
# Fallback: synthetic detection-style guard samples
# These exercise a different domain than stage2 VQA data
detection_prompts = [
("Detect all objects in this surveillance frame.", "person: [120, 50, 200, 300]; car: [400, 200, 600, 350]"),
("What objects are visible?", "Two people walking near a parked vehicle."),
("Count the number of people.", "3 people detected."),
("Is there any suspicious activity?", "No suspicious activity detected."),
("Describe the scene.", "An outdoor parking lot with several vehicles and pedestrians."),
]
samples = []
for i in range(max_samples):
q, a = detection_prompts[i % len(detection_prompts)]
samples.append(_tokenize_sample(q, a, tokenizer, img_size))
print(f"[guard] Using {len(samples)} synthetic guard samples (no stage3 data found)", file=sys.stderr)
return samples
@torch.no_grad()
def evaluate_guard(model, guard_data: list[dict], device: str) -> float:
"""Compute average decode loss on guard set."""
model.eval()
total_loss = 0.0
count = 0
for sample in guard_data:
images = sample["image"].unsqueeze(0).to(device)
q_ids = sample["question_ids"].unsqueeze(0).to(device)
q_mask = sample["question_mask"].unsqueeze(0).to(device)
a_ids = sample["answer_ids"].unsqueeze(0).to(device)
output = model.forward_stage2(
images=images,
query_ids=q_ids,
query_padding_mask=q_mask,
answer_ids=a_ids,
load_balance_weight=0.0,
)
total_loss += output["decode_loss"].item()
count += 1
return total_loss / max(count, 1)
def main():
parser = argparse.ArgumentParser(description="Guard evaluation for autoresearch")
parser.add_argument("--ckpt", type=str, default=None, help="Checkpoint path")
parser.add_argument("--config", type=str, required=True, help="YAML config path")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--guard-samples", type=int, default=500, help="Number of guard samples")
parser.add_argument("--baseline", type=float, default=DEFAULT_BASELINE_LOSS,
help="Baseline loss; threshold = 2x this value")
args = parser.parse_args()
model, config = load_model_and_config(args.config, args.ckpt, args.device)
tokenizer = load_tokenizer(config)
guard_data = build_guard_data(config, tokenizer, args.guard_samples)
guard_loss = evaluate_guard(model, guard_data, args.device)
threshold = 2.0 * args.baseline
# --- Parseable output ---
print(f"guard_loss: {guard_loss:.4f}")
print(f"threshold: {threshold:.4f}")
if guard_loss > threshold:
print(f"GUARD FAILED: guard_loss {guard_loss:.4f} > threshold {threshold:.4f}", file=sys.stderr)
sys.exit(1)
else:
print(f"GUARD PASSED: guard_loss {guard_loss:.4f} <= threshold {threshold:.4f}", file=sys.stderr)
sys.exit(0)
if __name__ == "__main__":
main()