| | """ |
| | HF Spaces app for VLIW kernel optimization via RL. |
| | Uses actual simulator for correctness-gated cycle-count rewards. |
| | """ |
| | import os |
| | import sys |
| | import gradio as gr |
| | import threading |
| | import time |
| | import random |
| | import re |
| | from copy import copy |
| | from pathlib import Path |
| |
|
| | |
| | startup_log = [] |
| |
|
| | def check_import(name, import_fn): |
| | try: |
| | result = import_fn() |
| | startup_log.append(f"[OK] {name}: {result}") |
| | return True |
| | except Exception as e: |
| | startup_log.append(f"[ERR] {name}: {str(e)[:80]}") |
| | return False |
| |
|
| | check_import("torch", lambda: __import__("torch").__version__) |
| | check_import("transformers", lambda: __import__("transformers").__version__) |
| | check_import("datasets", lambda: __import__("datasets").__version__) |
| | check_import("peft", lambda: __import__("peft").__version__) |
| | check_import("trl", lambda: __import__("trl").__version__) |
| | check_import("huggingface_hub", lambda: __import__("huggingface_hub").__version__) |
| |
|
| | try: |
| | from trl import GRPOConfig, GRPOTrainer |
| | startup_log.append("[OK] GRPOTrainer: OK") |
| | except Exception as e: |
| | startup_log.append(f"[ERR] GRPOTrainer: {e}") |
| |
|
| | try: |
| | import torch |
| | if torch.cuda.is_available(): |
| | startup_log.append(f"[OK] CUDA: {torch.cuda.get_device_name(0)}") |
| | else: |
| | startup_log.append("[ERR] CUDA: Not available") |
| | except Exception as e: |
| | startup_log.append(f"[ERR] CUDA check: {e}") |
| |
|
| | |
| | |
| | THIS_DIR = os.path.dirname(os.path.abspath(__file__)) |
| | PERF_TAKEHOME_PATH = os.path.join(THIS_DIR, "original_performance_takehome") |
| | if os.path.isdir(PERF_TAKEHOME_PATH): |
| | sys.path.insert(0, PERF_TAKEHOME_PATH) |
| |
|
| | |
| | try: |
| | from problem import ( |
| | Machine, Tree, Input, DebugInfo, |
| | build_mem_image, reference_kernel2, |
| | SLOT_LIMITS, VLEN, N_CORES, SCRATCH_SIZE, CoreState |
| | ) |
| | from perf_takehome import KernelBuilder, HASH_STAGES |
| | startup_log.append("[OK] VLIW Simulator: OK") |
| | SIMULATOR_AVAILABLE = True |
| | except Exception as e: |
| | startup_log.append(f"[ERR] VLIW Simulator: {e}") |
| | SIMULATOR_AVAILABLE = False |
| |
|
| | |
| | try: |
| | from huggingface_hub import HfApi, snapshot_download |
| | startup_log.append("[OK] huggingface_hub: OK") |
| | HF_HUB_AVAILABLE = True |
| | except Exception as e: |
| | startup_log.append(f"[ERR] huggingface_hub: {str(e)[:80]}") |
| | HF_HUB_AVAILABLE = False |
| |
|
| | |
| | BASELINE_CYCLES = 147734 |
| | TARGET_CYCLES = 1363 |
| | SCORE_SCALE = 3000.0 |
| | PARSE_REWARD = 0.02 |
| | API_REWARD = 0.05 |
| | EXEC_REWARD = 0.10 |
| | PERSIST_DIR = "/data" if os.path.isdir("/data") else "." |
| | ADAPTER_DIR = os.path.join(PERSIST_DIR, "adapters", "perf_takehome_latest") |
| | ADAPTER_DATASET_REPO = os.environ.get("ADAPTER_DATASET_REPO", "CreativeEngineer/vliw-optimizer-adapters") |
| | ADAPTER_DATASET_SUBDIR = os.environ.get("ADAPTER_DATASET_SUBDIR", "perf_takehome_latest") |
| |
|
| | |
| | training_state = { |
| | "is_training": False, |
| | "should_stop": False, |
| | "log": [], |
| | "best_cycles": BASELINE_CYCLES, |
| | "best_code": None, |
| | "step": 0, |
| | } |
| | state_lock = threading.Lock() |
| |
|
| | _eval_context = {} |
| |
|
| |
|
| | def get_status(): |
| | return "\n".join(startup_log) |
| |
|
| |
|
| | def extract_code_block(text: str) -> str: |
| | |
| | pattern = r"```python\s*(.*?)```" |
| | matches = re.findall(pattern, text, re.DOTALL) |
| | if matches: |
| | return matches[-1].strip() |
| | pattern = r"```\s*(.*?)```" |
| | matches = re.findall(pattern, text, re.DOTALL) |
| | if matches: |
| | return matches[-1].strip() |
| |
|
| | |
| | if "```python" in text: |
| | after = text.split("```python", 1)[1] |
| | if "```" in after: |
| | after = after.split("```", 1)[0] |
| | return after.strip() |
| | if "```" in text: |
| | after = text.split("```", 1)[1] |
| | if "```" in after: |
| | after = after.split("```", 1)[0] |
| | return after.strip() |
| | return text.strip() |
| |
|
| |
|
| | def _hf_token() -> str | None: |
| | return os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") |
| |
|
| |
|
| | def _ensure_dir(path: str) -> None: |
| | Path(path).mkdir(parents=True, exist_ok=True) |
| |
|
| |
|
| | def _adapter_exists(path: str) -> bool: |
| | return os.path.exists(os.path.join(path, "adapter_config.json")) |
| |
|
| |
|
| | def _try_download_adapter(add_log) -> None: |
| | if not HF_HUB_AVAILABLE: |
| | add_log("[ERR] Hub sync disabled: huggingface_hub not available") |
| | return |
| | _ensure_dir(os.path.dirname(ADAPTER_DIR)) |
| | allow = [f"{ADAPTER_DATASET_SUBDIR}/**"] |
| | try: |
| | snapshot_download( |
| | repo_id=ADAPTER_DATASET_REPO, |
| | repo_type="dataset", |
| | allow_patterns=allow, |
| | local_dir=os.path.dirname(ADAPTER_DIR), |
| | local_dir_use_symlinks=False, |
| | token=_hf_token(), |
| | ) |
| | downloaded = os.path.join(os.path.dirname(ADAPTER_DIR), ADAPTER_DATASET_SUBDIR) |
| | if _adapter_exists(downloaded): |
| | if downloaded != ADAPTER_DIR: |
| | _ensure_dir(os.path.dirname(ADAPTER_DIR)) |
| | |
| | _ensure_dir(ADAPTER_DIR) |
| | for root, _, files in os.walk(downloaded): |
| | rel = os.path.relpath(root, downloaded) |
| | dst_root = ADAPTER_DIR if rel == "." else os.path.join(ADAPTER_DIR, rel) |
| | _ensure_dir(dst_root) |
| | for name in files: |
| | src = os.path.join(root, name) |
| | dst = os.path.join(dst_root, name) |
| | with open(src, "rb") as fsrc, open(dst, "wb") as fdst: |
| | fdst.write(fsrc.read()) |
| | add_log(f"[OK] Downloaded adapter from dataset: {ADAPTER_DATASET_REPO}/{ADAPTER_DATASET_SUBDIR}") |
| | else: |
| | add_log("[INFO] No adapter found in dataset yet") |
| | except Exception as e: |
| | add_log(f"[INFO] Adapter download skipped: {str(e)[:160]}") |
| |
|
| |
|
| | def _try_upload_adapter(add_log) -> None: |
| | if not HF_HUB_AVAILABLE: |
| | add_log("[ERR] Hub sync disabled: huggingface_hub not available") |
| | return |
| | if not _adapter_exists(ADAPTER_DIR): |
| | add_log("[INFO] No adapter to upload yet") |
| | return |
| | token = _hf_token() |
| | if token is None: |
| | add_log("[INFO] No HF token set (HF_TOKEN/HUGGINGFACE_HUB_TOKEN); skipping upload") |
| | return |
| | try: |
| | api = HfApi(token=token) |
| | api.create_repo(repo_id=ADAPTER_DATASET_REPO, repo_type="dataset", exist_ok=True) |
| | api.upload_folder( |
| | repo_id=ADAPTER_DATASET_REPO, |
| | repo_type="dataset", |
| | folder_path=ADAPTER_DIR, |
| | path_in_repo=ADAPTER_DATASET_SUBDIR, |
| | commit_message="Update perf_takehome adapter", |
| | ) |
| | add_log(f"[OK] Uploaded adapter to dataset: {ADAPTER_DATASET_REPO}/{ADAPTER_DATASET_SUBDIR}") |
| | except Exception as e: |
| | add_log(f"[INFO] Adapter upload skipped: {str(e)[:160]}") |
| |
|
| |
|
| | def _run_machine_with_cycle_limit(machine: Machine, max_cycles: int) -> bool: |
| | for core in machine.cores: |
| | if core.state == CoreState.PAUSED: |
| | core.state = CoreState.RUNNING |
| | while any(c.state == CoreState.RUNNING for c in machine.cores): |
| | has_non_debug = False |
| | for core in machine.cores: |
| | if core.state != CoreState.RUNNING: |
| | continue |
| | if core.pc >= len(machine.program): |
| | core.state = CoreState.STOPPED |
| | continue |
| | instr = machine.program[core.pc] |
| | core.pc += 1 |
| | machine.step(instr, core) |
| | if any(name != "debug" for name in instr.keys()): |
| | has_non_debug = True |
| | if has_non_debug: |
| | machine.cycle += 1 |
| | if machine.cycle >= max_cycles: |
| | for core in machine.cores: |
| | core.state = CoreState.STOPPED |
| | return False |
| | return True |
| |
|
| |
|
| | def _get_eval_context(seed: int) -> dict: |
| | with state_lock: |
| | cached = _eval_context.get(seed) |
| | if cached is not None: |
| | return cached |
| | random.seed(seed) |
| | forest = Tree.generate(10) |
| | inp = Input.generate(forest, 256, 16) |
| | mem0 = build_mem_image(forest, inp) |
| | ref_mem = None |
| | for ref_mem in reference_kernel2(list(mem0)): |
| | pass |
| | if ref_mem is None: |
| | raise RuntimeError("Reference kernel produced no output") |
| | inp_values_p = ref_mem[6] |
| | expected = ref_mem[inp_values_p : inp_values_p + len(inp.values)] |
| | ctx = { |
| | "forest": forest, |
| | "inp": inp, |
| | "mem0": mem0, |
| | "expected": expected, |
| | "inp_values_p": inp_values_p, |
| | } |
| | with state_lock: |
| | _eval_context[seed] = ctx |
| | return ctx |
| |
|
| |
|
| | def verify_perf_takehome_code(code: str, seed: int = 123) -> dict: |
| | if not SIMULATOR_AVAILABLE: |
| | return { |
| | "score": 0.0, |
| | "correctness": 0.0, |
| | "cycles": None, |
| | "msg": "Simulator unavailable", |
| | "parse_ok": False, |
| | "api_ok": False, |
| | "exec_ok": False, |
| | } |
| |
|
| | try: |
| | code = code.strip() |
| | if not code: |
| | return { |
| | "score": 0.0, |
| | "correctness": 0.0, |
| | "cycles": None, |
| | "msg": "Empty code", |
| | "parse_ok": False, |
| | "api_ok": False, |
| | "exec_ok": False, |
| | } |
| |
|
| | try: |
| | compile(code, "<string>", "exec") |
| | except Exception as e: |
| | return { |
| | "score": 0.0, |
| | "correctness": 0.0, |
| | "cycles": None, |
| | "msg": f"Syntax error: {str(e)[:200]}", |
| | "parse_ok": False, |
| | "api_ok": False, |
| | "exec_ok": False, |
| | } |
| |
|
| | if "OptimizedKernelBuilder" not in code: |
| | return { |
| | "score": 0.0, |
| | "correctness": 0.0, |
| | "cycles": None, |
| | "msg": "Missing OptimizedKernelBuilder", |
| | "parse_ok": True, |
| | "api_ok": False, |
| | "exec_ok": False, |
| | } |
| |
|
| | if "def run" not in code: |
| | return { |
| | "score": 0.0, |
| | "correctness": 0.0, |
| | "cycles": None, |
| | "msg": "Missing run()", |
| | "parse_ok": True, |
| | "api_ok": False, |
| | "exec_ok": False, |
| | } |
| |
|
| | safe_builtins = { |
| | "abs": abs, |
| | "all": all, |
| | "any": any, |
| | "dict": dict, |
| | "enumerate": enumerate, |
| | "int": int, |
| | "len": len, |
| | "list": list, |
| | "max": max, |
| | "min": min, |
| | "range": range, |
| | "sum": sum, |
| | "tuple": tuple, |
| | "zip": zip, |
| | } |
| | exec_globals = { |
| | "__builtins__": safe_builtins, |
| | "KernelBuilder": KernelBuilder, |
| | "HASH_STAGES": HASH_STAGES, |
| | "VLEN": VLEN, |
| | "SLOT_LIMITS": SLOT_LIMITS, |
| | } |
| |
|
| | try: |
| | exec(code, exec_globals) |
| | except Exception as e: |
| | return { |
| | "score": 0.0, |
| | "correctness": 0.0, |
| | "cycles": None, |
| | "msg": f"Execution error: {str(e)[:200]}", |
| | "parse_ok": True, |
| | "api_ok": True, |
| | "exec_ok": False, |
| | } |
| |
|
| | if "OptimizedKernelBuilder" not in exec_globals: |
| | return { |
| | "score": 0.0, |
| | "correctness": 0.0, |
| | "cycles": None, |
| | "msg": "OptimizedKernelBuilder not defined after exec", |
| | "parse_ok": True, |
| | "api_ok": True, |
| | "exec_ok": True, |
| | } |
| |
|
| | ctx = _get_eval_context(seed) |
| | forest = ctx["forest"] |
| | inp = ctx["inp"] |
| | mem0 = ctx["mem0"] |
| |
|
| | kb = exec_globals["OptimizedKernelBuilder"]() |
| | kb.build_kernel(10, len(forest.values), 256, 16) |
| |
|
| | machine = Machine( |
| | list(mem0), |
| | kb.instrs, |
| | kb.debug_info(), |
| | n_cores=N_CORES, |
| | trace=False, |
| | ) |
| | machine.enable_pause = False |
| | machine.enable_debug = False |
| |
|
| | ok = _run_machine_with_cycle_limit(machine, max_cycles=250000) |
| | if not ok: |
| | return { |
| | "score": 0.0, |
| | "correctness": 0.0, |
| | "cycles": int(machine.cycle), |
| | "msg": f"Exceeded cycle limit (cycles={machine.cycle})", |
| | "parse_ok": True, |
| | "api_ok": True, |
| | "exec_ok": True, |
| | } |
| | cycles = machine.cycle |
| |
|
| | if cycles <= 100: |
| | return { |
| | "score": 0.0, |
| | "correctness": 0.0, |
| | "cycles": int(cycles), |
| | "msg": f"Suspiciously low cycles ({cycles})", |
| | "parse_ok": True, |
| | "api_ok": True, |
| | "exec_ok": True, |
| | } |
| | if cycles > 200000: |
| | return { |
| | "score": 0.0, |
| | "correctness": 0.0, |
| | "cycles": int(cycles), |
| | "msg": f"Cycles too high ({cycles})", |
| | "parse_ok": True, |
| | "api_ok": True, |
| | "exec_ok": True, |
| | } |
| |
|
| | inp_values_p = ctx["inp_values_p"] |
| | expected = ctx["expected"] |
| | actual = machine.mem[inp_values_p : inp_values_p + len(inp.values)] |
| | if expected != actual: |
| | return { |
| | "score": 0.0, |
| | "correctness": 0.0, |
| | "cycles": int(cycles), |
| | "msg": f"Incorrect output (cycles={cycles})", |
| | "parse_ok": True, |
| | "api_ok": True, |
| | "exec_ok": True, |
| | } |
| |
|
| | score = SCORE_SCALE / cycles |
| | return { |
| | "score": float(score), |
| | "correctness": 1.0, |
| | "cycles": int(cycles), |
| | "msg": f"Success: {cycles} cycles", |
| | "parse_ok": True, |
| | "api_ok": True, |
| | "exec_ok": True, |
| | } |
| | except Exception as e: |
| | return { |
| | "score": 0.0, |
| | "correctness": 0.0, |
| | "cycles": None, |
| | "msg": f"Execution error: {str(e)[:200]}", |
| | "parse_ok": False, |
| | "api_ok": False, |
| | "exec_ok": False, |
| | } |
| |
|
| |
|
| | def perf_takehome_reward_fn(completions, prompts=None, **kwargs): |
| | rewards = [] |
| | for completion in completions: |
| | if isinstance(completion, list): |
| | text = completion[0].get("content", "") if completion else "" |
| | else: |
| | text = str(completion) |
| |
|
| | code = extract_code_block(text) |
| | result = verify_perf_takehome_code(code) |
| |
|
| | reward = 0.0 |
| | if result.get("correctness", 0.0) > 0: |
| | reward = float(result["score"]) + 1.0 |
| | else: |
| | if result.get("parse_ok"): |
| | reward += PARSE_REWARD |
| | if result.get("api_ok"): |
| | reward += API_REWARD |
| | if result.get("exec_ok"): |
| | reward += EXEC_REWARD |
| | cycles = result.get("cycles") |
| | with state_lock: |
| | if isinstance(cycles, int) and cycles < training_state["best_cycles"]: |
| | training_state["best_cycles"] = cycles |
| | training_state["best_code"] = code |
| | rewards.append(float(reward)) |
| | return rewards |
| |
|
| |
|
| | |
| | FEWSHOT_EXAMPLES = """Example format (not optimized): |
| | ```python |
| | class OptimizedKernelBuilder(KernelBuilder): |
| | def build_kernel(self, forest_height, n_nodes, batch_size, rounds): |
| | self.add("flow", ("halt",)) |
| | |
| | def run(): |
| | return (0,) |
| | ``` |
| | |
| | Example with scratch + load: |
| | ```python |
| | class OptimizedKernelBuilder(KernelBuilder): |
| | def build_kernel(self, forest_height, n_nodes, batch_size, rounds): |
| | tmp = self.alloc_scratch("tmp") |
| | self.add("load", ("const", tmp, 0)) |
| | self.add("flow", ("halt",)) |
| | |
| | def run(): |
| | return (0,) |
| | ``` |
| | """ |
| |
|
| | PERF_TAKEHOME_PROMPT = f"""Write an optimized VLIW/SIMD kernel. OUTPUT ONLY ONE ```python CODE BLOCK. |
| | |
| | ARCHITECTURE: 12 ALU + 6 VALU (VLEN=8) + 2 load + 2 store + 1 flow slots per cycle. 1536-word scratch. |
| | |
| | API (KernelBuilder): |
| | - alloc_scratch(name, length) -> addr |
| | - scratch_const(val, name) -> addr |
| | - add(engine, slot): engine in {{alu, valu, load, store, flow}} |
| | - alu: (op, dst, src1, src2) where op in {{+,-,*,//,%,^,&,|,<<,>>,<,==,!=,<=,>=,>}} |
| | - valu: same ops but on vectors (VLEN=8) |
| | - load: (load,dst,addr), (vload,dst,addr), (const,dst,val), (vbroadcast,dst,scalar_addr) |
| | - store: (store,addr,src), (vstore,addr,src) |
| | - flow: (select,dst,cond,t,f), (vselect,dst,cond,t,f), (cond_jump,cond,pc), (jump,pc), (halt,) |
| | - label(name): mark code position |
| | - build(slots, vliw=True): pack slots into VLIW bundle |
| | |
| | MEMORY: mem[4]=forest_values, mem[5]=inp_indices, mem[6]=inp_values (256 elements each) |
| | |
| | ALGORITHM: 16 rounds x 256 items: |
| | load idx,val |
| | node = tree[idx] |
| | val = hash(val ^ node) using HASH_STAGES |
| | idx = 2*idx + (1 if val%2==0 else 2) |
| | idx = 0 if idx >= n_nodes else idx |
| | store idx,val |
| | |
| | RULES: |
| | - Output exactly one python code block. |
| | - The code block must define: |
| | - class OptimizedKernelBuilder(KernelBuilder): override build_kernel() and emit instructions using add()/build() |
| | - def run(): return any tuple (ignored), but must exist |
| | - No imports. |
| | |
| | Baseline: {BASELINE_CYCLES:,} cycles. Target: <{TARGET_CYCLES:,} cycles. |
| | |
| | {FEWSHOT_EXAMPLES} |
| | """ |
| |
|
| |
|
| | def run_training(model_name, chunk_steps, max_total_steps, max_minutes, auto_continue): |
| | """Run GRPO + LoRA training with correctness-gated perf_takehome rewards.""" |
| | import torch |
| | from datasets import Dataset |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
| | from peft import LoraConfig |
| | from peft import PeftModel |
| | from trl import GRPOConfig, GRPOTrainer |
| | from transformers import TrainerCallback |
| |
|
| | log = [] |
| |
|
| | def add_log(msg): |
| | log.append(f"[{time.strftime('%H:%M:%S')}] {msg}") |
| | with state_lock: |
| | training_state["log"] = log.copy() |
| |
|
| | with state_lock: |
| | training_state["is_training"] = True |
| | training_state["should_stop"] = False |
| | training_state["log"] = [] |
| | training_state["best_cycles"] = BASELINE_CYCLES |
| | training_state["best_code"] = None |
| | training_state["step"] = 0 |
| |
|
| | try: |
| | add_log(f"Starting VLIW optimization training") |
| | add_log(f"Model: {model_name}") |
| | add_log(f"Chunk steps: {chunk_steps}") |
| | add_log(f"Auto-continue: {auto_continue} (max_total_steps={max_total_steps}, max_minutes={max_minutes})") |
| | add_log(f"Baseline: {BASELINE_CYCLES:,} cycles, Target: {TARGET_CYCLES:,} cycles") |
| | add_log(f"Adapter dir: {ADAPTER_DIR}") |
| | add_log(f"Adapter dataset: {ADAPTER_DATASET_REPO}/{ADAPTER_DATASET_SUBDIR}") |
| |
|
| | |
| | add_log("Loading tokenizer...") |
| | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| | add_log("[OK] Tokenizer ready") |
| |
|
| | |
| | add_log("Loading model (4-bit quantization)...") |
| | bnb_config = BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_quant_type="nf4", |
| | bnb_4bit_compute_dtype=torch.bfloat16, |
| | ) |
| | base_model = AutoModelForCausalLM.from_pretrained( |
| | model_name, |
| | quantization_config=bnb_config, |
| | device_map="auto", |
| | trust_remote_code=True, |
| | ) |
| | add_log(f"[OK] Base model loaded on {next(base_model.parameters()).device}") |
| |
|
| | |
| | _try_download_adapter(add_log) |
| |
|
| | |
| | resume_adapter = False |
| | if os.path.isdir(ADAPTER_DIR) and os.path.exists(os.path.join(ADAPTER_DIR, "adapter_config.json")): |
| | add_log("Loading existing LoRA adapter (resume)...") |
| | model = PeftModel.from_pretrained(base_model, ADAPTER_DIR, is_trainable=True) |
| | add_log("[OK] Adapter loaded") |
| | resume_adapter = True |
| | else: |
| | model = base_model |
| |
|
| | |
| | add_log("Creating VLIW optimization dataset...") |
| | prompts = [PERF_TAKEHOME_PROMPT] * 16 |
| | dataset = Dataset.from_dict({"prompt": prompts}) |
| | add_log(f"[OK] Dataset ready: {len(prompts)} prompts") |
| |
|
| | |
| | add_log("Setting up LoRA...") |
| | lora_config = LoraConfig( |
| | r=16, |
| | lora_alpha=32, |
| | target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], |
| | lora_dropout=0.05, |
| | bias="none", |
| | task_type="CAUSAL_LM", |
| | ) |
| |
|
| | progress = {"step": 0} |
| | start_time = time.time() |
| | max_seconds = float(max_minutes) * 60.0 if auto_continue else float("inf") |
| | total_target_steps = int(max_total_steps) if auto_continue else int(chunk_steps) |
| |
|
| | |
| | class VLIWCallback(TrainerCallback): |
| | def on_step_end(self, args, state, control, **kwargs): |
| | with state_lock: |
| | progress["step"] += 1 |
| | training_state["step"] = progress["step"] |
| | if training_state["should_stop"]: |
| | control.should_training_stop = True |
| | if training_state["best_cycles"] <= TARGET_CYCLES: |
| | control.should_training_stop = True |
| | return control |
| |
|
| | def on_log(self, args, state, control, logs=None, **kwargs): |
| | if logs: |
| | loss = logs.get("loss", "N/A") |
| | reward = logs.get("reward", logs.get("mean_reward", "N/A")) |
| | step = progress["step"] |
| | add_log(f"Step {step}: loss={loss:.4f}, reward={reward:.4f}" if isinstance(loss, float) else f"Step {step}: {logs}") |
| |
|
| | add_log("Creating GRPO trainer with perf_takehome rewards...") |
| | output_dir = os.path.join(PERSIST_DIR, "grpo_perf_takehome_output") |
| | os.makedirs(output_dir, exist_ok=True) |
| |
|
| | add_log("[OK] Trainer config ready") |
| | add_log("Starting training loop...") |
| | add_log("(Stops early if target reached; can auto-continue in chunks)") |
| |
|
| | chunk_idx = 0 |
| | while True: |
| | with state_lock: |
| | if training_state["should_stop"]: |
| | break |
| | if training_state["best_cycles"] <= TARGET_CYCLES: |
| | break |
| |
|
| | if progress["step"] >= total_target_steps: |
| | break |
| | if (time.time() - start_time) >= max_seconds: |
| | break |
| |
|
| | remaining = total_target_steps - progress["step"] |
| | this_chunk_steps = min(int(chunk_steps), int(remaining)) |
| | if this_chunk_steps <= 0: |
| | break |
| |
|
| | chunk_idx += 1 |
| | add_log(f"Chunk {chunk_idx}: training {this_chunk_steps} steps...") |
| |
|
| | config = GRPOConfig( |
| | output_dir=output_dir, |
| | num_train_epochs=1, |
| | max_steps=this_chunk_steps, |
| | per_device_train_batch_size=1, |
| | gradient_accumulation_steps=4, |
| | learning_rate=1e-5, |
| | logging_steps=1, |
| | save_steps=999999, |
| | report_to="none", |
| | remove_unused_columns=False, |
| | max_completion_length=2048, |
| | num_generations=4, |
| | ) |
| |
|
| | trainer_kwargs = { |
| | "model": model, |
| | "args": config, |
| | "train_dataset": dataset, |
| | "reward_funcs": perf_takehome_reward_fn, |
| | "processing_class": tokenizer, |
| | "callbacks": [VLIWCallback()], |
| | } |
| | if not resume_adapter: |
| | trainer_kwargs["peft_config"] = lora_config |
| |
|
| | trainer = GRPOTrainer(**trainer_kwargs) |
| |
|
| | train_result = trainer.train() |
| | metrics = train_result.metrics |
| | add_log(f"Chunk {chunk_idx} done: steps={metrics.get('train_steps', this_chunk_steps)}") |
| |
|
| | |
| | try: |
| | os.makedirs(os.path.dirname(ADAPTER_DIR), exist_ok=True) |
| | trainer.save_model(ADAPTER_DIR) |
| | add_log(f"[OK] Saved adapter to {ADAPTER_DIR}") |
| | _try_upload_adapter(add_log) |
| | except Exception as e: |
| | add_log(f"[ERR] Failed to save adapter: {str(e)[:120]}") |
| |
|
| | if not auto_continue: |
| | break |
| |
|
| | |
| | add_log("Testing trained model...") |
| | inputs = tokenizer(PERF_TAKEHOME_PROMPT, return_tensors="pt").to(model.device) |
| | with torch.no_grad(): |
| | outputs = model.generate( |
| | **inputs, |
| | max_new_tokens=1024, |
| | do_sample=True, |
| | temperature=0.7, |
| | top_p=0.9, |
| | ) |
| | result = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
|
| | code = extract_code_block(result) |
| | verify_out = verify_perf_takehome_code(code) |
| | if verify_out.get("correctness", 0.0) > 0: |
| | cycles = verify_out.get("cycles") |
| | add_log(f"Generated kernel verified: {cycles:,} cycles") |
| | speedup = BASELINE_CYCLES / max(int(cycles), 1) if isinstance(cycles, int) else 0.0 |
| | add_log(f"Speedup: {speedup:.2f}x over baseline") |
| | else: |
| | add_log(f"Generated kernel invalid: {verify_out.get('msg', '')[:160]}") |
| |
|
| | add_log("\n[OK] All done!") |
| |
|
| | except Exception as e: |
| | import traceback |
| | add_log(f"[ERR] Error: {e}") |
| | add_log(traceback.format_exc()[:800]) |
| | finally: |
| | with state_lock: |
| | training_state["is_training"] = False |
| | try: |
| | del model |
| | torch.cuda.empty_cache() |
| | except: |
| | pass |
| |
|
| | return "\n".join(log) |
| |
|
| |
|
| | def start_training(model_name, chunk_steps, max_total_steps, max_minutes, auto_continue): |
| | """Start training.""" |
| | with state_lock: |
| | if training_state["is_training"]: |
| | return "\n".join(training_state["log"][-200:]) or "Training already in progress. Please wait." |
| | training_state["is_training"] = True |
| | training_state["should_stop"] = False |
| | training_state["log"] = [f"[{time.strftime('%H:%M:%S')}] Starting training..."] |
| | training_state["step"] = 0 |
| |
|
| | thread = threading.Thread( |
| | target=run_training, |
| | args=( |
| | model_name, |
| | int(chunk_steps), |
| | int(max_total_steps), |
| | float(max_minutes), |
| | bool(auto_continue), |
| | ), |
| | daemon=True, |
| | ) |
| | thread.start() |
| | return "Training started. Logs will stream below." |
| |
|
| |
|
| | def stop_training(): |
| | """Request stop.""" |
| | with state_lock: |
| | if not training_state["is_training"]: |
| | return "No training in progress" |
| | training_state["should_stop"] = True |
| | return "Stop requested. Training will stop after current step." |
| |
|
| |
|
| | |
| | with gr.Blocks(title="VLIW Optimizer") as demo: |
| | gr.Markdown("# VLIW Kernel Optimizer - RL Training") |
| | gr.Markdown(f""" |
| | Train a language model with reinforcement learning (LoRA) at test time to generate correct, fast VLIW/SIMD kernels. |
| | |
| | **Goal:** Reduce cycle count from **{BASELINE_CYCLES:,}** (baseline) to **<{TARGET_CYCLES:,}** (108x speedup) |
| | |
| | **How it works:** |
| | 1. Model generates Python kernel builder code |
| | 2. Simulator checks correctness vs reference and measures cycles |
| | 3. GRPO updates LoRA weights; adapter is saved and reloaded from `{ADAPTER_DIR}` |
| | """) |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | status_box = gr.Textbox( |
| | label="System Status", |
| | value=get_status(), |
| | lines=12, |
| | interactive=False, |
| | ) |
| |
|
| | with gr.Column(scale=2): |
| | model_dropdown = gr.Dropdown( |
| | choices=[ |
| | "Qwen/Qwen2.5-Coder-1.5B-Instruct", |
| | "Qwen/Qwen2.5-Coder-3B-Instruct", |
| | ], |
| | value="Qwen/Qwen2.5-Coder-1.5B-Instruct", |
| | label="Model", |
| | ) |
| | chunk_steps_slider = gr.Slider( |
| | minimum=5, |
| | maximum=100, |
| | value=20, |
| | step=5, |
| | label="Chunk Steps", |
| | ) |
| | auto_continue_checkbox = gr.Checkbox( |
| | value=False, |
| | label="Auto-continue (chain chunks)", |
| | ) |
| | max_total_steps_slider = gr.Slider( |
| | minimum=5, |
| | maximum=500, |
| | value=100, |
| | step=5, |
| | label="Max Total Steps", |
| | ) |
| | max_minutes_number = gr.Number( |
| | value=60, |
| | precision=0, |
| | label="Max Minutes", |
| | ) |
| |
|
| | with gr.Row(): |
| | start_btn = gr.Button("Start Training", variant="primary") |
| | stop_btn = gr.Button("Stop", variant="stop") |
| |
|
| | output_box = gr.Textbox( |
| | label="Training Log", |
| | lines=25, |
| | interactive=False, |
| | value="Click 'Start Training' to begin VLIW optimization.", |
| | ) |
| |
|
| | def poll_log(): |
| | with state_lock: |
| | if not training_state["log"]: |
| | return "" |
| | lines = training_state["log"][-200:] |
| | return "\n".join(line[:400] for line in lines) |
| |
|
| | start_btn.click( |
| | start_training, |
| | [model_dropdown, chunk_steps_slider, max_total_steps_slider, max_minutes_number, auto_continue_checkbox], |
| | [output_box], |
| | queue=False, |
| | ) |
| | stop_btn.click(stop_training, [], [output_box], queue=False) |
| | refresh_btn = gr.Button("Refresh Log") |
| | refresh_btn.click(poll_log, outputs=[output_box], queue=False) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch(server_name="0.0.0.0", server_port=7860) |
| |
|