#!/usr/bin/env python3 """ Standalone inference test for ArcisVLM. Tests that the model can generate coherent text given an image + question. This is the FIRST thing to run after training to verify the model works. Usage: python3 scripts/test_inference.py --ckpt checkpoints/v4_stage3_final.pt --device cuda python3 scripts/test_inference.py --ckpt checkpoints/v4_stage3_final.pt --device cpu """ import argparse import os import sys import time import torch import yaml sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from model.vlm import VLJEPAModel from model.tokenizer_utils import load_tokenizer, validate_tokenizer_model_match TEST_QUESTIONS = [ "What do you see in this image?", "How many people are in the scene?", "Describe the objects in the image.", "Is there a car in this image?", "What color is the main object?", "What is happening in this scene?", "Are there any people?", "Count the vehicles.", "What text is visible?", "Describe the weather conditions.", ] def main(): parser = argparse.ArgumentParser(description="ArcisVLM Inference Test") parser.add_argument("--ckpt", required=True, help="Checkpoint path") parser.add_argument("--config", default="configs/scale_1.3b.yaml") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") parser.add_argument("--max-tokens", type=int, default=64) parser.add_argument("--temperature", type=float, default=0.7) args = parser.parse_args() print("=" * 70) print("ArcisVLM Inference Test") print("=" * 70) # Load config with open(args.config) as f: config = yaml.safe_load(f) img_size = config.get("vision", {}).get("img_size", 448) # Load tokenizer FIRST print("\n--- Tokenizer ---") ckpt_dir = os.path.dirname(args.ckpt) tokenizer = load_tokenizer(config, checkpoint_dir=ckpt_dir) # Load model print("\n--- Model ---") model = VLJEPAModel(config) if os.path.exists(args.ckpt): ckpt = torch.load(args.ckpt, map_location=args.device, weights_only=False) if "model_state_dict" in ckpt: sd = ckpt["model_state_dict"] # Handle DDP 'module.' prefix cleaned = {} for k, v in sd.items(): cleaned[k.replace("module.", "")] = v missing, unexpected = model.load_state_dict(cleaned, strict=False) print(f" Loaded: {args.ckpt}") print(f" Epoch: {ckpt.get('epoch', '?')}, Loss: {ckpt.get('loss', '?')}") if missing: print(f" Missing keys: {len(missing)} (e.g. {missing[:3]})") if unexpected: print(f" Unexpected keys: {len(unexpected)} (e.g. {unexpected[:3]})") else: model.load_state_dict(ckpt, strict=False) print(f" Loaded raw state dict: {args.ckpt}") else: print(f" [WARN] Checkpoint not found: {args.ckpt}") print(f" Running with random weights (sanity check only)") model = model.to(args.device) model.eval() params = sum(p.numel() for p in model.parameters()) print(f" Parameters: {params:,} ({params/1e9:.2f}B)") # Validate tokenizer-model match print("\n--- Validation ---") match = validate_tokenizer_model_match(tokenizer, model) if not match: print(" [FATAL] Tokenizer-model mismatch! Results will be garbage.") print(" Fix: download correct tokenizer from HuggingFace") # Run inference on test questions print(f"\n--- Inference ({len(TEST_QUESTIONS)} questions) ---") print(f" Device: {args.device}") print(f" Max tokens: {args.max_tokens}") print(f" Temperature: {args.temperature}") print() # Create a dummy image (random noise — not ideal but tests generation) dummy_image = torch.randn(1, 3, img_size, img_size).to(args.device) total_time = 0 for i, question in enumerate(TEST_QUESTIONS): print(f" Q{i+1}: {question}") # Tokenize question q_ids = tokenizer.encode(question) q_tensor = torch.tensor([q_ids], dtype=torch.long, device=args.device) # Generate start = time.time() with torch.no_grad(): try: output_ids = model.generate( dummy_image, q_tensor, max_new_tokens=args.max_tokens, temperature=args.temperature, ) elapsed = time.time() - start total_time += elapsed # Decode if output_ids is not None and output_ids.numel() > 0: pred_text = tokenizer.decode(output_ids[0].cpu().tolist()) # Clean up pred_text = pred_text.replace("", "").replace("", "").strip() print(f" A{i+1}: {pred_text[:200]}") print(f" [{elapsed:.2f}s, {output_ids.shape[-1]} tokens]") else: print(f" A{i+1}: [EMPTY OUTPUT]") print(f" [{elapsed:.2f}s]") except Exception as e: elapsed = time.time() - start print(f" A{i+1}: [ERROR] {e}") print(f" [{elapsed:.2f}s]") print() print("=" * 70) print(f"Total inference time: {total_time:.2f}s") print(f"Average per query: {total_time / len(TEST_QUESTIONS):.2f}s") print("=" * 70) # Summary print("\n--- Diagnosis ---") print("If ALL answers are empty or garbage (random tokens):") print(" → Tokenizer mismatch between training and inference") print(" → Check: tokenizer vocab == model decoder vocab") print() print("If answers are repetitive (same word repeated):") print(" → Model collapsed during training (mode collapse)") print(" → Check: training loss was actually decreasing") print() print("If answers are coherent but wrong:") print(" → Model needs more/better training data") print(" → Architecture is working, just needs scale") if __name__ == "__main__": main()