Spaces:
Running
on
A10G
Running
on
A10G
| """ | |
| HF Spaces app for VLIW kernel optimization via RL. | |
| Uses actual simulator for correctness-gated cycle-count rewards. | |
| """ | |
| import os | |
| import sys | |
| import gradio as gr | |
| import threading | |
| import time | |
| import random | |
| import re | |
| from copy import copy | |
| from pathlib import Path | |
| # Check imports at startup | |
| startup_log = [] | |
| def check_import(name, import_fn): | |
| try: | |
| result = import_fn() | |
| startup_log.append(f"[OK] {name}: {result}") | |
| return True | |
| except Exception as e: | |
| startup_log.append(f"[ERR] {name}: {str(e)[:80]}") | |
| return False | |
| check_import("torch", lambda: __import__("torch").__version__) | |
| check_import("transformers", lambda: __import__("transformers").__version__) | |
| check_import("datasets", lambda: __import__("datasets").__version__) | |
| check_import("peft", lambda: __import__("peft").__version__) | |
| check_import("trl", lambda: __import__("trl").__version__) | |
| check_import("huggingface_hub", lambda: __import__("huggingface_hub").__version__) | |
| try: | |
| from trl import GRPOConfig, GRPOTrainer | |
| startup_log.append("[OK] GRPOTrainer: OK") | |
| except Exception as e: | |
| startup_log.append(f"[ERR] GRPOTrainer: {e}") | |
| try: | |
| import torch | |
| if torch.cuda.is_available(): | |
| startup_log.append(f"[OK] CUDA: {torch.cuda.get_device_name(0)}") | |
| else: | |
| startup_log.append("[ERR] CUDA: Not available") | |
| except Exception as e: | |
| startup_log.append(f"[ERR] CUDA check: {e}") | |
| # Prefer simulator + KernelBuilder from bundled original_performance_takehome. | |
| # In Spaces, this keeps evaluation consistent and enables correctness checks. | |
| THIS_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| PERF_TAKEHOME_PATH = os.path.join(THIS_DIR, "original_performance_takehome") | |
| if os.path.isdir(PERF_TAKEHOME_PATH): | |
| sys.path.insert(0, PERF_TAKEHOME_PATH) | |
| # Import simulator components | |
| try: | |
| from problem import ( | |
| Machine, Tree, Input, DebugInfo, | |
| build_mem_image, reference_kernel2, | |
| SLOT_LIMITS, VLEN, N_CORES, SCRATCH_SIZE, CoreState | |
| ) | |
| from perf_takehome import KernelBuilder, HASH_STAGES | |
| startup_log.append("[OK] VLIW Simulator: OK") | |
| SIMULATOR_AVAILABLE = True | |
| except Exception as e: | |
| startup_log.append(f"[ERR] VLIW Simulator: {e}") | |
| SIMULATOR_AVAILABLE = False | |
| # Hugging Face Hub adapter persistence via dataset repo | |
| try: | |
| from huggingface_hub import HfApi, snapshot_download | |
| startup_log.append("[OK] huggingface_hub: OK") | |
| HF_HUB_AVAILABLE = True | |
| except Exception as e: | |
| startup_log.append(f"[ERR] huggingface_hub: {str(e)[:80]}") | |
| HF_HUB_AVAILABLE = False | |
| # Constants | |
| BASELINE_CYCLES = 147734 | |
| TARGET_CYCLES = 1363 | |
| SCORE_SCALE = 3000.0 | |
| PERSIST_DIR = "/data" if os.path.isdir("/data") else "." | |
| ADAPTER_DIR = os.path.join(PERSIST_DIR, "adapters", "perf_takehome_latest") | |
| ADAPTER_DATASET_REPO = os.environ.get("ADAPTER_DATASET_REPO", "CreativeEngineer/vliw-optimizer-adapters") | |
| ADAPTER_DATASET_SUBDIR = os.environ.get("ADAPTER_DATASET_SUBDIR", "perf_takehome_latest") | |
| # Training state | |
| training_state = { | |
| "is_training": False, | |
| "should_stop": False, | |
| "log": [], | |
| "best_cycles": BASELINE_CYCLES, | |
| "best_code": None, | |
| "step": 0, | |
| } | |
| state_lock = threading.Lock() | |
| _eval_context = {} | |
| def get_status(): | |
| return "\n".join(startup_log) | |
| def extract_code_block(text: str) -> str: | |
| # Prefer closed fences | |
| pattern = r"```python\s*(.*?)```" | |
| matches = re.findall(pattern, text, re.DOTALL) | |
| if matches: | |
| return matches[-1].strip() | |
| pattern = r"```\s*(.*?)```" | |
| matches = re.findall(pattern, text, re.DOTALL) | |
| if matches: | |
| return matches[-1].strip() | |
| # Handle unclosed fences (common when generation truncates) | |
| if "```python" in text: | |
| after = text.split("```python", 1)[1] | |
| if "```" in after: | |
| after = after.split("```", 1)[0] | |
| return after.strip() | |
| if "```" in text: | |
| after = text.split("```", 1)[1] | |
| if "```" in after: | |
| after = after.split("```", 1)[0] | |
| return after.strip() | |
| return text.strip() | |
| def _hf_token() -> str | None: | |
| return os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") | |
| def _ensure_dir(path: str) -> None: | |
| Path(path).mkdir(parents=True, exist_ok=True) | |
| def _adapter_exists(path: str) -> bool: | |
| return os.path.exists(os.path.join(path, "adapter_config.json")) | |
| def _try_download_adapter(add_log) -> None: | |
| if not HF_HUB_AVAILABLE: | |
| add_log("[ERR] Hub sync disabled: huggingface_hub not available") | |
| return | |
| _ensure_dir(os.path.dirname(ADAPTER_DIR)) | |
| allow = [f"{ADAPTER_DATASET_SUBDIR}/**"] | |
| try: | |
| snapshot_download( | |
| repo_id=ADAPTER_DATASET_REPO, | |
| repo_type="dataset", | |
| allow_patterns=allow, | |
| local_dir=os.path.dirname(ADAPTER_DIR), | |
| local_dir_use_symlinks=False, | |
| token=_hf_token(), | |
| ) | |
| downloaded = os.path.join(os.path.dirname(ADAPTER_DIR), ADAPTER_DATASET_SUBDIR) | |
| if _adapter_exists(downloaded): | |
| if downloaded != ADAPTER_DIR: | |
| _ensure_dir(os.path.dirname(ADAPTER_DIR)) | |
| # Simple overwrite by copying files into ADAPTER_DIR | |
| _ensure_dir(ADAPTER_DIR) | |
| for root, _, files in os.walk(downloaded): | |
| rel = os.path.relpath(root, downloaded) | |
| dst_root = ADAPTER_DIR if rel == "." else os.path.join(ADAPTER_DIR, rel) | |
| _ensure_dir(dst_root) | |
| for name in files: | |
| src = os.path.join(root, name) | |
| dst = os.path.join(dst_root, name) | |
| with open(src, "rb") as fsrc, open(dst, "wb") as fdst: | |
| fdst.write(fsrc.read()) | |
| add_log(f"[OK] Downloaded adapter from dataset: {ADAPTER_DATASET_REPO}/{ADAPTER_DATASET_SUBDIR}") | |
| else: | |
| add_log("ℹ No adapter found in dataset yet") | |
| except Exception as e: | |
| add_log(f"ℹ Adapter download skipped: {str(e)[:160]}") | |
| def _try_upload_adapter(add_log) -> None: | |
| if not HF_HUB_AVAILABLE: | |
| add_log("[ERR] Hub sync disabled: huggingface_hub not available") | |
| return | |
| if not _adapter_exists(ADAPTER_DIR): | |
| add_log("ℹ No adapter to upload yet") | |
| return | |
| token = _hf_token() | |
| if token is None: | |
| add_log("ℹ No HF token set (HF_TOKEN/HUGGINGFACE_HUB_TOKEN); skipping upload") | |
| return | |
| try: | |
| api = HfApi(token=token) | |
| api.create_repo(repo_id=ADAPTER_DATASET_REPO, repo_type="dataset", exist_ok=True) | |
| api.upload_folder( | |
| repo_id=ADAPTER_DATASET_REPO, | |
| repo_type="dataset", | |
| folder_path=ADAPTER_DIR, | |
| path_in_repo=ADAPTER_DATASET_SUBDIR, | |
| commit_message="Update perf_takehome adapter", | |
| ) | |
| add_log(f"[OK] Uploaded adapter to dataset: {ADAPTER_DATASET_REPO}/{ADAPTER_DATASET_SUBDIR}") | |
| except Exception as e: | |
| add_log(f"ℹ Adapter upload skipped: {str(e)[:160]}") | |
| def _run_machine_with_cycle_limit(machine: Machine, max_cycles: int) -> bool: | |
| for core in machine.cores: | |
| if core.state == CoreState.PAUSED: | |
| core.state = CoreState.RUNNING | |
| while any(c.state == CoreState.RUNNING for c in machine.cores): | |
| has_non_debug = False | |
| for core in machine.cores: | |
| if core.state != CoreState.RUNNING: | |
| continue | |
| if core.pc >= len(machine.program): | |
| core.state = CoreState.STOPPED | |
| continue | |
| instr = machine.program[core.pc] | |
| core.pc += 1 | |
| machine.step(instr, core) | |
| if any(name != "debug" for name in instr.keys()): | |
| has_non_debug = True | |
| if has_non_debug: | |
| machine.cycle += 1 | |
| if machine.cycle >= max_cycles: | |
| for core in machine.cores: | |
| core.state = CoreState.STOPPED | |
| return False | |
| return True | |
| def _get_eval_context(seed: int) -> dict: | |
| with state_lock: | |
| cached = _eval_context.get(seed) | |
| if cached is not None: | |
| return cached | |
| random.seed(seed) | |
| forest = Tree.generate(10) | |
| inp = Input.generate(forest, 256, 16) | |
| mem0 = build_mem_image(forest, inp) | |
| ref_mem = None | |
| for ref_mem in reference_kernel2(list(mem0)): | |
| pass | |
| if ref_mem is None: | |
| raise RuntimeError("Reference kernel produced no output") | |
| inp_values_p = ref_mem[6] | |
| expected = ref_mem[inp_values_p : inp_values_p + len(inp.values)] | |
| ctx = { | |
| "forest": forest, | |
| "inp": inp, | |
| "mem0": mem0, | |
| "expected": expected, | |
| "inp_values_p": inp_values_p, | |
| } | |
| with state_lock: | |
| _eval_context[seed] = ctx | |
| return ctx | |
| def verify_perf_takehome_code(code: str, seed: int = 123) -> dict: | |
| if not SIMULATOR_AVAILABLE: | |
| return { | |
| "score": 0.0, | |
| "correctness": 0.0, | |
| "cycles": None, | |
| "msg": "Simulator unavailable", | |
| } | |
| try: | |
| code = code.strip() | |
| if not code: | |
| return { | |
| "score": 0.0, | |
| "correctness": 0.0, | |
| "cycles": None, | |
| "msg": "Empty code", | |
| } | |
| if "OptimizedKernelBuilder" not in code: | |
| return { | |
| "score": 0.0, | |
| "correctness": 0.0, | |
| "cycles": None, | |
| "msg": "Missing OptimizedKernelBuilder", | |
| } | |
| if "def run" not in code: | |
| return { | |
| "score": 0.0, | |
| "correctness": 0.0, | |
| "cycles": None, | |
| "msg": "Missing run()", | |
| } | |
| safe_builtins = { | |
| "abs": abs, | |
| "all": all, | |
| "any": any, | |
| "dict": dict, | |
| "enumerate": enumerate, | |
| "int": int, | |
| "len": len, | |
| "list": list, | |
| "max": max, | |
| "min": min, | |
| "range": range, | |
| "sum": sum, | |
| "tuple": tuple, | |
| "zip": zip, | |
| } | |
| exec_globals = { | |
| "__builtins__": safe_builtins, | |
| "KernelBuilder": KernelBuilder, | |
| "HASH_STAGES": HASH_STAGES, | |
| "VLEN": VLEN, | |
| "SLOT_LIMITS": SLOT_LIMITS, | |
| } | |
| exec(code, exec_globals) | |
| if "OptimizedKernelBuilder" not in exec_globals: | |
| return { | |
| "score": 0.0, | |
| "correctness": 0.0, | |
| "cycles": None, | |
| "msg": "OptimizedKernelBuilder not defined after exec", | |
| } | |
| ctx = _get_eval_context(seed) | |
| forest = ctx["forest"] | |
| inp = ctx["inp"] | |
| mem0 = ctx["mem0"] | |
| kb = exec_globals["OptimizedKernelBuilder"]() | |
| kb.build_kernel(10, len(forest.values), 256, 16) | |
| machine = Machine( | |
| list(mem0), | |
| kb.instrs, | |
| kb.debug_info(), | |
| n_cores=N_CORES, | |
| trace=False, | |
| ) | |
| machine.enable_pause = False | |
| machine.enable_debug = False | |
| ok = _run_machine_with_cycle_limit(machine, max_cycles=250000) | |
| if not ok: | |
| return { | |
| "score": 0.0, | |
| "correctness": 0.0, | |
| "cycles": int(machine.cycle), | |
| "msg": f"Exceeded cycle limit (cycles={machine.cycle})", | |
| } | |
| cycles = machine.cycle | |
| if cycles <= 100: | |
| return { | |
| "score": 0.0, | |
| "correctness": 0.0, | |
| "cycles": int(cycles), | |
| "msg": f"Suspiciously low cycles ({cycles})", | |
| } | |
| if cycles > 200000: | |
| return { | |
| "score": 0.0, | |
| "correctness": 0.0, | |
| "cycles": int(cycles), | |
| "msg": f"Cycles too high ({cycles})", | |
| } | |
| inp_values_p = ctx["inp_values_p"] | |
| expected = ctx["expected"] | |
| actual = machine.mem[inp_values_p : inp_values_p + len(inp.values)] | |
| if expected != actual: | |
| return { | |
| "score": 0.0, | |
| "correctness": 0.0, | |
| "cycles": int(cycles), | |
| "msg": f"Incorrect output (cycles={cycles})", | |
| } | |
| score = SCORE_SCALE / cycles | |
| return { | |
| "score": float(score), | |
| "correctness": 1.0, | |
| "cycles": int(cycles), | |
| "msg": f"Success: {cycles} cycles", | |
| } | |
| except Exception as e: | |
| return { | |
| "score": 0.0, | |
| "correctness": 0.0, | |
| "cycles": None, | |
| "msg": f"Execution error: {str(e)[:200]}", | |
| } | |
| def perf_takehome_reward_fn(completions, prompts=None, **kwargs): | |
| rewards = [] | |
| for completion in completions: | |
| if isinstance(completion, list): | |
| text = completion[0].get("content", "") if completion else "" | |
| else: | |
| text = str(completion) | |
| code = extract_code_block(text) | |
| result = verify_perf_takehome_code(code) | |
| reward = 0.0 | |
| if result.get("correctness", 0.0) > 0: | |
| reward = float(result["score"]) + 1.0 | |
| cycles = result.get("cycles") | |
| with state_lock: | |
| if isinstance(cycles, int) and cycles < training_state["best_cycles"]: | |
| training_state["best_cycles"] = cycles | |
| training_state["best_code"] = code | |
| rewards.append(float(reward)) | |
| return rewards | |
| # Prompt template for VLIW optimization | |
| PERF_TAKEHOME_PROMPT = f"""Write an optimized VLIW/SIMD kernel. OUTPUT ONLY ONE ```python CODE BLOCK. | |
| ARCHITECTURE: 12 ALU + 6 VALU (VLEN=8) + 2 load + 2 store + 1 flow slots per cycle. 1536-word scratch. | |
| API (KernelBuilder): | |
| - alloc_scratch(name, length) -> addr | |
| - scratch_const(val, name) -> addr | |
| - add(engine, slot): engine in {{alu, valu, load, store, flow}} | |
| - alu: (op, dst, src1, src2) where op in {{+,-,*,//,%,^,&,|,<<,>>,<,==,!=,<=,>=,>}} | |
| - valu: same ops but on vectors (VLEN=8) | |
| - load: (load,dst,addr), (vload,dst,addr), (const,dst,val), (vbroadcast,dst,scalar_addr) | |
| - store: (store,addr,src), (vstore,addr,src) | |
| - flow: (select,dst,cond,t,f), (vselect,dst,cond,t,f), (cond_jump,cond,pc), (jump,pc), (halt,) | |
| - label(name): mark code position | |
| - build(slots, vliw=True): pack slots into VLIW bundle | |
| MEMORY: mem[4]=forest_values, mem[5]=inp_indices, mem[6]=inp_values (256 elements each) | |
| ALGORITHM: 16 rounds x 256 items: | |
| load idx,val | |
| node = tree[idx] | |
| val = hash(val ^ node) using HASH_STAGES | |
| idx = 2*idx + (1 if val%2==0 else 2) | |
| idx = 0 if idx >= n_nodes else idx | |
| store idx,val | |
| RULES: | |
| - Output exactly one python code block. | |
| - The code block must define: | |
| - class OptimizedKernelBuilder(KernelBuilder): override build_kernel() and emit instructions using add()/build() | |
| - def run(): return any tuple (ignored), but must exist | |
| - No imports. | |
| Baseline: {BASELINE_CYCLES:,} cycles. Target: <{TARGET_CYCLES:,} cycles. | |
| """ | |
| def run_training(model_name, chunk_steps, max_total_steps, max_minutes, auto_continue): | |
| """Run GRPO + LoRA training with correctness-gated perf_takehome rewards.""" | |
| import torch | |
| from datasets import Dataset | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from peft import LoraConfig | |
| from peft import PeftModel | |
| from trl import GRPOConfig, GRPOTrainer | |
| from transformers import TrainerCallback | |
| log = [] | |
| def add_log(msg): | |
| log.append(f"[{time.strftime('%H:%M:%S')}] {msg}") | |
| with state_lock: | |
| training_state["log"] = log.copy() | |
| with state_lock: | |
| training_state["is_training"] = True | |
| training_state["should_stop"] = False | |
| training_state["log"] = [] | |
| training_state["best_cycles"] = BASELINE_CYCLES | |
| training_state["best_code"] = None | |
| training_state["step"] = 0 | |
| try: | |
| add_log(f"Starting VLIW optimization training") | |
| add_log(f"Model: {model_name}") | |
| add_log(f"Chunk steps: {chunk_steps}") | |
| add_log(f"Auto-continue: {auto_continue} (max_total_steps={max_total_steps}, max_minutes={max_minutes})") | |
| add_log(f"Baseline: {BASELINE_CYCLES:,} cycles, Target: {TARGET_CYCLES:,} cycles") | |
| add_log(f"Adapter dir: {ADAPTER_DIR}") | |
| add_log(f"Adapter dataset: {ADAPTER_DATASET_REPO}/{ADAPTER_DATASET_SUBDIR}") | |
| # Load tokenizer | |
| add_log("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| add_log("[OK] Tokenizer ready") | |
| # Load model with 4-bit quantization | |
| add_log("Loading model (4-bit quantization)...") | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| add_log(f"[OK] Base model loaded on {next(base_model.parameters()).device}") | |
| # Try to restore adapter from dataset before loading it | |
| _try_download_adapter(add_log) | |
| # Resume LoRA adapter if present | |
| resume_adapter = False | |
| if os.path.isdir(ADAPTER_DIR) and os.path.exists(os.path.join(ADAPTER_DIR, "adapter_config.json")): | |
| add_log("Loading existing LoRA adapter (resume)...") | |
| model = PeftModel.from_pretrained(base_model, ADAPTER_DIR, is_trainable=True) | |
| add_log("[OK] Adapter loaded") | |
| resume_adapter = True | |
| else: | |
| model = base_model | |
| # Create dataset with prompts | |
| add_log("Creating VLIW optimization dataset...") | |
| prompts = [PERF_TAKEHOME_PROMPT] * 16 | |
| dataset = Dataset.from_dict({"prompt": prompts}) | |
| add_log(f"[OK] Dataset ready: {len(prompts)} prompts") | |
| # LoRA config | |
| add_log("Setting up LoRA...") | |
| lora_config = LoraConfig( | |
| r=16, | |
| lora_alpha=32, | |
| target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], | |
| lora_dropout=0.05, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| ) | |
| progress = {"step": 0} | |
| start_time = time.time() | |
| max_seconds = float(max_minutes) * 60.0 if auto_continue else float("inf") | |
| total_target_steps = int(max_total_steps) if auto_continue else int(chunk_steps) | |
| # Custom callback for logging + early stop | |
| class VLIWCallback(TrainerCallback): | |
| def on_step_end(self, args, state, control, **kwargs): | |
| with state_lock: | |
| progress["step"] += 1 | |
| training_state["step"] = progress["step"] | |
| if training_state["should_stop"]: | |
| control.should_training_stop = True | |
| if training_state["best_cycles"] <= TARGET_CYCLES: | |
| control.should_training_stop = True | |
| return control | |
| def on_log(self, args, state, control, logs=None, **kwargs): | |
| if logs: | |
| loss = logs.get("loss", "N/A") | |
| reward = logs.get("reward", logs.get("mean_reward", "N/A")) | |
| step = progress["step"] | |
| add_log(f"Step {step}: loss={loss:.4f}, reward={reward:.4f}" if isinstance(loss, float) else f"Step {step}: {logs}") | |
| add_log("Creating GRPO trainer with perf_takehome rewards...") | |
| output_dir = os.path.join(PERSIST_DIR, "grpo_perf_takehome_output") | |
| os.makedirs(output_dir, exist_ok=True) | |
| add_log("[OK] Trainer config ready") | |
| add_log("Starting training loop...") | |
| add_log("(Stops early if target reached; can auto-continue in chunks)") | |
| chunk_idx = 0 | |
| while True: | |
| with state_lock: | |
| if training_state["should_stop"]: | |
| break | |
| if training_state["best_cycles"] <= TARGET_CYCLES: | |
| break | |
| if progress["step"] >= total_target_steps: | |
| break | |
| if (time.time() - start_time) >= max_seconds: | |
| break | |
| remaining = total_target_steps - progress["step"] | |
| this_chunk_steps = min(int(chunk_steps), int(remaining)) | |
| if this_chunk_steps <= 0: | |
| break | |
| chunk_idx += 1 | |
| add_log(f"Chunk {chunk_idx}: training {this_chunk_steps} steps...") | |
| config = GRPOConfig( | |
| output_dir=output_dir, | |
| num_train_epochs=1, | |
| max_steps=this_chunk_steps, | |
| per_device_train_batch_size=1, | |
| gradient_accumulation_steps=4, | |
| learning_rate=1e-5, | |
| logging_steps=1, | |
| save_steps=999999, | |
| report_to="none", | |
| remove_unused_columns=False, | |
| max_completion_length=2048, | |
| num_generations=4, | |
| ) | |
| trainer_kwargs = { | |
| "model": model, | |
| "args": config, | |
| "train_dataset": dataset, | |
| "reward_funcs": perf_takehome_reward_fn, | |
| "processing_class": tokenizer, | |
| "callbacks": [VLIWCallback()], | |
| } | |
| if not resume_adapter: | |
| trainer_kwargs["peft_config"] = lora_config | |
| trainer = GRPOTrainer(**trainer_kwargs) | |
| train_result = trainer.train() | |
| metrics = train_result.metrics | |
| add_log(f"Chunk {chunk_idx} done: steps={metrics.get('train_steps', this_chunk_steps)}") | |
| # Save adapter after each chunk so it persists across restarts | |
| try: | |
| os.makedirs(os.path.dirname(ADAPTER_DIR), exist_ok=True) | |
| trainer.save_model(ADAPTER_DIR) | |
| add_log(f"[OK] Saved adapter to {ADAPTER_DIR}") | |
| _try_upload_adapter(add_log) | |
| except Exception as e: | |
| add_log(f"[ERR] Failed to save adapter: {str(e)[:120]}") | |
| if not auto_continue: | |
| break | |
| # Test generation | |
| add_log("Testing trained model...") | |
| inputs = tokenizer(PERF_TAKEHOME_PROMPT, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=1024, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| ) | |
| result = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| code = extract_code_block(result) | |
| verify_out = verify_perf_takehome_code(code) | |
| if verify_out.get("correctness", 0.0) > 0: | |
| cycles = verify_out.get("cycles") | |
| add_log(f"Generated kernel verified: {cycles:,} cycles") | |
| speedup = BASELINE_CYCLES / max(int(cycles), 1) if isinstance(cycles, int) else 0.0 | |
| add_log(f"Speedup: {speedup:.2f}x over baseline") | |
| else: | |
| add_log(f"Generated kernel invalid: {verify_out.get('msg', '')[:160]}") | |
| add_log("\n[OK] All done!") | |
| except Exception as e: | |
| import traceback | |
| add_log(f"[ERR] Error: {e}") | |
| add_log(traceback.format_exc()[:800]) | |
| finally: | |
| with state_lock: | |
| training_state["is_training"] = False | |
| try: | |
| del model | |
| torch.cuda.empty_cache() | |
| except: | |
| pass | |
| return "\n".join(log) | |
| def start_training(model_name, chunk_steps, max_total_steps, max_minutes, auto_continue): | |
| """Start training.""" | |
| with state_lock: | |
| if training_state["is_training"]: | |
| return "\n".join(training_state["log"][-200:]) or "Training already in progress. Please wait." | |
| training_state["is_training"] = True | |
| training_state["should_stop"] = False | |
| training_state["log"] = [f"[{time.strftime('%H:%M:%S')}] Starting training..."] | |
| training_state["step"] = 0 | |
| thread = threading.Thread( | |
| target=run_training, | |
| args=( | |
| model_name, | |
| int(chunk_steps), | |
| int(max_total_steps), | |
| float(max_minutes), | |
| bool(auto_continue), | |
| ), | |
| daemon=True, | |
| ) | |
| thread.start() | |
| return "Training started. Logs will stream below." | |
| def stop_training(): | |
| """Request stop.""" | |
| with state_lock: | |
| if not training_state["is_training"]: | |
| return "No training in progress" | |
| training_state["should_stop"] = True | |
| return "Stop requested. Training will stop after current step." | |
| # Gradio UI | |
| with gr.Blocks(title="VLIW Optimizer") as demo: | |
| gr.Markdown("# VLIW Kernel Optimizer - RL Training") | |
| gr.Markdown(f""" | |
| Train a language model with reinforcement learning (LoRA) at test time to generate correct, fast VLIW/SIMD kernels. | |
| **Goal:** Reduce cycle count from **{BASELINE_CYCLES:,}** (baseline) to **<{TARGET_CYCLES:,}** (108x speedup) | |
| **How it works:** | |
| 1. Model generates Python kernel builder code | |
| 2. Simulator checks correctness vs reference and measures cycles | |
| 3. GRPO updates LoRA weights; adapter is saved and reloaded from `{ADAPTER_DIR}` | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| status_box = gr.Textbox( | |
| label="System Status", | |
| value=get_status(), | |
| lines=12, | |
| interactive=False, | |
| ) | |
| with gr.Column(scale=2): | |
| model_dropdown = gr.Dropdown( | |
| choices=[ | |
| "Qwen/Qwen2.5-Coder-1.5B-Instruct", | |
| "Qwen/Qwen2.5-Coder-3B-Instruct", | |
| ], | |
| value="Qwen/Qwen2.5-Coder-1.5B-Instruct", | |
| label="Model", | |
| ) | |
| chunk_steps_slider = gr.Slider( | |
| minimum=5, | |
| maximum=100, | |
| value=20, | |
| step=5, | |
| label="Chunk Steps", | |
| ) | |
| auto_continue_checkbox = gr.Checkbox( | |
| value=False, | |
| label="Auto-continue (chain chunks)", | |
| ) | |
| max_total_steps_slider = gr.Slider( | |
| minimum=5, | |
| maximum=500, | |
| value=100, | |
| step=5, | |
| label="Max Total Steps", | |
| ) | |
| max_minutes_number = gr.Number( | |
| value=60, | |
| precision=0, | |
| label="Max Minutes", | |
| ) | |
| with gr.Row(): | |
| start_btn = gr.Button("Start Training", variant="primary") | |
| stop_btn = gr.Button("Stop", variant="stop") | |
| output_box = gr.Textbox( | |
| label="Training Log", | |
| lines=25, | |
| interactive=False, | |
| value="Click 'Start Training' to begin VLIW optimization.", | |
| ) | |
| def poll_log(): | |
| with state_lock: | |
| if not training_state["log"]: | |
| return "" | |
| lines = training_state["log"][-200:] | |
| return "\n".join(line[:400] for line in lines) | |
| start_btn.click( | |
| start_training, | |
| [model_dropdown, chunk_steps_slider, max_total_steps_slider, max_minutes_number, auto_continue_checkbox], | |
| [output_box], | |
| queue=False, | |
| ) | |
| stop_btn.click(stop_training, [], [output_box], queue=False) | |
| refresh_btn = gr.Button("Refresh Log") | |
| refresh_btn.click(poll_log, outputs=[output_box], queue=False) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |