YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

Chronos PoC: PTX Transform Selection via RLVR

Proof-of-concept RL agent that selects PTX assembly transforms to optimize GPU kernel performance on NVIDIA L4 (sm_89).

What This Is

An MLP policy trained with DA-GRPO (Demonstration-Anchored Group Relative Policy Optimization) to select sequences of PTX-level transforms that reduce GPU kernel execution cycles. Trained on 64 gemm_tile kernel variants, validated on 33 diverse Triton kernels.

Results

Metric Value
Mean cycle reduction (gemm_tile) -29.2%
Best single kernel -53.8% (gemm_tile 4,6,8: 1839 -> 849 cycles)
Generalization to Triton kernels 17/33 kernels improve
Best Triton improvement -47.0% (attention_d64_kv64 with maxnreg_255)
Training time ~6 hours on single NVIDIA L4
Model parameters ~20K

Model Architecture

Input:  25 kernel features + 21 action mask + 21 action history = 67 dims
Hidden: Linear(67, 128) -> ReLU -> Dropout(0.1) -> Linear(128, 128) -> ReLU -> Dropout(0.1)
Output: Linear(128, 21) -> mask -> softmax

25 features: instruction counts (loads, stores, FMA, branches), vectorization ratios, cache hint coverage, register state, instruction mix ratios.

21 actions: 20 PTX transforms + stop. Transforms include vectorize loads/stores, cache hints (cs/cg/ca/cv), store cache hints (cs/wt/wb), register budget limits (32/64/128/255), instruction reorder (critical_path/interleave/loads_first/stores_last), prefetch, and split vector loads.

Files

checkpoint_best.pt       # Best checkpoint (epoch 250, -29.2% mean)
checkpoint_latest.pt     # Final checkpoint (epoch 500)
inference.py             # Self-contained inference script
training_result.json     # Per-kernel results (64 gemm_tile kernels)
bc_stats.json            # Behavior cloning warm-start statistics
measure_triton_results.json  # Triton kernel measurement results (33 kernels)

Inference

Requirements

pip install torch

No GPU required for inference (model runs on CPU). GPU only needed for actually applying transforms and measuring cycles.

Quick Start

import torch
from inference import load_model, predict_transforms

# Load model
model = load_model("checkpoint_best.pt")

# Predict transforms for a PTX kernel
with open("your_kernel.ptx") as f:
    ptx = f.read()
actions = predict_transforms(model, ptx)
# -> ['maxnreg_128', 'vec_ld', 'vec_st']

Command Line

# Demo with synthetic features
python inference.py --checkpoint checkpoint_best.pt

# Run on a PTX file
python inference.py --checkpoint checkpoint_best.pt --ptx path/to/kernel.ptx

Programmatic Usage

import torch
from inference import TransformPolicy, extract_features_from_ptx, get_action_mask, get_action_history, ACTION_NAMES

# Load
model = TransformPolicy(hidden=128)
ckpt = torch.load("checkpoint_best.pt", map_location="cpu", weights_only=False)
model.load_state_dict(ckpt["policy"])
model.eval()

# Extract features from PTX
ptx_source = open("kernel.ptx").read()
features = extract_features_from_ptx(ptx_source)

# Predict step by step
applied = set()
for step in range(6):
    feat_t = torch.tensor(features, dtype=torch.float32)
    mask_t = torch.tensor(get_action_mask(applied), dtype=torch.float32)
    hist_t = torch.tensor(get_action_history(applied), dtype=torch.float32)

    action_id = model.get_greedy_action(feat_t, mask_t, hist_t)
    action = ACTION_NAMES[action_id]
    if action == "stop":
        break
    print(f"Step {step+1}: apply {action}")
    applied.add(action)

Training Details

Algorithm: DA-GRPO

  1. BC warm-start (50 epochs): Clone greedy search trajectories. Best accuracy: 64.5%.
  2. GRPO training (450 epochs): Hardware-in-the-loop RL with SM clock() cycle measurement.
    • Group size: 8 rollouts per kernel (1 anchor from reference policy + 7 with forced diverse first actions)
    • Advantage: MC-GRPO (median baseline per kernel, global z-normalization)
    • Reward: log(cycles_before / cycles_after) — outcome-only, terminal
    • KL penalty: beta=0.01 against BC reference policy
    • Clipped surrogate: epsilon=0.2

Action Space

20 PTX transforms organized into 5 conflict groups (only one per group):

  • Cache hints (load): cs, cg, ca, cv
  • Store cache hints: cs, wt, wb
  • Register budget: maxnreg 32, 64, 128, 255
  • Instruction reorder: critical_path, interleave, loads_first, stores_last
  • Prefetch: L1, L2
  • Vectorize: loads, stores (independent)
  • Split: vector loads (independent)

Hardware

  • NVIDIA L4 GPU (sm_89, Ada Lovelace)
  • SM clock() cycle counter (1-cycle std dev, 200 samples per measurement)
  • pip-installed CUDA 12.9 ptxas

Limitations

  • Trained on gemm_tile kernels only (64 variants, m,n,k in {2,4,6,8})
  • Mode collapse: 52/64 kernels get the same sequence (vec_st + vec_ld + maxnreg_128)
  • MLP can't read PTX code — relies on 25 scalar features
  • 3% measurement error rate from cudaErrorMisalignedAddress on gemm_tile(4,6,4)
  • Reorder transforms deadlock on kernels with bar.sync barriers

References

License

Research prototype. Contact for usage terms.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Papers for JayLuci4/chronos-poc