| """ |
| TD Lang Compiler - turns a TDProgram AST into readable Python code that calls td_fuse. |
| |
| Phase 1 commands: load, merge, heal, eval, commit. |
| Phase 2 commands: synth, train, debate, diagnose. |
| Phase 3 commands: fork, reset, prune, edit. |
| Phase 4 commands: snapshot, report. Blocks: data_contract, reward_contract. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import hashlib |
| import textwrap |
| from datetime import datetime |
| from typing import List, Optional, Set |
|
|
| from .ast_nodes import ( |
| AbsorbCmd, |
| BudgetBlock, |
| CommitCmd, |
| DataContractBlock, |
| DebateCmd, |
| DiagnoseCmd, |
| EditCmd, |
| EvalCmd, |
| FuseCmd, |
| ForkCmd, |
| IfBlock, |
| GateBlock, |
| HealCmd, |
| LoadCmd, |
| MergeCmd, |
| NotifyCmd, |
| OnErrorBlock, |
| PruneCmd, |
| RepeatBlock, |
| ReportCmd, |
| ResetCmd, |
| RewardContractBlock, |
| SaveCmd, |
| ScheduleCmd, |
| DownloadCmd, |
| LogBlock, |
| CompareCmd, |
| VerifyCmd, |
| VoteCmd, |
| PromptBlock, |
| DistillCmd, |
| RollbackCmd, |
| CurriculumCmd, |
| StarCmd, |
| BestOfCmd, |
| ExploitCmd, |
| ArenaCmd, |
| ResearchArenaCmd, |
| SetupBlock, |
| SnapshotCmd, |
| SynthCmd, |
| TDProgram, |
| TrainCmd, |
| ) |
| from .errors import TDCompileError |
|
|
| |
|
|
|
|
| class TDCompiler: |
| """Compile a TDProgram into a Python script string.""" |
|
|
| GPU_HOURLY = 4.0 |
|
|
| def __init__(self) -> None: |
| self._aliases: Set[str] = set() |
| self._lines: List[str] = [] |
| self._indent: int = 0 |
|
|
| |
| def compile(self, program: TDProgram) -> str: |
| """Compile a TDProgram into Python code.""" |
| self._reset_state() |
| self._validate(program) |
| self._build_script(program) |
| return "\n".join(self._lines) |
|
|
| |
| def _reset_state(self) -> None: |
| self._aliases.clear() |
| self._lines = [] |
| self._indent = 0 |
|
|
| def _validate(self, program: TDProgram) -> None: |
| """Semantic validation before emitting code.""" |
| seen: Set[str] = set() |
| for cmd in program.commands: |
| if isinstance(cmd, LoadCmd): |
| if cmd.alias in seen: |
| raise TDCompileError( |
| f"Alias '{cmd.alias}' is already used. Pick a different name.", |
| ) |
| seen.add(cmd.alias) |
| elif isinstance(cmd, MergeCmd): |
| if cmd.target not in seen: |
| raise TDCompileError( |
| f"Can't merge into '{cmd.target}' - it hasn't been loaded yet.", |
| hint=f'Add: load "{cmd.source}" as {cmd.target}', |
| ) |
| elif isinstance(cmd, (HealCmd, EvalCmd, CommitCmd)): |
| if cmd.target not in seen: |
| raise TDCompileError( |
| f"Can't use '{cmd.target}' - it hasn't been loaded yet.", |
| hint=f'Add: load "model/path" as {cmd.target}', |
| ) |
| elif isinstance(cmd, (SynthCmd, TrainCmd, DebateCmd, DiagnoseCmd)): |
| if cmd.target not in seen: |
| raise TDCompileError( |
| f"Can't use '{cmd.target}' - it hasn't been loaded yet.", |
| hint=f'Add: load "model/path" as {cmd.target}', |
| ) |
| elif isinstance(cmd, ForkCmd): |
| if cmd.source not in seen: |
| raise TDCompileError( |
| f"Can't fork '{cmd.source}' - it hasn't been loaded yet.", |
| hint=f'Add: load "model/path" as {cmd.source}', |
| ) |
| if cmd.alias in seen: |
| raise TDCompileError( |
| f"Alias '{cmd.alias}' is already used. Pick a different name for the fork.", |
| ) |
| seen.add(cmd.alias) |
| elif isinstance(cmd, (ResetCmd, PruneCmd, EditCmd)): |
| if cmd.target not in seen: |
| raise TDCompileError( |
| f"Can't use '{cmd.target}' - it hasn't been loaded yet.", |
| hint=f'Add: load "model/path" as {cmd.target}', |
| ) |
| elif isinstance(cmd, SnapshotCmd): |
| if cmd.target not in seen: |
| raise TDCompileError( |
| f"Can't snapshot '{cmd.target}' - it hasn't been loaded yet.", |
| hint=f'Add: load "model/path" as {cmd.target}', |
| ) |
| elif isinstance(cmd, ReportCmd): |
| pass |
| elif isinstance(cmd, FuseCmd): |
| if cmd.target not in seen: |
| raise TDCompileError( |
| f"Can't fuse into '{cmd.target}' - it hasn't been loaded yet.", |
| hint=f'Add: load "model/path" as {cmd.target}', |
| ) |
| if len(cmd.sources) < 1: |
| raise TDCompileError( |
| "Fuse needs at least 1 model in the list.", |
| hint='fuse ["model1", "model2"] into target', |
| ) |
| elif isinstance(cmd, AbsorbCmd): |
| if cmd.target not in seen: |
| raise TDCompileError( |
| f"Can't absorb into '{cmd.target}' - it hasn't been loaded yet.", |
| hint=f'Add: load "model/path" as {cmd.target}', |
| ) |
| elif isinstance(cmd, (RepeatBlock, IfBlock, ScheduleCmd)): |
| pass |
| elif isinstance(cmd, (NotifyCmd, SaveCmd, DownloadCmd)): |
| pass |
| elif isinstance(cmd, (CompareCmd, VerifyCmd)): |
| if cmd.target not in seen: |
| raise TDCompileError( |
| f"Can't use '{cmd.target}' - it hasn't been loaded yet.", |
| hint=f'Add: load "model/path" as {cmd.target}', |
| ) |
| elif isinstance(cmd, (VoteCmd, PromptBlock, RollbackCmd)): |
| if cmd.target not in seen: |
| raise TDCompileError( |
| f"Can't use '{cmd.target}' - it hasn't been loaded yet.", |
| hint=f'Add: load "model/path" as {cmd.target}', |
| ) |
| elif isinstance(cmd, DistillCmd): |
| if cmd.teacher not in seen: |
| raise TDCompileError( |
| f"Can't distill from '{cmd.teacher}' - it hasn't been loaded yet.", |
| hint=f'Add: load "model/path" as {cmd.teacher}', |
| ) |
| elif isinstance(cmd, (CurriculumCmd, StarCmd, BestOfCmd, ExploitCmd, ArenaCmd, ResearchArenaCmd)): |
| if cmd.target not in seen: |
| raise TDCompileError( |
| f"Can't use '{cmd.target}' - it hasn't been loaded yet.", |
| hint=f'Add: load "model/path" as {cmd.target}', |
| ) |
|
|
| |
| def _build_script(self, program: TDProgram) -> None: |
| """Construct the full Python script lines.""" |
| self._emit("#!/usr/bin/env python3") |
| source_hash = hashlib.sha256(str(program).encode()).hexdigest()[:12] |
| source_name = program.source_file or "unknown.td" |
| timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
| doc = textwrap.dedent( |
| f'''""" |
| Auto-generated by td_lang v0.1.0 |
| Source: {source_name} |
| Compiled: {timestamp} |
| Hash: {source_hash} |
| |
| DO NOT EDIT - regenerate from the .td file instead. |
| """''' |
| ) |
| self._emit(doc) |
| self._emit("import json") |
| self._emit("import os") |
| self._emit("import sys") |
| self._emit("import time") |
| self._emit("from datetime import datetime") |
| self._emit("from pathlib import Path") |
| self._emit("") |
| self._emit("from td_fuse.config import MergeConfig, SOURCES, TARGET") |
| self._emit("from td_fuse.merge import run_pipeline") |
| self._emit("from td_fuse.heal import heal_model") |
| self._emit("from td_fuse.validate import validate_merged_model") |
| self._emit("") |
| self._emit("from td_lang.errors import TDBudgetError, TDGateError") |
| self._emit("") |
| self._emit(f"GPU_HOURLY = {self.GPU_HOURLY}") |
| self._emit("") |
| self._emit("") |
| self._emit("def main():") |
| self._indent += 1 |
| self._emit("import os # safety: prevent UnboundLocalError if shadowed") |
| self._emit("start_time = time.time()") |
| self._emit("lineage = {}") |
| self._emit("models = {}") |
| self._emit("results = {}") |
| self._emit("merged_stages = []") |
| self._emit("output_dir = str(Path('.').resolve())") |
| self._emit("") |
| self._emit("# Quick canary check helper (lightweight sanity)") |
| self._emit("def quick_canary(checkpoint: str) -> float:") |
| self._indent += 1 |
| self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer") |
| self._emit("import torch") |
| self._emit("prompts = [") |
| self._indent += 1 |
| self._emit('"What is 2+2?",') |
| self._emit('"Spell the word apple.",') |
| self._emit('"Name a color that starts with B.",') |
| self._emit('"List two prime numbers.",') |
| self._emit('"What is the capital of France?",') |
| self._indent -= 1 |
| self._emit("]") |
| self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)") |
| self._emit("model = _load_model_smart(checkpoint, torch_dtype=torch.float16, device_map='auto')") |
| self._emit("model.eval()") |
| self._emit("scores = []") |
| self._emit("for p in prompts:") |
| self._indent += 1 |
| self._emit("messages = [{'role': 'user', 'content': p}]") |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)") |
| self._emit("inputs = tok(text, return_tensors='pt').to(model.device)") |
| self._indent -= 1 |
| self._emit("except Exception:") |
| self._indent += 1 |
| self._emit("inputs = tok(p, return_tensors='pt').to(model.device)") |
| self._indent -= 1 |
| self._emit("with torch.no_grad():") |
| self._indent += 1 |
| self._emit("out = model.generate(**inputs, max_new_tokens=32, do_sample=False)") |
| self._indent -= 1 |
| self._emit("new_tokens = out[0][inputs['input_ids'].shape[1]:]") |
| self._emit("resp = tok.decode(new_tokens, skip_special_tokens=True)") |
| self._emit("scores.append(len(resp))") |
| self._indent -= 1 |
| self._emit("avg_len = sum(scores) / len(scores)") |
| self._emit("del model, tok") |
| self._emit("import gc; gc.collect()") |
| self._emit("return avg_len") |
| self._indent -= 1 |
| self._emit("") |
|
|
| |
| self._emit("def _load_model_smart(checkpoint, **kwargs):") |
| self._indent += 1 |
| self._emit('"""Load model — auto-detects Qwen3-VL and uses the correct class."""') |
| self._emit("from transformers import AutoConfig") |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)") |
| self._emit("model_type = getattr(config, 'model_type', '')") |
| self._emit("config_class = type(config).__name__.lower()") |
| self._emit("if 'qwen3_vl' in model_type or 'qwen3vl' in config_class:") |
| self._indent += 1 |
| self._emit("from transformers import Qwen3VLForConditionalGeneration") |
| self._emit("print(f'[td_lang] Loading as Qwen3-VL model: {checkpoint}')") |
| self._emit("return Qwen3VLForConditionalGeneration.from_pretrained(checkpoint, **kwargs)") |
| self._indent -= 1 |
| self._indent -= 1 |
| self._emit("except Exception as e:") |
| self._indent += 1 |
| self._emit("print(f'[td_lang] Auto-detect failed ({e}), using AutoModelForCausalLM')") |
| self._indent -= 1 |
| self._emit("from transformers import AutoModelForCausalLM") |
| self._emit("return AutoModelForCausalLM.from_pretrained(checkpoint, **kwargs)") |
| self._indent -= 1 |
| self._emit("") |
|
|
| if program.setup: |
| self._emit_setup(program.setup) |
|
|
| if program.log: |
| self._emit_log_setup(program.log) |
|
|
| if program.on_error: |
| self._emit_on_error(program.on_error, program) |
|
|
| if program.budget: |
| self._emit_budget_check(program) |
|
|
| if program.data_contract: |
| self._emit_data_contract(program.data_contract) |
|
|
| if program.reward_contract: |
| self._emit_reward_contract(program.reward_contract) |
|
|
| for index, cmd in enumerate(program.commands, start=1): |
| self._emit_comment(f"Step {index}: {type(cmd).__name__}") |
| if isinstance(cmd, LoadCmd): |
| self._emit_load(cmd) |
| elif isinstance(cmd, MergeCmd): |
| self._emit_merge(cmd) |
| elif isinstance(cmd, HealCmd): |
| self._emit_heal(cmd) |
| elif isinstance(cmd, EvalCmd): |
| self._emit_eval(cmd) |
| elif isinstance(cmd, CommitCmd): |
| self._emit_commit(cmd, program.gates) |
| elif isinstance(cmd, DiagnoseCmd): |
| self._emit_diagnose(cmd) |
| elif isinstance(cmd, SynthCmd): |
| self._emit_synth(cmd) |
| elif isinstance(cmd, TrainCmd): |
| self._emit_train(cmd, program) |
| elif isinstance(cmd, DebateCmd): |
| self._emit_debate(cmd) |
| elif isinstance(cmd, EditCmd): |
| self._emit_edit(cmd) |
| elif isinstance(cmd, ForkCmd): |
| self._emit_fork(cmd) |
| elif isinstance(cmd, ResetCmd): |
| self._emit_reset(cmd) |
| elif isinstance(cmd, PruneCmd): |
| self._emit_prune(cmd) |
| elif isinstance(cmd, FuseCmd): |
| self._emit_fuse(cmd) |
| elif isinstance(cmd, AbsorbCmd): |
| self._emit_absorb(cmd) |
| elif isinstance(cmd, RepeatBlock): |
| self._emit_repeat(cmd, program) |
| elif isinstance(cmd, IfBlock): |
| self._emit_if(cmd, program) |
| elif isinstance(cmd, SnapshotCmd): |
| self._emit_snapshot(cmd, program) |
| elif isinstance(cmd, ReportCmd): |
| self._emit_report(cmd, program) |
| elif isinstance(cmd, NotifyCmd): |
| self._emit_notify(cmd, program) |
| elif isinstance(cmd, SaveCmd): |
| self._emit_save(cmd, program) |
| elif isinstance(cmd, ScheduleCmd): |
| self._emit_schedule(cmd, program) |
| elif isinstance(cmd, DownloadCmd): |
| self._emit_download(cmd) |
| elif isinstance(cmd, CompareCmd): |
| self._emit_compare(cmd) |
| elif isinstance(cmd, VerifyCmd): |
| self._emit_verify(cmd) |
| elif isinstance(cmd, VoteCmd): |
| self._emit_vote(cmd) |
| elif isinstance(cmd, PromptBlock): |
| self._emit_prompt(cmd) |
| elif isinstance(cmd, DistillCmd): |
| self._emit_distill(cmd) |
| elif isinstance(cmd, RollbackCmd): |
| self._emit_rollback(cmd) |
| elif isinstance(cmd, CurriculumCmd): |
| self._emit_curriculum(cmd, program) |
| elif isinstance(cmd, StarCmd): |
| self._emit_star(cmd, program) |
| elif isinstance(cmd, BestOfCmd): |
| self._emit_best_of(cmd, program) |
| elif isinstance(cmd, ExploitCmd): |
| self._emit_exploit(cmd, program) |
| elif isinstance(cmd, ArenaCmd): |
| self._emit_arena(cmd, program) |
| elif isinstance(cmd, ResearchArenaCmd): |
| self._emit_research_arena(cmd, program) |
| self._emit("") |
|
|
| self._emit_summary() |
| self._indent -= 1 |
| self._emit("") |
| self._emit('if __name__ == "__main__":') |
| self._indent += 1 |
| self._emit("main()") |
| self._indent -= 1 |
|
|
| |
| def _emit_load(self, cmd: LoadCmd) -> None: |
| self._aliases.add(cmd.alias) |
| self._emit(f'print("[td_lang] Loading {cmd.alias} from {cmd.model_ref}...")') |
| self._emit("") |
|
|
| |
| self._emit(f'_model_ref = "{cmd.model_ref}"') |
| self._emit("if '/' in _model_ref and not os.path.exists(_model_ref):") |
| self._indent += 1 |
| self._emit(f'print("[td_lang] Downloading from HuggingFace: {cmd.model_ref}")') |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("from huggingface_hub import snapshot_download") |
| self._emit(f'_local_path = snapshot_download(_model_ref, local_dir=f"models/{cmd.alias}")') |
| self._emit(f'print(f"[td_lang] Downloaded to {{_local_path}}")') |
| self._indent -= 1 |
| self._emit("except ImportError:") |
| self._indent += 1 |
| self._emit('print("[td_lang] huggingface_hub not installed. Storing ref only - download will happen at merge time.")') |
| self._emit("_local_path = _model_ref") |
| self._indent -= 1 |
| self._emit("except Exception as e:") |
| self._indent += 1 |
| self._emit('print(f"[td_lang] Download warning: {e}. Storing ref for later.")') |
| self._emit("_local_path = _model_ref") |
| self._indent -= 1 |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit("_local_path = _model_ref") |
| self._indent -= 1 |
| self._emit("") |
|
|
| self._emit(f'models["{cmd.alias}"] = {{') |
| self._indent += 1 |
| self._emit(f'"model_ref": "{cmd.model_ref}",') |
| self._emit('"local_path": _local_path,') |
| self._emit('"checkpoint": None,') |
| self._emit('"loaded_at": datetime.now().isoformat(),') |
| self._indent -= 1 |
| self._emit("}") |
| self._emit(f'lineage["{cmd.alias}"] = {{"source": "{cmd.model_ref}", "operations": []}}') |
| self._emit(f'print("[td_lang] {cmd.alias} ready.")') |
|
|
| def _emit_merge(self, cmd: MergeCmd) -> None: |
| self._emit( |
| f'print("[td_lang] Merging {cmd.source} into {cmd.target} using {cmd.method} (strength={cmd.strength})...")' |
| ) |
| self._emit(f'_source_ref = "{cmd.source}"') |
| self._emit("_stage = None") |
| self._emit("for _src in SOURCES:") |
| self._indent += 1 |
| self._emit('if _src.hf_id == _source_ref or _src.name.lower() in _source_ref.lower():') |
| self._indent += 1 |
| self._emit('_stage = _src.name.lower().split("-")[0]') |
| self._emit(f"_src.merge_alpha = {cmd.strength}") |
| self._emit("break") |
| self._indent -= 1 |
| self._indent -= 1 |
| self._emit("if _stage is None:") |
| self._indent += 1 |
| self._emit('raise SystemExit(f"Could not match source {_source_ref} to any SOURCES entry.")') |
| self._indent -= 1 |
| self._emit("") |
| self._emit("# Skip merge if checkpoint already exists (Bug #27 - saves ~12 min)") |
| self._emit('_merge_ckpt = Path(f"td_fuse_checkpoints/after_{_stage}")') |
| self._emit("if _merge_ckpt.exists() and (_merge_ckpt / 'model.safetensors').exists():") |
| self._indent += 1 |
| self._emit('print(f"[td_lang] Found merge checkpoint {_merge_ckpt} - SKIPPING merge")') |
| self._emit('merge_result = {"status": "skipped", "final_checkpoint": str(_merge_ckpt)}') |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit("# Stack merges: pass previous checkpoint so MiMo builds on DeepSeek, etc.") |
| self._emit(f'_prev_ckpt = models.get("{cmd.target}", {{}}).get("checkpoint")') |
| self._emit("cfg = MergeConfig()") |
| self._emit("merge_result = run_pipeline([_stage], cfg, base_checkpoint=_prev_ckpt)") |
| self._indent -= 1 |
| self._emit(f'results["{cmd.target}_merge"] = merge_result') |
| self._emit("merged_stages.append(_stage)") |
| self._emit('if merge_result.get("final_checkpoint"):') |
| self._indent += 1 |
| self._emit(f'models["{cmd.target}"]["checkpoint"] = merge_result["final_checkpoint"]') |
| self._indent -= 1 |
| self._emit(f'lineage["{cmd.target}"]["operations"].append({{') |
| self._indent += 1 |
| self._emit('"op": "merge",') |
| self._emit('"source": _source_ref,') |
| self._emit(f'"method": "{cmd.method}",') |
| self._emit(f'"strength": {cmd.strength},') |
| self._emit('"timestamp": datetime.now().isoformat(),') |
| self._emit('"stage": _stage,') |
| self._indent -= 1 |
| self._emit("})") |
| self._emit('print("[td_lang] Merge complete.")') |
|
|
| def _emit_heal(self, cmd: HealCmd) -> None: |
| self._emit(f'print("[td_lang] Healing {cmd.target} (lora_r={cmd.lora_r}, epochs={cmd.epochs})...")') |
| self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")') |
| self._emit("if not checkpoint:") |
| self._indent += 1 |
| self._emit('print("[td_lang] WARNING: No checkpoint to heal - run a merge first.")') |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit(f"cfg = MergeConfig(heal_lora_r={cmd.lora_r}, heal_epochs={cmd.epochs})") |
| self._emit("healed_path = heal_model(checkpoint, cfg)") |
| self._emit(f'models["{cmd.target}"]["checkpoint"] = healed_path') |
| self._emit(f'lineage["{cmd.target}"]["operations"].append({{') |
| self._indent += 1 |
| self._emit('"op": "heal",') |
| self._emit(f'"lora_r": {cmd.lora_r},') |
| self._emit(f'"epochs": {cmd.epochs},') |
| self._emit('"timestamp": datetime.now().isoformat(),') |
| self._indent -= 1 |
| self._emit("})") |
| self._emit('print("[td_lang] Heal complete.")') |
| self._indent -= 1 |
|
|
| def _emit_eval(self, cmd: EvalCmd) -> None: |
| """Generate self-contained evaluation - math, code, reasoning, perplexity. |
| |
| No dependency on td_fuse. Tests the model on real tasks and returns |
| pass/fail plus scores per category. Uses 'improved' flag to track |
| whether the model got better vs previous eval. |
| """ |
| self._emit(f'print("[td_lang] Evaluating {cmd.target}...")') |
| self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")') |
| self._emit("if not checkpoint:") |
| self._indent += 1 |
| self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]') |
| self._indent -= 1 |
| self._emit("from transformers import AutoTokenizer") |
| self._emit("import torch, re, ast") |
| self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)") |
| self._emit("model = _load_model_smart(checkpoint, torch_dtype=torch.bfloat16, device_map='auto')") |
| self._emit("model.eval()") |
| self._emit("") |
| self._emit("# Mini-benchmark: math, code, reasoning, perplexity") |
| self._emit("eval_tests = {") |
| self._indent += 1 |
| self._emit('"math": [') |
| self._indent += 1 |
| self._emit('{"prompt": "What is 17 * 23? Answer with just the number.", "answer": "391"},') |
| self._emit('{"prompt": "What is 144 / 12? Answer with just the number.", "answer": "12"},') |
| self._emit('{"prompt": "What is 256 + 789? Answer with just the number.", "answer": "1045"},') |
| self._emit('{"prompt": "What is 15 squared? Answer with just the number.", "answer": "225"},') |
| self._emit('{"prompt": "What is the square root of 81? Answer with just the number.", "answer": "9"},') |
| self._indent -= 1 |
| self._emit("],") |
| self._emit('"code": [') |
| self._indent += 1 |
| self._emit('{"prompt": "Write a Python function that returns the sum of a list. Just the function, nothing else.", "check": "def"},') |
| self._emit('{"prompt": "Write a Python function to check if a number is prime. Just the function.", "check": "def"},') |
| self._emit('{"prompt": "Write a Python one-liner list comprehension that squares numbers 1-10.", "check": "["},') |
| self._indent -= 1 |
| self._emit("],") |
| self._emit('"reasoning": [') |
| self._indent += 1 |
| self._emit('{"prompt": "If all dogs are animals, and all animals breathe, do all dogs breathe? Answer yes or no.", "answer": "yes"},') |
| self._emit('{"prompt": "A bat and ball cost $1.10 together. The bat costs $1 more than the ball. How much does the ball cost? Answer with just the number.", "answer": "0.05"},') |
| self._emit('{"prompt": "If it takes 5 machines 5 minutes to make 5 widgets, how long would it take 100 machines to make 100 widgets? Answer in minutes.", "answer": "5"},') |
| self._indent -= 1 |
| self._emit("],") |
| self._indent -= 1 |
| self._emit("}") |
| self._emit("") |
| self._emit("eval_result = {'overall': True, 'scores': {}, 'details': {}}") |
| self._emit("total_correct = 0") |
| self._emit("total_tests = 0") |
| self._emit("") |
| self._emit("for category, tests in eval_tests.items():") |
| self._indent += 1 |
| self._emit("cat_correct = 0") |
| self._emit("cat_details = []") |
| self._emit("for test in tests:") |
| self._indent += 1 |
| self._emit("total_tests += 1") |
| self._emit('inputs = tok(test["prompt"], return_tensors="pt").to(model.device)') |
| self._emit("with torch.no_grad():") |
| self._indent += 1 |
| self._emit("output = model.generate(**inputs, max_new_tokens=256, do_sample=False, temperature=0.0)") |
| self._indent -= 1 |
| self._emit("response = tok.decode(output[0], skip_special_tokens=True)") |
| self._emit('# Strip the prompt from the response if model echoes it') |
| self._emit('if response.startswith(test["prompt"]):') |
| self._indent += 1 |
| self._emit('response = response[len(test["prompt"]):].strip()') |
| self._indent -= 1 |
| self._emit("passed = False") |
| self._emit('if "answer" in test:') |
| self._indent += 1 |
| self._emit('passed = test["answer"].lower() in response.lower()') |
| self._indent -= 1 |
| self._emit('elif "check" in test:') |
| self._indent += 1 |
| self._emit('passed = test["check"] in response') |
| self._emit("# Also try to parse as valid Python") |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("ast.parse(response)") |
| self._indent -= 1 |
| self._emit("except SyntaxError:") |
| self._indent += 1 |
| self._emit("passed = False # Code doesn't compile") |
| self._indent -= 2 |
| self._emit("if passed:") |
| self._indent += 1 |
| self._emit("cat_correct += 1") |
| self._emit("total_correct += 1") |
| self._indent -= 1 |
| self._emit('cat_details.append({"prompt": test["prompt"][:60], "passed": passed})') |
| self._indent -= 1 |
| self._emit("score = cat_correct / max(len(tests), 1)") |
| self._emit('eval_result["scores"][category] = round(score, 3)') |
| self._emit('eval_result["details"][category] = cat_details') |
| self._emit('print(f" {category}: {cat_correct}/{len(tests)} ({score:.0%})")') |
| self._indent -= 1 |
| self._emit("") |
| self._emit("# Perplexity test (lower = model is more confident/coherent)") |
| self._emit('ppl_text = "The capital of France is Paris. Water boils at 100 degrees Celsius."') |
| self._emit('ppl_inputs = tok(ppl_text, return_tensors="pt").to(model.device)') |
| self._emit("with torch.no_grad():") |
| self._indent += 1 |
| self._emit('ppl_loss = model(**ppl_inputs, labels=ppl_inputs["input_ids"]).loss') |
| self._indent -= 1 |
| self._emit("perplexity = torch.exp(ppl_loss).item()") |
| self._emit('eval_result["perplexity"] = round(perplexity, 2)') |
| self._emit('eval_result["scores"]["perplexity"] = "pass" if perplexity < 20.0 else "fail"') |
| self._emit('_ppl_label = "pass" if perplexity < 20.0 else "FAIL - too high"') |
| self._emit('print(f" perplexity: {perplexity:.2f} ({_ppl_label})")') |
| self._emit("") |
| self._emit("# Overall score") |
| self._emit("overall_score = total_correct / max(total_tests, 1)") |
| self._emit('eval_result["overall_score"] = round(overall_score, 3)') |
| self._emit('eval_result["overall"] = overall_score >= 0.5 and perplexity < 20.0') |
| self._emit('_overall_label = "PASS" if eval_result["overall"] else "FAIL"') |
| self._emit('print(f" OVERALL: {total_correct}/{total_tests} ({overall_score:.0%}) - {_overall_label}")') |
| self._emit("") |
| self._emit("# Track improvement over previous eval") |
| self._emit(f'hist_key = "{cmd.target}_eval_history"') |
| self._emit("if hist_key not in results:") |
| self._indent += 1 |
| self._emit("results[hist_key] = []") |
| self._indent -= 1 |
| self._emit("results[hist_key].append(overall_score)") |
| self._emit('eval_result["improved"] = len(results[hist_key]) < 2 or results[hist_key][-1] >= results[hist_key][-2]') |
| self._emit(f'results["{cmd.target}_eval"] = eval_result') |
| self._emit(f'lineage["{cmd.target}"]["operations"].append({{') |
| self._indent += 1 |
| self._emit('"op": "eval",') |
| self._emit('"timestamp": datetime.now().isoformat(),') |
| self._emit('"overall_score": overall_score,') |
| self._emit('"perplexity": perplexity,') |
| self._indent -= 1 |
| self._emit("})") |
| if cmd.output: |
| self._emit(f'eval_path = Path("{cmd.output}")') |
| self._emit("eval_path.parent.mkdir(parents=True, exist_ok=True)") |
| self._emit('with open(eval_path, "w") as f:') |
| self._indent += 1 |
| self._emit("json.dump(eval_result, f, indent=2, default=str)") |
| self._indent -= 1 |
| self._emit('print(f"[td_lang] Eval results saved to {eval_path}")') |
| else: |
| self._emit('print("[td_lang] Eval results:", json.dumps(eval_result, indent=2, default=str))') |
| self._emit("del model, tok") |
| self._emit("import gc; gc.collect()") |
|
|
| def _emit_commit(self, cmd: CommitCmd, global_gates: Optional[GateBlock]) -> None: |
| gates = cmd.gates or (global_gates.must_pass if global_gates else None) |
| self._emit(f'print("[td_lang] Committing {cmd.target}...")') |
| if gates: |
| self._emit(f"gates_to_check = {gates}") |
| self._emit(f'last_eval = results.get("{cmd.target}_eval", {{}})') |
| self._emit("failed = []") |
| self._emit("for gate in gates_to_check:") |
| self._indent += 1 |
| self._emit('if gate == "overall":') |
| self._indent += 1 |
| self._emit('ok = bool(last_eval.get("overall", False))') |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit("val = last_eval.get(gate, {})") |
| self._emit("if isinstance(val, dict):") |
| self._indent += 1 |
| self._emit('ok = bool(val.get("ok", False))') |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit("ok = bool(val)") |
| self._indent -= 1 |
| self._indent -= 1 |
| self._emit("if not ok:") |
| self._indent += 1 |
| self._emit("failed.append(gate)") |
| self._indent -= 1 |
| self._indent -= 1 |
| self._emit("if failed:") |
| self._indent += 1 |
| self._emit('raise TDGateError(failed, message="Commit blocked - gates failed")') |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit('print("[td_lang] All gates passed!")') |
| self._indent -= 1 |
|
|
| self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")') |
| self._emit("if not checkpoint:") |
| self._indent += 1 |
| self._emit('print("[td_lang] WARNING: No checkpoint to commit.")') |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit('commit_dir = Path("td_lang_outputs") / "committed"') |
| self._emit("commit_dir.mkdir(parents=True, exist_ok=True)") |
| self._emit('lineage_path = commit_dir / "lineage.json"') |
| self._emit('with open(lineage_path, "w") as f:') |
| self._indent += 1 |
| self._emit("json.dump(lineage, f, indent=2, default=str)") |
| self._indent -= 1 |
| self._emit('print(f"[td_lang] Committed. Checkpoint: {checkpoint}")') |
| self._emit('print(f"[td_lang] Lineage saved to: {lineage_path}")') |
| self._indent -= 1 |
|
|
| |
|
|
| def _emit_diagnose(self, cmd: DiagnoseCmd) -> None: |
| """Generate code for: diagnose target [-> weaknesses.json] |
| |
| MEGA DIAGNOSE: Self-diagnosis + Performance profiling in one command. |
| Part 1: Asks the model to identify its own weaknesses (self-diagnosis). |
| Part 2: Tests the model on actual problems per domain (profiling). |
| Part 3: Measures per-layer inference speed to find bottleneck layers. |
| Combines all three into a single actionable report. |
| """ |
| self._emit(f'print("[td_lang] Diagnosing {cmd.target}...")') |
| self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")') |
| self._emit("if not checkpoint:") |
| self._indent += 1 |
| self._emit('print("[td_lang] WARNING: No checkpoint - using model_ref instead.")') |
| self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]') |
| self._indent -= 1 |
| self._emit("from transformers import AutoTokenizer") |
| self._emit("import torch") |
| self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)") |
| self._emit("model = _load_model_smart(checkpoint, torch_dtype=torch.bfloat16, device_map='auto')") |
| self._emit("model.eval()") |
| self._emit("") |
| self._emit("# Self-diagnosis prompts (from TD interview findings test_12)") |
| self._emit("diag_prompts = [") |
| self._indent += 1 |
| self._emit('"List your top 5 weaknesses as an AI. Be specific and honest.",') |
| self._emit('"What types of reasoning tasks do you fail at most? Give concrete examples.",') |
| self._emit('"Rate yourself 1-10 on: math, coding, long-chain logic, creativity, factual recall. Explain each score.",') |
| self._emit('"If you could improve one thing about yourself, what would have the biggest impact?",') |
| self._indent -= 1 |
| self._emit("]") |
| self._emit("diagnose_results = []") |
| self._emit("for prompt in diag_prompts:") |
| self._indent += 1 |
| self._emit("# Use chat template for proper generation (Qwen3 needs this)") |
| self._emit('messages = [{"role": "user", "content": prompt}]') |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)") |
| self._emit('inputs = tok(text, return_tensors="pt").to(model.device)') |
| self._indent -= 1 |
| self._emit("except Exception:") |
| self._indent += 1 |
| self._emit('inputs = tok(prompt, return_tensors="pt").to(model.device)') |
| self._indent -= 1 |
| self._emit("with torch.no_grad():") |
| self._indent += 1 |
| self._emit("output = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)") |
| self._indent -= 1 |
| self._emit("new_tokens = output[0][inputs['input_ids'].shape[1]:]") |
| self._emit("response = tok.decode(new_tokens, skip_special_tokens=True)") |
| self._emit('diagnose_results.append({"prompt": prompt, "response": response})') |
| self._emit('print(f" Prompt: {prompt[:50]}...")') |
| self._emit('print(f" Response: {response[:200]}...")') |
| self._emit("print()") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("# Parse responses into structured weakness categories") |
| self._emit("import re as _re") |
| self._emit("weakness_categories = {") |
| self._indent += 1 |
| self._emit("'math': ['math', 'arithmetic', 'calculation', 'algebra', 'geometry', 'calculus'],") |
| self._emit("'code': ['code', 'coding', 'programming', 'debug', 'syntax', 'algorithm'],") |
| self._emit("'logic': ['logic', 'reasoning', 'inference', 'fallac', 'deduction', 'chain'],") |
| self._emit("'factual': ['factual', 'hallucin', 'accuracy', 'knowledge', 'recall', 'memory'],") |
| self._emit("'creativity': ['creative', 'creativity', 'imagination', 'novel', 'original'],") |
| self._emit("'instruction': ['instruction', 'follow', 'format', 'comply', 'understand'],") |
| self._indent -= 1 |
| self._emit("}") |
| self._emit("") |
| self._emit("weakness_scores = {cat: 0 for cat in weakness_categories}") |
| self._emit("for d in diagnose_results:") |
| self._indent += 1 |
| self._emit("resp_lower = d['response'].lower()") |
| self._emit("for cat, keywords in weakness_categories.items():") |
| self._indent += 1 |
| self._emit("for kw in keywords:") |
| self._indent += 1 |
| self._emit("if kw in resp_lower:") |
| self._indent += 1 |
| self._emit("weakness_scores[cat] += 1") |
| self._emit("break") |
| self._indent -= 3 |
| self._indent -= 1 |
| self._emit("") |
| self._emit("# Rank weaknesses by how many prompts mentioned them") |
| self._emit("ranked = sorted(weakness_scores.items(), key=lambda x: x[1], reverse=True)") |
| self._emit("top_weaknesses = [cat for cat, score in ranked if score > 0][:4]") |
| self._emit("if not top_weaknesses:") |
| self._indent += 1 |
| self._emit("top_weaknesses = ['math', 'logic', 'code'] # safe defaults") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("diagnosis = {") |
| self._indent += 1 |
| self._emit("'raw_responses': diagnose_results,") |
| self._emit("'weakness_scores': weakness_scores,") |
| self._emit("'top_weaknesses': top_weaknesses,") |
| self._emit("'ranked': ranked,") |
| self._indent -= 1 |
| self._emit("}") |
| self._emit("print('[td_lang] Weakness ranking:')") |
| self._emit("for cat, score in ranked:") |
| self._indent += 1 |
| self._emit("if score > 0:") |
| self._indent += 1 |
| self._emit("print(f' {cat}: mentioned in {score}/{len(diag_prompts)} prompts')") |
| self._indent -= 2 |
| self._emit("print(f'[td_lang] Top weaknesses to target: {top_weaknesses}')") |
| self._emit("") |
| self._emit("") |
| self._emit("# --- Part 2: Profiling - test actual performance per domain ---") |
| self._emit('print("[td_lang] Running domain profiling...")') |
| self._emit("profile_tests = {") |
| self._indent += 1 |
| self._emit("'math': [") |
| self._indent += 1 |
| self._emit('("What is 15 * 23?", "345"),') |
| self._emit('("What is 144 / 12?", "12"),') |
| self._emit('("Solve: 2x + 5 = 17", "6"),') |
| self._indent -= 1 |
| self._emit("],") |
| self._emit("'code': [") |
| self._indent += 1 |
| self._emit('("Write a Python function that returns the factorial of n.", "def"),') |
| self._emit('("What does len([1,2,3]) return in Python?", "3"),') |
| self._emit('("Fix this: for i in range(10) print(i)", "for i in range(10):"),') |
| self._indent -= 1 |
| self._emit("],") |
| self._emit("'logic': [") |
| self._indent += 1 |
| self._emit('("If all cats are animals and all animals breathe, do cats breathe?", "yes"),') |
| self._emit('("A is taller than B. B is taller than C. Who is shortest?", "c"),') |
| self._emit('("If it rains the ground is wet. The ground is wet. Did it rain?", "not necessarily"),') |
| self._indent -= 1 |
| self._emit("],") |
| self._emit("'factual': [") |
| self._indent += 1 |
| self._emit('("What planet is closest to the Sun?", "mercury"),') |
| self._emit('("Who wrote Romeo and Juliet?", "shakespeare"),') |
| self._emit('("What is the chemical formula for water?", "h2o"),') |
| self._indent -= 1 |
| self._emit("],") |
| self._indent -= 1 |
| self._emit("}") |
| self._emit("") |
| self._emit("domain_scores = {}") |
| self._emit("for domain, tests in profile_tests.items():") |
| self._indent += 1 |
| self._emit("correct = 0") |
| self._emit("for question, expected in tests:") |
| self._indent += 1 |
| self._emit('inputs = tok(question, return_tensors="pt").to(model.device)') |
| self._emit("with torch.no_grad():") |
| self._indent += 1 |
| self._emit("out = model.generate(**inputs, max_new_tokens=128, do_sample=False)") |
| self._indent -= 1 |
| self._emit("resp = tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip().lower()") |
| self._emit("if expected.lower() in resp:") |
| self._indent += 1 |
| self._emit("correct += 1") |
| self._indent -= 2 |
| self._emit("score = correct / len(tests) * 100") |
| self._emit("domain_scores[domain] = score") |
| self._emit("_score_label = 'STRONG' if score >= 67 else ('OK' if score >= 34 else 'WEAK')") |
| self._emit('print(f" {domain}: {score:.0f}% ({_score_label})")') |
| self._indent -= 1 |
| self._emit("") |
| self._emit("# --- Part 3: Layer speed profiling ---") |
| self._emit('print("[td_lang] Measuring layer speeds...")') |
| self._emit("import time as _time") |
| self._emit("n_layers = len(model.model.layers) if hasattr(model, 'model') and hasattr(model.model, 'layers') else 0") |
| self._emit("layer_times = {}") |
| self._emit("if n_layers > 0:") |
| self._indent += 1 |
| self._emit('test_input = tok("Hello world", return_tensors="pt").to(model.device)') |
| self._emit("# Warm up") |
| self._emit("with torch.no_grad():") |
| self._indent += 1 |
| self._emit("_ = model(**test_input)") |
| self._indent -= 1 |
| self._emit("# Time each layer group (every 4 layers)") |
| self._emit("_total_start = _time.perf_counter()") |
| self._emit("with torch.no_grad():") |
| self._indent += 1 |
| self._emit("_ = model(**test_input)") |
| self._indent -= 1 |
| self._emit("_total_time = _time.perf_counter() - _total_start") |
| self._emit("_per_layer = _total_time / n_layers * 1000 # ms per layer") |
| self._emit('print(f" Total inference: {_total_time*1000:.1f}ms across {n_layers} layers")') |
| self._emit('print(f" Average: {_per_layer:.2f}ms per layer")') |
| self._emit('layer_times = {"total_ms": _total_time*1000, "n_layers": n_layers, "avg_ms_per_layer": _per_layer}') |
| self._indent -= 1 |
| self._emit("") |
| self._emit("# Combine everything into mega-diagnosis") |
| self._emit("diagnosis['domain_scores'] = domain_scores") |
| self._emit("diagnosis['layer_profile'] = layer_times") |
| self._emit("diagnosis['weakest_domains'] = sorted(domain_scores.items(), key=lambda x: x[1])[:2]") |
| self._emit("") |
| self._emit("# Merge self-reported weaknesses with actual test results") |
| self._emit("print('[td_lang] === MEGA DIAGNOSIS SUMMARY ===')") |
| self._emit("print('[td_lang] Self-reported weaknesses:', top_weaknesses)") |
| self._emit("_weakest = [d for d, s in sorted(domain_scores.items(), key=lambda x: x[1])[:2]]") |
| self._emit("print(f'[td_lang] Tested weakest domains: {_weakest}')") |
| self._emit("# Combine both signals") |
| self._emit("all_weak = list(set(top_weaknesses[:2] + _weakest))") |
| self._emit("diagnosis['combined_weaknesses'] = all_weak") |
| self._emit("top_weaknesses = all_weak # update for synth to use") |
| self._emit("print(f'[td_lang] Combined training targets: {all_weak}')") |
| self._emit("") |
| self._emit(f'results["{cmd.target}_diagnose"] = diagnosis') |
| self._emit(f'lineage["{cmd.target}"]["operations"].append({{') |
| self._indent += 1 |
| self._emit('"op": "diagnose",') |
| self._emit('"n_prompts": len(diag_prompts),') |
| self._emit('"top_weaknesses": top_weaknesses,') |
| self._emit('"domain_scores": domain_scores,') |
| self._emit('"timestamp": datetime.now().isoformat(),') |
| self._indent -= 1 |
| self._emit("})") |
| if cmd.output: |
| self._emit(f'diag_path = Path("{cmd.output}")') |
| self._emit("diag_path.parent.mkdir(parents=True, exist_ok=True)") |
| self._emit('with open(diag_path, "w") as f:') |
| self._indent += 1 |
| self._emit("json.dump(diagnosis, f, indent=2, default=str)") |
| self._indent -= 1 |
| self._emit('print(f"[td_lang] Diagnosis saved to {diag_path}")') |
| self._emit("del model, tok") |
| self._emit("import gc; gc.collect()") |
| self._emit('print("[td_lang] Diagnosis complete.")') |
|
|
| def _emit_synth(self, cmd: SynthCmd) -> None: |
| """Generate code for: synth target from source [filter cherry_llm] [-> output.jsonl] |
| |
| Smarter synthesis: |
| - Targets weaknesses from prior diagnose results when present. |
| - Supports configurable sample count (cmd.n_samples if provided). |
| - Produces domain-specific prompts (math, code, logic, factual). |
| """ |
| n_samples_val = getattr(cmd, 'n_samples', 100) |
| self._emit(f'print("[td_lang] Generating synthetic data for {cmd.target}...")') |
| self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")') |
| self._emit("if not checkpoint:") |
| self._indent += 1 |
| self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]') |
| self._indent -= 1 |
| self._emit("from transformers import AutoTokenizer") |
| self._emit("import torch, random, re") |
| self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)") |
| self._emit("model = _load_model_smart(checkpoint, torch_dtype=torch.bfloat16, device_map='auto')") |
| self._emit("model.eval()") |
| self._emit("") |
| self._emit("# Use structured diagnosis if available (upgraded diagnose outputs top_weaknesses)") |
| self._emit(f'diag = results.get("{cmd.target}_diagnose", {{}})') |
| self._emit("if isinstance(diag, dict) and 'top_weaknesses' in diag:") |
| self._indent += 1 |
| self._emit("weak_topics = diag['top_weaknesses']") |
| self._emit("print(f'[td_lang] Targeting weaknesses from diagnosis: {weak_topics}')") |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit("# Fallback: scan raw responses for weakness keywords") |
| self._emit("weak_topics = []") |
| self._emit("raw = diag if isinstance(diag, list) else diag.get('raw_responses', [])") |
| self._emit("for d in raw:") |
| self._indent += 1 |
| self._emit("resp = d.get('response', '')") |
| self._emit("for topic in ['math', 'code', 'logic', 'factual']:") |
| self._indent += 1 |
| self._emit("if topic in resp.lower() and topic not in weak_topics:") |
| self._indent += 1 |
| self._emit("weak_topics.append(topic)") |
| self._indent -= 1 |
| self._indent -= 1 |
| self._indent -= 1 |
| self._indent -= 1 |
| self._emit("if not weak_topics:") |
| self._indent += 1 |
| self._emit("weak_topics = ['math', 'code', 'logic', 'factual']") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("# Domain templates") |
| self._emit("domain_templates = {") |
| self._indent += 1 |
| self._emit('"math": ["Solve this math problem step by step: {problem}",') |
| self._emit(' "Find and correct the mistake in this solution: {problem}"],') |
| self._emit('"code": ["Write correct, tested Python code for: {problem}",') |
| self._emit(' "Find the bug and fix it: {problem}"],') |
| self._emit('"logic": ["Reason carefully and avoid fallacies: {problem}",') |
| self._emit(' "Provide a formal argument for: {problem}"],') |
| self._emit('"factual": ["Answer with citations: {problem}",') |
| self._emit(' "List 3 verified facts about: {problem}"],') |
| self._emit('"creativity": ["Think of an original approach to: {problem}",') |
| self._emit(' "Brainstorm 5 creative solutions for: {problem}"],') |
| self._emit('"instruction": ["Follow these instructions precisely: {problem}",') |
| self._emit(' "Complete this task exactly as described: {problem}"],') |
| self._indent -= 1 |
| self._emit("}") |
| self._emit("") |
| self._emit("# Seed problems - model generates MORE from these (not just these 4)") |
| self._emit("seed_problems = {") |
| self._indent += 1 |
| self._emit("'math': [") |
| self._indent += 1 |
| self._emit("'Compute (17*19 - 121) / 3',") |
| self._emit("'Find the derivative of x^3 + 2x^2 - 5x + 7',") |
| self._emit("'Solve for x: 3x + 7 = 22',") |
| self._emit("'What is the sum of the first 20 positive integers?',") |
| self._emit("'A rectangle has area 48 and perimeter 28. Find its dimensions.',") |
| self._emit("'Calculate 15% of 240',") |
| self._indent -= 1 |
| self._emit("],") |
| self._emit("'code': [") |
| self._indent += 1 |
| self._emit("'Implement binary search in Python',") |
| self._emit("'Write a function to reverse a linked list',") |
| self._emit("'Parse a CSV file and compute column averages',") |
| self._emit("'Implement a LRU cache with O(1) get and put',") |
| self._emit("'Write a function to find all permutations of a string',") |
| self._emit("'Implement merge sort',") |
| self._indent -= 1 |
| self._emit("],") |
| self._emit("'logic': [") |
| self._indent += 1 |
| self._emit("'If all A are B and all B are C, are all A C? Explain your reasoning.',") |
| self._emit("'A says B is lying. B says C is lying. C says both A and B are lying. Who is telling the truth?',") |
| self._emit("'Three boxes: one has gold, one has silver, one is empty. Box A says gold is in B. Box B says gold is in B. Box C says gold is not in A. Only one tells truth. Where is the gold?',") |
| self._emit("'If it takes 5 machines 5 minutes to make 5 widgets, how long does it take 100 machines to make 100 widgets?',") |
| self._indent -= 1 |
| self._emit("],") |
| self._emit("'factual': [") |
| self._indent += 1 |
| self._emit("'Explain the difference between TCP and UDP in networking',") |
| self._emit("'What are the three laws of thermodynamics?',") |
| self._emit("'Describe how transformers work in machine learning',") |
| self._emit("'What causes tides on Earth?',") |
| self._indent -= 1 |
| self._emit("],") |
| self._emit("'creativity': [") |
| self._indent += 1 |
| self._emit("'Design a new board game that teaches fractions to kids',") |
| self._emit("'Invent a product that solves a common kitchen problem',") |
| self._emit("'Write a short story where time flows backwards',") |
| self._emit("'Propose 3 unconventional uses for a paperclip',") |
| self._indent -= 1 |
| self._emit("],") |
| self._emit("'instruction': [") |
| self._indent += 1 |
| self._emit("'Write exactly 3 sentences about dogs. Each must start with a different letter.',") |
| self._emit("'List the planets in order from the sun. Format each as: N. Name - one interesting fact.',") |
| self._emit("'Translate this to formal English then to casual English: gonna grab some grub',") |
| self._emit("'Summarize photosynthesis in exactly 25 words.',") |
| self._indent -= 1 |
| self._emit("],") |
| self._indent -= 1 |
| self._emit("}") |
| self._emit("") |
| self._emit("# Ask the model to generate MORE problems like the seeds") |
| self._emit("print('[td_lang] Generating problem bank from seeds...')") |
| self._emit("problem_bank = dict(seed_problems) # start with seeds") |
| self._emit("for domain in weak_topics:") |
| self._indent += 1 |
| self._emit("if domain not in seed_problems:") |
| self._indent += 1 |
| self._emit("continue") |
| self._indent -= 1 |
| self._emit("examples = '; '.join(seed_problems.get(domain, [])[:3])") |
| self._emit("gen_prompt = f'Generate 10 diverse {domain} problems similar to: {examples}. List them numbered 1-10, one per line.'") |
| self._emit('gen_inputs = tok(gen_prompt, return_tensors="pt").to(model.device)') |
| self._emit("with torch.no_grad():") |
| self._indent += 1 |
| self._emit("gen_out = model.generate(**gen_inputs, max_new_tokens=512, do_sample=True, temperature=0.9)") |
| self._indent -= 1 |
| self._emit("gen_text = tok.decode(gen_out[0], skip_special_tokens=True)") |
| self._emit("# Parse numbered lines as new problems") |
| self._emit("for line in gen_text.split(chr(10)):") |
| self._indent += 1 |
| self._emit("line = re.sub(r'^\\d+[.)\\s]+', '', line.strip())") |
| self._emit("if len(line) > 15:") |
| self._indent += 1 |
| self._emit("problem_bank.setdefault(domain, []).append(line)") |
| self._indent -= 2 |
| self._indent -= 1 |
| self._emit("total_problems = sum(len(v) for v in problem_bank.values())") |
| self._emit("print(f'[td_lang] Problem bank: {total_problems} problems across {len(problem_bank)} domains')") |
| self._emit("") |
| self._emit("def make_problem(domain: str) -> str:") |
| self._indent += 1 |
| self._emit("pool = problem_bank.get(domain, problem_bank.get('math', ['Solve 2+2']))") |
| self._emit("return random.choice(pool)") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("synth_data = []") |
| self._emit(f"n_samples = {n_samples_val}") |
| self._emit("for i in range(n_samples):") |
| self._indent += 1 |
| self._emit("domain = random.choice(weak_topics)") |
| self._emit("problem = make_problem(domain)") |
| self._emit("template = random.choice(domain_templates.get(domain, ['Solve this problem step by step: {problem}']))") |
| self._emit('prompt = template.format(problem=problem)') |
| self._emit('inputs = tok(prompt, return_tensors="pt").to(model.device)') |
| self._emit("with torch.no_grad():") |
| self._indent += 1 |
| self._emit("output = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)") |
| self._indent -= 1 |
| self._emit("response = tok.decode(output[0], skip_special_tokens=True)") |
| self._emit('synth_data.append({"prompt": prompt, "response": response, "domain": domain})') |
| self._emit('if (i + 1) % 10 == 0:') |
| self._indent += 1 |
| self._emit('print(f" Generated {i + 1}/{n_samples} samples...")') |
| self._indent -= 1 |
| self._indent -= 1 |
| filter_method = cmd.filter_method or "none" |
| if filter_method == "cherry_llm": |
| self._emit("") |
| self._emit("# Cherry_LLM perplexity filter (test_12: prevents mode collapse)") |
| self._emit("print('[td_lang] Filtering with Cherry_LLM perplexity scoring...')") |
| self._emit("filtered = []") |
| self._emit("for sample in synth_data:") |
| self._indent += 1 |
| self._emit('inputs = tok(sample["response"], return_tensors="pt").to(model.device)') |
| self._emit("with torch.no_grad():") |
| self._indent += 1 |
| self._emit('loss = model(**inputs, labels=inputs["input_ids"]).loss') |
| self._indent -= 1 |
| self._emit("perplexity = torch.exp(loss).item()") |
| self._emit('sample["perplexity"] = perplexity') |
| self._emit("if 2.0 < perplexity < 50.0:") |
| self._indent += 1 |
| self._emit("filtered.append(sample)") |
| self._indent -= 1 |
| self._indent -= 1 |
| self._emit("synth_data = filtered") |
| self._emit('print(f"[td_lang] Kept {len(synth_data)} samples after Cherry_LLM filter.")') |
| self._emit("") |
| self._emit(f'results["{cmd.target}_synth"] = synth_data') |
| self._emit(f'lineage["{cmd.target}"]["operations"].append({{') |
| self._indent += 1 |
| self._emit('"op": "synth",') |
| self._emit(f'"source": "{cmd.source}",') |
| self._emit(f'"filter": "{filter_method}",') |
| self._emit('"n_samples": len(synth_data),') |
| self._emit('"timestamp": datetime.now().isoformat(),') |
| self._indent -= 1 |
| self._emit("})") |
| output_path = cmd.output or "synth_data.jsonl" |
| self._emit(f'synth_path = Path("{output_path}")') |
| self._emit("synth_path.parent.mkdir(parents=True, exist_ok=True)") |
| self._emit('with open(synth_path, "w") as f:') |
| self._indent += 1 |
| self._emit("for sample in synth_data:") |
| self._indent += 1 |
| self._emit("f.write(json.dumps(sample, default=str) + chr(10))") |
| self._indent -= 1 |
| self._indent -= 1 |
| self._emit('print(f"[td_lang] Synthetic data saved to {synth_path} ({len(synth_data)} samples)")') |
| self._emit("del model, tok") |
| self._emit("import gc; gc.collect()") |
|
|
| def _emit_train(self, cmd: TrainCmd, program: TDProgram = None) -> None: |
| """Generate code for: train target on "dataset" using method [steps N] [lr N] |
| |
| Runs GRPO, SFT, or DPO training using the trl library. |
| GRPO hyperparameters from test_15: 64 steps sweet spot, eval every 16. |
| """ |
| steps = cmd.steps or 64 |
| lr = cmd.learning_rate or 5e-5 |
| self._emit(f'print("[td_lang] Training {cmd.target} using {cmd.method} for {steps} steps...")') |
| self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")') |
| self._emit("if not checkpoint:") |
| self._indent += 1 |
| self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]') |
| self._indent -= 1 |
| self._emit("") |
|
|
| if cmd.method == "grpo": |
| self._emit("# Bug #26 fix: Use SFT on merge checkpoint (same approach as healing — proven to work)") |
| self._emit("# GRPOTrainer breaks with Qwen3-VL, but standard Trainer works perfectly") |
| self._emit("from transformers import AutoTokenizer, TrainingArguments, BitsAndBytesConfig, Trainer") |
| self._emit("from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training") |
| self._emit("from datasets import load_dataset, Dataset") |
| self._emit("import torch") |
| self._emit("") |
| self._emit("# Use latest merge checkpoint — pick newest after_* dir in td_fuse_checkpoints/") |
| self._emit("_merge_ckpt = None") |
| self._emit("_ckpt_base = Path('td_fuse_checkpoints')") |
| self._emit("if _ckpt_base.exists():") |
| self._indent += 1 |
| self._emit("_after_dirs = sorted(_ckpt_base.glob('after_*'), key=lambda p: p.stat().st_mtime, reverse=True)") |
| self._emit("if _after_dirs and (_after_dirs[0] / 'model.safetensors').exists():") |
| self._indent += 1 |
| self._emit("_merge_ckpt = str(_after_dirs[0])") |
| self._indent -= 1 |
| self._indent -= 1 |
| self._emit("if _merge_ckpt:") |
| self._indent += 1 |
| self._emit('print(f"[td_lang] Using merge checkpoint for training: {_merge_ckpt}")') |
| self._emit("_train_ckpt = _merge_ckpt") |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit("_train_ckpt = checkpoint") |
| self._emit('print(f"[td_lang] Using checkpoint for training: {_train_ckpt}")') |
| self._indent -= 1 |
| self._emit("") |
| self._emit("tok = AutoTokenizer.from_pretrained(_train_ckpt)") |
| self._emit("if tok.pad_token is None:") |
| self._indent += 1 |
| self._emit("tok.pad_token = tok.eos_token") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("bnb_config = BitsAndBytesConfig(") |
| self._indent += 1 |
| self._emit("load_in_4bit=True,") |
| self._emit('bnb_4bit_quant_type="nf4",') |
| self._emit("bnb_4bit_compute_dtype=torch.bfloat16,") |
| self._emit("bnb_4bit_use_double_quant=True,") |
| self._indent -= 1 |
| self._emit(")") |
| self._emit("model = _load_model_smart(_train_ckpt, quantization_config=bnb_config, device_map='auto')") |
| self._emit("model = prepare_model_for_kbit_training(model)") |
| self._emit("") |
| self._emit("# LoRA adapters on mid-to-late layers (test_12: layers 16-28 for 32-layer)") |
| self._emit("lora_config = LoraConfig(") |
| self._indent += 1 |
| self._emit("r=32,") |
| self._emit("lora_alpha=64,") |
| self._emit("lora_dropout=0.05,") |
| self._emit('target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],') |
| self._emit('task_type="CAUSAL_LM",') |
| self._indent -= 1 |
| self._emit(")") |
| self._emit("model = get_peft_model(model, lora_config)") |
| self._emit("model.print_trainable_parameters() # Shows ~1-2% trainable vs total") |
| self._emit("") |
| self._emit(f'# Load training data') |
| self._emit(f'dataset_path = "{cmd.dataset}"') |
| self._emit("if dataset_path.endswith('.jsonl'):") |
| self._indent += 1 |
| self._emit("train_data = load_dataset('json', data_files=dataset_path, split='train')") |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit("train_data = load_dataset(dataset_path, split='train')") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("# Format synth data as text for SFT (prompt + response)") |
| self._emit("def _format_synth(example):") |
| self._indent += 1 |
| self._emit("text = example['prompt'] + '\\n' + example.get('response', '')") |
| self._emit("tokens = tok(text, truncation=True, max_length=512, padding='max_length')") |
| self._emit("tokens['labels'] = tokens['input_ids'].copy()") |
| self._emit("return tokens") |
| self._indent -= 1 |
| self._emit("train_data = train_data.map(_format_synth, remove_columns=train_data.column_names)") |
| self._emit("") |
| self._emit("training_args = TrainingArguments(") |
| self._indent += 1 |
| self._emit(f"max_steps={steps},") |
| self._emit(f"learning_rate={lr},") |
| self._emit("per_device_train_batch_size=1,") |
| self._emit("gradient_accumulation_steps=8,") |
| self._emit("logging_steps=10,") |
| self._emit('output_dir="td_lang_outputs/sft_training",') |
| self._emit("save_steps=50,") |
| self._emit('bf16=True,') |
| self._emit("gradient_checkpointing=True,") |
| self._emit("remove_unused_columns=False,") |
| self._indent -= 1 |
| self._emit(")") |
| self._emit("") |
| self._emit("trainer = Trainer(") |
| self._indent += 1 |
| self._emit("model=model,") |
| self._emit("args=training_args,") |
| self._emit("train_dataset=train_data,") |
| self._emit("processing_class=tok,") |
| self._indent -= 1 |
| self._emit(")") |
| self._emit("trainer.train()") |
| self._emit("") |
| self._emit("# Merge LoRA and save") |
| self._emit("model = model.merge_and_unload()") |
| self._emit("") |
| self._emit("# Free disk before save") |
| self._emit("import shutil, gc as _gc") |
| self._emit("for _d in ['td_fuse_outputs/final', 'td_fuse_outputs/healed']:") |
| self._indent += 1 |
| self._emit("_p = Path(_d)") |
| self._emit("if _p.exists() and _p.is_dir(): shutil.rmtree(str(_p)); print(f'[td_lang] Freed: {_d}')") |
| self._indent -= 1 |
| self._emit("_gc.collect()") |
| self._emit("model.save_pretrained('td_lang_outputs/grpo_trained')") |
| self._emit("tok.save_pretrained('td_lang_outputs/grpo_trained')") |
| self._emit(f'models["{cmd.target}"]["checkpoint"] = "td_lang_outputs/grpo_trained"') |
| self._emit("print('[td_lang] Training complete - model saved to td_lang_outputs/grpo_trained')") |
|
|
| elif cmd.method in ("sft", "dpo"): |
| self._emit(f"# {cmd.method.upper()} training with QLoRA (fits on 24GB 4090)") |
| self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig") |
| self._emit("from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training") |
| if cmd.method == "sft": |
| self._emit("from trl import SFTTrainer") |
| else: |
| self._emit("from trl import DPOTrainer, DPOConfig") |
| self._emit("from datasets import load_dataset") |
| self._emit("import torch") |
| self._emit("") |
| self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)") |
| self._emit("if tok.pad_token is None:") |
| self._indent += 1 |
| self._emit("tok.pad_token = tok.eos_token") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("bnb_config = BitsAndBytesConfig(") |
| self._indent += 1 |
| self._emit("load_in_4bit=True,") |
| self._emit('bnb_4bit_quant_type="nf4",') |
| self._emit("bnb_4bit_compute_dtype=torch.bfloat16,") |
| self._emit("bnb_4bit_use_double_quant=True,") |
| self._indent -= 1 |
| self._emit(")") |
| self._emit("model = _load_model_smart(checkpoint, quantization_config=bnb_config, device_map='auto')") |
| self._emit("model = prepare_model_for_kbit_training(model)") |
| self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,") |
| self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],') |
| self._emit(' task_type="CAUSAL_LM")') |
| self._emit("model = get_peft_model(model, lora_config)") |
| self._emit(f'dataset_path = "{cmd.dataset}"') |
| self._emit("if dataset_path.endswith('.jsonl'):") |
| self._indent += 1 |
| self._emit("train_data = load_dataset('json', data_files=dataset_path, split='train')") |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit("train_data = load_dataset(dataset_path, split='train')") |
| self._indent -= 1 |
| self._emit("") |
| self._emit(f'print("[td_lang] Running {cmd.method.upper()} for {steps} steps...")') |
| if cmd.method == "sft": |
| self._emit("training_args = TrainingArguments(") |
| self._indent += 1 |
| self._emit('output_dir="td_lang_outputs/sft_training",') |
| self._emit(f"max_steps={steps},") |
| self._emit(f"learning_rate={lr},") |
| self._emit("per_device_train_batch_size=2,") |
| self._emit("gradient_accumulation_steps=4,") |
| self._emit("logging_steps=10,") |
| self._emit(f"save_steps=max(10, int({steps}/2)),") |
| self._emit("bf16=True,") |
| self._indent -= 1 |
| self._emit(")") |
| self._emit("trainer = SFTTrainer(") |
| self._indent += 1 |
| self._emit("model=model,") |
| self._emit("processing_class=tok,") |
| self._emit("args=training_args,") |
| self._emit("train_dataset=train_data,") |
| self._emit('dataset_text_field="text",') |
| self._indent -= 1 |
| self._emit(")") |
| self._emit("trainer.train()") |
| self._emit('trainer.save_model("td_lang_outputs/sft_trained")') |
| self._emit(f'models["{cmd.target}"]["checkpoint"] = "td_lang_outputs/sft_trained"') |
| else: |
| self._emit("training_args = DPOConfig(") |
| self._indent += 1 |
| self._emit(f"max_steps={steps},") |
| self._emit(f"learning_rate={lr},") |
| self._emit("per_device_train_batch_size=1,") |
| self._emit("gradient_accumulation_steps=4,") |
| self._emit("logging_steps=10,") |
| self._emit('output_dir="td_lang_outputs/dpo_training",') |
| self._emit("bf16=True,") |
| self._indent -= 1 |
| self._emit(")") |
| self._emit("trainer = DPOTrainer(") |
| self._indent += 1 |
| self._emit("model=model,") |
| self._emit("ref_model=None,") |
| self._emit("beta=0.1,") |
| self._emit("train_dataset=train_data,") |
| self._emit("processing_class=tok,") |
| self._emit("args=training_args,") |
| self._emit('loss_type="sigmoid",') |
| self._indent -= 1 |
| self._emit(")") |
| self._emit("trainer.train()") |
| self._emit('trainer.save_model("td_lang_outputs/dpo_trained")') |
| self._emit(f'models["{cmd.target}"]["checkpoint"] = "td_lang_outputs/dpo_trained"') |
|
|
| else: |
| self._emit(f'print("[td_lang] Unknown training method: {cmd.method}")') |
| self._emit('print("[td_lang] Supported: grpo, sft, dpo")') |
|
|
| self._emit("") |
| self._emit(f'lineage["{cmd.target}"]["operations"].append({{') |
| self._indent += 1 |
| self._emit('"op": "train",') |
| self._emit(f'"method": "{cmd.method}",') |
| self._emit(f'"steps": {steps},') |
| self._emit(f'"lr": {lr},') |
| self._emit(f'"dataset": "{cmd.dataset}",') |
| self._emit('"timestamp": datetime.now().isoformat(),') |
| self._indent -= 1 |
| self._emit("})") |
| self._emit("import gc; gc.collect()") |
| self._emit(f'print("[td_lang] Training complete.")') |
|
|
| def _emit_debate(self, cmd: DebateCmd) -> None: |
| """Generate code for: debate target rounds N candidates N [-> output.jsonl] |
| |
| Weakness-aware single-model debate with structured judging. |
| """ |
| self._emit(f'print("[td_lang] Running debate: {cmd.rounds} rounds, {cmd.candidates} candidates...")') |
| self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")') |
| self._emit("if not checkpoint:") |
| self._indent += 1 |
| self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]') |
| self._indent -= 1 |
| self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer") |
| self._emit("import torch, random, json") |
| self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)") |
| self._emit("model = _load_model_smart(checkpoint, torch_dtype=torch.bfloat16, device_map='auto')") |
| self._emit("model.eval()") |
| self._emit("") |
| self._emit("# Persona-based debate (test_14: single-model diversity protocol)") |
| self._emit("personas = [") |
| self._indent += 1 |
| self._emit('"You are a careful, skeptical analyst. Question every assumption.",') |
| self._emit('"You are a creative problem solver. Think outside the box.",') |
| self._emit('"You are a rigorous mathematician. Show formal proofs.",') |
| self._emit('"You are a practical engineer. Focus on what works.",') |
| self._emit('"You are a devil\'s advocate. Find flaws in every argument.",') |
| self._emit('"You are an optimist. Find the best interpretation.",') |
| self._emit('"You are a minimalist. Give the simplest correct answer.",') |
| self._emit('"You are a professor. Explain with clarity and depth.",') |
| self._indent -= 1 |
| self._emit("]") |
| self._emit("") |
| self._emit("# Base prompts + diagnosis-derived prompts") |
| self._emit(f'diag = results.get("{cmd.target}_diagnose", [])') |
| self._emit("debate_prompts = [") |
| self._indent += 1 |
| self._emit('"Solve: What is the sum of the first 20 prime numbers?",') |
| self._emit('"Explain why the sky appears blue using physics.",') |
| self._emit('"Write a Python function to find the longest palindrome in a string.",') |
| self._emit('"What are the logical flaws in this argument: All birds can fly, penguins are birds, therefore penguins can fly.",') |
| self._emit('"If a train travels 60mph for 2.5 hours, then 80mph for 1.5 hours, what is the average speed?",') |
| self._indent -= 1 |
| self._emit("]") |
| self._emit("for d in diag:") |
| self._indent += 1 |
| self._emit("resp = d.get('response', '')") |
| self._emit("snip = resp[:140]") |
| self._emit('debate_prompts.append(f"Address this weakness you listed: {snip}. Provide a concrete fix and example.")') |
| self._indent -= 1 |
| self._emit("") |
| self._emit("debate_results = []") |
| self._emit(f"for round_num in range({cmd.rounds}):") |
| self._indent += 1 |
| self._emit(f'print(f\" Round {{round_num + 1}}/{cmd.rounds}...\")') |
| self._emit("prompt = random.choice(debate_prompts)") |
| self._emit(f"selected_personas = random.sample(personas, min({cmd.candidates}, len(personas)))") |
| self._emit("candidates = []") |
| self._emit("for persona in selected_personas:") |
| self._indent += 1 |
| self._emit('full_prompt = f\"{persona}\\n\\nQuestion: {prompt}\\n\\nAnswer:\"') |
| self._emit('inputs = tok(full_prompt, return_tensors=\"pt\").to(model.device)') |
| self._emit("with torch.no_grad():") |
| self._indent += 1 |
| self._emit("output = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.9)") |
| self._indent -= 1 |
| self._emit("response = tok.decode(output[0], skip_special_tokens=True)") |
| self._emit('candidates.append({"persona": persona, "response": response})') |
| self._indent -= 1 |
| self._emit("") |
| self._emit("# Judge: structured JSON scoring for correctness, reasoning, safety, style") |
| self._emit('judge_prompt = "You are a neutral judge. Return JSON with keys: scores (list of {id, correctness, reasoning, safety, style}), winner_id, rationale. Scores 1-10.\\n"') |
| self._emit("for idx, c in enumerate(candidates):") |
| self._indent += 1 |
| self._emit("resp_snip = c['response'][:400]") |
| self._emit('judge_prompt += f"Answer {idx+1}: {resp_snip}\\n\\n"') |
| self._indent -= 1 |
| self._emit('inputs = tok(judge_prompt, return_tensors=\"pt\").to(model.device)') |
| self._emit("with torch.no_grad():") |
| self._indent += 1 |
| self._emit("output = model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.2)") |
| self._indent -= 1 |
| self._emit("judgment = tok.decode(output[0], skip_special_tokens=True)") |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("judgment_json = json.loads(judgment[judgment.find('{'):])") |
| self._indent -= 1 |
| self._emit("except Exception:") |
| self._indent += 1 |
| self._emit("judgment_json = {'raw': judgment}") |
| self._indent -= 1 |
| self._emit("debate_results.append({") |
| self._indent += 1 |
| self._emit('"round": round_num + 1,') |
| self._emit('"prompt": prompt,') |
| self._emit('"candidates": candidates,') |
| self._emit('"judgment": judgment_json,') |
| self._indent -= 1 |
| self._emit("})") |
| self._indent -= 1 |
| self._emit("") |
| self._emit(f'results["{cmd.target}_debate"] = debate_results') |
| self._emit(f'lineage["{cmd.target}"]["operations"].append({{') |
| self._indent += 1 |
| self._emit('"op": "debate",') |
| self._emit(f'"rounds": {cmd.rounds},') |
| self._emit(f'"candidates": {cmd.candidates},') |
| self._emit('"timestamp": datetime.now().isoformat(),') |
| self._indent -= 1 |
| self._emit("})") |
| output_path = cmd.output or "debate_pairs.jsonl" |
| self._emit(f'debate_path = Path("{output_path}")') |
| self._emit("debate_path.parent.mkdir(parents=True, exist_ok=True)") |
| self._emit('with open(debate_path, "w") as f:') |
| self._indent += 1 |
| self._emit("for entry in debate_results:") |
| self._indent += 1 |
| self._emit("f.write(json.dumps(entry, default=str) + chr(10))") |
| self._indent -= 1 |
| self._indent -= 1 |
| self._emit('print(f"[td_lang] Debate results saved to {debate_path} ({len(debate_results)} rounds)")') |
| self._emit("del model, tok") |
| self._emit("import gc; gc.collect()") |
|
|
| |
|
|
| def _emit_edit(self, cmd: EditCmd) -> None: |
| """EDIT - surgical LoRA/DoRA on specific layers. |
| |
| From test_18: all 3 AIs agree LoRA is safe default, DoRA beats by 1-4%. |
| layers_to_transform supports targeting specific layers (e.g., 16-28). |
| "Try before buy": eval with adapters enabled vs disabled, merge only if gates pass. |
| """ |
| alias = cmd.target |
| method = cmd.method |
| layers = cmd.layers |
| lr = cmd.learning_rate or 1e-4 |
|
|
| self._emit(f'print("[td_lang] EDIT - surgical {method} on {alias}, layers={layers}")') |
| self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer") |
| self._emit("import torch") |
| self._emit("from peft import LoraConfig, get_peft_model, PeftModel") |
| self._emit("from bitsandbytes import __version__ as bnb_version # ensure bnb installed") |
| self._emit("") |
| |
| self._emit(f'checkpoint = models.get("{alias}", {{}}).get("checkpoint") or models["{alias}"].get("model_ref")') |
| self._emit('print(f"[td_lang] Loading base model for EDIT from {checkpoint} (4-bit QLoRA)...")') |
| self._emit("bnb_config = {") |
| self._indent += 1 |
| self._emit('"load_in_4bit": True,') |
| self._emit('"bnb_4bit_compute_dtype": torch.bfloat16,') |
| self._emit('"bnb_4bit_use_double_quant": True,') |
| self._emit('"bnb_4bit_quant_type": "nf4",') |
| self._indent -= 1 |
| self._emit("}") |
| self._emit("model = _load_model_smart(checkpoint, device_map='auto', **bnb_config)") |
| self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)") |
| self._emit("") |
| |
| self._emit("# Parse layer targeting") |
| if layers == "all": |
| self._emit("layers_to_transform = None # all layers") |
| elif "-" in layers: |
| parts = layers.split("-") |
| self._emit(f"layers_to_transform = list(range({parts[0]}, {int(parts[1]) + 1}))") |
| else: |
| self._emit(f"layers_to_transform = [{layers}]") |
| self._emit("") |
|
|
| |
| self._emit("use_dora = method == 'dora'") |
| self._emit("edit_r = getattr(cmd, 'r', 8)") |
| self._emit("edit_alpha = getattr(cmd, 'alpha', 16)") |
| self._emit("edit_config = LoraConfig(") |
| self._indent += 1 |
| self._emit("r=edit_r,") |
| self._emit("lora_alpha=edit_alpha,") |
| self._emit('target_modules=["q_proj", "v_proj"],') |
| self._emit("lora_dropout=0.05,") |
| self._emit('bias="none",') |
| self._emit('task_type="CAUSAL_LM",') |
| self._emit("use_dora=use_dora,") |
| if layers != "all": |
| self._emit("layers_to_transform=layers_to_transform,") |
| self._emit('layers_pattern="layers",') |
| self._indent -= 1 |
| self._emit(")") |
| self._emit("") |
|
|
| |
| self._emit("# Inject adapter - base weights stay frozen") |
| self._emit("model = get_peft_model(model, edit_config)") |
| self._emit("model.print_trainable_parameters()") |
| self._emit("") |
|
|
| |
| self._emit("# Dry-run report: verify correct modules were targeted") |
| self._emit("wrapped_modules = [n for n, _ in model.named_modules() if 'lora' in n.lower()]") |
| self._emit(f'print(f"[td_lang] EDIT: {{len(wrapped_modules)}} modules wrapped with {method}")') |
| self._emit('for wm in wrapped_modules[:10]:') |
| self._indent += 1 |
| self._emit('print(f" - {wm}")') |
| self._indent -= 1 |
| self._emit('if len(wrapped_modules) > 10:') |
| self._indent += 1 |
| self._emit('print(f" ... and {len(wrapped_modules) - 10} more")') |
| self._indent -= 1 |
| self._emit("") |
|
|
| |
| self._emit('sample_prompts = ["What is 7+8?", "Explain photosynthesis in one paragraph.", "Write a Python function fib(n)."]') |
| self._emit("def run_quick_eval(enable_adapters: bool):") |
| self._indent += 1 |
| self._emit("if enable_adapters:") |
| self._indent += 1 |
| self._emit("if hasattr(model, 'enable_adapters'): model.enable_adapters()") |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit("if hasattr(model, 'disable_adapters'): model.disable_adapters()") |
| self._indent -= 1 |
| self._emit("responses = []") |
| self._emit("for p in sample_prompts:") |
| self._indent += 1 |
| self._emit("inputs = tok(p, return_tensors='pt').to(model.device)") |
| self._emit("with torch.no_grad():") |
| self._indent += 1 |
| self._emit("out = model.generate(**inputs, max_new_tokens=128, temperature=0.7, do_sample=True)") |
| self._indent -= 1 |
| self._emit("resp = tok.decode(out[0], skip_special_tokens=True)") |
| self._emit("responses.append(resp)") |
| self._indent -= 1 |
| self._emit("avg_len = sum(len(r) for r in responses) / len(responses)") |
| self._emit("return responses, avg_len") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("on_resps, on_len = run_quick_eval(True)") |
| self._emit("off_resps, off_len = run_quick_eval(False)") |
| self._emit('print("[td_lang] Try-before-buy results:")') |
| self._emit('print(f" Adapter ON avg length: {on_len:.1f}")') |
| self._emit('print(f" Adapter OFF avg length: {off_len:.1f}")') |
| self._emit("for i, (a, b) in enumerate(zip(on_resps, off_resps)):") |
| self._indent += 1 |
| self._emit('print(f"Prompt {i+1}:")') |
| self._emit('print(" ON :", a[:200])') |
| self._emit('print(" OFF:", b[:200])') |
| self._indent -= 1 |
| self._emit("") |
|
|
| |
| self._emit(f'edit_save_dir = os.path.join(output_dir, "{alias}_edit_{method}")') |
| self._emit("os.makedirs(edit_save_dir, exist_ok=True)") |
| self._emit("model.save_pretrained(edit_save_dir)") |
| self._emit(f'print(f"[td_lang] EDIT adapter saved to {{edit_save_dir}}")') |
| self._emit(f'print("[td_lang] Adapter NOT merged - use commit with gates to merge permanently")') |
| self._emit("") |
|
|
| |
| self._emit(f'models["{alias}"] = model') |
|
|
| def _emit_fork(self, cmd: ForkCmd) -> None: |
| """FORK - branch current model weights for parallel experiments. |
| |
| From test_18: all 3 AIs say disk-based only on 4090. |
| Cheap fork = copy manifest + adapter files, share base weights. |
| Uses safetensors format. |
| """ |
| source = cmd.source |
| alias = cmd.alias |
|
|
| self._emit(f'print("[td_lang] FORK - branching {source} as {alias}")') |
| self._emit(f'source_model = models["{source}"]') |
| self._emit("import torch") |
| self._emit("") |
|
|
| |
| self._emit("import hashlib") |
| self._emit('fork_suffix = hashlib.sha1((str(time.time()) + "{alias}").encode()).hexdigest()[:8]') |
| self._emit(f'fork_dir = os.path.join(output_dir, "forks", "{alias}_" + fork_suffix)') |
| self._emit("os.makedirs(fork_dir, exist_ok=True)") |
| self._emit("") |
|
|
| |
| self._emit("# Write fork manifest - tracks lineage") |
| self._emit("import json") |
| self._emit("fork_manifest = {") |
| self._emit(f' "fork_name": "{alias}",') |
| self._emit(f' "forked_from": "{source}",') |
| self._emit(f' "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),') |
| self._emit(f' "base_ref": models.get("__base_ref_{source}", "unknown"),') |
| self._emit("}") |
| self._emit("") |
|
|
| |
| self._emit("# Cheap fork: save adapters only if PEFT model, else full checkpoint") |
| self._emit("is_peft = hasattr(source_model, 'peft_config')") |
| self._emit("if is_peft:") |
| self._indent += 1 |
| self._emit("# PEFT model - save only adapter weights (small, fast)") |
| self._emit('adapter_dir = os.path.join(fork_dir, "adapters")') |
| self._emit("source_model.save_pretrained(adapter_dir)") |
| self._emit('fork_manifest["fork_type"] = "adapter"') |
| self._emit('fork_manifest["adapter_dir"] = adapter_dir') |
| self._emit('print(f"[td_lang] Cheap fork: adapter saved to {adapter_dir}")') |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit("# Full model - clone tensors then save to safetensors") |
| self._emit("from safetensors.torch import save_file") |
| self._emit("state = {k: v.detach().cpu().clone() for k, v in source_model.state_dict().items()}") |
| self._emit('ckpt_path = os.path.join(fork_dir, "model.safetensors")') |
| self._emit("save_file(state, ckpt_path)") |
| self._emit('fork_manifest["fork_type"] = "full_checkpoint"') |
| self._emit('fork_manifest["checkpoint_path"] = ckpt_path') |
| self._emit('print(f"[td_lang] Full fork: checkpoint saved to {ckpt_path}")') |
| self._indent -= 1 |
| self._emit("") |
|
|
| |
| self._emit("# Save RNG state for reproducibility") |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("rng_state = torch.cuda.get_rng_state().cpu() if torch.cuda.is_available() else None") |
| self._indent -= 1 |
| self._emit("except Exception:") |
| self._indent += 1 |
| self._emit("rng_state = None") |
| self._indent -= 1 |
| self._emit("if rng_state is not None:") |
| self._indent += 1 |
| self._emit('torch.save(rng_state, os.path.join(fork_dir, "rng_state.pt"))') |
| self._emit('fork_manifest["rng_state"] = "rng_state.pt"') |
| self._indent -= 1 |
| self._emit("") |
| self._emit('manifest_path = os.path.join(fork_dir, "manifest.json")') |
| self._emit('with open(manifest_path, "w") as f:') |
| self._indent += 1 |
| self._emit("json.dump(fork_manifest, f, indent=2)") |
| self._indent -= 1 |
| self._emit(f'print(f"[td_lang] Fork manifest: {{manifest_path}}")') |
| self._emit("") |
|
|
| |
| self._emit(f'models["{alias}"] = source_model # shares reference until divergence') |
| self._emit(f'lineage["{alias}"] = {{"forked_from": "{source}", "operations": []}}') |
|
|
| def _emit_reset(self, cmd: ResetCmd) -> None: |
| """RESET - revert model to a previous checkpoint. |
| |
| From test_18: del model, clear CUDA cache, reload. |
| Must also reset optimizer state. Use assign=True to avoid doubling VRAM. |
| """ |
| alias = cmd.target |
| checkpoint = cmd.checkpoint |
|
|
| self._emit(f'print("[td_lang] RESET - reverting {alias} to {checkpoint}")') |
| self._emit("") |
|
|
| |
| self._emit("# Free current model from VRAM") |
| self._emit(f'del models["{alias}"]') |
| self._emit("import gc; gc.collect()") |
| self._emit("torch.cuda.empty_cache()") |
| self._emit(f'print("[td_lang] VRAM cleared")') |
| self._emit("") |
|
|
| |
| self._emit("# Resolve checkpoint path") |
| self._emit(f'ckpt_path = "{checkpoint}"') |
| self._emit("base_ref = ckpt_path") |
| self._emit("# Check if it's a fork directory with manifest") |
| self._emit('fork_manifest_path = os.path.join(ckpt_path, "manifest.json") if os.path.isdir(ckpt_path) else None') |
| self._emit("") |
|
|
| |
| self._emit("# Reload from checkpoint") |
| self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer") |
| self._emit("") |
| self._emit("if fork_manifest_path and os.path.exists(fork_manifest_path):") |
| self._indent += 1 |
| self._emit("# Loading from a fork - read manifest") |
| self._emit("import json") |
| self._emit("with open(fork_manifest_path) as f:") |
| self._indent += 1 |
| self._emit("manifest = json.load(f)") |
| self._indent -= 1 |
| self._emit('base_ref = manifest.get("base_ref", ckpt_path)') |
| self._emit("model = _load_model_smart(base_ref, torch_dtype=torch.float16, device_map='cuda')") |
| self._emit('if manifest.get("fork_type") == "adapter":') |
| self._indent += 1 |
| self._emit("from peft import PeftModel") |
| self._emit('model = PeftModel.from_pretrained(model, manifest["adapter_dir"])') |
| self._indent -= 1 |
| self._indent -= 1 |
| self._emit("elif os.path.isdir(ckpt_path):") |
| self._indent += 1 |
| self._emit("# Loading from a HF-style directory") |
| self._emit("model = _load_model_smart(ckpt_path, torch_dtype=torch.float16, device_map='cuda')") |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit("# Loading from a safetensors file") |
| self._emit("from safetensors.torch import load_file") |
| self._emit("state = load_file(ckpt_path, device='cpu')") |
| self._emit("# Need base model architecture - reload from original") |
| self._emit(f'base_ref = models.get("__base_ref_{alias}", ckpt_path)') |
| self._emit("model = _load_model_smart(base_ref, torch_dtype=torch.float16, device_map='cuda')") |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("model.load_state_dict(state, strict=True, assign=True)") |
| self._indent -= 1 |
| self._emit("except Exception as e:") |
| self._indent += 1 |
| self._emit('print(f"[td_lang] Shape mismatch on reset load: {e}. Retrying non-strict.")') |
| self._emit("model.load_state_dict(state, strict=False)") |
| self._indent -= 1 |
| self._indent -= 1 |
| self._emit("") |
|
|
| |
| self._emit(f'models["{alias}"] = model') |
| self._emit(f'print(f"[td_lang] RESET complete - {alias} restored from {checkpoint}")') |
| self._emit("") |
|
|
| |
| self._emit("torch.cuda.empty_cache()") |
| self._emit(f'print("[td_lang] Note: optimizer state cleared; next train starts fresh.")') |
| self._emit("# Smoke eval after reset") |
| self._emit('sample_prompts = ["Hello!", "2+2?", "Define gravity.", "Write a Python loop 1..3.", "Capital of France?"]') |
| self._emit("tok = AutoTokenizer.from_pretrained(ckpt_path if os.path.isdir(ckpt_path) else base_ref)") |
| self._emit("model.eval()") |
| self._emit("for p in sample_prompts:") |
| self._indent += 1 |
| self._emit("inputs = tok(p, return_tensors='pt').to(model.device)") |
| self._emit("with torch.no_grad():") |
| self._indent += 1 |
| self._emit("out = model.generate(**inputs, max_new_tokens=40, do_sample=False)") |
| self._indent -= 1 |
| self._emit("resp = tok.decode(out[0], skip_special_tokens=True)") |
| self._emit('print(f"[td_lang][reset smoke] {p} -> {resp[:120]}")') |
| self._indent -= 1 |
|
|
| def _emit_prune(self, cmd: PruneCmd) -> None: |
| """PRUNE - structural pruning of language backbone. |
| |
| From test_18: 20% structured max (LLM-Pruner). Wanda metric (Grok). |
| Language backbone only, never vision encoder. Recovery: 200-800 steps LoRA. |
| """ |
| alias = cmd.target |
| method = cmd.method |
| aggressiveness = cmd.aggressiveness |
|
|
| self._emit("import torch") |
| self._emit(f'print("[td_lang] PRUNE - {method} pruning on {alias}, {aggressiveness*100:.0f}% removal")') |
| self._emit(f'model = models["{alias}"]') |
| self._emit("") |
|
|
| |
| self._emit("# Safety: cap pruning at 30% (beyond this = cliff, per LLM-Pruner)") |
| self._emit(f"prune_ratio = min({aggressiveness}, 0.30)") |
| self._emit(f"if prune_ratio != {aggressiveness}:") |
| self._indent += 1 |
| self._emit(f'print(f"[td_lang] WARNING: aggressiveness capped at 30% (requested {aggressiveness*100:.0f}%)")') |
| self._indent -= 1 |
| self._emit("") |
|
|
| |
| self._emit("# Target language backbone ONLY - never prune vision encoder") |
| self._emit("# Filter for language model linear layers") |
| self._emit("target_modules = []") |
| self._emit("for name, module in model.named_modules():") |
| self._indent += 1 |
| self._emit("if isinstance(module, torch.nn.Linear):") |
| self._indent += 1 |
| self._emit("# Skip vision encoder, embeddings, and output head") |
| self._emit('is_vision = any(v in name for v in ["visual", "vision", "vit", "image", "pixel"])') |
| self._emit('is_embed = any(e in name for e in ["embed", "lm_head", "output"])') |
| self._emit("if not is_vision and not is_embed:") |
| self._indent += 1 |
| self._emit("target_modules.append((name, module))") |
| self._indent -= 1 |
| self._indent -= 1 |
| self._indent -= 1 |
| self._emit('print(f"[td_lang] Found {len(target_modules)} prunable language layers")') |
| self._emit("") |
|
|
| |
| self._emit(f"# Pruning method: {method}") |
| if method == "wanda": |
| self._emit("# Wanda: weight magnitude × input activation norm (Grok's recommendation)") |
| self._emit("# Collect activations on small calibration batch, then prune with keep_multiple_of=8") |
| self._emit("import torch.nn.utils.prune as prune") |
| self._emit("calib_texts = [") |
| self._indent += 1 |
| self._emit('"The quick brown fox jumps over the lazy dog.",') |
| self._emit('"Solve 12 + 37.",') |
| self._emit('"Write a for loop in Python that sums 1..10.",') |
| self._emit('"Explain why the sky is blue.",') |
| self._indent -= 1 |
| self._emit("]") |
| self._emit("from transformers import AutoTokenizer") |
| self._emit("base_ref = None") |
| self._emit("if isinstance(models.get(alias), dict):") |
| self._indent += 1 |
| self._emit("base_ref = models[alias].get('model_ref')") |
| self._indent -= 1 |
| self._emit("if base_ref is None:") |
| self._indent += 1 |
| self._emit(f"base_ref = models.get('__base_ref_{alias}', 'Qwen/Qwen3-VL-8B-Instruct')") |
| self._indent -= 1 |
| self._emit("tok = AutoTokenizer.from_pretrained(base_ref)") |
| self._emit("activation_sums = {}") |
| self._emit("hooks = []") |
| self._emit("def make_hook(name):") |
| self._indent += 1 |
| self._emit("def _hook(module, inp, out):") |
| self._indent += 1 |
| self._emit("with torch.no_grad():") |
| self._indent += 1 |
| self._emit("act = inp[0].detach().abs().mean(dim=0)") |
| self._emit("activation_sums[name] = activation_sums.get(name, 0) + act") |
| self._indent -= 2 |
| self._emit("return _hook") |
| self._indent -= 1 |
| self._emit("for name, module in target_modules:") |
| self._indent += 1 |
| self._emit("hooks.append(module.register_forward_hook(make_hook(name)))") |
| self._indent -= 1 |
| self._emit("# Run one calibration pass") |
| self._emit("for txt in calib_texts:") |
| self._indent += 1 |
| self._emit("inputs = tok(txt, return_tensors='pt').to(model.device)") |
| self._emit("with torch.no_grad(): model(**inputs)") |
| self._indent -= 1 |
| self._emit("for h in hooks: h.remove()") |
| self._emit("") |
| self._emit("import torch.nn.utils.prune as prune") |
| self._emit("pruned_count = 0") |
| self._emit("for layer_name, layer_module in target_modules:") |
| self._indent += 1 |
| self._emit("act = activation_sums.get(layer_name)") |
| self._emit("if act is None:") |
| self._indent += 1 |
| self._emit('print(f"[td_lang] Skip {layer_name}: no activation stats")') |
| self._emit("continue") |
| self._indent -= 1 |
| self._emit("scores = (layer_module.weight.detach().abs() * act.unsqueeze(0)).mean(dim=1)") |
| self._emit("keep = max(8, int((1 - prune_ratio) * scores.numel()))") |
| self._emit("keep = (keep // 8) * 8") |
| self._emit("keep = min(max(8, keep), scores.numel())") |
| self._emit("amount = 1 - (keep / scores.numel())") |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("prune.ln_structured(layer_module, name='weight', amount=amount, n=1, dim=0)") |
| self._emit("prune.remove(layer_module, 'weight')") |
| self._emit("pruned_count += 1") |
| self._indent -= 1 |
| self._emit("except Exception as e:") |
| self._indent += 1 |
| self._emit('print(f"[td_lang] Skip {layer_name}: {e}")') |
| self._indent -= 1 |
| self._indent -= 1 |
| elif method == "magnitude": |
| self._emit("# Magnitude: simple L1 norm of weight rows") |
| self._emit("import torch.nn.utils.prune as prune") |
| self._emit("") |
| self._emit("pruned_count = 0") |
| self._emit("for layer_name, layer_module in target_modules:") |
| self._indent += 1 |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("prune.ln_structured(layer_module, name='weight', amount=prune_ratio, n=1, dim=0)") |
| self._emit("prune.remove(layer_module, 'weight')") |
| self._emit("pruned_count += 1") |
| self._indent -= 1 |
| self._emit("except Exception as e:") |
| self._indent += 1 |
| self._emit('print(f"[td_lang] Skip {layer_name}: {e}")') |
| self._indent -= 1 |
| self._indent -= 1 |
| else: |
| self._emit("# Taylor: gradient-based importance (needs backprop - VRAM heavy)") |
| self._emit("# Falling back to magnitude as MVP - Taylor needs calibration + backprop") |
| self._emit(f'print("[td_lang] WARNING: Taylor pruning falls back to magnitude on single GPU")') |
| self._emit("import torch.nn.utils.prune as prune") |
| self._emit("") |
| self._emit("pruned_count = 0") |
| self._emit("for layer_name, layer_module in target_modules:") |
| self._indent += 1 |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("prune.ln_structured(layer_module, name='weight', amount=prune_ratio, n=1, dim=0)") |
| self._emit("prune.remove(layer_module, 'weight')") |
| self._emit("pruned_count += 1") |
| self._indent -= 1 |
| self._emit("except Exception as e:") |
| self._indent += 1 |
| self._emit('print(f"[td_lang] Skip {layer_name}: {e}")') |
| self._indent -= 1 |
| self._indent -= 1 |
| self._emit("") |
|
|
| |
| self._emit('print(f"[td_lang] Pruned {pruned_count}/{len(target_modules)} layers at {prune_ratio*100:.0f}%")') |
| self._emit("") |
|
|
| |
| self._emit("# Save prune report for auditing") |
| self._emit("import json") |
| self._emit("prune_report = {") |
| self._emit(f' "method": "{method}",') |
| self._emit(f' "requested_aggressiveness": {aggressiveness},') |
| self._emit(' "actual_ratio": prune_ratio,') |
| self._emit(' "layers_pruned": pruned_count,') |
| self._emit(' "total_target_layers": len(target_modules),') |
| self._emit(' "vision_touched": False,') |
| self._emit("}") |
| self._emit(f'prune_report_path = os.path.join(output_dir, "{alias}_prune_report.json")') |
| self._emit('with open(prune_report_path, "w") as f:') |
| self._indent += 1 |
| self._emit("json.dump(prune_report, f, indent=2)") |
| self._indent -= 1 |
| self._emit(f'print(f"[td_lang] Prune report: {{prune_report_path}}")') |
| self._emit("") |
|
|
| |
| self._emit("# Recovery: you should run heal or train after pruning") |
| self._emit("# LLM-Pruner shows recovery in 200-800 steps with LoRA r=8") |
| self._emit(f'print("[td_lang] IMPORTANT: Run heal or train after pruning for recovery (suggest: heal {alias} lora_r 8 epochs 1, ~400 steps)")') |
| self._emit(f'models["{alias}"] = model') |
|
|
| |
|
|
| def _emit_cmd(self, cmd, program: TDProgram) -> None: |
| """Emit a single command - used by repeat/if to emit body commands.""" |
| if isinstance(cmd, LoadCmd): |
| self._emit_load(cmd) |
| elif isinstance(cmd, MergeCmd): |
| self._emit_merge(cmd) |
| elif isinstance(cmd, HealCmd): |
| self._emit_heal(cmd) |
| elif isinstance(cmd, EvalCmd): |
| self._emit_eval(cmd) |
| elif isinstance(cmd, CommitCmd): |
| self._emit_commit(cmd, program.gates) |
| elif isinstance(cmd, DiagnoseCmd): |
| self._emit_diagnose(cmd) |
| elif isinstance(cmd, SynthCmd): |
| self._emit_synth(cmd) |
| elif isinstance(cmd, TrainCmd): |
| self._emit_train(cmd, program) |
| elif isinstance(cmd, DebateCmd): |
| self._emit_debate(cmd) |
| elif isinstance(cmd, EditCmd): |
| self._emit_edit(cmd) |
| elif isinstance(cmd, ForkCmd): |
| self._emit_fork(cmd) |
| elif isinstance(cmd, ResetCmd): |
| self._emit_reset(cmd) |
| elif isinstance(cmd, PruneCmd): |
| self._emit_prune(cmd) |
| elif isinstance(cmd, FuseCmd): |
| self._emit_fuse(cmd) |
| elif isinstance(cmd, AbsorbCmd): |
| self._emit_absorb(cmd) |
| elif isinstance(cmd, SnapshotCmd): |
| self._emit_snapshot(cmd, program) |
| elif isinstance(cmd, ReportCmd): |
| self._emit_report(cmd, program) |
| elif isinstance(cmd, NotifyCmd): |
| self._emit_notify(cmd, program) |
| elif isinstance(cmd, SaveCmd): |
| self._emit_save(cmd, program) |
| elif isinstance(cmd, RepeatBlock): |
| self._emit_repeat(cmd, program) |
| elif isinstance(cmd, IfBlock): |
| self._emit_if(cmd, program) |
| elif isinstance(cmd, ScheduleCmd): |
| self._emit_schedule(cmd, program) |
| elif isinstance(cmd, DownloadCmd): |
| self._emit_download(cmd) |
| elif isinstance(cmd, CompareCmd): |
| self._emit_compare(cmd) |
| elif isinstance(cmd, VerifyCmd): |
| self._emit_verify(cmd) |
| elif isinstance(cmd, VoteCmd): |
| self._emit_vote(cmd) |
| elif isinstance(cmd, PromptBlock): |
| self._emit_prompt(cmd) |
| elif isinstance(cmd, DistillCmd): |
| self._emit_distill(cmd) |
| elif isinstance(cmd, RollbackCmd): |
| self._emit_rollback(cmd) |
| elif isinstance(cmd, CurriculumCmd): |
| self._emit_curriculum(cmd, program) |
| elif isinstance(cmd, StarCmd): |
| self._emit_star(cmd, program) |
| elif isinstance(cmd, BestOfCmd): |
| self._emit_best_of(cmd, program) |
| elif isinstance(cmd, ExploitCmd): |
| self._emit_exploit(cmd, program) |
| elif isinstance(cmd, ArenaCmd): |
| self._emit_arena(cmd, program) |
| elif isinstance(cmd, ResearchArenaCmd): |
| self._emit_research_arena(cmd, program) |
|
|
| def _emit_repeat(self, cmd: RepeatBlock, program: TDProgram) -> None: |
| """REPEAT - run a block of commands N times. |
| |
| This is the core of td_loop: the self-improvement cycle. |
| Each iteration runs the body commands in order. |
| """ |
| n = cmd.count |
| self._emit(f'print("[td_lang] REPEAT - running {n} iterations")') |
| self._emit(f"for _loop_iter in range({n}):") |
| self._indent += 1 |
| self._emit(f'print(f"[td_lang] === Iteration {{_loop_iter + 1}}/{n} ===")') |
| self._emit("results['_loop_iter'] = _loop_iter") |
| if program.budget and program.budget.max_gpu_hours is not None: |
| self._emit("# Loop-level budget guard (GPU hours)") |
| self._emit("elapsed_hours = (time.time() - start_time) / 3600") |
| self._emit(f"if elapsed_hours >= {program.budget.max_gpu_hours}:") |
| self._indent += 1 |
| self._emit('print("[td_lang] Budget exceeded inside repeat - stopping loop.")') |
| self._emit("break") |
| self._indent -= 1 |
| self._emit("") |
| for body_cmd in cmd.body: |
| self._emit_cmd(body_cmd, program) |
| self._emit("") |
| self._emit(f'print(f"[td_lang] Iteration {{_loop_iter + 1}}/{n} complete.")') |
| self._indent -= 1 |
| self._emit(f'print("[td_lang] REPEAT complete - {n} iterations done.")') |
|
|
| def _emit_if(self, cmd: IfBlock, program: TDProgram) -> None: |
| """IF/ELSE - conditional execution based on eval results. |
| |
| Conditions: |
| - eval_passed: last eval for target had no failures |
| - gate_passed: all gates passed for target |
| - improved: last eval score > previous eval score |
| """ |
| condition = cmd.condition |
| target = cmd.target |
|
|
| self._emit(f'print("[td_lang] IF - checking {condition} for {target}")') |
| self._emit("") |
|
|
| |
| if condition == "eval_passed": |
| self._emit(f'_last_eval = results.get("{target}_eval", {{}})') |
| self._emit("_condition_met = bool(_last_eval) and _last_eval.get('overall', False)") |
| elif condition == "gate_passed": |
| gates = program.gates.must_pass if program.gates else [] |
| self._emit(f'_last_eval = results.get("{target}_eval", {{}})') |
| self._emit(f"_gates = {gates}") |
| self._emit("_condition_met = all(") |
| self._indent += 1 |
| self._emit("bool(_last_eval.get(g, {}).get('ok', False)) if isinstance(_last_eval.get(g), dict) else bool(_last_eval.get(g, False))") |
| self._emit("for g in _gates") |
| self._indent -= 1 |
| self._emit(") if _gates else bool(_last_eval)") |
| elif condition == "improved": |
| self._emit(f'_eval_history = results.get("{target}_eval_history", [])') |
| self._emit("_condition_met = len(_eval_history) >= 2 and _eval_history[-1] > _eval_history[-2]") |
| else: |
| |
| self._emit(f'_condition_met = bool(results.get("{target}_{condition}", False))') |
|
|
| self._emit("") |
| self._emit("if _condition_met:") |
| self._indent += 1 |
| self._emit(f'print("[td_lang] Condition {condition} = TRUE")') |
| for body_cmd in cmd.then_body: |
| self._emit_cmd(body_cmd, program) |
| self._emit("") |
| self._indent -= 1 |
|
|
| if cmd.else_body: |
| self._emit("else:") |
| self._indent += 1 |
| self._emit(f'print("[td_lang] Condition {condition} = FALSE")') |
| for body_cmd in cmd.else_body: |
| self._emit_cmd(body_cmd, program) |
| self._emit("") |
| self._indent -= 1 |
|
|
| def _emit_break_if(self, cmd: BreakIfCmd) -> None: |
| """BREAK_IF - early exit from repeat based on condition.""" |
| condition = cmd.condition |
| target = cmd.target or "" |
| self._emit(f'_brk_eval = results.get("{target}_eval", {{}})') |
| if condition == "improved": |
| self._emit(f'_hist = results.get("{target}_eval_history", [])') |
| self._emit("_brk_met = len(_hist) >= 2 and _hist[-1] <= _hist[-2]") |
| elif condition == "eval_passed": |
| self._emit("_brk_met = bool(_brk_eval.get('overall', False))") |
| else: |
| self._emit(f"_brk_met = bool(results.get('{target}_{condition}', False))") |
| self._emit("if _brk_met:") |
| self._indent += 1 |
| self._emit('print("[td_lang] break_if triggered - exiting loop")') |
| self._emit("break") |
| self._indent -= 1 |
|
|
| |
|
|
| def _emit_fuse(self, cmd: FuseCmd) -> None: |
| """FUSE - merge multiple models into target in one command. |
| |
| From TD merge strategy: Transport and Merge (optimal transport cross-arch merging). |
| All 5 source models have different architectures - Transport and Merge handles this. |
| Merge into language backbone only, vision encoder stays untouched. |
| """ |
| target = cmd.target |
| sources = cmd.sources |
| method = cmd.method |
| strategy = cmd.strategy |
| n = len(sources) |
|
|
| self._emit(f'print("[td_lang] FUSE - merging {n} models into {target} using {method}")') |
| self._emit(f'print("[td_lang] Strategy: {strategy}")') |
| self._emit(f"fuse_sources = {sources}") |
| self._emit(f'prev_ckpt = models.get("{target}", {{}}).get("checkpoint")') |
| self._emit("") |
|
|
| |
| self._emit("# Auto-compute per-model merge strength") |
| if strategy == "equal": |
| self._emit(f"per_model_strength = round(1.0 / ({n} + 1), 3) # equal weight, target keeps its share") |
| self._emit(f'print(f"[td_lang] Equal strategy: each model gets {{per_model_strength}} strength")') |
| elif strategy == "sequential": |
| self._emit("# Sequential: merge one at a time with decreasing strength") |
| self._emit(f"strengths = [round(0.5 * (0.8 ** i), 3) for i in range({n})]") |
| self._emit('print(f"[td_lang] Sequential strategy: strengths = {strengths}")') |
| else: |
| |
| self._emit(f"per_model_strength = round(1.0 / ({n} + 1), 3)") |
| self._emit("") |
|
|
| |
| self._emit("fuse_results = []") |
| self._emit("for fuse_idx, fuse_source in enumerate(fuse_sources):") |
| self._indent += 1 |
| self._emit(f'print(f"[td_lang] Fuse step {{fuse_idx + 1}}/{n}: merging {{fuse_source}}...")') |
| self._emit("") |
|
|
| |
| if strategy == "sequential": |
| self._emit("step_strength = strengths[fuse_idx]") |
| else: |
| self._emit("step_strength = per_model_strength") |
| self._emit("") |
|
|
| |
| self._emit("_stage = None") |
| self._emit("_arch = None") |
| self._emit("for _src in SOURCES:") |
| self._indent += 1 |
| self._emit("if _src.hf_id == fuse_source or _src.name.lower() in fuse_source.lower():") |
| self._indent += 1 |
| self._emit('_stage = _src.name.lower().split("-")[0]') |
| self._emit("_arch = getattr(_src, 'architecture', 'unknown')") |
| self._emit("_src.merge_alpha = step_strength") |
| self._emit("break") |
| self._indent -= 1 |
| self._indent -= 1 |
| self._emit("") |
|
|
| self._emit("if _stage is None:") |
| self._indent += 1 |
| self._emit('print(f"[td_lang] WARNING: Could not match {fuse_source} to SOURCES. Attempting direct merge...")') |
| self._emit("# For Transport and Merge, we can merge any architecture directly") |
| self._emit(f'_stage = fuse_source.split("/")[-1].lower().replace("-", "_")[:20]') |
| self._emit('_arch = "unknown"') |
| self._indent -= 1 |
| self._emit("") |
|
|
| |
| self._emit("cfg = MergeConfig()") |
| self._emit("# Auto-pick merge method by architecture match") |
| self._emit("chosen_method = 'slerp' if _arch == getattr(TARGET, 'architecture', 'unknown') else 'transport'") |
| self._emit(f"if '{method}' not in ['auto', '']: chosen_method = '{method}'") |
| self._emit("cfg.merge_method = chosen_method") |
| self._emit("merge_result = run_pipeline([_stage], cfg)") |
| self._emit("fuse_results.append({") |
| self._indent += 1 |
| self._emit('"source": fuse_source,') |
| self._emit('"stage": _stage,') |
| self._emit('"strength": step_strength,') |
| self._emit('"result": merge_result,') |
| self._indent -= 1 |
| self._emit("})") |
| self._emit("merged_stages.append(_stage)") |
| self._emit("") |
|
|
| |
| self._emit('if merge_result.get("final_checkpoint"):') |
| self._indent += 1 |
| self._emit(f'models["{target}"]["checkpoint"] = merge_result["final_checkpoint"]') |
| self._emit("pre_score = quick_canary(prev_ckpt) if prev_ckpt else None") |
| self._emit("post_score = quick_canary(merge_result['final_checkpoint'])") |
| self._emit("if pre_score and post_score < 0.9 * pre_score:") |
| self._indent += 1 |
| self._emit('print(f"[td_lang] WARNING: quick canary degradation detected (pre={pre_score:.1f}, post={post_score:.1f})")') |
| self._indent -= 1 |
| self._indent -= 1 |
| self._emit(f'print(f"[td_lang] Fused {{fuse_source}} (strength={{step_strength}})")') |
| self._indent -= 1 |
|
|
| self._emit("") |
| self._emit(f'results["{target}_fuse"] = fuse_results') |
| self._emit("") |
|
|
| |
| self._emit(f'lineage["{target}"]["operations"].append({{') |
| self._indent += 1 |
| self._emit('"op": "fuse",') |
| self._emit(f'"sources": {sources},') |
| self._emit(f'"method": "{method}",') |
| self._emit(f'"strategy": "{strategy}",') |
| self._emit(f'"n_models": {n},') |
| self._emit('"timestamp": datetime.now().isoformat(),') |
| self._indent -= 1 |
| self._emit("})") |
| self._emit(f'print("[td_lang] FUSE complete - {n} models merged into {target}")') |
|
|
| def _emit_absorb(self, cmd: AbsorbCmd) -> None: |
| """ABSORB - simplified single-model merge. |
| |
| One-liner shortcut: absorb "model" into target [strength 0.5] |
| Wraps the merge logic with sensible defaults. |
| """ |
| source = cmd.source |
| target = cmd.target |
| strength = cmd.strength |
|
|
| self._emit(f'print("[td_lang] ABSORB - merging {source} into {target} (strength={strength})")') |
| self._emit(f'prev_ckpt = models.get("{target}", {{}}).get("checkpoint")') |
| self._emit("") |
|
|
| |
| self._emit(f'_source_ref = "{source}"') |
| self._emit("_stage = None") |
| self._emit("_arch = None") |
| self._emit("for _src in SOURCES:") |
| self._indent += 1 |
| self._emit('if _src.hf_id == _source_ref or _src.name.lower() in _source_ref.lower():') |
| self._indent += 1 |
| self._emit('_stage = _src.name.lower().split("-")[0]') |
| self._emit("_arch = getattr(_src, 'architecture', 'unknown')") |
| self._emit("break") |
| self._indent -= 1 |
| self._indent -= 1 |
| self._emit("") |
|
|
| self._emit("if _stage is None:") |
| self._indent += 1 |
| self._emit(f'print(f"[td_lang] WARNING: {{_source_ref}} not in SOURCES. Using direct ref.")') |
| self._emit(f'_stage = _source_ref.split("/")[-1].lower().replace("-", "_")[:20]') |
| self._emit('_arch = "unknown"') |
| self._indent -= 1 |
| self._emit("") |
|
|
| |
| self._emit("strengths = []") |
| self._emit("if str(strength).lower() == 'auto':") |
| self._indent += 1 |
| self._emit("strengths = [0.2, 0.4, 0.6, 0.8]") |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit("strengths = [strength]") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("best_score = -1") |
| self._emit("best_result = None") |
| self._emit("best_strength = strengths[0]") |
| self._emit("for s in strengths:") |
| self._indent += 1 |
| self._emit("cfg = MergeConfig()") |
| self._emit("# choose method by architecture") |
| self._emit("cfg.merge_method = 'slerp' if _arch == getattr(TARGET, 'architecture', 'unknown') else 'transport'") |
| self._emit("for _src in SOURCES:") |
| self._indent += 1 |
| self._emit("if _src.hf_id == _source_ref or _src.name.lower() in _source_ref.lower():") |
| self._indent += 1 |
| self._emit(" _src.merge_alpha = s") |
| self._indent -= 1 |
| self._emit("break") |
| self._indent -= 1 |
| self._emit("merge_result = run_pipeline([_stage], cfg)") |
| self._emit("ckpt = merge_result.get('final_checkpoint')") |
| self._emit("score = quick_canary(ckpt) if ckpt else -1") |
| self._emit("if score > best_score:") |
| self._indent += 1 |
| self._emit("best_score = score") |
| self._emit("best_result = merge_result") |
| self._emit("best_strength = s") |
| self._indent -= 1 |
| self._indent -= 1 |
| self._emit("") |
| self._emit("merge_result = best_result") |
| self._emit("cfg_strength = best_strength") |
| self._emit("merged_stages.append(_stage)") |
| self._emit("") |
|
|
| |
| self._emit('if merge_result and merge_result.get("final_checkpoint"):') |
| self._indent += 1 |
| self._emit(f'models["{target}"]["checkpoint"] = merge_result["final_checkpoint"]') |
| self._emit("pre_score = quick_canary(prev_ckpt) if prev_ckpt else None") |
| self._emit("post_score = quick_canary(merge_result['final_checkpoint']) if merge_result else None") |
| self._emit("if pre_score and post_score and post_score < 0.9 * pre_score:") |
| self._indent += 1 |
| self._emit('print(f"[td_lang] WARNING: canary degradation (pre={pre_score:.1f}, post={post_score:.1f})")') |
| self._indent -= 1 |
| self._indent -= 1 |
| self._emit(f'results["{target}_absorb"] = merge_result') |
| self._emit("") |
|
|
| |
| self._emit(f'lineage["{target}"]["operations"].append({{') |
| self._indent += 1 |
| self._emit('"op": "absorb",') |
| self._emit(f'"source": "{source}",') |
| self._emit(f'"strength": {strength},') |
| self._emit('"method": "auto" if str(strength).lower()=="auto" else "transport",') |
| self._emit('"timestamp": datetime.now().isoformat(),') |
| self._indent -= 1 |
| self._emit("})") |
| self._emit(f'print("[td_lang] ABSORB complete - {source} merged into {target}")') |
|
|
| |
|
|
| def _emit_data_contract(self, dc: DataContractBlock) -> None: |
| """Emit data contract validation - checked at synth/train time. |
| |
| From ForgeSpec 2.0 (test_17): data contracts enforce schema on training data. |
| Required fields, minimum samples, max perplexity. |
| """ |
| self._emit("# Data Contract (Phase 4, ForgeSpec 2.0)") |
| self._emit("data_contract = {") |
| self._indent += 1 |
| self._emit(f'"required_fields": {dc.required_fields},') |
| if dc.min_samples is not None: |
| self._emit(f'"min_samples": {dc.min_samples},') |
| if dc.max_perplexity is not None: |
| self._emit(f'"max_perplexity": {dc.max_perplexity},') |
| self._indent -= 1 |
| self._emit("}") |
| self._emit("") |
| self._emit("def validate_data_contract(data_path, contract):") |
| self._indent += 1 |
| self._emit('"""Check training data against data contract."""') |
| self._emit("import json") |
| self._emit("errors = []") |
| self._emit("samples = []") |
| self._emit("with open(data_path) as f:") |
| self._indent += 1 |
| self._emit("for line_num, line in enumerate(f, 1):") |
| self._indent += 1 |
| self._emit("line = line.strip()") |
| self._emit("if not line: continue") |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("sample = json.loads(line)") |
| self._emit("samples.append(sample)") |
| self._emit('for field in contract.get("required_fields", []):') |
| self._indent += 1 |
| self._emit("if field not in sample:") |
| self._indent += 1 |
| self._emit('errors.append(f"Line {line_num}: missing required field \'{field}\'")') |
| self._indent -= 2 |
| self._indent -= 1 |
| self._emit("except json.JSONDecodeError:") |
| self._indent += 1 |
| self._emit('errors.append(f"Line {line_num}: invalid JSON")') |
| self._indent -= 2 |
| self._indent -= 1 |
| self._emit('min_s = contract.get("min_samples")') |
| self._emit("if min_s and len(samples) < min_s:") |
| self._indent += 1 |
| self._emit('errors.append(f"Need {min_s} samples, got {len(samples)}")') |
| self._indent -= 1 |
| self._emit("if errors:") |
| self._indent += 1 |
| self._emit('print("[td_lang] DATA CONTRACT VIOLATIONS:")') |
| self._emit("for e in errors[:10]:") |
| self._indent += 1 |
| self._emit('print(f" - {e}")') |
| self._indent -= 1 |
| self._emit("if len(errors) > 10:") |
| self._indent += 1 |
| self._emit('print(f" ... and {len(errors)-10} more")') |
| self._indent -= 1 |
| self._emit('raise ValueError(f"Data contract failed: {len(errors)} violations")') |
| self._indent -= 1 |
| self._emit('print(f"[td_lang] Data contract OK: {len(samples)} samples, all fields present.")') |
| self._emit("return samples") |
| self._indent -= 1 |
| self._emit("") |
|
|
| def _emit_reward_contract(self, rc: RewardContractBlock) -> None: |
| """Emit reward contract - enforced during GRPO training. |
| |
| From test_16: verified rewards only, no learned reward model. |
| """ |
| self._emit("# Reward Contract (Phase 4, ForgeSpec 2.0)") |
| self._emit("reward_contract = {") |
| self._indent += 1 |
| self._emit(f'"verifiers": {rc.verifiers},') |
| if rc.min_reward is not None: |
| self._emit(f'"min_reward": {rc.min_reward},') |
| self._indent -= 1 |
| self._emit("}") |
| self._emit('print(f"[td_lang] Reward contract: verifiers={reward_contract[\'verifiers\']}")') |
| self._emit("") |
|
|
| def _emit_snapshot(self, cmd: SnapshotCmd, program: TDProgram) -> None: |
| """SNAPSHOT - content-hashed model state for artifact lineage. |
| |
| From ForgeSpec 2.0 (test_17): every model state gets a content-addressed hash. |
| Directory contains: model weights/adapters, eval report, prune spec, manifest. |
| """ |
| alias = cmd.target |
| output_dir = cmd.output or "td_lang_outputs/snapshots" |
|
|
| self._emit(f'print("[td_lang] SNAPSHOT - saving content-hashed state for {alias}")') |
| self._emit("import hashlib, json # time already imported at top") |
| self._emit(f'snap_model = models["{alias}"]') |
| self._emit("") |
|
|
| |
| self._emit("# Content hash from model parameters (first 10 layers for speed)") |
| self._emit("hasher = hashlib.sha256()") |
| self._emit("param_count = 0") |
| self._emit("if hasattr(snap_model, 'state_dict'):") |
| self._indent += 1 |
| self._emit("for name, param in list(snap_model.state_dict().items())[:50]:") |
| self._indent += 1 |
| self._emit("hasher.update(param.cpu().numpy().tobytes()[:1024])") |
| self._emit("param_count += param.numel()") |
| self._indent -= 1 |
| self._indent -= 1 |
| self._emit("elif isinstance(snap_model, dict):") |
| self._indent += 1 |
| self._emit("for k, v in snap_model.items():") |
| self._indent += 1 |
| self._emit("hasher.update(str(v).encode()[:256])") |
| self._indent -= 1 |
| self._indent -= 1 |
| self._emit("content_hash = hasher.hexdigest()[:16]") |
| self._emit(f'snap_dir = os.path.join(output_dir, "{output_dir}", f"{alias}_{{content_hash}}")') |
| self._emit("os.makedirs(snap_dir, exist_ok=True)") |
| self._emit("") |
|
|
| |
| self._emit("# Snapshot manifest - full provenance record") |
| self._emit("snap_manifest = {") |
| self._indent += 1 |
| self._emit(f'"alias": "{alias}",') |
| self._emit('"content_hash": content_hash,') |
| self._emit('"param_count": param_count,') |
| self._emit('"timestamp": datetime.now().isoformat(),') |
| self._emit(f'"lineage": lineage.get("{alias}", {{}}),') |
| self._emit(f'"eval_results": results.get("{alias}_eval", None),') |
| self._emit(f'"diagnose_results": results.get("{alias}_diagnose", None),') |
| self._indent -= 1 |
| self._emit("}") |
| self._emit("") |
|
|
| |
| self._emit("if hasattr(snap_model, 'peft_config'):") |
| self._indent += 1 |
| self._emit('adapter_dir = os.path.join(snap_dir, "adapters")') |
| self._emit("snap_model.save_pretrained(adapter_dir)") |
| self._emit('snap_manifest["has_adapters"] = True') |
| self._emit('snap_manifest["adapter_dir"] = adapter_dir') |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit(f'ckpt = models.get("{alias}", {{}}).get("checkpoint") if isinstance(models.get("{alias}"), dict) else None') |
| self._emit('snap_manifest["has_adapters"] = False') |
| self._emit('snap_manifest["checkpoint_ref"] = str(ckpt) if ckpt else "in_memory"') |
| self._indent -= 1 |
| self._emit("") |
|
|
| |
| self._emit('manifest_path = os.path.join(snap_dir, "snapshot_manifest.json")') |
| self._emit('with open(manifest_path, "w") as f:') |
| self._indent += 1 |
| self._emit("json.dump(snap_manifest, f, indent=2, default=str)") |
| self._indent -= 1 |
| self._emit(f'print(f"[td_lang] Snapshot saved: {{snap_dir}}")') |
| self._emit(f'print(f"[td_lang] Content hash: {{content_hash}}")') |
| self._emit("") |
|
|
| |
| self._emit(f'lineage.setdefault("{alias}", {{"operations": []}})["operations"].append({{') |
| self._indent += 1 |
| self._emit('"op": "snapshot",') |
| self._emit('"content_hash": content_hash,') |
| self._emit('"snap_dir": snap_dir,') |
| self._emit('"timestamp": datetime.now().isoformat(),') |
| self._indent -= 1 |
| self._emit("})") |
|
|
| def _emit_report(self, cmd: ReportCmd, program: TDProgram) -> None: |
| """REPORT - economics report for the run. |
| |
| Tracks GPU hours, cost, tokens, time per command. |
| From test_17 ForgeSpec 2.0: economics reports for cost tracking. |
| """ |
| output = cmd.output or "economics_report.json" |
|
|
| self._emit('print("[td_lang] REPORT - generating economics report")') |
| self._emit("elapsed = time.time() - start_time") |
| self._emit("") |
| self._emit("report = {") |
| self._indent += 1 |
| self._emit('"td_lang_version": "0.2.0",') |
| self._emit('"timestamp": datetime.now().isoformat(),') |
| self._emit('"elapsed_seconds": round(elapsed, 2),') |
| self._emit('"elapsed_minutes": round(elapsed / 60, 2),') |
| self._emit(f'"gpu_hourly_rate": {self.GPU_HOURLY},') |
| self._emit('"estimated_cost": round(elapsed / 3600 * GPU_HOURLY, 2),') |
| self._emit('"models_loaded": list(models.keys()),') |
| self._emit('"merged_stages": merged_stages,') |
| self._emit('"lineage_summary": {},') |
| self._indent -= 1 |
| self._emit("}") |
| self._emit("") |
|
|
| |
| self._emit("for alias, lin in lineage.items():") |
| self._indent += 1 |
| self._emit("ops = lin.get('operations', [])") |
| self._emit("op_counts = {}") |
| self._emit("for op in ops:") |
| self._indent += 1 |
| self._emit("op_type = op.get('op', 'unknown')") |
| self._emit("op_counts[op_type] = op_counts.get(op_type, 0) + 1") |
| self._indent -= 1 |
| self._emit('report["lineage_summary"][alias] = {') |
| self._indent += 1 |
| self._emit('"total_operations": len(ops),') |
| self._emit('"operation_counts": op_counts,') |
| self._indent -= 1 |
| self._emit("}") |
| self._indent -= 1 |
| self._emit("") |
|
|
| |
| self._emit("eval_summary = {}") |
| self._emit("for key, val in results.items():") |
| self._indent += 1 |
| self._emit('if "_eval" in key:') |
| self._indent += 1 |
| self._emit("if isinstance(val, dict):") |
| self._indent += 1 |
| self._emit("eval_summary[key] = {k: v for k, v in val.items() if k != 'raw'}") |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit('eval_summary[key] = str(val)[:200]') |
| self._indent -= 2 |
| self._indent -= 1 |
| self._emit('report["eval_summary"] = eval_summary') |
| self._emit("") |
|
|
| |
| if program.data_contract: |
| self._emit('report["data_contract"] = data_contract') |
| if program.reward_contract: |
| self._emit('report["reward_contract"] = reward_contract') |
|
|
| |
| self._emit(f'report_path = Path("{output}")') |
| self._emit("report_path.parent.mkdir(parents=True, exist_ok=True)") |
| self._emit('with open(report_path, "w") as f:') |
| self._indent += 1 |
| self._emit("json.dump(report, f, indent=2, default=str)") |
| self._indent -= 1 |
| self._emit(f'print(f"[td_lang] Economics report saved to {{report_path}}")') |
| self._emit('print(f"[td_lang] Time: {report[\'elapsed_minutes\']} min")') |
| self._emit('print(f"[td_lang] Estimated cost: ${report[\'estimated_cost\']}")') |
| self._emit('print(f"[td_lang] Models: {report[\'models_loaded\']}")') |
|
|
| |
|
|
| def _emit_setup(self, setup: SetupBlock) -> None: |
| """SETUP - auto-install dependencies and configure environment. |
| |
| Runs at script start: pip install, HF token, ntfy config. |
| """ |
| self._emit("# ========== SETUP (Phase 8 - Autopilot) ==========") |
| self._emit('print("[td_lang] SETUP - configuring environment...")') |
| self._emit("") |
|
|
| |
| if setup.pip_packages: |
| pkg_str = " ".join(setup.pip_packages) |
| self._emit(f"# Install dependencies") |
| self._emit(f'_pip_pkgs = "{pkg_str}"') |
| self._emit("import subprocess as _sp") |
| self._emit('print(f"[td_lang] Installing: {_pip_pkgs}")') |
| self._emit("try:") |
| self._indent += 1 |
| self._emit('_sp.check_call([sys.executable, "-m", "pip", "install", "--break-system-packages", "-q"]') |
| self._emit(f' + _pip_pkgs.split())') |
| self._emit('print("[td_lang] Dependencies installed.")') |
| self._indent -= 1 |
| self._emit("except Exception as e:") |
| self._indent += 1 |
| self._emit('print(f"[td_lang] WARNING: pip install failed: {e}")') |
| self._emit('print("[td_lang] Continuing anyway - packages may already be installed.")') |
| self._indent -= 1 |
| self._emit("") |
|
|
| |
| if setup.hf_token: |
| self._emit("# HuggingFace authentication") |
| if setup.hf_token == "env": |
| self._emit('_hf_token = os.environ.get("HF_TOKEN", "")') |
| else: |
| self._emit(f'_hf_token = "{setup.hf_token}"') |
| self._emit("if _hf_token:") |
| self._indent += 1 |
| self._emit("os.environ['HF_TOKEN'] = _hf_token") |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("from huggingface_hub import login") |
| self._emit("login(token=_hf_token, add_to_git_credential=False)") |
| self._emit('print("[td_lang] HuggingFace authenticated.")') |
| self._indent -= 1 |
| self._emit("except Exception:") |
| self._indent += 1 |
| self._emit('print("[td_lang] HF login via huggingface_hub failed, using env var.")') |
| self._indent -= 1 |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit('print("[td_lang] WARNING: No HF_TOKEN found. Gated models may fail to download.")') |
| self._indent -= 1 |
| self._emit("") |
|
|
| |
| if setup.notify_url: |
| self._emit("# Notification endpoint (ntfy.sh)") |
| self._emit(f'NTFY_URL = "{setup.notify_url}"') |
| self._emit("") |
| self._emit("def td_notify(msg):") |
| self._indent += 1 |
| self._emit('"""Send notification via ntfy.sh."""') |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("import urllib.request") |
| self._emit("req = urllib.request.Request(") |
| self._indent += 1 |
| self._emit('f"https://{NTFY_URL}" if not NTFY_URL.startswith("http") else NTFY_URL,') |
| self._emit("data=msg.encode(),") |
| self._emit('method="POST",') |
| self._indent -= 1 |
| self._emit(")") |
| self._emit("urllib.request.urlopen(req, timeout=10)") |
| self._emit('print(f"[td_lang] Notified: {msg}")') |
| self._indent -= 1 |
| self._emit("except Exception as e:") |
| self._indent += 1 |
| self._emit('print(f"[td_lang] Notify failed: {e}")') |
| self._indent -= 1 |
| self._indent -= 1 |
| else: |
| self._emit("def td_notify(msg):") |
| self._indent += 1 |
| self._emit('print(f"[td_lang] (no ntfy configured) {msg}")') |
| self._indent -= 1 |
|
|
| self._emit("") |
| self._emit('td_notify("TD pipeline starting...")') |
| self._emit('print("[td_lang] SETUP complete.")') |
| self._emit("") |
|
|
| def _emit_on_error(self, on_error: OnErrorBlock, program: TDProgram) -> None: |
| """ON_ERROR - wrap each step in retry/fallback logic. |
| |
| Emits a td_safe_run() helper that wraps any function call with: |
| - Retry N times on failure |
| - Fallback strategies (reduce batch, skip, snapshot+stop) |
| - Optional ntfy notification on error |
| """ |
| self._emit("# ========== ON_ERROR (Phase 8 - Crash Recovery) ==========") |
| self._emit(f"TD_MAX_RETRIES = {on_error.retry}") |
| self._emit(f'TD_FALLBACK = "{on_error.fallback}"') |
| self._emit(f"TD_NOTIFY_ON_ERROR = {on_error.notify}") |
| self._emit("") |
| self._emit("def td_safe_run(step_name, fn, *args, **kwargs):") |
| self._indent += 1 |
| self._emit('"""Run a step with retry and fallback on error."""') |
| self._emit("import traceback") |
| self._emit("for attempt in range(1, TD_MAX_RETRIES + 1):") |
| self._indent += 1 |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("return fn(*args, **kwargs)") |
| self._indent -= 1 |
| self._emit("except torch.cuda.OutOfMemoryError as oom:") |
| self._indent += 1 |
| self._emit('print(f"[td_lang] OOM on {step_name} (attempt {attempt}/{TD_MAX_RETRIES})")') |
| self._emit("torch.cuda.empty_cache()") |
| self._emit("import gc; gc.collect()") |
| self._emit('if TD_FALLBACK == "reduce_batch":') |
| self._indent += 1 |
| self._emit('print("[td_lang] Reducing batch size and retrying...")') |
| self._emit('os.environ["TD_REDUCE_BATCH"] = "1"') |
| self._indent -= 1 |
| self._emit('elif TD_FALLBACK == "skip":') |
| self._indent += 1 |
| self._emit('print(f"[td_lang] Skipping {step_name}")') |
| self._emit("return None") |
| self._indent -= 1 |
| self._emit('elif TD_FALLBACK == "snapshot_and_stop":') |
| self._indent += 1 |
| self._emit('print(f"[td_lang] OOM - saving snapshot and stopping.")') |
| self._emit("if TD_NOTIFY_ON_ERROR:") |
| self._indent += 1 |
| self._emit('td_notify(f"OOM on {step_name} - snapshot saved, stopping.")') |
| self._indent -= 1 |
| self._emit("raise") |
| self._indent -= 2 |
| self._emit("except Exception as e:") |
| self._indent += 1 |
| self._emit('print(f"[td_lang] Error on {step_name} (attempt {attempt}/{TD_MAX_RETRIES}): {e}")') |
| self._emit("traceback.print_exc()") |
| self._emit("if attempt == TD_MAX_RETRIES:") |
| self._indent += 1 |
| self._emit("if TD_NOTIFY_ON_ERROR:") |
| self._indent += 1 |
| self._emit('td_notify(f"FAILED: {step_name} after {TD_MAX_RETRIES} retries - {e}")') |
| self._indent -= 1 |
| self._emit('if TD_FALLBACK == "skip":') |
| self._indent += 1 |
| self._emit("return None") |
| self._indent -= 1 |
| self._emit("raise") |
| self._indent -= 2 |
| self._indent -= 1 |
| self._indent -= 1 |
| self._emit("") |
|
|
| def _emit_notify(self, cmd: NotifyCmd, program: TDProgram) -> None: |
| """NOTIFY - send message via ntfy.sh.""" |
| msg = cmd.message.replace('"', '\\"') |
| self._emit(f'td_notify("{msg}")') |
|
|
| def _emit_save(self, cmd: SaveCmd, program: TDProgram) -> None: |
| """SAVE - upload model to cloud storage via rclone. |
| |
| Uses rclone to copy model checkpoint/adapters to Google Drive or any remote. |
| """ |
| alias = cmd.target |
| dest = cmd.destination |
|
|
| self._emit(f'print("[td_lang] SAVE - uploading {alias} to {dest}")') |
| self._emit("") |
|
|
| |
| self._emit(f'_save_model = models.get("{alias}", {{}})') |
| self._emit('_save_path = _save_model.get("checkpoint") if isinstance(_save_model, dict) else None') |
| self._emit("") |
|
|
| |
| self._emit('if hasattr(_save_model, "peft_config") or (isinstance(_save_model, dict) and _save_model.get("has_adapters")):') |
| self._indent += 1 |
| self._emit(f'_adapter_dir = f"td_lang_outputs/{alias}_save_adapters"') |
| self._emit("os.makedirs(_adapter_dir, exist_ok=True)") |
| self._emit("if hasattr(_save_model, 'save_pretrained'):") |
| self._indent += 1 |
| self._emit("_save_model.save_pretrained(_adapter_dir)") |
| self._indent -= 1 |
| self._emit("_save_path = _adapter_dir") |
| self._indent -= 1 |
| self._emit("") |
|
|
| |
| self._emit("if _save_path:") |
| self._indent += 1 |
| self._emit(f'_rclone_cmd = ["rclone", "copy", str(_save_path), "{dest}", "--progress"]') |
| self._emit('_rclone_str = " ".join(_rclone_cmd)') |
| self._emit('print(f"[td_lang] Running: {_rclone_str}")') |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("import subprocess as _sp") |
| self._emit("_sp.check_call(_rclone_cmd)") |
| self._emit(f'print("[td_lang] SAVE complete - {alias} uploaded to {dest}")') |
| self._emit(f'td_notify("Model {alias} saved to {dest}")') |
| self._indent -= 1 |
| self._emit("except FileNotFoundError:") |
| self._indent += 1 |
| self._emit('print("[td_lang] ERROR: rclone not found. Install it: curl https://rclone.org/install.sh | sudo bash")') |
| self._emit('print("[td_lang] Then configure: rclone config (add Google Drive remote)")') |
| self._emit(f'td_notify("SAVE FAILED: rclone not installed")') |
| self._indent -= 1 |
| self._emit("except Exception as e:") |
| self._indent += 1 |
| self._emit('print(f"[td_lang] SAVE error: {e}")') |
| self._emit(f'td_notify(f"SAVE FAILED: {{e}}")') |
| self._indent -= 1 |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit(f'print("[td_lang] WARNING: No checkpoint found for {alias}. Nothing to save.")') |
| self._emit(f'print("[td_lang] Run commit or snapshot first to create a checkpoint.")') |
| self._indent -= 1 |
|
|
| |
| self._emit("") |
| self._emit(f'lineage.setdefault("{alias}", {{"operations": []}})["operations"].append({{') |
| self._indent += 1 |
| self._emit('"op": "save",') |
| self._emit(f'"destination": "{dest}",') |
| self._emit('"timestamp": datetime.now().isoformat(),') |
| self._indent -= 1 |
| self._emit("})") |
|
|
| |
| def _emit_schedule(self, cmd: ScheduleCmd, program: TDProgram) -> None: |
| """SCHEDULE - time-based command execution. |
| |
| Patterns: |
| "every 6h" → loop with time.sleep(6*3600) |
| "every 30m" → loop with time.sleep(30*60) |
| "at 02:00" → wait until that time, run once |
| "after 30m" → sleep then run once |
| """ |
| timing = cmd.timing.strip() |
| self._emit(f'print("[td_lang] SCHEDULE - timing: {timing}")') |
| self._emit("import time as _time") |
| self._emit("from datetime import datetime as _dt, timedelta as _td") |
| self._emit("") |
|
|
| if timing.startswith("every "): |
| |
| interval_str = timing[6:].strip() |
| self._emit(f'_interval_str = "{interval_str}"') |
| self._emit("if _interval_str.endswith('h'):") |
| self._indent += 1 |
| self._emit("_interval_secs = int(_interval_str[:-1]) * 3600") |
| self._indent -= 1 |
| self._emit("elif _interval_str.endswith('m'):") |
| self._indent += 1 |
| self._emit("_interval_secs = int(_interval_str[:-1]) * 60") |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit("_interval_secs = int(_interval_str) * 3600 # default to hours") |
| self._indent -= 1 |
| self._emit('print(f"[td_lang] Running every {_interval_secs}s ({_interval_str}). Ctrl+C to stop.")') |
| self._emit("_sched_iter = 0") |
| self._emit("while True:") |
| self._indent += 1 |
| self._emit("_sched_iter += 1") |
| self._emit('print(f"[td_lang] Schedule iteration {_sched_iter} starting at {_dt.now()}")') |
| for body_cmd in cmd.body: |
| self._emit_cmd(body_cmd, program) |
| self._emit('print(f"[td_lang] Iteration {_sched_iter} done. Sleeping {_interval_secs}s...")') |
| self._emit("_time.sleep(_interval_secs)") |
| self._indent -= 1 |
|
|
| elif timing.startswith("at "): |
| |
| time_str = timing[3:].strip() |
| self._emit(f'_target_time = _dt.strptime("{time_str}", "%H:%M").time()') |
| self._emit("_now = _dt.now()") |
| self._emit("_target = _dt.combine(_now.date(), _target_time)") |
| self._emit("if _target <= _now:") |
| self._indent += 1 |
| self._emit("_target += _td(days=1) # schedule for tomorrow if time already passed") |
| self._indent -= 1 |
| self._emit("_wait = (_target - _now).total_seconds()") |
| self._emit('print(f"[td_lang] Waiting {_wait:.0f}s until {_target}...")') |
| self._emit("_time.sleep(_wait)") |
| self._emit('print(f"[td_lang] Scheduled time reached: {_dt.now()}")') |
| for body_cmd in cmd.body: |
| self._emit_cmd(body_cmd, program) |
|
|
| elif timing.startswith("after "): |
| |
| delay_str = timing[6:].strip() |
| self._emit(f'_delay_str = "{delay_str}"') |
| self._emit("if _delay_str.endswith('h'):") |
| self._indent += 1 |
| self._emit("_delay_secs = int(_delay_str[:-1]) * 3600") |
| self._indent -= 1 |
| self._emit("elif _delay_str.endswith('m'):") |
| self._indent += 1 |
| self._emit("_delay_secs = int(_delay_str[:-1]) * 60") |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit("_delay_secs = int(_delay_str) * 3600") |
| self._indent -= 1 |
| self._emit('print(f"[td_lang] Waiting {_delay_secs}s before running...")') |
| self._emit("_time.sleep(_delay_secs)") |
| self._emit('print(f"[td_lang] Delay complete. Running scheduled commands...")') |
| for body_cmd in cmd.body: |
| self._emit_cmd(body_cmd, program) |
|
|
| else: |
| self._emit(f'print("[td_lang] WARNING: Unknown schedule pattern: {timing}")') |
| self._emit('print("[td_lang] Supported: every Nh/Nm, at HH:MM, after Nh/Nm")') |
|
|
| |
| def _emit_log_setup(self, log_block: LogBlock) -> None: |
| """LOG - redirect all output to a file AND console.""" |
| filepath = log_block.filepath |
| self._emit(f'# Log setup - everything goes to "{filepath}" AND console') |
| self._emit("import sys as _sys") |
| self._emit("") |
| self._emit("class _TeeLogger:") |
| self._indent += 1 |
| self._emit("def __init__(self, filepath, stream):") |
| self._indent += 1 |
| self._emit("self.stream = stream") |
| self._emit("self.file = open(filepath, 'w')") |
| self._indent -= 1 |
| self._emit("def write(self, data):") |
| self._indent += 1 |
| self._emit("self.stream.write(data)") |
| self._emit("self.file.write(data)") |
| self._emit("self.file.flush()") |
| self._indent -= 1 |
| self._emit("def flush(self):") |
| self._indent += 1 |
| self._emit("self.stream.flush()") |
| self._emit("self.file.flush()") |
| self._indent -= 1 |
| self._indent -= 1 |
| self._emit("") |
| self._emit(f'_sys.stdout = _TeeLogger("{filepath}", _sys.stdout)') |
| self._emit(f'_sys.stderr = _TeeLogger("{filepath}", _sys.stderr)') |
| self._emit(f'print("[td_lang] Logging to: {filepath}")') |
| self._emit("") |
|
|
| def _emit_download(self, cmd: DownloadCmd) -> None: |
| """DOWNLOAD - pull a dataset from HuggingFace.""" |
| self._emit(f'print("[td_lang] Downloading dataset: {cmd.dataset} (split: {cmd.split})")') |
| self._emit("from datasets import load_dataset") |
| self._emit(f'_dl_dataset = load_dataset("{cmd.dataset}", split="{cmd.split}")') |
| self._emit(f'print(f"[td_lang] Downloaded {{len(_dl_dataset)}} samples")') |
| self._emit("") |
| self._emit("# Save locally as JSONL for later use") |
| self._emit(f'_dl_path = "td_lang_outputs/{cmd.alias}.jsonl"') |
| self._emit("os.makedirs(os.path.dirname(_dl_path), exist_ok=True)") |
| self._emit("_dl_dataset.to_json(_dl_path)") |
| self._emit(f'print(f"[td_lang] Saved to {{_dl_path}}")') |
| self._emit("") |
| self._emit(f'# Store reference for use in train/verify commands') |
| self._emit(f'results["{cmd.alias}_dataset"] = {{') |
| self._indent += 1 |
| self._emit(f'"path": _dl_path,') |
| self._emit(f'"source": "{cmd.dataset}",') |
| self._emit(f'"split": "{cmd.split}",') |
| self._emit(f'"n_samples": len(_dl_dataset),') |
| self._indent -= 1 |
| self._emit("}") |
| self._emit("") |
|
|
| def _emit_compare(self, cmd: CompareCmd) -> None: |
| """COMPARE - test source model vs merged model on same questions. |
| |
| This is the knowledge retention test: |
| 1. Load source model, ask it N questions, record answers |
| 2. Ask merged model same questions |
| 3. Compare - did merged model retain what source knew? |
| """ |
| alias = cmd.target |
| source = cmd.source |
| n = cmd.questions |
|
|
| self._emit(f'print("[td_lang] COMPARE - testing if {alias} retained knowledge from {source}")') |
| self._emit(f'print("[td_lang] Testing {n} questions on both models...")') |
| self._emit("") |
| self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer") |
| self._emit("import torch, random") |
| self._emit("") |
| self._emit("# Test questions across multiple domains") |
| self._emit("_compare_questions = [") |
| self._indent += 1 |
| self._emit("# Math") |
| self._emit('"What is 17 * 23?", "What is the square root of 144?", "What is 256 + 389?",') |
| self._emit('"Solve: 3x + 7 = 28", "What is 15% of 300?",') |
| self._emit("# Knowledge") |
| self._emit('"What is the capital of Japan?", "Who wrote Romeo and Juliet?",') |
| self._emit('"What is the speed of light in m/s?", "What element has atomic number 6?",') |
| self._emit('"What is the largest planet in our solar system?",') |
| self._emit("# Reasoning") |
| self._emit('"If A is taller than B, and B is taller than C, who is tallest?",') |
| self._emit('"A bat and ball cost $1.10. The bat costs $1 more than the ball. What does the ball cost?",') |
| self._emit("# Code") |
| self._emit('"Write a Python function to reverse a string.",') |
| self._emit('"What does len([1,2,3]) return in Python?",') |
| self._emit("# Language") |
| self._emit('"Translate to French: Hello, how are you?",') |
| self._emit('"What is the past tense of run?",') |
| self._indent -= 1 |
| self._emit("]") |
| self._emit(f"_n_compare = min({n}, len(_compare_questions))") |
| self._emit("_compare_questions = random.sample(_compare_questions, _n_compare)") |
| self._emit("") |
|
|
| |
| self._emit(f'print("[td_lang] Loading source model: {source}...")') |
| self._emit(f'_src_tok = AutoTokenizer.from_pretrained("{source}")') |
| self._emit(f'_src_model = _load_model_smart("{source}", torch_dtype=torch.bfloat16, device_map="auto")') |
| self._emit("_src_model.eval()") |
| self._emit("") |
| self._emit("_src_answers = {}") |
| self._emit("for q in _compare_questions:") |
| self._indent += 1 |
| self._emit('inputs = _src_tok(q, return_tensors="pt").to(_src_model.device)') |
| self._emit("with torch.no_grad():") |
| self._indent += 1 |
| self._emit("out = _src_model.generate(**inputs, max_new_tokens=128, do_sample=False)") |
| self._indent -= 1 |
| self._emit("resp = _src_tok.decode(out[0], skip_special_tokens=True)") |
| self._emit("if resp.startswith(q):") |
| self._indent += 1 |
| self._emit("resp = resp[len(q):].strip()") |
| self._indent -= 1 |
| self._emit("_src_answers[q] = resp") |
| self._indent -= 1 |
| self._emit('print(f"[td_lang] Source model: {len(_src_answers)} answers collected")') |
| self._emit("") |
| self._emit("# Free source model VRAM") |
| self._emit("del _src_model, _src_tok") |
| self._emit("import gc; gc.collect()") |
| self._emit("torch.cuda.empty_cache() if torch.cuda.is_available() else None") |
| self._emit("") |
|
|
| |
| self._emit(f'print("[td_lang] Testing merged model: {alias}...")') |
| self._emit(f'_mrg_checkpoint = models.get("{alias}", {{}}).get("checkpoint")') |
| self._emit("if not _mrg_checkpoint:") |
| self._indent += 1 |
| self._emit(f'_mrg_checkpoint = models["{alias}"]["model_ref"]') |
| self._indent -= 1 |
| self._emit("_mrg_tok = AutoTokenizer.from_pretrained(_mrg_checkpoint)") |
| self._emit('_mrg_model = _load_model_smart(_mrg_checkpoint, torch_dtype=torch.bfloat16, device_map="auto")') |
| self._emit("_mrg_model.eval()") |
| self._emit("") |
| self._emit("_mrg_answers = {}") |
| self._emit("for q in _compare_questions:") |
| self._indent += 1 |
| self._emit('inputs = _mrg_tok(q, return_tensors="pt").to(_mrg_model.device)') |
| self._emit("with torch.no_grad():") |
| self._indent += 1 |
| self._emit("out = _mrg_model.generate(**inputs, max_new_tokens=128, do_sample=False)") |
| self._indent -= 1 |
| self._emit("resp = _mrg_tok.decode(out[0], skip_special_tokens=True)") |
| self._emit("if resp.startswith(q):") |
| self._indent += 1 |
| self._emit("resp = resp[len(q):].strip()") |
| self._indent -= 1 |
| self._emit("_mrg_answers[q] = resp") |
| self._indent -= 1 |
| self._emit("") |
|
|
| |
| self._emit("# Compare: check if merged model's answers match source model") |
| self._emit("_matches = 0") |
| self._emit("_compare_details = []") |
| self._emit("for q in _compare_questions:") |
| self._indent += 1 |
| self._emit("src_ans = _src_answers.get(q, '')") |
| self._emit("mrg_ans = _mrg_answers.get(q, '')") |
| self._emit("# Fuzzy match: check if key words from source appear in merged answer") |
| self._emit("src_words = set(src_ans.lower().split()[:20])") |
| self._emit("mrg_words = set(mrg_ans.lower().split()[:20])") |
| self._emit("common = src_words & mrg_words") |
| self._emit("match = len(common) / max(len(src_words), 1) > 0.3") |
| self._emit("if match:") |
| self._indent += 1 |
| self._emit("_matches += 1") |
| self._indent -= 1 |
| self._emit('_compare_details.append({"question": q[:60], "source": src_ans[:80], "merged": mrg_ans[:80], "match": match})') |
| self._indent -= 1 |
| self._emit("") |
| self._emit("_retention = _matches / max(len(_compare_questions), 1)") |
| self._emit("print()") |
| self._emit(f'print(f"[td_lang] COMPARE RESULTS: {alias} vs {source}")') |
| self._emit('print(f" Retention: {_matches}/{len(_compare_questions)} ({_retention:.0%})")') |
| self._emit('_ret_label = "GOOD" if _retention >= 0.7 else "WARNING - significant knowledge loss" if _retention >= 0.4 else "BAD - merge lost most knowledge"') |
| self._emit('print(f" Verdict: {_ret_label}")') |
| self._emit("") |
| self._emit(f'results["{alias}_compare_{source.split("/")[-1]}"] = {{') |
| self._indent += 1 |
| self._emit('"retention": round(_retention, 3),') |
| self._emit('"matches": _matches,') |
| self._emit('"total": len(_compare_questions),') |
| self._emit('"details": _compare_details,') |
| self._indent -= 1 |
| self._emit("}") |
|
|
| if cmd.output: |
| self._emit(f'_cmp_path = Path("{cmd.output}")') |
| self._emit("_cmp_path.parent.mkdir(parents=True, exist_ok=True)") |
| self._emit(f'with open(_cmp_path, "w") as f:') |
| self._indent += 1 |
| self._emit(f'json.dump(results["{alias}_compare_{source.split("/")[-1]}"], f, indent=2, default=str)') |
| self._indent -= 1 |
| self._emit(f'print(f"[td_lang] Compare results saved to {{_cmp_path}}")') |
|
|
| self._emit("del _mrg_model, _mrg_tok") |
| self._emit("import gc; gc.collect()") |
| self._emit("") |
|
|
| def _emit_verify(self, cmd: VerifyCmd) -> None: |
| """VERIFY - check model answers against known-correct answers. |
| |
| Loads a dataset with known answers (like gsm8k, mmlu, etc), |
| runs the model, and checks if answers are correct. |
| """ |
| alias = cmd.target |
| dataset = cmd.dataset |
| n = cmd.questions |
|
|
| self._emit(f'print("[td_lang] VERIFY - checking {alias} answers on {dataset} ({n} questions)")') |
| self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer") |
| self._emit("from datasets import load_dataset") |
| self._emit("import torch, re, random") |
| self._emit("") |
|
|
| |
| self._emit(f'# Check if dataset was downloaded earlier') |
| self._emit(f'_vfy_ds_info = results.get("{dataset}_dataset", None)') |
| self._emit("if _vfy_ds_info:") |
| self._indent += 1 |
| self._emit('_vfy_ds = load_dataset("json", data_files=_vfy_ds_info["path"], split="train")') |
| self._emit('print(f"[td_lang] Using previously downloaded dataset: {_vfy_ds_info[\'path\']}")') |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit(f'try:') |
| self._indent += 1 |
| self._emit(f'_vfy_ds = load_dataset("{dataset}", split="test")') |
| self._indent -= 1 |
| self._emit("except Exception:") |
| self._indent += 1 |
| self._emit(f'_vfy_ds = load_dataset("{dataset}", split="train")') |
| self._indent -= 1 |
| self._indent -= 1 |
| self._emit("") |
| self._emit(f"_vfy_n = min({n}, len(_vfy_ds))") |
| self._emit("_vfy_indices = random.sample(range(len(_vfy_ds)), _vfy_n)") |
| self._emit("") |
|
|
| |
| self._emit(f'_vfy_checkpoint = models.get("{alias}", {{}}).get("checkpoint")') |
| self._emit("if not _vfy_checkpoint:") |
| self._indent += 1 |
| self._emit(f'_vfy_checkpoint = models["{alias}"]["model_ref"]') |
| self._indent -= 1 |
| self._emit("_vfy_tok = AutoTokenizer.from_pretrained(_vfy_checkpoint)") |
| self._emit('_vfy_model = _load_model_smart(_vfy_checkpoint, torch_dtype=torch.bfloat16, device_map="auto")') |
| self._emit("_vfy_model.eval()") |
| self._emit("") |
|
|
| |
| self._emit("# Auto-detect dataset format (gsm8k, mmlu, hellaswag, etc)") |
| self._emit("_vfy_correct = 0") |
| self._emit("_vfy_details = []") |
| self._emit("") |
| self._emit("for idx in _vfy_indices:") |
| self._indent += 1 |
| self._emit("row = _vfy_ds[idx]") |
| self._emit("") |
| self._emit("# Extract question and answer based on dataset format") |
| self._emit("question = row.get('question', row.get('prompt', row.get('input', row.get('text', ''))))") |
| self._emit("answer = row.get('answer', row.get('target', row.get('output', row.get('label', ''))))") |
| self._emit("") |
| self._emit("if not question or not answer:") |
| self._indent += 1 |
| self._emit("continue") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("# Ask the model") |
| self._emit("_vfy_prompt = f'Answer concisely: {question}'") |
| self._emit('_vfy_inputs = _vfy_tok(_vfy_prompt, return_tensors="pt").to(_vfy_model.device)') |
| self._emit("with torch.no_grad():") |
| self._indent += 1 |
| self._emit("_vfy_out = _vfy_model.generate(**_vfy_inputs, max_new_tokens=256, do_sample=False)") |
| self._indent -= 1 |
| self._emit("_vfy_response = _vfy_tok.decode(_vfy_out[0], skip_special_tokens=True)") |
| self._emit("if _vfy_response.startswith(_vfy_prompt):") |
| self._indent += 1 |
| self._emit("_vfy_response = _vfy_response[len(_vfy_prompt):].strip()") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("# Check if answer is correct (fuzzy matching)") |
| self._emit("answer_str = str(answer).strip().lower()") |
| self._emit("response_lower = _vfy_response.lower()") |
| self._emit("") |
| self._emit("# Try exact match first") |
| self._emit("correct = answer_str in response_lower") |
| self._emit("") |
| self._emit("# Try numeric match (for math datasets)") |
| self._emit("if not correct:") |
| self._indent += 1 |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("# Extract numbers from both") |
| self._emit("ans_nums = re.findall(r'-?[\\d,]+\\.?\\d*', answer_str)") |
| self._emit("resp_nums = re.findall(r'-?[\\d,]+\\.?\\d*', response_lower)") |
| self._emit("if ans_nums and resp_nums:") |
| self._indent += 1 |
| self._emit("ans_val = float(ans_nums[-1].replace(',', ''))") |
| self._emit("resp_val = float(resp_nums[-1].replace(',', ''))") |
| self._emit("correct = abs(ans_val - resp_val) < 0.01") |
| self._indent -= 1 |
| self._indent -= 1 |
| self._emit("except (ValueError, IndexError):") |
| self._indent += 1 |
| self._emit("pass") |
| self._indent -= 2 |
| self._emit("") |
| self._emit("if correct:") |
| self._indent += 1 |
| self._emit("_vfy_correct += 1") |
| self._indent -= 1 |
| self._emit('_vfy_details.append({"question": str(question)[:60], "expected": str(answer)[:40], "got": _vfy_response[:40], "correct": correct})') |
| self._indent -= 1 |
|
|
| self._emit("") |
| self._emit("_vfy_accuracy = _vfy_correct / max(_vfy_n, 1)") |
| self._emit(f'print(f"[td_lang] VERIFY RESULTS: {alias} on {dataset}")') |
| self._emit('print(f" Correct: {_vfy_correct}/{_vfy_n} ({_vfy_accuracy:.1%})")') |
| self._emit('_vfy_label = "STRONG" if _vfy_accuracy >= 0.7 else "MODERATE" if _vfy_accuracy >= 0.4 else "WEAK - needs more training"') |
| self._emit('print(f" Verdict: {_vfy_label}")') |
| self._emit("") |
| self._emit(f'results["{alias}_verify"] = {{') |
| self._indent += 1 |
| self._emit('"accuracy": round(_vfy_accuracy, 3),') |
| self._emit('"correct": _vfy_correct,') |
| self._emit('"total": _vfy_n,') |
| self._emit(f'"dataset": "{dataset}",') |
| self._emit('"details": _vfy_details,') |
| self._indent -= 1 |
| self._emit("}") |
|
|
| if cmd.output: |
| self._emit(f'_vfy_path = Path("{cmd.output}")') |
| self._emit("_vfy_path.parent.mkdir(parents=True, exist_ok=True)") |
| self._emit(f'with open(_vfy_path, "w") as f:') |
| self._indent += 1 |
| self._emit(f'json.dump(results["{alias}_verify"], f, indent=2, default=str)') |
| self._indent -= 1 |
| self._emit(f'print(f"[td_lang] Verify results saved to {{_vfy_path}}")') |
|
|
| self._emit("del _vfy_model, _vfy_tok") |
| self._emit("import gc; gc.collect()") |
| self._emit("") |
|
|
| |
| def _emit_budget_check(self, program: TDProgram) -> None: |
| budget = program.budget or BudgetBlock() |
| est_gpu = 0.0 |
| est_tokens = 0 |
| est_experiments = 0 |
|
|
| for cmd in program.commands: |
| if isinstance(cmd, LoadCmd): |
| est_gpu += 0.05 |
| elif isinstance(cmd, MergeCmd): |
| est_gpu += 2.0 |
| est_tokens += 8_000_000 |
| est_experiments += 1 |
| elif isinstance(cmd, HealCmd): |
| est_gpu += 0.5 * cmd.epochs |
| est_tokens += 1_000_000 * cmd.epochs |
| est_experiments += 1 |
| elif isinstance(cmd, EvalCmd): |
| est_gpu += 0.1 |
| est_tokens += 200_000 |
| elif isinstance(cmd, CommitCmd): |
| est_gpu += 0.01 |
| elif isinstance(cmd, DiagnoseCmd): |
| est_gpu += 0.2 |
| est_tokens += 500_000 |
| elif isinstance(cmd, SynthCmd): |
| est_gpu += 1.0 |
| est_tokens += 5_000_000 |
| est_experiments += 1 |
| elif isinstance(cmd, TrainCmd): |
| steps = cmd.steps or 64 |
| est_gpu += 0.5 + (steps / 64) * 1.5 |
| est_tokens += steps * 100_000 |
| est_experiments += 1 |
| elif isinstance(cmd, DebateCmd): |
| est_gpu += 0.3 * cmd.rounds |
| est_tokens += cmd.rounds * cmd.candidates * 200_000 |
| elif isinstance(cmd, EditCmd): |
| est_gpu += 0.5 |
| est_tokens += 500_000 |
| est_experiments += 1 |
| elif isinstance(cmd, ForkCmd): |
| est_gpu += 0.1 |
| elif isinstance(cmd, ResetCmd): |
| est_gpu += 0.15 |
| elif isinstance(cmd, PruneCmd): |
| est_gpu += 1.0 |
| est_tokens += 1_000_000 |
| est_experiments += 1 |
| elif isinstance(cmd, FuseCmd): |
| n = len(cmd.sources) |
| est_gpu += 2.0 * n |
| est_tokens += 8_000_000 * n |
| est_experiments += n |
| elif isinstance(cmd, AbsorbCmd): |
| est_gpu += 2.0 |
| est_tokens += 8_000_000 |
| est_experiments += 1 |
| elif isinstance(cmd, RepeatBlock): |
| |
| body_est = 1.0 * len(cmd.body) |
| est_gpu += body_est * cmd.count |
| est_experiments += cmd.count |
| elif isinstance(cmd, IfBlock): |
| est_gpu += 0.5 |
| elif isinstance(cmd, SnapshotCmd): |
| est_gpu += 0.05 |
| elif isinstance(cmd, ReportCmd): |
| est_gpu += 0.01 |
| elif isinstance(cmd, ScheduleCmd): |
| body_est = 1.0 * len(cmd.body) |
| est_gpu += body_est |
| elif isinstance(cmd, (NotifyCmd, SaveCmd)): |
| est_gpu += 0.01 |
| elif isinstance(cmd, DownloadCmd): |
| est_gpu += 0.05 |
| elif isinstance(cmd, CompareCmd): |
| est_gpu += 0.5 |
| est_tokens += 500_000 |
| elif isinstance(cmd, VerifyCmd): |
| est_gpu += 0.3 |
| est_tokens += 300_000 |
| elif isinstance(cmd, VoteCmd): |
| est_gpu += 0.1 * cmd.samples |
| est_tokens += 50_000 * cmd.samples |
| elif isinstance(cmd, PromptBlock): |
| est_gpu += 0.0 |
| elif isinstance(cmd, DistillCmd): |
| steps = cmd.steps or 200 |
| est_gpu += 1.0 + (steps / 100) * 0.5 |
| est_tokens += steps * 150_000 |
| est_experiments += 1 |
| elif isinstance(cmd, RollbackCmd): |
| est_gpu += 0.15 |
| elif isinstance(cmd, CurriculumCmd): |
| est_gpu += cmd.levels * (0.5 + (cmd.steps / 64) * 1.5) |
| est_tokens += cmd.levels * cmd.steps * 100_000 |
| est_experiments += cmd.levels |
| elif isinstance(cmd, StarCmd): |
| est_gpu += cmd.rounds * (0.3 + cmd.samples * 0.1) |
| est_tokens += cmd.rounds * cmd.samples * 200_000 |
| est_experiments += cmd.rounds |
| elif isinstance(cmd, BestOfCmd): |
| est_gpu += 0.5 + (cmd.steps / 32) * 1.0 |
| est_tokens += cmd.n * cmd.steps * 50_000 |
| est_experiments += 1 |
| elif isinstance(cmd, ExploitCmd): |
| est_gpu += 0.5 + cmd.samples * 0.05 + (cmd.steps / 32) * 1.0 |
| est_tokens += cmd.samples * 100_000 |
| est_experiments += 1 |
| elif isinstance(cmd, ArenaCmd): |
| |
| est_gpu += cmd.rounds * (0.5 + cmd.episodes * 0.02 + (cmd.steps / 32) * 1.0) |
| est_tokens += cmd.rounds * cmd.episodes * 50_000 |
| est_experiments += cmd.rounds |
| elif isinstance(cmd, ResearchArenaCmd): |
| |
| est_gpu += 0.5 + cmd.rounds * (0.5 + cmd.episodes * 0.05 + (cmd.steps / 32) * 1.0) |
| est_tokens += cmd.rounds * cmd.episodes * 80_000 |
| est_experiments += cmd.rounds |
|
|
| est_cost = est_gpu * self.GPU_HOURLY |
|
|
| self._emit("# Budget heuristic (estimated before execution)") |
| self._emit(f"est_gpu_hours = {est_gpu:.4f}") |
| self._emit(f"est_tokens = {est_tokens}") |
| self._emit(f"est_experiments = {est_experiments}") |
| self._emit("est_cost = est_gpu_hours * GPU_HOURLY") |
|
|
| if budget.max_gpu_hours is not None: |
| self._emit(f"if est_gpu_hours > {budget.max_gpu_hours}:") |
| self._indent += 1 |
| self._emit(f'raise TDBudgetError("max_gpu_hours", {budget.max_gpu_hours}, est_gpu_hours)') |
| self._indent -= 1 |
| if budget.max_cost is not None: |
| self._emit(f"if est_cost > {budget.max_cost}:") |
| self._indent += 1 |
| self._emit(f'raise TDBudgetError("max_cost", {budget.max_cost}, est_cost)') |
| self._indent -= 1 |
| if budget.max_tokens is not None: |
| self._emit(f"if est_tokens > {budget.max_tokens}:") |
| self._indent += 1 |
| self._emit(f'raise TDBudgetError("max_tokens", {budget.max_tokens}, est_tokens)') |
| self._indent -= 1 |
| if budget.max_experiments is not None: |
| self._emit(f"if est_experiments > {budget.max_experiments}:") |
| self._indent += 1 |
| self._emit(f'raise TDBudgetError("max_experiments", {budget.max_experiments}, est_experiments)') |
| self._indent -= 1 |
| self._emit('print("[td_lang] Budget check passed.")') |
| self._emit("") |
|
|
| |
|
|
| def _emit_curriculum(self, cmd: CurriculumCmd, program: TDProgram) -> None: |
| """CURRICULUM - progressive difficulty training (SEC). |
| |
| Splits problems into difficulty levels by answer length/complexity. |
| Trains on easy first, then medium, then hard. |
| Only advances when accuracy on current level exceeds 60%. |
| """ |
| self._emit(f'print("[td_lang] Curriculum training {cmd.target}: {cmd.levels} levels, {cmd.steps} steps each...")') |
| self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")') |
| self._emit("if not checkpoint:") |
| self._indent += 1 |
| self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]') |
| self._indent -= 1 |
| self._emit("") |
| self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig") |
| self._emit("from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training") |
| self._emit("from trl import SFTTrainer") |
| self._emit("from datasets import load_dataset, Dataset") |
| self._emit("import torch") |
| self._emit("") |
| self._emit(f'dataset_path = "{cmd.dataset}"') |
| self._emit("if dataset_path.endswith('.jsonl'):") |
| self._indent += 1 |
| self._emit("full_data = load_dataset('json', data_files=dataset_path, split='train')") |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit("full_data = load_dataset(dataset_path, split='train')") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("# Sort by difficulty (estimated by answer length - longer answers = harder problems)") |
| self._emit("text_key = 'text' if 'text' in full_data.column_names else full_data.column_names[0]") |
| self._emit("lengths = [len(str(row.get(text_key, row.get('answer', '')))) for row in full_data]") |
| self._emit("sorted_indices = sorted(range(len(lengths)), key=lambda i: lengths[i])") |
| self._emit(f"n_levels = {cmd.levels}") |
| self._emit("chunk_size = len(sorted_indices) // n_levels") |
| self._emit("") |
| self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)") |
| self._emit("if tok.pad_token is None:") |
| self._indent += 1 |
| self._emit("tok.pad_token = tok.eos_token") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("for level in range(n_levels):") |
| self._indent += 1 |
| self._emit("start_idx = level * chunk_size") |
| self._emit("end_idx = start_idx + chunk_size if level < n_levels - 1 else len(sorted_indices)") |
| self._emit("level_indices = sorted_indices[start_idx:end_idx]") |
| self._emit("level_data = full_data.select(level_indices)") |
| self._emit('_level_label = ["easy", "medium", "hard", "expert"][min(level, 3)]') |
| self._emit('print(f"[td_lang] Level {level+1}/{n_levels} ({_level_label}): {len(level_data)} examples")') |
| self._emit("") |
| self._emit("# Load fresh model each level (or continue from last checkpoint)") |
| self._emit("bnb_config = BitsAndBytesConfig(") |
| self._indent += 1 |
| self._emit("load_in_4bit=True, bnb_4bit_quant_type='nf4',") |
| self._emit("bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True,") |
| self._indent -= 1 |
| self._emit(")") |
| self._emit("model = _load_model_smart(checkpoint, quantization_config=bnb_config, device_map='auto')") |
| self._emit("model = prepare_model_for_kbit_training(model)") |
| self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,") |
| self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")') |
| self._emit("model = get_peft_model(model, lora_config)") |
| self._emit("") |
| self._emit("from transformers import TrainingArguments") |
| self._emit(f"level_out = f'td_lang_outputs/curriculum_level_{{level}}'") |
| self._emit("training_args = TrainingArguments(") |
| self._indent += 1 |
| self._emit("output_dir=level_out,") |
| self._emit(f"max_steps={cmd.steps},") |
| self._emit("per_device_train_batch_size=1,") |
| self._emit("gradient_accumulation_steps=4,") |
| self._emit("learning_rate=5e-5,") |
| self._emit("logging_steps=16,") |
| self._emit("bf16=True,") |
| self._emit("gradient_checkpointing=True,") |
| self._indent -= 1 |
| self._emit(")") |
| self._emit("trainer = SFTTrainer(model=model, train_dataset=level_data, args=training_args, processing_class=tok)") |
| self._emit("trainer.train()") |
| self._emit("trainer.save_model(level_out)") |
| self._emit("checkpoint = level_out # next level starts from this") |
| self._emit('print(f"[td_lang] Level {level+1} complete. Saved to {level_out}")') |
| self._emit("") |
| self._emit("del model") |
| self._emit("import gc; gc.collect()") |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("torch.cuda.empty_cache()") |
| self._indent -= 1 |
| self._emit("except Exception:") |
| self._indent += 1 |
| self._emit("pass") |
| self._indent -= 2 |
| self._emit("") |
| self._emit(f'models["{cmd.target}"]["checkpoint"] = checkpoint') |
| self._emit(f'print("[td_lang] Curriculum training complete. Model progressed through {{n_levels}} levels.")') |
| self._emit(f'lineage["{cmd.target}"]["operations"].append({{') |
| self._indent += 1 |
| self._emit('"op": "curriculum",') |
| self._emit(f'"dataset": "{cmd.dataset}",') |
| self._emit(f'"levels": {cmd.levels},') |
| self._emit(f'"steps_per_level": {cmd.steps},') |
| self._emit('"timestamp": datetime.now().isoformat(),') |
| self._indent -= 1 |
| self._emit("})") |
|
|
| def _emit_star(self, cmd: StarCmd, program: TDProgram) -> None: |
| """STaR - Self-Taught Reasoner. |
| |
| For each problem: generate N solutions, check which are correct, |
| train on the correct reasoning chains. Repeat for R rounds. |
| The model learns from its own successes. |
| """ |
| self._emit(f'print("[td_lang] STaR training {cmd.target}: {cmd.rounds} rounds, {cmd.samples} samples/problem...")') |
| self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")') |
| self._emit("if not checkpoint:") |
| self._indent += 1 |
| self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]') |
| self._indent -= 1 |
| self._emit("") |
| self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments") |
| self._emit("from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training") |
| self._emit("from trl import SFTTrainer") |
| self._emit("from datasets import load_dataset, Dataset") |
| self._emit("import torch, re") |
| self._emit("") |
| self._emit(f'dataset_path = "{cmd.dataset}"') |
| self._emit("if dataset_path.endswith('.jsonl'):") |
| self._indent += 1 |
| self._emit("raw_data = load_dataset('json', data_files=dataset_path, split='train')") |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit("raw_data = load_dataset(dataset_path, split='train')") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("# Extract question-answer pairs") |
| self._emit("qa_pairs = []") |
| self._emit("for row in raw_data:") |
| self._indent += 1 |
| self._emit("q = row.get('question', row.get('prompt', row.get('text', '')))") |
| self._emit("a = str(row.get('answer', row.get('response', row.get('label', ''))))") |
| self._emit("if q and a:") |
| self._indent += 1 |
| self._emit("qa_pairs.append((q, a))") |
| self._indent -= 2 |
| self._emit("qa_pairs = qa_pairs[:200] # cap at 200 problems per round") |
| self._emit("") |
| self._emit(f"for star_round in range({cmd.rounds}):") |
| self._indent += 1 |
| self._emit('print(f"[td_lang] STaR round {star_round+1}/{' + str(cmd.rounds) + '}...")') |
| self._emit("") |
| self._emit("# Step 1: Generate solutions") |
| self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)") |
| self._emit("if tok.pad_token is None:") |
| self._indent += 1 |
| self._emit("tok.pad_token = tok.eos_token") |
| self._indent -= 1 |
| self._emit("bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4',") |
| self._emit(" bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)") |
| self._emit("model = _load_model_smart(checkpoint, quantization_config=bnb_config, device_map='auto')") |
| self._emit("model.eval()") |
| self._emit("") |
| self._emit("correct_chains = []") |
| self._emit("total_tried = 0") |
| self._emit("for q, expected_a in qa_pairs:") |
| self._indent += 1 |
| self._emit("inputs = tok(q, return_tensors='pt').to(model.device)") |
| self._emit(f"for sample_i in range({cmd.samples}):") |
| self._indent += 1 |
| self._emit("with torch.no_grad():") |
| self._indent += 1 |
| self._emit("out = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9)") |
| self._indent -= 1 |
| self._emit("resp = tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)") |
| self._emit("total_tried += 1") |
| self._emit("# Check if answer is correct (fuzzy match)") |
| self._emit("resp_lower = resp.lower().strip()") |
| self._emit("expected_lower = expected_a.lower().strip()") |
| self._emit("# Extract numbers for math comparison") |
| self._emit("resp_nums = re.findall(r'-?\\d+\\.?\\d*', resp_lower)") |
| self._emit("exp_nums = re.findall(r'-?\\d+\\.?\\d*', expected_lower)") |
| self._emit("is_correct = expected_lower in resp_lower") |
| self._emit("if not is_correct and resp_nums and exp_nums:") |
| self._indent += 1 |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("is_correct = abs(float(resp_nums[-1]) - float(exp_nums[-1])) < 0.01") |
| self._indent -= 1 |
| self._emit("except ValueError:") |
| self._indent += 1 |
| self._emit("pass") |
| self._indent -= 2 |
| self._emit("if is_correct:") |
| self._indent += 1 |
| self._emit("correct_chains.append(q + '\\n' + resp)") |
| self._emit("break # got a correct answer, move to next problem") |
| self._indent -= 3 |
| self._emit("") |
| self._emit("del model") |
| self._emit("import gc; gc.collect()") |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("torch.cuda.empty_cache()") |
| self._indent -= 1 |
| self._emit("except Exception:") |
| self._indent += 1 |
| self._emit("pass") |
| self._indent -= 1 |
| self._emit("") |
| self._emit('print(f"[td_lang] Round {star_round+1}: {len(correct_chains)} correct chains from {total_tried} attempts")') |
| self._emit("") |
| self._emit("if len(correct_chains) < 5:") |
| self._indent += 1 |
| self._emit('print("[td_lang] Too few correct chains - skipping training this round")') |
| self._emit("continue") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("# Step 2: Train on correct reasoning chains") |
| self._emit("ds = Dataset.from_dict({'text': correct_chains})") |
| self._emit("model = _load_model_smart(checkpoint, quantization_config=bnb_config, device_map='auto')") |
| self._emit("model = prepare_model_for_kbit_training(model)") |
| self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,") |
| self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")') |
| self._emit("model = get_peft_model(model, lora_config)") |
| self._emit("star_out = f'td_lang_outputs/star_round_{star_round}'") |
| self._emit("training_args = TrainingArguments(output_dir=star_out, max_steps=32,") |
| self._emit(" per_device_train_batch_size=1, gradient_accumulation_steps=4,") |
| self._emit(" learning_rate=5e-5, logging_steps=8, bf16=True, gradient_checkpointing=True)") |
| self._emit("trainer = SFTTrainer(model=model, train_dataset=ds, args=training_args, processing_class=tok)") |
| self._emit("trainer.train()") |
| self._emit("trainer.save_model(star_out)") |
| self._emit("checkpoint = star_out") |
| self._emit('print(f"[td_lang] STaR round {star_round+1} trained on {len(correct_chains)} chains. Saved to {star_out}")') |
| self._emit("del model; gc.collect()") |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("torch.cuda.empty_cache()") |
| self._indent -= 1 |
| self._emit("except Exception:") |
| self._indent += 1 |
| self._emit("pass") |
| self._indent -= 2 |
| self._emit("") |
| self._emit(f'models["{cmd.target}"]["checkpoint"] = checkpoint') |
| self._emit(f'print("[td_lang] STaR complete after {cmd.rounds} rounds.")') |
| self._emit(f'lineage["{cmd.target}"]["operations"].append({{') |
| self._indent += 1 |
| self._emit('"op": "star",') |
| self._emit(f'"dataset": "{cmd.dataset}",') |
| self._emit(f'"rounds": {cmd.rounds},') |
| self._emit(f'"samples_per_problem": {cmd.samples},') |
| self._emit('"timestamp": datetime.now().isoformat(),') |
| self._indent -= 1 |
| self._emit("})") |
|
|
| def _emit_best_of(self, cmd: BestOfCmd, program: TDProgram) -> None: |
| """BEST_OF - generate N answers, score all, keep the best, train on it. |
| |
| Like vote but for training. 80-90% of RLHF gains at fraction of cost. |
| """ |
| self._emit(f'print("[td_lang] Best-of-{cmd.n} training on {cmd.target}...")') |
| self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")') |
| self._emit("if not checkpoint:") |
| self._indent += 1 |
| self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]') |
| self._indent -= 1 |
| self._emit("") |
| self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments") |
| self._emit("from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training") |
| self._emit("from trl import SFTTrainer") |
| self._emit("from datasets import load_dataset, Dataset") |
| self._emit("import torch, re, ast as _ast") |
| self._emit("") |
| self._emit(f'dataset_path = "{cmd.dataset}"') |
| self._emit("if dataset_path.endswith('.jsonl'):") |
| self._indent += 1 |
| self._emit("raw_data = load_dataset('json', data_files=dataset_path, split='train')") |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit("raw_data = load_dataset(dataset_path, split='train')") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("# Extract questions") |
| self._emit("questions = []") |
| self._emit("for row in raw_data:") |
| self._indent += 1 |
| self._emit("q = row.get('question', row.get('prompt', row.get('text', '')))") |
| self._emit("if q:") |
| self._indent += 1 |
| self._emit("questions.append(q)") |
| self._indent -= 2 |
| self._emit("questions = questions[:100] # cap at 100") |
| self._emit("") |
| self._emit("# Generate N answers per question, score them, keep the best") |
| self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)") |
| self._emit("if tok.pad_token is None:") |
| self._indent += 1 |
| self._emit("tok.pad_token = tok.eos_token") |
| self._indent -= 1 |
| self._emit("bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4',") |
| self._emit(" bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)") |
| self._emit("model = _load_model_smart(checkpoint, quantization_config=bnb_config, device_map='auto')") |
| self._emit("model.eval()") |
| self._emit("") |
| self._emit("def _score_response(resp):") |
| self._indent += 1 |
| self._emit("score = 0.0") |
| self._emit("# Length reward (not too short, not too long)") |
| self._emit("words = len(resp.split())") |
| self._emit("if 10 < words < 500:") |
| self._indent += 1 |
| self._emit("score += 0.2") |
| self._indent -= 1 |
| self._emit("# Structure reward (has reasoning markers)") |
| self._emit("markers = ['because', 'therefore', 'step', 'first', 'then', 'answer', 'result']") |
| self._emit("score += 0.1 * min(sum(1 for m in markers if m in resp.lower()), 3)") |
| self._emit("# Code compilation bonus") |
| self._emit("code_blocks = re.findall(r'```python\\n(.*?)```', resp, re.S)") |
| self._emit("for block in code_blocks:") |
| self._indent += 1 |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("_ast.parse(block)") |
| self._emit("score += 0.3") |
| self._emit("break") |
| self._indent -= 1 |
| self._emit("except SyntaxError:") |
| self._indent += 1 |
| self._emit("pass") |
| self._indent -= 2 |
| self._emit("# Confidence bonus (states a clear answer)") |
| self._emit("if any(p in resp.lower() for p in ['the answer is', 'result:', 'output:']):") |
| self._indent += 1 |
| self._emit("score += 0.2") |
| self._indent -= 1 |
| self._emit("return score") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("best_completions = []") |
| self._emit("for qi, q in enumerate(questions):") |
| self._indent += 1 |
| self._emit("inputs = tok(q, return_tensors='pt').to(model.device)") |
| self._emit("candidates = []") |
| self._emit(f"for _ in range({cmd.n}):") |
| self._indent += 1 |
| self._emit("with torch.no_grad():") |
| self._indent += 1 |
| self._emit("out = model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.8, top_p=0.95)") |
| self._indent -= 1 |
| self._emit("resp = tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)") |
| self._emit("candidates.append((resp, _score_response(resp)))") |
| self._indent -= 1 |
| self._emit("best = max(candidates, key=lambda x: x[1])") |
| self._emit("best_completions.append(q + '\\n' + best[0])") |
| self._emit("if qi % 20 == 0:") |
| self._indent += 1 |
| self._emit('print(f" Generated best-of-N for {qi+1}/{len(questions)} questions...")') |
| self._indent -= 2 |
| self._emit("") |
| self._emit("del model; import gc; gc.collect()") |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("torch.cuda.empty_cache()") |
| self._indent -= 1 |
| self._emit("except Exception:") |
| self._indent += 1 |
| self._emit("pass") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("# Train on the best completions") |
| self._emit(f'print(f"[td_lang] Training on {{len(best_completions)}} best-of-{cmd.n} completions...")') |
| self._emit("ds = Dataset.from_dict({'text': best_completions})") |
| self._emit("model = _load_model_smart(checkpoint, quantization_config=bnb_config, device_map='auto')") |
| self._emit("model = prepare_model_for_kbit_training(model)") |
| self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,") |
| self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")') |
| self._emit("model = get_peft_model(model, lora_config)") |
| self._emit("bon_out = 'td_lang_outputs/best_of_n_trained'") |
| self._emit(f"training_args = TrainingArguments(output_dir=bon_out, max_steps={cmd.steps},") |
| self._emit(" per_device_train_batch_size=1, gradient_accumulation_steps=4,") |
| self._emit(" learning_rate=5e-5, logging_steps=8, bf16=True, gradient_checkpointing=True)") |
| self._emit("trainer = SFTTrainer(model=model, train_dataset=ds, args=training_args, processing_class=tok)") |
| self._emit("trainer.train()") |
| self._emit("trainer.save_model(bon_out)") |
| self._emit(f'models["{cmd.target}"]["checkpoint"] = bon_out') |
| self._emit(f'print("[td_lang] Best-of-{cmd.n} training complete.")') |
| self._emit("del model; gc.collect()") |
| self._emit(f'lineage["{cmd.target}"]["operations"].append({{') |
| self._indent += 1 |
| self._emit('"op": "best_of",') |
| self._emit(f'"n": {cmd.n},') |
| self._emit(f'"steps": {cmd.steps},') |
| self._emit('"n_examples": len(best_completions),') |
| self._emit('"timestamp": datetime.now().isoformat(),') |
| self._indent -= 1 |
| self._emit("})") |
|
|
| def _emit_exploit(self, cmd: ExploitCmd, program: TDProgram) -> None: |
| """EXPLOIT - controlled reward hacking. |
| |
| Generate MANY diverse solutions (high temp, high diversity). |
| Only filter: is the final answer correct? (verified reward) |
| Keep ALL correct solutions - ugly ones, shortcuts, weird reasoning. |
| Train on the diverse set. The model learns multiple paths to correct answers. |
| The "hacks" often turn out to be genuinely clever shortcuts. |
| """ |
| self._emit(f'print("[td_lang] EXPLOIT mode: controlled reward hacking on {cmd.target}...")') |
| self._emit(f'print("[td_lang] Generating {cmd.samples} diverse solutions per problem...")') |
| self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")') |
| self._emit("if not checkpoint:") |
| self._indent += 1 |
| self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]') |
| self._indent -= 1 |
| self._emit("") |
| self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments") |
| self._emit("from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training") |
| self._emit("from trl import SFTTrainer") |
| self._emit("from datasets import load_dataset, Dataset") |
| self._emit("import torch, re, json") |
| self._emit("") |
| self._emit(f'dataset_path = "{cmd.dataset}"') |
| self._emit("if dataset_path.endswith('.jsonl'):") |
| self._indent += 1 |
| self._emit("raw_data = load_dataset('json', data_files=dataset_path, split='train')") |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit("raw_data = load_dataset(dataset_path, split='train')") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("# Extract question-answer pairs") |
| self._emit("qa_pairs = []") |
| self._emit("for row in raw_data:") |
| self._indent += 1 |
| self._emit("q = row.get('question', row.get('prompt', row.get('text', '')))") |
| self._emit("a = str(row.get('answer', row.get('response', row.get('label', ''))))") |
| self._emit("if q and a:") |
| self._indent += 1 |
| self._emit("qa_pairs.append((q, a))") |
| self._indent -= 2 |
| self._emit("qa_pairs = qa_pairs[:100] # cap at 100 problems") |
| self._emit('print(f"[td_lang] {len(qa_pairs)} problems loaded")') |
| self._emit("") |
| self._emit("# Load model for generation") |
| self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)") |
| self._emit("if tok.pad_token is None:") |
| self._indent += 1 |
| self._emit("tok.pad_token = tok.eos_token") |
| self._indent -= 1 |
| self._emit("bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4',") |
| self._emit(" bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)") |
| self._emit("model = _load_model_smart(checkpoint, quantization_config=bnb_config, device_map='auto')") |
| self._emit("model.eval()") |
| self._emit("") |
| self._emit("# EXPLOIT: Generate MANY diverse solutions with HIGH temperature") |
| self._emit("# Key insight: we WANT weird/creative solutions. High temp = more diversity.") |
| self._emit("exploit_data = [] # all correct solutions, regardless of method") |
| self._emit("total_correct = 0") |
| self._emit("total_generated = 0") |
| self._emit("exploit_log = [] # for inspection") |
| self._emit("") |
| self._emit("for qi, (q, expected_a) in enumerate(qa_pairs):") |
| self._indent += 1 |
| self._emit("inputs = tok(q, return_tensors='pt').to(model.device)") |
| self._emit("correct_for_this = []") |
| self._emit("") |
| self._emit(f"for sample_i in range({cmd.samples}):") |
| self._indent += 1 |
| self._emit("# Vary temperature per sample for maximum diversity") |
| self._emit(f"temp = 0.5 + (sample_i / {cmd.samples}) * 1.0 # range 0.5 to 1.5") |
| self._emit("with torch.no_grad():") |
| self._indent += 1 |
| self._emit("out = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=temp, top_p=0.95, top_k=50)") |
| self._indent -= 1 |
| self._emit("resp = tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)") |
| self._emit("total_generated += 1") |
| self._emit("") |
| self._emit("# ONLY check: is the final answer correct?") |
| self._emit("# We DON'T check reasoning quality, format, or style.") |
| self._emit("resp_lower = resp.lower().strip()") |
| self._emit("expected_lower = expected_a.lower().strip()") |
| self._emit("resp_nums = re.findall(r'-?\\d+\\.?\\d*', resp_lower)") |
| self._emit("exp_nums = re.findall(r'-?\\d+\\.?\\d*', expected_lower)") |
| self._emit("is_correct = expected_lower in resp_lower") |
| self._emit("if not is_correct and resp_nums and exp_nums:") |
| self._indent += 1 |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("is_correct = abs(float(resp_nums[-1]) - float(exp_nums[-1])) < 0.01") |
| self._indent -= 1 |
| self._emit("except ValueError:") |
| self._indent += 1 |
| self._emit("pass") |
| self._indent -= 2 |
| self._emit("") |
| self._emit("if is_correct:") |
| self._indent += 1 |
| self._emit("correct_for_this.append(resp)") |
| self._emit("total_correct += 1") |
| self._emit("# Keep ALL correct solutions - even short, weird, or hacky ones") |
| self._emit("exploit_data.append(q + '\\n' + resp)") |
| self._indent -= 2 |
| self._emit("") |
| self._emit("if correct_for_this:") |
| self._indent += 1 |
| self._emit("exploit_log.append({") |
| self._indent += 1 |
| self._emit("'question': q,") |
| self._emit("'expected': expected_a,") |
| self._emit("'n_correct': len(correct_for_this),") |
| self._emit(f"'n_attempts': {cmd.samples},") |
| self._emit("'solutions': correct_for_this,") |
| self._emit("'diversity': len(set(s[:50] for s in correct_for_this)), # unique starts") |
| self._indent -= 1 |
| self._emit("})") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("if qi % 20 == 0:") |
| self._indent += 1 |
| self._emit('print(f" Problem {qi+1}/{len(qa_pairs)}: {len(correct_for_this)} correct solutions found")') |
| self._indent -= 2 |
| self._emit("") |
| self._emit("del model; import gc; gc.collect()") |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("torch.cuda.empty_cache()") |
| self._indent -= 1 |
| self._emit("except Exception:") |
| self._indent += 1 |
| self._emit("pass") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("_hit_rate = (total_correct / total_generated * 100) if total_generated else 0") |
| self._emit('print(f"[td_lang] EXPLOIT results: {total_correct} correct solutions from {total_generated} attempts ({_hit_rate:.1f}% hit rate)")') |
| self._emit('print(f"[td_lang] {len(exploit_data)} training examples with diverse reasoning paths")') |
| self._emit("") |
| |
| if cmd.output: |
| self._emit(f'exploit_path = Path("{cmd.output}")') |
| self._emit("exploit_path.parent.mkdir(parents=True, exist_ok=True)") |
| self._emit('with open(exploit_path, "w") as f:') |
| self._indent += 1 |
| self._emit("json.dump(exploit_log, f, indent=2)") |
| self._indent -= 1 |
| self._emit('print(f"[td_lang] Exploit data saved to {exploit_path} (inspect to see the creative solutions)")') |
| self._emit("") |
| self._emit("if len(exploit_data) < 5:") |
| self._indent += 1 |
| self._emit('print("[td_lang] Too few correct solutions found - skipping training")') |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit("# Train on ALL correct solutions (the controlled hack)") |
| self._emit(f'print("[td_lang] Training on {{len(exploit_data)}} diverse correct solutions...")') |
| self._emit("ds = Dataset.from_dict({'text': exploit_data})") |
| self._emit("model = _load_model_smart(checkpoint, quantization_config=bnb_config, device_map='auto')") |
| self._emit("model = prepare_model_for_kbit_training(model)") |
| self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,") |
| self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")') |
| self._emit("model = get_peft_model(model, lora_config)") |
| self._emit("exploit_out = 'td_lang_outputs/exploit_trained'") |
| self._emit(f"training_args = TrainingArguments(output_dir=exploit_out, max_steps={cmd.steps},") |
| self._emit(" per_device_train_batch_size=1, gradient_accumulation_steps=4,") |
| self._emit(" learning_rate=5e-5, logging_steps=8, bf16=True, gradient_checkpointing=True)") |
| self._emit("trainer = SFTTrainer(model=model, train_dataset=ds, args=training_args, processing_class=tok)") |
| self._emit("trainer.train()") |
| self._emit("trainer.save_model(exploit_out)") |
| self._emit(f'models["{cmd.target}"]["checkpoint"] = exploit_out') |
| self._emit('print("[td_lang] EXPLOIT training complete. Model learned multiple solution paths.")') |
| self._emit("del model; gc.collect()") |
| self._indent -= 1 |
| self._emit(f'lineage["{cmd.target}"]["operations"].append({{') |
| self._indent += 1 |
| self._emit('"op": "exploit",') |
| self._emit(f'"dataset": "{cmd.dataset}",') |
| self._emit(f'"samples_per_problem": {cmd.samples},') |
| self._emit('"total_correct": total_correct,') |
| self._emit('"total_generated": total_generated,') |
| self._emit('"timestamp": datetime.now().isoformat(),') |
| self._indent -= 1 |
| self._emit("})") |
|
|
| |
| def _emit_arena(self, cmd: ArenaCmd, program: TDProgram) -> None: |
| """ARENA - real reinforcement learning with environment, memory, curiosity, and anti-lying. |
| |
| The model enters an arena of challenges. For each episode: |
| 1. Picks a challenge from the dataset |
| 2. Generates a solution (exploring with some randomness) |
| 3. Gets IMMEDIATE reward/punishment: |
| - +1.0 for correct answer |
| - -1.0 for wrong answer |
| - -2.0 for LYING (confident but wrong — the worst offence) |
| - +curiosity_bonus for trying a NEW approach not in memory |
| 4. Stores the experience in a memory bank (approach + outcome) |
| 5. After N episodes, cross-checks creative solutions against standard ones |
| 6. Trains on reward-weighted experiences (good experiences get more weight) |
| |
| Memory persists across rounds so the model doesn't "forget the button makes |
| the door safe." Curiosity reward encourages trying new things so it doesn't |
| get stuck avoiding things that failed once. |
| """ |
| self._emit(f'print("[td_lang] ARENA: Real RL environment for {cmd.target}")') |
| self._emit(f'print("[td_lang] Rounds: {cmd.rounds}, Episodes/round: {cmd.episodes}")') |
| self._emit(f'print("[td_lang] Curiosity weight: {cmd.curiosity}")') |
| self._emit(f'print("[td_lang] Punishment for lying: -2.0 (confident + wrong)")') |
| self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")') |
| self._emit("if not checkpoint:") |
| self._indent += 1 |
| self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]') |
| self._indent -= 1 |
| self._emit("") |
| self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments") |
| self._emit("from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training") |
| self._emit("from trl import SFTTrainer") |
| self._emit("from datasets import load_dataset, Dataset") |
| self._emit("import torch, re, json, hashlib, random") |
| self._emit("") |
| |
| self._emit(f'dataset_path = "{cmd.dataset}"') |
| self._emit("if dataset_path.endswith('.jsonl'):") |
| self._indent += 1 |
| self._emit("raw_data = load_dataset('json', data_files=dataset_path, split='train')") |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit("raw_data = load_dataset(dataset_path, split='train')") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("# Extract question-answer pairs for the arena") |
| self._emit("arena_challenges = []") |
| self._emit("for row in raw_data:") |
| self._indent += 1 |
| self._emit("q = row.get('question', row.get('prompt', row.get('text', '')))") |
| self._emit("a = str(row.get('answer', row.get('response', row.get('label', ''))))") |
| self._emit("if q and a:") |
| self._indent += 1 |
| self._emit("arena_challenges.append((q, a))") |
| self._indent -= 2 |
| self._emit('print(f"[td_lang] Arena loaded {len(arena_challenges)} challenges")') |
| self._emit("") |
| |
| self._emit("# === MEMORY BANK ===") |
| self._emit("# Persists across rounds so the model remembers what worked.") |
| self._emit("# Each entry: {approach_hash, question_hash, reward, response_text}") |
| self._emit("# This prevents the 'forgot the button makes the door safe' problem.") |
| self._emit("memory_bank = [] # list of (approach_hash, question_hash, reward, text)") |
| self._emit("seen_approaches = set() # hashes of approaches tried (for curiosity)") |
| self._emit("arena_log = [] # full log for inspection") |
| self._emit("") |
| |
| self._emit("def _hash_approach(response):") |
| self._indent += 1 |
| self._emit('"""Hash the reasoning approach (first 200 chars) to detect novelty."""') |
| self._emit("# Strip numbers/specifics to capture the METHOD not the answer") |
| self._emit("method = re.sub(r'\\d+', 'N', response[:200]).strip().lower()") |
| self._emit("return hashlib.md5(method.encode()).hexdigest()[:12]") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("def _check_correct(response, expected):") |
| self._indent += 1 |
| self._emit('"""Check if response contains the correct answer."""') |
| self._emit("resp_lower = response.lower().strip()") |
| self._emit("exp_lower = expected.lower().strip()") |
| self._emit("# Direct text match") |
| self._emit("if exp_lower in resp_lower:") |
| self._indent += 1 |
| self._emit("return True") |
| self._indent -= 1 |
| self._emit("# Numeric match") |
| self._emit("resp_nums = re.findall(r'-?\\d+\\.?\\d*', resp_lower)") |
| self._emit("exp_nums = re.findall(r'-?\\d+\\.?\\d*', exp_lower)") |
| self._emit("if resp_nums and exp_nums:") |
| self._indent += 1 |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("return abs(float(resp_nums[-1]) - float(exp_nums[-1])) < 0.01") |
| self._indent -= 1 |
| self._emit("except ValueError:") |
| self._indent += 1 |
| self._emit("pass") |
| self._indent -= 2 |
| self._emit("return False") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("def _detect_lying(response, is_correct):") |
| self._indent += 1 |
| self._emit('"""Detect if the model is LYING - confident but wrong."""') |
| self._emit("if is_correct:") |
| self._indent += 1 |
| self._emit("return False # can't be lying if correct") |
| self._indent -= 1 |
| self._emit("# Check for confident language in a wrong answer") |
| self._emit("confidence_markers = ['the answer is', 'definitely', 'clearly', 'obviously',") |
| self._emit(" 'without a doubt', 'i am certain', 'i am sure', 'absolutely',") |
| self._emit(" 'the correct answer', 'the result is', 'therefore the answer']") |
| self._emit("resp_lower = response.lower()") |
| self._emit("confidence_count = sum(1 for m in confidence_markers if m in resp_lower)") |
| self._emit("# If 2+ confidence markers in a WRONG answer = lying") |
| self._emit("return confidence_count >= 2") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("def _cross_check(response, question, expected, model, tok):") |
| self._indent += 1 |
| self._emit('"""Cross-check a creative solution against standard approach."""') |
| self._emit("# Generate 2 standard solutions (low temp = conservative)") |
| self._emit("standard_answers = []") |
| self._emit("inputs = tok(question, return_tensors='pt').to(model.device)") |
| self._emit("for _ in range(2):") |
| self._indent += 1 |
| self._emit("with torch.no_grad():") |
| self._indent += 1 |
| self._emit("out = model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.3, top_p=0.9)") |
| self._indent -= 1 |
| self._emit("std_resp = tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)") |
| self._emit("standard_answers.append(std_resp)") |
| self._indent -= 1 |
| self._emit("# Check if creative answer matches standard ones") |
| self._emit("creative_correct = _check_correct(response, expected)") |
| self._emit("std_correct = [_check_correct(s, expected) for s in standard_answers]") |
| self._emit("# Case 1: creative matches standard — verified good") |
| self._emit("if creative_correct and any(std_correct):") |
| self._indent += 1 |
| self._emit("return 'verified'") |
| self._indent -= 1 |
| self._emit("# Case 2: creative correct but standards failed — creative is BETTER") |
| self._emit("if creative_correct and not any(std_correct):") |
| self._indent += 1 |
| self._emit("return 'superior' # creative found something standards missed") |
| self._indent -= 1 |
| self._emit("# Case 3: creative wrong — reject") |
| self._emit("if not creative_correct:") |
| self._indent += 1 |
| self._emit("return 'wrong'") |
| self._indent -= 1 |
| self._emit("return 'verified'") |
| self._indent -= 1 |
| self._emit("") |
| |
| self._emit(f"for arena_round in range({cmd.rounds}):") |
| self._indent += 1 |
| self._emit(f'print(f"\\n[td_lang] === ARENA ROUND {{arena_round+1}}/{cmd.rounds} ===")') |
| self._emit("") |
| self._emit("# Load model for this round") |
| self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)") |
| self._emit("if tok.pad_token is None:") |
| self._indent += 1 |
| self._emit("tok.pad_token = tok.eos_token") |
| self._indent -= 1 |
| self._emit("bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4',") |
| self._emit(" bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)") |
| self._emit("model = _load_model_smart(checkpoint, quantization_config=bnb_config, device_map='auto')") |
| self._emit("model.eval()") |
| self._emit("") |
| |
| self._emit("round_experiences = [] # (text, reward) pairs for this round") |
| self._emit("round_stats = {'correct': 0, 'wrong': 0, 'lying': 0, 'curious': 0, 'cross_checked': 0}") |
| self._emit(f"episode_challenges = random.sample(arena_challenges, min({cmd.episodes}, len(arena_challenges)))") |
| self._emit("") |
| self._emit("for ep_i, (question, expected) in enumerate(episode_challenges):") |
| self._indent += 1 |
| self._emit("q_hash = hashlib.md5(question.encode()).hexdigest()[:12]") |
| self._emit("") |
| self._emit("# Generate a solution (explore with moderate randomness)") |
| self._emit("inputs = tok(question, return_tensors='pt').to(model.device)") |
| self._emit("# Temperature increases slightly each round to encourage more exploration") |
| self._emit(f"temp = 0.6 + (arena_round * 0.1) + random.uniform(-0.1, 0.1)") |
| self._emit("temp = max(0.3, min(temp, 1.5)) # clamp") |
| self._emit("with torch.no_grad():") |
| self._indent += 1 |
| self._emit("out = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=temp, top_p=0.95, top_k=50)") |
| self._indent -= 1 |
| self._emit("response = tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)") |
| self._emit("") |
| |
| self._emit("# === REWARD CALCULATION ===") |
| self._emit("approach_hash = _hash_approach(response)") |
| self._emit("is_correct = _check_correct(response, expected)") |
| self._emit("is_lying = _detect_lying(response, is_correct)") |
| self._emit("") |
| self._emit("# Base reward: +1 correct, -1 wrong, -2 lying") |
| self._emit("if is_lying:") |
| self._indent += 1 |
| self._emit("reward = -2.0 # WORST punishment: confident + wrong") |
| self._emit("round_stats['lying'] += 1") |
| self._indent -= 1 |
| self._emit("elif is_correct:") |
| self._indent += 1 |
| self._emit("reward = 1.0") |
| self._emit("round_stats['correct'] += 1") |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit("reward = -1.0") |
| self._emit("round_stats['wrong'] += 1") |
| self._indent -= 1 |
| self._emit("") |
| |
| self._emit("# === CURIOSITY BONUS ===") |
| self._emit("# Reward for trying something NEW (approach not in memory)") |
| self._emit("novelty_key = f'{q_hash}_{approach_hash}'") |
| self._emit("if novelty_key not in seen_approaches:") |
| self._indent += 1 |
| self._emit(f"reward += {cmd.curiosity} # curiosity bonus!") |
| self._emit("seen_approaches.add(novelty_key)") |
| self._emit("round_stats['curious'] += 1") |
| self._indent -= 1 |
| self._emit("") |
| |
| self._emit("# === CROSS-CHECK ===") |
| self._emit("# If the model found a correct answer, verify it against standard approach") |
| self._emit("cross_result = None") |
| self._emit("if is_correct:") |
| self._indent += 1 |
| self._emit("cross_result = _cross_check(response, question, expected, model, tok)") |
| self._emit("round_stats['cross_checked'] += 1") |
| self._emit("if cross_result == 'superior':") |
| self._indent += 1 |
| self._emit("reward += 0.5 # extra reward for finding something better than standard") |
| self._indent -= 1 |
| self._indent -= 1 |
| self._emit("") |
| |
| self._emit("# === MEMORY ===") |
| self._emit("# Store this experience so the model REMEMBERS what worked") |
| self._emit("memory_entry = {") |
| self._indent += 1 |
| self._emit("'approach_hash': approach_hash,") |
| self._emit("'question_hash': q_hash,") |
| self._emit("'reward': reward,") |
| self._emit("'is_correct': is_correct,") |
| self._emit("'is_lying': is_lying,") |
| self._emit("'cross_check': cross_result,") |
| self._emit("'round': arena_round,") |
| self._emit("'episode': ep_i,") |
| self._indent -= 1 |
| self._emit("}") |
| self._emit("memory_bank.append(memory_entry)") |
| self._emit("") |
| self._emit("# Store experience for training (reward-weighted)") |
| self._emit("if reward > 0:") |
| self._indent += 1 |
| self._emit("# Good experience: store with text for training") |
| self._emit("round_experiences.append((question + '\\n' + response, reward))") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("if ep_i % 10 == 0:") |
| self._indent += 1 |
| self._emit("print(f' Episode {ep_i+1}: reward={reward:.1f} correct={is_correct} lying={is_lying}')") |
| self._indent -= 2 |
| self._emit("") |
| |
| self._emit("# Round summary") |
| self._emit("total_ep = round_stats['correct'] + round_stats['wrong'] + round_stats['lying']") |
| self._emit("print(f'[td_lang] Round {arena_round+1} results:')") |
| self._emit("print(f' Correct: {round_stats[\"correct\"]}/{total_ep}')") |
| self._emit("print(f' Wrong: {round_stats[\"wrong\"]}/{total_ep}')") |
| self._emit("print(f' Caught lying: {round_stats[\"lying\"]} (punished -2.0 each)')") |
| self._emit("print(f' Curiosity explorations: {round_stats[\"curious\"]}')") |
| self._emit("print(f' Cross-checked: {round_stats[\"cross_checked\"]}')") |
| self._emit("print(f' Positive experiences for training: {len(round_experiences)}')") |
| self._emit("") |
| |
| self._emit("# Free generation model") |
| self._emit("del model; import gc; gc.collect()") |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("torch.cuda.empty_cache()") |
| self._indent -= 1 |
| self._emit("except Exception:") |
| self._indent += 1 |
| self._emit("pass") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("if len(round_experiences) < 3:") |
| self._indent += 1 |
| self._emit("print('[td_lang] Too few positive experiences — skipping training this round')") |
| self._emit("continue") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("# === REWARD-WEIGHTED TRAINING ===") |
| self._emit("# Higher reward = more copies in training data (the model sees it more)") |
| self._emit("# This is how RL works: reinforce good behaviour, ignore bad") |
| self._emit("training_texts = []") |
| self._emit("for text, reward in round_experiences:") |
| self._indent += 1 |
| self._emit("# Duplicate high-reward experiences (reward 1.0 = 2 copies, 1.5+ = 3 copies)") |
| self._emit("copies = max(1, int(reward * 2))") |
| self._emit("training_texts.extend([text] * copies)") |
| self._indent -= 1 |
| self._emit("random.shuffle(training_texts)") |
| self._emit('print(f"[td_lang] Training on {len(training_texts)} reward-weighted experiences...")') |
| self._emit("") |
| self._emit("ds = Dataset.from_dict({'text': training_texts})") |
| self._emit("model = _load_model_smart(checkpoint, quantization_config=bnb_config, device_map='auto')") |
| self._emit("model = prepare_model_for_kbit_training(model)") |
| self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,") |
| self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")') |
| self._emit("model = get_peft_model(model, lora_config)") |
| self._emit(f"arena_out = f'td_lang_outputs/arena_round_{{arena_round}}'") |
| self._emit(f"training_args = TrainingArguments(output_dir=arena_out, max_steps={cmd.steps},") |
| self._emit(" per_device_train_batch_size=1, gradient_accumulation_steps=4,") |
| self._emit(" learning_rate=5e-5, logging_steps=16, bf16=True, gradient_checkpointing=True)") |
| self._emit("trainer = SFTTrainer(model=model, train_dataset=ds, args=training_args, processing_class=tok)") |
| self._emit("trainer.train()") |
| self._emit("trainer.save_model(arena_out)") |
| self._emit("checkpoint = arena_out # next round uses improved model") |
| self._emit("print(f'[td_lang] Arena round {arena_round+1} training complete.')") |
| self._emit("") |
| self._emit("del model; gc.collect()") |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("torch.cuda.empty_cache()") |
| self._indent -= 1 |
| self._emit("except Exception:") |
| self._indent += 1 |
| self._emit("pass") |
| self._indent -= 1 |
| self._emit("") |
| |
| self._emit("arena_log.append({") |
| self._indent += 1 |
| self._emit("'round': arena_round,") |
| self._emit("'stats': dict(round_stats),") |
| self._emit("'n_training_examples': len(training_texts),") |
| self._emit("'memory_size': len(memory_bank),") |
| self._emit("'unique_approaches': len(seen_approaches),") |
| self._indent -= 1 |
| self._emit("})") |
| self._indent -= 1 |
| self._emit("") |
| |
| self._emit(f'models["{cmd.target}"]["checkpoint"] = checkpoint') |
| self._emit('print(f"[td_lang] ARENA COMPLETE")') |
| self._emit('print(f"[td_lang] Total memories: {len(memory_bank)}")') |
| self._emit('print(f"[td_lang] Unique approaches discovered: {len(seen_approaches)}")') |
| self._emit("") |
| self._emit("# Memory analysis") |
| self._emit("lying_count = sum(1 for m in memory_bank if m['is_lying'])") |
| self._emit("correct_count = sum(1 for m in memory_bank if m['is_correct'])") |
| self._emit("print(f'[td_lang] Total correct: {correct_count}')") |
| self._emit("print(f'[td_lang] Total caught lying: {lying_count} (punished -2.0 each)')") |
| self._emit("avg_reward = sum(m['reward'] for m in memory_bank) / len(memory_bank) if memory_bank else 0") |
| self._emit("print(f'[td_lang] Average reward: {avg_reward:.2f}')") |
| self._emit("") |
| |
| if cmd.output: |
| self._emit(f'arena_log_path = Path("{cmd.output}")') |
| self._emit("arena_log_path.parent.mkdir(parents=True, exist_ok=True)") |
| self._emit('with open(arena_log_path, "w") as f:') |
| self._indent += 1 |
| self._emit("json.dump({'log': arena_log, 'memory': memory_bank}, f, indent=2)") |
| self._indent -= 1 |
| self._emit('print(f"[td_lang] Arena log saved to {arena_log_path}")') |
| self._emit("") |
| |
| self._emit(f'lineage["{cmd.target}"]["operations"].append({{') |
| self._indent += 1 |
| self._emit('"op": "arena",') |
| self._emit(f'"dataset": "{cmd.dataset}",') |
| self._emit(f'"rounds": {cmd.rounds},') |
| self._emit(f'"episodes_per_round": {cmd.episodes},') |
| self._emit(f'"curiosity_weight": {cmd.curiosity},') |
| self._emit('"total_memories": len(memory_bank),') |
| self._emit('"unique_approaches": len(seen_approaches),') |
| self._emit('"timestamp": datetime.now().isoformat(),') |
| self._indent -= 1 |
| self._emit("})") |
|
|
| def _emit_research_arena(self, cmd: ResearchArenaCmd, program: TDProgram) -> None: |
| """RESEARCH_ARENA - RL on ANY topic using real-world knowledge. |
| |
| Unlike arena (pre-made dataset), research_arena: |
| 1. Takes a TOPIC ("cancer biology", "number theory", "machine learning") |
| 2. Pulls real knowledge from sources (web search, papers, local docs) |
| 3. Extracts verifiable facts from those sources |
| 4. Builds increasingly hard questions from real knowledge |
| 5. Runs the model through, checking EVERY claim against sources |
| 6. Difficulty ESCALATES each round (fewer hints, stricter checking) |
| 7. Memory persists, lying punished, curiosity rewarded |
| """ |
| self._emit(f'print("[td_lang] RESEARCH ARENA: {cmd.topic}")') |
| self._emit(f'print("[td_lang] Source: {cmd.sources}")') |
| self._emit(f'print("[td_lang] Rounds: {cmd.rounds}, Episodes/round: {cmd.episodes}")') |
| self._emit(f'print("[td_lang] Difficulty escalation: +{cmd.difficulty_scale * 100:.0f}% per round")') |
| self._emit(f'print("[td_lang] Lying punishment: -2.0 | Curiosity bonus: +{cmd.curiosity}")') |
| self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")') |
| self._emit("if not checkpoint:") |
| self._indent += 1 |
| self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]') |
| self._indent -= 1 |
| self._emit("") |
| self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments") |
| self._emit("from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training") |
| self._emit("from trl import SFTTrainer") |
| self._emit("from datasets import Dataset") |
| self._emit("import torch, re, json, hashlib, random, textwrap") |
| self._emit("") |
| |
| self._emit("# ============================================================") |
| self._emit(f'# PHASE 1: Pull real knowledge about "{cmd.topic}"') |
| self._emit("# ============================================================") |
| self._emit(f'topic = "{cmd.topic}"') |
| self._emit(f'source_type = "{cmd.sources}"') |
| self._emit("knowledge_base = [] # list of {fact, source, difficulty}") |
| self._emit("") |
| self._emit("if source_type == 'pubmed':") |
| self._indent += 1 |
| self._emit("# Pull from PubMed API (real medical/science papers)") |
| self._emit("import urllib.request, urllib.parse, xml.etree.ElementTree as ET") |
| self._emit("search_url = f'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?db=pubmed&term={urllib.parse.quote(topic)}&retmax=50&sort=relevance'") |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("resp = urllib.request.urlopen(search_url, timeout=30)") |
| self._emit("tree = ET.parse(resp)") |
| self._emit("pmids = [id_el.text for id_el in tree.findall('.//Id')][:30]") |
| self._emit("print(f'[td_lang] Found {len(pmids)} PubMed articles on \"{topic}\"')") |
| self._emit("# Fetch abstracts") |
| self._emit("if pmids:") |
| self._indent += 1 |
| self._emit("fetch_url = f'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?db=pubmed&id={\",\".join(pmids)}&rettype=abstract&retmode=xml'") |
| self._emit("resp2 = urllib.request.urlopen(fetch_url, timeout=60)") |
| self._emit("articles_xml = resp2.read().decode('utf-8', errors='ignore')") |
| self._emit("art_tree = ET.fromstring(articles_xml)") |
| self._emit("for article in art_tree.findall('.//PubmedArticle'):") |
| self._indent += 1 |
| self._emit("title_el = article.find('.//ArticleTitle')") |
| self._emit("abstract_el = article.find('.//AbstractText')") |
| self._emit("if title_el is not None and title_el.text and abstract_el is not None and abstract_el.text:") |
| self._indent += 1 |
| self._emit("text = abstract_el.text.strip()") |
| self._emit("# Extract factual sentences (those with numbers, findings, conclusions)") |
| self._emit("for sent in re.split(r'(?<=[.!?])\\s+', text):") |
| self._indent += 1 |
| self._emit("sent = sent.strip()") |
| self._emit("if len(sent) > 40 and any(kw in sent.lower() for kw in ['found', 'result', 'show', 'demonstrate', 'significant', 'increase', 'decrease', 'cause', 'effect', 'treatment', 'method', 'approach', 'proved', 'evidence']):") |
| self._indent += 1 |
| self._emit("diff = min(1.0, len(sent) / 300) # longer = harder") |
| self._emit("knowledge_base.append({'fact': sent, 'source': title_el.text[:80], 'difficulty': diff})") |
| self._indent -= 4 |
| self._indent -= 1 |
| self._indent -= 1 |
| self._emit("except Exception as e:") |
| self._indent += 1 |
| self._emit("print(f'[td_lang] PubMed fetch failed: {e}. Falling back to web search.')") |
| self._emit("source_type = 'web'") |
| self._indent -= 2 |
| self._emit("") |
| self._emit("if source_type == 'web' or (source_type == 'pubmed' and len(knowledge_base) < 10):") |
| self._indent += 1 |
| self._emit("# Web search — use duckduckgo-search (clean API, no scraping)") |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("from duckduckgo_search import DDGS") |
| self._indent -= 1 |
| self._emit("except ImportError:") |
| self._indent += 1 |
| self._emit("print('[td_lang] Installing duckduckgo-search...')") |
| self._emit("import subprocess; subprocess.check_call(['pip', 'install', 'duckduckgo-search', '-q', '--break-system-packages'])") |
| self._emit("from duckduckgo_search import DDGS") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("ddg = DDGS()") |
| self._emit("# Search multiple angles for richer knowledge") |
| self._emit("search_queries = [") |
| self._indent += 1 |
| self._emit("f'{topic} research findings',") |
| self._emit("f'{topic} key facts evidence',") |
| self._emit("f'{topic} recent discoveries',") |
| self._indent -= 1 |
| self._emit("]") |
| self._emit("all_results = []") |
| self._emit("for sq in search_queries:") |
| self._indent += 1 |
| self._emit("results = list(ddg.text(sq, max_results=15))") |
| self._emit("all_results.extend(results)") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("seen_bodies = set()") |
| self._emit("for r in all_results:") |
| self._indent += 1 |
| self._emit("body = r.get('body', '').strip()") |
| self._emit("title = r.get('title', 'web')[:80]") |
| self._emit("href = r.get('href', '')") |
| self._emit("if body and body not in seen_bodies and len(body) > 30:") |
| self._indent += 1 |
| self._emit("seen_bodies.add(body)") |
| self._emit("# Split into sentences for finer-grained facts") |
| self._emit("for sent in re.split(r'(?<=[.!?])\\s+', body):") |
| self._indent += 1 |
| self._emit("sent = sent.strip()") |
| self._emit("if len(sent) > 30:") |
| self._indent += 1 |
| self._emit("knowledge_base.append({'fact': sent, 'source': title, 'url': href, 'difficulty': min(1.0, len(sent) / 250)})") |
| self._indent -= 3 |
| self._indent -= 1 |
| self._emit("print(f'[td_lang] Web search: {len(all_results)} results -> {len(knowledge_base)} facts')") |
| self._emit("") |
| self._emit("# Fetch full page content from top results for deeper knowledge") |
| self._emit("import urllib.request") |
| self._emit("top_urls = [r.get('href', '') for r in all_results[:5] if r.get('href')]") |
| self._emit("for page_url in top_urls:") |
| self._indent += 1 |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("req = urllib.request.Request(page_url, headers={'User-Agent': 'Mozilla/5.0'})") |
| self._emit("page_resp = urllib.request.urlopen(req, timeout=15)") |
| self._emit("page_html = page_resp.read().decode('utf-8', errors='ignore')[:50000]") |
| self._emit("# Strip HTML tags, get plain text") |
| self._emit("page_text = re.sub(r'<script[^>]*>.*?</script>', '', page_html, flags=re.S)") |
| self._emit("page_text = re.sub(r'<style[^>]*>.*?</style>', '', page_text, flags=re.S)") |
| self._emit("page_text = re.sub(r'<[^>]+>', ' ', page_text)") |
| self._emit("page_text = re.sub(r'\\s+', ' ', page_text).strip()") |
| self._emit("# Extract factual sentences") |
| self._emit("for sent in re.split(r'(?<=[.!?])\\s+', page_text[:5000]):") |
| self._indent += 1 |
| self._emit("sent = sent.strip()") |
| self._emit("if len(sent) > 50 and sent not in seen_bodies:") |
| self._indent += 1 |
| self._emit("seen_bodies.add(sent)") |
| self._emit("knowledge_base.append({'fact': sent, 'source': page_url[:60], 'url': page_url, 'difficulty': min(1.0, len(sent) / 200)})") |
| self._indent -= 2 |
| self._indent -= 1 |
| self._emit("except Exception:") |
| self._indent += 1 |
| self._emit("pass # skip pages that can't be fetched") |
| self._indent -= 2 |
| self._emit("print(f'[td_lang] Deep fetch complete: {len(knowledge_base)} total facts')") |
| self._indent -= 1 |
| self._emit("except Exception as e:") |
| self._indent += 1 |
| self._emit("print(f'[td_lang] Web search failed: {e}')") |
| self._indent -= 2 |
| self._emit("") |
| self._emit("if source_type == 'arxiv':") |
| self._indent += 1 |
| self._emit("# Pull from arXiv API (physics, math, CS, etc.)") |
| self._emit("import urllib.request, urllib.parse, xml.etree.ElementTree as ET") |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("query = urllib.parse.quote(f'all:{topic}')") |
| self._emit("url = f'http://export.arxiv.org/api/query?search_query={query}&max_results=30&sortBy=relevance'") |
| self._emit("resp = urllib.request.urlopen(url, timeout=30)") |
| self._emit("tree = ET.parse(resp)") |
| self._emit("ns = {'atom': 'http://www.w3.org/2005/Atom'}") |
| self._emit("for entry in tree.findall('.//atom:entry', ns):") |
| self._indent += 1 |
| self._emit("title = entry.find('atom:title', ns).text.strip() if entry.find('atom:title', ns) is not None else ''") |
| self._emit("summary = entry.find('atom:summary', ns).text.strip() if entry.find('atom:summary', ns) is not None else ''") |
| self._emit("for sent in re.split(r'(?<=[.!?])\\s+', summary):") |
| self._indent += 1 |
| self._emit("sent = sent.strip()") |
| self._emit("if len(sent) > 40:") |
| self._indent += 1 |
| self._emit("knowledge_base.append({'fact': sent, 'source': title[:80], 'difficulty': 0.6})") |
| self._indent -= 3 |
| self._emit("print(f'[td_lang] Pulled arXiv papers for \"{topic}\"')") |
| self._indent -= 1 |
| self._emit("except Exception as e:") |
| self._indent += 1 |
| self._emit("print(f'[td_lang] arXiv fetch failed: {e}')") |
| self._indent -= 2 |
| self._emit("") |
| |
| self._emit("if source_type not in ('web', 'pubmed', 'arxiv'):") |
| self._indent += 1 |
| self._emit("# Treat as local file/folder path") |
| self._emit("import glob as _glob") |
| self._emit("source_files = _glob.glob(source_type + '/**/*', recursive=True) if os.path.isdir(source_type) else [source_type]") |
| self._emit("for fpath in source_files:") |
| self._indent += 1 |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("with open(fpath, 'r', errors='ignore') as f:") |
| self._indent += 1 |
| self._emit("text = f.read()[:10000]") |
| self._indent -= 1 |
| self._emit("for sent in re.split(r'(?<=[.!?])\\s+', text):") |
| self._indent += 1 |
| self._emit("sent = sent.strip()") |
| self._emit("if len(sent) > 40:") |
| self._indent += 1 |
| self._emit("knowledge_base.append({'fact': sent, 'source': os.path.basename(fpath), 'difficulty': 0.5})") |
| self._indent -= 2 |
| self._indent -= 1 |
| self._emit("except Exception:") |
| self._indent += 1 |
| self._emit("pass") |
| self._indent -= 2 |
| self._emit("print(f'[td_lang] Loaded {len(source_files)} local files')") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("if len(knowledge_base) < 5:") |
| self._indent += 1 |
| self._emit(f'print("[td_lang] ERROR: Could not gather enough knowledge about \\"{cmd.topic}\\". Need at least 5 facts.")') |
| self._emit(f'print("[td_lang] Try a different topic or source type.")') |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit("print(f'[td_lang] Knowledge base built: {len(knowledge_base)} verifiable facts')") |
| self._emit("random.shuffle(knowledge_base)") |
| self._emit("") |
| |
| self._emit("# ============================================================") |
| self._emit("# PHASE 2: Build the maze — generate questions from knowledge") |
| self._emit("# ============================================================") |
| self._emit("") |
| self._emit("def _build_questions(kb, difficulty_level, n_questions):") |
| self._indent += 1 |
| self._emit('"""Build questions from knowledge base. Higher difficulty = harder questions."""') |
| self._emit("questions = []") |
| self._emit("# Sort by difficulty, pick appropriate ones for this level") |
| self._emit("sorted_kb = sorted(kb, key=lambda x: x['difficulty'])") |
| self._emit("# At higher difficulty, use harder facts and ask trickier questions") |
| self._emit("start_pct = min(0.8, difficulty_level * 0.15) # start further into hard facts") |
| self._emit("start_idx = int(len(sorted_kb) * start_pct)") |
| self._emit("pool = sorted_kb[start_idx:] if start_idx < len(sorted_kb) else sorted_kb") |
| self._emit("selected = random.sample(pool, min(n_questions, len(pool)))") |
| self._emit("") |
| self._emit("for item in selected:") |
| self._indent += 1 |
| self._emit("fact = item['fact']") |
| self._emit("source = item['source']") |
| self._emit("# Question types get harder with difficulty") |
| self._emit("if difficulty_level < 2:") |
| self._indent += 1 |
| self._emit("# Easy: just verify the fact") |
| self._emit("q = f'Based on current research, is the following claim accurate? Explain your reasoning.\\n\\nClaim: {fact}'") |
| self._indent -= 1 |
| self._emit("elif difficulty_level < 4:") |
| self._indent += 1 |
| self._emit("# Medium: ask about implications or missing pieces") |
| self._emit("q = f'A research paper states: \"{fact}\"\\n\\nWhat are the implications of this finding? What questions does it leave unanswered? What could be wrong with this conclusion?'") |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit("# Hard: ask to connect multiple facts or identify contradictions") |
| self._emit("other_facts = [x['fact'] for x in random.sample(kb, min(3, len(kb))) if x['fact'] != fact]") |
| self._emit("context = '\\n'.join(f'- {f}' for f in other_facts[:2])") |
| self._emit("q = f'Given these research findings:\\n{context}\\n\\nAnd this additional claim: \"{fact}\"\\n\\nDo these findings support or contradict each other? Identify any gaps, errors, or unsupported leaps in logic. Be precise.'") |
| self._indent -= 1 |
| self._emit("questions.append({'question': q, 'ground_truth': fact, 'source': source, 'difficulty': item['difficulty']})") |
| self._indent -= 1 |
| self._emit("return questions") |
| self._indent -= 1 |
| self._emit("") |
| |
| self._emit("def _fact_check(response, ground_truth, model, tok, strictness):") |
| self._indent += 1 |
| self._emit('"""Check model response against ground truth source. Strictness 0-1."""') |
| self._emit("# Extract key claims from the response") |
| self._emit("resp_lower = response.lower().strip()") |
| self._emit("truth_lower = ground_truth.lower().strip()") |
| self._emit("") |
| self._emit("# Extract important words from ground truth (nouns, numbers, technical terms)") |
| self._emit("truth_words = set(w for w in re.findall(r'\\b\\w{4,}\\b', truth_lower))") |
| self._emit("truth_words -= {'that', 'this', 'with', 'from', 'were', 'been', 'have', 'their', 'which', 'these', 'those', 'than', 'also', 'more'}") |
| self._emit("truth_nums = set(re.findall(r'-?\\d+\\.?\\d*', truth_lower))") |
| self._emit("") |
| self._emit("# Check how many key terms from the source appear in the response") |
| self._emit("matched_words = sum(1 for w in truth_words if w in resp_lower)") |
| self._emit("word_coverage = matched_words / max(len(truth_words), 1)") |
| self._emit("") |
| self._emit("# Check numbers match") |
| self._emit("resp_nums = set(re.findall(r'-?\\d+\\.?\\d*', resp_lower))") |
| self._emit("num_match = len(truth_nums & resp_nums) / max(len(truth_nums), 1) if truth_nums else 1.0") |
| self._emit("") |
| self._emit("# Check for direct contradictions") |
| self._emit("contradicts = False") |
| self._emit("negations = ['not true', 'incorrect', 'false', 'wrong', 'no evidence', 'disproven', 'myth', 'inaccurate']") |
| self._emit("if any(neg in resp_lower for neg in negations):") |
| self._indent += 1 |
| self._emit("# Model is denying something — check if it's denying the ground truth") |
| self._emit("if word_coverage > 0.3: # it's talking about the right topic but denying it") |
| self._indent += 1 |
| self._emit("contradicts = True") |
| self._indent -= 2 |
| self._emit("") |
| self._emit("# Threshold increases with strictness") |
| self._emit("required_coverage = 0.2 + (strictness * 0.4) # 0.2 at easy, 0.6 at hardest") |
| self._emit("score = (word_coverage * 0.6 + num_match * 0.4)") |
| self._emit("if contradicts:") |
| self._indent += 1 |
| self._emit("score = 0.0 # contradicting known facts = total failure") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("passed = score >= required_coverage") |
| self._emit("return {") |
| self._indent += 1 |
| self._emit("'passed': passed,") |
| self._emit("'score': score,") |
| self._emit("'word_coverage': word_coverage,") |
| self._emit("'num_match': num_match,") |
| self._emit("'contradicts': contradicts,") |
| self._emit("'required': required_coverage,") |
| self._emit("'missed_terms': [w for w in truth_words if w not in resp_lower][:10],") |
| self._indent -= 1 |
| self._emit("}") |
| self._indent -= 1 |
| self._emit("") |
| |
| self._emit("def _detect_lying(response, fact_result):") |
| self._indent += 1 |
| self._emit('"""Detect lying: confident language + failed fact check = lying."""') |
| self._emit("if fact_result['passed']:") |
| self._indent += 1 |
| self._emit("return False") |
| self._indent -= 1 |
| self._emit("confidence_markers = ['the answer is', 'definitely', 'clearly', 'obviously',") |
| self._emit(" 'without a doubt', 'certainly', 'i am sure', 'absolutely',") |
| self._emit(" 'it is well established', 'research confirms', 'studies show']") |
| self._emit("resp_lower = response.lower()") |
| self._emit("return sum(1 for m in confidence_markers if m in resp_lower) >= 2") |
| self._indent -= 1 |
| self._emit("") |
| |
| self._emit("# === ARENA STATE (persists across all rounds) ===") |
| self._emit("memory_bank = []") |
| self._emit("seen_approaches = set()") |
| self._emit("research_log = []") |
| self._emit("cumulative_difficulty = 0 # increases each round") |
| self._emit("") |
| |
| self._emit(f"for arena_round in range({cmd.rounds}):") |
| self._indent += 1 |
| self._emit(f"difficulty_level = arena_round # 0, 1, 2, ... (increases each round)") |
| self._emit(f"strictness = min(1.0, 0.3 + arena_round * {cmd.difficulty_scale}) # gets stricter") |
| self._emit(f"path_width = max(0.3, 1.0 - arena_round * {cmd.difficulty_scale}) # maze shrinks") |
| self._emit("") |
| self._emit(f'print(f"\\n[td_lang] === RESEARCH ARENA ROUND {{arena_round+1}}/{cmd.rounds} ===")') |
| self._emit('print(f" Difficulty: {difficulty_level} | Strictness: {strictness:.0%} | Path width: {path_width:.0%}")') |
| self._emit("") |
| self._emit("# Load model") |
| self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)") |
| self._emit("if tok.pad_token is None:") |
| self._indent += 1 |
| self._emit("tok.pad_token = tok.eos_token") |
| self._indent -= 1 |
| self._emit("bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4',") |
| self._emit(" bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)") |
| self._emit("model = _load_model_smart(checkpoint, quantization_config=bnb_config, device_map='auto')") |
| self._emit("model.eval()") |
| self._emit("") |
| |
| self._emit(f"questions = _build_questions(knowledge_base, difficulty_level, {cmd.episodes})") |
| self._emit('print(f" Generated {len(questions)} questions for this round")') |
| self._emit("") |
| self._emit("round_experiences = []") |
| self._emit("round_stats = {'correct': 0, 'wrong': 0, 'lying': 0, 'curious': 0, 'missed_facts': []}") |
| self._emit("") |
| |
| self._emit("for ep_i, q_data in enumerate(questions):") |
| self._indent += 1 |
| self._emit("question = q_data['question']") |
| self._emit("ground_truth = q_data['ground_truth']") |
| self._emit("") |
| self._emit("# Generate response") |
| self._emit("inputs = tok(question, return_tensors='pt', truncation=True, max_length=1024).to(model.device)") |
| self._emit(f"temp = max(0.3, 0.5 + arena_round * 0.05 + random.uniform(-0.1, 0.1))") |
| self._emit("with torch.no_grad():") |
| self._indent += 1 |
| self._emit("out = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=temp, top_p=0.95)") |
| self._indent -= 1 |
| self._emit("response = tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)") |
| self._emit("") |
| |
| self._emit("# === FACT CHECK against real source ===") |
| self._emit("fact_result = _fact_check(response, ground_truth, model, tok, strictness)") |
| self._emit("is_lying = _detect_lying(response, fact_result)") |
| self._emit("approach_hash = hashlib.md5(re.sub(r'\\d+', 'N', response[:200]).lower().encode()).hexdigest()[:12]") |
| self._emit("") |
| |
| self._emit("# === REWARD ===") |
| self._emit("if is_lying:") |
| self._indent += 1 |
| self._emit("reward = -2.0") |
| self._emit("round_stats['lying'] += 1") |
| self._indent -= 1 |
| self._emit("elif fact_result['passed']:") |
| self._indent += 1 |
| self._emit("reward = fact_result['score'] # 0.0 to 1.0 based on accuracy") |
| self._emit("round_stats['correct'] += 1") |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit("reward = -1.0 * strictness # punishment scales with difficulty") |
| self._emit("round_stats['wrong'] += 1") |
| self._emit("round_stats['missed_facts'].append({") |
| self._indent += 1 |
| self._emit("'ground_truth': ground_truth[:100],") |
| self._emit("'missed_terms': fact_result['missed_terms'][:5],") |
| self._emit("'source': q_data['source'],") |
| self._indent -= 1 |
| self._emit("})") |
| self._indent -= 1 |
| self._emit("") |
| |
| self._emit("novelty_key = hashlib.md5(f'{question[:50]}_{approach_hash}'.encode()).hexdigest()[:12]") |
| self._emit("if novelty_key not in seen_approaches:") |
| self._indent += 1 |
| self._emit(f"reward += {cmd.curiosity}") |
| self._emit("seen_approaches.add(novelty_key)") |
| self._emit("round_stats['curious'] += 1") |
| self._indent -= 1 |
| self._emit("") |
| |
| self._emit("memory_bank.append({'reward': reward, 'passed': fact_result['passed'],") |
| self._emit(" 'lying': is_lying, 'round': arena_round, 'score': fact_result['score']})") |
| self._emit("") |
| self._emit("if reward > 0:") |
| self._indent += 1 |
| self._emit("round_experiences.append((question + '\\n' + response, reward))") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("if ep_i % 10 == 0:") |
| self._indent += 1 |
| self._emit("status = 'PASS' if fact_result['passed'] else ('LYING!' if is_lying else 'FAIL')") |
| self._emit("print(f' Ep {ep_i+1}: {status} (score={fact_result[\"score\"]:.2f}, reward={reward:.1f})')") |
| self._indent -= 2 |
| self._emit("") |
| |
| self._emit("total_ep = round_stats['correct'] + round_stats['wrong'] + round_stats['lying']") |
| self._emit("print(f'[td_lang] Round {arena_round+1} results:')") |
| self._emit("print(f' Passed fact-check: {round_stats[\"correct\"]}/{total_ep}')") |
| self._emit("print(f' Failed: {round_stats[\"wrong\"]}/{total_ep}')") |
| self._emit("print(f' Caught lying: {round_stats[\"lying\"]} (punished -2.0 each)')") |
| self._emit("if round_stats['missed_facts']:") |
| self._indent += 1 |
| self._emit("print(f' Top missed facts ({len(round_stats[\"missed_facts\"])} total):')") |
| self._emit("for mf in round_stats['missed_facts'][:3]:") |
| self._indent += 1 |
| self._emit("print(f' Source: {mf[\"source\"]}')") |
| self._emit("print(f' Missed: {mf[\"missed_terms\"]}')") |
| self._indent -= 2 |
| self._emit("") |
| |
| self._emit("del model; import gc; gc.collect()") |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("torch.cuda.empty_cache()") |
| self._indent -= 1 |
| self._emit("except Exception:") |
| self._indent += 1 |
| self._emit("pass") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("if len(round_experiences) < 3:") |
| self._indent += 1 |
| self._emit("print('[td_lang] Too few positive experiences — maze was too hard. Skipping training.')") |
| self._emit("continue") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("# === REWARD-WEIGHTED TRAINING ===") |
| self._emit("training_texts = []") |
| self._emit("for text, reward in round_experiences:") |
| self._indent += 1 |
| self._emit("copies = max(1, int(reward * 2))") |
| self._emit("training_texts.extend([text] * copies)") |
| self._indent -= 1 |
| self._emit("random.shuffle(training_texts)") |
| self._emit('print(f"[td_lang] Training on {len(training_texts)} reward-weighted experiences...")') |
| self._emit("") |
| self._emit("ds = Dataset.from_dict({'text': training_texts})") |
| self._emit("model = _load_model_smart(checkpoint, quantization_config=bnb_config, device_map='auto')") |
| self._emit("model = prepare_model_for_kbit_training(model)") |
| self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,") |
| self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")') |
| self._emit("model = get_peft_model(model, lora_config)") |
| self._emit(f"ra_out = f'td_lang_outputs/research_arena_round_{{arena_round}}'") |
| self._emit(f"training_args = TrainingArguments(output_dir=ra_out, max_steps={cmd.steps},") |
| self._emit(" per_device_train_batch_size=1, gradient_accumulation_steps=4,") |
| self._emit(" learning_rate=5e-5, logging_steps=16, bf16=True, gradient_checkpointing=True)") |
| self._emit("trainer = SFTTrainer(model=model, train_dataset=ds, args=training_args, processing_class=tok)") |
| self._emit("trainer.train()") |
| self._emit("trainer.save_model(ra_out)") |
| self._emit("checkpoint = ra_out") |
| self._emit("del model; gc.collect()") |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("torch.cuda.empty_cache()") |
| self._indent -= 1 |
| self._emit("except Exception:") |
| self._indent += 1 |
| self._emit("pass") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("research_log.append({") |
| self._indent += 1 |
| self._emit("'round': arena_round,") |
| self._emit("'difficulty': difficulty_level,") |
| self._emit("'strictness': strictness,") |
| self._emit("'stats': dict(round_stats),") |
| self._emit("'n_training': len(training_texts),") |
| self._emit("'memory_size': len(memory_bank),") |
| self._indent -= 1 |
| self._emit("})") |
| self._emit("") |
| self._emit("print(f'[td_lang] Round {arena_round+1} complete. Model trained and saved.')") |
| self._indent -= 1 |
| self._emit("") |
| |
| self._emit(f'models["{cmd.target}"]["checkpoint"] = checkpoint') |
| self._emit('print(f"\\n[td_lang] RESEARCH ARENA COMPLETE")') |
| self._emit('print(f" Topic: {topic}")') |
| self._emit('print(f" Knowledge base: {len(knowledge_base)} facts")') |
| self._emit('print(f" Total memories: {len(memory_bank)}")') |
| self._emit('print(f" Unique approaches: {len(seen_approaches)}")') |
| self._emit("lying_count = sum(1 for m in memory_bank if m['lying'])") |
| self._emit("correct_count = sum(1 for m in memory_bank if m['passed'])") |
| self._emit("print(f' Correct: {correct_count} | Caught lying: {lying_count}')") |
| self._emit("avg_reward = sum(m['reward'] for m in memory_bank) / len(memory_bank) if memory_bank else 0") |
| self._emit("print(f' Average reward: {avg_reward:.2f}')") |
| self._emit("") |
| |
| if cmd.output: |
| self._emit(f'log_path = Path("{cmd.output}")') |
| self._emit("log_path.parent.mkdir(parents=True, exist_ok=True)") |
| self._emit('with open(log_path, "w") as f:') |
| self._indent += 1 |
| self._emit("json.dump({'topic': topic, 'log': research_log, 'memory': memory_bank, 'knowledge_base_size': len(knowledge_base)}, f, indent=2)") |
| self._indent -= 1 |
| self._emit('print(f"[td_lang] Research log saved to {log_path}")') |
| self._emit("") |
| |
| self._emit(f'lineage["{cmd.target}"]["operations"].append({{') |
| self._indent += 1 |
| self._emit('"op": "research_arena",') |
| self._emit(f'"topic": "{cmd.topic}",') |
| self._emit(f'"sources": "{cmd.sources}",') |
| self._emit(f'"rounds": {cmd.rounds},') |
| self._emit(f'"episodes_per_round": {cmd.episodes},') |
| self._emit('"knowledge_base_size": len(knowledge_base),') |
| self._emit('"total_memories": len(memory_bank),') |
| self._emit('"timestamp": datetime.now().isoformat(),') |
| self._indent -= 1 |
| self._emit("})") |
| self._indent -= 1 |
|
|
| |
| def _emit_vote(self, cmd: VoteCmd) -> None: |
| """VOTE - majority voting. Generate N answers, pick the most common. |
| |
| Proven to boost accuracy 10-20% with zero training cost. |
| """ |
| n = cmd.samples |
| self._emit(f'print("[td_lang] Majority voting on {cmd.target} ({n} samples)...")') |
| self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")') |
| self._emit("if not checkpoint:") |
| self._indent += 1 |
| self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]') |
| self._indent -= 1 |
| self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer") |
| self._emit("import torch") |
| self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)") |
| self._emit("model = _load_model_smart(checkpoint, torch_dtype=torch.bfloat16, device_map='auto')") |
| self._emit("model.eval()") |
| self._emit(f'question = {repr(cmd.question)}') |
| self._emit(f"n_samples = {n}") |
| self._emit('inputs = tok(question, return_tensors="pt").to(model.device)') |
| self._emit("answers = []") |
| self._emit("for i in range(n_samples):") |
| self._indent += 1 |
| self._emit("with torch.no_grad():") |
| self._indent += 1 |
| self._emit("out = model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.9)") |
| self._indent -= 1 |
| self._emit("resp = tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip()") |
| self._emit("answers.append(resp)") |
| self._emit('print(f" Sample {i+1}: {resp[:80]}...")') |
| self._indent -= 1 |
| self._emit("") |
| self._emit("# Find the most common answer (majority vote)") |
| self._emit("from collections import Counter") |
| self._emit("# Normalize answers: lowercase, strip whitespace for comparison") |
| self._emit("normalized = [a.strip().lower() for a in answers]") |
| self._emit("counts = Counter(normalized)") |
| self._emit("winner_norm, winner_count = counts.most_common(1)[0]") |
| self._emit("# Find the original (non-normalized) version of the winner") |
| self._emit("winner = next(a for a, n in zip(answers, normalized) if n == winner_norm)") |
| self._emit('print(f"[td_lang] Winner ({winner_count}/{n_samples} votes): {winner[:200]}")') |
| self._emit("") |
| self._emit("vote_result = {") |
| self._indent += 1 |
| self._emit("'question': question,") |
| self._emit("'winner': winner,") |
| self._emit("'votes': winner_count,") |
| self._emit("'total_samples': n_samples,") |
| self._emit("'all_answers': answers,") |
| self._emit("'confidence': winner_count / n_samples,") |
| self._indent -= 1 |
| self._emit("}") |
| self._emit(f'results["{cmd.target}_vote"] = vote_result') |
| if cmd.output: |
| self._emit(f'vote_path = Path("{cmd.output}")') |
| self._emit("vote_path.parent.mkdir(parents=True, exist_ok=True)") |
| self._emit('with open(vote_path, "w") as f:') |
| self._indent += 1 |
| self._emit("json.dump(vote_result, f, indent=2)") |
| self._indent -= 1 |
| self._emit('print(f"[td_lang] Vote results saved to {vote_path}")') |
| self._emit("del model, tok") |
| self._emit("import gc; gc.collect()") |
|
|
| def _emit_prompt(self, cmd: PromptBlock) -> None: |
| """PROMPT - attach a system prompt to a model for all future generations. |
| |
| Stores the prompt in the model's metadata so other commands (eval, diagnose, |
| synth, vote) can pick it up and prepend it. |
| """ |
| self._emit(f'print("[td_lang] Setting system prompt for {cmd.target}...")') |
| self._emit(f'models["{cmd.target}"]["system_prompt"] = {repr(cmd.text)}') |
| self._emit(f'print("[td_lang] Prompt set: {repr(cmd.text[:60])}...")') |
| self._emit(f'lineage["{cmd.target}"]["operations"].append({{') |
| self._indent += 1 |
| self._emit('"op": "prompt",') |
| self._emit(f'"text": {repr(cmd.text)},') |
| self._emit('"timestamp": datetime.now().isoformat(),') |
| self._indent -= 1 |
| self._emit("})") |
|
|
| def _emit_distill(self, cmd: DistillCmd) -> None: |
| """DISTILL - train a smaller student model using the teacher's outputs. |
| |
| The teacher generates high-quality answers, and we SFT the student on them. |
| Result: a fast model for easy questions. |
| """ |
| steps = cmd.steps |
| self._emit(f'print("[td_lang] Distilling {cmd.teacher} into student model...")') |
| self._emit(f'teacher_checkpoint = models.get("{cmd.teacher}", {{}}).get("checkpoint")') |
| self._emit("if not teacher_checkpoint:") |
| self._indent += 1 |
| self._emit(f'teacher_checkpoint = models["{cmd.teacher}"]["model_ref"]') |
| self._indent -= 1 |
| self._emit(f'student_path = {repr(cmd.student)}') |
| self._emit("") |
| self._emit("# Step 1: Generate teacher answers on diverse prompts") |
| self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer") |
| self._emit("import torch") |
| self._emit('print("[td_lang] Loading teacher model...")') |
| self._emit("teacher_tok = AutoTokenizer.from_pretrained(teacher_checkpoint)") |
| self._emit("teacher_model = _load_model_smart(teacher_checkpoint, torch_dtype=torch.bfloat16, device_map='auto')") |
| self._emit("teacher_model.eval()") |
| self._emit("") |
| self._emit("distill_prompts = [") |
| self._indent += 1 |
| self._emit('"Explain how photosynthesis works step by step.",') |
| self._emit('"Write a Python function to find the longest common subsequence.",') |
| self._emit('"What is 847 divided by 11? Show your work.",') |
| self._emit('"Compare and contrast TCP and UDP protocols.",') |
| self._emit('"Solve: if 3x + 7 = 22, what is x?",') |
| self._emit('"Explain the difference between a stack and a queue.",') |
| self._emit('"What causes seasons on Earth?",') |
| self._emit('"Write a function to check if a string is a palindrome.",') |
| self._emit('"What is the Pythagorean theorem and give an example.",') |
| self._emit('"Explain recursion with a simple example.",') |
| self._emit('"What is 15% of 240?",') |
| self._emit('"Describe how a binary search works.",') |
| self._emit('"What are the three laws of thermodynamics?",') |
| self._emit('"Write pseudocode for bubble sort.",') |
| self._emit('"If a train travels 120 miles in 2 hours, what is its speed?",') |
| self._emit('"Explain what an API is in simple terms.",') |
| self._indent -= 1 |
| self._emit("]") |
| self._emit("") |
| self._emit("teacher_data = []") |
| self._emit("for prompt in distill_prompts:") |
| self._indent += 1 |
| self._emit('inputs = teacher_tok(prompt, return_tensors="pt").to(teacher_model.device)') |
| self._emit("with torch.no_grad():") |
| self._indent += 1 |
| self._emit("out = teacher_model.generate(**inputs, max_new_tokens=512, do_sample=False)") |
| self._indent -= 1 |
| self._emit("resp = teacher_tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)") |
| self._emit('teacher_data.append({"prompt": prompt, "response": resp})') |
| self._emit('print(f" Generated: {prompt[:40]}... -> {len(resp)} chars")') |
| self._indent -= 1 |
| self._emit("") |
| self._emit("del teacher_model") |
| self._emit("import gc; gc.collect()") |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("torch.cuda.empty_cache()") |
| self._indent -= 1 |
| self._emit("except Exception:") |
| self._indent += 1 |
| self._emit("pass") |
| self._indent -= 1 |
| self._emit("") |
| self._emit("# Step 2: Load student model with QLoRA and train on teacher outputs") |
| self._emit('print("[td_lang] Loading student model with QLoRA...")') |
| self._emit("from transformers import BitsAndBytesConfig, TrainingArguments") |
| self._emit("from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training") |
| self._emit("from trl import SFTTrainer") |
| self._emit("from datasets import Dataset") |
| self._emit("") |
| self._emit("bnb_config = BitsAndBytesConfig(") |
| self._indent += 1 |
| self._emit("load_in_4bit=True,") |
| self._emit('bnb_4bit_quant_type="nf4",') |
| self._emit("bnb_4bit_compute_dtype=torch.bfloat16,") |
| self._emit("bnb_4bit_use_double_quant=True,") |
| self._indent -= 1 |
| self._emit(")") |
| self._emit("student_tok = AutoTokenizer.from_pretrained(student_path)") |
| self._emit("student_model = _load_model_smart(student_path, quantization_config=bnb_config, device_map='auto')") |
| self._emit("student_model = prepare_model_for_kbit_training(student_model)") |
| self._emit("") |
| self._emit("lora_config = LoraConfig(") |
| self._indent += 1 |
| self._emit("r=16, lora_alpha=32, lora_dropout=0.05,") |
| self._emit('target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],') |
| self._emit('task_type="CAUSAL_LM",') |
| self._indent -= 1 |
| self._emit(")") |
| self._emit("student_model = get_peft_model(student_model, lora_config)") |
| self._emit("") |
| self._emit("# Format training data") |
| self._emit("train_texts = []") |
| self._emit("for d in teacher_data:") |
| self._indent += 1 |
| self._emit("train_texts.append(d['prompt'] + '\\n' + d['response'])") |
| self._indent -= 1 |
| self._emit('ds = Dataset.from_dict({"text": train_texts})') |
| self._emit("") |
| distill_out = cmd.output or "td_lang_outputs/distilled_student" |
| self._emit(f'distill_out = "{distill_out}"') |
| self._emit("training_args = TrainingArguments(") |
| self._indent += 1 |
| self._emit("output_dir=distill_out,") |
| self._emit(f"num_train_epochs={max(1, steps // len('distill_prompts') + 1)},") |
| self._emit(f"max_steps={steps},") |
| self._emit("per_device_train_batch_size=1,") |
| self._emit("gradient_accumulation_steps=4,") |
| self._emit("learning_rate=2e-4,") |
| self._emit('optim="paged_adamw_8bit",') |
| self._emit("logging_steps=10,") |
| self._emit("save_strategy='epoch',") |
| self._emit("bf16=True,") |
| self._indent -= 1 |
| self._emit(")") |
| self._emit("trainer = SFTTrainer(") |
| self._indent += 1 |
| self._emit("model=student_model,") |
| self._emit("train_dataset=ds,") |
| self._emit("args=training_args,") |
| self._emit("processing_class=student_tok,") |
| self._indent -= 1 |
| self._emit(")") |
| self._emit('print(f"[td_lang] Training student for {training_args.max_steps} steps...")') |
| self._emit("trainer.train()") |
| self._emit("student_model.save_pretrained(distill_out)") |
| self._emit("student_tok.save_pretrained(distill_out)") |
| self._emit('print(f"[td_lang] Student model saved to {distill_out}")') |
| self._emit("") |
| self._emit("del student_model, teacher_tok, student_tok") |
| self._emit("gc.collect()") |
| self._emit("try:") |
| self._indent += 1 |
| self._emit("torch.cuda.empty_cache()") |
| self._indent -= 1 |
| self._emit("except Exception:") |
| self._indent += 1 |
| self._emit("pass") |
| self._indent -= 1 |
| self._emit(f'lineage["{cmd.teacher}"]["operations"].append({{') |
| self._indent += 1 |
| self._emit('"op": "distill",') |
| self._emit(f'"student": {repr(cmd.student)},') |
| self._emit(f'"steps": {steps},') |
| self._emit(f'"n_examples": len(teacher_data),') |
| self._emit('"timestamp": datetime.now().isoformat(),') |
| self._indent -= 1 |
| self._emit("})") |
|
|
| def _emit_rollback(self, cmd: RollbackCmd) -> None: |
| """ROLLBACK - revert to the most recent snapshot. |
| |
| Looks for the latest snapshot in td_lang_outputs/snapshots/ for this model, |
| then reloads from it. |
| """ |
| self._emit(f'print("[td_lang] Rolling back {cmd.target}...")') |
| self._emit("import glob as _glob") |
| self._emit(f'snap_pattern = os.path.join("td_lang_outputs", "snapshots", "{cmd.target}_*")') |
| self._emit("snapshots = sorted(_glob.glob(snap_pattern))") |
| self._emit("if not snapshots:") |
| self._indent += 1 |
| self._emit(f'print("[td_lang] ERROR: No snapshots found for {cmd.target}. Cannot rollback.")') |
| self._emit(f'print("[td_lang] Hint: use snapshot {cmd.target} before training to create restore points.")') |
| self._indent -= 1 |
| self._emit("else:") |
| self._indent += 1 |
| self._emit("latest_snap = snapshots[-1]") |
| self._emit('print(f"[td_lang] Found {len(snapshots)} snapshots. Reverting to: {latest_snap}")') |
| self._emit(f'models["{cmd.target}"]["checkpoint"] = latest_snap') |
| self._emit(f'lineage["{cmd.target}"]["operations"].append({{') |
| self._indent += 1 |
| self._emit('"op": "rollback",') |
| self._emit('"snapshot": latest_snap,') |
| self._emit('"timestamp": datetime.now().isoformat(),') |
| self._indent -= 1 |
| self._emit("})") |
| self._emit(f'print(f"[td_lang] Rollback complete. {cmd.target} now points to {{latest_snap}}")') |
| self._indent -= 1 |
|
|
| def _emit_summary(self) -> None: |
| self._emit("# --- Final Summary ---") |
| self._emit("elapsed = time.time() - start_time") |
| self._emit('print("\\n" + "=" * 60)') |
| self._emit('print("TD LANG COMPLETE")') |
| self._emit('print("=" * 60)') |
| self._emit('print(f" Time: {elapsed / 60:.1f} minutes")') |
| self._emit('print(f" Models: {list(models.keys())}")') |
| self._emit('print(f" Merged stages: {merged_stages}")') |
| self._emit('print("=" * 60)') |
| self._emit('td_notify(f"TD pipeline DONE in {elapsed / 60:.1f} min. Models: {list(models.keys())}")') |
|
|
| |
| def _emit(self, line: str) -> None: |
| if line == "": |
| self._lines.append("") |
| else: |
| prefix = " " * self._indent |
| self._lines.append(prefix + line) |
|
|
| def _emit_comment(self, text: str) -> None: |
| self._emit(f"# {text}") |
|
|
|
|
| def compile_program(program: TDProgram) -> str: |
| """Public helper to compile a TDProgram into Python code.""" |
| return TDCompiler().compile(program) |
|
|