chronos-poc / inference.py
JayLuci4's picture
Chronos PoC: PTX transform selection via RLVR (DA-GRPO)
adbcafd verified
"""Chronos PoC: PTX transform selection inference.
Loads a trained TransformPolicy checkpoint and predicts the optimal
sequence of PTX transforms for a given kernel.
Usage:
python inference.py --checkpoint checkpoint_best.pt --kernel gemm_tile --m 4 --n 6 --k 8
python inference.py --checkpoint checkpoint_best.pt --ptx path/to/kernel.ptx
"""
import argparse
import sys
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
# ---------------------------------------------------------------------------
# Model definition (self-contained, no external dependencies for inference)
# ---------------------------------------------------------------------------
N_FEATURES = 25 # Model was trained with 25 scalar features
N_ACTIONS = 21
ACTION_NAMES = [
"vec_ld", "vec_st",
"cache_cs", "cache_cg", "cache_ca", "cache_cv",
"st_cache_cs", "st_cache_wt", "st_cache_wb",
"maxnreg_32", "maxnreg_64", "maxnreg_128", "maxnreg_255",
"reorder_cp", "reorder_il", "reorder_lf", "reorder_sl",
"prefetch_L1", "prefetch_L2",
"split_ld",
"stop",
]
FEATURE_NAMES = [
"n_instructions", "n_ld_global", "n_st_global", "n_fma",
"n_ld_param", "n_prefetch", "n_branch",
"n_ld_global_vec", "n_st_global_vec", "vec_ld_ratio", "vec_st_ratio",
"n_cache_hint_ld", "n_cache_hint_st", "hint_ld_ratio", "hint_st_ratio",
"load_ratio", "store_ratio", "fma_ratio", "compute_ratio",
"mem_ratio", "compute_to_mem",
"total_regs", "n_f32_regs", "n_b64_regs", "maxnreg",
]
CONFLICT_GROUPS = {
"cache_hints": {"cache_cs", "cache_cg", "cache_ca", "cache_cv"},
"store_cache_hints": {"st_cache_cs", "st_cache_wt", "st_cache_wb"},
"register_budget": {"maxnreg_32", "maxnreg_64", "maxnreg_128", "maxnreg_255"},
"prefetch": {"prefetch_L1", "prefetch_L2"},
"reorder": {"reorder_cp", "reorder_il", "reorder_lf", "reorder_sl"},
}
class TransformPolicy(nn.Module):
"""MLP policy for PTX transform selection.
Input: 25 features + 21 action mask + 21 action history = 67 dims
Output: 21 logits (masked before softmax)
"""
def __init__(self, hidden=128):
super().__init__()
input_dim = N_FEATURES + N_ACTIONS + N_ACTIONS # 67
self.net = nn.Sequential(
nn.Linear(input_dim, hidden),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden, N_ACTIONS),
)
def forward(self, features, action_mask, action_history):
x = torch.cat([features, action_mask, action_history], dim=-1)
logits = self.net(x)
logits = logits.masked_fill(action_mask == 0, float('-inf'))
return logits
@torch.no_grad()
def get_greedy_action(self, features, action_mask, action_history):
logits = self.forward(
features.unsqueeze(0), action_mask.unsqueeze(0),
action_history.unsqueeze(0),
)
return logits.argmax(dim=-1).item()
@torch.no_grad()
def get_action_probs(self, features, action_mask, action_history):
logits = self.forward(
features.unsqueeze(0), action_mask.unsqueeze(0),
action_history.unsqueeze(0),
)
probs = F.softmax(logits, dim=-1)
return probs.squeeze(0)
# ---------------------------------------------------------------------------
# Feature extraction (self-contained, regex-based)
# ---------------------------------------------------------------------------
import re
_LD_GLOBAL = re.compile(r'ld\.global')
_LD_GLOBAL_VEC = re.compile(r'ld\.global(?:\.\w+)*\.v[24]')
_ST_GLOBAL = re.compile(r'st\.global')
_ST_GLOBAL_VEC = re.compile(r'st\.global(?:\.\w+)*\.v[24]')
_FMA = re.compile(r'\bfma\.')
_MUL = re.compile(r'\bmul\.')
_ADD = re.compile(r'\badd\.')
_LD_PARAM = re.compile(r'ld\.param')
_PREFETCH = re.compile(r'prefetch\.global')
_CACHE_HINT_LD = re.compile(r'ld\.global\.(?:cs|cg|ca|cv)')
_CACHE_HINT_ST = re.compile(r'st\.global\.(?:wb|wt|cs)')
_MAXNREG = re.compile(r'\.maxnreg\s+(\d+)')
def extract_features_from_ptx(ptx_source):
"""Extract 25 scalar features from PTX source text."""
n_instr = 0
n_ld_global = 0
n_ld_global_vec = 0
n_st_global = 0
n_st_global_vec = 0
n_fma = 0
n_mul = 0
n_add = 0
n_ld_param = 0
n_prefetch = 0
n_cache_hint_ld = 0
n_cache_hint_st = 0
n_branch = 0
# Parse register declarations
reg_decls = {}
for line in ptx_source.split('\n'):
m = re.search(r'\.reg\s+(\.\w+)\s+%\w+<(\d+)>\s*;', line)
if m:
reg_decls[m.group(1)] = int(m.group(2))
# Count instructions (lines between { and })
in_body = False
for line in ptx_source.split('\n'):
stripped = line.strip()
if stripped == '{':
in_body = True
continue
if stripped == '}':
in_body = False
continue
if not in_body:
continue
# Skip non-instructions
if not stripped or stripped.startswith('//') or stripped.startswith('.'):
continue
if stripped.endswith(':'): # label
continue
if stripped in ('ret;', 'exit;', ')', ','):
continue
# Check for branch
if 'bra ' in stripped or 'bra\t' in stripped:
n_branch += 1
continue
n_instr += 1
if _LD_GLOBAL.search(line):
n_ld_global += 1
if _LD_GLOBAL_VEC.search(line):
n_ld_global_vec += 1
if _CACHE_HINT_LD.search(line):
n_cache_hint_ld += 1
if _ST_GLOBAL.search(line):
n_st_global += 1
if _ST_GLOBAL_VEC.search(line):
n_st_global_vec += 1
if _CACHE_HINT_ST.search(line):
n_cache_hint_st += 1
if _FMA.search(line):
n_fma += 1
if _MUL.search(line):
n_mul += 1
if _ADD.search(line):
n_add += 1
if _LD_PARAM.search(line):
n_ld_param += 1
if _PREFETCH.search(line):
n_prefetch += 1
maxnreg = 0
m = _MAXNREG.search(ptx_source)
if m:
maxnreg = int(m.group(1))
total_regs = sum(reg_decls.values())
n_f32_regs = reg_decls.get('.f32', 0)
n_b64_regs = reg_decls.get('.b64', 0)
n_total = max(n_instr, 1)
n_compute = n_fma + n_mul + n_add
n_mem = n_ld_global + n_st_global
return [
n_instr,
n_ld_global,
n_st_global,
n_fma,
n_ld_param,
n_prefetch,
n_branch,
n_ld_global_vec,
n_st_global_vec,
round(n_ld_global_vec / max(n_ld_global, 1), 4), # vec_ld_ratio
round(n_st_global_vec / max(n_st_global, 1), 4), # vec_st_ratio
n_cache_hint_ld,
n_cache_hint_st,
round(n_cache_hint_ld / max(n_ld_global, 1), 4), # hint_ld_ratio
round(n_cache_hint_st / max(n_st_global, 1), 4), # hint_st_ratio
round(n_ld_global / n_total, 4), # load_ratio
round(n_st_global / n_total, 4), # store_ratio
round(n_fma / n_total, 4), # fma_ratio
round(n_compute / n_total, 4), # compute_ratio
round(n_mem / n_total, 4), # mem_ratio
round(n_compute / max(n_mem, 1), 4), # compute_to_mem
total_regs,
n_f32_regs,
n_b64_regs,
maxnreg,
]
# ---------------------------------------------------------------------------
# Action mask and history
# ---------------------------------------------------------------------------
def get_action_mask(applied_set):
mask = []
for label in ACTION_NAMES:
if label == "stop":
mask.append(1)
continue
if label in applied_set:
mask.append(0)
continue
conflict = False
for group_labels in CONFLICT_GROUPS.values():
if label in group_labels and applied_set & group_labels:
conflict = True
break
mask.append(0 if conflict else 1)
return mask
def get_action_history(applied_set):
return [1 if name in applied_set else 0 for name in ACTION_NAMES]
# ---------------------------------------------------------------------------
# Inference
# ---------------------------------------------------------------------------
def load_model(checkpoint_path, device="cpu"):
"""Load trained TransformPolicy from checkpoint."""
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
model = TransformPolicy(hidden=128)
model.load_state_dict(ckpt["policy"])
model.eval()
model.to(device)
epoch = ckpt.get("epoch", "unknown")
print(f"Loaded checkpoint from epoch {epoch}")
if "eval_result" in ckpt:
mean_imp = ckpt["eval_result"].get("mean_improvement", 0)
print(f" Eval mean improvement: {mean_imp*100:.1f}%")
return model
def predict_transforms(model, ptx_source, max_steps=6, verbose=True):
"""Predict optimal transform sequence for a PTX kernel.
Returns list of transform labels (excluding 'stop').
"""
features = extract_features_from_ptx(ptx_source)
applied = set()
actions = []
if verbose:
print(f"\nKernel: {features[0]} instructions, "
f"{features[1]} global loads, {features[2]} global stores, "
f"{features[3]} FMA, {features[21]} total regs")
for step in range(max_steps):
feat_t = torch.tensor(features, dtype=torch.float32)
mask = get_action_mask(applied)
mask_t = torch.tensor(mask, dtype=torch.float32)
hist = get_action_history(applied)
hist_t = torch.tensor(hist, dtype=torch.float32)
action_id = model.get_greedy_action(feat_t, mask_t, hist_t)
action_label = ACTION_NAMES[action_id]
if verbose:
probs = model.get_action_probs(feat_t, mask_t, hist_t)
top5 = torch.topk(probs, min(5, probs.size(0)))
top5_str = ", ".join(
f"{ACTION_NAMES[i]}={p:.2f}"
for p, i in zip(top5.values.tolist(), top5.indices.tolist())
)
print(f" Step {step+1}: {action_label} (top5: {top5_str})")
if action_label == "stop":
break
actions.append(action_label)
applied.add(action_label)
if verbose:
print(f"\nPredicted sequence: {' -> '.join(actions) if actions else '(no transforms)'}")
return actions
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="Chronos PoC: PTX transform inference")
parser.add_argument("--checkpoint", required=True, help="Path to .pt checkpoint")
parser.add_argument("--ptx", help="Path to PTX file")
parser.add_argument("--kernel", default="gemm_tile",
help="Kernel type (for generating PTX if --ptx not provided)")
parser.add_argument("--m", type=int, default=4)
parser.add_argument("--n", type=int, default=6)
parser.add_argument("--k", type=int, default=8)
args = parser.parse_args()
model = load_model(args.checkpoint)
if args.ptx:
with open(args.ptx) as f:
ptx_source = f.read()
print(f"\nLoaded PTX from: {args.ptx}")
else:
print(f"\nTo run on a specific kernel, use: --ptx path/to/kernel.ptx")
print("Showing demo with a sample feature vector...")
# Demo: create a synthetic feature vector matching gemm_tile(4,6,8)
# (the best kernel from training: -53.8% improvement)
demo_features = [
170, # n_instructions
16, # n_ld_global
8, # n_st_global
48, # n_fma
12, # n_ld_param
0, # n_prefetch
2, # n_branch
0, # n_ld_global_vec
0, # n_st_global_vec
0.0, # vec_ld_ratio
0.0, # vec_st_ratio
0, # n_cache_hint_ld
0, # n_cache_hint_st
0.0, # hint_ld_ratio
0.0, # hint_st_ratio
0.094, # load_ratio
0.047, # store_ratio
0.282, # fma_ratio
0.388, # compute_ratio
0.141, # mem_ratio
2.75, # compute_to_mem
95, # total_regs
48, # n_f32_regs
16, # n_b64_regs
0, # maxnreg
]
applied = set()
actions = []
print(f"\nDemo: gemm_tile({args.m},{args.n},{args.k})-like features")
print(f"Features: {len(demo_features)} dims")
for step in range(6):
feat_t = torch.tensor(demo_features, dtype=torch.float32)
mask = get_action_mask(applied)
mask_t = torch.tensor(mask, dtype=torch.float32)
hist = get_action_history(applied)
hist_t = torch.tensor(hist, dtype=torch.float32)
action_id = model.get_greedy_action(feat_t, mask_t, hist_t)
action_label = ACTION_NAMES[action_id]
probs = model.get_action_probs(feat_t, mask_t, hist_t)
top3 = torch.topk(probs, min(3, probs.size(0)))
top3_str = ", ".join(
f"{ACTION_NAMES[i]}={p:.2f}"
for p, i in zip(top3.values.tolist(), top3.indices.tolist())
)
print(f" Step {step+1}: {action_label} (probs: {top3_str})")
if action_label == "stop":
break
actions.append(action_label)
applied.add(action_label)
print(f"\nPredicted: {' -> '.join(actions)}")
print(f"Expected for gemm_tile(4,6,8): maxnreg_128 -> vec_ld -> vec_st -> stop")
return
actions = predict_transforms(model, ptx_source)
print(f"\nTo apply these transforms, use the Chronos transform pipeline.")
if __name__ == "__main__":
main()