""" Profiling utilities: torch.profiler wrapper and analysis tools. Following D-103: profile first, optimize only hot paths. Uses torch.profiler to identify training loop bottlenecks. """ import sys import os import json import math import torch sys.path.insert(0, os.path.dirname(__file__)) from .main import ARBModel from .config import VOCAB, CTX def profile_training(model, train_data, device, n_steps=20, warmup_steps=5, top_k=10, batch_size=64, ctx=CTX): """ Profile N training steps using torch.profiler. Runs profiling with CUDA + CPU activity tracing, warmup steps (no profiling), then profiled steps. Returns list of top-K hot path tuples and saves JSON. Args: model: ARBModel instance train_data: 1D byte tensor of training data device: 'cuda' or 'cpu' n_steps: Number of profiled training steps warmup_steps: Steps before profiling begins (no tracing) top_k: Number of top operations to return batch_size: Batch size for each training step ctx: Context window length Returns: List of dicts with keys: op_name, cuda_time_us, cpu_time_us, calls """ model.train() prof = None if device == "cuda": prof = torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ], record_shapes=True, with_stack=True, with_flops=True, ) else: prof = torch.profiler.profile( activities=[torch.profiler.ProfilerActivity.CPU], record_shapes=True, with_stack=False, ) # Warmup steps (no profiling) for _ in range(warmup_steps): ix = torch.randint(0, len(train_data) - ctx - 1, (batch_size,)) x = torch.stack([train_data[j: j + ctx] for j in ix]) targets = x[:, 3:] x = x.to(device) targets = targets.to(device) with torch.no_grad(): model(x, targets=targets) # Profiled steps prof.start() for _ in range(n_steps): ix = torch.randint(0, len(train_data) - ctx - 1, (batch_size,)) x = torch.stack([train_data[j: j + ctx] for j in ix]) targets = x[:, 3:] x = x.to(device) targets = targets.to(device) with torch.no_grad(): model(x, targets=targets) if device == "cuda": torch.cuda.synchronize() prof.stop() # Process profiler output if device == "cuda": key_avg = prof.key_averages() table = key_avg.table(sort_by="cuda_time_total", row_limit=top_k) else: key_avg = prof.key_averages() table = key_avg.table(sort_by="cpu_time_total", row_limit=top_k) # Extract top-K entries events = key_avg.events() if hasattr(key_avg, 'events') else key_avg[:top_k] top_results = [] for evt in events[:top_k]: # device_time replaces deprecated cuda_time in recent PyTorch cuda_t = (evt.device_time if hasattr(evt, 'device_time') and evt.device_time is not None else evt.cuda_time if hasattr(evt, 'cuda_time') else 0) entry = { "op_name": evt.key if hasattr(evt, 'key') else str(evt), "cuda_time_us": cuda_t, "cpu_time_us": evt.cpu_time if hasattr(evt, 'cpu_time') else 0, "calls": evt.count if hasattr(evt, 'count') else 1, } top_results.append(entry) # Print summary print("\n=== Profiling Results (Top-{} Hot Paths) ===".format(top_k)) print(table) print("============================================\n") # Save profiler output as JSON prof.export_chrome_trace("/tmp/profiler_trace.json") return top_results def analyze_profiler_output(prof_path): """ Load saved profiler JSON output and extract key insights. Args: prof_path: Path to saved profiler JSON file Returns: List of dicts with op_name, cuda_time_us, cpu_time_us, calls """ with open(prof_path, "r") as f: data = json.load(f) # Profiler JSON can be a dict with 'traceEvents' or a flat list if isinstance(data, dict) and "traceEvents" in data: events = data["traceEvents"] elif isinstance(data, list): events = data else: events = [] # Aggregate events by name op_stats = {} for evt in events: if isinstance(evt, dict): name = evt.get("name", "unknown") dur = evt.get("dur", 0) # microseconds cat = evt.get("cat", "") if name not in op_stats: op_stats[name] = {"cuda_time_us": 0, "cpu_time_us": 0, "calls": 0} if "gpu" in cat.lower(): op_stats[name]["cuda_time_us"] += dur elif "cpu" in cat.lower() or cat == "": op_stats[name]["cpu_time_us"] += dur op_stats[name]["calls"] += 1 # Sort by CUDA time descending sorted_ops = sorted( op_stats.items(), key=lambda x: x[1]["cuda_time_us"], reverse=True, ) results = [] for name, stats in sorted_ops: results.append({ "op_name": name, "cuda_time_us": stats["cuda_time_us"], "cpu_time_us": stats["cpu_time_us"], "calls": stats["calls"], }) # Print formatted summary print("\n=== Profiler Analysis ===") print(f"{'Operation':<40} {'CUDA Time (us)':>15} {'CPU Time (us)':>15} {'Calls':>8}") print("-" * 80) for r in results[:20]: print(f"{r['op_name']:<40} {r['cuda_time_us']:>15.0f} {r['cpu_time_us']:>15.0f} {r['calls']:>8}") # Identify dominating patterns total_cuda = sum(r["cuda_time_us"] for r in results) if total_cuda > 0: print("\n=== Hot Path Analysis ===") for r in results[:5]: pct = (r["cuda_time_us"] / total_cuda) * 100 if total_cuda > 0 else 0 label = "" if "vq" in r["op_name"].lower() or "flash_vq" in r["op_name"].lower(): label = " → VQ candidate for Triton kernel" elif "moe" in r["op_name"].lower() or "scatter" in r["op_name"].lower(): label = " → MoE dispatch candidate" elif "embed" in r["op_name"].lower() or "gather" in r["op_name"].lower(): label = " → Embedding gather (existing Triton kernel)" elif "mm" in r["op_name"].lower() or "linear" in r["op_name"].lower(): label = " → General matmul (torch.compile candidate)" print(f" {r['op_name']:<40} {pct:>5.1f}%{label}") print("============================================\n") return results