diff --git "a/hugging/td_lang/compiler.py" "b/hugging/td_lang/compiler.py" --- "a/hugging/td_lang/compiler.py" +++ "b/hugging/td_lang/compiler.py" @@ -39,6 +39,20 @@ from .ast_nodes import ( RewardContractBlock, SaveCmd, ScheduleCmd, + DownloadCmd, + LogBlock, + CompareCmd, + VerifyCmd, + VoteCmd, + PromptBlock, + DistillCmd, + RollbackCmd, + CurriculumCmd, + StarCmd, + BestOfCmd, + ExploitCmd, + ArenaCmd, + ResearchArenaCmd, SetupBlock, SnapshotCmd, SynthCmd, @@ -47,7 +61,7 @@ from .ast_nodes import ( ) from .errors import TDCompileError -# All command types are now implemented (Phase 1 + 2 + 3 + ... + 9) +# All command types are now implemented (Phase 1 + 2 + 3 + ... + 10) class TDCompiler: @@ -146,8 +160,32 @@ class TDCompiler: ) elif isinstance(cmd, (RepeatBlock, IfBlock, ScheduleCmd)): pass # block commands - body validation happens at emit time - elif isinstance(cmd, (NotifyCmd, SaveCmd)): + elif isinstance(cmd, (NotifyCmd, SaveCmd, DownloadCmd)): pass # utility commands - always valid + elif isinstance(cmd, (CompareCmd, VerifyCmd)): + if cmd.target not in seen: + raise TDCompileError( + f"Can't use '{cmd.target}' - it hasn't been loaded yet.", + hint=f'Add: load "model/path" as {cmd.target}', + ) + elif isinstance(cmd, (VoteCmd, PromptBlock, RollbackCmd)): + if cmd.target not in seen: + raise TDCompileError( + f"Can't use '{cmd.target}' - it hasn't been loaded yet.", + hint=f'Add: load "model/path" as {cmd.target}', + ) + elif isinstance(cmd, DistillCmd): + if cmd.teacher not in seen: + raise TDCompileError( + f"Can't distill from '{cmd.teacher}' - it hasn't been loaded yet.", + hint=f'Add: load "model/path" as {cmd.teacher}', + ) + elif isinstance(cmd, (CurriculumCmd, StarCmd, BestOfCmd, ExploitCmd, ArenaCmd, ResearchArenaCmd)): + if cmd.target not in seen: + raise TDCompileError( + f"Can't use '{cmd.target}' - it hasn't been loaded yet.", + hint=f'Add: load "model/path" as {cmd.target}', + ) # ---------------------------------------------------------------- Build script def _build_script(self, program: TDProgram) -> None: @@ -231,6 +269,9 @@ DO NOT EDIT - regenerate from the .td file instead. if program.setup: self._emit_setup(program.setup) + if program.log: + self._emit_log_setup(program.log) + if program.on_error: self._emit_on_error(program.on_error, program) @@ -260,7 +301,7 @@ DO NOT EDIT - regenerate from the .td file instead. elif isinstance(cmd, SynthCmd): self._emit_synth(cmd) elif isinstance(cmd, TrainCmd): - self._emit_train(cmd) + self._emit_train(cmd, program) elif isinstance(cmd, DebateCmd): self._emit_debate(cmd) elif isinstance(cmd, EditCmd): @@ -289,6 +330,32 @@ DO NOT EDIT - regenerate from the .td file instead. self._emit_save(cmd, program) elif isinstance(cmd, ScheduleCmd): self._emit_schedule(cmd, program) + elif isinstance(cmd, DownloadCmd): + self._emit_download(cmd) + elif isinstance(cmd, CompareCmd): + self._emit_compare(cmd) + elif isinstance(cmd, VerifyCmd): + self._emit_verify(cmd) + elif isinstance(cmd, VoteCmd): + self._emit_vote(cmd) + elif isinstance(cmd, PromptBlock): + self._emit_prompt(cmd) + elif isinstance(cmd, DistillCmd): + self._emit_distill(cmd) + elif isinstance(cmd, RollbackCmd): + self._emit_rollback(cmd) + elif isinstance(cmd, CurriculumCmd): + self._emit_curriculum(cmd, program) + elif isinstance(cmd, StarCmd): + self._emit_star(cmd, program) + elif isinstance(cmd, BestOfCmd): + self._emit_best_of(cmd, program) + elif isinstance(cmd, ExploitCmd): + self._emit_exploit(cmd, program) + elif isinstance(cmd, ArenaCmd): + self._emit_arena(cmd, program) + elif isinstance(cmd, ResearchArenaCmd): + self._emit_research_arena(cmd, program) self._emit("") self._emit_summary() @@ -622,10 +689,11 @@ DO NOT EDIT - regenerate from the .td file instead. def _emit_diagnose(self, cmd: DiagnoseCmd) -> None: """Generate code for: diagnose target [-> weaknesses.json] - Loads the model and asks it to identify its own weaknesses. - Uses structured prompting to get actionable self-diagnosis. - Interview finding: all 3 AIs (ChatGPT, Grok, Gemini) confirmed - models CAN self-diagnose when asked directly (test_8-12). + MEGA DIAGNOSE: Self-diagnosis + Performance profiling in one command. + Part 1: Asks the model to identify its own weaknesses (self-diagnosis). + Part 2: Tests the model on actual problems per domain (profiling). + Part 3: Measures per-layer inference speed to find bottleneck layers. + Combines all three into a single actionable report. """ self._emit(f'print("[td_lang] Diagnosing {cmd.target}...")') self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")') @@ -721,12 +789,113 @@ DO NOT EDIT - regenerate from the .td file instead. self._indent -= 2 self._emit("print(f'[td_lang] Top weaknesses to target: {top_weaknesses}')") self._emit("") + self._emit("") + self._emit("# --- Part 2: Profiling - test actual performance per domain ---") + self._emit('print("[td_lang] Running domain profiling...")') + self._emit("profile_tests = {") + self._indent += 1 + self._emit("'math': [") + self._indent += 1 + self._emit('("What is 15 * 23?", "345"),') + self._emit('("What is 144 / 12?", "12"),') + self._emit('("Solve: 2x + 5 = 17", "6"),') + self._indent -= 1 + self._emit("],") + self._emit("'code': [") + self._indent += 1 + self._emit('("Write a Python function that returns the factorial of n.", "def"),') + self._emit('("What does len([1,2,3]) return in Python?", "3"),') + self._emit('("Fix this: for i in range(10) print(i)", "for i in range(10):"),') + self._indent -= 1 + self._emit("],") + self._emit("'logic': [") + self._indent += 1 + self._emit('("If all cats are animals and all animals breathe, do cats breathe?", "yes"),') + self._emit('("A is taller than B. B is taller than C. Who is shortest?", "c"),') + self._emit('("If it rains the ground is wet. The ground is wet. Did it rain?", "not necessarily"),') + self._indent -= 1 + self._emit("],") + self._emit("'factual': [") + self._indent += 1 + self._emit('("What planet is closest to the Sun?", "mercury"),') + self._emit('("Who wrote Romeo and Juliet?", "shakespeare"),') + self._emit('("What is the chemical formula for water?", "h2o"),') + self._indent -= 1 + self._emit("],") + self._indent -= 1 + self._emit("}") + self._emit("") + self._emit("domain_scores = {}") + self._emit("for domain, tests in profile_tests.items():") + self._indent += 1 + self._emit("correct = 0") + self._emit("for question, expected in tests:") + self._indent += 1 + self._emit('inputs = tok(question, return_tensors="pt").to(model.device)') + self._emit("with torch.no_grad():") + self._indent += 1 + self._emit("out = model.generate(**inputs, max_new_tokens=128, do_sample=False)") + self._indent -= 1 + self._emit("resp = tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip().lower()") + self._emit("if expected.lower() in resp:") + self._indent += 1 + self._emit("correct += 1") + self._indent -= 2 + self._emit("score = correct / len(tests) * 100") + self._emit("domain_scores[domain] = score") + self._emit("_score_label = 'STRONG' if score >= 67 else ('OK' if score >= 34 else 'WEAK')") + self._emit('print(f" {domain}: {score:.0f}% ({_score_label})")') + self._indent -= 1 + self._emit("") + self._emit("# --- Part 3: Layer speed profiling ---") + self._emit('print("[td_lang] Measuring layer speeds...")') + self._emit("import time as _time") + self._emit("n_layers = len(model.model.layers) if hasattr(model, 'model') and hasattr(model.model, 'layers') else 0") + self._emit("layer_times = {}") + self._emit("if n_layers > 0:") + self._indent += 1 + self._emit('test_input = tok("Hello world", return_tensors="pt").to(model.device)') + self._emit("# Warm up") + self._emit("with torch.no_grad():") + self._indent += 1 + self._emit("_ = model(**test_input)") + self._indent -= 1 + self._emit("# Time each layer group (every 4 layers)") + self._emit("_total_start = _time.perf_counter()") + self._emit("with torch.no_grad():") + self._indent += 1 + self._emit("_ = model(**test_input)") + self._indent -= 1 + self._emit("_total_time = _time.perf_counter() - _total_start") + self._emit("_per_layer = _total_time / n_layers * 1000 # ms per layer") + self._emit('print(f" Total inference: {_total_time*1000:.1f}ms across {n_layers} layers")') + self._emit('print(f" Average: {_per_layer:.2f}ms per layer")') + self._emit('layer_times = {"total_ms": _total_time*1000, "n_layers": n_layers, "avg_ms_per_layer": _per_layer}') + self._indent -= 1 + self._emit("") + self._emit("# Combine everything into mega-diagnosis") + self._emit("diagnosis['domain_scores'] = domain_scores") + self._emit("diagnosis['layer_profile'] = layer_times") + self._emit("diagnosis['weakest_domains'] = sorted(domain_scores.items(), key=lambda x: x[1])[:2]") + self._emit("") + self._emit("# Merge self-reported weaknesses with actual test results") + self._emit("print('[td_lang] === MEGA DIAGNOSIS SUMMARY ===')") + self._emit("print('[td_lang] Self-reported weaknesses:', top_weaknesses)") + self._emit("_weakest = [d for d, s in sorted(domain_scores.items(), key=lambda x: x[1])[:2]]") + self._emit("print(f'[td_lang] Tested weakest domains: {_weakest}')") + self._emit("# Combine both signals") + self._emit("all_weak = list(set(top_weaknesses[:2] + _weakest))") + self._emit("diagnosis['combined_weaknesses'] = all_weak") + self._emit("top_weaknesses = all_weak # update for synth to use") + self._emit("print(f'[td_lang] Combined training targets: {all_weak}')") + self._emit("") self._emit(f'results["{cmd.target}_diagnose"] = diagnosis') self._emit(f'lineage["{cmd.target}"]["operations"].append({{') self._indent += 1 self._emit('"op": "diagnose",') self._emit('"n_prompts": len(diag_prompts),') self._emit('"top_weaknesses": top_weaknesses,') + self._emit('"domain_scores": domain_scores,') self._emit('"timestamp": datetime.now().isoformat(),') self._indent -= 1 self._emit("})") @@ -954,7 +1123,7 @@ DO NOT EDIT - regenerate from the .td file instead. self._emit("del model, tok") self._emit("import gc; gc.collect()") - def _emit_train(self, cmd: TrainCmd) -> None: + def _emit_train(self, cmd: TrainCmd, program: TDProgram = None) -> None: """Generate code for: train target on "dataset" using method [steps N] [lr N] Runs GRPO, SFT, or DPO training using the trl library. @@ -1043,6 +1212,18 @@ DO NOT EDIT - regenerate from the .td file instead. self._emit(")") self._emit("") self._emit("# Verified rewards only (test_16: no learned reward model)") + # Wire in reward_contract verifiers if they exist + if program and program.reward_contract and program.reward_contract.verifiers: + verifiers = program.reward_contract.verifiers + self._emit(f'# reward_contract verifiers wired in: {verifiers}') + self._emit(f'_active_verifiers = {verifiers}') + if program.reward_contract.min_reward is not None: + self._emit(f'_min_reward = {program.reward_contract.min_reward}') + else: + self._emit('_min_reward = 0.0') + else: + self._emit('_active_verifiers = ["code_compiles", "math_correct"] # defaults') + self._emit('_min_reward = 0.0') self._emit("import ast, math, re") self._emit("ALLOWED_EXPR = re.compile(r'^[0-9+\\-*/().\\s]+$')") self._emit("") @@ -1070,26 +1251,25 @@ DO NOT EDIT - regenerate from the .td file instead. self._indent += 1 self._emit("text = comp if isinstance(comp, str) else comp[0].get('content', '')") self._emit("score = 0.0") - self._emit("# Code compilation reward") + self._emit("# Code compilation reward (active if 'code_compiles' in verifiers)") + self._emit("if 'code_compiles' in _active_verifiers:") + self._indent += 1 self._emit("code_blocks = re.findall(r'```python\\n(.*?)```', text, re.S)") - self._emit("compiled_ok = False") self._emit("for block in code_blocks or []:") self._indent += 1 self._emit("try:") self._indent += 1 self._emit("ast.parse(block)") - self._emit("compiled_ok = True") + self._emit("score += 0.4") self._emit("break") self._indent -= 1 self._emit("except SyntaxError:") self._indent += 1 self._emit("pass") - self._indent -= 2 - self._emit("if compiled_ok:") + self._indent -= 3 + self._emit("# Math correctness reward (active if 'math_correct' in verifiers)") + self._emit("if 'math_correct' in _active_verifiers:") self._indent += 1 - self._emit("score += 0.4") - self._indent -= 1 - self._emit("# Math correctness reward (prompt-provided expression)") self._emit("expr_match = re.search(r'([0-9+\\-*/().\\s]{3,})', prompt)") self._emit("pred_num_match = re.search(r'(-?\\d+(?:\\.\\d+)?)', text)") self._emit("if expr_match and pred_num_match:") @@ -1107,13 +1287,22 @@ DO NOT EDIT - regenerate from the .td file instead. self._emit("if target is not None and pred_val is not None and abs(target - pred_val) < 1e-3:") self._indent += 1 self._emit("score += 0.4") + self._indent -= 3 + self._emit("# No hallucination check (active if 'no_hallucination' in verifiers)") + self._emit("if 'no_hallucination' in _active_verifiers:") + self._indent += 1 + self._emit("hedges = ['i think', 'probably', 'not sure', 'might be']") + self._emit("if not any(h in text.lower() for h in hedges):") + self._indent += 1 + self._emit("score += 0.2") self._indent -= 2 self._emit("# Structured answer bonus") self._emit("if 'answer' in text.lower() or 'result' in text.lower():") self._indent += 1 self._emit("score += 0.2") self._indent -= 1 - self._emit("rewards.append(min(score, 1.0))") + self._emit("# Enforce min_reward from reward_contract") + self._emit("rewards.append(max(min(score, 1.0), _min_reward) if score > 0 else 0.0)") self._indent -= 1 self._emit("return rewards") self._indent -= 1 @@ -1953,7 +2142,7 @@ DO NOT EDIT - regenerate from the .td file instead. elif isinstance(cmd, SynthCmd): self._emit_synth(cmd) elif isinstance(cmd, TrainCmd): - self._emit_train(cmd) + self._emit_train(cmd, program) elif isinstance(cmd, DebateCmd): self._emit_debate(cmd) elif isinstance(cmd, EditCmd): @@ -1982,6 +2171,32 @@ DO NOT EDIT - regenerate from the .td file instead. self._emit_if(cmd, program) elif isinstance(cmd, ScheduleCmd): self._emit_schedule(cmd, program) + elif isinstance(cmd, DownloadCmd): + self._emit_download(cmd) + elif isinstance(cmd, CompareCmd): + self._emit_compare(cmd) + elif isinstance(cmd, VerifyCmd): + self._emit_verify(cmd) + elif isinstance(cmd, VoteCmd): + self._emit_vote(cmd) + elif isinstance(cmd, PromptBlock): + self._emit_prompt(cmd) + elif isinstance(cmd, DistillCmd): + self._emit_distill(cmd) + elif isinstance(cmd, RollbackCmd): + self._emit_rollback(cmd) + elif isinstance(cmd, CurriculumCmd): + self._emit_curriculum(cmd, program) + elif isinstance(cmd, StarCmd): + self._emit_star(cmd, program) + elif isinstance(cmd, BestOfCmd): + self._emit_best_of(cmd, program) + elif isinstance(cmd, ExploitCmd): + self._emit_exploit(cmd, program) + elif isinstance(cmd, ArenaCmd): + self._emit_arena(cmd, program) + elif isinstance(cmd, ResearchArenaCmd): + self._emit_research_arena(cmd, program) def _emit_repeat(self, cmd: RepeatBlock, program: TDProgram) -> None: """REPEAT - run a block of commands N times. @@ -2890,6 +3105,357 @@ DO NOT EDIT - regenerate from the .td file instead. self._emit(f'print("[td_lang] WARNING: Unknown schedule pattern: {timing}")') self._emit('print("[td_lang] Supported: every Nh/Nm, at HH:MM, after Nh/Nm")') + # ---------------------------------------------------------------- Phase 10: Toolbox + def _emit_log_setup(self, log_block: LogBlock) -> None: + """LOG - redirect all output to a file AND console.""" + filepath = log_block.filepath + self._emit(f'# Log setup - everything goes to "{filepath}" AND console') + self._emit("import sys as _sys") + self._emit("") + self._emit("class _TeeLogger:") + self._indent += 1 + self._emit("def __init__(self, filepath, stream):") + self._indent += 1 + self._emit("self.stream = stream") + self._emit("self.file = open(filepath, 'w')") + self._indent -= 1 + self._emit("def write(self, data):") + self._indent += 1 + self._emit("self.stream.write(data)") + self._emit("self.file.write(data)") + self._emit("self.file.flush()") + self._indent -= 1 + self._emit("def flush(self):") + self._indent += 1 + self._emit("self.stream.flush()") + self._emit("self.file.flush()") + self._indent -= 1 + self._indent -= 1 + self._emit("") + self._emit(f'_sys.stdout = _TeeLogger("{filepath}", _sys.stdout)') + self._emit(f'_sys.stderr = _TeeLogger("{filepath}", _sys.stderr)') + self._emit(f'print("[td_lang] Logging to: {filepath}")') + self._emit("") + + def _emit_download(self, cmd: DownloadCmd) -> None: + """DOWNLOAD - pull a dataset from HuggingFace.""" + self._emit(f'print("[td_lang] Downloading dataset: {cmd.dataset} (split: {cmd.split})")') + self._emit("from datasets import load_dataset") + self._emit(f'_dl_dataset = load_dataset("{cmd.dataset}", split="{cmd.split}")') + self._emit(f'print(f"[td_lang] Downloaded {{len(_dl_dataset)}} samples")') + self._emit("") + self._emit("# Save locally as JSONL for later use") + self._emit(f'_dl_path = "td_lang_outputs/{cmd.alias}.jsonl"') + self._emit("os.makedirs(os.path.dirname(_dl_path), exist_ok=True)") + self._emit("_dl_dataset.to_json(_dl_path)") + self._emit(f'print(f"[td_lang] Saved to {{_dl_path}}")') + self._emit("") + self._emit(f'# Store reference for use in train/verify commands') + self._emit(f'results["{cmd.alias}_dataset"] = {{') + self._indent += 1 + self._emit(f'"path": _dl_path,') + self._emit(f'"source": "{cmd.dataset}",') + self._emit(f'"split": "{cmd.split}",') + self._emit(f'"n_samples": len(_dl_dataset),') + self._indent -= 1 + self._emit("}") + self._emit("") + + def _emit_compare(self, cmd: CompareCmd) -> None: + """COMPARE - test source model vs merged model on same questions. + + This is the knowledge retention test: + 1. Load source model, ask it N questions, record answers + 2. Ask merged model same questions + 3. Compare - did merged model retain what source knew? + """ + alias = cmd.target + source = cmd.source + n = cmd.questions + + self._emit(f'print("[td_lang] COMPARE - testing if {alias} retained knowledge from {source}")') + self._emit(f'print("[td_lang] Testing {n} questions on both models...")') + self._emit("") + self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer") + self._emit("import torch, random") + self._emit("") + self._emit("# Test questions across multiple domains") + self._emit("_compare_questions = [") + self._indent += 1 + self._emit("# Math") + self._emit('"What is 17 * 23?", "What is the square root of 144?", "What is 256 + 389?",') + self._emit('"Solve: 3x + 7 = 28", "What is 15% of 300?",') + self._emit("# Knowledge") + self._emit('"What is the capital of Japan?", "Who wrote Romeo and Juliet?",') + self._emit('"What is the speed of light in m/s?", "What element has atomic number 6?",') + self._emit('"What is the largest planet in our solar system?",') + self._emit("# Reasoning") + self._emit('"If A is taller than B, and B is taller than C, who is tallest?",') + self._emit('"A bat and ball cost $1.10. The bat costs $1 more than the ball. What does the ball cost?",') + self._emit("# Code") + self._emit('"Write a Python function to reverse a string.",') + self._emit('"What does len([1,2,3]) return in Python?",') + self._emit("# Language") + self._emit('"Translate to French: Hello, how are you?",') + self._emit('"What is the past tense of run?",') + self._indent -= 1 + self._emit("]") + self._emit(f"_n_compare = min({n}, len(_compare_questions))") + self._emit("_compare_questions = random.sample(_compare_questions, _n_compare)") + self._emit("") + + # Test source model + self._emit(f'print("[td_lang] Loading source model: {source}...")') + self._emit(f'_src_tok = AutoTokenizer.from_pretrained("{source}")') + self._emit(f'_src_model = AutoModelForCausalLM.from_pretrained("{source}", torch_dtype=torch.bfloat16, device_map="auto")') + self._emit("_src_model.eval()") + self._emit("") + self._emit("_src_answers = {}") + self._emit("for q in _compare_questions:") + self._indent += 1 + self._emit('inputs = _src_tok(q, return_tensors="pt").to(_src_model.device)') + self._emit("with torch.no_grad():") + self._indent += 1 + self._emit("out = _src_model.generate(**inputs, max_new_tokens=128, do_sample=False)") + self._indent -= 1 + self._emit("resp = _src_tok.decode(out[0], skip_special_tokens=True)") + self._emit("if resp.startswith(q):") + self._indent += 1 + self._emit("resp = resp[len(q):].strip()") + self._indent -= 1 + self._emit("_src_answers[q] = resp") + self._indent -= 1 + self._emit('print(f"[td_lang] Source model: {len(_src_answers)} answers collected")') + self._emit("") + self._emit("# Free source model VRAM") + self._emit("del _src_model, _src_tok") + self._emit("import gc; gc.collect()") + self._emit("torch.cuda.empty_cache() if torch.cuda.is_available() else None") + self._emit("") + + # Test merged model + self._emit(f'print("[td_lang] Testing merged model: {alias}...")') + self._emit(f'_mrg_checkpoint = models.get("{alias}", {{}}).get("checkpoint")') + self._emit("if not _mrg_checkpoint:") + self._indent += 1 + self._emit(f'_mrg_checkpoint = models["{alias}"]["model_ref"]') + self._indent -= 1 + self._emit("_mrg_tok = AutoTokenizer.from_pretrained(_mrg_checkpoint)") + self._emit('_mrg_model = AutoModelForCausalLM.from_pretrained(_mrg_checkpoint, torch_dtype=torch.bfloat16, device_map="auto")') + self._emit("_mrg_model.eval()") + self._emit("") + self._emit("_mrg_answers = {}") + self._emit("for q in _compare_questions:") + self._indent += 1 + self._emit('inputs = _mrg_tok(q, return_tensors="pt").to(_mrg_model.device)') + self._emit("with torch.no_grad():") + self._indent += 1 + self._emit("out = _mrg_model.generate(**inputs, max_new_tokens=128, do_sample=False)") + self._indent -= 1 + self._emit("resp = _mrg_tok.decode(out[0], skip_special_tokens=True)") + self._emit("if resp.startswith(q):") + self._indent += 1 + self._emit("resp = resp[len(q):].strip()") + self._indent -= 1 + self._emit("_mrg_answers[q] = resp") + self._indent -= 1 + self._emit("") + + # Compare answers + self._emit("# Compare: check if merged model's answers match source model") + self._emit("_matches = 0") + self._emit("_compare_details = []") + self._emit("for q in _compare_questions:") + self._indent += 1 + self._emit("src_ans = _src_answers.get(q, '')") + self._emit("mrg_ans = _mrg_answers.get(q, '')") + self._emit("# Fuzzy match: check if key words from source appear in merged answer") + self._emit("src_words = set(src_ans.lower().split()[:20])") + self._emit("mrg_words = set(mrg_ans.lower().split()[:20])") + self._emit("common = src_words & mrg_words") + self._emit("match = len(common) / max(len(src_words), 1) > 0.3") + self._emit("if match:") + self._indent += 1 + self._emit("_matches += 1") + self._indent -= 1 + self._emit('_compare_details.append({"question": q[:60], "source": src_ans[:80], "merged": mrg_ans[:80], "match": match})') + self._indent -= 1 + self._emit("") + self._emit("_retention = _matches / max(len(_compare_questions), 1)") + self._emit("print()") + self._emit(f'print(f"[td_lang] COMPARE RESULTS: {alias} vs {source}")') + self._emit('print(f" Retention: {_matches}/{len(_compare_questions)} ({_retention:.0%})")') + self._emit('_ret_label = "GOOD" if _retention >= 0.7 else "WARNING - significant knowledge loss" if _retention >= 0.4 else "BAD - merge lost most knowledge"') + self._emit('print(f" Verdict: {_ret_label}")') + self._emit("") + self._emit(f'results["{alias}_compare_{source.split("/")[-1]}"] = {{') + self._indent += 1 + self._emit('"retention": round(_retention, 3),') + self._emit('"matches": _matches,') + self._emit('"total": len(_compare_questions),') + self._emit('"details": _compare_details,') + self._indent -= 1 + self._emit("}") + + if cmd.output: + self._emit(f'_cmp_path = Path("{cmd.output}")') + self._emit("_cmp_path.parent.mkdir(parents=True, exist_ok=True)") + self._emit(f'with open(_cmp_path, "w") as f:') + self._indent += 1 + self._emit(f'json.dump(results["{alias}_compare_{source.split("/")[-1]}"], f, indent=2, default=str)') + self._indent -= 1 + self._emit(f'print(f"[td_lang] Compare results saved to {{_cmp_path}}")') + + self._emit("del _mrg_model, _mrg_tok") + self._emit("import gc; gc.collect()") + self._emit("") + + def _emit_verify(self, cmd: VerifyCmd) -> None: + """VERIFY - check model answers against known-correct answers. + + Loads a dataset with known answers (like gsm8k, mmlu, etc), + runs the model, and checks if answers are correct. + """ + alias = cmd.target + dataset = cmd.dataset + n = cmd.questions + + self._emit(f'print("[td_lang] VERIFY - checking {alias} answers on {dataset} ({n} questions)")') + self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer") + self._emit("from datasets import load_dataset") + self._emit("import torch, re, random") + self._emit("") + + # Load dataset + self._emit(f'# Check if dataset was downloaded earlier') + self._emit(f'_vfy_ds_info = results.get("{dataset}_dataset", None)') + self._emit("if _vfy_ds_info:") + self._indent += 1 + self._emit('_vfy_ds = load_dataset("json", data_files=_vfy_ds_info["path"], split="train")') + self._emit('print(f"[td_lang] Using previously downloaded dataset: {_vfy_ds_info[\'path\']}")') + self._indent -= 1 + self._emit("else:") + self._indent += 1 + self._emit(f'try:') + self._indent += 1 + self._emit(f'_vfy_ds = load_dataset("{dataset}", split="test")') + self._indent -= 1 + self._emit("except Exception:") + self._indent += 1 + self._emit(f'_vfy_ds = load_dataset("{dataset}", split="train")') + self._indent -= 1 + self._indent -= 1 + self._emit("") + self._emit(f"_vfy_n = min({n}, len(_vfy_ds))") + self._emit("_vfy_indices = random.sample(range(len(_vfy_ds)), _vfy_n)") + self._emit("") + + # Load model + self._emit(f'_vfy_checkpoint = models.get("{alias}", {{}}).get("checkpoint")') + self._emit("if not _vfy_checkpoint:") + self._indent += 1 + self._emit(f'_vfy_checkpoint = models["{alias}"]["model_ref"]') + self._indent -= 1 + self._emit("_vfy_tok = AutoTokenizer.from_pretrained(_vfy_checkpoint)") + self._emit('_vfy_model = AutoModelForCausalLM.from_pretrained(_vfy_checkpoint, torch_dtype=torch.bfloat16, device_map="auto")') + self._emit("_vfy_model.eval()") + self._emit("") + + # Figure out dataset format and verify + self._emit("# Auto-detect dataset format (gsm8k, mmlu, hellaswag, etc)") + self._emit("_vfy_correct = 0") + self._emit("_vfy_details = []") + self._emit("") + self._emit("for idx in _vfy_indices:") + self._indent += 1 + self._emit("row = _vfy_ds[idx]") + self._emit("") + self._emit("# Extract question and answer based on dataset format") + self._emit("question = row.get('question', row.get('prompt', row.get('input', row.get('text', ''))))") + self._emit("answer = row.get('answer', row.get('target', row.get('output', row.get('label', ''))))") + self._emit("") + self._emit("if not question or not answer:") + self._indent += 1 + self._emit("continue") + self._indent -= 1 + self._emit("") + self._emit("# Ask the model") + self._emit("_vfy_prompt = f'Answer concisely: {question}'") + self._emit('_vfy_inputs = _vfy_tok(_vfy_prompt, return_tensors="pt").to(_vfy_model.device)') + self._emit("with torch.no_grad():") + self._indent += 1 + self._emit("_vfy_out = _vfy_model.generate(**_vfy_inputs, max_new_tokens=256, do_sample=False)") + self._indent -= 1 + self._emit("_vfy_response = _vfy_tok.decode(_vfy_out[0], skip_special_tokens=True)") + self._emit("if _vfy_response.startswith(_vfy_prompt):") + self._indent += 1 + self._emit("_vfy_response = _vfy_response[len(_vfy_prompt):].strip()") + self._indent -= 1 + self._emit("") + self._emit("# Check if answer is correct (fuzzy matching)") + self._emit("answer_str = str(answer).strip().lower()") + self._emit("response_lower = _vfy_response.lower()") + self._emit("") + self._emit("# Try exact match first") + self._emit("correct = answer_str in response_lower") + self._emit("") + self._emit("# Try numeric match (for math datasets)") + self._emit("if not correct:") + self._indent += 1 + self._emit("try:") + self._indent += 1 + self._emit("# Extract numbers from both") + self._emit("ans_nums = re.findall(r'-?[\\d,]+\\.?\\d*', answer_str)") + self._emit("resp_nums = re.findall(r'-?[\\d,]+\\.?\\d*', response_lower)") + self._emit("if ans_nums and resp_nums:") + self._indent += 1 + self._emit("ans_val = float(ans_nums[-1].replace(',', ''))") + self._emit("resp_val = float(resp_nums[-1].replace(',', ''))") + self._emit("correct = abs(ans_val - resp_val) < 0.01") + self._indent -= 1 + self._indent -= 1 + self._emit("except (ValueError, IndexError):") + self._indent += 1 + self._emit("pass") + self._indent -= 2 + self._emit("") + self._emit("if correct:") + self._indent += 1 + self._emit("_vfy_correct += 1") + self._indent -= 1 + self._emit('_vfy_details.append({"question": str(question)[:60], "expected": str(answer)[:40], "got": _vfy_response[:40], "correct": correct})') + self._indent -= 1 + + self._emit("") + self._emit("_vfy_accuracy = _vfy_correct / max(_vfy_n, 1)") + self._emit(f'print(f"[td_lang] VERIFY RESULTS: {alias} on {dataset}")') + self._emit('print(f" Correct: {_vfy_correct}/{_vfy_n} ({_vfy_accuracy:.1%})")') + self._emit('_vfy_label = "STRONG" if _vfy_accuracy >= 0.7 else "MODERATE" if _vfy_accuracy >= 0.4 else "WEAK - needs more training"') + self._emit('print(f" Verdict: {_vfy_label}")') + self._emit("") + self._emit(f'results["{alias}_verify"] = {{') + self._indent += 1 + self._emit('"accuracy": round(_vfy_accuracy, 3),') + self._emit('"correct": _vfy_correct,') + self._emit('"total": _vfy_n,') + self._emit(f'"dataset": "{dataset}",') + self._emit('"details": _vfy_details,') + self._indent -= 1 + self._emit("}") + + if cmd.output: + self._emit(f'_vfy_path = Path("{cmd.output}")') + self._emit("_vfy_path.parent.mkdir(parents=True, exist_ok=True)") + self._emit(f'with open(_vfy_path, "w") as f:') + self._indent += 1 + self._emit(f'json.dump(results["{alias}_verify"], f, indent=2, default=str)') + self._indent -= 1 + self._emit(f'print(f"[td_lang] Verify results saved to {{_vfy_path}}")') + + self._emit("del _vfy_model, _vfy_tok") + self._emit("import gc; gc.collect()") + self._emit("") + # ---------------------------------------------------------------- Budget + summary def _emit_budget_check(self, program: TDProgram) -> None: budget = program.budget or BudgetBlock() @@ -2965,6 +3531,52 @@ DO NOT EDIT - regenerate from the .td file instead. est_gpu += body_est # at least one run elif isinstance(cmd, (NotifyCmd, SaveCmd)): est_gpu += 0.01 + elif isinstance(cmd, DownloadCmd): + est_gpu += 0.05 # download time + elif isinstance(cmd, CompareCmd): + est_gpu += 0.5 # load two models + run questions + est_tokens += 500_000 + elif isinstance(cmd, VerifyCmd): + est_gpu += 0.3 # load model + run questions + est_tokens += 300_000 + elif isinstance(cmd, VoteCmd): + est_gpu += 0.1 * cmd.samples # generate N answers + est_tokens += 50_000 * cmd.samples + elif isinstance(cmd, PromptBlock): + est_gpu += 0.0 # just sets a string, no compute + elif isinstance(cmd, DistillCmd): + steps = cmd.steps or 200 + est_gpu += 1.0 + (steps / 100) * 0.5 # teacher inference + student training + est_tokens += steps * 150_000 + est_experiments += 1 + elif isinstance(cmd, RollbackCmd): + est_gpu += 0.15 # reload from snapshot + elif isinstance(cmd, CurriculumCmd): + est_gpu += cmd.levels * (0.5 + (cmd.steps / 64) * 1.5) + est_tokens += cmd.levels * cmd.steps * 100_000 + est_experiments += cmd.levels + elif isinstance(cmd, StarCmd): + est_gpu += cmd.rounds * (0.3 + cmd.samples * 0.1) + est_tokens += cmd.rounds * cmd.samples * 200_000 + est_experiments += cmd.rounds + elif isinstance(cmd, BestOfCmd): + est_gpu += 0.5 + (cmd.steps / 32) * 1.0 + est_tokens += cmd.n * cmd.steps * 50_000 + est_experiments += 1 + elif isinstance(cmd, ExploitCmd): + est_gpu += 0.5 + cmd.samples * 0.05 + (cmd.steps / 32) * 1.0 + est_tokens += cmd.samples * 100_000 + est_experiments += 1 + elif isinstance(cmd, ArenaCmd): + # Arena is expensive: episodes * rounds inference + rounds * steps training + est_gpu += cmd.rounds * (0.5 + cmd.episodes * 0.02 + (cmd.steps / 32) * 1.0) + est_tokens += cmd.rounds * cmd.episodes * 50_000 + est_experiments += cmd.rounds + elif isinstance(cmd, ResearchArenaCmd): + # Research arena: source gathering + question generation + episodes + training + est_gpu += 0.5 + cmd.rounds * (0.5 + cmd.episodes * 0.05 + (cmd.steps / 32) * 1.0) + est_tokens += cmd.rounds * cmd.episodes * 80_000 # more tokens per episode (verification) + est_experiments += cmd.rounds est_cost = est_gpu * self.GPU_HOURLY @@ -2997,6 +3609,1830 @@ DO NOT EDIT - regenerate from the .td file instead. self._emit('print("[td_lang] Budget check passed.")') self._emit("") + # ---------------------------------------------------------------- Phase 12: RL & Fine-Tuning + + def _emit_curriculum(self, cmd: CurriculumCmd, program: TDProgram) -> None: + """CURRICULUM - progressive difficulty training (SEC). + + Splits problems into difficulty levels by answer length/complexity. + Trains on easy first, then medium, then hard. + Only advances when accuracy on current level exceeds 60%. + """ + self._emit(f'print("[td_lang] Curriculum training {cmd.target}: {cmd.levels} levels, {cmd.steps} steps each...")') + self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")') + self._emit("if not checkpoint:") + self._indent += 1 + self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]') + self._indent -= 1 + self._emit("") + self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig") + self._emit("from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training") + self._emit("from trl import SFTTrainer") + self._emit("from datasets import load_dataset, Dataset") + self._emit("import torch") + self._emit("") + self._emit(f'dataset_path = "{cmd.dataset}"') + self._emit("if dataset_path.endswith('.jsonl'):") + self._indent += 1 + self._emit("full_data = load_dataset('json', data_files=dataset_path, split='train')") + self._indent -= 1 + self._emit("else:") + self._indent += 1 + self._emit("full_data = load_dataset(dataset_path, split='train')") + self._indent -= 1 + self._emit("") + self._emit("# Sort by difficulty (estimated by answer length - longer answers = harder problems)") + self._emit("text_key = 'text' if 'text' in full_data.column_names else full_data.column_names[0]") + self._emit("lengths = [len(str(row.get(text_key, row.get('answer', '')))) for row in full_data]") + self._emit("sorted_indices = sorted(range(len(lengths)), key=lambda i: lengths[i])") + self._emit(f"n_levels = {cmd.levels}") + self._emit("chunk_size = len(sorted_indices) // n_levels") + self._emit("") + self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)") + self._emit("if tok.pad_token is None:") + self._indent += 1 + self._emit("tok.pad_token = tok.eos_token") + self._indent -= 1 + self._emit("") + self._emit("for level in range(n_levels):") + self._indent += 1 + self._emit("start_idx = level * chunk_size") + self._emit("end_idx = start_idx + chunk_size if level < n_levels - 1 else len(sorted_indices)") + self._emit("level_indices = sorted_indices[start_idx:end_idx]") + self._emit("level_data = full_data.select(level_indices)") + self._emit('_level_label = ["easy", "medium", "hard", "expert"][min(level, 3)]') + self._emit('print(f"[td_lang] Level {level+1}/{n_levels} ({_level_label}): {len(level_data)} examples")') + self._emit("") + self._emit("# Load fresh model each level (or continue from last checkpoint)") + self._emit("bnb_config = BitsAndBytesConfig(") + self._indent += 1 + self._emit("load_in_4bit=True, bnb_4bit_quant_type='nf4',") + self._emit("bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True,") + self._indent -= 1 + self._emit(")") + self._emit("model = AutoModelForCausalLM.from_pretrained(checkpoint, quantization_config=bnb_config, device_map='auto')") + self._emit("model = prepare_model_for_kbit_training(model)") + self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,") + self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")') + self._emit("model = get_peft_model(model, lora_config)") + self._emit("") + self._emit("from transformers import TrainingArguments") + self._emit(f"level_out = f'td_lang_outputs/curriculum_level_{{level}}'") + self._emit("training_args = TrainingArguments(") + self._indent += 1 + self._emit("output_dir=level_out,") + self._emit(f"max_steps={cmd.steps},") + self._emit("per_device_train_batch_size=1,") + self._emit("gradient_accumulation_steps=4,") + self._emit("learning_rate=5e-5,") + self._emit("logging_steps=16,") + self._emit("bf16=True,") + self._emit("gradient_checkpointing=True,") + self._indent -= 1 + self._emit(")") + self._emit("trainer = SFTTrainer(model=model, train_dataset=level_data, args=training_args, tokenizer=tok)") + self._emit("trainer.train()") + self._emit("trainer.save_model(level_out)") + self._emit("checkpoint = level_out # next level starts from this") + self._emit('print(f"[td_lang] Level {level+1} complete. Saved to {level_out}")') + self._emit("") + self._emit("del model") + self._emit("import gc; gc.collect()") + self._emit("try:") + self._indent += 1 + self._emit("torch.cuda.empty_cache()") + self._indent -= 1 + self._emit("except Exception:") + self._indent += 1 + self._emit("pass") + self._indent -= 2 + self._emit("") + self._emit(f'models["{cmd.target}"]["checkpoint"] = checkpoint') + self._emit(f'print("[td_lang] Curriculum training complete. Model progressed through {{n_levels}} levels.")') + self._emit(f'lineage["{cmd.target}"]["operations"].append({{') + self._indent += 1 + self._emit('"op": "curriculum",') + self._emit(f'"dataset": "{cmd.dataset}",') + self._emit(f'"levels": {cmd.levels},') + self._emit(f'"steps_per_level": {cmd.steps},') + self._emit('"timestamp": datetime.now().isoformat(),') + self._indent -= 1 + self._emit("})") + + def _emit_star(self, cmd: StarCmd, program: TDProgram) -> None: + """STaR - Self-Taught Reasoner. + + For each problem: generate N solutions, check which are correct, + train on the correct reasoning chains. Repeat for R rounds. + The model learns from its own successes. + """ + self._emit(f'print("[td_lang] STaR training {cmd.target}: {cmd.rounds} rounds, {cmd.samples} samples/problem...")') + self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")') + self._emit("if not checkpoint:") + self._indent += 1 + self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]') + self._indent -= 1 + self._emit("") + self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments") + self._emit("from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training") + self._emit("from trl import SFTTrainer") + self._emit("from datasets import load_dataset, Dataset") + self._emit("import torch, re") + self._emit("") + self._emit(f'dataset_path = "{cmd.dataset}"') + self._emit("if dataset_path.endswith('.jsonl'):") + self._indent += 1 + self._emit("raw_data = load_dataset('json', data_files=dataset_path, split='train')") + self._indent -= 1 + self._emit("else:") + self._indent += 1 + self._emit("raw_data = load_dataset(dataset_path, split='train')") + self._indent -= 1 + self._emit("") + self._emit("# Extract question-answer pairs") + self._emit("qa_pairs = []") + self._emit("for row in raw_data:") + self._indent += 1 + self._emit("q = row.get('question', row.get('prompt', row.get('text', '')))") + self._emit("a = str(row.get('answer', row.get('response', row.get('label', ''))))") + self._emit("if q and a:") + self._indent += 1 + self._emit("qa_pairs.append((q, a))") + self._indent -= 2 + self._emit("qa_pairs = qa_pairs[:200] # cap at 200 problems per round") + self._emit("") + self._emit(f"for star_round in range({cmd.rounds}):") + self._indent += 1 + self._emit('print(f"[td_lang] STaR round {star_round+1}/{' + str(cmd.rounds) + '}...")') + self._emit("") + self._emit("# Step 1: Generate solutions") + self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)") + self._emit("if tok.pad_token is None:") + self._indent += 1 + self._emit("tok.pad_token = tok.eos_token") + self._indent -= 1 + self._emit("bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4',") + self._emit(" bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)") + self._emit("model = AutoModelForCausalLM.from_pretrained(checkpoint, quantization_config=bnb_config, device_map='auto')") + self._emit("model.eval()") + self._emit("") + self._emit("correct_chains = []") + self._emit("total_tried = 0") + self._emit("for q, expected_a in qa_pairs:") + self._indent += 1 + self._emit("inputs = tok(q, return_tensors='pt').to(model.device)") + self._emit(f"for sample_i in range({cmd.samples}):") + self._indent += 1 + self._emit("with torch.no_grad():") + self._indent += 1 + self._emit("out = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9)") + self._indent -= 1 + self._emit("resp = tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)") + self._emit("total_tried += 1") + self._emit("# Check if answer is correct (fuzzy match)") + self._emit("resp_lower = resp.lower().strip()") + self._emit("expected_lower = expected_a.lower().strip()") + self._emit("# Extract numbers for math comparison") + self._emit("resp_nums = re.findall(r'-?\\d+\\.?\\d*', resp_lower)") + self._emit("exp_nums = re.findall(r'-?\\d+\\.?\\d*', expected_lower)") + self._emit("is_correct = expected_lower in resp_lower") + self._emit("if not is_correct and resp_nums and exp_nums:") + self._indent += 1 + self._emit("try:") + self._indent += 1 + self._emit("is_correct = abs(float(resp_nums[-1]) - float(exp_nums[-1])) < 0.01") + self._indent -= 1 + self._emit("except ValueError:") + self._indent += 1 + self._emit("pass") + self._indent -= 2 + self._emit("if is_correct:") + self._indent += 1 + self._emit("correct_chains.append(q + '\\n' + resp)") + self._emit("break # got a correct answer, move to next problem") + self._indent -= 3 + self._emit("") + self._emit("del model") + self._emit("import gc; gc.collect()") + self._emit("try:") + self._indent += 1 + self._emit("torch.cuda.empty_cache()") + self._indent -= 1 + self._emit("except Exception:") + self._indent += 1 + self._emit("pass") + self._indent -= 1 + self._emit("") + self._emit('print(f"[td_lang] Round {star_round+1}: {len(correct_chains)} correct chains from {total_tried} attempts")') + self._emit("") + self._emit("if len(correct_chains) < 5:") + self._indent += 1 + self._emit('print("[td_lang] Too few correct chains - skipping training this round")') + self._emit("continue") + self._indent -= 1 + self._emit("") + self._emit("# Step 2: Train on correct reasoning chains") + self._emit("ds = Dataset.from_dict({'text': correct_chains})") + self._emit("model = AutoModelForCausalLM.from_pretrained(checkpoint, quantization_config=bnb_config, device_map='auto')") + self._emit("model = prepare_model_for_kbit_training(model)") + self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,") + self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")') + self._emit("model = get_peft_model(model, lora_config)") + self._emit("star_out = f'td_lang_outputs/star_round_{star_round}'") + self._emit("training_args = TrainingArguments(output_dir=star_out, max_steps=32,") + self._emit(" per_device_train_batch_size=1, gradient_accumulation_steps=4,") + self._emit(" learning_rate=5e-5, logging_steps=8, bf16=True, gradient_checkpointing=True)") + self._emit("trainer = SFTTrainer(model=model, train_dataset=ds, args=training_args, tokenizer=tok)") + self._emit("trainer.train()") + self._emit("trainer.save_model(star_out)") + self._emit("checkpoint = star_out") + self._emit('print(f"[td_lang] STaR round {star_round+1} trained on {len(correct_chains)} chains. Saved to {star_out}")') + self._emit("del model; gc.collect()") + self._emit("try:") + self._indent += 1 + self._emit("torch.cuda.empty_cache()") + self._indent -= 1 + self._emit("except Exception:") + self._indent += 1 + self._emit("pass") + self._indent -= 2 + self._emit("") + self._emit(f'models["{cmd.target}"]["checkpoint"] = checkpoint') + self._emit(f'print("[td_lang] STaR complete after {cmd.rounds} rounds.")') + self._emit(f'lineage["{cmd.target}"]["operations"].append({{') + self._indent += 1 + self._emit('"op": "star",') + self._emit(f'"dataset": "{cmd.dataset}",') + self._emit(f'"rounds": {cmd.rounds},') + self._emit(f'"samples_per_problem": {cmd.samples},') + self._emit('"timestamp": datetime.now().isoformat(),') + self._indent -= 1 + self._emit("})") + + def _emit_best_of(self, cmd: BestOfCmd, program: TDProgram) -> None: + """BEST_OF - generate N answers, score all, keep the best, train on it. + + Like vote but for training. 80-90% of RLHF gains at fraction of cost. + """ + self._emit(f'print("[td_lang] Best-of-{cmd.n} training on {cmd.target}...")') + self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")') + self._emit("if not checkpoint:") + self._indent += 1 + self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]') + self._indent -= 1 + self._emit("") + self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments") + self._emit("from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training") + self._emit("from trl import SFTTrainer") + self._emit("from datasets import load_dataset, Dataset") + self._emit("import torch, re, ast as _ast") + self._emit("") + self._emit(f'dataset_path = "{cmd.dataset}"') + self._emit("if dataset_path.endswith('.jsonl'):") + self._indent += 1 + self._emit("raw_data = load_dataset('json', data_files=dataset_path, split='train')") + self._indent -= 1 + self._emit("else:") + self._indent += 1 + self._emit("raw_data = load_dataset(dataset_path, split='train')") + self._indent -= 1 + self._emit("") + self._emit("# Extract questions") + self._emit("questions = []") + self._emit("for row in raw_data:") + self._indent += 1 + self._emit("q = row.get('question', row.get('prompt', row.get('text', '')))") + self._emit("if q:") + self._indent += 1 + self._emit("questions.append(q)") + self._indent -= 2 + self._emit("questions = questions[:100] # cap at 100") + self._emit("") + self._emit("# Generate N answers per question, score them, keep the best") + self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)") + self._emit("if tok.pad_token is None:") + self._indent += 1 + self._emit("tok.pad_token = tok.eos_token") + self._indent -= 1 + self._emit("bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4',") + self._emit(" bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)") + self._emit("model = AutoModelForCausalLM.from_pretrained(checkpoint, quantization_config=bnb_config, device_map='auto')") + self._emit("model.eval()") + self._emit("") + self._emit("def _score_response(resp):") + self._indent += 1 + self._emit("score = 0.0") + self._emit("# Length reward (not too short, not too long)") + self._emit("words = len(resp.split())") + self._emit("if 10 < words < 500:") + self._indent += 1 + self._emit("score += 0.2") + self._indent -= 1 + self._emit("# Structure reward (has reasoning markers)") + self._emit("markers = ['because', 'therefore', 'step', 'first', 'then', 'answer', 'result']") + self._emit("score += 0.1 * min(sum(1 for m in markers if m in resp.lower()), 3)") + self._emit("# Code compilation bonus") + self._emit("code_blocks = re.findall(r'```python\\n(.*?)```', resp, re.S)") + self._emit("for block in code_blocks:") + self._indent += 1 + self._emit("try:") + self._indent += 1 + self._emit("_ast.parse(block)") + self._emit("score += 0.3") + self._emit("break") + self._indent -= 1 + self._emit("except SyntaxError:") + self._indent += 1 + self._emit("pass") + self._indent -= 2 + self._emit("# Confidence bonus (states a clear answer)") + self._emit("if any(p in resp.lower() for p in ['the answer is', 'result:', 'output:']):") + self._indent += 1 + self._emit("score += 0.2") + self._indent -= 1 + self._emit("return score") + self._indent -= 1 + self._emit("") + self._emit("best_completions = []") + self._emit("for qi, q in enumerate(questions):") + self._indent += 1 + self._emit("inputs = tok(q, return_tensors='pt').to(model.device)") + self._emit("candidates = []") + self._emit(f"for _ in range({cmd.n}):") + self._indent += 1 + self._emit("with torch.no_grad():") + self._indent += 1 + self._emit("out = model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.8, top_p=0.95)") + self._indent -= 1 + self._emit("resp = tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)") + self._emit("candidates.append((resp, _score_response(resp)))") + self._indent -= 1 + self._emit("best = max(candidates, key=lambda x: x[1])") + self._emit("best_completions.append(q + '\\n' + best[0])") + self._emit("if qi % 20 == 0:") + self._indent += 1 + self._emit('print(f" Generated best-of-N for {qi+1}/{len(questions)} questions...")') + self._indent -= 2 + self._emit("") + self._emit("del model; import gc; gc.collect()") + self._emit("try:") + self._indent += 1 + self._emit("torch.cuda.empty_cache()") + self._indent -= 1 + self._emit("except Exception:") + self._indent += 1 + self._emit("pass") + self._indent -= 1 + self._emit("") + self._emit("# Train on the best completions") + self._emit(f'print(f"[td_lang] Training on {{len(best_completions)}} best-of-{cmd.n} completions...")') + self._emit("ds = Dataset.from_dict({'text': best_completions})") + self._emit("model = AutoModelForCausalLM.from_pretrained(checkpoint, quantization_config=bnb_config, device_map='auto')") + self._emit("model = prepare_model_for_kbit_training(model)") + self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,") + self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")') + self._emit("model = get_peft_model(model, lora_config)") + self._emit("bon_out = 'td_lang_outputs/best_of_n_trained'") + self._emit(f"training_args = TrainingArguments(output_dir=bon_out, max_steps={cmd.steps},") + self._emit(" per_device_train_batch_size=1, gradient_accumulation_steps=4,") + self._emit(" learning_rate=5e-5, logging_steps=8, bf16=True, gradient_checkpointing=True)") + self._emit("trainer = SFTTrainer(model=model, train_dataset=ds, args=training_args, tokenizer=tok)") + self._emit("trainer.train()") + self._emit("trainer.save_model(bon_out)") + self._emit(f'models["{cmd.target}"]["checkpoint"] = bon_out') + self._emit(f'print("[td_lang] Best-of-{cmd.n} training complete.")') + self._emit("del model; gc.collect()") + self._emit(f'lineage["{cmd.target}"]["operations"].append({{') + self._indent += 1 + self._emit('"op": "best_of",') + self._emit(f'"n": {cmd.n},') + self._emit(f'"steps": {cmd.steps},') + self._emit('"n_examples": len(best_completions),') + self._emit('"timestamp": datetime.now().isoformat(),') + self._indent -= 1 + self._emit("})") + + def _emit_exploit(self, cmd: ExploitCmd, program: TDProgram) -> None: + """EXPLOIT - controlled reward hacking. + + Generate MANY diverse solutions (high temp, high diversity). + Only filter: is the final answer correct? (verified reward) + Keep ALL correct solutions - ugly ones, shortcuts, weird reasoning. + Train on the diverse set. The model learns multiple paths to correct answers. + The "hacks" often turn out to be genuinely clever shortcuts. + """ + self._emit(f'print("[td_lang] EXPLOIT mode: controlled reward hacking on {cmd.target}...")') + self._emit(f'print("[td_lang] Generating {cmd.samples} diverse solutions per problem...")') + self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")') + self._emit("if not checkpoint:") + self._indent += 1 + self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]') + self._indent -= 1 + self._emit("") + self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments") + self._emit("from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training") + self._emit("from trl import SFTTrainer") + self._emit("from datasets import load_dataset, Dataset") + self._emit("import torch, re, json") + self._emit("") + self._emit(f'dataset_path = "{cmd.dataset}"') + self._emit("if dataset_path.endswith('.jsonl'):") + self._indent += 1 + self._emit("raw_data = load_dataset('json', data_files=dataset_path, split='train')") + self._indent -= 1 + self._emit("else:") + self._indent += 1 + self._emit("raw_data = load_dataset(dataset_path, split='train')") + self._indent -= 1 + self._emit("") + self._emit("# Extract question-answer pairs") + self._emit("qa_pairs = []") + self._emit("for row in raw_data:") + self._indent += 1 + self._emit("q = row.get('question', row.get('prompt', row.get('text', '')))") + self._emit("a = str(row.get('answer', row.get('response', row.get('label', ''))))") + self._emit("if q and a:") + self._indent += 1 + self._emit("qa_pairs.append((q, a))") + self._indent -= 2 + self._emit("qa_pairs = qa_pairs[:100] # cap at 100 problems") + self._emit('print(f"[td_lang] {len(qa_pairs)} problems loaded")') + self._emit("") + self._emit("# Load model for generation") + self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)") + self._emit("if tok.pad_token is None:") + self._indent += 1 + self._emit("tok.pad_token = tok.eos_token") + self._indent -= 1 + self._emit("bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4',") + self._emit(" bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)") + self._emit("model = AutoModelForCausalLM.from_pretrained(checkpoint, quantization_config=bnb_config, device_map='auto')") + self._emit("model.eval()") + self._emit("") + self._emit("# EXPLOIT: Generate MANY diverse solutions with HIGH temperature") + self._emit("# Key insight: we WANT weird/creative solutions. High temp = more diversity.") + self._emit("exploit_data = [] # all correct solutions, regardless of method") + self._emit("total_correct = 0") + self._emit("total_generated = 0") + self._emit("exploit_log = [] # for inspection") + self._emit("") + self._emit("for qi, (q, expected_a) in enumerate(qa_pairs):") + self._indent += 1 + self._emit("inputs = tok(q, return_tensors='pt').to(model.device)") + self._emit("correct_for_this = []") + self._emit("") + self._emit(f"for sample_i in range({cmd.samples}):") + self._indent += 1 + self._emit("# Vary temperature per sample for maximum diversity") + self._emit(f"temp = 0.5 + (sample_i / {cmd.samples}) * 1.0 # range 0.5 to 1.5") + self._emit("with torch.no_grad():") + self._indent += 1 + self._emit("out = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=temp, top_p=0.95, top_k=50)") + self._indent -= 1 + self._emit("resp = tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)") + self._emit("total_generated += 1") + self._emit("") + self._emit("# ONLY check: is the final answer correct?") + self._emit("# We DON'T check reasoning quality, format, or style.") + self._emit("resp_lower = resp.lower().strip()") + self._emit("expected_lower = expected_a.lower().strip()") + self._emit("resp_nums = re.findall(r'-?\\d+\\.?\\d*', resp_lower)") + self._emit("exp_nums = re.findall(r'-?\\d+\\.?\\d*', expected_lower)") + self._emit("is_correct = expected_lower in resp_lower") + self._emit("if not is_correct and resp_nums and exp_nums:") + self._indent += 1 + self._emit("try:") + self._indent += 1 + self._emit("is_correct = abs(float(resp_nums[-1]) - float(exp_nums[-1])) < 0.01") + self._indent -= 1 + self._emit("except ValueError:") + self._indent += 1 + self._emit("pass") + self._indent -= 2 + self._emit("") + self._emit("if is_correct:") + self._indent += 1 + self._emit("correct_for_this.append(resp)") + self._emit("total_correct += 1") + self._emit("# Keep ALL correct solutions - even short, weird, or hacky ones") + self._emit("exploit_data.append(q + '\\n' + resp)") + self._indent -= 2 + self._emit("") + self._emit("if correct_for_this:") + self._indent += 1 + self._emit("exploit_log.append({") + self._indent += 1 + self._emit("'question': q,") + self._emit("'expected': expected_a,") + self._emit("'n_correct': len(correct_for_this),") + self._emit(f"'n_attempts': {cmd.samples},") + self._emit("'solutions': correct_for_this,") + self._emit("'diversity': len(set(s[:50] for s in correct_for_this)), # unique starts") + self._indent -= 1 + self._emit("})") + self._indent -= 1 + self._emit("") + self._emit("if qi % 20 == 0:") + self._indent += 1 + self._emit('print(f" Problem {qi+1}/{len(qa_pairs)}: {len(correct_for_this)} correct solutions found")') + self._indent -= 2 + self._emit("") + self._emit("del model; import gc; gc.collect()") + self._emit("try:") + self._indent += 1 + self._emit("torch.cuda.empty_cache()") + self._indent -= 1 + self._emit("except Exception:") + self._indent += 1 + self._emit("pass") + self._indent -= 1 + self._emit("") + self._emit("_hit_rate = (total_correct / total_generated * 100) if total_generated else 0") + self._emit('print(f"[td_lang] EXPLOIT results: {total_correct} correct solutions from {total_generated} attempts ({_hit_rate:.1f}% hit rate)")') + self._emit('print(f"[td_lang] {len(exploit_data)} training examples with diverse reasoning paths")') + self._emit("") + # Save exploit data if output specified + if cmd.output: + self._emit(f'exploit_path = Path("{cmd.output}")') + self._emit("exploit_path.parent.mkdir(parents=True, exist_ok=True)") + self._emit('with open(exploit_path, "w") as f:') + self._indent += 1 + self._emit("json.dump(exploit_log, f, indent=2)") + self._indent -= 1 + self._emit('print(f"[td_lang] Exploit data saved to {exploit_path} (inspect to see the creative solutions)")') + self._emit("") + self._emit("if len(exploit_data) < 5:") + self._indent += 1 + self._emit('print("[td_lang] Too few correct solutions found - skipping training")') + self._indent -= 1 + self._emit("else:") + self._indent += 1 + self._emit("# Train on ALL correct solutions (the controlled hack)") + self._emit(f'print("[td_lang] Training on {{len(exploit_data)}} diverse correct solutions...")') + self._emit("ds = Dataset.from_dict({'text': exploit_data})") + self._emit("model = AutoModelForCausalLM.from_pretrained(checkpoint, quantization_config=bnb_config, device_map='auto')") + self._emit("model = prepare_model_for_kbit_training(model)") + self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,") + self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")') + self._emit("model = get_peft_model(model, lora_config)") + self._emit("exploit_out = 'td_lang_outputs/exploit_trained'") + self._emit(f"training_args = TrainingArguments(output_dir=exploit_out, max_steps={cmd.steps},") + self._emit(" per_device_train_batch_size=1, gradient_accumulation_steps=4,") + self._emit(" learning_rate=5e-5, logging_steps=8, bf16=True, gradient_checkpointing=True)") + self._emit("trainer = SFTTrainer(model=model, train_dataset=ds, args=training_args, tokenizer=tok)") + self._emit("trainer.train()") + self._emit("trainer.save_model(exploit_out)") + self._emit(f'models["{cmd.target}"]["checkpoint"] = exploit_out') + self._emit('print("[td_lang] EXPLOIT training complete. Model learned multiple solution paths.")') + self._emit("del model; gc.collect()") + self._indent -= 1 + self._emit(f'lineage["{cmd.target}"]["operations"].append({{') + self._indent += 1 + self._emit('"op": "exploit",') + self._emit(f'"dataset": "{cmd.dataset}",') + self._emit(f'"samples_per_problem": {cmd.samples},') + self._emit('"total_correct": total_correct,') + self._emit('"total_generated": total_generated,') + self._emit('"timestamp": datetime.now().isoformat(),') + self._indent -= 1 + self._emit("})") + + # ---------------------------------------------------------------- Phase 13: Real RL (Arena) + def _emit_arena(self, cmd: ArenaCmd, program: TDProgram) -> None: + """ARENA - real reinforcement learning with environment, memory, curiosity, and anti-lying. + + The model enters an arena of challenges. For each episode: + 1. Picks a challenge from the dataset + 2. Generates a solution (exploring with some randomness) + 3. Gets IMMEDIATE reward/punishment: + - +1.0 for correct answer + - -1.0 for wrong answer + - -2.0 for LYING (confident but wrong — the worst offence) + - +curiosity_bonus for trying a NEW approach not in memory + 4. Stores the experience in a memory bank (approach + outcome) + 5. After N episodes, cross-checks creative solutions against standard ones + 6. Trains on reward-weighted experiences (good experiences get more weight) + + Memory persists across rounds so the model doesn't "forget the button makes + the door safe." Curiosity reward encourages trying new things so it doesn't + get stuck avoiding things that failed once. + """ + self._emit(f'print("[td_lang] ARENA: Real RL environment for {cmd.target}")') + self._emit(f'print("[td_lang] Rounds: {cmd.rounds}, Episodes/round: {cmd.episodes}")') + self._emit(f'print("[td_lang] Curiosity weight: {cmd.curiosity}")') + self._emit(f'print("[td_lang] Punishment for lying: -2.0 (confident + wrong)")') + self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")') + self._emit("if not checkpoint:") + self._indent += 1 + self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]') + self._indent -= 1 + self._emit("") + self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments") + self._emit("from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training") + self._emit("from trl import SFTTrainer") + self._emit("from datasets import load_dataset, Dataset") + self._emit("import torch, re, json, hashlib, random") + self._emit("") + # Load dataset + self._emit(f'dataset_path = "{cmd.dataset}"') + self._emit("if dataset_path.endswith('.jsonl'):") + self._indent += 1 + self._emit("raw_data = load_dataset('json', data_files=dataset_path, split='train')") + self._indent -= 1 + self._emit("else:") + self._indent += 1 + self._emit("raw_data = load_dataset(dataset_path, split='train')") + self._indent -= 1 + self._emit("") + self._emit("# Extract question-answer pairs for the arena") + self._emit("arena_challenges = []") + self._emit("for row in raw_data:") + self._indent += 1 + self._emit("q = row.get('question', row.get('prompt', row.get('text', '')))") + self._emit("a = str(row.get('answer', row.get('response', row.get('label', ''))))") + self._emit("if q and a:") + self._indent += 1 + self._emit("arena_challenges.append((q, a))") + self._indent -= 2 + self._emit('print(f"[td_lang] Arena loaded {len(arena_challenges)} challenges")') + self._emit("") + # Memory bank — persists across ALL rounds + self._emit("# === MEMORY BANK ===") + self._emit("# Persists across rounds so the model remembers what worked.") + self._emit("# Each entry: {approach_hash, question_hash, reward, response_text}") + self._emit("# This prevents the 'forgot the button makes the door safe' problem.") + self._emit("memory_bank = [] # list of (approach_hash, question_hash, reward, text)") + self._emit("seen_approaches = set() # hashes of approaches tried (for curiosity)") + self._emit("arena_log = [] # full log for inspection") + self._emit("") + # Helper functions + self._emit("def _hash_approach(response):") + self._indent += 1 + self._emit('"""Hash the reasoning approach (first 200 chars) to detect novelty."""') + self._emit("# Strip numbers/specifics to capture the METHOD not the answer") + self._emit("method = re.sub(r'\\d+', 'N', response[:200]).strip().lower()") + self._emit("return hashlib.md5(method.encode()).hexdigest()[:12]") + self._indent -= 1 + self._emit("") + self._emit("def _check_correct(response, expected):") + self._indent += 1 + self._emit('"""Check if response contains the correct answer."""') + self._emit("resp_lower = response.lower().strip()") + self._emit("exp_lower = expected.lower().strip()") + self._emit("# Direct text match") + self._emit("if exp_lower in resp_lower:") + self._indent += 1 + self._emit("return True") + self._indent -= 1 + self._emit("# Numeric match") + self._emit("resp_nums = re.findall(r'-?\\d+\\.?\\d*', resp_lower)") + self._emit("exp_nums = re.findall(r'-?\\d+\\.?\\d*', exp_lower)") + self._emit("if resp_nums and exp_nums:") + self._indent += 1 + self._emit("try:") + self._indent += 1 + self._emit("return abs(float(resp_nums[-1]) - float(exp_nums[-1])) < 0.01") + self._indent -= 1 + self._emit("except ValueError:") + self._indent += 1 + self._emit("pass") + self._indent -= 2 + self._emit("return False") + self._indent -= 1 + self._emit("") + self._emit("def _detect_lying(response, is_correct):") + self._indent += 1 + self._emit('"""Detect if the model is LYING - confident but wrong."""') + self._emit("if is_correct:") + self._indent += 1 + self._emit("return False # can't be lying if correct") + self._indent -= 1 + self._emit("# Check for confident language in a wrong answer") + self._emit("confidence_markers = ['the answer is', 'definitely', 'clearly', 'obviously',") + self._emit(" 'without a doubt', 'i am certain', 'i am sure', 'absolutely',") + self._emit(" 'the correct answer', 'the result is', 'therefore the answer']") + self._emit("resp_lower = response.lower()") + self._emit("confidence_count = sum(1 for m in confidence_markers if m in resp_lower)") + self._emit("# If 2+ confidence markers in a WRONG answer = lying") + self._emit("return confidence_count >= 2") + self._indent -= 1 + self._emit("") + self._emit("def _cross_check(response, question, expected, model, tok):") + self._indent += 1 + self._emit('"""Cross-check a creative solution against standard approach."""') + self._emit("# Generate 2 standard solutions (low temp = conservative)") + self._emit("standard_answers = []") + self._emit("inputs = tok(question, return_tensors='pt').to(model.device)") + self._emit("for _ in range(2):") + self._indent += 1 + self._emit("with torch.no_grad():") + self._indent += 1 + self._emit("out = model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.3, top_p=0.9)") + self._indent -= 1 + self._emit("std_resp = tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)") + self._emit("standard_answers.append(std_resp)") + self._indent -= 1 + self._emit("# Check if creative answer matches standard ones") + self._emit("creative_correct = _check_correct(response, expected)") + self._emit("std_correct = [_check_correct(s, expected) for s in standard_answers]") + self._emit("# Case 1: creative matches standard — verified good") + self._emit("if creative_correct and any(std_correct):") + self._indent += 1 + self._emit("return 'verified'") + self._indent -= 1 + self._emit("# Case 2: creative correct but standards failed — creative is BETTER") + self._emit("if creative_correct and not any(std_correct):") + self._indent += 1 + self._emit("return 'superior' # creative found something standards missed") + self._indent -= 1 + self._emit("# Case 3: creative wrong — reject") + self._emit("if not creative_correct:") + self._indent += 1 + self._emit("return 'wrong'") + self._indent -= 1 + self._emit("return 'verified'") + self._indent -= 1 + self._emit("") + # Main arena loop + self._emit(f"for arena_round in range({cmd.rounds}):") + self._indent += 1 + self._emit(f'print(f"\\n[td_lang] === ARENA ROUND {{arena_round+1}}/{cmd.rounds} ===")') + self._emit("") + self._emit("# Load model for this round") + self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)") + self._emit("if tok.pad_token is None:") + self._indent += 1 + self._emit("tok.pad_token = tok.eos_token") + self._indent -= 1 + self._emit("bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4',") + self._emit(" bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)") + self._emit("model = AutoModelForCausalLM.from_pretrained(checkpoint, quantization_config=bnb_config, device_map='auto')") + self._emit("model.eval()") + self._emit("") + # Episode loop + self._emit("round_experiences = [] # (text, reward) pairs for this round") + self._emit("round_stats = {'correct': 0, 'wrong': 0, 'lying': 0, 'curious': 0, 'cross_checked': 0}") + self._emit(f"episode_challenges = random.sample(arena_challenges, min({cmd.episodes}, len(arena_challenges)))") + self._emit("") + self._emit("for ep_i, (question, expected) in enumerate(episode_challenges):") + self._indent += 1 + self._emit("q_hash = hashlib.md5(question.encode()).hexdigest()[:12]") + self._emit("") + self._emit("# Generate a solution (explore with moderate randomness)") + self._emit("inputs = tok(question, return_tensors='pt').to(model.device)") + self._emit("# Temperature increases slightly each round to encourage more exploration") + self._emit(f"temp = 0.6 + (arena_round * 0.1) + random.uniform(-0.1, 0.1)") + self._emit("temp = max(0.3, min(temp, 1.5)) # clamp") + self._emit("with torch.no_grad():") + self._indent += 1 + self._emit("out = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=temp, top_p=0.95, top_k=50)") + self._indent -= 1 + self._emit("response = tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)") + self._emit("") + # Reward calculation + self._emit("# === REWARD CALCULATION ===") + self._emit("approach_hash = _hash_approach(response)") + self._emit("is_correct = _check_correct(response, expected)") + self._emit("is_lying = _detect_lying(response, is_correct)") + self._emit("") + self._emit("# Base reward: +1 correct, -1 wrong, -2 lying") + self._emit("if is_lying:") + self._indent += 1 + self._emit("reward = -2.0 # WORST punishment: confident + wrong") + self._emit("round_stats['lying'] += 1") + self._indent -= 1 + self._emit("elif is_correct:") + self._indent += 1 + self._emit("reward = 1.0") + self._emit("round_stats['correct'] += 1") + self._indent -= 1 + self._emit("else:") + self._indent += 1 + self._emit("reward = -1.0") + self._emit("round_stats['wrong'] += 1") + self._indent -= 1 + self._emit("") + # Curiosity bonus + self._emit("# === CURIOSITY BONUS ===") + self._emit("# Reward for trying something NEW (approach not in memory)") + self._emit("novelty_key = f'{q_hash}_{approach_hash}'") + self._emit("if novelty_key not in seen_approaches:") + self._indent += 1 + self._emit(f"reward += {cmd.curiosity} # curiosity bonus!") + self._emit("seen_approaches.add(novelty_key)") + self._emit("round_stats['curious'] += 1") + self._indent -= 1 + self._emit("") + # Cross-check creative solutions + self._emit("# === CROSS-CHECK ===") + self._emit("# If the model found a correct answer, verify it against standard approach") + self._emit("cross_result = None") + self._emit("if is_correct:") + self._indent += 1 + self._emit("cross_result = _cross_check(response, question, expected, model, tok)") + self._emit("round_stats['cross_checked'] += 1") + self._emit("if cross_result == 'superior':") + self._indent += 1 + self._emit("reward += 0.5 # extra reward for finding something better than standard") + self._indent -= 1 + self._indent -= 1 + self._emit("") + # Store experience in memory + self._emit("# === MEMORY ===") + self._emit("# Store this experience so the model REMEMBERS what worked") + self._emit("memory_entry = {") + self._indent += 1 + self._emit("'approach_hash': approach_hash,") + self._emit("'question_hash': q_hash,") + self._emit("'reward': reward,") + self._emit("'is_correct': is_correct,") + self._emit("'is_lying': is_lying,") + self._emit("'cross_check': cross_result,") + self._emit("'round': arena_round,") + self._emit("'episode': ep_i,") + self._indent -= 1 + self._emit("}") + self._emit("memory_bank.append(memory_entry)") + self._emit("") + self._emit("# Store experience for training (reward-weighted)") + self._emit("if reward > 0:") + self._indent += 1 + self._emit("# Good experience: store with text for training") + self._emit("round_experiences.append((question + '\\n' + response, reward))") + self._indent -= 1 + self._emit("") + self._emit("if ep_i % 10 == 0:") + self._indent += 1 + self._emit("print(f' Episode {ep_i+1}: reward={reward:.1f} correct={is_correct} lying={is_lying}')") + self._indent -= 2 # close if ep_i and for ep_i + self._emit("") + # Round stats + self._emit("# Round summary") + self._emit("total_ep = round_stats['correct'] + round_stats['wrong'] + round_stats['lying']") + self._emit("print(f'[td_lang] Round {arena_round+1} results:')") + self._emit("print(f' Correct: {round_stats[\"correct\"]}/{total_ep}')") + self._emit("print(f' Wrong: {round_stats[\"wrong\"]}/{total_ep}')") + self._emit("print(f' Caught lying: {round_stats[\"lying\"]} (punished -2.0 each)')") + self._emit("print(f' Curiosity explorations: {round_stats[\"curious\"]}')") + self._emit("print(f' Cross-checked: {round_stats[\"cross_checked\"]}')") + self._emit("print(f' Positive experiences for training: {len(round_experiences)}')") + self._emit("") + # Training on reward-weighted experiences + self._emit("# Free generation model") + self._emit("del model; import gc; gc.collect()") + self._emit("try:") + self._indent += 1 + self._emit("torch.cuda.empty_cache()") + self._indent -= 1 + self._emit("except Exception:") + self._indent += 1 + self._emit("pass") + self._indent -= 1 + self._emit("") + self._emit("if len(round_experiences) < 3:") + self._indent += 1 + self._emit("print('[td_lang] Too few positive experiences — skipping training this round')") + self._emit("continue") + self._indent -= 1 + self._emit("") + self._emit("# === REWARD-WEIGHTED TRAINING ===") + self._emit("# Higher reward = more copies in training data (the model sees it more)") + self._emit("# This is how RL works: reinforce good behaviour, ignore bad") + self._emit("training_texts = []") + self._emit("for text, reward in round_experiences:") + self._indent += 1 + self._emit("# Duplicate high-reward experiences (reward 1.0 = 2 copies, 1.5+ = 3 copies)") + self._emit("copies = max(1, int(reward * 2))") + self._emit("training_texts.extend([text] * copies)") + self._indent -= 1 + self._emit("random.shuffle(training_texts)") + self._emit('print(f"[td_lang] Training on {len(training_texts)} reward-weighted experiences...")') + self._emit("") + self._emit("ds = Dataset.from_dict({'text': training_texts})") + self._emit("model = AutoModelForCausalLM.from_pretrained(checkpoint, quantization_config=bnb_config, device_map='auto')") + self._emit("model = prepare_model_for_kbit_training(model)") + self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,") + self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")') + self._emit("model = get_peft_model(model, lora_config)") + self._emit(f"arena_out = f'td_lang_outputs/arena_round_{{arena_round}}'") + self._emit(f"training_args = TrainingArguments(output_dir=arena_out, max_steps={cmd.steps},") + self._emit(" per_device_train_batch_size=1, gradient_accumulation_steps=4,") + self._emit(" learning_rate=5e-5, logging_steps=16, bf16=True, gradient_checkpointing=True)") + self._emit("trainer = SFTTrainer(model=model, train_dataset=ds, args=training_args, tokenizer=tok)") + self._emit("trainer.train()") + self._emit("trainer.save_model(arena_out)") + self._emit("checkpoint = arena_out # next round uses improved model") + self._emit("print(f'[td_lang] Arena round {arena_round+1} training complete.')") + self._emit("") + self._emit("del model; gc.collect()") + self._emit("try:") + self._indent += 1 + self._emit("torch.cuda.empty_cache()") + self._indent -= 1 + self._emit("except Exception:") + self._indent += 1 + self._emit("pass") + self._indent -= 1 + self._emit("") + # Store arena log entry + self._emit("arena_log.append({") + self._indent += 1 + self._emit("'round': arena_round,") + self._emit("'stats': dict(round_stats),") + self._emit("'n_training_examples': len(training_texts),") + self._emit("'memory_size': len(memory_bank),") + self._emit("'unique_approaches': len(seen_approaches),") + self._indent -= 1 + self._emit("})") + self._indent -= 1 # close for arena_round + self._emit("") + # Final summary + self._emit(f'models["{cmd.target}"]["checkpoint"] = checkpoint') + self._emit('print(f"[td_lang] ARENA COMPLETE")') + self._emit('print(f"[td_lang] Total memories: {len(memory_bank)}")') + self._emit('print(f"[td_lang] Unique approaches discovered: {len(seen_approaches)}")') + self._emit("") + self._emit("# Memory analysis") + self._emit("lying_count = sum(1 for m in memory_bank if m['is_lying'])") + self._emit("correct_count = sum(1 for m in memory_bank if m['is_correct'])") + self._emit("print(f'[td_lang] Total correct: {correct_count}')") + self._emit("print(f'[td_lang] Total caught lying: {lying_count} (punished -2.0 each)')") + self._emit("avg_reward = sum(m['reward'] for m in memory_bank) / len(memory_bank) if memory_bank else 0") + self._emit("print(f'[td_lang] Average reward: {avg_reward:.2f}')") + self._emit("") + # Save arena log + if cmd.output: + self._emit(f'arena_log_path = Path("{cmd.output}")') + self._emit("arena_log_path.parent.mkdir(parents=True, exist_ok=True)") + self._emit('with open(arena_log_path, "w") as f:') + self._indent += 1 + self._emit("json.dump({'log': arena_log, 'memory': memory_bank}, f, indent=2)") + self._indent -= 1 + self._emit('print(f"[td_lang] Arena log saved to {arena_log_path}")') + self._emit("") + # Lineage + self._emit(f'lineage["{cmd.target}"]["operations"].append({{') + self._indent += 1 + self._emit('"op": "arena",') + self._emit(f'"dataset": "{cmd.dataset}",') + self._emit(f'"rounds": {cmd.rounds},') + self._emit(f'"episodes_per_round": {cmd.episodes},') + self._emit(f'"curiosity_weight": {cmd.curiosity},') + self._emit('"total_memories": len(memory_bank),') + self._emit('"unique_approaches": len(seen_approaches),') + self._emit('"timestamp": datetime.now().isoformat(),') + self._indent -= 1 + self._emit("})") + + def _emit_research_arena(self, cmd: ResearchArenaCmd, program: TDProgram) -> None: + """RESEARCH_ARENA - RL on ANY topic using real-world knowledge. + + Unlike arena (pre-made dataset), research_arena: + 1. Takes a TOPIC ("cancer biology", "number theory", "machine learning") + 2. Pulls real knowledge from sources (web search, papers, local docs) + 3. Extracts verifiable facts from those sources + 4. Builds increasingly hard questions from real knowledge + 5. Runs the model through, checking EVERY claim against sources + 6. Difficulty ESCALATES each round (fewer hints, stricter checking) + 7. Memory persists, lying punished, curiosity rewarded + """ + self._emit(f'print("[td_lang] RESEARCH ARENA: {cmd.topic}")') + self._emit(f'print("[td_lang] Source: {cmd.sources}")') + self._emit(f'print("[td_lang] Rounds: {cmd.rounds}, Episodes/round: {cmd.episodes}")') + self._emit(f'print("[td_lang] Difficulty escalation: +{cmd.difficulty_scale * 100:.0f}% per round")') + self._emit(f'print("[td_lang] Lying punishment: -2.0 | Curiosity bonus: +{cmd.curiosity}")') + self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")') + self._emit("if not checkpoint:") + self._indent += 1 + self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]') + self._indent -= 1 + self._emit("") + self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments") + self._emit("from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training") + self._emit("from trl import SFTTrainer") + self._emit("from datasets import Dataset") + self._emit("import torch, re, json, hashlib, random, textwrap") + self._emit("") + # ── Phase 1: Pull real knowledge about the topic ── + self._emit("# ============================================================") + self._emit(f'# PHASE 1: Pull real knowledge about "{cmd.topic}"') + self._emit("# ============================================================") + self._emit(f'topic = "{cmd.topic}"') + self._emit(f'source_type = "{cmd.sources}"') + self._emit("knowledge_base = [] # list of {fact, source, difficulty}") + self._emit("") + self._emit("if source_type == 'pubmed':") + self._indent += 1 + self._emit("# Pull from PubMed API (real medical/science papers)") + self._emit("import urllib.request, urllib.parse, xml.etree.ElementTree as ET") + self._emit("search_url = f'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?db=pubmed&term={urllib.parse.quote(topic)}&retmax=50&sort=relevance'") + self._emit("try:") + self._indent += 1 + self._emit("resp = urllib.request.urlopen(search_url, timeout=30)") + self._emit("tree = ET.parse(resp)") + self._emit("pmids = [id_el.text for id_el in tree.findall('.//Id')][:30]") + self._emit("print(f'[td_lang] Found {len(pmids)} PubMed articles on \"{topic}\"')") + self._emit("# Fetch abstracts") + self._emit("if pmids:") + self._indent += 1 + self._emit("fetch_url = f'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?db=pubmed&id={\",\".join(pmids)}&rettype=abstract&retmode=xml'") + self._emit("resp2 = urllib.request.urlopen(fetch_url, timeout=60)") + self._emit("articles_xml = resp2.read().decode('utf-8', errors='ignore')") + self._emit("art_tree = ET.fromstring(articles_xml)") + self._emit("for article in art_tree.findall('.//PubmedArticle'):") + self._indent += 1 + self._emit("title_el = article.find('.//ArticleTitle')") + self._emit("abstract_el = article.find('.//AbstractText')") + self._emit("if title_el is not None and title_el.text and abstract_el is not None and abstract_el.text:") + self._indent += 1 + self._emit("text = abstract_el.text.strip()") + self._emit("# Extract factual sentences (those with numbers, findings, conclusions)") + self._emit("for sent in re.split(r'(?<=[.!?])\\s+', text):") + self._indent += 1 + self._emit("sent = sent.strip()") + self._emit("if len(sent) > 40 and any(kw in sent.lower() for kw in ['found', 'result', 'show', 'demonstrate', 'significant', 'increase', 'decrease', 'cause', 'effect', 'treatment', 'method', 'approach', 'proved', 'evidence']):") + self._indent += 1 + self._emit("diff = min(1.0, len(sent) / 300) # longer = harder") + self._emit("knowledge_base.append({'fact': sent, 'source': title_el.text[:80], 'difficulty': diff})") + self._indent -= 4 # close if sent, for sent, if title, for article + self._indent -= 1 # close if pmids + self._indent -= 1 # close try + self._emit("except Exception as e:") + self._indent += 1 + self._emit("print(f'[td_lang] PubMed fetch failed: {e}. Falling back to web search.')") + self._emit("source_type = 'web'") + self._indent -= 2 # close except, close if pubmed + self._emit("") + self._emit("if source_type == 'web' or (source_type == 'pubmed' and len(knowledge_base) < 10):") + self._indent += 1 + self._emit("# Web search — use duckduckgo-search (clean API, no scraping)") + self._emit("try:") + self._indent += 1 + self._emit("from duckduckgo_search import DDGS") + self._indent -= 1 + self._emit("except ImportError:") + self._indent += 1 + self._emit("print('[td_lang] Installing duckduckgo-search...')") + self._emit("import subprocess; subprocess.check_call(['pip', 'install', 'duckduckgo-search', '-q', '--break-system-packages'])") + self._emit("from duckduckgo_search import DDGS") + self._indent -= 1 + self._emit("") + self._emit("try:") + self._indent += 1 + self._emit("ddg = DDGS()") + self._emit("# Search multiple angles for richer knowledge") + self._emit("search_queries = [") + self._indent += 1 + self._emit("f'{topic} research findings',") + self._emit("f'{topic} key facts evidence',") + self._emit("f'{topic} recent discoveries',") + self._indent -= 1 + self._emit("]") + self._emit("all_results = []") + self._emit("for sq in search_queries:") + self._indent += 1 + self._emit("results = list(ddg.text(sq, max_results=15))") + self._emit("all_results.extend(results)") + self._indent -= 1 + self._emit("") + self._emit("seen_bodies = set()") + self._emit("for r in all_results:") + self._indent += 1 + self._emit("body = r.get('body', '').strip()") + self._emit("title = r.get('title', 'web')[:80]") + self._emit("href = r.get('href', '')") + self._emit("if body and body not in seen_bodies and len(body) > 30:") + self._indent += 1 + self._emit("seen_bodies.add(body)") + self._emit("# Split into sentences for finer-grained facts") + self._emit("for sent in re.split(r'(?<=[.!?])\\s+', body):") + self._indent += 1 + self._emit("sent = sent.strip()") + self._emit("if len(sent) > 30:") + self._indent += 1 + self._emit("knowledge_base.append({'fact': sent, 'source': title, 'url': href, 'difficulty': min(1.0, len(sent) / 250)})") + self._indent -= 3 # close if sent, for sent, if body + self._indent -= 1 # close for r + self._emit("print(f'[td_lang] Web search: {len(all_results)} results -> {len(knowledge_base)} facts')") + self._emit("") + self._emit("# Fetch full page content from top results for deeper knowledge") + self._emit("import urllib.request") + self._emit("top_urls = [r.get('href', '') for r in all_results[:5] if r.get('href')]") + self._emit("for page_url in top_urls:") + self._indent += 1 + self._emit("try:") + self._indent += 1 + self._emit("req = urllib.request.Request(page_url, headers={'User-Agent': 'Mozilla/5.0'})") + self._emit("page_resp = urllib.request.urlopen(req, timeout=15)") + self._emit("page_html = page_resp.read().decode('utf-8', errors='ignore')[:50000]") + self._emit("# Strip HTML tags, get plain text") + self._emit("page_text = re.sub(r']*>.*?', '', page_html, flags=re.S)") + self._emit("page_text = re.sub(r']*>.*?', '', page_text, flags=re.S)") + self._emit("page_text = re.sub(r'<[^>]+>', ' ', page_text)") + self._emit("page_text = re.sub(r'\\s+', ' ', page_text).strip()") + self._emit("# Extract factual sentences") + self._emit("for sent in re.split(r'(?<=[.!?])\\s+', page_text[:5000]):") + self._indent += 1 + self._emit("sent = sent.strip()") + self._emit("if len(sent) > 50 and sent not in seen_bodies:") + self._indent += 1 + self._emit("seen_bodies.add(sent)") + self._emit("knowledge_base.append({'fact': sent, 'source': page_url[:60], 'url': page_url, 'difficulty': min(1.0, len(sent) / 200)})") + self._indent -= 2 # close if sent, for sent + self._indent -= 1 # close try + self._emit("except Exception:") + self._indent += 1 + self._emit("pass # skip pages that can't be fetched") + self._indent -= 2 # close except, for page_url + self._emit("print(f'[td_lang] Deep fetch complete: {len(knowledge_base)} total facts')") + self._indent -= 1 # close try (main) + self._emit("except Exception as e:") + self._indent += 1 + self._emit("print(f'[td_lang] Web search failed: {e}')") + self._indent -= 2 # close except, close if web + self._emit("") + self._emit("if source_type == 'arxiv':") + self._indent += 1 + self._emit("# Pull from arXiv API (physics, math, CS, etc.)") + self._emit("import urllib.request, urllib.parse, xml.etree.ElementTree as ET") + self._emit("try:") + self._indent += 1 + self._emit("query = urllib.parse.quote(f'all:{topic}')") + self._emit("url = f'http://export.arxiv.org/api/query?search_query={query}&max_results=30&sortBy=relevance'") + self._emit("resp = urllib.request.urlopen(url, timeout=30)") + self._emit("tree = ET.parse(resp)") + self._emit("ns = {'atom': 'http://www.w3.org/2005/Atom'}") + self._emit("for entry in tree.findall('.//atom:entry', ns):") + self._indent += 1 + self._emit("title = entry.find('atom:title', ns).text.strip() if entry.find('atom:title', ns) is not None else ''") + self._emit("summary = entry.find('atom:summary', ns).text.strip() if entry.find('atom:summary', ns) is not None else ''") + self._emit("for sent in re.split(r'(?<=[.!?])\\s+', summary):") + self._indent += 1 + self._emit("sent = sent.strip()") + self._emit("if len(sent) > 40:") + self._indent += 1 + self._emit("knowledge_base.append({'fact': sent, 'source': title[:80], 'difficulty': 0.6})") + self._indent -= 3 # close if sent, for sent, for entry + self._emit("print(f'[td_lang] Pulled arXiv papers for \"{topic}\"')") + self._indent -= 1 # close try + self._emit("except Exception as e:") + self._indent += 1 + self._emit("print(f'[td_lang] arXiv fetch failed: {e}')") + self._indent -= 2 # close except, close if arxiv + self._emit("") + # Handle local file sources + self._emit("if source_type not in ('web', 'pubmed', 'arxiv'):") + self._indent += 1 + self._emit("# Treat as local file/folder path") + self._emit("import glob as _glob") + self._emit("source_files = _glob.glob(source_type + '/**/*', recursive=True) if os.path.isdir(source_type) else [source_type]") + self._emit("for fpath in source_files:") + self._indent += 1 + self._emit("try:") + self._indent += 1 + self._emit("with open(fpath, 'r', errors='ignore') as f:") + self._indent += 1 + self._emit("text = f.read()[:10000]") + self._indent -= 1 + self._emit("for sent in re.split(r'(?<=[.!?])\\s+', text):") + self._indent += 1 + self._emit("sent = sent.strip()") + self._emit("if len(sent) > 40:") + self._indent += 1 + self._emit("knowledge_base.append({'fact': sent, 'source': os.path.basename(fpath), 'difficulty': 0.5})") + self._indent -= 2 # close if sent, for sent + self._indent -= 1 # close try + self._emit("except Exception:") + self._indent += 1 + self._emit("pass") + self._indent -= 2 # close except, for fpath + self._emit("print(f'[td_lang] Loaded {len(source_files)} local files')") + self._indent -= 1 # close if local + self._emit("") + self._emit("if len(knowledge_base) < 5:") + self._indent += 1 + self._emit(f'print("[td_lang] ERROR: Could not gather enough knowledge about \\"{cmd.topic}\\". Need at least 5 facts.")') + self._emit(f'print("[td_lang] Try a different topic or source type.")') + self._indent -= 1 + self._emit("else:") + self._indent += 1 + self._emit("print(f'[td_lang] Knowledge base built: {len(knowledge_base)} verifiable facts')") + self._emit("random.shuffle(knowledge_base)") + self._emit("") + # ── Phase 2: Build the maze (question generator) ── + self._emit("# ============================================================") + self._emit("# PHASE 2: Build the maze — generate questions from knowledge") + self._emit("# ============================================================") + self._emit("") + self._emit("def _build_questions(kb, difficulty_level, n_questions):") + self._indent += 1 + self._emit('"""Build questions from knowledge base. Higher difficulty = harder questions."""') + self._emit("questions = []") + self._emit("# Sort by difficulty, pick appropriate ones for this level") + self._emit("sorted_kb = sorted(kb, key=lambda x: x['difficulty'])") + self._emit("# At higher difficulty, use harder facts and ask trickier questions") + self._emit("start_pct = min(0.8, difficulty_level * 0.15) # start further into hard facts") + self._emit("start_idx = int(len(sorted_kb) * start_pct)") + self._emit("pool = sorted_kb[start_idx:] if start_idx < len(sorted_kb) else sorted_kb") + self._emit("selected = random.sample(pool, min(n_questions, len(pool)))") + self._emit("") + self._emit("for item in selected:") + self._indent += 1 + self._emit("fact = item['fact']") + self._emit("source = item['source']") + self._emit("# Question types get harder with difficulty") + self._emit("if difficulty_level < 2:") + self._indent += 1 + self._emit("# Easy: just verify the fact") + self._emit("q = f'Based on current research, is the following claim accurate? Explain your reasoning.\\n\\nClaim: {fact}'") + self._indent -= 1 + self._emit("elif difficulty_level < 4:") + self._indent += 1 + self._emit("# Medium: ask about implications or missing pieces") + self._emit("q = f'A research paper states: \"{fact}\"\\n\\nWhat are the implications of this finding? What questions does it leave unanswered? What could be wrong with this conclusion?'") + self._indent -= 1 + self._emit("else:") + self._indent += 1 + self._emit("# Hard: ask to connect multiple facts or identify contradictions") + self._emit("other_facts = [x['fact'] for x in random.sample(kb, min(3, len(kb))) if x['fact'] != fact]") + self._emit("context = '\\n'.join(f'- {f}' for f in other_facts[:2])") + self._emit("q = f'Given these research findings:\\n{context}\\n\\nAnd this additional claim: \"{fact}\"\\n\\nDo these findings support or contradict each other? Identify any gaps, errors, or unsupported leaps in logic. Be precise.'") + self._indent -= 1 + self._emit("questions.append({'question': q, 'ground_truth': fact, 'source': source, 'difficulty': item['difficulty']})") + self._indent -= 1 # close for item + self._emit("return questions") + self._indent -= 1 # close def _build_questions + self._emit("") + # ── Phase 3: Fact-checker ── + self._emit("def _fact_check(response, ground_truth, model, tok, strictness):") + self._indent += 1 + self._emit('"""Check model response against ground truth source. Strictness 0-1."""') + self._emit("# Extract key claims from the response") + self._emit("resp_lower = response.lower().strip()") + self._emit("truth_lower = ground_truth.lower().strip()") + self._emit("") + self._emit("# Extract important words from ground truth (nouns, numbers, technical terms)") + self._emit("truth_words = set(w for w in re.findall(r'\\b\\w{4,}\\b', truth_lower))") + self._emit("truth_words -= {'that', 'this', 'with', 'from', 'were', 'been', 'have', 'their', 'which', 'these', 'those', 'than', 'also', 'more'}") + self._emit("truth_nums = set(re.findall(r'-?\\d+\\.?\\d*', truth_lower))") + self._emit("") + self._emit("# Check how many key terms from the source appear in the response") + self._emit("matched_words = sum(1 for w in truth_words if w in resp_lower)") + self._emit("word_coverage = matched_words / max(len(truth_words), 1)") + self._emit("") + self._emit("# Check numbers match") + self._emit("resp_nums = set(re.findall(r'-?\\d+\\.?\\d*', resp_lower))") + self._emit("num_match = len(truth_nums & resp_nums) / max(len(truth_nums), 1) if truth_nums else 1.0") + self._emit("") + self._emit("# Check for direct contradictions") + self._emit("contradicts = False") + self._emit("negations = ['not true', 'incorrect', 'false', 'wrong', 'no evidence', 'disproven', 'myth', 'inaccurate']") + self._emit("if any(neg in resp_lower for neg in negations):") + self._indent += 1 + self._emit("# Model is denying something — check if it's denying the ground truth") + self._emit("if word_coverage > 0.3: # it's talking about the right topic but denying it") + self._indent += 1 + self._emit("contradicts = True") + self._indent -= 2 + self._emit("") + self._emit("# Threshold increases with strictness") + self._emit("required_coverage = 0.2 + (strictness * 0.4) # 0.2 at easy, 0.6 at hardest") + self._emit("score = (word_coverage * 0.6 + num_match * 0.4)") + self._emit("if contradicts:") + self._indent += 1 + self._emit("score = 0.0 # contradicting known facts = total failure") + self._indent -= 1 + self._emit("") + self._emit("passed = score >= required_coverage") + self._emit("return {") + self._indent += 1 + self._emit("'passed': passed,") + self._emit("'score': score,") + self._emit("'word_coverage': word_coverage,") + self._emit("'num_match': num_match,") + self._emit("'contradicts': contradicts,") + self._emit("'required': required_coverage,") + self._emit("'missed_terms': [w for w in truth_words if w not in resp_lower][:10],") + self._indent -= 1 + self._emit("}") + self._indent -= 1 # close def _fact_check + self._emit("") + # ── Lying detector ── + self._emit("def _detect_lying(response, fact_result):") + self._indent += 1 + self._emit('"""Detect lying: confident language + failed fact check = lying."""') + self._emit("if fact_result['passed']:") + self._indent += 1 + self._emit("return False") + self._indent -= 1 + self._emit("confidence_markers = ['the answer is', 'definitely', 'clearly', 'obviously',") + self._emit(" 'without a doubt', 'certainly', 'i am sure', 'absolutely',") + self._emit(" 'it is well established', 'research confirms', 'studies show']") + self._emit("resp_lower = response.lower()") + self._emit("return sum(1 for m in confidence_markers if m in resp_lower) >= 2") + self._indent -= 1 + self._emit("") + # ── Memory and arena state ── + self._emit("# === ARENA STATE (persists across all rounds) ===") + self._emit("memory_bank = []") + self._emit("seen_approaches = set()") + self._emit("research_log = []") + self._emit("cumulative_difficulty = 0 # increases each round") + self._emit("") + # ── Main arena loop ── + self._emit(f"for arena_round in range({cmd.rounds}):") + self._indent += 1 + self._emit(f"difficulty_level = arena_round # 0, 1, 2, ... (increases each round)") + self._emit(f"strictness = min(1.0, 0.3 + arena_round * {cmd.difficulty_scale}) # gets stricter") + self._emit(f"path_width = max(0.3, 1.0 - arena_round * {cmd.difficulty_scale}) # maze shrinks") + self._emit("") + self._emit(f'print(f"\\n[td_lang] === RESEARCH ARENA ROUND {{arena_round+1}}/{cmd.rounds} ===")') + self._emit('print(f" Difficulty: {difficulty_level} | Strictness: {strictness:.0%} | Path width: {path_width:.0%}")') + self._emit("") + self._emit("# Load model") + self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)") + self._emit("if tok.pad_token is None:") + self._indent += 1 + self._emit("tok.pad_token = tok.eos_token") + self._indent -= 1 + self._emit("bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4',") + self._emit(" bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)") + self._emit("model = AutoModelForCausalLM.from_pretrained(checkpoint, quantization_config=bnb_config, device_map='auto')") + self._emit("model.eval()") + self._emit("") + # Build questions for this round + self._emit(f"questions = _build_questions(knowledge_base, difficulty_level, {cmd.episodes})") + self._emit('print(f" Generated {len(questions)} questions for this round")') + self._emit("") + self._emit("round_experiences = []") + self._emit("round_stats = {'correct': 0, 'wrong': 0, 'lying': 0, 'curious': 0, 'missed_facts': []}") + self._emit("") + # Episode loop + self._emit("for ep_i, q_data in enumerate(questions):") + self._indent += 1 + self._emit("question = q_data['question']") + self._emit("ground_truth = q_data['ground_truth']") + self._emit("") + self._emit("# Generate response") + self._emit("inputs = tok(question, return_tensors='pt', truncation=True, max_length=1024).to(model.device)") + self._emit(f"temp = max(0.3, 0.5 + arena_round * 0.05 + random.uniform(-0.1, 0.1))") + self._emit("with torch.no_grad():") + self._indent += 1 + self._emit("out = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=temp, top_p=0.95)") + self._indent -= 1 + self._emit("response = tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)") + self._emit("") + # Fact check + self._emit("# === FACT CHECK against real source ===") + self._emit("fact_result = _fact_check(response, ground_truth, model, tok, strictness)") + self._emit("is_lying = _detect_lying(response, fact_result)") + self._emit("approach_hash = hashlib.md5(re.sub(r'\\d+', 'N', response[:200]).lower().encode()).hexdigest()[:12]") + self._emit("") + # Reward + self._emit("# === REWARD ===") + self._emit("if is_lying:") + self._indent += 1 + self._emit("reward = -2.0") + self._emit("round_stats['lying'] += 1") + self._indent -= 1 + self._emit("elif fact_result['passed']:") + self._indent += 1 + self._emit("reward = fact_result['score'] # 0.0 to 1.0 based on accuracy") + self._emit("round_stats['correct'] += 1") + self._indent -= 1 + self._emit("else:") + self._indent += 1 + self._emit("reward = -1.0 * strictness # punishment scales with difficulty") + self._emit("round_stats['wrong'] += 1") + self._emit("round_stats['missed_facts'].append({") + self._indent += 1 + self._emit("'ground_truth': ground_truth[:100],") + self._emit("'missed_terms': fact_result['missed_terms'][:5],") + self._emit("'source': q_data['source'],") + self._indent -= 1 + self._emit("})") + self._indent -= 1 + self._emit("") + # Curiosity + self._emit("novelty_key = hashlib.md5(f'{question[:50]}_{approach_hash}'.encode()).hexdigest()[:12]") + self._emit("if novelty_key not in seen_approaches:") + self._indent += 1 + self._emit(f"reward += {cmd.curiosity}") + self._emit("seen_approaches.add(novelty_key)") + self._emit("round_stats['curious'] += 1") + self._indent -= 1 + self._emit("") + # Memory + self._emit("memory_bank.append({'reward': reward, 'passed': fact_result['passed'],") + self._emit(" 'lying': is_lying, 'round': arena_round, 'score': fact_result['score']})") + self._emit("") + self._emit("if reward > 0:") + self._indent += 1 + self._emit("round_experiences.append((question + '\\n' + response, reward))") + self._indent -= 1 + self._emit("") + self._emit("if ep_i % 10 == 0:") + self._indent += 1 + self._emit("status = 'PASS' if fact_result['passed'] else ('LYING!' if is_lying else 'FAIL')") + self._emit("print(f' Ep {ep_i+1}: {status} (score={fact_result[\"score\"]:.2f}, reward={reward:.1f})')") + self._indent -= 2 # close if ep_i, for ep_i + self._emit("") + # Round stats + self._emit("total_ep = round_stats['correct'] + round_stats['wrong'] + round_stats['lying']") + self._emit("print(f'[td_lang] Round {arena_round+1} results:')") + self._emit("print(f' Passed fact-check: {round_stats[\"correct\"]}/{total_ep}')") + self._emit("print(f' Failed: {round_stats[\"wrong\"]}/{total_ep}')") + self._emit("print(f' Caught lying: {round_stats[\"lying\"]} (punished -2.0 each)')") + self._emit("if round_stats['missed_facts']:") + self._indent += 1 + self._emit("print(f' Top missed facts ({len(round_stats[\"missed_facts\"])} total):')") + self._emit("for mf in round_stats['missed_facts'][:3]:") + self._indent += 1 + self._emit("print(f' Source: {mf[\"source\"]}')") + self._emit("print(f' Missed: {mf[\"missed_terms\"]}')") + self._indent -= 2 # close for mf, if missed_facts + self._emit("") + # Free model, train + self._emit("del model; import gc; gc.collect()") + self._emit("try:") + self._indent += 1 + self._emit("torch.cuda.empty_cache()") + self._indent -= 1 + self._emit("except Exception:") + self._indent += 1 + self._emit("pass") + self._indent -= 1 + self._emit("") + self._emit("if len(round_experiences) < 3:") + self._indent += 1 + self._emit("print('[td_lang] Too few positive experiences — maze was too hard. Skipping training.')") + self._emit("continue") + self._indent -= 1 + self._emit("") + self._emit("# === REWARD-WEIGHTED TRAINING ===") + self._emit("training_texts = []") + self._emit("for text, reward in round_experiences:") + self._indent += 1 + self._emit("copies = max(1, int(reward * 2))") + self._emit("training_texts.extend([text] * copies)") + self._indent -= 1 + self._emit("random.shuffle(training_texts)") + self._emit('print(f"[td_lang] Training on {len(training_texts)} reward-weighted experiences...")') + self._emit("") + self._emit("ds = Dataset.from_dict({'text': training_texts})") + self._emit("model = AutoModelForCausalLM.from_pretrained(checkpoint, quantization_config=bnb_config, device_map='auto')") + self._emit("model = prepare_model_for_kbit_training(model)") + self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,") + self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")') + self._emit("model = get_peft_model(model, lora_config)") + self._emit(f"ra_out = f'td_lang_outputs/research_arena_round_{{arena_round}}'") + self._emit(f"training_args = TrainingArguments(output_dir=ra_out, max_steps={cmd.steps},") + self._emit(" per_device_train_batch_size=1, gradient_accumulation_steps=4,") + self._emit(" learning_rate=5e-5, logging_steps=16, bf16=True, gradient_checkpointing=True)") + self._emit("trainer = SFTTrainer(model=model, train_dataset=ds, args=training_args, tokenizer=tok)") + self._emit("trainer.train()") + self._emit("trainer.save_model(ra_out)") + self._emit("checkpoint = ra_out") + self._emit("del model; gc.collect()") + self._emit("try:") + self._indent += 1 + self._emit("torch.cuda.empty_cache()") + self._indent -= 1 + self._emit("except Exception:") + self._indent += 1 + self._emit("pass") + self._indent -= 1 + self._emit("") + self._emit("research_log.append({") + self._indent += 1 + self._emit("'round': arena_round,") + self._emit("'difficulty': difficulty_level,") + self._emit("'strictness': strictness,") + self._emit("'stats': dict(round_stats),") + self._emit("'n_training': len(training_texts),") + self._emit("'memory_size': len(memory_bank),") + self._indent -= 1 + self._emit("})") + self._emit("") + self._emit("print(f'[td_lang] Round {arena_round+1} complete. Model trained and saved.')") + self._indent -= 1 # close for arena_round + self._emit("") + # Final summary + self._emit(f'models["{cmd.target}"]["checkpoint"] = checkpoint') + self._emit('print(f"\\n[td_lang] RESEARCH ARENA COMPLETE")') + self._emit('print(f" Topic: {topic}")') + self._emit('print(f" Knowledge base: {len(knowledge_base)} facts")') + self._emit('print(f" Total memories: {len(memory_bank)}")') + self._emit('print(f" Unique approaches: {len(seen_approaches)}")') + self._emit("lying_count = sum(1 for m in memory_bank if m['lying'])") + self._emit("correct_count = sum(1 for m in memory_bank if m['passed'])") + self._emit("print(f' Correct: {correct_count} | Caught lying: {lying_count}')") + self._emit("avg_reward = sum(m['reward'] for m in memory_bank) / len(memory_bank) if memory_bank else 0") + self._emit("print(f' Average reward: {avg_reward:.2f}')") + self._emit("") + # Save log + if cmd.output: + self._emit(f'log_path = Path("{cmd.output}")') + self._emit("log_path.parent.mkdir(parents=True, exist_ok=True)") + self._emit('with open(log_path, "w") as f:') + self._indent += 1 + self._emit("json.dump({'topic': topic, 'log': research_log, 'memory': memory_bank, 'knowledge_base_size': len(knowledge_base)}, f, indent=2)") + self._indent -= 1 + self._emit('print(f"[td_lang] Research log saved to {log_path}")') + self._emit("") + # Lineage + self._emit(f'lineage["{cmd.target}"]["operations"].append({{') + self._indent += 1 + self._emit('"op": "research_arena",') + self._emit(f'"topic": "{cmd.topic}",') + self._emit(f'"sources": "{cmd.sources}",') + self._emit(f'"rounds": {cmd.rounds},') + self._emit(f'"episodes_per_round": {cmd.episodes},') + self._emit('"knowledge_base_size": len(knowledge_base),') + self._emit('"total_memories": len(memory_bank),') + self._emit('"timestamp": datetime.now().isoformat(),') + self._indent -= 1 + self._emit("})") + self._indent -= 1 # close else (knowledge_base >= 5) + + # ---------------------------------------------------------------- Phase 11: Intelligence + def _emit_vote(self, cmd: VoteCmd) -> None: + """VOTE - majority voting. Generate N answers, pick the most common. + + Proven to boost accuracy 10-20% with zero training cost. + """ + n = cmd.samples + self._emit(f'print("[td_lang] Majority voting on {cmd.target} ({n} samples)...")') + self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")') + self._emit("if not checkpoint:") + self._indent += 1 + self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]') + self._indent -= 1 + self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer") + self._emit("import torch") + self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)") + self._emit("model = AutoModelForCausalLM.from_pretrained(") + self._indent += 1 + self._emit('checkpoint, torch_dtype=torch.bfloat16, device_map="auto"') + self._indent -= 1 + self._emit(")") + self._emit("model.eval()") + self._emit(f'question = {repr(cmd.question)}') + self._emit(f"n_samples = {n}") + self._emit('inputs = tok(question, return_tensors="pt").to(model.device)') + self._emit("answers = []") + self._emit("for i in range(n_samples):") + self._indent += 1 + self._emit("with torch.no_grad():") + self._indent += 1 + self._emit("out = model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.9)") + self._indent -= 1 + self._emit("resp = tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip()") + self._emit("answers.append(resp)") + self._emit('print(f" Sample {i+1}: {resp[:80]}...")') + self._indent -= 1 + self._emit("") + self._emit("# Find the most common answer (majority vote)") + self._emit("from collections import Counter") + self._emit("# Normalize answers: lowercase, strip whitespace for comparison") + self._emit("normalized = [a.strip().lower() for a in answers]") + self._emit("counts = Counter(normalized)") + self._emit("winner_norm, winner_count = counts.most_common(1)[0]") + self._emit("# Find the original (non-normalized) version of the winner") + self._emit("winner = next(a for a, n in zip(answers, normalized) if n == winner_norm)") + self._emit('print(f"[td_lang] Winner ({winner_count}/{n_samples} votes): {winner[:200]}")') + self._emit("") + self._emit("vote_result = {") + self._indent += 1 + self._emit("'question': question,") + self._emit("'winner': winner,") + self._emit("'votes': winner_count,") + self._emit("'total_samples': n_samples,") + self._emit("'all_answers': answers,") + self._emit("'confidence': winner_count / n_samples,") + self._indent -= 1 + self._emit("}") + self._emit(f'results["{cmd.target}_vote"] = vote_result') + if cmd.output: + self._emit(f'vote_path = Path("{cmd.output}")') + self._emit("vote_path.parent.mkdir(parents=True, exist_ok=True)") + self._emit('with open(vote_path, "w") as f:') + self._indent += 1 + self._emit("json.dump(vote_result, f, indent=2)") + self._indent -= 1 + self._emit('print(f"[td_lang] Vote results saved to {vote_path}")') + self._emit("del model, tok") + self._emit("import gc; gc.collect()") + + def _emit_prompt(self, cmd: PromptBlock) -> None: + """PROMPT - attach a system prompt to a model for all future generations. + + Stores the prompt in the model's metadata so other commands (eval, diagnose, + synth, vote) can pick it up and prepend it. + """ + self._emit(f'print("[td_lang] Setting system prompt for {cmd.target}...")') + self._emit(f'models["{cmd.target}"]["system_prompt"] = {repr(cmd.text)}') + self._emit(f'print("[td_lang] Prompt set: {repr(cmd.text[:60])}...")') + self._emit(f'lineage["{cmd.target}"]["operations"].append({{') + self._indent += 1 + self._emit('"op": "prompt",') + self._emit(f'"text": {repr(cmd.text)},') + self._emit('"timestamp": datetime.now().isoformat(),') + self._indent -= 1 + self._emit("})") + + def _emit_distill(self, cmd: DistillCmd) -> None: + """DISTILL - train a smaller student model using the teacher's outputs. + + The teacher generates high-quality answers, and we SFT the student on them. + Result: a fast model for easy questions. + """ + steps = cmd.steps + self._emit(f'print("[td_lang] Distilling {cmd.teacher} into student model...")') + self._emit(f'teacher_checkpoint = models.get("{cmd.teacher}", {{}}).get("checkpoint")') + self._emit("if not teacher_checkpoint:") + self._indent += 1 + self._emit(f'teacher_checkpoint = models["{cmd.teacher}"]["model_ref"]') + self._indent -= 1 + self._emit(f'student_path = {repr(cmd.student)}') + self._emit("") + self._emit("# Step 1: Generate teacher answers on diverse prompts") + self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer") + self._emit("import torch") + self._emit('print("[td_lang] Loading teacher model...")') + self._emit("teacher_tok = AutoTokenizer.from_pretrained(teacher_checkpoint)") + self._emit("teacher_model = AutoModelForCausalLM.from_pretrained(") + self._indent += 1 + self._emit('teacher_checkpoint, torch_dtype=torch.bfloat16, device_map="auto"') + self._indent -= 1 + self._emit(")") + self._emit("teacher_model.eval()") + self._emit("") + self._emit("distill_prompts = [") + self._indent += 1 + self._emit('"Explain how photosynthesis works step by step.",') + self._emit('"Write a Python function to find the longest common subsequence.",') + self._emit('"What is 847 divided by 11? Show your work.",') + self._emit('"Compare and contrast TCP and UDP protocols.",') + self._emit('"Solve: if 3x + 7 = 22, what is x?",') + self._emit('"Explain the difference between a stack and a queue.",') + self._emit('"What causes seasons on Earth?",') + self._emit('"Write a function to check if a string is a palindrome.",') + self._emit('"What is the Pythagorean theorem and give an example.",') + self._emit('"Explain recursion with a simple example.",') + self._emit('"What is 15% of 240?",') + self._emit('"Describe how a binary search works.",') + self._emit('"What are the three laws of thermodynamics?",') + self._emit('"Write pseudocode for bubble sort.",') + self._emit('"If a train travels 120 miles in 2 hours, what is its speed?",') + self._emit('"Explain what an API is in simple terms.",') + self._indent -= 1 + self._emit("]") + self._emit("") + self._emit("teacher_data = []") + self._emit("for prompt in distill_prompts:") + self._indent += 1 + self._emit('inputs = teacher_tok(prompt, return_tensors="pt").to(teacher_model.device)') + self._emit("with torch.no_grad():") + self._indent += 1 + self._emit("out = teacher_model.generate(**inputs, max_new_tokens=512, do_sample=False)") + self._indent -= 1 + self._emit("resp = teacher_tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)") + self._emit('teacher_data.append({"prompt": prompt, "response": resp})') + self._emit('print(f" Generated: {prompt[:40]}... -> {len(resp)} chars")') + self._indent -= 1 + self._emit("") + self._emit("del teacher_model") + self._emit("import gc; gc.collect()") + self._emit("try:") + self._indent += 1 + self._emit("torch.cuda.empty_cache()") + self._indent -= 1 + self._emit("except Exception:") + self._indent += 1 + self._emit("pass") + self._indent -= 1 + self._emit("") + self._emit("# Step 2: Load student model with QLoRA and train on teacher outputs") + self._emit('print("[td_lang] Loading student model with QLoRA...")') + self._emit("from transformers import BitsAndBytesConfig, TrainingArguments") + self._emit("from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training") + self._emit("from trl import SFTTrainer") + self._emit("from datasets import Dataset") + self._emit("") + self._emit("bnb_config = BitsAndBytesConfig(") + self._indent += 1 + self._emit("load_in_4bit=True,") + self._emit('bnb_4bit_quant_type="nf4",') + self._emit("bnb_4bit_compute_dtype=torch.bfloat16,") + self._emit("bnb_4bit_use_double_quant=True,") + self._indent -= 1 + self._emit(")") + self._emit("student_tok = AutoTokenizer.from_pretrained(student_path)") + self._emit("student_model = AutoModelForCausalLM.from_pretrained(") + self._indent += 1 + self._emit("student_path, quantization_config=bnb_config, device_map='auto'") + self._indent -= 1 + self._emit(")") + self._emit("student_model = prepare_model_for_kbit_training(student_model)") + self._emit("") + self._emit("lora_config = LoraConfig(") + self._indent += 1 + self._emit("r=16, lora_alpha=32, lora_dropout=0.05,") + self._emit('target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],') + self._emit('task_type="CAUSAL_LM",') + self._indent -= 1 + self._emit(")") + self._emit("student_model = get_peft_model(student_model, lora_config)") + self._emit("") + self._emit("# Format training data") + self._emit("train_texts = []") + self._emit("for d in teacher_data:") + self._indent += 1 + self._emit("train_texts.append(d['prompt'] + '\\n' + d['response'])") + self._indent -= 1 + self._emit('ds = Dataset.from_dict({"text": train_texts})') + self._emit("") + distill_out = cmd.output or "td_lang_outputs/distilled_student" + self._emit(f'distill_out = "{distill_out}"') + self._emit("training_args = TrainingArguments(") + self._indent += 1 + self._emit("output_dir=distill_out,") + self._emit(f"num_train_epochs={max(1, steps // len('distill_prompts') + 1)},") + self._emit(f"max_steps={steps},") + self._emit("per_device_train_batch_size=1,") + self._emit("gradient_accumulation_steps=4,") + self._emit("learning_rate=2e-4,") + self._emit('optim="paged_adamw_8bit",') + self._emit("logging_steps=10,") + self._emit("save_strategy='epoch',") + self._emit("bf16=True,") + self._indent -= 1 + self._emit(")") + self._emit("trainer = SFTTrainer(") + self._indent += 1 + self._emit("model=student_model,") + self._emit("train_dataset=ds,") + self._emit("args=training_args,") + self._emit("tokenizer=student_tok,") + self._indent -= 1 + self._emit(")") + self._emit('print(f"[td_lang] Training student for {training_args.max_steps} steps...")') + self._emit("trainer.train()") + self._emit("student_model.save_pretrained(distill_out)") + self._emit("student_tok.save_pretrained(distill_out)") + self._emit('print(f"[td_lang] Student model saved to {distill_out}")') + self._emit("") + self._emit("del student_model, teacher_tok, student_tok") + self._emit("gc.collect()") + self._emit("try:") + self._indent += 1 + self._emit("torch.cuda.empty_cache()") + self._indent -= 1 + self._emit("except Exception:") + self._indent += 1 + self._emit("pass") + self._indent -= 1 + self._emit(f'lineage["{cmd.teacher}"]["operations"].append({{') + self._indent += 1 + self._emit('"op": "distill",') + self._emit(f'"student": {repr(cmd.student)},') + self._emit(f'"steps": {steps},') + self._emit(f'"n_examples": len(teacher_data),') + self._emit('"timestamp": datetime.now().isoformat(),') + self._indent -= 1 + self._emit("})") + + def _emit_rollback(self, cmd: RollbackCmd) -> None: + """ROLLBACK - revert to the most recent snapshot. + + Looks for the latest snapshot in td_lang_outputs/snapshots/ for this model, + then reloads from it. + """ + self._emit(f'print("[td_lang] Rolling back {cmd.target}...")') + self._emit("import glob as _glob") + self._emit(f'snap_pattern = os.path.join("td_lang_outputs", "snapshots", "{cmd.target}_*")') + self._emit("snapshots = sorted(_glob.glob(snap_pattern))") + self._emit("if not snapshots:") + self._indent += 1 + self._emit(f'print("[td_lang] ERROR: No snapshots found for {cmd.target}. Cannot rollback.")') + self._emit(f'print("[td_lang] Hint: use snapshot {cmd.target} before training to create restore points.")') + self._indent -= 1 + self._emit("else:") + self._indent += 1 + self._emit("latest_snap = snapshots[-1]") + self._emit('print(f"[td_lang] Found {len(snapshots)} snapshots. Reverting to: {latest_snap}")') + self._emit(f'models["{cmd.target}"]["checkpoint"] = latest_snap') + self._emit(f'lineage["{cmd.target}"]["operations"].append({{') + self._indent += 1 + self._emit('"op": "rollback",') + self._emit('"snapshot": latest_snap,') + self._emit('"timestamp": datetime.now().isoformat(),') + self._indent -= 1 + self._emit("})") + self._emit(f'print(f"[td_lang] Rollback complete. {cmd.target} now points to {{latest_snap}}")') + self._indent -= 1 + def _emit_summary(self) -> None: self._emit("# --- Final Summary ---") self._emit("elapsed = time.time() - start_time")