arcisvlm / scripts /test_inference.py
Hardik Sanghvi
feat: integrate Gemma 4 E2B backbone for production-quality VLM inference
7a564e3
Raw
History Blame Contribute Delete
6.19 kB
#!/usr/bin/env python3
"""
Standalone inference test for ArcisVLM.
Tests that the model can generate coherent text given an image + question.
This is the FIRST thing to run after training to verify the model works.
Usage:
python3 scripts/test_inference.py --ckpt checkpoints/v4_stage3_final.pt --device cuda
python3 scripts/test_inference.py --ckpt checkpoints/v4_stage3_final.pt --device cpu
"""
import argparse
import os
import sys
import time
import torch
import yaml
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from model.vlm import VLJEPAModel
from model.tokenizer_utils import load_tokenizer, validate_tokenizer_model_match
TEST_QUESTIONS = [
"What do you see in this image?",
"How many people are in the scene?",
"Describe the objects in the image.",
"Is there a car in this image?",
"What color is the main object?",
"What is happening in this scene?",
"Are there any people?",
"Count the vehicles.",
"What text is visible?",
"Describe the weather conditions.",
]
def main():
parser = argparse.ArgumentParser(description="ArcisVLM Inference Test")
parser.add_argument("--ckpt", required=True, help="Checkpoint path")
parser.add_argument("--config", default="configs/scale_1.3b.yaml")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--max-tokens", type=int, default=64)
parser.add_argument("--temperature", type=float, default=0.7)
args = parser.parse_args()
print("=" * 70)
print("ArcisVLM Inference Test")
print("=" * 70)
# Load config
with open(args.config) as f:
config = yaml.safe_load(f)
img_size = config.get("vision", {}).get("img_size", 448)
# Load tokenizer FIRST
print("\n--- Tokenizer ---")
ckpt_dir = os.path.dirname(args.ckpt)
tokenizer = load_tokenizer(config, checkpoint_dir=ckpt_dir)
# Load model
print("\n--- Model ---")
model = VLJEPAModel(config)
if os.path.exists(args.ckpt):
ckpt = torch.load(args.ckpt, map_location=args.device, weights_only=False)
if "model_state_dict" in ckpt:
sd = ckpt["model_state_dict"]
# Handle DDP 'module.' prefix
cleaned = {}
for k, v in sd.items():
cleaned[k.replace("module.", "")] = v
missing, unexpected = model.load_state_dict(cleaned, strict=False)
print(f" Loaded: {args.ckpt}")
print(f" Epoch: {ckpt.get('epoch', '?')}, Loss: {ckpt.get('loss', '?')}")
if missing:
print(f" Missing keys: {len(missing)} (e.g. {missing[:3]})")
if unexpected:
print(f" Unexpected keys: {len(unexpected)} (e.g. {unexpected[:3]})")
else:
model.load_state_dict(ckpt, strict=False)
print(f" Loaded raw state dict: {args.ckpt}")
else:
print(f" [WARN] Checkpoint not found: {args.ckpt}")
print(f" Running with random weights (sanity check only)")
model = model.to(args.device)
model.eval()
params = sum(p.numel() for p in model.parameters())
print(f" Parameters: {params:,} ({params/1e9:.2f}B)")
# Validate tokenizer-model match
print("\n--- Validation ---")
match = validate_tokenizer_model_match(tokenizer, model)
if not match:
print(" [FATAL] Tokenizer-model mismatch! Results will be garbage.")
print(" Fix: download correct tokenizer from HuggingFace")
# Run inference on test questions
print(f"\n--- Inference ({len(TEST_QUESTIONS)} questions) ---")
print(f" Device: {args.device}")
print(f" Max tokens: {args.max_tokens}")
print(f" Temperature: {args.temperature}")
print()
# Create a dummy image (random noise — not ideal but tests generation)
dummy_image = torch.randn(1, 3, img_size, img_size).to(args.device)
total_time = 0
for i, question in enumerate(TEST_QUESTIONS):
print(f" Q{i+1}: {question}")
# Tokenize question
q_ids = tokenizer.encode(question)
q_tensor = torch.tensor([q_ids], dtype=torch.long, device=args.device)
# Generate
start = time.time()
with torch.no_grad():
try:
output_ids = model.generate(
dummy_image, q_tensor,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
)
elapsed = time.time() - start
total_time += elapsed
# Decode
if output_ids is not None and output_ids.numel() > 0:
pred_text = tokenizer.decode(output_ids[0].cpu().tolist())
# Clean up
pred_text = pred_text.replace("<pad>", "").replace("<eos>", "").strip()
print(f" A{i+1}: {pred_text[:200]}")
print(f" [{elapsed:.2f}s, {output_ids.shape[-1]} tokens]")
else:
print(f" A{i+1}: [EMPTY OUTPUT]")
print(f" [{elapsed:.2f}s]")
except Exception as e:
elapsed = time.time() - start
print(f" A{i+1}: [ERROR] {e}")
print(f" [{elapsed:.2f}s]")
print()
print("=" * 70)
print(f"Total inference time: {total_time:.2f}s")
print(f"Average per query: {total_time / len(TEST_QUESTIONS):.2f}s")
print("=" * 70)
# Summary
print("\n--- Diagnosis ---")
print("If ALL answers are empty or garbage (random tokens):")
print(" → Tokenizer mismatch between training and inference")
print(" → Check: tokenizer vocab == model decoder vocab")
print()
print("If answers are repetitive (same word repeated):")
print(" → Model collapsed during training (mode collapse)")
print(" → Check: training loss was actually decreasing")
print()
print("If answers are coherent but wrong:")
print(" → Model needs more/better training data")
print(" → Architecture is working, just needs scale")
if __name__ == "__main__":
main()