"""Chronos PoC: PTX transform selection inference. Loads a trained TransformPolicy checkpoint and predicts the optimal sequence of PTX transforms for a given kernel. Usage: python inference.py --checkpoint checkpoint_best.pt --kernel gemm_tile --m 4 --n 6 --k 8 python inference.py --checkpoint checkpoint_best.pt --ptx path/to/kernel.ptx """ import argparse import sys import os import json import torch import torch.nn as nn import torch.nn.functional as F from torch.distributions import Categorical # --------------------------------------------------------------------------- # Model definition (self-contained, no external dependencies for inference) # --------------------------------------------------------------------------- N_FEATURES = 25 # Model was trained with 25 scalar features N_ACTIONS = 21 ACTION_NAMES = [ "vec_ld", "vec_st", "cache_cs", "cache_cg", "cache_ca", "cache_cv", "st_cache_cs", "st_cache_wt", "st_cache_wb", "maxnreg_32", "maxnreg_64", "maxnreg_128", "maxnreg_255", "reorder_cp", "reorder_il", "reorder_lf", "reorder_sl", "prefetch_L1", "prefetch_L2", "split_ld", "stop", ] FEATURE_NAMES = [ "n_instructions", "n_ld_global", "n_st_global", "n_fma", "n_ld_param", "n_prefetch", "n_branch", "n_ld_global_vec", "n_st_global_vec", "vec_ld_ratio", "vec_st_ratio", "n_cache_hint_ld", "n_cache_hint_st", "hint_ld_ratio", "hint_st_ratio", "load_ratio", "store_ratio", "fma_ratio", "compute_ratio", "mem_ratio", "compute_to_mem", "total_regs", "n_f32_regs", "n_b64_regs", "maxnreg", ] CONFLICT_GROUPS = { "cache_hints": {"cache_cs", "cache_cg", "cache_ca", "cache_cv"}, "store_cache_hints": {"st_cache_cs", "st_cache_wt", "st_cache_wb"}, "register_budget": {"maxnreg_32", "maxnreg_64", "maxnreg_128", "maxnreg_255"}, "prefetch": {"prefetch_L1", "prefetch_L2"}, "reorder": {"reorder_cp", "reorder_il", "reorder_lf", "reorder_sl"}, } class TransformPolicy(nn.Module): """MLP policy for PTX transform selection. Input: 25 features + 21 action mask + 21 action history = 67 dims Output: 21 logits (masked before softmax) """ def __init__(self, hidden=128): super().__init__() input_dim = N_FEATURES + N_ACTIONS + N_ACTIONS # 67 self.net = nn.Sequential( nn.Linear(input_dim, hidden), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden, hidden), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden, N_ACTIONS), ) def forward(self, features, action_mask, action_history): x = torch.cat([features, action_mask, action_history], dim=-1) logits = self.net(x) logits = logits.masked_fill(action_mask == 0, float('-inf')) return logits @torch.no_grad() def get_greedy_action(self, features, action_mask, action_history): logits = self.forward( features.unsqueeze(0), action_mask.unsqueeze(0), action_history.unsqueeze(0), ) return logits.argmax(dim=-1).item() @torch.no_grad() def get_action_probs(self, features, action_mask, action_history): logits = self.forward( features.unsqueeze(0), action_mask.unsqueeze(0), action_history.unsqueeze(0), ) probs = F.softmax(logits, dim=-1) return probs.squeeze(0) # --------------------------------------------------------------------------- # Feature extraction (self-contained, regex-based) # --------------------------------------------------------------------------- import re _LD_GLOBAL = re.compile(r'ld\.global') _LD_GLOBAL_VEC = re.compile(r'ld\.global(?:\.\w+)*\.v[24]') _ST_GLOBAL = re.compile(r'st\.global') _ST_GLOBAL_VEC = re.compile(r'st\.global(?:\.\w+)*\.v[24]') _FMA = re.compile(r'\bfma\.') _MUL = re.compile(r'\bmul\.') _ADD = re.compile(r'\badd\.') _LD_PARAM = re.compile(r'ld\.param') _PREFETCH = re.compile(r'prefetch\.global') _CACHE_HINT_LD = re.compile(r'ld\.global\.(?:cs|cg|ca|cv)') _CACHE_HINT_ST = re.compile(r'st\.global\.(?:wb|wt|cs)') _MAXNREG = re.compile(r'\.maxnreg\s+(\d+)') def extract_features_from_ptx(ptx_source): """Extract 25 scalar features from PTX source text.""" n_instr = 0 n_ld_global = 0 n_ld_global_vec = 0 n_st_global = 0 n_st_global_vec = 0 n_fma = 0 n_mul = 0 n_add = 0 n_ld_param = 0 n_prefetch = 0 n_cache_hint_ld = 0 n_cache_hint_st = 0 n_branch = 0 # Parse register declarations reg_decls = {} for line in ptx_source.split('\n'): m = re.search(r'\.reg\s+(\.\w+)\s+%\w+<(\d+)>\s*;', line) if m: reg_decls[m.group(1)] = int(m.group(2)) # Count instructions (lines between { and }) in_body = False for line in ptx_source.split('\n'): stripped = line.strip() if stripped == '{': in_body = True continue if stripped == '}': in_body = False continue if not in_body: continue # Skip non-instructions if not stripped or stripped.startswith('//') or stripped.startswith('.'): continue if stripped.endswith(':'): # label continue if stripped in ('ret;', 'exit;', ')', ','): continue # Check for branch if 'bra ' in stripped or 'bra\t' in stripped: n_branch += 1 continue n_instr += 1 if _LD_GLOBAL.search(line): n_ld_global += 1 if _LD_GLOBAL_VEC.search(line): n_ld_global_vec += 1 if _CACHE_HINT_LD.search(line): n_cache_hint_ld += 1 if _ST_GLOBAL.search(line): n_st_global += 1 if _ST_GLOBAL_VEC.search(line): n_st_global_vec += 1 if _CACHE_HINT_ST.search(line): n_cache_hint_st += 1 if _FMA.search(line): n_fma += 1 if _MUL.search(line): n_mul += 1 if _ADD.search(line): n_add += 1 if _LD_PARAM.search(line): n_ld_param += 1 if _PREFETCH.search(line): n_prefetch += 1 maxnreg = 0 m = _MAXNREG.search(ptx_source) if m: maxnreg = int(m.group(1)) total_regs = sum(reg_decls.values()) n_f32_regs = reg_decls.get('.f32', 0) n_b64_regs = reg_decls.get('.b64', 0) n_total = max(n_instr, 1) n_compute = n_fma + n_mul + n_add n_mem = n_ld_global + n_st_global return [ n_instr, n_ld_global, n_st_global, n_fma, n_ld_param, n_prefetch, n_branch, n_ld_global_vec, n_st_global_vec, round(n_ld_global_vec / max(n_ld_global, 1), 4), # vec_ld_ratio round(n_st_global_vec / max(n_st_global, 1), 4), # vec_st_ratio n_cache_hint_ld, n_cache_hint_st, round(n_cache_hint_ld / max(n_ld_global, 1), 4), # hint_ld_ratio round(n_cache_hint_st / max(n_st_global, 1), 4), # hint_st_ratio round(n_ld_global / n_total, 4), # load_ratio round(n_st_global / n_total, 4), # store_ratio round(n_fma / n_total, 4), # fma_ratio round(n_compute / n_total, 4), # compute_ratio round(n_mem / n_total, 4), # mem_ratio round(n_compute / max(n_mem, 1), 4), # compute_to_mem total_regs, n_f32_regs, n_b64_regs, maxnreg, ] # --------------------------------------------------------------------------- # Action mask and history # --------------------------------------------------------------------------- def get_action_mask(applied_set): mask = [] for label in ACTION_NAMES: if label == "stop": mask.append(1) continue if label in applied_set: mask.append(0) continue conflict = False for group_labels in CONFLICT_GROUPS.values(): if label in group_labels and applied_set & group_labels: conflict = True break mask.append(0 if conflict else 1) return mask def get_action_history(applied_set): return [1 if name in applied_set else 0 for name in ACTION_NAMES] # --------------------------------------------------------------------------- # Inference # --------------------------------------------------------------------------- def load_model(checkpoint_path, device="cpu"): """Load trained TransformPolicy from checkpoint.""" ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) model = TransformPolicy(hidden=128) model.load_state_dict(ckpt["policy"]) model.eval() model.to(device) epoch = ckpt.get("epoch", "unknown") print(f"Loaded checkpoint from epoch {epoch}") if "eval_result" in ckpt: mean_imp = ckpt["eval_result"].get("mean_improvement", 0) print(f" Eval mean improvement: {mean_imp*100:.1f}%") return model def predict_transforms(model, ptx_source, max_steps=6, verbose=True): """Predict optimal transform sequence for a PTX kernel. Returns list of transform labels (excluding 'stop'). """ features = extract_features_from_ptx(ptx_source) applied = set() actions = [] if verbose: print(f"\nKernel: {features[0]} instructions, " f"{features[1]} global loads, {features[2]} global stores, " f"{features[3]} FMA, {features[21]} total regs") for step in range(max_steps): feat_t = torch.tensor(features, dtype=torch.float32) mask = get_action_mask(applied) mask_t = torch.tensor(mask, dtype=torch.float32) hist = get_action_history(applied) hist_t = torch.tensor(hist, dtype=torch.float32) action_id = model.get_greedy_action(feat_t, mask_t, hist_t) action_label = ACTION_NAMES[action_id] if verbose: probs = model.get_action_probs(feat_t, mask_t, hist_t) top5 = torch.topk(probs, min(5, probs.size(0))) top5_str = ", ".join( f"{ACTION_NAMES[i]}={p:.2f}" for p, i in zip(top5.values.tolist(), top5.indices.tolist()) ) print(f" Step {step+1}: {action_label} (top5: {top5_str})") if action_label == "stop": break actions.append(action_label) applied.add(action_label) if verbose: print(f"\nPredicted sequence: {' -> '.join(actions) if actions else '(no transforms)'}") return actions # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): parser = argparse.ArgumentParser(description="Chronos PoC: PTX transform inference") parser.add_argument("--checkpoint", required=True, help="Path to .pt checkpoint") parser.add_argument("--ptx", help="Path to PTX file") parser.add_argument("--kernel", default="gemm_tile", help="Kernel type (for generating PTX if --ptx not provided)") parser.add_argument("--m", type=int, default=4) parser.add_argument("--n", type=int, default=6) parser.add_argument("--k", type=int, default=8) args = parser.parse_args() model = load_model(args.checkpoint) if args.ptx: with open(args.ptx) as f: ptx_source = f.read() print(f"\nLoaded PTX from: {args.ptx}") else: print(f"\nTo run on a specific kernel, use: --ptx path/to/kernel.ptx") print("Showing demo with a sample feature vector...") # Demo: create a synthetic feature vector matching gemm_tile(4,6,8) # (the best kernel from training: -53.8% improvement) demo_features = [ 170, # n_instructions 16, # n_ld_global 8, # n_st_global 48, # n_fma 12, # n_ld_param 0, # n_prefetch 2, # n_branch 0, # n_ld_global_vec 0, # n_st_global_vec 0.0, # vec_ld_ratio 0.0, # vec_st_ratio 0, # n_cache_hint_ld 0, # n_cache_hint_st 0.0, # hint_ld_ratio 0.0, # hint_st_ratio 0.094, # load_ratio 0.047, # store_ratio 0.282, # fma_ratio 0.388, # compute_ratio 0.141, # mem_ratio 2.75, # compute_to_mem 95, # total_regs 48, # n_f32_regs 16, # n_b64_regs 0, # maxnreg ] applied = set() actions = [] print(f"\nDemo: gemm_tile({args.m},{args.n},{args.k})-like features") print(f"Features: {len(demo_features)} dims") for step in range(6): feat_t = torch.tensor(demo_features, dtype=torch.float32) mask = get_action_mask(applied) mask_t = torch.tensor(mask, dtype=torch.float32) hist = get_action_history(applied) hist_t = torch.tensor(hist, dtype=torch.float32) action_id = model.get_greedy_action(feat_t, mask_t, hist_t) action_label = ACTION_NAMES[action_id] probs = model.get_action_probs(feat_t, mask_t, hist_t) top3 = torch.topk(probs, min(3, probs.size(0))) top3_str = ", ".join( f"{ACTION_NAMES[i]}={p:.2f}" for p, i in zip(top3.values.tolist(), top3.indices.tolist()) ) print(f" Step {step+1}: {action_label} (probs: {top3_str})") if action_label == "stop": break actions.append(action_label) applied.add(action_label) print(f"\nPredicted: {' -> '.join(actions)}") print(f"Expected for gemm_tile(4,6,8): maxnreg_128 -> vec_ld -> vec_st -> stop") return actions = predict_transforms(model, ptx_source) print(f"\nTo apply these transforms, use the Chronos transform pipeline.") if __name__ == "__main__": main()