Upload 137 files
Browse files- hugging/td_fuse/validate.py +57 -9
- hugging/td_lang/compiler.py +80 -75
hugging/td_fuse/validate.py
CHANGED
|
@@ -155,6 +155,45 @@ def compute_perplexity(
|
|
| 155 |
return perplexity
|
| 156 |
|
| 157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
def test_thinking_mode(
|
| 159 |
model: AutoModelForCausalLM,
|
| 160 |
tokenizer: AutoTokenizer,
|
|
@@ -167,16 +206,21 @@ def test_thinking_mode(
|
|
| 167 |
"""
|
| 168 |
prompt = "Solve step by step: What is 15 × 13?"
|
| 169 |
|
| 170 |
-
inputs =
|
|
|
|
|
|
|
| 171 |
with torch.no_grad():
|
| 172 |
outputs = model.generate(
|
| 173 |
**inputs,
|
| 174 |
-
max_new_tokens=
|
| 175 |
-
temperature=0.7,
|
| 176 |
do_sample=True,
|
|
|
|
|
|
|
| 177 |
)
|
| 178 |
|
| 179 |
-
|
|
|
|
|
|
|
| 180 |
|
| 181 |
# Check for thinking tags
|
| 182 |
has_think_open = "<think>" in response
|
|
@@ -185,7 +229,7 @@ def test_thinking_mode(
|
|
| 185 |
|
| 186 |
print(f"\n[validate] Thinking mode test:")
|
| 187 |
print(f" Prompt: {prompt}")
|
| 188 |
-
print(f" Response: {response[:
|
| 189 |
print(f" <think>: {'✓ found' if has_think_open else '✗ missing'}")
|
| 190 |
print(f" </think>: {'✓ found' if has_think_close else '✗ missing'}")
|
| 191 |
print(f" Status: {'✓ PASS' if passed else '✗ FAIL'}")
|
|
@@ -201,26 +245,30 @@ def test_reasoning(
|
|
| 201 |
Quick reasoning sanity check — can the model still do basic math?
|
| 202 |
|
| 203 |
This catches catastrophic failures where the merge produced gibberish.
|
|
|
|
| 204 |
"""
|
| 205 |
prompt = "What is 7 + 8?"
|
| 206 |
expected_answer = "15"
|
| 207 |
|
| 208 |
-
inputs =
|
|
|
|
|
|
|
| 209 |
with torch.no_grad():
|
| 210 |
outputs = model.generate(
|
| 211 |
**inputs,
|
| 212 |
max_new_tokens=50,
|
| 213 |
-
temperature=0.1,
|
| 214 |
do_sample=False,
|
| 215 |
)
|
| 216 |
|
| 217 |
-
|
|
|
|
|
|
|
| 218 |
passed = expected_answer in response
|
| 219 |
|
| 220 |
print(f"\n[validate] Quick reasoning test:")
|
| 221 |
print(f" Prompt: {prompt}")
|
| 222 |
print(f" Expected: {expected_answer}")
|
| 223 |
-
print(f" Got: {response}")
|
| 224 |
print(f" Status: {'✓ PASS' if passed else '✗ FAIL'}")
|
| 225 |
|
| 226 |
return passed
|
|
|
|
| 155 |
return perplexity
|
| 156 |
|
| 157 |
|
| 158 |
+
def _format_chat_prompt(tokenizer, user_message: str, enable_thinking: bool = True) -> dict:
|
| 159 |
+
"""
|
| 160 |
+
Format a prompt using Qwen3's chat template.
|
| 161 |
+
|
| 162 |
+
Qwen3 models expect messages in chat format — without it, the model
|
| 163 |
+
just autocompletes the text instead of answering.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
tokenizer: The tokenizer (or processor.tokenizer for VL models)
|
| 167 |
+
user_message: The user's question
|
| 168 |
+
enable_thinking: If True, allow <think> tags. If False, add /no_think.
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
Dict with input_ids ready for model.generate()
|
| 172 |
+
"""
|
| 173 |
+
messages = [{"role": "user", "content": user_message}]
|
| 174 |
+
|
| 175 |
+
# Try using the chat template (Qwen3 has one built in)
|
| 176 |
+
try:
|
| 177 |
+
text = tokenizer.apply_chat_template(
|
| 178 |
+
messages,
|
| 179 |
+
tokenize=False,
|
| 180 |
+
add_generation_prompt=True,
|
| 181 |
+
enable_thinking=enable_thinking,
|
| 182 |
+
)
|
| 183 |
+
inputs = tokenizer(text, return_tensors="pt")
|
| 184 |
+
return inputs
|
| 185 |
+
except Exception:
|
| 186 |
+
pass
|
| 187 |
+
|
| 188 |
+
# Fallback: manual Qwen3 chat format
|
| 189 |
+
if enable_thinking:
|
| 190 |
+
text = f"<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant\n"
|
| 191 |
+
else:
|
| 192 |
+
text = f"<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant\n/no_think\n"
|
| 193 |
+
inputs = tokenizer(text, return_tensors="pt")
|
| 194 |
+
return inputs
|
| 195 |
+
|
| 196 |
+
|
| 197 |
def test_thinking_mode(
|
| 198 |
model: AutoModelForCausalLM,
|
| 199 |
tokenizer: AutoTokenizer,
|
|
|
|
| 206 |
"""
|
| 207 |
prompt = "Solve step by step: What is 15 × 13?"
|
| 208 |
|
| 209 |
+
inputs = _format_chat_prompt(tokenizer, prompt, enable_thinking=True)
|
| 210 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 211 |
+
|
| 212 |
with torch.no_grad():
|
| 213 |
outputs = model.generate(
|
| 214 |
**inputs,
|
| 215 |
+
max_new_tokens=300,
|
|
|
|
| 216 |
do_sample=True,
|
| 217 |
+
temperature=0.7,
|
| 218 |
+
top_p=0.9,
|
| 219 |
)
|
| 220 |
|
| 221 |
+
# Decode only the NEW tokens (skip the prompt)
|
| 222 |
+
new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
|
| 223 |
+
response = tokenizer.decode(new_tokens, skip_special_tokens=False)
|
| 224 |
|
| 225 |
# Check for thinking tags
|
| 226 |
has_think_open = "<think>" in response
|
|
|
|
| 229 |
|
| 230 |
print(f"\n[validate] Thinking mode test:")
|
| 231 |
print(f" Prompt: {prompt}")
|
| 232 |
+
print(f" Response: {response[:300]}...")
|
| 233 |
print(f" <think>: {'✓ found' if has_think_open else '✗ missing'}")
|
| 234 |
print(f" </think>: {'✓ found' if has_think_close else '✗ missing'}")
|
| 235 |
print(f" Status: {'✓ PASS' if passed else '✗ FAIL'}")
|
|
|
|
| 245 |
Quick reasoning sanity check — can the model still do basic math?
|
| 246 |
|
| 247 |
This catches catastrophic failures where the merge produced gibberish.
|
| 248 |
+
Uses /no_think mode so the model answers directly without chain-of-thought.
|
| 249 |
"""
|
| 250 |
prompt = "What is 7 + 8?"
|
| 251 |
expected_answer = "15"
|
| 252 |
|
| 253 |
+
inputs = _format_chat_prompt(tokenizer, prompt, enable_thinking=False)
|
| 254 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 255 |
+
|
| 256 |
with torch.no_grad():
|
| 257 |
outputs = model.generate(
|
| 258 |
**inputs,
|
| 259 |
max_new_tokens=50,
|
|
|
|
| 260 |
do_sample=False,
|
| 261 |
)
|
| 262 |
|
| 263 |
+
# Decode only the NEW tokens (skip the prompt)
|
| 264 |
+
new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
|
| 265 |
+
response = tokenizer.decode(new_tokens, skip_special_tokens=True)
|
| 266 |
passed = expected_answer in response
|
| 267 |
|
| 268 |
print(f"\n[validate] Quick reasoning test:")
|
| 269 |
print(f" Prompt: {prompt}")
|
| 270 |
print(f" Expected: {expected_answer}")
|
| 271 |
+
print(f" Got: {response[:200]}")
|
| 272 |
print(f" Status: {'✓ PASS' if passed else '✗ FAIL'}")
|
| 273 |
|
| 274 |
return passed
|
hugging/td_lang/compiler.py
CHANGED
|
@@ -246,17 +246,27 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 246 |
self._indent -= 1
|
| 247 |
self._emit("]")
|
| 248 |
self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)")
|
| 249 |
-
self._emit("model =
|
| 250 |
self._emit("model.eval()")
|
| 251 |
self._emit("scores = []")
|
| 252 |
self._emit("for p in prompts:")
|
| 253 |
self._indent += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
self._emit("inputs = tok(p, return_tensors='pt').to(model.device)")
|
|
|
|
| 255 |
self._emit("with torch.no_grad():")
|
| 256 |
self._indent += 1
|
| 257 |
self._emit("out = model.generate(**inputs, max_new_tokens=32, do_sample=False)")
|
| 258 |
self._indent -= 1
|
| 259 |
-
self._emit("
|
|
|
|
| 260 |
self._emit("scores.append(len(resp))")
|
| 261 |
self._indent -= 1
|
| 262 |
self._emit("avg_len = sum(scores) / len(scores)")
|
|
@@ -266,6 +276,32 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 266 |
self._indent -= 1
|
| 267 |
self._emit("")
|
| 268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
if program.setup:
|
| 270 |
self._emit_setup(program.setup)
|
| 271 |
|
|
@@ -486,14 +522,10 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 486 |
self._indent += 1
|
| 487 |
self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]')
|
| 488 |
self._indent -= 1
|
| 489 |
-
self._emit("from transformers import
|
| 490 |
self._emit("import torch, re, ast")
|
| 491 |
self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)")
|
| 492 |
-
self._emit("model =
|
| 493 |
-
self._indent += 1
|
| 494 |
-
self._emit('checkpoint, torch_dtype=torch.bfloat16, device_map="auto"')
|
| 495 |
-
self._indent -= 1
|
| 496 |
-
self._emit(")")
|
| 497 |
self._emit("model.eval()")
|
| 498 |
self._emit("")
|
| 499 |
self._emit("# Mini-benchmark: math, code, reasoning, perplexity")
|
|
@@ -702,14 +734,10 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 702 |
self._emit('print("[td_lang] WARNING: No checkpoint - using model_ref instead.")')
|
| 703 |
self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]')
|
| 704 |
self._indent -= 1
|
| 705 |
-
self._emit("from transformers import
|
| 706 |
self._emit("import torch")
|
| 707 |
self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)")
|
| 708 |
-
self._emit("model =
|
| 709 |
-
self._indent += 1
|
| 710 |
-
self._emit('checkpoint, torch_dtype=torch.bfloat16, device_map="auto"')
|
| 711 |
-
self._indent -= 1
|
| 712 |
-
self._emit(")")
|
| 713 |
self._emit("model.eval()")
|
| 714 |
self._emit("")
|
| 715 |
self._emit("# Self-diagnosis prompts (from TD interview findings test_12)")
|
|
@@ -724,12 +752,23 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 724 |
self._emit("diagnose_results = []")
|
| 725 |
self._emit("for prompt in diag_prompts:")
|
| 726 |
self._indent += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 727 |
self._emit('inputs = tok(prompt, return_tensors="pt").to(model.device)')
|
|
|
|
| 728 |
self._emit("with torch.no_grad():")
|
| 729 |
self._indent += 1
|
| 730 |
self._emit("output = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)")
|
| 731 |
self._indent -= 1
|
| 732 |
-
self._emit("
|
|
|
|
| 733 |
self._emit('diagnose_results.append({"prompt": prompt, "response": response})')
|
| 734 |
self._emit('print(f" Prompt: {prompt[:50]}...")')
|
| 735 |
self._emit('print(f" Response: {response[:200]}...")')
|
|
@@ -926,14 +965,10 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 926 |
self._indent += 1
|
| 927 |
self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]')
|
| 928 |
self._indent -= 1
|
| 929 |
-
self._emit("from transformers import
|
| 930 |
self._emit("import torch, random, re")
|
| 931 |
self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)")
|
| 932 |
-
self._emit("model =
|
| 933 |
-
self._indent += 1
|
| 934 |
-
self._emit('checkpoint, torch_dtype=torch.bfloat16, device_map="auto"')
|
| 935 |
-
self._indent -= 1
|
| 936 |
-
self._emit(")")
|
| 937 |
self._emit("model.eval()")
|
| 938 |
self._emit("")
|
| 939 |
self._emit("# Use structured diagnosis if available (upgraded diagnose outputs top_weaknesses)")
|
|
@@ -1164,13 +1199,7 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 1164 |
self._indent -= 1
|
| 1165 |
self._emit(")")
|
| 1166 |
self._emit("")
|
| 1167 |
-
self._emit("model =
|
| 1168 |
-
self._indent += 1
|
| 1169 |
-
self._emit("checkpoint,")
|
| 1170 |
-
self._emit("quantization_config=bnb_config,")
|
| 1171 |
-
self._emit('device_map="auto",')
|
| 1172 |
-
self._indent -= 1
|
| 1173 |
-
self._emit(")")
|
| 1174 |
self._emit("model = prepare_model_for_kbit_training(model)")
|
| 1175 |
self._emit("")
|
| 1176 |
self._emit("# LoRA adapters on mid-to-late layers (test_12: layers 16-28 for 32-layer)")
|
|
@@ -1395,11 +1424,7 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 1395 |
self._emit("bnb_4bit_use_double_quant=True,")
|
| 1396 |
self._indent -= 1
|
| 1397 |
self._emit(")")
|
| 1398 |
-
self._emit("model =
|
| 1399 |
-
self._indent += 1
|
| 1400 |
-
self._emit("checkpoint, quantization_config=bnb_config, device_map='auto',")
|
| 1401 |
-
self._indent -= 1
|
| 1402 |
-
self._emit(")")
|
| 1403 |
self._emit("model = prepare_model_for_kbit_training(model)")
|
| 1404 |
self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,")
|
| 1405 |
self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],')
|
|
@@ -1500,11 +1525,7 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 1500 |
self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer")
|
| 1501 |
self._emit("import torch, random, json")
|
| 1502 |
self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)")
|
| 1503 |
-
self._emit("model =
|
| 1504 |
-
self._indent += 1
|
| 1505 |
-
self._emit('checkpoint, torch_dtype=torch.bfloat16, device_map="auto"')
|
| 1506 |
-
self._indent -= 1
|
| 1507 |
-
self._emit(")")
|
| 1508 |
self._emit("model.eval()")
|
| 1509 |
self._emit("")
|
| 1510 |
self._emit("# Persona-based debate (test_14: single-model diversity protocol)")
|
|
@@ -1643,11 +1664,7 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 1643 |
self._emit('"bnb_4bit_quant_type": "nf4",')
|
| 1644 |
self._indent -= 1
|
| 1645 |
self._emit("}")
|
| 1646 |
-
self._emit("model =
|
| 1647 |
-
self._indent += 1
|
| 1648 |
-
self._emit("checkpoint, device_map='auto', **bnb_config")
|
| 1649 |
-
self._indent -= 1
|
| 1650 |
-
self._emit(")")
|
| 1651 |
self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)")
|
| 1652 |
self._emit("")
|
| 1653 |
# Parse layer spec into layers_to_transform
|
|
@@ -1879,7 +1896,7 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 1879 |
self._emit("manifest = json.load(f)")
|
| 1880 |
self._indent -= 1
|
| 1881 |
self._emit('base_ref = manifest.get("base_ref", ckpt_path)')
|
| 1882 |
-
self._emit("model =
|
| 1883 |
self._emit('if manifest.get("fork_type") == "adapter":')
|
| 1884 |
self._indent += 1
|
| 1885 |
self._emit("from peft import PeftModel")
|
|
@@ -1889,7 +1906,7 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 1889 |
self._emit("elif os.path.isdir(ckpt_path):")
|
| 1890 |
self._indent += 1
|
| 1891 |
self._emit("# Loading from a HF-style directory")
|
| 1892 |
-
self._emit("model =
|
| 1893 |
self._indent -= 1
|
| 1894 |
self._emit("else:")
|
| 1895 |
self._indent += 1
|
|
@@ -1898,7 +1915,7 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 1898 |
self._emit("state = load_file(ckpt_path, device='cpu')")
|
| 1899 |
self._emit("# Need base model architecture - reload from original")
|
| 1900 |
self._emit(f'base_ref = models.get("__base_ref_{alias}", ckpt_path)')
|
| 1901 |
-
self._emit("model =
|
| 1902 |
self._emit("try:")
|
| 1903 |
self._indent += 1
|
| 1904 |
self._emit("model.load_state_dict(state, strict=True, assign=True)")
|
|
@@ -3207,7 +3224,7 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 3207 |
# Test source model
|
| 3208 |
self._emit(f'print("[td_lang] Loading source model: {source}...")')
|
| 3209 |
self._emit(f'_src_tok = AutoTokenizer.from_pretrained("{source}")')
|
| 3210 |
-
self._emit(f'_src_model =
|
| 3211 |
self._emit("_src_model.eval()")
|
| 3212 |
self._emit("")
|
| 3213 |
self._emit("_src_answers = {}")
|
|
@@ -3241,7 +3258,7 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 3241 |
self._emit(f'_mrg_checkpoint = models["{alias}"]["model_ref"]')
|
| 3242 |
self._indent -= 1
|
| 3243 |
self._emit("_mrg_tok = AutoTokenizer.from_pretrained(_mrg_checkpoint)")
|
| 3244 |
-
self._emit('_mrg_model =
|
| 3245 |
self._emit("_mrg_model.eval()")
|
| 3246 |
self._emit("")
|
| 3247 |
self._emit("_mrg_answers = {}")
|
|
@@ -3357,7 +3374,7 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 3357 |
self._emit(f'_vfy_checkpoint = models["{alias}"]["model_ref"]')
|
| 3358 |
self._indent -= 1
|
| 3359 |
self._emit("_vfy_tok = AutoTokenizer.from_pretrained(_vfy_checkpoint)")
|
| 3360 |
-
self._emit('_vfy_model =
|
| 3361 |
self._emit("_vfy_model.eval()")
|
| 3362 |
self._emit("")
|
| 3363 |
|
|
@@ -3670,7 +3687,7 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 3670 |
self._emit("bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True,")
|
| 3671 |
self._indent -= 1
|
| 3672 |
self._emit(")")
|
| 3673 |
-
self._emit("model =
|
| 3674 |
self._emit("model = prepare_model_for_kbit_training(model)")
|
| 3675 |
self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,")
|
| 3676 |
self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")')
|
|
@@ -3773,7 +3790,7 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 3773 |
self._indent -= 1
|
| 3774 |
self._emit("bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4',")
|
| 3775 |
self._emit(" bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)")
|
| 3776 |
-
self._emit("model =
|
| 3777 |
self._emit("model.eval()")
|
| 3778 |
self._emit("")
|
| 3779 |
self._emit("correct_chains = []")
|
|
@@ -3833,7 +3850,7 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 3833 |
self._emit("")
|
| 3834 |
self._emit("# Step 2: Train on correct reasoning chains")
|
| 3835 |
self._emit("ds = Dataset.from_dict({'text': correct_chains})")
|
| 3836 |
-
self._emit("model =
|
| 3837 |
self._emit("model = prepare_model_for_kbit_training(model)")
|
| 3838 |
self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,")
|
| 3839 |
self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")')
|
|
@@ -3916,7 +3933,7 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 3916 |
self._indent -= 1
|
| 3917 |
self._emit("bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4',")
|
| 3918 |
self._emit(" bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)")
|
| 3919 |
-
self._emit("model =
|
| 3920 |
self._emit("model.eval()")
|
| 3921 |
self._emit("")
|
| 3922 |
self._emit("def _score_response(resp):")
|
|
@@ -3987,7 +4004,7 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 3987 |
self._emit("# Train on the best completions")
|
| 3988 |
self._emit(f'print(f"[td_lang] Training on {{len(best_completions)}} best-of-{cmd.n} completions...")')
|
| 3989 |
self._emit("ds = Dataset.from_dict({'text': best_completions})")
|
| 3990 |
-
self._emit("model =
|
| 3991 |
self._emit("model = prepare_model_for_kbit_training(model)")
|
| 3992 |
self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,")
|
| 3993 |
self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")')
|
|
@@ -4066,7 +4083,7 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 4066 |
self._indent -= 1
|
| 4067 |
self._emit("bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4',")
|
| 4068 |
self._emit(" bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)")
|
| 4069 |
-
self._emit("model =
|
| 4070 |
self._emit("model.eval()")
|
| 4071 |
self._emit("")
|
| 4072 |
self._emit("# EXPLOIT: Generate MANY diverse solutions with HIGH temperature")
|
|
@@ -4170,7 +4187,7 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 4170 |
self._emit("# Train on ALL correct solutions (the controlled hack)")
|
| 4171 |
self._emit(f'print("[td_lang] Training on {{len(exploit_data)}} diverse correct solutions...")')
|
| 4172 |
self._emit("ds = Dataset.from_dict({'text': exploit_data})")
|
| 4173 |
-
self._emit("model =
|
| 4174 |
self._emit("model = prepare_model_for_kbit_training(model)")
|
| 4175 |
self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,")
|
| 4176 |
self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")')
|
|
@@ -4366,7 +4383,7 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 4366 |
self._indent -= 1
|
| 4367 |
self._emit("bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4',")
|
| 4368 |
self._emit(" bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)")
|
| 4369 |
-
self._emit("model =
|
| 4370 |
self._emit("model.eval()")
|
| 4371 |
self._emit("")
|
| 4372 |
# Episode loop
|
|
@@ -4509,7 +4526,7 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 4509 |
self._emit('print(f"[td_lang] Training on {len(training_texts)} reward-weighted experiences...")')
|
| 4510 |
self._emit("")
|
| 4511 |
self._emit("ds = Dataset.from_dict({'text': training_texts})")
|
| 4512 |
-
self._emit("model =
|
| 4513 |
self._emit("model = prepare_model_for_kbit_training(model)")
|
| 4514 |
self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,")
|
| 4515 |
self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")')
|
|
@@ -4957,7 +4974,7 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 4957 |
self._indent -= 1
|
| 4958 |
self._emit("bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4',")
|
| 4959 |
self._emit(" bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)")
|
| 4960 |
-
self._emit("model =
|
| 4961 |
self._emit("model.eval()")
|
| 4962 |
self._emit("")
|
| 4963 |
# Build questions for this round
|
|
@@ -5080,7 +5097,7 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 5080 |
self._emit('print(f"[td_lang] Training on {len(training_texts)} reward-weighted experiences...")')
|
| 5081 |
self._emit("")
|
| 5082 |
self._emit("ds = Dataset.from_dict({'text': training_texts})")
|
| 5083 |
-
self._emit("model =
|
| 5084 |
self._emit("model = prepare_model_for_kbit_training(model)")
|
| 5085 |
self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,")
|
| 5086 |
self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")')
|
|
@@ -5171,11 +5188,7 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 5171 |
self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer")
|
| 5172 |
self._emit("import torch")
|
| 5173 |
self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)")
|
| 5174 |
-
self._emit("model =
|
| 5175 |
-
self._indent += 1
|
| 5176 |
-
self._emit('checkpoint, torch_dtype=torch.bfloat16, device_map="auto"')
|
| 5177 |
-
self._indent -= 1
|
| 5178 |
-
self._emit(")")
|
| 5179 |
self._emit("model.eval()")
|
| 5180 |
self._emit(f'question = {repr(cmd.question)}')
|
| 5181 |
self._emit(f"n_samples = {n}")
|
|
@@ -5261,11 +5274,7 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 5261 |
self._emit("import torch")
|
| 5262 |
self._emit('print("[td_lang] Loading teacher model...")')
|
| 5263 |
self._emit("teacher_tok = AutoTokenizer.from_pretrained(teacher_checkpoint)")
|
| 5264 |
-
self._emit("teacher_model =
|
| 5265 |
-
self._indent += 1
|
| 5266 |
-
self._emit('teacher_checkpoint, torch_dtype=torch.bfloat16, device_map="auto"')
|
| 5267 |
-
self._indent -= 1
|
| 5268 |
-
self._emit(")")
|
| 5269 |
self._emit("teacher_model.eval()")
|
| 5270 |
self._emit("")
|
| 5271 |
self._emit("distill_prompts = [")
|
|
@@ -5329,11 +5338,7 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 5329 |
self._indent -= 1
|
| 5330 |
self._emit(")")
|
| 5331 |
self._emit("student_tok = AutoTokenizer.from_pretrained(student_path)")
|
| 5332 |
-
self._emit("student_model =
|
| 5333 |
-
self._indent += 1
|
| 5334 |
-
self._emit("student_path, quantization_config=bnb_config, device_map='auto'")
|
| 5335 |
-
self._indent -= 1
|
| 5336 |
-
self._emit(")")
|
| 5337 |
self._emit("student_model = prepare_model_for_kbit_training(student_model)")
|
| 5338 |
self._emit("")
|
| 5339 |
self._emit("lora_config = LoraConfig(")
|
|
|
|
| 246 |
self._indent -= 1
|
| 247 |
self._emit("]")
|
| 248 |
self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)")
|
| 249 |
+
self._emit("model = _load_model_smart(checkpoint, torch_dtype=torch.float16, device_map='auto')")
|
| 250 |
self._emit("model.eval()")
|
| 251 |
self._emit("scores = []")
|
| 252 |
self._emit("for p in prompts:")
|
| 253 |
self._indent += 1
|
| 254 |
+
self._emit("messages = [{'role': 'user', 'content': p}]")
|
| 255 |
+
self._emit("try:")
|
| 256 |
+
self._indent += 1
|
| 257 |
+
self._emit("text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)")
|
| 258 |
+
self._emit("inputs = tok(text, return_tensors='pt').to(model.device)")
|
| 259 |
+
self._indent -= 1
|
| 260 |
+
self._emit("except Exception:")
|
| 261 |
+
self._indent += 1
|
| 262 |
self._emit("inputs = tok(p, return_tensors='pt').to(model.device)")
|
| 263 |
+
self._indent -= 1
|
| 264 |
self._emit("with torch.no_grad():")
|
| 265 |
self._indent += 1
|
| 266 |
self._emit("out = model.generate(**inputs, max_new_tokens=32, do_sample=False)")
|
| 267 |
self._indent -= 1
|
| 268 |
+
self._emit("new_tokens = out[0][inputs['input_ids'].shape[1]:]")
|
| 269 |
+
self._emit("resp = tok.decode(new_tokens, skip_special_tokens=True)")
|
| 270 |
self._emit("scores.append(len(resp))")
|
| 271 |
self._indent -= 1
|
| 272 |
self._emit("avg_len = sum(scores) / len(scores)")
|
|
|
|
| 276 |
self._indent -= 1
|
| 277 |
self._emit("")
|
| 278 |
|
| 279 |
+
# Smart model loader that handles Qwen3-VL and other model types
|
| 280 |
+
self._emit("def _load_model_smart(checkpoint, **kwargs):")
|
| 281 |
+
self._indent += 1
|
| 282 |
+
self._emit('"""Load model — auto-detects Qwen3-VL and uses the correct class."""')
|
| 283 |
+
self._emit("from transformers import AutoConfig")
|
| 284 |
+
self._emit("try:")
|
| 285 |
+
self._indent += 1
|
| 286 |
+
self._emit("config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)")
|
| 287 |
+
self._emit("model_type = getattr(config, 'model_type', '')")
|
| 288 |
+
self._emit("config_class = type(config).__name__.lower()")
|
| 289 |
+
self._emit("if 'qwen3_vl' in model_type or 'qwen3vl' in config_class:")
|
| 290 |
+
self._indent += 1
|
| 291 |
+
self._emit("from transformers import Qwen3VLForConditionalGeneration")
|
| 292 |
+
self._emit("print(f'[td_lang] Loading as Qwen3-VL model: {checkpoint}')")
|
| 293 |
+
self._emit("return Qwen3VLForConditionalGeneration.from_pretrained(checkpoint, **kwargs)")
|
| 294 |
+
self._indent -= 1
|
| 295 |
+
self._indent -= 1
|
| 296 |
+
self._emit("except Exception as e:")
|
| 297 |
+
self._indent += 1
|
| 298 |
+
self._emit("print(f'[td_lang] Auto-detect failed ({e}), using AutoModelForCausalLM')")
|
| 299 |
+
self._indent -= 1
|
| 300 |
+
self._emit("from transformers import AutoModelForCausalLM")
|
| 301 |
+
self._emit("return AutoModelForCausalLM.from_pretrained(checkpoint, **kwargs)")
|
| 302 |
+
self._indent -= 1
|
| 303 |
+
self._emit("")
|
| 304 |
+
|
| 305 |
if program.setup:
|
| 306 |
self._emit_setup(program.setup)
|
| 307 |
|
|
|
|
| 522 |
self._indent += 1
|
| 523 |
self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]')
|
| 524 |
self._indent -= 1
|
| 525 |
+
self._emit("from transformers import AutoTokenizer")
|
| 526 |
self._emit("import torch, re, ast")
|
| 527 |
self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)")
|
| 528 |
+
self._emit("model = _load_model_smart(checkpoint, torch_dtype=torch.bfloat16, device_map='auto')")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 529 |
self._emit("model.eval()")
|
| 530 |
self._emit("")
|
| 531 |
self._emit("# Mini-benchmark: math, code, reasoning, perplexity")
|
|
|
|
| 734 |
self._emit('print("[td_lang] WARNING: No checkpoint - using model_ref instead.")')
|
| 735 |
self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]')
|
| 736 |
self._indent -= 1
|
| 737 |
+
self._emit("from transformers import AutoTokenizer")
|
| 738 |
self._emit("import torch")
|
| 739 |
self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)")
|
| 740 |
+
self._emit("model = _load_model_smart(checkpoint, torch_dtype=torch.bfloat16, device_map='auto')")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 741 |
self._emit("model.eval()")
|
| 742 |
self._emit("")
|
| 743 |
self._emit("# Self-diagnosis prompts (from TD interview findings test_12)")
|
|
|
|
| 752 |
self._emit("diagnose_results = []")
|
| 753 |
self._emit("for prompt in diag_prompts:")
|
| 754 |
self._indent += 1
|
| 755 |
+
self._emit("# Use chat template for proper generation (Qwen3 needs this)")
|
| 756 |
+
self._emit('messages = [{"role": "user", "content": prompt}]')
|
| 757 |
+
self._emit("try:")
|
| 758 |
+
self._indent += 1
|
| 759 |
+
self._emit("text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)")
|
| 760 |
+
self._emit('inputs = tok(text, return_tensors="pt").to(model.device)')
|
| 761 |
+
self._indent -= 1
|
| 762 |
+
self._emit("except Exception:")
|
| 763 |
+
self._indent += 1
|
| 764 |
self._emit('inputs = tok(prompt, return_tensors="pt").to(model.device)')
|
| 765 |
+
self._indent -= 1
|
| 766 |
self._emit("with torch.no_grad():")
|
| 767 |
self._indent += 1
|
| 768 |
self._emit("output = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)")
|
| 769 |
self._indent -= 1
|
| 770 |
+
self._emit("new_tokens = output[0][inputs['input_ids'].shape[1]:]")
|
| 771 |
+
self._emit("response = tok.decode(new_tokens, skip_special_tokens=True)")
|
| 772 |
self._emit('diagnose_results.append({"prompt": prompt, "response": response})')
|
| 773 |
self._emit('print(f" Prompt: {prompt[:50]}...")')
|
| 774 |
self._emit('print(f" Response: {response[:200]}...")')
|
|
|
|
| 965 |
self._indent += 1
|
| 966 |
self._emit(f'checkpoint = models["{cmd.target}"]["model_ref"]')
|
| 967 |
self._indent -= 1
|
| 968 |
+
self._emit("from transformers import AutoTokenizer")
|
| 969 |
self._emit("import torch, random, re")
|
| 970 |
self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)")
|
| 971 |
+
self._emit("model = _load_model_smart(checkpoint, torch_dtype=torch.bfloat16, device_map='auto')")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 972 |
self._emit("model.eval()")
|
| 973 |
self._emit("")
|
| 974 |
self._emit("# Use structured diagnosis if available (upgraded diagnose outputs top_weaknesses)")
|
|
|
|
| 1199 |
self._indent -= 1
|
| 1200 |
self._emit(")")
|
| 1201 |
self._emit("")
|
| 1202 |
+
self._emit("model = _load_model_smart(checkpoint, quantization_config=bnb_config, device_map='auto')")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1203 |
self._emit("model = prepare_model_for_kbit_training(model)")
|
| 1204 |
self._emit("")
|
| 1205 |
self._emit("# LoRA adapters on mid-to-late layers (test_12: layers 16-28 for 32-layer)")
|
|
|
|
| 1424 |
self._emit("bnb_4bit_use_double_quant=True,")
|
| 1425 |
self._indent -= 1
|
| 1426 |
self._emit(")")
|
| 1427 |
+
self._emit("model = _load_model_smart(checkpoint, quantization_config=bnb_config, device_map='auto')")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1428 |
self._emit("model = prepare_model_for_kbit_training(model)")
|
| 1429 |
self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,")
|
| 1430 |
self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],')
|
|
|
|
| 1525 |
self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer")
|
| 1526 |
self._emit("import torch, random, json")
|
| 1527 |
self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)")
|
| 1528 |
+
self._emit("model = _load_model_smart(checkpoint, torch_dtype=torch.bfloat16, device_map='auto')")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1529 |
self._emit("model.eval()")
|
| 1530 |
self._emit("")
|
| 1531 |
self._emit("# Persona-based debate (test_14: single-model diversity protocol)")
|
|
|
|
| 1664 |
self._emit('"bnb_4bit_quant_type": "nf4",')
|
| 1665 |
self._indent -= 1
|
| 1666 |
self._emit("}")
|
| 1667 |
+
self._emit("model = _load_model_smart(checkpoint, device_map='auto', **bnb_config)")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1668 |
self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)")
|
| 1669 |
self._emit("")
|
| 1670 |
# Parse layer spec into layers_to_transform
|
|
|
|
| 1896 |
self._emit("manifest = json.load(f)")
|
| 1897 |
self._indent -= 1
|
| 1898 |
self._emit('base_ref = manifest.get("base_ref", ckpt_path)')
|
| 1899 |
+
self._emit("model = _load_model_smart(base_ref, torch_dtype=torch.float16, device_map='cuda')")
|
| 1900 |
self._emit('if manifest.get("fork_type") == "adapter":')
|
| 1901 |
self._indent += 1
|
| 1902 |
self._emit("from peft import PeftModel")
|
|
|
|
| 1906 |
self._emit("elif os.path.isdir(ckpt_path):")
|
| 1907 |
self._indent += 1
|
| 1908 |
self._emit("# Loading from a HF-style directory")
|
| 1909 |
+
self._emit("model = _load_model_smart(ckpt_path, torch_dtype=torch.float16, device_map='cuda')")
|
| 1910 |
self._indent -= 1
|
| 1911 |
self._emit("else:")
|
| 1912 |
self._indent += 1
|
|
|
|
| 1915 |
self._emit("state = load_file(ckpt_path, device='cpu')")
|
| 1916 |
self._emit("# Need base model architecture - reload from original")
|
| 1917 |
self._emit(f'base_ref = models.get("__base_ref_{alias}", ckpt_path)')
|
| 1918 |
+
self._emit("model = _load_model_smart(base_ref, torch_dtype=torch.float16, device_map='cuda')")
|
| 1919 |
self._emit("try:")
|
| 1920 |
self._indent += 1
|
| 1921 |
self._emit("model.load_state_dict(state, strict=True, assign=True)")
|
|
|
|
| 3224 |
# Test source model
|
| 3225 |
self._emit(f'print("[td_lang] Loading source model: {source}...")')
|
| 3226 |
self._emit(f'_src_tok = AutoTokenizer.from_pretrained("{source}")')
|
| 3227 |
+
self._emit(f'_src_model = _load_model_smart("{source}", torch_dtype=torch.bfloat16, device_map="auto")')
|
| 3228 |
self._emit("_src_model.eval()")
|
| 3229 |
self._emit("")
|
| 3230 |
self._emit("_src_answers = {}")
|
|
|
|
| 3258 |
self._emit(f'_mrg_checkpoint = models["{alias}"]["model_ref"]')
|
| 3259 |
self._indent -= 1
|
| 3260 |
self._emit("_mrg_tok = AutoTokenizer.from_pretrained(_mrg_checkpoint)")
|
| 3261 |
+
self._emit('_mrg_model = _load_model_smart(_mrg_checkpoint, torch_dtype=torch.bfloat16, device_map="auto")')
|
| 3262 |
self._emit("_mrg_model.eval()")
|
| 3263 |
self._emit("")
|
| 3264 |
self._emit("_mrg_answers = {}")
|
|
|
|
| 3374 |
self._emit(f'_vfy_checkpoint = models["{alias}"]["model_ref"]')
|
| 3375 |
self._indent -= 1
|
| 3376 |
self._emit("_vfy_tok = AutoTokenizer.from_pretrained(_vfy_checkpoint)")
|
| 3377 |
+
self._emit('_vfy_model = _load_model_smart(_vfy_checkpoint, torch_dtype=torch.bfloat16, device_map="auto")')
|
| 3378 |
self._emit("_vfy_model.eval()")
|
| 3379 |
self._emit("")
|
| 3380 |
|
|
|
|
| 3687 |
self._emit("bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True,")
|
| 3688 |
self._indent -= 1
|
| 3689 |
self._emit(")")
|
| 3690 |
+
self._emit("model = _load_model_smart(checkpoint, quantization_config=bnb_config, device_map='auto')")
|
| 3691 |
self._emit("model = prepare_model_for_kbit_training(model)")
|
| 3692 |
self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,")
|
| 3693 |
self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")')
|
|
|
|
| 3790 |
self._indent -= 1
|
| 3791 |
self._emit("bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4',")
|
| 3792 |
self._emit(" bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)")
|
| 3793 |
+
self._emit("model = _load_model_smart(checkpoint, quantization_config=bnb_config, device_map='auto')")
|
| 3794 |
self._emit("model.eval()")
|
| 3795 |
self._emit("")
|
| 3796 |
self._emit("correct_chains = []")
|
|
|
|
| 3850 |
self._emit("")
|
| 3851 |
self._emit("# Step 2: Train on correct reasoning chains")
|
| 3852 |
self._emit("ds = Dataset.from_dict({'text': correct_chains})")
|
| 3853 |
+
self._emit("model = _load_model_smart(checkpoint, quantization_config=bnb_config, device_map='auto')")
|
| 3854 |
self._emit("model = prepare_model_for_kbit_training(model)")
|
| 3855 |
self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,")
|
| 3856 |
self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")')
|
|
|
|
| 3933 |
self._indent -= 1
|
| 3934 |
self._emit("bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4',")
|
| 3935 |
self._emit(" bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)")
|
| 3936 |
+
self._emit("model = _load_model_smart(checkpoint, quantization_config=bnb_config, device_map='auto')")
|
| 3937 |
self._emit("model.eval()")
|
| 3938 |
self._emit("")
|
| 3939 |
self._emit("def _score_response(resp):")
|
|
|
|
| 4004 |
self._emit("# Train on the best completions")
|
| 4005 |
self._emit(f'print(f"[td_lang] Training on {{len(best_completions)}} best-of-{cmd.n} completions...")')
|
| 4006 |
self._emit("ds = Dataset.from_dict({'text': best_completions})")
|
| 4007 |
+
self._emit("model = _load_model_smart(checkpoint, quantization_config=bnb_config, device_map='auto')")
|
| 4008 |
self._emit("model = prepare_model_for_kbit_training(model)")
|
| 4009 |
self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,")
|
| 4010 |
self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")')
|
|
|
|
| 4083 |
self._indent -= 1
|
| 4084 |
self._emit("bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4',")
|
| 4085 |
self._emit(" bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)")
|
| 4086 |
+
self._emit("model = _load_model_smart(checkpoint, quantization_config=bnb_config, device_map='auto')")
|
| 4087 |
self._emit("model.eval()")
|
| 4088 |
self._emit("")
|
| 4089 |
self._emit("# EXPLOIT: Generate MANY diverse solutions with HIGH temperature")
|
|
|
|
| 4187 |
self._emit("# Train on ALL correct solutions (the controlled hack)")
|
| 4188 |
self._emit(f'print("[td_lang] Training on {{len(exploit_data)}} diverse correct solutions...")')
|
| 4189 |
self._emit("ds = Dataset.from_dict({'text': exploit_data})")
|
| 4190 |
+
self._emit("model = _load_model_smart(checkpoint, quantization_config=bnb_config, device_map='auto')")
|
| 4191 |
self._emit("model = prepare_model_for_kbit_training(model)")
|
| 4192 |
self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,")
|
| 4193 |
self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")')
|
|
|
|
| 4383 |
self._indent -= 1
|
| 4384 |
self._emit("bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4',")
|
| 4385 |
self._emit(" bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)")
|
| 4386 |
+
self._emit("model = _load_model_smart(checkpoint, quantization_config=bnb_config, device_map='auto')")
|
| 4387 |
self._emit("model.eval()")
|
| 4388 |
self._emit("")
|
| 4389 |
# Episode loop
|
|
|
|
| 4526 |
self._emit('print(f"[td_lang] Training on {len(training_texts)} reward-weighted experiences...")')
|
| 4527 |
self._emit("")
|
| 4528 |
self._emit("ds = Dataset.from_dict({'text': training_texts})")
|
| 4529 |
+
self._emit("model = _load_model_smart(checkpoint, quantization_config=bnb_config, device_map='auto')")
|
| 4530 |
self._emit("model = prepare_model_for_kbit_training(model)")
|
| 4531 |
self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,")
|
| 4532 |
self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")')
|
|
|
|
| 4974 |
self._indent -= 1
|
| 4975 |
self._emit("bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4',")
|
| 4976 |
self._emit(" bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)")
|
| 4977 |
+
self._emit("model = _load_model_smart(checkpoint, quantization_config=bnb_config, device_map='auto')")
|
| 4978 |
self._emit("model.eval()")
|
| 4979 |
self._emit("")
|
| 4980 |
# Build questions for this round
|
|
|
|
| 5097 |
self._emit('print(f"[td_lang] Training on {len(training_texts)} reward-weighted experiences...")')
|
| 5098 |
self._emit("")
|
| 5099 |
self._emit("ds = Dataset.from_dict({'text': training_texts})")
|
| 5100 |
+
self._emit("model = _load_model_smart(checkpoint, quantization_config=bnb_config, device_map='auto')")
|
| 5101 |
self._emit("model = prepare_model_for_kbit_training(model)")
|
| 5102 |
self._emit("lora_config = LoraConfig(r=32, lora_alpha=64, lora_dropout=0.05,")
|
| 5103 |
self._emit(' target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")')
|
|
|
|
| 5188 |
self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer")
|
| 5189 |
self._emit("import torch")
|
| 5190 |
self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)")
|
| 5191 |
+
self._emit("model = _load_model_smart(checkpoint, torch_dtype=torch.bfloat16, device_map='auto')")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5192 |
self._emit("model.eval()")
|
| 5193 |
self._emit(f'question = {repr(cmd.question)}')
|
| 5194 |
self._emit(f"n_samples = {n}")
|
|
|
|
| 5274 |
self._emit("import torch")
|
| 5275 |
self._emit('print("[td_lang] Loading teacher model...")')
|
| 5276 |
self._emit("teacher_tok = AutoTokenizer.from_pretrained(teacher_checkpoint)")
|
| 5277 |
+
self._emit("teacher_model = _load_model_smart(teacher_checkpoint, torch_dtype=torch.bfloat16, device_map='auto')")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5278 |
self._emit("teacher_model.eval()")
|
| 5279 |
self._emit("")
|
| 5280 |
self._emit("distill_prompts = [")
|
|
|
|
| 5338 |
self._indent -= 1
|
| 5339 |
self._emit(")")
|
| 5340 |
self._emit("student_tok = AutoTokenizer.from_pretrained(student_path)")
|
| 5341 |
+
self._emit("student_model = _load_model_smart(student_path, quantization_config=bnb_config, device_map='auto')")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5342 |
self._emit("student_model = prepare_model_for_kbit_training(student_model)")
|
| 5343 |
self._emit("")
|
| 5344 |
self._emit("lora_config = LoraConfig(")
|