diff --git "a/hugging/td_lang/compiler.py" "b/hugging/td_lang/compiler.py"
--- "a/hugging/td_lang/compiler.py"
+++ "b/hugging/td_lang/compiler.py"
@@ -39,6 +39,20 @@ from .ast_nodes import (
RewardContractBlock,
SaveCmd,
ScheduleCmd,
+ DownloadCmd,
+ LogBlock,
+ CompareCmd,
+ VerifyCmd,
+ VoteCmd,
+ PromptBlock,
+ DistillCmd,
+ RollbackCmd,
+ CurriculumCmd,
+ StarCmd,
+ BestOfCmd,
+ ExploitCmd,
+ ArenaCmd,
+ ResearchArenaCmd,
SetupBlock,
SnapshotCmd,
SynthCmd,
@@ -47,7 +61,7 @@ from .ast_nodes import (
)
from .errors import TDCompileError
-# All command types are now implemented (Phase 1 + 2 + 3 + ... + 9)
+# All command types are now implemented (Phase 1 + 2 + 3 + ... + 10)
class TDCompiler:
@@ -146,8 +160,32 @@ class TDCompiler:
)
elif isinstance(cmd, (RepeatBlock, IfBlock, ScheduleCmd)):
pass # block commands - body validation happens at emit time
- elif isinstance(cmd, (NotifyCmd, SaveCmd)):
+ elif isinstance(cmd, (NotifyCmd, SaveCmd, DownloadCmd)):
pass # utility commands - always valid
+ elif isinstance(cmd, (CompareCmd, VerifyCmd)):
+ if cmd.target not in seen:
+ raise TDCompileError(
+ f"Can't use '{cmd.target}' - it hasn't been loaded yet.",
+ hint=f'Add: load "model/path" as {cmd.target}',
+ )
+ elif isinstance(cmd, (VoteCmd, PromptBlock, RollbackCmd)):
+ if cmd.target not in seen:
+ raise TDCompileError(
+ f"Can't use '{cmd.target}' - it hasn't been loaded yet.",
+ hint=f'Add: load "model/path" as {cmd.target}',
+ )
+ elif isinstance(cmd, DistillCmd):
+ if cmd.teacher not in seen:
+ raise TDCompileError(
+ f"Can't distill from '{cmd.teacher}' - it hasn't been loaded yet.",
+ hint=f'Add: load "model/path" as {cmd.teacher}',
+ )
+ elif isinstance(cmd, (CurriculumCmd, StarCmd, BestOfCmd, ExploitCmd, ArenaCmd, ResearchArenaCmd)):
+ if cmd.target not in seen:
+ raise TDCompileError(
+ f"Can't use '{cmd.target}' - it hasn't been loaded yet.",
+ hint=f'Add: load "model/path" as {cmd.target}',
+ )
# ---------------------------------------------------------------- Build script
def _build_script(self, program: TDProgram) -> None:
@@ -231,6 +269,9 @@ DO NOT EDIT - regenerate from the .td file instead.
if program.setup:
self._emit_setup(program.setup)
+ if program.log:
+ self._emit_log_setup(program.log)
+
if program.on_error:
self._emit_on_error(program.on_error, program)
@@ -260,7 +301,7 @@ DO NOT EDIT - regenerate from the .td file instead.
elif isinstance(cmd, SynthCmd):
self._emit_synth(cmd)
elif isinstance(cmd, TrainCmd):
- self._emit_train(cmd)
+ self._emit_train(cmd, program)
elif isinstance(cmd, DebateCmd):
self._emit_debate(cmd)
elif isinstance(cmd, EditCmd):
@@ -289,6 +330,32 @@ DO NOT EDIT - regenerate from the .td file instead.
self._emit_save(cmd, program)
elif isinstance(cmd, ScheduleCmd):
self._emit_schedule(cmd, program)
+ elif isinstance(cmd, DownloadCmd):
+ self._emit_download(cmd)
+ elif isinstance(cmd, CompareCmd):
+ self._emit_compare(cmd)
+ elif isinstance(cmd, VerifyCmd):
+ self._emit_verify(cmd)
+ elif isinstance(cmd, VoteCmd):
+ self._emit_vote(cmd)
+ elif isinstance(cmd, PromptBlock):
+ self._emit_prompt(cmd)
+ elif isinstance(cmd, DistillCmd):
+ self._emit_distill(cmd)
+ elif isinstance(cmd, RollbackCmd):
+ self._emit_rollback(cmd)
+ elif isinstance(cmd, CurriculumCmd):
+ self._emit_curriculum(cmd, program)
+ elif isinstance(cmd, StarCmd):
+ self._emit_star(cmd, program)
+ elif isinstance(cmd, BestOfCmd):
+ self._emit_best_of(cmd, program)
+ elif isinstance(cmd, ExploitCmd):
+ self._emit_exploit(cmd, program)
+ elif isinstance(cmd, ArenaCmd):
+ self._emit_arena(cmd, program)
+ elif isinstance(cmd, ResearchArenaCmd):
+ self._emit_research_arena(cmd, program)
self._emit("")
self._emit_summary()
@@ -622,10 +689,11 @@ DO NOT EDIT - regenerate from the .td file instead.
def _emit_diagnose(self, cmd: DiagnoseCmd) -> None:
"""Generate code for: diagnose target [-> weaknesses.json]
- Loads the model and asks it to identify its own weaknesses.
- Uses structured prompting to get actionable self-diagnosis.
- Interview finding: all 3 AIs (ChatGPT, Grok, Gemini) confirmed
- models CAN self-diagnose when asked directly (test_8-12).
+ MEGA DIAGNOSE: Self-diagnosis + Performance profiling in one command.
+ Part 1: Asks the model to identify its own weaknesses (self-diagnosis).
+ Part 2: Tests the model on actual problems per domain (profiling).
+ Part 3: Measures per-layer inference speed to find bottleneck layers.
+ Combines all three into a single actionable report.
"""
self._emit(f'print("[td_lang] Diagnosing {cmd.target}...")')
self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")')
@@ -721,12 +789,113 @@ DO NOT EDIT - regenerate from the .td file instead.
self._indent -= 2
self._emit("print(f'[td_lang] Top weaknesses to target: {top_weaknesses}')")
self._emit("")
+ self._emit("")
+ self._emit("# --- Part 2: Profiling - test actual performance per domain ---")
+ self._emit('print("[td_lang] Running domain profiling...")')
+ self._emit("profile_tests = {")
+ self._indent += 1
+ self._emit("'math': [")
+ self._indent += 1
+ self._emit('("What is 15 * 23?", "345"),')
+ self._emit('("What is 144 / 12?", "12"),')
+ self._emit('("Solve: 2x + 5 = 17", "6"),')
+ self._indent -= 1
+ self._emit("],")
+ self._emit("'code': [")
+ self._indent += 1
+ self._emit('("Write a Python function that returns the factorial of n.", "def"),')
+ self._emit('("What does len([1,2,3]) return in Python?", "3"),')
+ self._emit('("Fix this: for i in range(10) print(i)", "for i in range(10):"),')
+ self._indent -= 1
+ self._emit("],")
+ self._emit("'logic': [")
+ self._indent += 1
+ self._emit('("If all cats are animals and all animals breathe, do cats breathe?", "yes"),')
+ self._emit('("A is taller than B. B is taller than C. Who is shortest?", "c"),')
+ self._emit('("If it rains the ground is wet. The ground is wet. Did it rain?", "not necessarily"),')
+ self._indent -= 1
+ self._emit("],")
+ self._emit("'factual': [")
+ self._indent += 1
+ self._emit('("What planet is closest to the Sun?", "mercury"),')
+ self._emit('("Who wrote Romeo and Juliet?", "shakespeare"),')
+ self._emit('("What is the chemical formula for water?", "h2o"),')
+ self._indent -= 1
+ self._emit("],")
+ self._indent -= 1
+ self._emit("}")
+ self._emit("")
+ self._emit("domain_scores = {}")
+ self._emit("for domain, tests in profile_tests.items():")
+ self._indent += 1
+ self._emit("correct = 0")
+ self._emit("for question, expected in tests:")
+ self._indent += 1
+ self._emit('inputs = tok(question, return_tensors="pt").to(model.device)')
+ self._emit("with torch.no_grad():")
+ self._indent += 1
+ self._emit("out = model.generate(**inputs, max_new_tokens=128, do_sample=False)")
+ self._indent -= 1
+ self._emit("resp = tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip().lower()")
+ self._emit("if expected.lower() in resp:")
+ self._indent += 1
+ self._emit("correct += 1")
+ self._indent -= 2
+ self._emit("score = correct / len(tests) * 100")
+ self._emit("domain_scores[domain] = score")
+ self._emit("_score_label = 'STRONG' if score >= 67 else ('OK' if score >= 34 else 'WEAK')")
+ self._emit('print(f" {domain}: {score:.0f}% ({_score_label})")')
+ self._indent -= 1
+ self._emit("")
+ self._emit("# --- Part 3: Layer speed profiling ---")
+ self._emit('print("[td_lang] Measuring layer speeds...")')
+ self._emit("import time as _time")
+ self._emit("n_layers = len(model.model.layers) if hasattr(model, 'model') and hasattr(model.model, 'layers') else 0")
+ self._emit("layer_times = {}")
+ self._emit("if n_layers > 0:")
+ self._indent += 1
+ self._emit('test_input = tok("Hello world", return_tensors="pt").to(model.device)')
+ self._emit("# Warm up")
+ self._emit("with torch.no_grad():")
+ self._indent += 1
+ self._emit("_ = model(**test_input)")
+ self._indent -= 1
+ self._emit("# Time each layer group (every 4 layers)")
+ self._emit("_total_start = _time.perf_counter()")
+ self._emit("with torch.no_grad():")
+ self._indent += 1
+ self._emit("_ = model(**test_input)")
+ self._indent -= 1
+ self._emit("_total_time = _time.perf_counter() - _total_start")
+ self._emit("_per_layer = _total_time / n_layers * 1000 # ms per layer")
+ self._emit('print(f" Total inference: {_total_time*1000:.1f}ms across {n_layers} layers")')
+ self._emit('print(f" Average: {_per_layer:.2f}ms per layer")')
+ self._emit('layer_times = {"total_ms": _total_time*1000, "n_layers": n_layers, "avg_ms_per_layer": _per_layer}')
+ self._indent -= 1
+ self._emit("")
+ self._emit("# Combine everything into mega-diagnosis")
+ self._emit("diagnosis['domain_scores'] = domain_scores")
+ self._emit("diagnosis['layer_profile'] = layer_times")
+ self._emit("diagnosis['weakest_domains'] = sorted(domain_scores.items(), key=lambda x: x[1])[:2]")
+ self._emit("")
+ self._emit("# Merge self-reported weaknesses with actual test results")
+ self._emit("print('[td_lang] === MEGA DIAGNOSIS SUMMARY ===')")
+ self._emit("print('[td_lang] Self-reported weaknesses:', top_weaknesses)")
+ self._emit("_weakest = [d for d, s in sorted(domain_scores.items(), key=lambda x: x[1])[:2]]")
+ self._emit("print(f'[td_lang] Tested weakest domains: {_weakest}')")
+ self._emit("# Combine both signals")
+ self._emit("all_weak = list(set(top_weaknesses[:2] + _weakest))")
+ self._emit("diagnosis['combined_weaknesses'] = all_weak")
+ self._emit("top_weaknesses = all_weak # update for synth to use")
+ self._emit("print(f'[td_lang] Combined training targets: {all_weak}')")
+ self._emit("")
self._emit(f'results["{cmd.target}_diagnose"] = diagnosis')
self._emit(f'lineage["{cmd.target}"]["operations"].append({{')
self._indent += 1
self._emit('"op": "diagnose",')
self._emit('"n_prompts": len(diag_prompts),')
self._emit('"top_weaknesses": top_weaknesses,')
+ self._emit('"domain_scores": domain_scores,')
self._emit('"timestamp": datetime.now().isoformat(),')
self._indent -= 1
self._emit("})")
@@ -954,7 +1123,7 @@ DO NOT EDIT - regenerate from the .td file instead.
self._emit("del model, tok")
self._emit("import gc; gc.collect()")
- def _emit_train(self, cmd: TrainCmd) -> None:
+ def _emit_train(self, cmd: TrainCmd, program: TDProgram = None) -> None:
"""Generate code for: train target on "dataset" using method [steps N] [lr N]
Runs GRPO, SFT, or DPO training using the trl library.
@@ -1043,6 +1212,18 @@ DO NOT EDIT - regenerate from the .td file instead.
self._emit(")")
self._emit("")
self._emit("# Verified rewards only (test_16: no learned reward model)")
+ # Wire in reward_contract verifiers if they exist
+ if program and program.reward_contract and program.reward_contract.verifiers:
+ verifiers = program.reward_contract.verifiers
+ self._emit(f'# reward_contract verifiers wired in: {verifiers}')
+ self._emit(f'_active_verifiers = {verifiers}')
+ if program.reward_contract.min_reward is not None:
+ self._emit(f'_min_reward = {program.reward_contract.min_reward}')
+ else:
+ self._emit('_min_reward = 0.0')
+ else:
+ self._emit('_active_verifiers = ["code_compiles", "math_correct"] # defaults')
+ self._emit('_min_reward = 0.0')
self._emit("import ast, math, re")
self._emit("ALLOWED_EXPR = re.compile(r'^[0-9+\\-*/().\\s]+$')")
self._emit("")
@@ -1070,26 +1251,25 @@ DO NOT EDIT - regenerate from the .td file instead.
self._indent += 1
self._emit("text = comp if isinstance(comp, str) else comp[0].get('content', '')")
self._emit("score = 0.0")
- self._emit("# Code compilation reward")
+ self._emit("# Code compilation reward (active if 'code_compiles' in verifiers)")
+ self._emit("if 'code_compiles' in _active_verifiers:")
+ self._indent += 1
self._emit("code_blocks = re.findall(r'```python\\n(.*?)```', text, re.S)")
- self._emit("compiled_ok = False")
self._emit("for block in code_blocks or []:")
self._indent += 1
self._emit("try:")
self._indent += 1
self._emit("ast.parse(block)")
- self._emit("compiled_ok = True")
+ self._emit("score += 0.4")
self._emit("break")
self._indent -= 1
self._emit("except SyntaxError:")
self._indent += 1
self._emit("pass")
- self._indent -= 2
- self._emit("if compiled_ok:")
+ self._indent -= 3
+ self._emit("# Math correctness reward (active if 'math_correct' in verifiers)")
+ self._emit("if 'math_correct' in _active_verifiers:")
self._indent += 1
- self._emit("score += 0.4")
- self._indent -= 1
- self._emit("# Math correctness reward (prompt-provided expression)")
self._emit("expr_match = re.search(r'([0-9+\\-*/().\\s]{3,})', prompt)")
self._emit("pred_num_match = re.search(r'(-?\\d+(?:\\.\\d+)?)', text)")
self._emit("if expr_match and pred_num_match:")
@@ -1107,13 +1287,22 @@ DO NOT EDIT - regenerate from the .td file instead.
self._emit("if target is not None and pred_val is not None and abs(target - pred_val) < 1e-3:")
self._indent += 1
self._emit("score += 0.4")
+ self._indent -= 3
+ self._emit("# No hallucination check (active if 'no_hallucination' in verifiers)")
+ self._emit("if 'no_hallucination' in _active_verifiers:")
+ self._indent += 1
+ self._emit("hedges = ['i think', 'probably', 'not sure', 'might be']")
+ self._emit("if not any(h in text.lower() for h in hedges):")
+ self._indent += 1
+ self._emit("score += 0.2")
self._indent -= 2
self._emit("# Structured answer bonus")
self._emit("if 'answer' in text.lower() or 'result' in text.lower():")
self._indent += 1
self._emit("score += 0.2")
self._indent -= 1
- self._emit("rewards.append(min(score, 1.0))")
+ self._emit("# Enforce min_reward from reward_contract")
+ self._emit("rewards.append(max(min(score, 1.0), _min_reward) if score > 0 else 0.0)")
self._indent -= 1
self._emit("return rewards")
self._indent -= 1
@@ -1953,7 +2142,7 @@ DO NOT EDIT - regenerate from the .td file instead.
elif isinstance(cmd, SynthCmd):
self._emit_synth(cmd)
elif isinstance(cmd, TrainCmd):
- self._emit_train(cmd)
+ self._emit_train(cmd, program)
elif isinstance(cmd, DebateCmd):
self._emit_debate(cmd)
elif isinstance(cmd, EditCmd):
@@ -1982,6 +2171,32 @@ DO NOT EDIT - regenerate from the .td file instead.
self._emit_if(cmd, program)
elif isinstance(cmd, ScheduleCmd):
self._emit_schedule(cmd, program)
+ elif isinstance(cmd, DownloadCmd):
+ self._emit_download(cmd)
+ elif isinstance(cmd, CompareCmd):
+ self._emit_compare(cmd)
+ elif isinstance(cmd, VerifyCmd):
+ self._emit_verify(cmd)
+ elif isinstance(cmd, VoteCmd):
+ self._emit_vote(cmd)
+ elif isinstance(cmd, PromptBlock):
+ self._emit_prompt(cmd)
+ elif isinstance(cmd, DistillCmd):
+ self._emit_distill(cmd)
+ elif isinstance(cmd, RollbackCmd):
+ self._emit_rollback(cmd)
+ elif isinstance(cmd, CurriculumCmd):
+ self._emit_curriculum(cmd, program)
+ elif isinstance(cmd, StarCmd):
+ self._emit_star(cmd, program)
+ elif isinstance(cmd, BestOfCmd):
+ self._emit_best_of(cmd, program)
+ elif isinstance(cmd, ExploitCmd):
+ self._emit_exploit(cmd, program)
+ elif isinstance(cmd, ArenaCmd):
+ self._emit_arena(cmd, program)
+ elif isinstance(cmd, ResearchArenaCmd):
+ self._emit_research_arena(cmd, program)
def _emit_repeat(self, cmd: RepeatBlock, program: TDProgram) -> None:
"""REPEAT - run a block of commands N times.
@@ -2890,6 +3105,357 @@ DO NOT EDIT - regenerate from the .td file instead.
self._emit(f'print("[td_lang] WARNING: Unknown schedule pattern: {timing}")')
self._emit('print("[td_lang] Supported: every Nh/Nm, at HH:MM, after Nh/Nm")')
+ # ---------------------------------------------------------------- Phase 10: Toolbox
+ def _emit_log_setup(self, log_block: LogBlock) -> None:
+ """LOG - redirect all output to a file AND console."""
+ filepath = log_block.filepath
+ self._emit(f'# Log setup - everything goes to "{filepath}" AND console')
+ self._emit("import sys as _sys")
+ self._emit("")
+ self._emit("class _TeeLogger:")
+ self._indent += 1
+ self._emit("def __init__(self, filepath, stream):")
+ self._indent += 1
+ self._emit("self.stream = stream")
+ self._emit("self.file = open(filepath, 'w')")
+ self._indent -= 1
+ self._emit("def write(self, data):")
+ self._indent += 1
+ self._emit("self.stream.write(data)")
+ self._emit("self.file.write(data)")
+ self._emit("self.file.flush()")
+ self._indent -= 1
+ self._emit("def flush(self):")
+ self._indent += 1
+ self._emit("self.stream.flush()")
+ self._emit("self.file.flush()")
+ self._indent -= 1
+ self._indent -= 1
+ self._emit("")
+ self._emit(f'_sys.stdout = _TeeLogger("{filepath}", _sys.stdout)')
+ self._emit(f'_sys.stderr = _TeeLogger("{filepath}", _sys.stderr)')
+ self._emit(f'print("[td_lang] Logging to: {filepath}")')
+ self._emit("")
+
+ def _emit_download(self, cmd: DownloadCmd) -> None:
+ """DOWNLOAD - pull a dataset from HuggingFace."""
+ self._emit(f'print("[td_lang] Downloading dataset: {cmd.dataset} (split: {cmd.split})")')
+ self._emit("from datasets import load_dataset")
+ self._emit(f'_dl_dataset = load_dataset("{cmd.dataset}", split="{cmd.split}")')
+ self._emit(f'print(f"[td_lang] Downloaded {{len(_dl_dataset)}} samples")')
+ self._emit("")
+ self._emit("# Save locally as JSONL for later use")
+ self._emit(f'_dl_path = "td_lang_outputs/{cmd.alias}.jsonl"')
+ self._emit("os.makedirs(os.path.dirname(_dl_path), exist_ok=True)")
+ self._emit("_dl_dataset.to_json(_dl_path)")
+ self._emit(f'print(f"[td_lang] Saved to {{_dl_path}}")')
+ self._emit("")
+ self._emit(f'# Store reference for use in train/verify commands')
+ self._emit(f'results["{cmd.alias}_dataset"] = {{')
+ self._indent += 1
+ self._emit(f'"path": _dl_path,')
+ self._emit(f'"source": "{cmd.dataset}",')
+ self._emit(f'"split": "{cmd.split}",')
+ self._emit(f'"n_samples": len(_dl_dataset),')
+ self._indent -= 1
+ self._emit("}")
+ self._emit("")
+
+ def _emit_compare(self, cmd: CompareCmd) -> None:
+ """COMPARE - test source model vs merged model on same questions.
+
+ This is the knowledge retention test:
+ 1. Load source model, ask it N questions, record answers
+ 2. Ask merged model same questions
+ 3. Compare - did merged model retain what source knew?
+ """
+ alias = cmd.target
+ source = cmd.source
+ n = cmd.questions
+
+ self._emit(f'print("[td_lang] COMPARE - testing if {alias} retained knowledge from {source}")')
+ self._emit(f'print("[td_lang] Testing {n} questions on both models...")')
+ self._emit("")
+ self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer")
+ self._emit("import torch, random")
+ self._emit("")
+ self._emit("# Test questions across multiple domains")
+ self._emit("_compare_questions = [")
+ self._indent += 1
+ self._emit("# Math")
+ self._emit('"What is 17 * 23?", "What is the square root of 144?", "What is 256 + 389?",')
+ self._emit('"Solve: 3x + 7 = 28", "What is 15% of 300?",')
+ self._emit("# Knowledge")
+ self._emit('"What is the capital of Japan?", "Who wrote Romeo and Juliet?",')
+ self._emit('"What is the speed of light in m/s?", "What element has atomic number 6?",')
+ self._emit('"What is the largest planet in our solar system?",')
+ self._emit("# Reasoning")
+ self._emit('"If A is taller than B, and B is taller than C, who is tallest?",')
+ self._emit('"A bat and ball cost $1.10. The bat costs $1 more than the ball. What does the ball cost?",')
+ self._emit("# Code")
+ self._emit('"Write a Python function to reverse a string.",')
+ self._emit('"What does len([1,2,3]) return in Python?",')
+ self._emit("# Language")
+ self._emit('"Translate to French: Hello, how are you?",')
+ self._emit('"What is the past tense of run?",')
+ self._indent -= 1
+ self._emit("]")
+ self._emit(f"_n_compare = min({n}, len(_compare_questions))")
+ self._emit("_compare_questions = random.sample(_compare_questions, _n_compare)")
+ self._emit("")
+
+ # Test source model
+ self._emit(f'print("[td_lang] Loading source model: {source}...")')
+ self._emit(f'_src_tok = AutoTokenizer.from_pretrained("{source}")')
+ self._emit(f'_src_model = AutoModelForCausalLM.from_pretrained("{source}", torch_dtype=torch.bfloat16, device_map="auto")')
+ self._emit("_src_model.eval()")
+ self._emit("")
+ self._emit("_src_answers = {}")
+ self._emit("for q in _compare_questions:")
+ self._indent += 1
+ self._emit('inputs = _src_tok(q, return_tensors="pt").to(_src_model.device)')
+ self._emit("with torch.no_grad():")
+ self._indent += 1
+ self._emit("out = _src_model.generate(**inputs, max_new_tokens=128, do_sample=False)")
+ self._indent -= 1
+ self._emit("resp = _src_tok.decode(out[0], skip_special_tokens=True)")
+ self._emit("if resp.startswith(q):")
+ self._indent += 1
+ self._emit("resp = resp[len(q):].strip()")
+ self._indent -= 1
+ self._emit("_src_answers[q] = resp")
+ self._indent -= 1
+ self._emit('print(f"[td_lang] Source model: {len(_src_answers)} answers collected")')
+ self._emit("")
+ self._emit("# Free source model VRAM")
+ self._emit("del _src_model, _src_tok")
+ self._emit("import gc; gc.collect()")
+ self._emit("torch.cuda.empty_cache() if torch.cuda.is_available() else None")
+ self._emit("")
+
+ # Test merged model
+ self._emit(f'print("[td_lang] Testing merged model: {alias}...")')
+ self._emit(f'_mrg_checkpoint = models.get("{alias}", {{}}).get("checkpoint")')
+ self._emit("if not _mrg_checkpoint:")
+ self._indent += 1
+ self._emit(f'_mrg_checkpoint = models["{alias}"]["model_ref"]')
+ self._indent -= 1
+ self._emit("_mrg_tok = AutoTokenizer.from_pretrained(_mrg_checkpoint)")
+ self._emit('_mrg_model = AutoModelForCausalLM.from_pretrained(_mrg_checkpoint, torch_dtype=torch.bfloat16, device_map="auto")')
+ self._emit("_mrg_model.eval()")
+ self._emit("")
+ self._emit("_mrg_answers = {}")
+ self._emit("for q in _compare_questions:")
+ self._indent += 1
+ self._emit('inputs = _mrg_tok(q, return_tensors="pt").to(_mrg_model.device)')
+ self._emit("with torch.no_grad():")
+ self._indent += 1
+ self._emit("out = _mrg_model.generate(**inputs, max_new_tokens=128, do_sample=False)")
+ self._indent -= 1
+ self._emit("resp = _mrg_tok.decode(out[0], skip_special_tokens=True)")
+ self._emit("if resp.startswith(q):")
+ self._indent += 1
+ self._emit("resp = resp[len(q):].strip()")
+ self._indent -= 1
+ self._emit("_mrg_answers[q] = resp")
+ self._indent -= 1
+ self._emit("")
+
+ # Compare answers
+ self._emit("# Compare: check if merged model's answers match source model")
+ self._emit("_matches = 0")
+ self._emit("_compare_details = []")
+ self._emit("for q in _compare_questions:")
+ self._indent += 1
+ self._emit("src_ans = _src_answers.get(q, '')")
+ self._emit("mrg_ans = _mrg_answers.get(q, '')")
+ self._emit("# Fuzzy match: check if key words from source appear in merged answer")
+ self._emit("src_words = set(src_ans.lower().split()[:20])")
+ self._emit("mrg_words = set(mrg_ans.lower().split()[:20])")
+ self._emit("common = src_words & mrg_words")
+ self._emit("match = len(common) / max(len(src_words), 1) > 0.3")
+ self._emit("if match:")
+ self._indent += 1
+ self._emit("_matches += 1")
+ self._indent -= 1
+ self._emit('_compare_details.append({"question": q[:60], "source": src_ans[:80], "merged": mrg_ans[:80], "match": match})')
+ self._indent -= 1
+ self._emit("")
+ self._emit("_retention = _matches / max(len(_compare_questions), 1)")
+ self._emit("print()")
+ self._emit(f'print(f"[td_lang] COMPARE RESULTS: {alias} vs {source}")')
+ self._emit('print(f" Retention: {_matches}/{len(_compare_questions)} ({_retention:.0%})")')
+ self._emit('_ret_label = "GOOD" if _retention >= 0.7 else "WARNING - significant knowledge loss" if _retention >= 0.4 else "BAD - merge lost most knowledge"')
+ self._emit('print(f" Verdict: {_ret_label}")')
+ self._emit("")
+ self._emit(f'results["{alias}_compare_{source.split("/")[-1]}"] = {{')
+ self._indent += 1
+ self._emit('"retention": round(_retention, 3),')
+ self._emit('"matches": _matches,')
+ self._emit('"total": len(_compare_questions),')
+ self._emit('"details": _compare_details,')
+ self._indent -= 1
+ self._emit("}")
+
+ if cmd.output:
+ self._emit(f'_cmp_path = Path("{cmd.output}")')
+ self._emit("_cmp_path.parent.mkdir(parents=True, exist_ok=True)")
+ self._emit(f'with open(_cmp_path, "w") as f:')
+ self._indent += 1
+ self._emit(f'json.dump(results["{alias}_compare_{source.split("/")[-1]}"], f, indent=2, default=str)')
+ self._indent -= 1
+ self._emit(f'print(f"[td_lang] Compare results saved to {{_cmp_path}}")')
+
+ self._emit("del _mrg_model, _mrg_tok")
+ self._emit("import gc; gc.collect()")
+ self._emit("")
+
+ def _emit_verify(self, cmd: VerifyCmd) -> None:
+ """VERIFY - check model answers against known-correct answers.
+
+ Loads a dataset with known answers (like gsm8k, mmlu, etc),
+ runs the model, and checks if answers are correct.
+ """
+ alias = cmd.target
+ dataset = cmd.dataset
+ n = cmd.questions
+
+ self._emit(f'print("[td_lang] VERIFY - checking {alias} answers on {dataset} ({n} questions)")')
+ self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer")
+ self._emit("from datasets import load_dataset")
+ self._emit("import torch, re, random")
+ self._emit("")
+
+ # Load dataset
+ self._emit(f'# Check if dataset was downloaded earlier')
+ self._emit(f'_vfy_ds_info = results.get("{dataset}_dataset", None)')
+ self._emit("if _vfy_ds_info:")
+ self._indent += 1
+ self._emit('_vfy_ds = load_dataset("json", data_files=_vfy_ds_info["path"], split="train")')
+ self._emit('print(f"[td_lang] Using previously downloaded dataset: {_vfy_ds_info[\'path\']}")')
+ self._indent -= 1
+ self._emit("else:")
+ self._indent += 1
+ self._emit(f'try:')
+ self._indent += 1
+ self._emit(f'_vfy_ds = load_dataset("{dataset}", split="test")')
+ self._indent -= 1
+ self._emit("except Exception:")
+ self._indent += 1
+ self._emit(f'_vfy_ds = load_dataset("{dataset}", split="train")')
+ self._indent -= 1
+ self._indent -= 1
+ self._emit("")
+ self._emit(f"_vfy_n = min({n}, len(_vfy_ds))")
+ self._emit("_vfy_indices = random.sample(range(len(_vfy_ds)), _vfy_n)")
+ self._emit("")
+
+ # Load model
+ self._emit(f'_vfy_checkpoint = models.get("{alias}", {{}}).get("checkpoint")')
+ self._emit("if not _vfy_checkpoint:")
+ self._indent += 1
+ self._emit(f'_vfy_checkpoint = models["{alias}"]["model_ref"]')
+ self._indent -= 1
+ self._emit("_vfy_tok = AutoTokenizer.from_pretrained(_vfy_checkpoint)")
+ self._emit('_vfy_model = AutoModelForCausalLM.from_pretrained(_vfy_checkpoint, torch_dtype=torch.bfloat16, device_map="auto")')
+ self._emit("_vfy_model.eval()")
+ self._emit("")
+
+ # Figure out dataset format and verify
+ self._emit("# Auto-detect dataset format (gsm8k, mmlu, hellaswag, etc)")
+ self._emit("_vfy_correct = 0")
+ self._emit("_vfy_details = []")
+ self._emit("")
+ self._emit("for idx in _vfy_indices:")
+ self._indent += 1
+ self._emit("row = _vfy_ds[idx]")
+ self._emit("")
+ self._emit("# Extract question and answer based on dataset format")
+ self._emit("question = row.get('question', row.get('prompt', row.get('input', row.get('text', ''))))")
+ self._emit("answer = row.get('answer', row.get('target', row.get('output', row.get('label', ''))))")
+ self._emit("")
+ self._emit("if not question or not answer:")
+ self._indent += 1
+ self._emit("continue")
+ self._indent -= 1
+ self._emit("")
+ self._emit("# Ask the model")
+ self._emit("_vfy_prompt = f'Answer concisely: {question}'")
+ self._emit('_vfy_inputs = _vfy_tok(_vfy_prompt, return_tensors="pt").to(_vfy_model.device)')
+ self._emit("with torch.no_grad():")
+ self._indent += 1
+ self._emit("_vfy_out = _vfy_model.generate(**_vfy_inputs, max_new_tokens=256, do_sample=False)")
+ self._indent -= 1
+ self._emit("_vfy_response = _vfy_tok.decode(_vfy_out[0], skip_special_tokens=True)")
+ self._emit("if _vfy_response.startswith(_vfy_prompt):")
+ self._indent += 1
+ self._emit("_vfy_response = _vfy_response[len(_vfy_prompt):].strip()")
+ self._indent -= 1
+ self._emit("")
+ self._emit("# Check if answer is correct (fuzzy matching)")
+ self._emit("answer_str = str(answer).strip().lower()")
+ self._emit("response_lower = _vfy_response.lower()")
+ self._emit("")
+ self._emit("# Try exact match first")
+ self._emit("correct = answer_str in response_lower")
+ self._emit("")
+ self._emit("# Try numeric match (for math datasets)")
+ self._emit("if not correct:")
+ self._indent += 1
+ self._emit("try:")
+ self._indent += 1
+ self._emit("# Extract numbers from both")
+ self._emit("ans_nums = re.findall(r'-?[\\d,]+\\.?\\d*', answer_str)")
+ self._emit("resp_nums = re.findall(r'-?[\\d,]+\\.?\\d*', response_lower)")
+ self._emit("if ans_nums and resp_nums:")
+ self._indent += 1
+ self._emit("ans_val = float(ans_nums[-1].replace(',', ''))")
+ self._emit("resp_val = float(resp_nums[-1].replace(',', ''))")
+ self._emit("correct = abs(ans_val - resp_val) < 0.01")
+ self._indent -= 1
+ self._indent -= 1
+ self._emit("except (ValueError, IndexError):")
+ self._indent += 1
+ self._emit("pass")
+ self._indent -= 2
+ self._emit("")
+ self._emit("if correct:")
+ self._indent += 1
+ self._emit("_vfy_correct += 1")
+ self._indent -= 1
+ self._emit('_vfy_details.append({"question": str(question)[:60], "expected": str(answer)[:40], "got": _vfy_response[:40], "correct": correct})')
+ self._indent -= 1
+
+ self._emit("")
+ self._emit("_vfy_accuracy = _vfy_correct / max(_vfy_n, 1)")
+ self._emit(f'print(f"[td_lang] VERIFY RESULTS: {alias} on {dataset}")')
+ self._emit('print(f" Correct: {_vfy_correct}/{_vfy_n} ({_vfy_accuracy:.1%})")')
+ self._emit('_vfy_label = "STRONG" if _vfy_accuracy >= 0.7 else "MODERATE" if _vfy_accuracy >= 0.4 else "WEAK - needs more training"')
+ self._emit('print(f" Verdict: {_vfy_label}")')
+ self._emit("")
+ self._emit(f'results["{alias}_verify"] = {{')
+ self._indent += 1
+ self._emit('"accuracy": round(_vfy_accuracy, 3),')
+ self._emit('"correct": _vfy_correct,')
+ self._emit('"total": _vfy_n,')
+ self._emit(f'"dataset": "{dataset}",')
+ self._emit('"details": _vfy_details,')
+ self._indent -= 1
+ self._emit("}")
+
+ if cmd.output:
+ self._emit(f'_vfy_path = Path("{cmd.output}")')
+ self._emit("_vfy_path.parent.mkdir(parents=True, exist_ok=True)")
+ self._emit(f'with open(_vfy_path, "w") as f:')
+ self._indent += 1
+ self._emit(f'json.dump(results["{alias}_verify"], f, indent=2, default=str)')
+ self._indent -= 1
+ self._emit(f'print(f"[td_lang] Verify results saved to {{_vfy_path}}")')
+
+ self._emit("del _vfy_model, _vfy_tok")
+ self._emit("import gc; gc.collect()")
+ self._emit("")
+
# ---------------------------------------------------------------- Budget + summary
def _emit_budget_check(self, program: TDProgram) -> None:
budget = program.budget or BudgetBlock()
@@ -2965,6 +3531,52 @@ DO NOT EDIT - regenerate from the .td file instead.
est_gpu += body_est # at least one run
elif isinstance(cmd, (NotifyCmd, SaveCmd)):
est_gpu += 0.01
+ elif isinstance(cmd, DownloadCmd):
+ est_gpu += 0.05 # download time
+ elif isinstance(cmd, CompareCmd):
+ est_gpu += 0.5 # load two models + run questions
+ est_tokens += 500_000
+ elif isinstance(cmd, VerifyCmd):
+ est_gpu += 0.3 # load model + run questions
+ est_tokens += 300_000
+ elif isinstance(cmd, VoteCmd):
+ est_gpu += 0.1 * cmd.samples # generate N answers
+ est_tokens += 50_000 * cmd.samples
+ elif isinstance(cmd, PromptBlock):
+ est_gpu += 0.0 # just sets a string, no compute
+ elif isinstance(cmd, DistillCmd):
+ steps = cmd.steps or 200
+ est_gpu += 1.0 + (steps / 100) * 0.5 # teacher inference + student training
+ est_tokens += steps * 150_000
+ est_experiments += 1
+ elif isinstance(cmd, RollbackCmd):
+ est_gpu += 0.15 # reload from snapshot
+ elif isinstance(cmd, CurriculumCmd):
+ est_gpu += cmd.levels * (0.5 + (cmd.steps / 64) * 1.5)
+ est_tokens += cmd.levels * cmd.steps * 100_000
+ est_experiments += cmd.levels
+ elif isinstance(cmd, StarCmd):
+ est_gpu += cmd.rounds * (0.3 + cmd.samples * 0.1)
+ est_tokens += cmd.rounds * cmd.samples * 200_000
+ est_experiments += cmd.rounds
+ elif isinstance(cmd, BestOfCmd):
+ est_gpu += 0.5 + (cmd.steps / 32) * 1.0
+ est_tokens += cmd.n * cmd.steps * 50_000
+ est_experiments += 1
+ elif isinstance(cmd, ExploitCmd):
+ est_gpu += 0.5 + cmd.samples * 0.05 + (cmd.steps / 32) * 1.0
+ est_tokens += cmd.samples * 100_000
+ est_experiments += 1
+ elif isinstance(cmd, ArenaCmd):
+ # Arena is expensive: episodes * rounds inference + rounds * steps training
+ est_gpu += cmd.rounds * (0.5 + cmd.episodes * 0.02 + (cmd.steps / 32) * 1.0)
+ est_tokens += cmd.rounds * cmd.episodes * 50_000
+ est_experiments += cmd.rounds
+ elif isinstance(cmd, ResearchArenaCmd):
+ # Research arena: source gathering + question generation + episodes + training
+ est_gpu += 0.5 + cmd.rounds * (0.5 + cmd.episodes * 0.05 + (cmd.steps / 32) * 1.0)
+ est_tokens += cmd.rounds * cmd.episodes * 80_000 # more tokens per episode (verification)
+ est_experiments += cmd.rounds
est_cost = est_gpu * self.GPU_HOURLY
@@ -2997,6 +3609,1830 @@ DO NOT EDIT - regenerate from the .td file instead.
self._emit('print("[td_lang] Budget check passed.")')
self._emit("")
+ # ---------------------------------------------------------------- Phase 12: RL & Fine-Tuning
+
+ def _emit_curriculum(self, cmd: CurriculumCmd, program: TDProgram) -> None:
+ """CURRICULUM - progressive difficulty training (SEC).
+
+ Splits problems into difficulty levels by answer length/complexity.
+ Trains on easy first, then medium, then hard.
+ Only advances when accuracy on current level exceeds 60%.
+ """
+ self._emit(f'print("[td_lang] Curriculum training {cmd.target}: {cmd.levels} levels, {cmd.steps} steps each...")')
+ self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")')
+ self._emit("if not checkpoint:")
+ self._indent += 1
+ self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]')
+ self._indent -= 1
+ self._emit("")
+ self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig")
+ self._emit("from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training")
+ self._emit("from trl import SFTTrainer")
+ self._emit("from datasets import load_dataset, Dataset")
+ self._emit("import torch")
+ self._emit("")
+ self._emit(f'dataset_path = "{cmd.dataset}"')
+ self._emit("if dataset_path.endswith('.jsonl'):")
+ self._indent += 1
+ self._emit("full_data = load_dataset('json', data_files=dataset_path, split='train')")
+ self._indent -= 1
+ self._emit("else:")
+ self._indent += 1
+ self._emit("full_data = load_dataset(dataset_path, split='train')")
+ self._indent -= 1
+ self._emit("")
+ self._emit("# Sort by difficulty (estimated by answer length - longer answers = harder problems)")
+ self._emit("text_key = 'text' if 'text' in full_data.column_names else full_data.column_names[0]")
+ self._emit("lengths = [len(str(row.get(text_key, row.get('answer', '')))) for row in full_data]")
+ self._emit("sorted_indices = sorted(range(len(lengths)), key=lambda i: lengths[i])")
+ self._emit(f"n_levels = {cmd.levels}")
+ self._emit("chunk_size = len(sorted_indices) // n_levels")
+ self._emit("")
+ self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)")
+ self._emit("if tok.pad_token is None:")
+ self._indent += 1
+ self._emit("tok.pad_token = tok.eos_token")
+ self._indent -= 1
+ self._emit("")
+ self._emit("for level in range(n_levels):")
+ self._indent += 1
+ self._emit("start_idx = level * chunk_size")
+ self._emit("end_idx = start_idx + chunk_size if level < n_levels - 1 else len(sorted_indices)")
+ self._emit("level_indices = sorted_indices[start_idx:end_idx]")
+ self._emit("level_data = full_data.select(level_indices)")
+ self._emit('_level_label = ["easy", "medium", "hard", "expert"][min(level, 3)]')
+ self._emit('print(f"[td_lang] Level {level+1}/{n_levels} ({_level_label}): {len(level_data)} examples")')
+ self._emit("")
+ self._emit("# Load fresh model each level (or continue from last checkpoint)")
+ self._emit("bnb_config = BitsAndBytesConfig(")
+ self._indent += 1
+ self._emit("load_in_4bit=True, bnb_4bit_quant_type='nf4',")
+ self._emit("bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True,")
+ self._indent -= 1
+ self._emit(")")
+ self._emit("model = AutoModelForCausalLM.from_pretrained(checkpoint, quantization_config=bnb_config, device_map='auto')")
+ self._emit("model = prepare_model_for_kbit_training(model)")
+ self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,")
+ self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")')
+ self._emit("model = get_peft_model(model, lora_config)")
+ self._emit("")
+ self._emit("from transformers import TrainingArguments")
+ self._emit(f"level_out = f'td_lang_outputs/curriculum_level_{{level}}'")
+ self._emit("training_args = TrainingArguments(")
+ self._indent += 1
+ self._emit("output_dir=level_out,")
+ self._emit(f"max_steps={cmd.steps},")
+ self._emit("per_device_train_batch_size=1,")
+ self._emit("gradient_accumulation_steps=4,")
+ self._emit("learning_rate=5e-5,")
+ self._emit("logging_steps=16,")
+ self._emit("bf16=True,")
+ self._emit("gradient_checkpointing=True,")
+ self._indent -= 1
+ self._emit(")")
+ self._emit("trainer = SFTTrainer(model=model, train_dataset=level_data, args=training_args, tokenizer=tok)")
+ self._emit("trainer.train()")
+ self._emit("trainer.save_model(level_out)")
+ self._emit("checkpoint = level_out # next level starts from this")
+ self._emit('print(f"[td_lang] Level {level+1} complete. Saved to {level_out}")')
+ self._emit("")
+ self._emit("del model")
+ self._emit("import gc; gc.collect()")
+ self._emit("try:")
+ self._indent += 1
+ self._emit("torch.cuda.empty_cache()")
+ self._indent -= 1
+ self._emit("except Exception:")
+ self._indent += 1
+ self._emit("pass")
+ self._indent -= 2
+ self._emit("")
+ self._emit(f'models["{cmd.target}"]["checkpoint"] = checkpoint')
+ self._emit(f'print("[td_lang] Curriculum training complete. Model progressed through {{n_levels}} levels.")')
+ self._emit(f'lineage["{cmd.target}"]["operations"].append({{')
+ self._indent += 1
+ self._emit('"op": "curriculum",')
+ self._emit(f'"dataset": "{cmd.dataset}",')
+ self._emit(f'"levels": {cmd.levels},')
+ self._emit(f'"steps_per_level": {cmd.steps},')
+ self._emit('"timestamp": datetime.now().isoformat(),')
+ self._indent -= 1
+ self._emit("})")
+
+ def _emit_star(self, cmd: StarCmd, program: TDProgram) -> None:
+ """STaR - Self-Taught Reasoner.
+
+ For each problem: generate N solutions, check which are correct,
+ train on the correct reasoning chains. Repeat for R rounds.
+ The model learns from its own successes.
+ """
+ self._emit(f'print("[td_lang] STaR training {cmd.target}: {cmd.rounds} rounds, {cmd.samples} samples/problem...")')
+ self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")')
+ self._emit("if not checkpoint:")
+ self._indent += 1
+ self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]')
+ self._indent -= 1
+ self._emit("")
+ self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments")
+ self._emit("from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training")
+ self._emit("from trl import SFTTrainer")
+ self._emit("from datasets import load_dataset, Dataset")
+ self._emit("import torch, re")
+ self._emit("")
+ self._emit(f'dataset_path = "{cmd.dataset}"')
+ self._emit("if dataset_path.endswith('.jsonl'):")
+ self._indent += 1
+ self._emit("raw_data = load_dataset('json', data_files=dataset_path, split='train')")
+ self._indent -= 1
+ self._emit("else:")
+ self._indent += 1
+ self._emit("raw_data = load_dataset(dataset_path, split='train')")
+ self._indent -= 1
+ self._emit("")
+ self._emit("# Extract question-answer pairs")
+ self._emit("qa_pairs = []")
+ self._emit("for row in raw_data:")
+ self._indent += 1
+ self._emit("q = row.get('question', row.get('prompt', row.get('text', '')))")
+ self._emit("a = str(row.get('answer', row.get('response', row.get('label', ''))))")
+ self._emit("if q and a:")
+ self._indent += 1
+ self._emit("qa_pairs.append((q, a))")
+ self._indent -= 2
+ self._emit("qa_pairs = qa_pairs[:200] # cap at 200 problems per round")
+ self._emit("")
+ self._emit(f"for star_round in range({cmd.rounds}):")
+ self._indent += 1
+ self._emit('print(f"[td_lang] STaR round {star_round+1}/{' + str(cmd.rounds) + '}...")')
+ self._emit("")
+ self._emit("# Step 1: Generate solutions")
+ self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)")
+ self._emit("if tok.pad_token is None:")
+ self._indent += 1
+ self._emit("tok.pad_token = tok.eos_token")
+ self._indent -= 1
+ self._emit("bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4',")
+ self._emit(" bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)")
+ self._emit("model = AutoModelForCausalLM.from_pretrained(checkpoint, quantization_config=bnb_config, device_map='auto')")
+ self._emit("model.eval()")
+ self._emit("")
+ self._emit("correct_chains = []")
+ self._emit("total_tried = 0")
+ self._emit("for q, expected_a in qa_pairs:")
+ self._indent += 1
+ self._emit("inputs = tok(q, return_tensors='pt').to(model.device)")
+ self._emit(f"for sample_i in range({cmd.samples}):")
+ self._indent += 1
+ self._emit("with torch.no_grad():")
+ self._indent += 1
+ self._emit("out = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9)")
+ self._indent -= 1
+ self._emit("resp = tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)")
+ self._emit("total_tried += 1")
+ self._emit("# Check if answer is correct (fuzzy match)")
+ self._emit("resp_lower = resp.lower().strip()")
+ self._emit("expected_lower = expected_a.lower().strip()")
+ self._emit("# Extract numbers for math comparison")
+ self._emit("resp_nums = re.findall(r'-?\\d+\\.?\\d*', resp_lower)")
+ self._emit("exp_nums = re.findall(r'-?\\d+\\.?\\d*', expected_lower)")
+ self._emit("is_correct = expected_lower in resp_lower")
+ self._emit("if not is_correct and resp_nums and exp_nums:")
+ self._indent += 1
+ self._emit("try:")
+ self._indent += 1
+ self._emit("is_correct = abs(float(resp_nums[-1]) - float(exp_nums[-1])) < 0.01")
+ self._indent -= 1
+ self._emit("except ValueError:")
+ self._indent += 1
+ self._emit("pass")
+ self._indent -= 2
+ self._emit("if is_correct:")
+ self._indent += 1
+ self._emit("correct_chains.append(q + '\\n' + resp)")
+ self._emit("break # got a correct answer, move to next problem")
+ self._indent -= 3
+ self._emit("")
+ self._emit("del model")
+ self._emit("import gc; gc.collect()")
+ self._emit("try:")
+ self._indent += 1
+ self._emit("torch.cuda.empty_cache()")
+ self._indent -= 1
+ self._emit("except Exception:")
+ self._indent += 1
+ self._emit("pass")
+ self._indent -= 1
+ self._emit("")
+ self._emit('print(f"[td_lang] Round {star_round+1}: {len(correct_chains)} correct chains from {total_tried} attempts")')
+ self._emit("")
+ self._emit("if len(correct_chains) < 5:")
+ self._indent += 1
+ self._emit('print("[td_lang] Too few correct chains - skipping training this round")')
+ self._emit("continue")
+ self._indent -= 1
+ self._emit("")
+ self._emit("# Step 2: Train on correct reasoning chains")
+ self._emit("ds = Dataset.from_dict({'text': correct_chains})")
+ self._emit("model = AutoModelForCausalLM.from_pretrained(checkpoint, quantization_config=bnb_config, device_map='auto')")
+ self._emit("model = prepare_model_for_kbit_training(model)")
+ self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,")
+ self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")')
+ self._emit("model = get_peft_model(model, lora_config)")
+ self._emit("star_out = f'td_lang_outputs/star_round_{star_round}'")
+ self._emit("training_args = TrainingArguments(output_dir=star_out, max_steps=32,")
+ self._emit(" per_device_train_batch_size=1, gradient_accumulation_steps=4,")
+ self._emit(" learning_rate=5e-5, logging_steps=8, bf16=True, gradient_checkpointing=True)")
+ self._emit("trainer = SFTTrainer(model=model, train_dataset=ds, args=training_args, tokenizer=tok)")
+ self._emit("trainer.train()")
+ self._emit("trainer.save_model(star_out)")
+ self._emit("checkpoint = star_out")
+ self._emit('print(f"[td_lang] STaR round {star_round+1} trained on {len(correct_chains)} chains. Saved to {star_out}")')
+ self._emit("del model; gc.collect()")
+ self._emit("try:")
+ self._indent += 1
+ self._emit("torch.cuda.empty_cache()")
+ self._indent -= 1
+ self._emit("except Exception:")
+ self._indent += 1
+ self._emit("pass")
+ self._indent -= 2
+ self._emit("")
+ self._emit(f'models["{cmd.target}"]["checkpoint"] = checkpoint')
+ self._emit(f'print("[td_lang] STaR complete after {cmd.rounds} rounds.")')
+ self._emit(f'lineage["{cmd.target}"]["operations"].append({{')
+ self._indent += 1
+ self._emit('"op": "star",')
+ self._emit(f'"dataset": "{cmd.dataset}",')
+ self._emit(f'"rounds": {cmd.rounds},')
+ self._emit(f'"samples_per_problem": {cmd.samples},')
+ self._emit('"timestamp": datetime.now().isoformat(),')
+ self._indent -= 1
+ self._emit("})")
+
+ def _emit_best_of(self, cmd: BestOfCmd, program: TDProgram) -> None:
+ """BEST_OF - generate N answers, score all, keep the best, train on it.
+
+ Like vote but for training. 80-90% of RLHF gains at fraction of cost.
+ """
+ self._emit(f'print("[td_lang] Best-of-{cmd.n} training on {cmd.target}...")')
+ self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")')
+ self._emit("if not checkpoint:")
+ self._indent += 1
+ self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]')
+ self._indent -= 1
+ self._emit("")
+ self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments")
+ self._emit("from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training")
+ self._emit("from trl import SFTTrainer")
+ self._emit("from datasets import load_dataset, Dataset")
+ self._emit("import torch, re, ast as _ast")
+ self._emit("")
+ self._emit(f'dataset_path = "{cmd.dataset}"')
+ self._emit("if dataset_path.endswith('.jsonl'):")
+ self._indent += 1
+ self._emit("raw_data = load_dataset('json', data_files=dataset_path, split='train')")
+ self._indent -= 1
+ self._emit("else:")
+ self._indent += 1
+ self._emit("raw_data = load_dataset(dataset_path, split='train')")
+ self._indent -= 1
+ self._emit("")
+ self._emit("# Extract questions")
+ self._emit("questions = []")
+ self._emit("for row in raw_data:")
+ self._indent += 1
+ self._emit("q = row.get('question', row.get('prompt', row.get('text', '')))")
+ self._emit("if q:")
+ self._indent += 1
+ self._emit("questions.append(q)")
+ self._indent -= 2
+ self._emit("questions = questions[:100] # cap at 100")
+ self._emit("")
+ self._emit("# Generate N answers per question, score them, keep the best")
+ self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)")
+ self._emit("if tok.pad_token is None:")
+ self._indent += 1
+ self._emit("tok.pad_token = tok.eos_token")
+ self._indent -= 1
+ self._emit("bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4',")
+ self._emit(" bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)")
+ self._emit("model = AutoModelForCausalLM.from_pretrained(checkpoint, quantization_config=bnb_config, device_map='auto')")
+ self._emit("model.eval()")
+ self._emit("")
+ self._emit("def _score_response(resp):")
+ self._indent += 1
+ self._emit("score = 0.0")
+ self._emit("# Length reward (not too short, not too long)")
+ self._emit("words = len(resp.split())")
+ self._emit("if 10 < words < 500:")
+ self._indent += 1
+ self._emit("score += 0.2")
+ self._indent -= 1
+ self._emit("# Structure reward (has reasoning markers)")
+ self._emit("markers = ['because', 'therefore', 'step', 'first', 'then', 'answer', 'result']")
+ self._emit("score += 0.1 * min(sum(1 for m in markers if m in resp.lower()), 3)")
+ self._emit("# Code compilation bonus")
+ self._emit("code_blocks = re.findall(r'```python\\n(.*?)```', resp, re.S)")
+ self._emit("for block in code_blocks:")
+ self._indent += 1
+ self._emit("try:")
+ self._indent += 1
+ self._emit("_ast.parse(block)")
+ self._emit("score += 0.3")
+ self._emit("break")
+ self._indent -= 1
+ self._emit("except SyntaxError:")
+ self._indent += 1
+ self._emit("pass")
+ self._indent -= 2
+ self._emit("# Confidence bonus (states a clear answer)")
+ self._emit("if any(p in resp.lower() for p in ['the answer is', 'result:', 'output:']):")
+ self._indent += 1
+ self._emit("score += 0.2")
+ self._indent -= 1
+ self._emit("return score")
+ self._indent -= 1
+ self._emit("")
+ self._emit("best_completions = []")
+ self._emit("for qi, q in enumerate(questions):")
+ self._indent += 1
+ self._emit("inputs = tok(q, return_tensors='pt').to(model.device)")
+ self._emit("candidates = []")
+ self._emit(f"for _ in range({cmd.n}):")
+ self._indent += 1
+ self._emit("with torch.no_grad():")
+ self._indent += 1
+ self._emit("out = model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.8, top_p=0.95)")
+ self._indent -= 1
+ self._emit("resp = tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)")
+ self._emit("candidates.append((resp, _score_response(resp)))")
+ self._indent -= 1
+ self._emit("best = max(candidates, key=lambda x: x[1])")
+ self._emit("best_completions.append(q + '\\n' + best[0])")
+ self._emit("if qi % 20 == 0:")
+ self._indent += 1
+ self._emit('print(f" Generated best-of-N for {qi+1}/{len(questions)} questions...")')
+ self._indent -= 2
+ self._emit("")
+ self._emit("del model; import gc; gc.collect()")
+ self._emit("try:")
+ self._indent += 1
+ self._emit("torch.cuda.empty_cache()")
+ self._indent -= 1
+ self._emit("except Exception:")
+ self._indent += 1
+ self._emit("pass")
+ self._indent -= 1
+ self._emit("")
+ self._emit("# Train on the best completions")
+ self._emit(f'print(f"[td_lang] Training on {{len(best_completions)}} best-of-{cmd.n} completions...")')
+ self._emit("ds = Dataset.from_dict({'text': best_completions})")
+ self._emit("model = AutoModelForCausalLM.from_pretrained(checkpoint, quantization_config=bnb_config, device_map='auto')")
+ self._emit("model = prepare_model_for_kbit_training(model)")
+ self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,")
+ self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")')
+ self._emit("model = get_peft_model(model, lora_config)")
+ self._emit("bon_out = 'td_lang_outputs/best_of_n_trained'")
+ self._emit(f"training_args = TrainingArguments(output_dir=bon_out, max_steps={cmd.steps},")
+ self._emit(" per_device_train_batch_size=1, gradient_accumulation_steps=4,")
+ self._emit(" learning_rate=5e-5, logging_steps=8, bf16=True, gradient_checkpointing=True)")
+ self._emit("trainer = SFTTrainer(model=model, train_dataset=ds, args=training_args, tokenizer=tok)")
+ self._emit("trainer.train()")
+ self._emit("trainer.save_model(bon_out)")
+ self._emit(f'models["{cmd.target}"]["checkpoint"] = bon_out')
+ self._emit(f'print("[td_lang] Best-of-{cmd.n} training complete.")')
+ self._emit("del model; gc.collect()")
+ self._emit(f'lineage["{cmd.target}"]["operations"].append({{')
+ self._indent += 1
+ self._emit('"op": "best_of",')
+ self._emit(f'"n": {cmd.n},')
+ self._emit(f'"steps": {cmd.steps},')
+ self._emit('"n_examples": len(best_completions),')
+ self._emit('"timestamp": datetime.now().isoformat(),')
+ self._indent -= 1
+ self._emit("})")
+
+ def _emit_exploit(self, cmd: ExploitCmd, program: TDProgram) -> None:
+ """EXPLOIT - controlled reward hacking.
+
+ Generate MANY diverse solutions (high temp, high diversity).
+ Only filter: is the final answer correct? (verified reward)
+ Keep ALL correct solutions - ugly ones, shortcuts, weird reasoning.
+ Train on the diverse set. The model learns multiple paths to correct answers.
+ The "hacks" often turn out to be genuinely clever shortcuts.
+ """
+ self._emit(f'print("[td_lang] EXPLOIT mode: controlled reward hacking on {cmd.target}...")')
+ self._emit(f'print("[td_lang] Generating {cmd.samples} diverse solutions per problem...")')
+ self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")')
+ self._emit("if not checkpoint:")
+ self._indent += 1
+ self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]')
+ self._indent -= 1
+ self._emit("")
+ self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments")
+ self._emit("from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training")
+ self._emit("from trl import SFTTrainer")
+ self._emit("from datasets import load_dataset, Dataset")
+ self._emit("import torch, re, json")
+ self._emit("")
+ self._emit(f'dataset_path = "{cmd.dataset}"')
+ self._emit("if dataset_path.endswith('.jsonl'):")
+ self._indent += 1
+ self._emit("raw_data = load_dataset('json', data_files=dataset_path, split='train')")
+ self._indent -= 1
+ self._emit("else:")
+ self._indent += 1
+ self._emit("raw_data = load_dataset(dataset_path, split='train')")
+ self._indent -= 1
+ self._emit("")
+ self._emit("# Extract question-answer pairs")
+ self._emit("qa_pairs = []")
+ self._emit("for row in raw_data:")
+ self._indent += 1
+ self._emit("q = row.get('question', row.get('prompt', row.get('text', '')))")
+ self._emit("a = str(row.get('answer', row.get('response', row.get('label', ''))))")
+ self._emit("if q and a:")
+ self._indent += 1
+ self._emit("qa_pairs.append((q, a))")
+ self._indent -= 2
+ self._emit("qa_pairs = qa_pairs[:100] # cap at 100 problems")
+ self._emit('print(f"[td_lang] {len(qa_pairs)} problems loaded")')
+ self._emit("")
+ self._emit("# Load model for generation")
+ self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)")
+ self._emit("if tok.pad_token is None:")
+ self._indent += 1
+ self._emit("tok.pad_token = tok.eos_token")
+ self._indent -= 1
+ self._emit("bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4',")
+ self._emit(" bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)")
+ self._emit("model = AutoModelForCausalLM.from_pretrained(checkpoint, quantization_config=bnb_config, device_map='auto')")
+ self._emit("model.eval()")
+ self._emit("")
+ self._emit("# EXPLOIT: Generate MANY diverse solutions with HIGH temperature")
+ self._emit("# Key insight: we WANT weird/creative solutions. High temp = more diversity.")
+ self._emit("exploit_data = [] # all correct solutions, regardless of method")
+ self._emit("total_correct = 0")
+ self._emit("total_generated = 0")
+ self._emit("exploit_log = [] # for inspection")
+ self._emit("")
+ self._emit("for qi, (q, expected_a) in enumerate(qa_pairs):")
+ self._indent += 1
+ self._emit("inputs = tok(q, return_tensors='pt').to(model.device)")
+ self._emit("correct_for_this = []")
+ self._emit("")
+ self._emit(f"for sample_i in range({cmd.samples}):")
+ self._indent += 1
+ self._emit("# Vary temperature per sample for maximum diversity")
+ self._emit(f"temp = 0.5 + (sample_i / {cmd.samples}) * 1.0 # range 0.5 to 1.5")
+ self._emit("with torch.no_grad():")
+ self._indent += 1
+ self._emit("out = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=temp, top_p=0.95, top_k=50)")
+ self._indent -= 1
+ self._emit("resp = tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)")
+ self._emit("total_generated += 1")
+ self._emit("")
+ self._emit("# ONLY check: is the final answer correct?")
+ self._emit("# We DON'T check reasoning quality, format, or style.")
+ self._emit("resp_lower = resp.lower().strip()")
+ self._emit("expected_lower = expected_a.lower().strip()")
+ self._emit("resp_nums = re.findall(r'-?\\d+\\.?\\d*', resp_lower)")
+ self._emit("exp_nums = re.findall(r'-?\\d+\\.?\\d*', expected_lower)")
+ self._emit("is_correct = expected_lower in resp_lower")
+ self._emit("if not is_correct and resp_nums and exp_nums:")
+ self._indent += 1
+ self._emit("try:")
+ self._indent += 1
+ self._emit("is_correct = abs(float(resp_nums[-1]) - float(exp_nums[-1])) < 0.01")
+ self._indent -= 1
+ self._emit("except ValueError:")
+ self._indent += 1
+ self._emit("pass")
+ self._indent -= 2
+ self._emit("")
+ self._emit("if is_correct:")
+ self._indent += 1
+ self._emit("correct_for_this.append(resp)")
+ self._emit("total_correct += 1")
+ self._emit("# Keep ALL correct solutions - even short, weird, or hacky ones")
+ self._emit("exploit_data.append(q + '\\n' + resp)")
+ self._indent -= 2
+ self._emit("")
+ self._emit("if correct_for_this:")
+ self._indent += 1
+ self._emit("exploit_log.append({")
+ self._indent += 1
+ self._emit("'question': q,")
+ self._emit("'expected': expected_a,")
+ self._emit("'n_correct': len(correct_for_this),")
+ self._emit(f"'n_attempts': {cmd.samples},")
+ self._emit("'solutions': correct_for_this,")
+ self._emit("'diversity': len(set(s[:50] for s in correct_for_this)), # unique starts")
+ self._indent -= 1
+ self._emit("})")
+ self._indent -= 1
+ self._emit("")
+ self._emit("if qi % 20 == 0:")
+ self._indent += 1
+ self._emit('print(f" Problem {qi+1}/{len(qa_pairs)}: {len(correct_for_this)} correct solutions found")')
+ self._indent -= 2
+ self._emit("")
+ self._emit("del model; import gc; gc.collect()")
+ self._emit("try:")
+ self._indent += 1
+ self._emit("torch.cuda.empty_cache()")
+ self._indent -= 1
+ self._emit("except Exception:")
+ self._indent += 1
+ self._emit("pass")
+ self._indent -= 1
+ self._emit("")
+ self._emit("_hit_rate = (total_correct / total_generated * 100) if total_generated else 0")
+ self._emit('print(f"[td_lang] EXPLOIT results: {total_correct} correct solutions from {total_generated} attempts ({_hit_rate:.1f}% hit rate)")')
+ self._emit('print(f"[td_lang] {len(exploit_data)} training examples with diverse reasoning paths")')
+ self._emit("")
+ # Save exploit data if output specified
+ if cmd.output:
+ self._emit(f'exploit_path = Path("{cmd.output}")')
+ self._emit("exploit_path.parent.mkdir(parents=True, exist_ok=True)")
+ self._emit('with open(exploit_path, "w") as f:')
+ self._indent += 1
+ self._emit("json.dump(exploit_log, f, indent=2)")
+ self._indent -= 1
+ self._emit('print(f"[td_lang] Exploit data saved to {exploit_path} (inspect to see the creative solutions)")')
+ self._emit("")
+ self._emit("if len(exploit_data) < 5:")
+ self._indent += 1
+ self._emit('print("[td_lang] Too few correct solutions found - skipping training")')
+ self._indent -= 1
+ self._emit("else:")
+ self._indent += 1
+ self._emit("# Train on ALL correct solutions (the controlled hack)")
+ self._emit(f'print("[td_lang] Training on {{len(exploit_data)}} diverse correct solutions...")')
+ self._emit("ds = Dataset.from_dict({'text': exploit_data})")
+ self._emit("model = AutoModelForCausalLM.from_pretrained(checkpoint, quantization_config=bnb_config, device_map='auto')")
+ self._emit("model = prepare_model_for_kbit_training(model)")
+ self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,")
+ self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")')
+ self._emit("model = get_peft_model(model, lora_config)")
+ self._emit("exploit_out = 'td_lang_outputs/exploit_trained'")
+ self._emit(f"training_args = TrainingArguments(output_dir=exploit_out, max_steps={cmd.steps},")
+ self._emit(" per_device_train_batch_size=1, gradient_accumulation_steps=4,")
+ self._emit(" learning_rate=5e-5, logging_steps=8, bf16=True, gradient_checkpointing=True)")
+ self._emit("trainer = SFTTrainer(model=model, train_dataset=ds, args=training_args, tokenizer=tok)")
+ self._emit("trainer.train()")
+ self._emit("trainer.save_model(exploit_out)")
+ self._emit(f'models["{cmd.target}"]["checkpoint"] = exploit_out')
+ self._emit('print("[td_lang] EXPLOIT training complete. Model learned multiple solution paths.")')
+ self._emit("del model; gc.collect()")
+ self._indent -= 1
+ self._emit(f'lineage["{cmd.target}"]["operations"].append({{')
+ self._indent += 1
+ self._emit('"op": "exploit",')
+ self._emit(f'"dataset": "{cmd.dataset}",')
+ self._emit(f'"samples_per_problem": {cmd.samples},')
+ self._emit('"total_correct": total_correct,')
+ self._emit('"total_generated": total_generated,')
+ self._emit('"timestamp": datetime.now().isoformat(),')
+ self._indent -= 1
+ self._emit("})")
+
+ # ---------------------------------------------------------------- Phase 13: Real RL (Arena)
+ def _emit_arena(self, cmd: ArenaCmd, program: TDProgram) -> None:
+ """ARENA - real reinforcement learning with environment, memory, curiosity, and anti-lying.
+
+ The model enters an arena of challenges. For each episode:
+ 1. Picks a challenge from the dataset
+ 2. Generates a solution (exploring with some randomness)
+ 3. Gets IMMEDIATE reward/punishment:
+ - +1.0 for correct answer
+ - -1.0 for wrong answer
+ - -2.0 for LYING (confident but wrong — the worst offence)
+ - +curiosity_bonus for trying a NEW approach not in memory
+ 4. Stores the experience in a memory bank (approach + outcome)
+ 5. After N episodes, cross-checks creative solutions against standard ones
+ 6. Trains on reward-weighted experiences (good experiences get more weight)
+
+ Memory persists across rounds so the model doesn't "forget the button makes
+ the door safe." Curiosity reward encourages trying new things so it doesn't
+ get stuck avoiding things that failed once.
+ """
+ self._emit(f'print("[td_lang] ARENA: Real RL environment for {cmd.target}")')
+ self._emit(f'print("[td_lang] Rounds: {cmd.rounds}, Episodes/round: {cmd.episodes}")')
+ self._emit(f'print("[td_lang] Curiosity weight: {cmd.curiosity}")')
+ self._emit(f'print("[td_lang] Punishment for lying: -2.0 (confident + wrong)")')
+ self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")')
+ self._emit("if not checkpoint:")
+ self._indent += 1
+ self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]')
+ self._indent -= 1
+ self._emit("")
+ self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments")
+ self._emit("from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training")
+ self._emit("from trl import SFTTrainer")
+ self._emit("from datasets import load_dataset, Dataset")
+ self._emit("import torch, re, json, hashlib, random")
+ self._emit("")
+ # Load dataset
+ self._emit(f'dataset_path = "{cmd.dataset}"')
+ self._emit("if dataset_path.endswith('.jsonl'):")
+ self._indent += 1
+ self._emit("raw_data = load_dataset('json', data_files=dataset_path, split='train')")
+ self._indent -= 1
+ self._emit("else:")
+ self._indent += 1
+ self._emit("raw_data = load_dataset(dataset_path, split='train')")
+ self._indent -= 1
+ self._emit("")
+ self._emit("# Extract question-answer pairs for the arena")
+ self._emit("arena_challenges = []")
+ self._emit("for row in raw_data:")
+ self._indent += 1
+ self._emit("q = row.get('question', row.get('prompt', row.get('text', '')))")
+ self._emit("a = str(row.get('answer', row.get('response', row.get('label', ''))))")
+ self._emit("if q and a:")
+ self._indent += 1
+ self._emit("arena_challenges.append((q, a))")
+ self._indent -= 2
+ self._emit('print(f"[td_lang] Arena loaded {len(arena_challenges)} challenges")')
+ self._emit("")
+ # Memory bank — persists across ALL rounds
+ self._emit("# === MEMORY BANK ===")
+ self._emit("# Persists across rounds so the model remembers what worked.")
+ self._emit("# Each entry: {approach_hash, question_hash, reward, response_text}")
+ self._emit("# This prevents the 'forgot the button makes the door safe' problem.")
+ self._emit("memory_bank = [] # list of (approach_hash, question_hash, reward, text)")
+ self._emit("seen_approaches = set() # hashes of approaches tried (for curiosity)")
+ self._emit("arena_log = [] # full log for inspection")
+ self._emit("")
+ # Helper functions
+ self._emit("def _hash_approach(response):")
+ self._indent += 1
+ self._emit('"""Hash the reasoning approach (first 200 chars) to detect novelty."""')
+ self._emit("# Strip numbers/specifics to capture the METHOD not the answer")
+ self._emit("method = re.sub(r'\\d+', 'N', response[:200]).strip().lower()")
+ self._emit("return hashlib.md5(method.encode()).hexdigest()[:12]")
+ self._indent -= 1
+ self._emit("")
+ self._emit("def _check_correct(response, expected):")
+ self._indent += 1
+ self._emit('"""Check if response contains the correct answer."""')
+ self._emit("resp_lower = response.lower().strip()")
+ self._emit("exp_lower = expected.lower().strip()")
+ self._emit("# Direct text match")
+ self._emit("if exp_lower in resp_lower:")
+ self._indent += 1
+ self._emit("return True")
+ self._indent -= 1
+ self._emit("# Numeric match")
+ self._emit("resp_nums = re.findall(r'-?\\d+\\.?\\d*', resp_lower)")
+ self._emit("exp_nums = re.findall(r'-?\\d+\\.?\\d*', exp_lower)")
+ self._emit("if resp_nums and exp_nums:")
+ self._indent += 1
+ self._emit("try:")
+ self._indent += 1
+ self._emit("return abs(float(resp_nums[-1]) - float(exp_nums[-1])) < 0.01")
+ self._indent -= 1
+ self._emit("except ValueError:")
+ self._indent += 1
+ self._emit("pass")
+ self._indent -= 2
+ self._emit("return False")
+ self._indent -= 1
+ self._emit("")
+ self._emit("def _detect_lying(response, is_correct):")
+ self._indent += 1
+ self._emit('"""Detect if the model is LYING - confident but wrong."""')
+ self._emit("if is_correct:")
+ self._indent += 1
+ self._emit("return False # can't be lying if correct")
+ self._indent -= 1
+ self._emit("# Check for confident language in a wrong answer")
+ self._emit("confidence_markers = ['the answer is', 'definitely', 'clearly', 'obviously',")
+ self._emit(" 'without a doubt', 'i am certain', 'i am sure', 'absolutely',")
+ self._emit(" 'the correct answer', 'the result is', 'therefore the answer']")
+ self._emit("resp_lower = response.lower()")
+ self._emit("confidence_count = sum(1 for m in confidence_markers if m in resp_lower)")
+ self._emit("# If 2+ confidence markers in a WRONG answer = lying")
+ self._emit("return confidence_count >= 2")
+ self._indent -= 1
+ self._emit("")
+ self._emit("def _cross_check(response, question, expected, model, tok):")
+ self._indent += 1
+ self._emit('"""Cross-check a creative solution against standard approach."""')
+ self._emit("# Generate 2 standard solutions (low temp = conservative)")
+ self._emit("standard_answers = []")
+ self._emit("inputs = tok(question, return_tensors='pt').to(model.device)")
+ self._emit("for _ in range(2):")
+ self._indent += 1
+ self._emit("with torch.no_grad():")
+ self._indent += 1
+ self._emit("out = model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.3, top_p=0.9)")
+ self._indent -= 1
+ self._emit("std_resp = tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)")
+ self._emit("standard_answers.append(std_resp)")
+ self._indent -= 1
+ self._emit("# Check if creative answer matches standard ones")
+ self._emit("creative_correct = _check_correct(response, expected)")
+ self._emit("std_correct = [_check_correct(s, expected) for s in standard_answers]")
+ self._emit("# Case 1: creative matches standard — verified good")
+ self._emit("if creative_correct and any(std_correct):")
+ self._indent += 1
+ self._emit("return 'verified'")
+ self._indent -= 1
+ self._emit("# Case 2: creative correct but standards failed — creative is BETTER")
+ self._emit("if creative_correct and not any(std_correct):")
+ self._indent += 1
+ self._emit("return 'superior' # creative found something standards missed")
+ self._indent -= 1
+ self._emit("# Case 3: creative wrong — reject")
+ self._emit("if not creative_correct:")
+ self._indent += 1
+ self._emit("return 'wrong'")
+ self._indent -= 1
+ self._emit("return 'verified'")
+ self._indent -= 1
+ self._emit("")
+ # Main arena loop
+ self._emit(f"for arena_round in range({cmd.rounds}):")
+ self._indent += 1
+ self._emit(f'print(f"\\n[td_lang] === ARENA ROUND {{arena_round+1}}/{cmd.rounds} ===")')
+ self._emit("")
+ self._emit("# Load model for this round")
+ self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)")
+ self._emit("if tok.pad_token is None:")
+ self._indent += 1
+ self._emit("tok.pad_token = tok.eos_token")
+ self._indent -= 1
+ self._emit("bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4',")
+ self._emit(" bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)")
+ self._emit("model = AutoModelForCausalLM.from_pretrained(checkpoint, quantization_config=bnb_config, device_map='auto')")
+ self._emit("model.eval()")
+ self._emit("")
+ # Episode loop
+ self._emit("round_experiences = [] # (text, reward) pairs for this round")
+ self._emit("round_stats = {'correct': 0, 'wrong': 0, 'lying': 0, 'curious': 0, 'cross_checked': 0}")
+ self._emit(f"episode_challenges = random.sample(arena_challenges, min({cmd.episodes}, len(arena_challenges)))")
+ self._emit("")
+ self._emit("for ep_i, (question, expected) in enumerate(episode_challenges):")
+ self._indent += 1
+ self._emit("q_hash = hashlib.md5(question.encode()).hexdigest()[:12]")
+ self._emit("")
+ self._emit("# Generate a solution (explore with moderate randomness)")
+ self._emit("inputs = tok(question, return_tensors='pt').to(model.device)")
+ self._emit("# Temperature increases slightly each round to encourage more exploration")
+ self._emit(f"temp = 0.6 + (arena_round * 0.1) + random.uniform(-0.1, 0.1)")
+ self._emit("temp = max(0.3, min(temp, 1.5)) # clamp")
+ self._emit("with torch.no_grad():")
+ self._indent += 1
+ self._emit("out = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=temp, top_p=0.95, top_k=50)")
+ self._indent -= 1
+ self._emit("response = tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)")
+ self._emit("")
+ # Reward calculation
+ self._emit("# === REWARD CALCULATION ===")
+ self._emit("approach_hash = _hash_approach(response)")
+ self._emit("is_correct = _check_correct(response, expected)")
+ self._emit("is_lying = _detect_lying(response, is_correct)")
+ self._emit("")
+ self._emit("# Base reward: +1 correct, -1 wrong, -2 lying")
+ self._emit("if is_lying:")
+ self._indent += 1
+ self._emit("reward = -2.0 # WORST punishment: confident + wrong")
+ self._emit("round_stats['lying'] += 1")
+ self._indent -= 1
+ self._emit("elif is_correct:")
+ self._indent += 1
+ self._emit("reward = 1.0")
+ self._emit("round_stats['correct'] += 1")
+ self._indent -= 1
+ self._emit("else:")
+ self._indent += 1
+ self._emit("reward = -1.0")
+ self._emit("round_stats['wrong'] += 1")
+ self._indent -= 1
+ self._emit("")
+ # Curiosity bonus
+ self._emit("# === CURIOSITY BONUS ===")
+ self._emit("# Reward for trying something NEW (approach not in memory)")
+ self._emit("novelty_key = f'{q_hash}_{approach_hash}'")
+ self._emit("if novelty_key not in seen_approaches:")
+ self._indent += 1
+ self._emit(f"reward += {cmd.curiosity} # curiosity bonus!")
+ self._emit("seen_approaches.add(novelty_key)")
+ self._emit("round_stats['curious'] += 1")
+ self._indent -= 1
+ self._emit("")
+ # Cross-check creative solutions
+ self._emit("# === CROSS-CHECK ===")
+ self._emit("# If the model found a correct answer, verify it against standard approach")
+ self._emit("cross_result = None")
+ self._emit("if is_correct:")
+ self._indent += 1
+ self._emit("cross_result = _cross_check(response, question, expected, model, tok)")
+ self._emit("round_stats['cross_checked'] += 1")
+ self._emit("if cross_result == 'superior':")
+ self._indent += 1
+ self._emit("reward += 0.5 # extra reward for finding something better than standard")
+ self._indent -= 1
+ self._indent -= 1
+ self._emit("")
+ # Store experience in memory
+ self._emit("# === MEMORY ===")
+ self._emit("# Store this experience so the model REMEMBERS what worked")
+ self._emit("memory_entry = {")
+ self._indent += 1
+ self._emit("'approach_hash': approach_hash,")
+ self._emit("'question_hash': q_hash,")
+ self._emit("'reward': reward,")
+ self._emit("'is_correct': is_correct,")
+ self._emit("'is_lying': is_lying,")
+ self._emit("'cross_check': cross_result,")
+ self._emit("'round': arena_round,")
+ self._emit("'episode': ep_i,")
+ self._indent -= 1
+ self._emit("}")
+ self._emit("memory_bank.append(memory_entry)")
+ self._emit("")
+ self._emit("# Store experience for training (reward-weighted)")
+ self._emit("if reward > 0:")
+ self._indent += 1
+ self._emit("# Good experience: store with text for training")
+ self._emit("round_experiences.append((question + '\\n' + response, reward))")
+ self._indent -= 1
+ self._emit("")
+ self._emit("if ep_i % 10 == 0:")
+ self._indent += 1
+ self._emit("print(f' Episode {ep_i+1}: reward={reward:.1f} correct={is_correct} lying={is_lying}')")
+ self._indent -= 2 # close if ep_i and for ep_i
+ self._emit("")
+ # Round stats
+ self._emit("# Round summary")
+ self._emit("total_ep = round_stats['correct'] + round_stats['wrong'] + round_stats['lying']")
+ self._emit("print(f'[td_lang] Round {arena_round+1} results:')")
+ self._emit("print(f' Correct: {round_stats[\"correct\"]}/{total_ep}')")
+ self._emit("print(f' Wrong: {round_stats[\"wrong\"]}/{total_ep}')")
+ self._emit("print(f' Caught lying: {round_stats[\"lying\"]} (punished -2.0 each)')")
+ self._emit("print(f' Curiosity explorations: {round_stats[\"curious\"]}')")
+ self._emit("print(f' Cross-checked: {round_stats[\"cross_checked\"]}')")
+ self._emit("print(f' Positive experiences for training: {len(round_experiences)}')")
+ self._emit("")
+ # Training on reward-weighted experiences
+ self._emit("# Free generation model")
+ self._emit("del model; import gc; gc.collect()")
+ self._emit("try:")
+ self._indent += 1
+ self._emit("torch.cuda.empty_cache()")
+ self._indent -= 1
+ self._emit("except Exception:")
+ self._indent += 1
+ self._emit("pass")
+ self._indent -= 1
+ self._emit("")
+ self._emit("if len(round_experiences) < 3:")
+ self._indent += 1
+ self._emit("print('[td_lang] Too few positive experiences — skipping training this round')")
+ self._emit("continue")
+ self._indent -= 1
+ self._emit("")
+ self._emit("# === REWARD-WEIGHTED TRAINING ===")
+ self._emit("# Higher reward = more copies in training data (the model sees it more)")
+ self._emit("# This is how RL works: reinforce good behaviour, ignore bad")
+ self._emit("training_texts = []")
+ self._emit("for text, reward in round_experiences:")
+ self._indent += 1
+ self._emit("# Duplicate high-reward experiences (reward 1.0 = 2 copies, 1.5+ = 3 copies)")
+ self._emit("copies = max(1, int(reward * 2))")
+ self._emit("training_texts.extend([text] * copies)")
+ self._indent -= 1
+ self._emit("random.shuffle(training_texts)")
+ self._emit('print(f"[td_lang] Training on {len(training_texts)} reward-weighted experiences...")')
+ self._emit("")
+ self._emit("ds = Dataset.from_dict({'text': training_texts})")
+ self._emit("model = AutoModelForCausalLM.from_pretrained(checkpoint, quantization_config=bnb_config, device_map='auto')")
+ self._emit("model = prepare_model_for_kbit_training(model)")
+ self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,")
+ self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")')
+ self._emit("model = get_peft_model(model, lora_config)")
+ self._emit(f"arena_out = f'td_lang_outputs/arena_round_{{arena_round}}'")
+ self._emit(f"training_args = TrainingArguments(output_dir=arena_out, max_steps={cmd.steps},")
+ self._emit(" per_device_train_batch_size=1, gradient_accumulation_steps=4,")
+ self._emit(" learning_rate=5e-5, logging_steps=16, bf16=True, gradient_checkpointing=True)")
+ self._emit("trainer = SFTTrainer(model=model, train_dataset=ds, args=training_args, tokenizer=tok)")
+ self._emit("trainer.train()")
+ self._emit("trainer.save_model(arena_out)")
+ self._emit("checkpoint = arena_out # next round uses improved model")
+ self._emit("print(f'[td_lang] Arena round {arena_round+1} training complete.')")
+ self._emit("")
+ self._emit("del model; gc.collect()")
+ self._emit("try:")
+ self._indent += 1
+ self._emit("torch.cuda.empty_cache()")
+ self._indent -= 1
+ self._emit("except Exception:")
+ self._indent += 1
+ self._emit("pass")
+ self._indent -= 1
+ self._emit("")
+ # Store arena log entry
+ self._emit("arena_log.append({")
+ self._indent += 1
+ self._emit("'round': arena_round,")
+ self._emit("'stats': dict(round_stats),")
+ self._emit("'n_training_examples': len(training_texts),")
+ self._emit("'memory_size': len(memory_bank),")
+ self._emit("'unique_approaches': len(seen_approaches),")
+ self._indent -= 1
+ self._emit("})")
+ self._indent -= 1 # close for arena_round
+ self._emit("")
+ # Final summary
+ self._emit(f'models["{cmd.target}"]["checkpoint"] = checkpoint')
+ self._emit('print(f"[td_lang] ARENA COMPLETE")')
+ self._emit('print(f"[td_lang] Total memories: {len(memory_bank)}")')
+ self._emit('print(f"[td_lang] Unique approaches discovered: {len(seen_approaches)}")')
+ self._emit("")
+ self._emit("# Memory analysis")
+ self._emit("lying_count = sum(1 for m in memory_bank if m['is_lying'])")
+ self._emit("correct_count = sum(1 for m in memory_bank if m['is_correct'])")
+ self._emit("print(f'[td_lang] Total correct: {correct_count}')")
+ self._emit("print(f'[td_lang] Total caught lying: {lying_count} (punished -2.0 each)')")
+ self._emit("avg_reward = sum(m['reward'] for m in memory_bank) / len(memory_bank) if memory_bank else 0")
+ self._emit("print(f'[td_lang] Average reward: {avg_reward:.2f}')")
+ self._emit("")
+ # Save arena log
+ if cmd.output:
+ self._emit(f'arena_log_path = Path("{cmd.output}")')
+ self._emit("arena_log_path.parent.mkdir(parents=True, exist_ok=True)")
+ self._emit('with open(arena_log_path, "w") as f:')
+ self._indent += 1
+ self._emit("json.dump({'log': arena_log, 'memory': memory_bank}, f, indent=2)")
+ self._indent -= 1
+ self._emit('print(f"[td_lang] Arena log saved to {arena_log_path}")')
+ self._emit("")
+ # Lineage
+ self._emit(f'lineage["{cmd.target}"]["operations"].append({{')
+ self._indent += 1
+ self._emit('"op": "arena",')
+ self._emit(f'"dataset": "{cmd.dataset}",')
+ self._emit(f'"rounds": {cmd.rounds},')
+ self._emit(f'"episodes_per_round": {cmd.episodes},')
+ self._emit(f'"curiosity_weight": {cmd.curiosity},')
+ self._emit('"total_memories": len(memory_bank),')
+ self._emit('"unique_approaches": len(seen_approaches),')
+ self._emit('"timestamp": datetime.now().isoformat(),')
+ self._indent -= 1
+ self._emit("})")
+
+ def _emit_research_arena(self, cmd: ResearchArenaCmd, program: TDProgram) -> None:
+ """RESEARCH_ARENA - RL on ANY topic using real-world knowledge.
+
+ Unlike arena (pre-made dataset), research_arena:
+ 1. Takes a TOPIC ("cancer biology", "number theory", "machine learning")
+ 2. Pulls real knowledge from sources (web search, papers, local docs)
+ 3. Extracts verifiable facts from those sources
+ 4. Builds increasingly hard questions from real knowledge
+ 5. Runs the model through, checking EVERY claim against sources
+ 6. Difficulty ESCALATES each round (fewer hints, stricter checking)
+ 7. Memory persists, lying punished, curiosity rewarded
+ """
+ self._emit(f'print("[td_lang] RESEARCH ARENA: {cmd.topic}")')
+ self._emit(f'print("[td_lang] Source: {cmd.sources}")')
+ self._emit(f'print("[td_lang] Rounds: {cmd.rounds}, Episodes/round: {cmd.episodes}")')
+ self._emit(f'print("[td_lang] Difficulty escalation: +{cmd.difficulty_scale * 100:.0f}% per round")')
+ self._emit(f'print("[td_lang] Lying punishment: -2.0 | Curiosity bonus: +{cmd.curiosity}")')
+ self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")')
+ self._emit("if not checkpoint:")
+ self._indent += 1
+ self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]')
+ self._indent -= 1
+ self._emit("")
+ self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments")
+ self._emit("from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training")
+ self._emit("from trl import SFTTrainer")
+ self._emit("from datasets import Dataset")
+ self._emit("import torch, re, json, hashlib, random, textwrap")
+ self._emit("")
+ # ── Phase 1: Pull real knowledge about the topic ──
+ self._emit("# ============================================================")
+ self._emit(f'# PHASE 1: Pull real knowledge about "{cmd.topic}"')
+ self._emit("# ============================================================")
+ self._emit(f'topic = "{cmd.topic}"')
+ self._emit(f'source_type = "{cmd.sources}"')
+ self._emit("knowledge_base = [] # list of {fact, source, difficulty}")
+ self._emit("")
+ self._emit("if source_type == 'pubmed':")
+ self._indent += 1
+ self._emit("# Pull from PubMed API (real medical/science papers)")
+ self._emit("import urllib.request, urllib.parse, xml.etree.ElementTree as ET")
+ self._emit("search_url = f'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?db=pubmed&term={urllib.parse.quote(topic)}&retmax=50&sort=relevance'")
+ self._emit("try:")
+ self._indent += 1
+ self._emit("resp = urllib.request.urlopen(search_url, timeout=30)")
+ self._emit("tree = ET.parse(resp)")
+ self._emit("pmids = [id_el.text for id_el in tree.findall('.//Id')][:30]")
+ self._emit("print(f'[td_lang] Found {len(pmids)} PubMed articles on \"{topic}\"')")
+ self._emit("# Fetch abstracts")
+ self._emit("if pmids:")
+ self._indent += 1
+ self._emit("fetch_url = f'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?db=pubmed&id={\",\".join(pmids)}&rettype=abstract&retmode=xml'")
+ self._emit("resp2 = urllib.request.urlopen(fetch_url, timeout=60)")
+ self._emit("articles_xml = resp2.read().decode('utf-8', errors='ignore')")
+ self._emit("art_tree = ET.fromstring(articles_xml)")
+ self._emit("for article in art_tree.findall('.//PubmedArticle'):")
+ self._indent += 1
+ self._emit("title_el = article.find('.//ArticleTitle')")
+ self._emit("abstract_el = article.find('.//AbstractText')")
+ self._emit("if title_el is not None and title_el.text and abstract_el is not None and abstract_el.text:")
+ self._indent += 1
+ self._emit("text = abstract_el.text.strip()")
+ self._emit("# Extract factual sentences (those with numbers, findings, conclusions)")
+ self._emit("for sent in re.split(r'(?<=[.!?])\\s+', text):")
+ self._indent += 1
+ self._emit("sent = sent.strip()")
+ self._emit("if len(sent) > 40 and any(kw in sent.lower() for kw in ['found', 'result', 'show', 'demonstrate', 'significant', 'increase', 'decrease', 'cause', 'effect', 'treatment', 'method', 'approach', 'proved', 'evidence']):")
+ self._indent += 1
+ self._emit("diff = min(1.0, len(sent) / 300) # longer = harder")
+ self._emit("knowledge_base.append({'fact': sent, 'source': title_el.text[:80], 'difficulty': diff})")
+ self._indent -= 4 # close if sent, for sent, if title, for article
+ self._indent -= 1 # close if pmids
+ self._indent -= 1 # close try
+ self._emit("except Exception as e:")
+ self._indent += 1
+ self._emit("print(f'[td_lang] PubMed fetch failed: {e}. Falling back to web search.')")
+ self._emit("source_type = 'web'")
+ self._indent -= 2 # close except, close if pubmed
+ self._emit("")
+ self._emit("if source_type == 'web' or (source_type == 'pubmed' and len(knowledge_base) < 10):")
+ self._indent += 1
+ self._emit("# Web search — use duckduckgo-search (clean API, no scraping)")
+ self._emit("try:")
+ self._indent += 1
+ self._emit("from duckduckgo_search import DDGS")
+ self._indent -= 1
+ self._emit("except ImportError:")
+ self._indent += 1
+ self._emit("print('[td_lang] Installing duckduckgo-search...')")
+ self._emit("import subprocess; subprocess.check_call(['pip', 'install', 'duckduckgo-search', '-q', '--break-system-packages'])")
+ self._emit("from duckduckgo_search import DDGS")
+ self._indent -= 1
+ self._emit("")
+ self._emit("try:")
+ self._indent += 1
+ self._emit("ddg = DDGS()")
+ self._emit("# Search multiple angles for richer knowledge")
+ self._emit("search_queries = [")
+ self._indent += 1
+ self._emit("f'{topic} research findings',")
+ self._emit("f'{topic} key facts evidence',")
+ self._emit("f'{topic} recent discoveries',")
+ self._indent -= 1
+ self._emit("]")
+ self._emit("all_results = []")
+ self._emit("for sq in search_queries:")
+ self._indent += 1
+ self._emit("results = list(ddg.text(sq, max_results=15))")
+ self._emit("all_results.extend(results)")
+ self._indent -= 1
+ self._emit("")
+ self._emit("seen_bodies = set()")
+ self._emit("for r in all_results:")
+ self._indent += 1
+ self._emit("body = r.get('body', '').strip()")
+ self._emit("title = r.get('title', 'web')[:80]")
+ self._emit("href = r.get('href', '')")
+ self._emit("if body and body not in seen_bodies and len(body) > 30:")
+ self._indent += 1
+ self._emit("seen_bodies.add(body)")
+ self._emit("# Split into sentences for finer-grained facts")
+ self._emit("for sent in re.split(r'(?<=[.!?])\\s+', body):")
+ self._indent += 1
+ self._emit("sent = sent.strip()")
+ self._emit("if len(sent) > 30:")
+ self._indent += 1
+ self._emit("knowledge_base.append({'fact': sent, 'source': title, 'url': href, 'difficulty': min(1.0, len(sent) / 250)})")
+ self._indent -= 3 # close if sent, for sent, if body
+ self._indent -= 1 # close for r
+ self._emit("print(f'[td_lang] Web search: {len(all_results)} results -> {len(knowledge_base)} facts')")
+ self._emit("")
+ self._emit("# Fetch full page content from top results for deeper knowledge")
+ self._emit("import urllib.request")
+ self._emit("top_urls = [r.get('href', '') for r in all_results[:5] if r.get('href')]")
+ self._emit("for page_url in top_urls:")
+ self._indent += 1
+ self._emit("try:")
+ self._indent += 1
+ self._emit("req = urllib.request.Request(page_url, headers={'User-Agent': 'Mozilla/5.0'})")
+ self._emit("page_resp = urllib.request.urlopen(req, timeout=15)")
+ self._emit("page_html = page_resp.read().decode('utf-8', errors='ignore')[:50000]")
+ self._emit("# Strip HTML tags, get plain text")
+ self._emit("page_text = re.sub(r'', '', page_html, flags=re.S)")
+ self._emit("page_text = re.sub(r'', '', page_text, flags=re.S)")
+ self._emit("page_text = re.sub(r'<[^>]+>', ' ', page_text)")
+ self._emit("page_text = re.sub(r'\\s+', ' ', page_text).strip()")
+ self._emit("# Extract factual sentences")
+ self._emit("for sent in re.split(r'(?<=[.!?])\\s+', page_text[:5000]):")
+ self._indent += 1
+ self._emit("sent = sent.strip()")
+ self._emit("if len(sent) > 50 and sent not in seen_bodies:")
+ self._indent += 1
+ self._emit("seen_bodies.add(sent)")
+ self._emit("knowledge_base.append({'fact': sent, 'source': page_url[:60], 'url': page_url, 'difficulty': min(1.0, len(sent) / 200)})")
+ self._indent -= 2 # close if sent, for sent
+ self._indent -= 1 # close try
+ self._emit("except Exception:")
+ self._indent += 1
+ self._emit("pass # skip pages that can't be fetched")
+ self._indent -= 2 # close except, for page_url
+ self._emit("print(f'[td_lang] Deep fetch complete: {len(knowledge_base)} total facts')")
+ self._indent -= 1 # close try (main)
+ self._emit("except Exception as e:")
+ self._indent += 1
+ self._emit("print(f'[td_lang] Web search failed: {e}')")
+ self._indent -= 2 # close except, close if web
+ self._emit("")
+ self._emit("if source_type == 'arxiv':")
+ self._indent += 1
+ self._emit("# Pull from arXiv API (physics, math, CS, etc.)")
+ self._emit("import urllib.request, urllib.parse, xml.etree.ElementTree as ET")
+ self._emit("try:")
+ self._indent += 1
+ self._emit("query = urllib.parse.quote(f'all:{topic}')")
+ self._emit("url = f'http://export.arxiv.org/api/query?search_query={query}&max_results=30&sortBy=relevance'")
+ self._emit("resp = urllib.request.urlopen(url, timeout=30)")
+ self._emit("tree = ET.parse(resp)")
+ self._emit("ns = {'atom': 'http://www.w3.org/2005/Atom'}")
+ self._emit("for entry in tree.findall('.//atom:entry', ns):")
+ self._indent += 1
+ self._emit("title = entry.find('atom:title', ns).text.strip() if entry.find('atom:title', ns) is not None else ''")
+ self._emit("summary = entry.find('atom:summary', ns).text.strip() if entry.find('atom:summary', ns) is not None else ''")
+ self._emit("for sent in re.split(r'(?<=[.!?])\\s+', summary):")
+ self._indent += 1
+ self._emit("sent = sent.strip()")
+ self._emit("if len(sent) > 40:")
+ self._indent += 1
+ self._emit("knowledge_base.append({'fact': sent, 'source': title[:80], 'difficulty': 0.6})")
+ self._indent -= 3 # close if sent, for sent, for entry
+ self._emit("print(f'[td_lang] Pulled arXiv papers for \"{topic}\"')")
+ self._indent -= 1 # close try
+ self._emit("except Exception as e:")
+ self._indent += 1
+ self._emit("print(f'[td_lang] arXiv fetch failed: {e}')")
+ self._indent -= 2 # close except, close if arxiv
+ self._emit("")
+ # Handle local file sources
+ self._emit("if source_type not in ('web', 'pubmed', 'arxiv'):")
+ self._indent += 1
+ self._emit("# Treat as local file/folder path")
+ self._emit("import glob as _glob")
+ self._emit("source_files = _glob.glob(source_type + '/**/*', recursive=True) if os.path.isdir(source_type) else [source_type]")
+ self._emit("for fpath in source_files:")
+ self._indent += 1
+ self._emit("try:")
+ self._indent += 1
+ self._emit("with open(fpath, 'r', errors='ignore') as f:")
+ self._indent += 1
+ self._emit("text = f.read()[:10000]")
+ self._indent -= 1
+ self._emit("for sent in re.split(r'(?<=[.!?])\\s+', text):")
+ self._indent += 1
+ self._emit("sent = sent.strip()")
+ self._emit("if len(sent) > 40:")
+ self._indent += 1
+ self._emit("knowledge_base.append({'fact': sent, 'source': os.path.basename(fpath), 'difficulty': 0.5})")
+ self._indent -= 2 # close if sent, for sent
+ self._indent -= 1 # close try
+ self._emit("except Exception:")
+ self._indent += 1
+ self._emit("pass")
+ self._indent -= 2 # close except, for fpath
+ self._emit("print(f'[td_lang] Loaded {len(source_files)} local files')")
+ self._indent -= 1 # close if local
+ self._emit("")
+ self._emit("if len(knowledge_base) < 5:")
+ self._indent += 1
+ self._emit(f'print("[td_lang] ERROR: Could not gather enough knowledge about \\"{cmd.topic}\\". Need at least 5 facts.")')
+ self._emit(f'print("[td_lang] Try a different topic or source type.")')
+ self._indent -= 1
+ self._emit("else:")
+ self._indent += 1
+ self._emit("print(f'[td_lang] Knowledge base built: {len(knowledge_base)} verifiable facts')")
+ self._emit("random.shuffle(knowledge_base)")
+ self._emit("")
+ # ── Phase 2: Build the maze (question generator) ──
+ self._emit("# ============================================================")
+ self._emit("# PHASE 2: Build the maze — generate questions from knowledge")
+ self._emit("# ============================================================")
+ self._emit("")
+ self._emit("def _build_questions(kb, difficulty_level, n_questions):")
+ self._indent += 1
+ self._emit('"""Build questions from knowledge base. Higher difficulty = harder questions."""')
+ self._emit("questions = []")
+ self._emit("# Sort by difficulty, pick appropriate ones for this level")
+ self._emit("sorted_kb = sorted(kb, key=lambda x: x['difficulty'])")
+ self._emit("# At higher difficulty, use harder facts and ask trickier questions")
+ self._emit("start_pct = min(0.8, difficulty_level * 0.15) # start further into hard facts")
+ self._emit("start_idx = int(len(sorted_kb) * start_pct)")
+ self._emit("pool = sorted_kb[start_idx:] if start_idx < len(sorted_kb) else sorted_kb")
+ self._emit("selected = random.sample(pool, min(n_questions, len(pool)))")
+ self._emit("")
+ self._emit("for item in selected:")
+ self._indent += 1
+ self._emit("fact = item['fact']")
+ self._emit("source = item['source']")
+ self._emit("# Question types get harder with difficulty")
+ self._emit("if difficulty_level < 2:")
+ self._indent += 1
+ self._emit("# Easy: just verify the fact")
+ self._emit("q = f'Based on current research, is the following claim accurate? Explain your reasoning.\\n\\nClaim: {fact}'")
+ self._indent -= 1
+ self._emit("elif difficulty_level < 4:")
+ self._indent += 1
+ self._emit("# Medium: ask about implications or missing pieces")
+ self._emit("q = f'A research paper states: \"{fact}\"\\n\\nWhat are the implications of this finding? What questions does it leave unanswered? What could be wrong with this conclusion?'")
+ self._indent -= 1
+ self._emit("else:")
+ self._indent += 1
+ self._emit("# Hard: ask to connect multiple facts or identify contradictions")
+ self._emit("other_facts = [x['fact'] for x in random.sample(kb, min(3, len(kb))) if x['fact'] != fact]")
+ self._emit("context = '\\n'.join(f'- {f}' for f in other_facts[:2])")
+ self._emit("q = f'Given these research findings:\\n{context}\\n\\nAnd this additional claim: \"{fact}\"\\n\\nDo these findings support or contradict each other? Identify any gaps, errors, or unsupported leaps in logic. Be precise.'")
+ self._indent -= 1
+ self._emit("questions.append({'question': q, 'ground_truth': fact, 'source': source, 'difficulty': item['difficulty']})")
+ self._indent -= 1 # close for item
+ self._emit("return questions")
+ self._indent -= 1 # close def _build_questions
+ self._emit("")
+ # ── Phase 3: Fact-checker ──
+ self._emit("def _fact_check(response, ground_truth, model, tok, strictness):")
+ self._indent += 1
+ self._emit('"""Check model response against ground truth source. Strictness 0-1."""')
+ self._emit("# Extract key claims from the response")
+ self._emit("resp_lower = response.lower().strip()")
+ self._emit("truth_lower = ground_truth.lower().strip()")
+ self._emit("")
+ self._emit("# Extract important words from ground truth (nouns, numbers, technical terms)")
+ self._emit("truth_words = set(w for w in re.findall(r'\\b\\w{4,}\\b', truth_lower))")
+ self._emit("truth_words -= {'that', 'this', 'with', 'from', 'were', 'been', 'have', 'their', 'which', 'these', 'those', 'than', 'also', 'more'}")
+ self._emit("truth_nums = set(re.findall(r'-?\\d+\\.?\\d*', truth_lower))")
+ self._emit("")
+ self._emit("# Check how many key terms from the source appear in the response")
+ self._emit("matched_words = sum(1 for w in truth_words if w in resp_lower)")
+ self._emit("word_coverage = matched_words / max(len(truth_words), 1)")
+ self._emit("")
+ self._emit("# Check numbers match")
+ self._emit("resp_nums = set(re.findall(r'-?\\d+\\.?\\d*', resp_lower))")
+ self._emit("num_match = len(truth_nums & resp_nums) / max(len(truth_nums), 1) if truth_nums else 1.0")
+ self._emit("")
+ self._emit("# Check for direct contradictions")
+ self._emit("contradicts = False")
+ self._emit("negations = ['not true', 'incorrect', 'false', 'wrong', 'no evidence', 'disproven', 'myth', 'inaccurate']")
+ self._emit("if any(neg in resp_lower for neg in negations):")
+ self._indent += 1
+ self._emit("# Model is denying something — check if it's denying the ground truth")
+ self._emit("if word_coverage > 0.3: # it's talking about the right topic but denying it")
+ self._indent += 1
+ self._emit("contradicts = True")
+ self._indent -= 2
+ self._emit("")
+ self._emit("# Threshold increases with strictness")
+ self._emit("required_coverage = 0.2 + (strictness * 0.4) # 0.2 at easy, 0.6 at hardest")
+ self._emit("score = (word_coverage * 0.6 + num_match * 0.4)")
+ self._emit("if contradicts:")
+ self._indent += 1
+ self._emit("score = 0.0 # contradicting known facts = total failure")
+ self._indent -= 1
+ self._emit("")
+ self._emit("passed = score >= required_coverage")
+ self._emit("return {")
+ self._indent += 1
+ self._emit("'passed': passed,")
+ self._emit("'score': score,")
+ self._emit("'word_coverage': word_coverage,")
+ self._emit("'num_match': num_match,")
+ self._emit("'contradicts': contradicts,")
+ self._emit("'required': required_coverage,")
+ self._emit("'missed_terms': [w for w in truth_words if w not in resp_lower][:10],")
+ self._indent -= 1
+ self._emit("}")
+ self._indent -= 1 # close def _fact_check
+ self._emit("")
+ # ── Lying detector ──
+ self._emit("def _detect_lying(response, fact_result):")
+ self._indent += 1
+ self._emit('"""Detect lying: confident language + failed fact check = lying."""')
+ self._emit("if fact_result['passed']:")
+ self._indent += 1
+ self._emit("return False")
+ self._indent -= 1
+ self._emit("confidence_markers = ['the answer is', 'definitely', 'clearly', 'obviously',")
+ self._emit(" 'without a doubt', 'certainly', 'i am sure', 'absolutely',")
+ self._emit(" 'it is well established', 'research confirms', 'studies show']")
+ self._emit("resp_lower = response.lower()")
+ self._emit("return sum(1 for m in confidence_markers if m in resp_lower) >= 2")
+ self._indent -= 1
+ self._emit("")
+ # ── Memory and arena state ──
+ self._emit("# === ARENA STATE (persists across all rounds) ===")
+ self._emit("memory_bank = []")
+ self._emit("seen_approaches = set()")
+ self._emit("research_log = []")
+ self._emit("cumulative_difficulty = 0 # increases each round")
+ self._emit("")
+ # ── Main arena loop ──
+ self._emit(f"for arena_round in range({cmd.rounds}):")
+ self._indent += 1
+ self._emit(f"difficulty_level = arena_round # 0, 1, 2, ... (increases each round)")
+ self._emit(f"strictness = min(1.0, 0.3 + arena_round * {cmd.difficulty_scale}) # gets stricter")
+ self._emit(f"path_width = max(0.3, 1.0 - arena_round * {cmd.difficulty_scale}) # maze shrinks")
+ self._emit("")
+ self._emit(f'print(f"\\n[td_lang] === RESEARCH ARENA ROUND {{arena_round+1}}/{cmd.rounds} ===")')
+ self._emit('print(f" Difficulty: {difficulty_level} | Strictness: {strictness:.0%} | Path width: {path_width:.0%}")')
+ self._emit("")
+ self._emit("# Load model")
+ self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)")
+ self._emit("if tok.pad_token is None:")
+ self._indent += 1
+ self._emit("tok.pad_token = tok.eos_token")
+ self._indent -= 1
+ self._emit("bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4',")
+ self._emit(" bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)")
+ self._emit("model = AutoModelForCausalLM.from_pretrained(checkpoint, quantization_config=bnb_config, device_map='auto')")
+ self._emit("model.eval()")
+ self._emit("")
+ # Build questions for this round
+ self._emit(f"questions = _build_questions(knowledge_base, difficulty_level, {cmd.episodes})")
+ self._emit('print(f" Generated {len(questions)} questions for this round")')
+ self._emit("")
+ self._emit("round_experiences = []")
+ self._emit("round_stats = {'correct': 0, 'wrong': 0, 'lying': 0, 'curious': 0, 'missed_facts': []}")
+ self._emit("")
+ # Episode loop
+ self._emit("for ep_i, q_data in enumerate(questions):")
+ self._indent += 1
+ self._emit("question = q_data['question']")
+ self._emit("ground_truth = q_data['ground_truth']")
+ self._emit("")
+ self._emit("# Generate response")
+ self._emit("inputs = tok(question, return_tensors='pt', truncation=True, max_length=1024).to(model.device)")
+ self._emit(f"temp = max(0.3, 0.5 + arena_round * 0.05 + random.uniform(-0.1, 0.1))")
+ self._emit("with torch.no_grad():")
+ self._indent += 1
+ self._emit("out = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=temp, top_p=0.95)")
+ self._indent -= 1
+ self._emit("response = tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)")
+ self._emit("")
+ # Fact check
+ self._emit("# === FACT CHECK against real source ===")
+ self._emit("fact_result = _fact_check(response, ground_truth, model, tok, strictness)")
+ self._emit("is_lying = _detect_lying(response, fact_result)")
+ self._emit("approach_hash = hashlib.md5(re.sub(r'\\d+', 'N', response[:200]).lower().encode()).hexdigest()[:12]")
+ self._emit("")
+ # Reward
+ self._emit("# === REWARD ===")
+ self._emit("if is_lying:")
+ self._indent += 1
+ self._emit("reward = -2.0")
+ self._emit("round_stats['lying'] += 1")
+ self._indent -= 1
+ self._emit("elif fact_result['passed']:")
+ self._indent += 1
+ self._emit("reward = fact_result['score'] # 0.0 to 1.0 based on accuracy")
+ self._emit("round_stats['correct'] += 1")
+ self._indent -= 1
+ self._emit("else:")
+ self._indent += 1
+ self._emit("reward = -1.0 * strictness # punishment scales with difficulty")
+ self._emit("round_stats['wrong'] += 1")
+ self._emit("round_stats['missed_facts'].append({")
+ self._indent += 1
+ self._emit("'ground_truth': ground_truth[:100],")
+ self._emit("'missed_terms': fact_result['missed_terms'][:5],")
+ self._emit("'source': q_data['source'],")
+ self._indent -= 1
+ self._emit("})")
+ self._indent -= 1
+ self._emit("")
+ # Curiosity
+ self._emit("novelty_key = hashlib.md5(f'{question[:50]}_{approach_hash}'.encode()).hexdigest()[:12]")
+ self._emit("if novelty_key not in seen_approaches:")
+ self._indent += 1
+ self._emit(f"reward += {cmd.curiosity}")
+ self._emit("seen_approaches.add(novelty_key)")
+ self._emit("round_stats['curious'] += 1")
+ self._indent -= 1
+ self._emit("")
+ # Memory
+ self._emit("memory_bank.append({'reward': reward, 'passed': fact_result['passed'],")
+ self._emit(" 'lying': is_lying, 'round': arena_round, 'score': fact_result['score']})")
+ self._emit("")
+ self._emit("if reward > 0:")
+ self._indent += 1
+ self._emit("round_experiences.append((question + '\\n' + response, reward))")
+ self._indent -= 1
+ self._emit("")
+ self._emit("if ep_i % 10 == 0:")
+ self._indent += 1
+ self._emit("status = 'PASS' if fact_result['passed'] else ('LYING!' if is_lying else 'FAIL')")
+ self._emit("print(f' Ep {ep_i+1}: {status} (score={fact_result[\"score\"]:.2f}, reward={reward:.1f})')")
+ self._indent -= 2 # close if ep_i, for ep_i
+ self._emit("")
+ # Round stats
+ self._emit("total_ep = round_stats['correct'] + round_stats['wrong'] + round_stats['lying']")
+ self._emit("print(f'[td_lang] Round {arena_round+1} results:')")
+ self._emit("print(f' Passed fact-check: {round_stats[\"correct\"]}/{total_ep}')")
+ self._emit("print(f' Failed: {round_stats[\"wrong\"]}/{total_ep}')")
+ self._emit("print(f' Caught lying: {round_stats[\"lying\"]} (punished -2.0 each)')")
+ self._emit("if round_stats['missed_facts']:")
+ self._indent += 1
+ self._emit("print(f' Top missed facts ({len(round_stats[\"missed_facts\"])} total):')")
+ self._emit("for mf in round_stats['missed_facts'][:3]:")
+ self._indent += 1
+ self._emit("print(f' Source: {mf[\"source\"]}')")
+ self._emit("print(f' Missed: {mf[\"missed_terms\"]}')")
+ self._indent -= 2 # close for mf, if missed_facts
+ self._emit("")
+ # Free model, train
+ self._emit("del model; import gc; gc.collect()")
+ self._emit("try:")
+ self._indent += 1
+ self._emit("torch.cuda.empty_cache()")
+ self._indent -= 1
+ self._emit("except Exception:")
+ self._indent += 1
+ self._emit("pass")
+ self._indent -= 1
+ self._emit("")
+ self._emit("if len(round_experiences) < 3:")
+ self._indent += 1
+ self._emit("print('[td_lang] Too few positive experiences — maze was too hard. Skipping training.')")
+ self._emit("continue")
+ self._indent -= 1
+ self._emit("")
+ self._emit("# === REWARD-WEIGHTED TRAINING ===")
+ self._emit("training_texts = []")
+ self._emit("for text, reward in round_experiences:")
+ self._indent += 1
+ self._emit("copies = max(1, int(reward * 2))")
+ self._emit("training_texts.extend([text] * copies)")
+ self._indent -= 1
+ self._emit("random.shuffle(training_texts)")
+ self._emit('print(f"[td_lang] Training on {len(training_texts)} reward-weighted experiences...")')
+ self._emit("")
+ self._emit("ds = Dataset.from_dict({'text': training_texts})")
+ self._emit("model = AutoModelForCausalLM.from_pretrained(checkpoint, quantization_config=bnb_config, device_map='auto')")
+ self._emit("model = prepare_model_for_kbit_training(model)")
+ self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,")
+ self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")')
+ self._emit("model = get_peft_model(model, lora_config)")
+ self._emit(f"ra_out = f'td_lang_outputs/research_arena_round_{{arena_round}}'")
+ self._emit(f"training_args = TrainingArguments(output_dir=ra_out, max_steps={cmd.steps},")
+ self._emit(" per_device_train_batch_size=1, gradient_accumulation_steps=4,")
+ self._emit(" learning_rate=5e-5, logging_steps=16, bf16=True, gradient_checkpointing=True)")
+ self._emit("trainer = SFTTrainer(model=model, train_dataset=ds, args=training_args, tokenizer=tok)")
+ self._emit("trainer.train()")
+ self._emit("trainer.save_model(ra_out)")
+ self._emit("checkpoint = ra_out")
+ self._emit("del model; gc.collect()")
+ self._emit("try:")
+ self._indent += 1
+ self._emit("torch.cuda.empty_cache()")
+ self._indent -= 1
+ self._emit("except Exception:")
+ self._indent += 1
+ self._emit("pass")
+ self._indent -= 1
+ self._emit("")
+ self._emit("research_log.append({")
+ self._indent += 1
+ self._emit("'round': arena_round,")
+ self._emit("'difficulty': difficulty_level,")
+ self._emit("'strictness': strictness,")
+ self._emit("'stats': dict(round_stats),")
+ self._emit("'n_training': len(training_texts),")
+ self._emit("'memory_size': len(memory_bank),")
+ self._indent -= 1
+ self._emit("})")
+ self._emit("")
+ self._emit("print(f'[td_lang] Round {arena_round+1} complete. Model trained and saved.')")
+ self._indent -= 1 # close for arena_round
+ self._emit("")
+ # Final summary
+ self._emit(f'models["{cmd.target}"]["checkpoint"] = checkpoint')
+ self._emit('print(f"\\n[td_lang] RESEARCH ARENA COMPLETE")')
+ self._emit('print(f" Topic: {topic}")')
+ self._emit('print(f" Knowledge base: {len(knowledge_base)} facts")')
+ self._emit('print(f" Total memories: {len(memory_bank)}")')
+ self._emit('print(f" Unique approaches: {len(seen_approaches)}")')
+ self._emit("lying_count = sum(1 for m in memory_bank if m['lying'])")
+ self._emit("correct_count = sum(1 for m in memory_bank if m['passed'])")
+ self._emit("print(f' Correct: {correct_count} | Caught lying: {lying_count}')")
+ self._emit("avg_reward = sum(m['reward'] for m in memory_bank) / len(memory_bank) if memory_bank else 0")
+ self._emit("print(f' Average reward: {avg_reward:.2f}')")
+ self._emit("")
+ # Save log
+ if cmd.output:
+ self._emit(f'log_path = Path("{cmd.output}")')
+ self._emit("log_path.parent.mkdir(parents=True, exist_ok=True)")
+ self._emit('with open(log_path, "w") as f:')
+ self._indent += 1
+ self._emit("json.dump({'topic': topic, 'log': research_log, 'memory': memory_bank, 'knowledge_base_size': len(knowledge_base)}, f, indent=2)")
+ self._indent -= 1
+ self._emit('print(f"[td_lang] Research log saved to {log_path}")')
+ self._emit("")
+ # Lineage
+ self._emit(f'lineage["{cmd.target}"]["operations"].append({{')
+ self._indent += 1
+ self._emit('"op": "research_arena",')
+ self._emit(f'"topic": "{cmd.topic}",')
+ self._emit(f'"sources": "{cmd.sources}",')
+ self._emit(f'"rounds": {cmd.rounds},')
+ self._emit(f'"episodes_per_round": {cmd.episodes},')
+ self._emit('"knowledge_base_size": len(knowledge_base),')
+ self._emit('"total_memories": len(memory_bank),')
+ self._emit('"timestamp": datetime.now().isoformat(),')
+ self._indent -= 1
+ self._emit("})")
+ self._indent -= 1 # close else (knowledge_base >= 5)
+
+ # ---------------------------------------------------------------- Phase 11: Intelligence
+ def _emit_vote(self, cmd: VoteCmd) -> None:
+ """VOTE - majority voting. Generate N answers, pick the most common.
+
+ Proven to boost accuracy 10-20% with zero training cost.
+ """
+ n = cmd.samples
+ self._emit(f'print("[td_lang] Majority voting on {cmd.target} ({n} samples)...")')
+ self._emit(f'checkpoint = models.get("{cmd.target}", {{}}).get("checkpoint")')
+ self._emit("if not checkpoint:")
+ self._indent += 1
+ self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]')
+ self._indent -= 1
+ self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer")
+ self._emit("import torch")
+ self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)")
+ self._emit("model = AutoModelForCausalLM.from_pretrained(")
+ self._indent += 1
+ self._emit('checkpoint, torch_dtype=torch.bfloat16, device_map="auto"')
+ self._indent -= 1
+ self._emit(")")
+ self._emit("model.eval()")
+ self._emit(f'question = {repr(cmd.question)}')
+ self._emit(f"n_samples = {n}")
+ self._emit('inputs = tok(question, return_tensors="pt").to(model.device)')
+ self._emit("answers = []")
+ self._emit("for i in range(n_samples):")
+ self._indent += 1
+ self._emit("with torch.no_grad():")
+ self._indent += 1
+ self._emit("out = model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.9)")
+ self._indent -= 1
+ self._emit("resp = tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip()")
+ self._emit("answers.append(resp)")
+ self._emit('print(f" Sample {i+1}: {resp[:80]}...")')
+ self._indent -= 1
+ self._emit("")
+ self._emit("# Find the most common answer (majority vote)")
+ self._emit("from collections import Counter")
+ self._emit("# Normalize answers: lowercase, strip whitespace for comparison")
+ self._emit("normalized = [a.strip().lower() for a in answers]")
+ self._emit("counts = Counter(normalized)")
+ self._emit("winner_norm, winner_count = counts.most_common(1)[0]")
+ self._emit("# Find the original (non-normalized) version of the winner")
+ self._emit("winner = next(a for a, n in zip(answers, normalized) if n == winner_norm)")
+ self._emit('print(f"[td_lang] Winner ({winner_count}/{n_samples} votes): {winner[:200]}")')
+ self._emit("")
+ self._emit("vote_result = {")
+ self._indent += 1
+ self._emit("'question': question,")
+ self._emit("'winner': winner,")
+ self._emit("'votes': winner_count,")
+ self._emit("'total_samples': n_samples,")
+ self._emit("'all_answers': answers,")
+ self._emit("'confidence': winner_count / n_samples,")
+ self._indent -= 1
+ self._emit("}")
+ self._emit(f'results["{cmd.target}_vote"] = vote_result')
+ if cmd.output:
+ self._emit(f'vote_path = Path("{cmd.output}")')
+ self._emit("vote_path.parent.mkdir(parents=True, exist_ok=True)")
+ self._emit('with open(vote_path, "w") as f:')
+ self._indent += 1
+ self._emit("json.dump(vote_result, f, indent=2)")
+ self._indent -= 1
+ self._emit('print(f"[td_lang] Vote results saved to {vote_path}")')
+ self._emit("del model, tok")
+ self._emit("import gc; gc.collect()")
+
+ def _emit_prompt(self, cmd: PromptBlock) -> None:
+ """PROMPT - attach a system prompt to a model for all future generations.
+
+ Stores the prompt in the model's metadata so other commands (eval, diagnose,
+ synth, vote) can pick it up and prepend it.
+ """
+ self._emit(f'print("[td_lang] Setting system prompt for {cmd.target}...")')
+ self._emit(f'models["{cmd.target}"]["system_prompt"] = {repr(cmd.text)}')
+ self._emit(f'print("[td_lang] Prompt set: {repr(cmd.text[:60])}...")')
+ self._emit(f'lineage["{cmd.target}"]["operations"].append({{')
+ self._indent += 1
+ self._emit('"op": "prompt",')
+ self._emit(f'"text": {repr(cmd.text)},')
+ self._emit('"timestamp": datetime.now().isoformat(),')
+ self._indent -= 1
+ self._emit("})")
+
+ def _emit_distill(self, cmd: DistillCmd) -> None:
+ """DISTILL - train a smaller student model using the teacher's outputs.
+
+ The teacher generates high-quality answers, and we SFT the student on them.
+ Result: a fast model for easy questions.
+ """
+ steps = cmd.steps
+ self._emit(f'print("[td_lang] Distilling {cmd.teacher} into student model...")')
+ self._emit(f'teacher_checkpoint = models.get("{cmd.teacher}", {{}}).get("checkpoint")')
+ self._emit("if not teacher_checkpoint:")
+ self._indent += 1
+ self._emit(f'teacher_checkpoint = models["{cmd.teacher}"]["model_ref"]')
+ self._indent -= 1
+ self._emit(f'student_path = {repr(cmd.student)}')
+ self._emit("")
+ self._emit("# Step 1: Generate teacher answers on diverse prompts")
+ self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer")
+ self._emit("import torch")
+ self._emit('print("[td_lang] Loading teacher model...")')
+ self._emit("teacher_tok = AutoTokenizer.from_pretrained(teacher_checkpoint)")
+ self._emit("teacher_model = AutoModelForCausalLM.from_pretrained(")
+ self._indent += 1
+ self._emit('teacher_checkpoint, torch_dtype=torch.bfloat16, device_map="auto"')
+ self._indent -= 1
+ self._emit(")")
+ self._emit("teacher_model.eval()")
+ self._emit("")
+ self._emit("distill_prompts = [")
+ self._indent += 1
+ self._emit('"Explain how photosynthesis works step by step.",')
+ self._emit('"Write a Python function to find the longest common subsequence.",')
+ self._emit('"What is 847 divided by 11? Show your work.",')
+ self._emit('"Compare and contrast TCP and UDP protocols.",')
+ self._emit('"Solve: if 3x + 7 = 22, what is x?",')
+ self._emit('"Explain the difference between a stack and a queue.",')
+ self._emit('"What causes seasons on Earth?",')
+ self._emit('"Write a function to check if a string is a palindrome.",')
+ self._emit('"What is the Pythagorean theorem and give an example.",')
+ self._emit('"Explain recursion with a simple example.",')
+ self._emit('"What is 15% of 240?",')
+ self._emit('"Describe how a binary search works.",')
+ self._emit('"What are the three laws of thermodynamics?",')
+ self._emit('"Write pseudocode for bubble sort.",')
+ self._emit('"If a train travels 120 miles in 2 hours, what is its speed?",')
+ self._emit('"Explain what an API is in simple terms.",')
+ self._indent -= 1
+ self._emit("]")
+ self._emit("")
+ self._emit("teacher_data = []")
+ self._emit("for prompt in distill_prompts:")
+ self._indent += 1
+ self._emit('inputs = teacher_tok(prompt, return_tensors="pt").to(teacher_model.device)')
+ self._emit("with torch.no_grad():")
+ self._indent += 1
+ self._emit("out = teacher_model.generate(**inputs, max_new_tokens=512, do_sample=False)")
+ self._indent -= 1
+ self._emit("resp = teacher_tok.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)")
+ self._emit('teacher_data.append({"prompt": prompt, "response": resp})')
+ self._emit('print(f" Generated: {prompt[:40]}... -> {len(resp)} chars")')
+ self._indent -= 1
+ self._emit("")
+ self._emit("del teacher_model")
+ self._emit("import gc; gc.collect()")
+ self._emit("try:")
+ self._indent += 1
+ self._emit("torch.cuda.empty_cache()")
+ self._indent -= 1
+ self._emit("except Exception:")
+ self._indent += 1
+ self._emit("pass")
+ self._indent -= 1
+ self._emit("")
+ self._emit("# Step 2: Load student model with QLoRA and train on teacher outputs")
+ self._emit('print("[td_lang] Loading student model with QLoRA...")')
+ self._emit("from transformers import BitsAndBytesConfig, TrainingArguments")
+ self._emit("from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training")
+ self._emit("from trl import SFTTrainer")
+ self._emit("from datasets import Dataset")
+ self._emit("")
+ self._emit("bnb_config = BitsAndBytesConfig(")
+ self._indent += 1
+ self._emit("load_in_4bit=True,")
+ self._emit('bnb_4bit_quant_type="nf4",')
+ self._emit("bnb_4bit_compute_dtype=torch.bfloat16,")
+ self._emit("bnb_4bit_use_double_quant=True,")
+ self._indent -= 1
+ self._emit(")")
+ self._emit("student_tok = AutoTokenizer.from_pretrained(student_path)")
+ self._emit("student_model = AutoModelForCausalLM.from_pretrained(")
+ self._indent += 1
+ self._emit("student_path, quantization_config=bnb_config, device_map='auto'")
+ self._indent -= 1
+ self._emit(")")
+ self._emit("student_model = prepare_model_for_kbit_training(student_model)")
+ self._emit("")
+ self._emit("lora_config = LoraConfig(")
+ self._indent += 1
+ self._emit("r=16, lora_alpha=32, lora_dropout=0.05,")
+ self._emit('target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],')
+ self._emit('task_type="CAUSAL_LM",')
+ self._indent -= 1
+ self._emit(")")
+ self._emit("student_model = get_peft_model(student_model, lora_config)")
+ self._emit("")
+ self._emit("# Format training data")
+ self._emit("train_texts = []")
+ self._emit("for d in teacher_data:")
+ self._indent += 1
+ self._emit("train_texts.append(d['prompt'] + '\\n' + d['response'])")
+ self._indent -= 1
+ self._emit('ds = Dataset.from_dict({"text": train_texts})')
+ self._emit("")
+ distill_out = cmd.output or "td_lang_outputs/distilled_student"
+ self._emit(f'distill_out = "{distill_out}"')
+ self._emit("training_args = TrainingArguments(")
+ self._indent += 1
+ self._emit("output_dir=distill_out,")
+ self._emit(f"num_train_epochs={max(1, steps // len('distill_prompts') + 1)},")
+ self._emit(f"max_steps={steps},")
+ self._emit("per_device_train_batch_size=1,")
+ self._emit("gradient_accumulation_steps=4,")
+ self._emit("learning_rate=2e-4,")
+ self._emit('optim="paged_adamw_8bit",')
+ self._emit("logging_steps=10,")
+ self._emit("save_strategy='epoch',")
+ self._emit("bf16=True,")
+ self._indent -= 1
+ self._emit(")")
+ self._emit("trainer = SFTTrainer(")
+ self._indent += 1
+ self._emit("model=student_model,")
+ self._emit("train_dataset=ds,")
+ self._emit("args=training_args,")
+ self._emit("tokenizer=student_tok,")
+ self._indent -= 1
+ self._emit(")")
+ self._emit('print(f"[td_lang] Training student for {training_args.max_steps} steps...")')
+ self._emit("trainer.train()")
+ self._emit("student_model.save_pretrained(distill_out)")
+ self._emit("student_tok.save_pretrained(distill_out)")
+ self._emit('print(f"[td_lang] Student model saved to {distill_out}")')
+ self._emit("")
+ self._emit("del student_model, teacher_tok, student_tok")
+ self._emit("gc.collect()")
+ self._emit("try:")
+ self._indent += 1
+ self._emit("torch.cuda.empty_cache()")
+ self._indent -= 1
+ self._emit("except Exception:")
+ self._indent += 1
+ self._emit("pass")
+ self._indent -= 1
+ self._emit(f'lineage["{cmd.teacher}"]["operations"].append({{')
+ self._indent += 1
+ self._emit('"op": "distill",')
+ self._emit(f'"student": {repr(cmd.student)},')
+ self._emit(f'"steps": {steps},')
+ self._emit(f'"n_examples": len(teacher_data),')
+ self._emit('"timestamp": datetime.now().isoformat(),')
+ self._indent -= 1
+ self._emit("})")
+
+ def _emit_rollback(self, cmd: RollbackCmd) -> None:
+ """ROLLBACK - revert to the most recent snapshot.
+
+ Looks for the latest snapshot in td_lang_outputs/snapshots/ for this model,
+ then reloads from it.
+ """
+ self._emit(f'print("[td_lang] Rolling back {cmd.target}...")')
+ self._emit("import glob as _glob")
+ self._emit(f'snap_pattern = os.path.join("td_lang_outputs", "snapshots", "{cmd.target}_*")')
+ self._emit("snapshots = sorted(_glob.glob(snap_pattern))")
+ self._emit("if not snapshots:")
+ self._indent += 1
+ self._emit(f'print("[td_lang] ERROR: No snapshots found for {cmd.target}. Cannot rollback.")')
+ self._emit(f'print("[td_lang] Hint: use snapshot {cmd.target} before training to create restore points.")')
+ self._indent -= 1
+ self._emit("else:")
+ self._indent += 1
+ self._emit("latest_snap = snapshots[-1]")
+ self._emit('print(f"[td_lang] Found {len(snapshots)} snapshots. Reverting to: {latest_snap}")')
+ self._emit(f'models["{cmd.target}"]["checkpoint"] = latest_snap')
+ self._emit(f'lineage["{cmd.target}"]["operations"].append({{')
+ self._indent += 1
+ self._emit('"op": "rollback",')
+ self._emit('"snapshot": latest_snap,')
+ self._emit('"timestamp": datetime.now().isoformat(),')
+ self._indent -= 1
+ self._emit("})")
+ self._emit(f'print(f"[td_lang] Rollback complete. {cmd.target} now points to {{latest_snap}}")')
+ self._indent -= 1
+
def _emit_summary(self) -> None:
self._emit("# --- Final Summary ---")
self._emit("elapsed = time.time() - start_time")