arcisvlm / scripts /profile_model.py
Hardik Sanghvi
feat: integrate Gemma 4 E2B backbone for production-quality VLM inference
7a564e3
Raw
History Blame Contribute Delete
7.53 kB
#!/usr/bin/env python3
"""
Profile ArcisVLM inference to find bottleneck kernels.
Runs torch.profiler on N dummy inference passes and reports:
- Top-10 CUDA kernels by GPU time %
- Total inference time, tokens/sec, peak memory
- Saves Chrome-compatible profiler trace to profiling/trace.json
Usage:
python3 scripts/profile_model.py --config configs/default.yaml --device cpu --num-samples 10
python3 scripts/profile_model.py --ckpt checkpoints/stage2_final.pt --config configs/scale_1.3b.yaml --device cuda
"""
import argparse
import os
import sys
import time
import torch
import yaml
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from model.vlm import VLJEPAModel
from model.tokenizer import BPETokenizer
def load_model_and_config(config_path: str, ckpt_path: str | None, device: str):
"""Load config, build model, optionally load checkpoint."""
with open(config_path) as f:
config = yaml.safe_load(f)
model = VLJEPAModel(config)
if ckpt_path and os.path.exists(ckpt_path):
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
state = ckpt.get("model_state_dict", ckpt)
model.load_state_dict(state, strict=False)
print(f"[profile] Loaded checkpoint: {ckpt_path}", file=sys.stderr)
else:
print("[profile] No checkpoint — profiling random init", file=sys.stderr)
model = model.to(device)
model.eval()
return model, config
def make_dummy_inputs(config: dict, device: str, batch_size: int = 1):
"""Create dummy image + query inputs for inference profiling."""
img_size = config["vision"]["img_size"]
vocab_size = config["decoder"]["vocab_size"]
max_q = config.get("predictor", {}).get("max_query_len", 64)
# Use short query for profiling
q_len = min(32, max_q)
images = torch.randn(batch_size, 3, img_size, img_size, device=device)
query_ids = torch.randint(1, vocab_size, (batch_size, q_len), device=device)
query_mask = torch.ones(batch_size, q_len, dtype=torch.long, device=device)
return images, query_ids, query_mask
def warmup_model(model, config, device, n_warmup: int = 3):
"""Run a few warmup passes so CUDA kernels are compiled/cached."""
images, q_ids, q_mask = make_dummy_inputs(config, device)
for _ in range(n_warmup):
with torch.no_grad():
model.generate(images, q_ids, q_mask, max_new_tokens=16, temperature=0.8)
if device.startswith("cuda"):
torch.cuda.synchronize()
def profile_inference(model, config: dict, device: str, num_samples: int,
trace_path: str, max_new_tokens: int = 32):
"""Run profiled inference and return timing + kernel stats."""
is_cuda = device.startswith("cuda")
# Reset peak memory tracking
if is_cuda:
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
activities = [torch.profiler.ProfilerActivity.CPU]
if is_cuda:
activities.append(torch.profiler.ProfilerActivity.CUDA)
total_tokens = 0
t_start = time.perf_counter()
os.makedirs(os.path.dirname(trace_path) or ".", exist_ok=True)
with torch.profiler.profile(
activities=activities,
record_shapes=True,
profile_memory=True,
with_stack=False,
) as prof:
for i in range(num_samples):
images, q_ids, q_mask = make_dummy_inputs(config, device)
with torch.no_grad():
generated = model.generate(
images, q_ids, q_mask,
max_new_tokens=max_new_tokens,
temperature=0.8,
)
total_tokens += generated.shape[1]
if is_cuda:
torch.cuda.synchronize()
t_end = time.perf_counter()
# Save Chrome trace
prof.export_chrome_trace(trace_path)
# Collect kernel-level stats
sort_key = "cuda_time_total" if is_cuda else "cpu_time_total"
time_key = "cuda_time_total" if is_cuda else "cpu_time_total"
events = prof.key_averages()
# Compute total time across all kernels
total_kernel_time = sum(
getattr(evt, time_key, 0) for evt in events
)
# Build ranked list
ranked = []
for evt in sorted(events, key=lambda e: getattr(e, time_key, 0), reverse=True):
kernel_time = getattr(evt, time_key, 0)
pct = (kernel_time / total_kernel_time * 100) if total_kernel_time > 0 else 0.0
ranked.append({
"name": evt.key,
"calls": evt.count,
"time_us": kernel_time,
"pct": pct,
"cpu_time_us": evt.cpu_time_total,
"cuda_time_us": getattr(evt, "cuda_time_total", 0),
})
wall_time = t_end - t_start
tokens_per_sec = total_tokens / wall_time if wall_time > 0 else 0
peak_mem_gb = (torch.cuda.max_memory_allocated() / 1e9) if is_cuda else 0.0
return {
"ranked_kernels": ranked,
"wall_time_s": wall_time,
"total_tokens": total_tokens,
"tokens_per_sec": tokens_per_sec,
"peak_memory_gb": peak_mem_gb,
"num_samples": num_samples,
"trace_path": trace_path,
}
def print_report(results: dict):
"""Print human-readable profiling report to stdout."""
print("=" * 70)
print("ArcisVLM Inference Profiling Report")
print("=" * 70)
print(f" Samples: {results['num_samples']}")
print(f" Wall time: {results['wall_time_s']:.2f}s")
print(f" Total tokens: {results['total_tokens']}")
print(f" Tokens/sec: {results['tokens_per_sec']:.1f}")
if results["peak_memory_gb"] > 0:
print(f" Peak GPU mem: {results['peak_memory_gb']:.2f} GB")
print(f" Trace saved: {results['trace_path']}")
print()
print("Top-10 Kernels by GPU/CPU Time:")
print("-" * 70)
print(f" {'Rank':>4} {'%':>6} {'Time(us)':>10} {'Calls':>6} {'Kernel'}")
print("-" * 70)
for i, k in enumerate(results["ranked_kernels"][:10]):
print(f" {i+1:>4} {k['pct']:>5.1f}% {k['time_us']:>10.0f} {k['calls']:>6} {k['name']}")
print("-" * 70)
def main():
parser = argparse.ArgumentParser(description="Profile ArcisVLM inference")
parser.add_argument("--ckpt", type=str, default=None, help="Checkpoint path")
parser.add_argument("--config", type=str, required=True, help="YAML config path")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--num-samples", type=int, default=100, help="Number of inference samples")
parser.add_argument("--max-new-tokens", type=int, default=32, help="Tokens to generate per sample")
parser.add_argument("--trace-dir", type=str, default="profiling", help="Directory for trace output")
parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations before profiling")
args = parser.parse_args()
trace_path = os.path.join(args.trace_dir, "trace.json")
model, config = load_model_and_config(args.config, args.ckpt, args.device)
# Warmup
print(f"[profile] Warming up ({args.warmup} iters)...", file=sys.stderr)
warmup_model(model, config, args.device, args.warmup)
# Profile
print(f"[profile] Profiling {args.num_samples} samples on {args.device}...", file=sys.stderr)
results = profile_inference(
model, config, args.device, args.num_samples,
trace_path, args.max_new_tokens
)
print_report(results)
if __name__ == "__main__":
main()