#!/usr/bin/env python3 """ Profile ArcisVLM inference to find bottleneck kernels. Runs torch.profiler on N dummy inference passes and reports: - Top-10 CUDA kernels by GPU time % - Total inference time, tokens/sec, peak memory - Saves Chrome-compatible profiler trace to profiling/trace.json Usage: python3 scripts/profile_model.py --config configs/default.yaml --device cpu --num-samples 10 python3 scripts/profile_model.py --ckpt checkpoints/stage2_final.pt --config configs/scale_1.3b.yaml --device cuda """ import argparse import os import sys import time import torch import yaml sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from model.vlm import VLJEPAModel from model.tokenizer import BPETokenizer def load_model_and_config(config_path: str, ckpt_path: str | None, device: str): """Load config, build model, optionally load checkpoint.""" with open(config_path) as f: config = yaml.safe_load(f) model = VLJEPAModel(config) if ckpt_path and os.path.exists(ckpt_path): ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) state = ckpt.get("model_state_dict", ckpt) model.load_state_dict(state, strict=False) print(f"[profile] Loaded checkpoint: {ckpt_path}", file=sys.stderr) else: print("[profile] No checkpoint — profiling random init", file=sys.stderr) model = model.to(device) model.eval() return model, config def make_dummy_inputs(config: dict, device: str, batch_size: int = 1): """Create dummy image + query inputs for inference profiling.""" img_size = config["vision"]["img_size"] vocab_size = config["decoder"]["vocab_size"] max_q = config.get("predictor", {}).get("max_query_len", 64) # Use short query for profiling q_len = min(32, max_q) images = torch.randn(batch_size, 3, img_size, img_size, device=device) query_ids = torch.randint(1, vocab_size, (batch_size, q_len), device=device) query_mask = torch.ones(batch_size, q_len, dtype=torch.long, device=device) return images, query_ids, query_mask def warmup_model(model, config, device, n_warmup: int = 3): """Run a few warmup passes so CUDA kernels are compiled/cached.""" images, q_ids, q_mask = make_dummy_inputs(config, device) for _ in range(n_warmup): with torch.no_grad(): model.generate(images, q_ids, q_mask, max_new_tokens=16, temperature=0.8) if device.startswith("cuda"): torch.cuda.synchronize() def profile_inference(model, config: dict, device: str, num_samples: int, trace_path: str, max_new_tokens: int = 32): """Run profiled inference and return timing + kernel stats.""" is_cuda = device.startswith("cuda") # Reset peak memory tracking if is_cuda: torch.cuda.reset_peak_memory_stats() torch.cuda.synchronize() activities = [torch.profiler.ProfilerActivity.CPU] if is_cuda: activities.append(torch.profiler.ProfilerActivity.CUDA) total_tokens = 0 t_start = time.perf_counter() os.makedirs(os.path.dirname(trace_path) or ".", exist_ok=True) with torch.profiler.profile( activities=activities, record_shapes=True, profile_memory=True, with_stack=False, ) as prof: for i in range(num_samples): images, q_ids, q_mask = make_dummy_inputs(config, device) with torch.no_grad(): generated = model.generate( images, q_ids, q_mask, max_new_tokens=max_new_tokens, temperature=0.8, ) total_tokens += generated.shape[1] if is_cuda: torch.cuda.synchronize() t_end = time.perf_counter() # Save Chrome trace prof.export_chrome_trace(trace_path) # Collect kernel-level stats sort_key = "cuda_time_total" if is_cuda else "cpu_time_total" time_key = "cuda_time_total" if is_cuda else "cpu_time_total" events = prof.key_averages() # Compute total time across all kernels total_kernel_time = sum( getattr(evt, time_key, 0) for evt in events ) # Build ranked list ranked = [] for evt in sorted(events, key=lambda e: getattr(e, time_key, 0), reverse=True): kernel_time = getattr(evt, time_key, 0) pct = (kernel_time / total_kernel_time * 100) if total_kernel_time > 0 else 0.0 ranked.append({ "name": evt.key, "calls": evt.count, "time_us": kernel_time, "pct": pct, "cpu_time_us": evt.cpu_time_total, "cuda_time_us": getattr(evt, "cuda_time_total", 0), }) wall_time = t_end - t_start tokens_per_sec = total_tokens / wall_time if wall_time > 0 else 0 peak_mem_gb = (torch.cuda.max_memory_allocated() / 1e9) if is_cuda else 0.0 return { "ranked_kernels": ranked, "wall_time_s": wall_time, "total_tokens": total_tokens, "tokens_per_sec": tokens_per_sec, "peak_memory_gb": peak_mem_gb, "num_samples": num_samples, "trace_path": trace_path, } def print_report(results: dict): """Print human-readable profiling report to stdout.""" print("=" * 70) print("ArcisVLM Inference Profiling Report") print("=" * 70) print(f" Samples: {results['num_samples']}") print(f" Wall time: {results['wall_time_s']:.2f}s") print(f" Total tokens: {results['total_tokens']}") print(f" Tokens/sec: {results['tokens_per_sec']:.1f}") if results["peak_memory_gb"] > 0: print(f" Peak GPU mem: {results['peak_memory_gb']:.2f} GB") print(f" Trace saved: {results['trace_path']}") print() print("Top-10 Kernels by GPU/CPU Time:") print("-" * 70) print(f" {'Rank':>4} {'%':>6} {'Time(us)':>10} {'Calls':>6} {'Kernel'}") print("-" * 70) for i, k in enumerate(results["ranked_kernels"][:10]): print(f" {i+1:>4} {k['pct']:>5.1f}% {k['time_us']:>10.0f} {k['calls']:>6} {k['name']}") print("-" * 70) def main(): parser = argparse.ArgumentParser(description="Profile ArcisVLM inference") parser.add_argument("--ckpt", type=str, default=None, help="Checkpoint path") parser.add_argument("--config", type=str, required=True, help="YAML config path") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") parser.add_argument("--num-samples", type=int, default=100, help="Number of inference samples") parser.add_argument("--max-new-tokens", type=int, default=32, help="Tokens to generate per sample") parser.add_argument("--trace-dir", type=str, default="profiling", help="Directory for trace output") parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations before profiling") args = parser.parse_args() trace_path = os.path.join(args.trace_dir, "trace.json") model, config = load_model_and_config(args.config, args.ckpt, args.device) # Warmup print(f"[profile] Warming up ({args.warmup} iters)...", file=sys.stderr) warmup_model(model, config, args.device, args.warmup) # Profile print(f"[profile] Profiling {args.num_samples} samples on {args.device}...", file=sys.stderr) results = profile_inference( model, config, args.device, args.num_samples, trace_path, args.max_new_tokens ) print_report(results) if __name__ == "__main__": main()