| | """Chronos PoC: PTX transform selection inference. |
| | |
| | Loads a trained TransformPolicy checkpoint and predicts the optimal |
| | sequence of PTX transforms for a given kernel. |
| | |
| | Usage: |
| | python inference.py --checkpoint checkpoint_best.pt --kernel gemm_tile --m 4 --n 6 --k 8 |
| | python inference.py --checkpoint checkpoint_best.pt --ptx path/to/kernel.ptx |
| | """ |
| |
|
| | import argparse |
| | import sys |
| | import os |
| | import json |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.distributions import Categorical |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | N_FEATURES = 25 |
| | N_ACTIONS = 21 |
| |
|
| | ACTION_NAMES = [ |
| | "vec_ld", "vec_st", |
| | "cache_cs", "cache_cg", "cache_ca", "cache_cv", |
| | "st_cache_cs", "st_cache_wt", "st_cache_wb", |
| | "maxnreg_32", "maxnreg_64", "maxnreg_128", "maxnreg_255", |
| | "reorder_cp", "reorder_il", "reorder_lf", "reorder_sl", |
| | "prefetch_L1", "prefetch_L2", |
| | "split_ld", |
| | "stop", |
| | ] |
| |
|
| | FEATURE_NAMES = [ |
| | "n_instructions", "n_ld_global", "n_st_global", "n_fma", |
| | "n_ld_param", "n_prefetch", "n_branch", |
| | "n_ld_global_vec", "n_st_global_vec", "vec_ld_ratio", "vec_st_ratio", |
| | "n_cache_hint_ld", "n_cache_hint_st", "hint_ld_ratio", "hint_st_ratio", |
| | "load_ratio", "store_ratio", "fma_ratio", "compute_ratio", |
| | "mem_ratio", "compute_to_mem", |
| | "total_regs", "n_f32_regs", "n_b64_regs", "maxnreg", |
| | ] |
| |
|
| | CONFLICT_GROUPS = { |
| | "cache_hints": {"cache_cs", "cache_cg", "cache_ca", "cache_cv"}, |
| | "store_cache_hints": {"st_cache_cs", "st_cache_wt", "st_cache_wb"}, |
| | "register_budget": {"maxnreg_32", "maxnreg_64", "maxnreg_128", "maxnreg_255"}, |
| | "prefetch": {"prefetch_L1", "prefetch_L2"}, |
| | "reorder": {"reorder_cp", "reorder_il", "reorder_lf", "reorder_sl"}, |
| | } |
| |
|
| |
|
| | class TransformPolicy(nn.Module): |
| | """MLP policy for PTX transform selection. |
| | |
| | Input: 25 features + 21 action mask + 21 action history = 67 dims |
| | Output: 21 logits (masked before softmax) |
| | """ |
| |
|
| | def __init__(self, hidden=128): |
| | super().__init__() |
| | input_dim = N_FEATURES + N_ACTIONS + N_ACTIONS |
| | self.net = nn.Sequential( |
| | nn.Linear(input_dim, hidden), |
| | nn.ReLU(), |
| | nn.Dropout(0.1), |
| | nn.Linear(hidden, hidden), |
| | nn.ReLU(), |
| | nn.Dropout(0.1), |
| | nn.Linear(hidden, N_ACTIONS), |
| | ) |
| |
|
| | def forward(self, features, action_mask, action_history): |
| | x = torch.cat([features, action_mask, action_history], dim=-1) |
| | logits = self.net(x) |
| | logits = logits.masked_fill(action_mask == 0, float('-inf')) |
| | return logits |
| |
|
| | @torch.no_grad() |
| | def get_greedy_action(self, features, action_mask, action_history): |
| | logits = self.forward( |
| | features.unsqueeze(0), action_mask.unsqueeze(0), |
| | action_history.unsqueeze(0), |
| | ) |
| | return logits.argmax(dim=-1).item() |
| |
|
| | @torch.no_grad() |
| | def get_action_probs(self, features, action_mask, action_history): |
| | logits = self.forward( |
| | features.unsqueeze(0), action_mask.unsqueeze(0), |
| | action_history.unsqueeze(0), |
| | ) |
| | probs = F.softmax(logits, dim=-1) |
| | return probs.squeeze(0) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | import re |
| |
|
| | _LD_GLOBAL = re.compile(r'ld\.global') |
| | _LD_GLOBAL_VEC = re.compile(r'ld\.global(?:\.\w+)*\.v[24]') |
| | _ST_GLOBAL = re.compile(r'st\.global') |
| | _ST_GLOBAL_VEC = re.compile(r'st\.global(?:\.\w+)*\.v[24]') |
| | _FMA = re.compile(r'\bfma\.') |
| | _MUL = re.compile(r'\bmul\.') |
| | _ADD = re.compile(r'\badd\.') |
| | _LD_PARAM = re.compile(r'ld\.param') |
| | _PREFETCH = re.compile(r'prefetch\.global') |
| | _CACHE_HINT_LD = re.compile(r'ld\.global\.(?:cs|cg|ca|cv)') |
| | _CACHE_HINT_ST = re.compile(r'st\.global\.(?:wb|wt|cs)') |
| | _MAXNREG = re.compile(r'\.maxnreg\s+(\d+)') |
| |
|
| |
|
| | def extract_features_from_ptx(ptx_source): |
| | """Extract 25 scalar features from PTX source text.""" |
| | n_instr = 0 |
| | n_ld_global = 0 |
| | n_ld_global_vec = 0 |
| | n_st_global = 0 |
| | n_st_global_vec = 0 |
| | n_fma = 0 |
| | n_mul = 0 |
| | n_add = 0 |
| | n_ld_param = 0 |
| | n_prefetch = 0 |
| | n_cache_hint_ld = 0 |
| | n_cache_hint_st = 0 |
| | n_branch = 0 |
| |
|
| | |
| | reg_decls = {} |
| | for line in ptx_source.split('\n'): |
| | m = re.search(r'\.reg\s+(\.\w+)\s+%\w+<(\d+)>\s*;', line) |
| | if m: |
| | reg_decls[m.group(1)] = int(m.group(2)) |
| |
|
| | |
| | in_body = False |
| | for line in ptx_source.split('\n'): |
| | stripped = line.strip() |
| |
|
| | if stripped == '{': |
| | in_body = True |
| | continue |
| | if stripped == '}': |
| | in_body = False |
| | continue |
| | if not in_body: |
| | continue |
| |
|
| | |
| | if not stripped or stripped.startswith('//') or stripped.startswith('.'): |
| | continue |
| | if stripped.endswith(':'): |
| | continue |
| | if stripped in ('ret;', 'exit;', ')', ','): |
| | continue |
| |
|
| | |
| | if 'bra ' in stripped or 'bra\t' in stripped: |
| | n_branch += 1 |
| | continue |
| |
|
| | n_instr += 1 |
| |
|
| | if _LD_GLOBAL.search(line): |
| | n_ld_global += 1 |
| | if _LD_GLOBAL_VEC.search(line): |
| | n_ld_global_vec += 1 |
| | if _CACHE_HINT_LD.search(line): |
| | n_cache_hint_ld += 1 |
| | if _ST_GLOBAL.search(line): |
| | n_st_global += 1 |
| | if _ST_GLOBAL_VEC.search(line): |
| | n_st_global_vec += 1 |
| | if _CACHE_HINT_ST.search(line): |
| | n_cache_hint_st += 1 |
| | if _FMA.search(line): |
| | n_fma += 1 |
| | if _MUL.search(line): |
| | n_mul += 1 |
| | if _ADD.search(line): |
| | n_add += 1 |
| | if _LD_PARAM.search(line): |
| | n_ld_param += 1 |
| | if _PREFETCH.search(line): |
| | n_prefetch += 1 |
| |
|
| | maxnreg = 0 |
| | m = _MAXNREG.search(ptx_source) |
| | if m: |
| | maxnreg = int(m.group(1)) |
| |
|
| | total_regs = sum(reg_decls.values()) |
| | n_f32_regs = reg_decls.get('.f32', 0) |
| | n_b64_regs = reg_decls.get('.b64', 0) |
| | n_total = max(n_instr, 1) |
| | n_compute = n_fma + n_mul + n_add |
| | n_mem = n_ld_global + n_st_global |
| |
|
| | return [ |
| | n_instr, |
| | n_ld_global, |
| | n_st_global, |
| | n_fma, |
| | n_ld_param, |
| | n_prefetch, |
| | n_branch, |
| | n_ld_global_vec, |
| | n_st_global_vec, |
| | round(n_ld_global_vec / max(n_ld_global, 1), 4), |
| | round(n_st_global_vec / max(n_st_global, 1), 4), |
| | n_cache_hint_ld, |
| | n_cache_hint_st, |
| | round(n_cache_hint_ld / max(n_ld_global, 1), 4), |
| | round(n_cache_hint_st / max(n_st_global, 1), 4), |
| | round(n_ld_global / n_total, 4), |
| | round(n_st_global / n_total, 4), |
| | round(n_fma / n_total, 4), |
| | round(n_compute / n_total, 4), |
| | round(n_mem / n_total, 4), |
| | round(n_compute / max(n_mem, 1), 4), |
| | total_regs, |
| | n_f32_regs, |
| | n_b64_regs, |
| | maxnreg, |
| | ] |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def get_action_mask(applied_set): |
| | mask = [] |
| | for label in ACTION_NAMES: |
| | if label == "stop": |
| | mask.append(1) |
| | continue |
| | if label in applied_set: |
| | mask.append(0) |
| | continue |
| | conflict = False |
| | for group_labels in CONFLICT_GROUPS.values(): |
| | if label in group_labels and applied_set & group_labels: |
| | conflict = True |
| | break |
| | mask.append(0 if conflict else 1) |
| | return mask |
| |
|
| |
|
| | def get_action_history(applied_set): |
| | return [1 if name in applied_set else 0 for name in ACTION_NAMES] |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def load_model(checkpoint_path, device="cpu"): |
| | """Load trained TransformPolicy from checkpoint.""" |
| | ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) |
| | model = TransformPolicy(hidden=128) |
| | model.load_state_dict(ckpt["policy"]) |
| | model.eval() |
| | model.to(device) |
| | epoch = ckpt.get("epoch", "unknown") |
| | print(f"Loaded checkpoint from epoch {epoch}") |
| | if "eval_result" in ckpt: |
| | mean_imp = ckpt["eval_result"].get("mean_improvement", 0) |
| | print(f" Eval mean improvement: {mean_imp*100:.1f}%") |
| | return model |
| |
|
| |
|
| | def predict_transforms(model, ptx_source, max_steps=6, verbose=True): |
| | """Predict optimal transform sequence for a PTX kernel. |
| | |
| | Returns list of transform labels (excluding 'stop'). |
| | """ |
| | features = extract_features_from_ptx(ptx_source) |
| | applied = set() |
| | actions = [] |
| |
|
| | if verbose: |
| | print(f"\nKernel: {features[0]} instructions, " |
| | f"{features[1]} global loads, {features[2]} global stores, " |
| | f"{features[3]} FMA, {features[21]} total regs") |
| |
|
| | for step in range(max_steps): |
| | feat_t = torch.tensor(features, dtype=torch.float32) |
| | mask = get_action_mask(applied) |
| | mask_t = torch.tensor(mask, dtype=torch.float32) |
| | hist = get_action_history(applied) |
| | hist_t = torch.tensor(hist, dtype=torch.float32) |
| |
|
| | action_id = model.get_greedy_action(feat_t, mask_t, hist_t) |
| | action_label = ACTION_NAMES[action_id] |
| |
|
| | if verbose: |
| | probs = model.get_action_probs(feat_t, mask_t, hist_t) |
| | top5 = torch.topk(probs, min(5, probs.size(0))) |
| | top5_str = ", ".join( |
| | f"{ACTION_NAMES[i]}={p:.2f}" |
| | for p, i in zip(top5.values.tolist(), top5.indices.tolist()) |
| | ) |
| | print(f" Step {step+1}: {action_label} (top5: {top5_str})") |
| |
|
| | if action_label == "stop": |
| | break |
| |
|
| | actions.append(action_label) |
| | applied.add(action_label) |
| |
|
| | if verbose: |
| | print(f"\nPredicted sequence: {' -> '.join(actions) if actions else '(no transforms)'}") |
| |
|
| | return actions |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="Chronos PoC: PTX transform inference") |
| | parser.add_argument("--checkpoint", required=True, help="Path to .pt checkpoint") |
| | parser.add_argument("--ptx", help="Path to PTX file") |
| | parser.add_argument("--kernel", default="gemm_tile", |
| | help="Kernel type (for generating PTX if --ptx not provided)") |
| | parser.add_argument("--m", type=int, default=4) |
| | parser.add_argument("--n", type=int, default=6) |
| | parser.add_argument("--k", type=int, default=8) |
| | args = parser.parse_args() |
| |
|
| | model = load_model(args.checkpoint) |
| |
|
| | if args.ptx: |
| | with open(args.ptx) as f: |
| | ptx_source = f.read() |
| | print(f"\nLoaded PTX from: {args.ptx}") |
| | else: |
| | print(f"\nTo run on a specific kernel, use: --ptx path/to/kernel.ptx") |
| | print("Showing demo with a sample feature vector...") |
| |
|
| | |
| | |
| | demo_features = [ |
| | 170, |
| | 16, |
| | 8, |
| | 48, |
| | 12, |
| | 0, |
| | 2, |
| | 0, |
| | 0, |
| | 0.0, |
| | 0.0, |
| | 0, |
| | 0, |
| | 0.0, |
| | 0.0, |
| | 0.094, |
| | 0.047, |
| | 0.282, |
| | 0.388, |
| | 0.141, |
| | 2.75, |
| | 95, |
| | 48, |
| | 16, |
| | 0, |
| | ] |
| |
|
| | applied = set() |
| | actions = [] |
| | print(f"\nDemo: gemm_tile({args.m},{args.n},{args.k})-like features") |
| | print(f"Features: {len(demo_features)} dims") |
| |
|
| | for step in range(6): |
| | feat_t = torch.tensor(demo_features, dtype=torch.float32) |
| | mask = get_action_mask(applied) |
| | mask_t = torch.tensor(mask, dtype=torch.float32) |
| | hist = get_action_history(applied) |
| | hist_t = torch.tensor(hist, dtype=torch.float32) |
| |
|
| | action_id = model.get_greedy_action(feat_t, mask_t, hist_t) |
| | action_label = ACTION_NAMES[action_id] |
| |
|
| | probs = model.get_action_probs(feat_t, mask_t, hist_t) |
| | top3 = torch.topk(probs, min(3, probs.size(0))) |
| | top3_str = ", ".join( |
| | f"{ACTION_NAMES[i]}={p:.2f}" |
| | for p, i in zip(top3.values.tolist(), top3.indices.tolist()) |
| | ) |
| | print(f" Step {step+1}: {action_label} (probs: {top3_str})") |
| |
|
| | if action_label == "stop": |
| | break |
| | actions.append(action_label) |
| | applied.add(action_label) |
| |
|
| | print(f"\nPredicted: {' -> '.join(actions)}") |
| | print(f"Expected for gemm_tile(4,6,8): maxnreg_128 -> vec_ld -> vec_st -> stop") |
| | return |
| |
|
| | actions = predict_transforms(model, ptx_source) |
| | print(f"\nTo apply these transforms, use the Chronos transform pipeline.") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|