Chronos PoC: PTX Transform Selection via RLVR
Proof-of-concept RL agent that selects PTX assembly transforms to optimize GPU kernel performance on NVIDIA L4 (sm_89).
What This Is
An MLP policy trained with DA-GRPO (Demonstration-Anchored Group Relative Policy Optimization) to select sequences of PTX-level transforms that reduce GPU kernel execution cycles. Trained on 64 gemm_tile kernel variants, validated on 33 diverse Triton kernels.
Results
| Metric | Value |
|---|---|
| Mean cycle reduction (gemm_tile) | -29.2% |
| Best single kernel | -53.8% (gemm_tile 4,6,8: 1839 -> 849 cycles) |
| Generalization to Triton kernels | 17/33 kernels improve |
| Best Triton improvement | -47.0% (attention_d64_kv64 with maxnreg_255) |
| Training time | ~6 hours on single NVIDIA L4 |
| Model parameters | ~20K |
Model Architecture
Input: 25 kernel features + 21 action mask + 21 action history = 67 dims
Hidden: Linear(67, 128) -> ReLU -> Dropout(0.1) -> Linear(128, 128) -> ReLU -> Dropout(0.1)
Output: Linear(128, 21) -> mask -> softmax
25 features: instruction counts (loads, stores, FMA, branches), vectorization ratios, cache hint coverage, register state, instruction mix ratios.
21 actions: 20 PTX transforms + stop. Transforms include vectorize loads/stores, cache hints (cs/cg/ca/cv), store cache hints (cs/wt/wb), register budget limits (32/64/128/255), instruction reorder (critical_path/interleave/loads_first/stores_last), prefetch, and split vector loads.
Files
checkpoint_best.pt # Best checkpoint (epoch 250, -29.2% mean)
checkpoint_latest.pt # Final checkpoint (epoch 500)
inference.py # Self-contained inference script
training_result.json # Per-kernel results (64 gemm_tile kernels)
bc_stats.json # Behavior cloning warm-start statistics
measure_triton_results.json # Triton kernel measurement results (33 kernels)
Inference
Requirements
pip install torch
No GPU required for inference (model runs on CPU). GPU only needed for actually applying transforms and measuring cycles.
Quick Start
import torch
from inference import load_model, predict_transforms
# Load model
model = load_model("checkpoint_best.pt")
# Predict transforms for a PTX kernel
with open("your_kernel.ptx") as f:
ptx = f.read()
actions = predict_transforms(model, ptx)
# -> ['maxnreg_128', 'vec_ld', 'vec_st']
Command Line
# Demo with synthetic features
python inference.py --checkpoint checkpoint_best.pt
# Run on a PTX file
python inference.py --checkpoint checkpoint_best.pt --ptx path/to/kernel.ptx
Programmatic Usage
import torch
from inference import TransformPolicy, extract_features_from_ptx, get_action_mask, get_action_history, ACTION_NAMES
# Load
model = TransformPolicy(hidden=128)
ckpt = torch.load("checkpoint_best.pt", map_location="cpu", weights_only=False)
model.load_state_dict(ckpt["policy"])
model.eval()
# Extract features from PTX
ptx_source = open("kernel.ptx").read()
features = extract_features_from_ptx(ptx_source)
# Predict step by step
applied = set()
for step in range(6):
feat_t = torch.tensor(features, dtype=torch.float32)
mask_t = torch.tensor(get_action_mask(applied), dtype=torch.float32)
hist_t = torch.tensor(get_action_history(applied), dtype=torch.float32)
action_id = model.get_greedy_action(feat_t, mask_t, hist_t)
action = ACTION_NAMES[action_id]
if action == "stop":
break
print(f"Step {step+1}: apply {action}")
applied.add(action)
Training Details
Algorithm: DA-GRPO
- BC warm-start (50 epochs): Clone greedy search trajectories. Best accuracy: 64.5%.
- GRPO training (450 epochs): Hardware-in-the-loop RL with SM clock() cycle measurement.
- Group size: 8 rollouts per kernel (1 anchor from reference policy + 7 with forced diverse first actions)
- Advantage: MC-GRPO (median baseline per kernel, global z-normalization)
- Reward: log(cycles_before / cycles_after) — outcome-only, terminal
- KL penalty: beta=0.01 against BC reference policy
- Clipped surrogate: epsilon=0.2
Action Space
20 PTX transforms organized into 5 conflict groups (only one per group):
- Cache hints (load): cs, cg, ca, cv
- Store cache hints: cs, wt, wb
- Register budget: maxnreg 32, 64, 128, 255
- Instruction reorder: critical_path, interleave, loads_first, stores_last
- Prefetch: L1, L2
- Vectorize: loads, stores (independent)
- Split: vector loads (independent)
Hardware
- NVIDIA L4 GPU (sm_89, Ada Lovelace)
- SM clock() cycle counter (1-cycle std dev, 200 samples per measurement)
- pip-installed CUDA 12.9 ptxas
Limitations
- Trained on gemm_tile kernels only (64 variants, m,n,k in {2,4,6,8})
- Mode collapse: 52/64 kernels get the same sequence (vec_st + vec_ld + maxnreg_128)
- MLP can't read PTX code — relies on 25 scalar features
- 3% measurement error rate from cudaErrorMisalignedAddress on gemm_tile(4,6,4)
- Reorder transforms deadlock on kernels with bar.sync barriers
References
- CuAsmRL (CGO 2025): PPO on SASS scheduling
- Dr. Kernel (2026): REINFORCE for Triton kernels
- DeepSeek-R1 (2025): GRPO algorithm
- MC-GRPO (2025): Median-centered baseline
License
Research prototype. Contact for usage terms.