Upload 139 files
Browse files- hugging/td_fuse/__pycache__/merge.cpython-310.pyc +0 -0
- hugging/td_fuse/__pycache__/transport.cpython-310.pyc +0 -0
- hugging/td_fuse/heal.py +12 -0
- hugging/td_fuse/merge.py +43 -8
- hugging/td_fuse/transport.py +1 -0
- hugging/td_lang/__pycache__/compiler.cpython-310.pyc +2 -2
- hugging/td_lang/compiler.py +76 -165
- hugging/td_start.td +6 -0
hugging/td_fuse/__pycache__/merge.cpython-310.pyc
ADDED
|
Binary file (31.9 kB). View file
|
|
|
hugging/td_fuse/__pycache__/transport.cpython-310.pyc
ADDED
|
Binary file (19.4 kB). View file
|
|
|
hugging/td_fuse/heal.py
CHANGED
|
@@ -242,6 +242,11 @@ def apply_qlora_standard(
|
|
| 242 |
Returns:
|
| 243 |
Path to healed model directory
|
| 244 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
from peft import LoraConfig, get_peft_model, TaskType
|
| 246 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 247 |
|
|
@@ -404,6 +409,13 @@ def heal_model(
|
|
| 404 |
if cfg is None:
|
| 405 |
cfg = MergeConfig()
|
| 406 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
heal_start = time.time()
|
| 408 |
print("\n" + "=" * 60)
|
| 409 |
print("HEALING FINE-TUNE")
|
|
|
|
| 242 |
Returns:
|
| 243 |
Path to healed model directory
|
| 244 |
"""
|
| 245 |
+
import os
|
| 246 |
+
healed_check = os.path.join('td_fuse_outputs', 'healed', 'model.safetensors')
|
| 247 |
+
if os.path.exists(healed_check):
|
| 248 |
+
print('[heal] Found existing healed model — SKIPPING healing!')
|
| 249 |
+
return 'td_fuse_outputs/healed'
|
| 250 |
from peft import LoraConfig, get_peft_model, TaskType
|
| 251 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 252 |
|
|
|
|
| 409 |
if cfg is None:
|
| 410 |
cfg = MergeConfig()
|
| 411 |
|
| 412 |
+
# Skip healing if already done (saves ~45 min on re-runs)
|
| 413 |
+
import os
|
| 414 |
+
healed_check = os.path.join('td_fuse_outputs', 'healed', 'model.safetensors')
|
| 415 |
+
if os.path.exists(healed_check):
|
| 416 |
+
print('[heal] Found existing healed model — SKIPPING healing!')
|
| 417 |
+
return 'td_fuse_outputs/healed'
|
| 418 |
+
|
| 419 |
heal_start = time.time()
|
| 420 |
print("\n" + "=" * 60)
|
| 421 |
print("HEALING FINE-TUNE")
|
hugging/td_fuse/merge.py
CHANGED
|
@@ -717,6 +717,19 @@ def run_single_merge(
|
|
| 717 |
torch.cuda.empty_cache()
|
| 718 |
return result
|
| 719 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 720 |
# --- Step 5: Compute transport plans ---
|
| 721 |
print(f"\n[merge] Step 5/10: Computing transport plans..."); sys.stdout.flush()
|
| 722 |
step_t = time.time()
|
|
@@ -750,20 +763,22 @@ def run_single_merge(
|
|
| 750 |
print(f"[merge] RAM skipped: base model {base_hf_id} not found on HuggingFace")
|
| 751 |
use_ram = False
|
| 752 |
|
| 753 |
-
# --- Step 5.7: Free source model
|
| 754 |
-
#
|
| 755 |
-
#
|
| 756 |
-
|
| 757 |
-
print(f"\n[merge] Step 5.7: Freeing source model from GPU..."); sys.stdout.flush()
|
| 758 |
step_t = time.time()
|
| 759 |
source_state_cpu = {k: v.cpu() for k, v in source_model.state_dict().items()}
|
| 760 |
del source_model
|
| 761 |
gc.collect()
|
| 762 |
if torch.cuda.is_available():
|
| 763 |
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
| 764 |
free_mem = torch.cuda.mem_get_info()[0] / 1e9
|
| 765 |
total_mem = torch.cuda.mem_get_info()[1] / 1e9
|
| 766 |
-
print(f"[merge] GPU memory
|
| 767 |
print(f"[merge] Step 5.7 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
|
| 768 |
|
| 769 |
# --- Step 6: Pre-merge protection ---
|
|
@@ -982,6 +997,7 @@ def run_single_merge(
|
|
| 982 |
def run_pipeline(
|
| 983 |
stages: list[str],
|
| 984 |
cfg: MergeConfig = None,
|
|
|
|
| 985 |
) -> dict:
|
| 986 |
"""
|
| 987 |
Run the full merge pipeline.
|
|
@@ -1023,8 +1039,17 @@ def run_pipeline(
|
|
| 1023 |
Path(cfg.output_dir).mkdir(parents=True, exist_ok=True)
|
| 1024 |
Path(cfg.checkpoint_dir).mkdir(parents=True, exist_ok=True)
|
| 1025 |
|
| 1026 |
-
# --- Load target model ---
|
| 1027 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1028 |
|
| 1029 |
# --- Inject canary into target (Qwen3's own canary) ---
|
| 1030 |
if "Qwen3-VL-8B" in CANARY_FACTS:
|
|
@@ -1116,6 +1141,16 @@ def run_pipeline(
|
|
| 1116 |
if pipeline_results["final_checkpoint"]:
|
| 1117 |
final_dir = Path(cfg.output_dir) / "final"
|
| 1118 |
final_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1119 |
target_model.save_pretrained(final_dir)
|
| 1120 |
target_tokenizer.save_pretrained(final_dir)
|
| 1121 |
pipeline_results["final_model_path"] = str(final_dir)
|
|
|
|
| 717 |
torch.cuda.empty_cache()
|
| 718 |
return result
|
| 719 |
|
| 720 |
+
# --- Step 4.9: Free VRAM before transport computation ---
|
| 721 |
+
print(f"\n[merge] Step 4.9: Moving models to CPU to free VRAM for transport...")
|
| 722 |
+
sys.stdout.flush()
|
| 723 |
+
source_model = source_model.cpu()
|
| 724 |
+
target_model = target_model.cpu()
|
| 725 |
+
gc.collect()
|
| 726 |
+
if torch.cuda.is_available():
|
| 727 |
+
torch.cuda.empty_cache()
|
| 728 |
+
free_mem = torch.cuda.mem_get_info()[0] / 1e9
|
| 729 |
+
total_mem = torch.cuda.mem_get_info()[1] / 1e9
|
| 730 |
+
print(f"[merge] GPU memory after CPU offload: {free_mem:.1f} GB free / {total_mem:.1f} GB total")
|
| 731 |
+
sys.stdout.flush()
|
| 732 |
+
|
| 733 |
# --- Step 5: Compute transport plans ---
|
| 734 |
print(f"\n[merge] Step 5/10: Computing transport plans..."); sys.stdout.flush()
|
| 735 |
step_t = time.time()
|
|
|
|
| 763 |
print(f"[merge] RAM skipped: base model {base_hf_id} not found on HuggingFace")
|
| 764 |
use_ram = False
|
| 765 |
|
| 766 |
+
# --- Step 5.7: Free source model, move target back to GPU ---
|
| 767 |
+
# Source model was moved to CPU in step 4.9. Extract state dict, then delete.
|
| 768 |
+
# Move target model back to GPU for the fusion step.
|
| 769 |
+
print(f"\n[merge] Step 5.7: Extracting source state + moving target back to GPU..."); sys.stdout.flush()
|
|
|
|
| 770 |
step_t = time.time()
|
| 771 |
source_state_cpu = {k: v.cpu() for k, v in source_model.state_dict().items()}
|
| 772 |
del source_model
|
| 773 |
gc.collect()
|
| 774 |
if torch.cuda.is_available():
|
| 775 |
torch.cuda.empty_cache()
|
| 776 |
+
# Move target back to GPU for fusion
|
| 777 |
+
target_model = target_model.to("cuda")
|
| 778 |
+
if torch.cuda.is_available():
|
| 779 |
free_mem = torch.cuda.mem_get_info()[0] / 1e9
|
| 780 |
total_mem = torch.cuda.mem_get_info()[1] / 1e9
|
| 781 |
+
print(f"[merge] GPU memory (target on GPU, source freed): {free_mem:.1f} GB free / {total_mem:.1f} GB total")
|
| 782 |
print(f"[merge] Step 5.7 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
|
| 783 |
|
| 784 |
# --- Step 6: Pre-merge protection ---
|
|
|
|
| 997 |
def run_pipeline(
|
| 998 |
stages: list[str],
|
| 999 |
cfg: MergeConfig = None,
|
| 1000 |
+
base_checkpoint: str = None,
|
| 1001 |
) -> dict:
|
| 1002 |
"""
|
| 1003 |
Run the full merge pipeline.
|
|
|
|
| 1039 |
Path(cfg.output_dir).mkdir(parents=True, exist_ok=True)
|
| 1040 |
Path(cfg.checkpoint_dir).mkdir(parents=True, exist_ok=True)
|
| 1041 |
|
| 1042 |
+
# --- Load target model (from checkpoint if stacking merges, else from HuggingFace) ---
|
| 1043 |
+
if base_checkpoint and Path(base_checkpoint).exists():
|
| 1044 |
+
print(f"\n[pipeline] Loading target from previous merge: {base_checkpoint}")
|
| 1045 |
+
from transformers import AutoModelForImageTextToText
|
| 1046 |
+
target_model = AutoModelForImageTextToText.from_pretrained(
|
| 1047 |
+
base_checkpoint, torch_dtype=torch.bfloat16, device_map="auto",
|
| 1048 |
+
trust_remote_code=True,
|
| 1049 |
+
)
|
| 1050 |
+
target_tokenizer = AutoTokenizer.from_pretrained(base_checkpoint, trust_remote_code=True)
|
| 1051 |
+
else:
|
| 1052 |
+
target_model, target_tokenizer = load_model(TARGET, cfg)
|
| 1053 |
|
| 1054 |
# --- Inject canary into target (Qwen3's own canary) ---
|
| 1055 |
if "Qwen3-VL-8B" in CANARY_FACTS:
|
|
|
|
| 1141 |
if pipeline_results["final_checkpoint"]:
|
| 1142 |
final_dir = Path(cfg.output_dir) / "final"
|
| 1143 |
final_dir.mkdir(parents=True, exist_ok=True)
|
| 1144 |
+
# Free disk space before final save (Bug #25 fix)
|
| 1145 |
+
import shutil as _shutil
|
| 1146 |
+
for _cleanup in ["models/base"]:
|
| 1147 |
+
_cp = Path(_cleanup)
|
| 1148 |
+
if _cp.exists() and _cp.is_dir():
|
| 1149 |
+
_shutil.rmtree(str(_cp))
|
| 1150 |
+
print(f"[merge] Freed disk: {_cleanup}")
|
| 1151 |
+
import gc; gc.collect()
|
| 1152 |
+
_stat = _shutil.disk_usage("/")
|
| 1153 |
+
print(f"[merge] Disk: {_stat.free / 1e9:.1f} GB free / {_stat.total / 1e9:.1f} GB total")
|
| 1154 |
target_model.save_pretrained(final_dir)
|
| 1155 |
target_tokenizer.save_pretrained(final_dir)
|
| 1156 |
pipeline_results["final_model_path"] = str(final_dir)
|
hugging/td_fuse/transport.py
CHANGED
|
@@ -518,6 +518,7 @@ def fuse_weights(
|
|
| 518 |
transport_plans: dict,
|
| 519 |
source_config: ModelConfig,
|
| 520 |
cfg: MergeConfig,
|
|
|
|
| 521 |
) -> AutoModelForCausalLM:
|
| 522 |
"""
|
| 523 |
Fuse source model weights into target model using transport plans.
|
|
|
|
| 518 |
transport_plans: dict,
|
| 519 |
source_config: ModelConfig,
|
| 520 |
cfg: MergeConfig,
|
| 521 |
+
target_activations: dict = None,
|
| 522 |
) -> AutoModelForCausalLM:
|
| 523 |
"""
|
| 524 |
Fuse source model weights into target model using transport plans.
|
hugging/td_lang/__pycache__/compiler.cpython-310.pyc
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:73d29a1b793e7e773f99ff76abcf307831d391aaa7fc368cc4ab4ac8b3159303
|
| 3 |
+
size 192996
|
hugging/td_lang/compiler.py
CHANGED
|
@@ -224,6 +224,7 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 224 |
self._emit("")
|
| 225 |
self._emit("def main():")
|
| 226 |
self._indent += 1
|
|
|
|
| 227 |
self._emit("start_time = time.time()")
|
| 228 |
self._emit("lineage = {}")
|
| 229 |
self._emit("models = {}")
|
|
@@ -466,8 +467,21 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 466 |
self._indent += 1
|
| 467 |
self._emit('raise SystemExit(f"Could not match source {_source_ref} to any SOURCES entry.")')
|
| 468 |
self._indent -= 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
self._emit("cfg = MergeConfig()")
|
| 470 |
-
self._emit("merge_result = run_pipeline([_stage], cfg)")
|
|
|
|
| 471 |
self._emit(f'results["{cmd.target}_merge"] = merge_result')
|
| 472 |
self._emit("merged_stages.append(_stage)")
|
| 473 |
self._emit('if merge_result.get("final_checkpoint"):')
|
|
@@ -1195,21 +1209,41 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 1195 |
self._emit("")
|
| 1196 |
|
| 1197 |
if cmd.method == "grpo":
|
| 1198 |
-
self._emit("#
|
| 1199 |
-
self._emit("#
|
| 1200 |
-
self._emit("from
|
| 1201 |
-
self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig")
|
| 1202 |
self._emit("from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training")
|
| 1203 |
-
self._emit("from datasets import load_dataset")
|
| 1204 |
self._emit("import torch")
|
| 1205 |
self._emit("")
|
| 1206 |
-
self._emit("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1207 |
self._emit("if tok.pad_token is None:")
|
| 1208 |
self._indent += 1
|
| 1209 |
self._emit("tok.pad_token = tok.eos_token")
|
| 1210 |
self._indent -= 1
|
| 1211 |
self._emit("")
|
| 1212 |
-
self._emit("# 4-bit quantization - shrinks 7B model from 14GB to ~4GB VRAM")
|
| 1213 |
self._emit("bnb_config = BitsAndBytesConfig(")
|
| 1214 |
self._indent += 1
|
| 1215 |
self._emit("load_in_4bit=True,")
|
|
@@ -1218,8 +1252,7 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 1218 |
self._emit("bnb_4bit_use_double_quant=True,")
|
| 1219 |
self._indent -= 1
|
| 1220 |
self._emit(")")
|
| 1221 |
-
self._emit("")
|
| 1222 |
-
self._emit("model = _load_model_smart(checkpoint, quantization_config=bnb_config, device_map='auto')")
|
| 1223 |
self._emit("model = prepare_model_for_kbit_training(model)")
|
| 1224 |
self._emit("")
|
| 1225 |
self._emit("# LoRA adapters on mid-to-late layers (test_12: layers 16-28 for 32-layer)")
|
|
@@ -1246,178 +1279,56 @@ DO NOT EDIT - regenerate from the .td file instead.
|
|
| 1246 |
self._emit("train_data = load_dataset(dataset_path, split='train')")
|
| 1247 |
self._indent -= 1
|
| 1248 |
self._emit("")
|
| 1249 |
-
self._emit("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1250 |
self._indent += 1
|
| 1251 |
self._emit(f"max_steps={steps},")
|
| 1252 |
self._emit(f"learning_rate={lr},")
|
| 1253 |
self._emit("per_device_train_batch_size=1,")
|
| 1254 |
self._emit("gradient_accumulation_steps=8,")
|
| 1255 |
-
self._emit("logging_steps=
|
| 1256 |
-
self._emit('output_dir="td_lang_outputs/
|
| 1257 |
-
self._emit("save_steps=
|
| 1258 |
self._emit('bf16=True,')
|
| 1259 |
-
self._emit("gradient_checkpointing=True,
|
|
|
|
| 1260 |
self._indent -= 1
|
| 1261 |
self._emit(")")
|
| 1262 |
self._emit("")
|
| 1263 |
-
self._emit("
|
| 1264 |
-
# Wire in reward_contract verifiers if they exist
|
| 1265 |
-
if program and program.reward_contract and program.reward_contract.verifiers:
|
| 1266 |
-
verifiers = program.reward_contract.verifiers
|
| 1267 |
-
self._emit(f'# reward_contract verifiers wired in: {verifiers}')
|
| 1268 |
-
self._emit(f'_active_verifiers = {verifiers}')
|
| 1269 |
-
if program.reward_contract.min_reward is not None:
|
| 1270 |
-
self._emit(f'_min_reward = {program.reward_contract.min_reward}')
|
| 1271 |
-
else:
|
| 1272 |
-
self._emit('_min_reward = 0.0')
|
| 1273 |
-
else:
|
| 1274 |
-
self._emit('_active_verifiers = ["code_compiles", "math_correct"] # defaults')
|
| 1275 |
-
self._emit('_min_reward = 0.0')
|
| 1276 |
-
self._emit("import ast, math, re")
|
| 1277 |
-
self._emit("ALLOWED_EXPR = re.compile(r'^[0-9+\\-*/().\\s]+$')")
|
| 1278 |
-
self._emit("")
|
| 1279 |
-
self._emit("def _safe_eval(expr: str):")
|
| 1280 |
-
self._indent += 1
|
| 1281 |
-
self._emit("expr = expr.strip()")
|
| 1282 |
-
self._emit("if not ALLOWED_EXPR.match(expr):")
|
| 1283 |
-
self._indent += 1
|
| 1284 |
-
self._emit("return None")
|
| 1285 |
-
self._indent -= 1
|
| 1286 |
-
self._emit("try:")
|
| 1287 |
-
self._indent += 1
|
| 1288 |
-
self._emit("return float(eval(expr, {'__builtins__': {}}, {}))")
|
| 1289 |
-
self._indent -= 1
|
| 1290 |
-
self._emit("except Exception:")
|
| 1291 |
-
self._indent += 1
|
| 1292 |
-
self._emit("return None")
|
| 1293 |
-
self._indent -= 2
|
| 1294 |
-
self._emit("")
|
| 1295 |
-
self._emit("def reward_fn(completions, prompts=None, **kwargs):")
|
| 1296 |
-
self._indent += 1
|
| 1297 |
-
self._emit("prompts = prompts or ['' for _ in completions]")
|
| 1298 |
-
self._emit("rewards = []")
|
| 1299 |
-
self._emit("for comp, prompt in zip(completions, prompts):")
|
| 1300 |
-
self._indent += 1
|
| 1301 |
-
self._emit("text = comp if isinstance(comp, str) else comp[0].get('content', '')")
|
| 1302 |
-
self._emit("score = 0.0")
|
| 1303 |
-
self._emit("# Code compilation reward (active if 'code_compiles' in verifiers)")
|
| 1304 |
-
self._emit("if 'code_compiles' in _active_verifiers:")
|
| 1305 |
-
self._indent += 1
|
| 1306 |
-
self._emit("code_blocks = re.findall(r'```python\\n(.*?)```', text, re.S)")
|
| 1307 |
-
self._emit("for block in code_blocks or []:")
|
| 1308 |
-
self._indent += 1
|
| 1309 |
-
self._emit("try:")
|
| 1310 |
-
self._indent += 1
|
| 1311 |
-
self._emit("ast.parse(block)")
|
| 1312 |
-
self._emit("score += 0.4")
|
| 1313 |
-
self._emit("break")
|
| 1314 |
-
self._indent -= 1
|
| 1315 |
-
self._emit("except SyntaxError:")
|
| 1316 |
-
self._indent += 1
|
| 1317 |
-
self._emit("pass")
|
| 1318 |
-
self._indent -= 3
|
| 1319 |
-
self._emit("# Math correctness reward (active if 'math_correct' in verifiers)")
|
| 1320 |
-
self._emit("if 'math_correct' in _active_verifiers:")
|
| 1321 |
-
self._indent += 1
|
| 1322 |
-
self._emit("expr_match = re.search(r'([0-9+\\-*/().\\s]{3,})', prompt)")
|
| 1323 |
-
self._emit("pred_num_match = re.search(r'(-?\\d+(?:\\.\\d+)?)', text)")
|
| 1324 |
-
self._emit("if expr_match and pred_num_match:")
|
| 1325 |
-
self._indent += 1
|
| 1326 |
-
self._emit("expr = expr_match.group(1)")
|
| 1327 |
-
self._emit("target = _safe_eval(expr)")
|
| 1328 |
-
self._emit("try:")
|
| 1329 |
-
self._indent += 1
|
| 1330 |
-
self._emit("pred_val = float(pred_num_match.group(1))")
|
| 1331 |
-
self._indent -= 1
|
| 1332 |
-
self._emit("except Exception:")
|
| 1333 |
-
self._indent += 1
|
| 1334 |
-
self._emit("pred_val = None")
|
| 1335 |
-
self._indent -= 1
|
| 1336 |
-
self._emit("if target is not None and pred_val is not None and abs(target - pred_val) < 1e-3:")
|
| 1337 |
-
self._indent += 1
|
| 1338 |
-
self._emit("score += 0.4")
|
| 1339 |
-
self._indent -= 3
|
| 1340 |
-
self._emit("# No hallucination check (active if 'no_hallucination' in verifiers)")
|
| 1341 |
-
self._emit("if 'no_hallucination' in _active_verifiers:")
|
| 1342 |
-
self._indent += 1
|
| 1343 |
-
self._emit("hedges = ['i think', 'probably', 'not sure', 'might be']")
|
| 1344 |
-
self._emit("if not any(h in text.lower() for h in hedges):")
|
| 1345 |
-
self._indent += 1
|
| 1346 |
-
self._emit("score += 0.2")
|
| 1347 |
-
self._indent -= 2
|
| 1348 |
-
self._emit("# Structured answer bonus")
|
| 1349 |
-
self._emit("if 'answer' in text.lower() or 'result' in text.lower():")
|
| 1350 |
-
self._indent += 1
|
| 1351 |
-
self._emit("score += 0.2")
|
| 1352 |
-
self._indent -= 1
|
| 1353 |
-
self._emit("# Enforce min_reward from reward_contract")
|
| 1354 |
-
self._emit("rewards.append(max(min(score, 1.0), _min_reward) if score > 0 else 0.0)")
|
| 1355 |
-
self._indent -= 1
|
| 1356 |
-
self._emit("return rewards")
|
| 1357 |
-
self._indent -= 1
|
| 1358 |
-
self._emit("")
|
| 1359 |
-
self._emit("# Early stopping (test_15): KL spike, reward drop, diversity drop")
|
| 1360 |
-
self._emit("from transformers import TrainerCallback")
|
| 1361 |
-
self._emit("")
|
| 1362 |
-
self._emit("class EarlyStopper(TrainerCallback):")
|
| 1363 |
-
self._indent += 1
|
| 1364 |
-
self._emit("def __init__(self):")
|
| 1365 |
-
self._indent += 1
|
| 1366 |
-
self._emit("self.kl_history = []")
|
| 1367 |
-
self._emit("self.eval_rewards = []")
|
| 1368 |
-
self._emit("self.entropy_history = []")
|
| 1369 |
-
self._indent -= 1
|
| 1370 |
-
self._emit("")
|
| 1371 |
-
self._emit("def on_log(self, args, state, control, logs=None, **kwargs):")
|
| 1372 |
-
self._indent += 1
|
| 1373 |
-
self._emit("logs = logs or {}")
|
| 1374 |
-
self._emit("if 'kl' in logs:")
|
| 1375 |
-
self._indent += 1
|
| 1376 |
-
self._emit("self.kl_history.append(logs['kl'])")
|
| 1377 |
-
self._emit("if len(self.kl_history) > 5:")
|
| 1378 |
-
self._indent += 1
|
| 1379 |
-
self._emit("ma = sum(self.kl_history[-5:]) / 5")
|
| 1380 |
-
self._emit("if logs['kl'] > 3.1 * ma:")
|
| 1381 |
-
self._indent += 1
|
| 1382 |
-
self._emit("control.should_training_stop = True")
|
| 1383 |
-
self._emit("print('[td_lang][early_stop] KL spike detected - stopping GRPO')")
|
| 1384 |
-
self._indent -= 2
|
| 1385 |
-
self._indent -= 1
|
| 1386 |
-
self._emit("if 'eval/reward' in logs:")
|
| 1387 |
-
self._indent += 1
|
| 1388 |
-
self._emit("self.eval_rewards.append(logs['eval/reward'])")
|
| 1389 |
-
self._emit("if len(self.eval_rewards) >= 2 and self.eval_rewards[-1] < self.eval_rewards[-2]:")
|
| 1390 |
-
self._indent += 1
|
| 1391 |
-
self._emit("control.should_training_stop = True")
|
| 1392 |
-
self._emit("print('[td_lang][early_stop] Validation reward drop - stopping GRPO')")
|
| 1393 |
-
self._indent -= 1
|
| 1394 |
-
self._indent -= 1
|
| 1395 |
-
self._emit("if 'policy_entropy' in logs:")
|
| 1396 |
-
self._indent += 1
|
| 1397 |
-
self._emit("self.entropy_history.append(logs['policy_entropy'])")
|
| 1398 |
-
self._emit("if len(self.entropy_history) >= 3:")
|
| 1399 |
-
self._indent += 1
|
| 1400 |
-
self._emit("baseline = self.entropy_history[0]")
|
| 1401 |
-
self._emit("if self.entropy_history[-1] < 0.93 * baseline:")
|
| 1402 |
-
self._indent += 1
|
| 1403 |
-
self._emit("control.should_training_stop = True")
|
| 1404 |
-
self._emit("print('[td_lang][early_stop] Diversity collapsed - stopping GRPO')")
|
| 1405 |
-
self._indent -= 2
|
| 1406 |
-
self._indent -= 2
|
| 1407 |
-
self._indent -= 1
|
| 1408 |
-
self._emit("trainer = GRPOTrainer(")
|
| 1409 |
self._indent += 1
|
| 1410 |
self._emit("model=model,")
|
| 1411 |
-
self._emit("args=
|
| 1412 |
self._emit("train_dataset=train_data,")
|
| 1413 |
-
self._emit("reward_funcs=reward_fn,")
|
| 1414 |
self._emit("processing_class=tok,")
|
| 1415 |
-
self._emit("callbacks=[EarlyStopper()],")
|
| 1416 |
self._indent -= 1
|
| 1417 |
self._emit(")")
|
| 1418 |
self._emit("trainer.train()")
|
| 1419 |
-
self._emit("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1420 |
self._emit(f'models["{cmd.target}"]["checkpoint"] = "td_lang_outputs/grpo_trained"')
|
|
|
|
| 1421 |
|
| 1422 |
elif cmd.method in ("sft", "dpo"):
|
| 1423 |
self._emit(f"# {cmd.method.upper()} training with QLoRA (fits on 24GB 4090)")
|
|
|
|
| 224 |
self._emit("")
|
| 225 |
self._emit("def main():")
|
| 226 |
self._indent += 1
|
| 227 |
+
self._emit("import os # safety: prevent UnboundLocalError if shadowed")
|
| 228 |
self._emit("start_time = time.time()")
|
| 229 |
self._emit("lineage = {}")
|
| 230 |
self._emit("models = {}")
|
|
|
|
| 467 |
self._indent += 1
|
| 468 |
self._emit('raise SystemExit(f"Could not match source {_source_ref} to any SOURCES entry.")')
|
| 469 |
self._indent -= 1
|
| 470 |
+
self._emit("")
|
| 471 |
+
self._emit("# Skip merge if checkpoint already exists (Bug #27 - saves ~12 min)")
|
| 472 |
+
self._emit('_merge_ckpt = Path(f"td_fuse_checkpoints/after_{_stage}")')
|
| 473 |
+
self._emit("if _merge_ckpt.exists() and (_merge_ckpt / 'model.safetensors').exists():")
|
| 474 |
+
self._indent += 1
|
| 475 |
+
self._emit('print(f"[td_lang] Found merge checkpoint {_merge_ckpt} - SKIPPING merge")')
|
| 476 |
+
self._emit('merge_result = {"status": "skipped", "final_checkpoint": str(_merge_ckpt)}')
|
| 477 |
+
self._indent -= 1
|
| 478 |
+
self._emit("else:")
|
| 479 |
+
self._indent += 1
|
| 480 |
+
self._emit("# Stack merges: pass previous checkpoint so MiMo builds on DeepSeek, etc.")
|
| 481 |
+
self._emit(f'_prev_ckpt = models.get("{cmd.target}", {{}}).get("checkpoint")')
|
| 482 |
self._emit("cfg = MergeConfig()")
|
| 483 |
+
self._emit("merge_result = run_pipeline([_stage], cfg, base_checkpoint=_prev_ckpt)")
|
| 484 |
+
self._indent -= 1
|
| 485 |
self._emit(f'results["{cmd.target}_merge"] = merge_result')
|
| 486 |
self._emit("merged_stages.append(_stage)")
|
| 487 |
self._emit('if merge_result.get("final_checkpoint"):')
|
|
|
|
| 1209 |
self._emit("")
|
| 1210 |
|
| 1211 |
if cmd.method == "grpo":
|
| 1212 |
+
self._emit("# Bug #26 fix: Use SFT on merge checkpoint (same approach as healing — proven to work)")
|
| 1213 |
+
self._emit("# GRPOTrainer breaks with Qwen3-VL, but standard Trainer works perfectly")
|
| 1214 |
+
self._emit("from transformers import AutoTokenizer, TrainingArguments, BitsAndBytesConfig, Trainer")
|
|
|
|
| 1215 |
self._emit("from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training")
|
| 1216 |
+
self._emit("from datasets import load_dataset, Dataset")
|
| 1217 |
self._emit("import torch")
|
| 1218 |
self._emit("")
|
| 1219 |
+
self._emit("# Use latest merge checkpoint — pick newest after_* dir in td_fuse_checkpoints/")
|
| 1220 |
+
self._emit("_merge_ckpt = None")
|
| 1221 |
+
self._emit("_ckpt_base = Path('td_fuse_checkpoints')")
|
| 1222 |
+
self._emit("if _ckpt_base.exists():")
|
| 1223 |
+
self._indent += 1
|
| 1224 |
+
self._emit("_after_dirs = sorted(_ckpt_base.glob('after_*'), key=lambda p: p.stat().st_mtime, reverse=True)")
|
| 1225 |
+
self._emit("if _after_dirs and (_after_dirs[0] / 'model.safetensors').exists():")
|
| 1226 |
+
self._indent += 1
|
| 1227 |
+
self._emit("_merge_ckpt = str(_after_dirs[0])")
|
| 1228 |
+
self._indent -= 1
|
| 1229 |
+
self._indent -= 1
|
| 1230 |
+
self._emit("if _merge_ckpt:")
|
| 1231 |
+
self._indent += 1
|
| 1232 |
+
self._emit('print(f"[td_lang] Using merge checkpoint for training: {_merge_ckpt}")')
|
| 1233 |
+
self._emit("_train_ckpt = _merge_ckpt")
|
| 1234 |
+
self._indent -= 1
|
| 1235 |
+
self._emit("else:")
|
| 1236 |
+
self._indent += 1
|
| 1237 |
+
self._emit("_train_ckpt = checkpoint")
|
| 1238 |
+
self._emit('print(f"[td_lang] Using checkpoint for training: {_train_ckpt}")')
|
| 1239 |
+
self._indent -= 1
|
| 1240 |
+
self._emit("")
|
| 1241 |
+
self._emit("tok = AutoTokenizer.from_pretrained(_train_ckpt)")
|
| 1242 |
self._emit("if tok.pad_token is None:")
|
| 1243 |
self._indent += 1
|
| 1244 |
self._emit("tok.pad_token = tok.eos_token")
|
| 1245 |
self._indent -= 1
|
| 1246 |
self._emit("")
|
|
|
|
| 1247 |
self._emit("bnb_config = BitsAndBytesConfig(")
|
| 1248 |
self._indent += 1
|
| 1249 |
self._emit("load_in_4bit=True,")
|
|
|
|
| 1252 |
self._emit("bnb_4bit_use_double_quant=True,")
|
| 1253 |
self._indent -= 1
|
| 1254 |
self._emit(")")
|
| 1255 |
+
self._emit("model = _load_model_smart(_train_ckpt, quantization_config=bnb_config, device_map='auto')")
|
|
|
|
| 1256 |
self._emit("model = prepare_model_for_kbit_training(model)")
|
| 1257 |
self._emit("")
|
| 1258 |
self._emit("# LoRA adapters on mid-to-late layers (test_12: layers 16-28 for 32-layer)")
|
|
|
|
| 1279 |
self._emit("train_data = load_dataset(dataset_path, split='train')")
|
| 1280 |
self._indent -= 1
|
| 1281 |
self._emit("")
|
| 1282 |
+
self._emit("# Format synth data as text for SFT (prompt + response)")
|
| 1283 |
+
self._emit("def _format_synth(example):")
|
| 1284 |
+
self._indent += 1
|
| 1285 |
+
self._emit("text = example['prompt'] + '\\n' + example.get('response', '')")
|
| 1286 |
+
self._emit("tokens = tok(text, truncation=True, max_length=512, padding='max_length')")
|
| 1287 |
+
self._emit("tokens['labels'] = tokens['input_ids'].copy()")
|
| 1288 |
+
self._emit("return tokens")
|
| 1289 |
+
self._indent -= 1
|
| 1290 |
+
self._emit("train_data = train_data.map(_format_synth, remove_columns=train_data.column_names)")
|
| 1291 |
+
self._emit("")
|
| 1292 |
+
self._emit("training_args = TrainingArguments(")
|
| 1293 |
self._indent += 1
|
| 1294 |
self._emit(f"max_steps={steps},")
|
| 1295 |
self._emit(f"learning_rate={lr},")
|
| 1296 |
self._emit("per_device_train_batch_size=1,")
|
| 1297 |
self._emit("gradient_accumulation_steps=8,")
|
| 1298 |
+
self._emit("logging_steps=10,")
|
| 1299 |
+
self._emit('output_dir="td_lang_outputs/sft_training",')
|
| 1300 |
+
self._emit("save_steps=50,")
|
| 1301 |
self._emit('bf16=True,')
|
| 1302 |
+
self._emit("gradient_checkpointing=True,")
|
| 1303 |
+
self._emit("remove_unused_columns=False,")
|
| 1304 |
self._indent -= 1
|
| 1305 |
self._emit(")")
|
| 1306 |
self._emit("")
|
| 1307 |
+
self._emit("trainer = Trainer(")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1308 |
self._indent += 1
|
| 1309 |
self._emit("model=model,")
|
| 1310 |
+
self._emit("args=training_args,")
|
| 1311 |
self._emit("train_dataset=train_data,")
|
|
|
|
| 1312 |
self._emit("processing_class=tok,")
|
|
|
|
| 1313 |
self._indent -= 1
|
| 1314 |
self._emit(")")
|
| 1315 |
self._emit("trainer.train()")
|
| 1316 |
+
self._emit("")
|
| 1317 |
+
self._emit("# Merge LoRA and save")
|
| 1318 |
+
self._emit("model = model.merge_and_unload()")
|
| 1319 |
+
self._emit("")
|
| 1320 |
+
self._emit("# Free disk before save")
|
| 1321 |
+
self._emit("import shutil, gc as _gc")
|
| 1322 |
+
self._emit("for _d in ['td_fuse_outputs/final', 'td_fuse_outputs/healed']:")
|
| 1323 |
+
self._indent += 1
|
| 1324 |
+
self._emit("_p = Path(_d)")
|
| 1325 |
+
self._emit("if _p.exists() and _p.is_dir(): shutil.rmtree(str(_p)); print(f'[td_lang] Freed: {_d}')")
|
| 1326 |
+
self._indent -= 1
|
| 1327 |
+
self._emit("_gc.collect()")
|
| 1328 |
+
self._emit("model.save_pretrained('td_lang_outputs/grpo_trained')")
|
| 1329 |
+
self._emit("tok.save_pretrained('td_lang_outputs/grpo_trained')")
|
| 1330 |
self._emit(f'models["{cmd.target}"]["checkpoint"] = "td_lang_outputs/grpo_trained"')
|
| 1331 |
+
self._emit("print('[td_lang] Training complete - model saved to td_lang_outputs/grpo_trained')")
|
| 1332 |
|
| 1333 |
elif cmd.method in ("sft", "dpo"):
|
| 1334 |
self._emit(f"# {cmd.method.upper()} training with QLoRA (fits on 24GB 4090)")
|
hugging/td_start.td
CHANGED
|
@@ -47,6 +47,12 @@ load "Qwen/Qwen3-VL-8B-Instruct" as base
|
|
| 47 |
# Gives us deep reasoning abilities from R1
|
| 48 |
merge "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" into base using transport strength 0.5
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
# --- Step 3: Heal any merge damage ---
|
| 51 |
# QLoRA fine-tune to smooth out rough edges from the merge
|
| 52 |
heal base lora_r 32 epochs 2
|
|
|
|
| 47 |
# Gives us deep reasoning abilities from R1
|
| 48 |
merge "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" into base using transport strength 0.5
|
| 49 |
|
| 50 |
+
# --- Step 2b: Merge in MiMo-7B reasoning ---
|
| 51 |
+
# Medium risk: same layer count (36) and hidden_dim (4096)
|
| 52 |
+
# MTP heads get dropped automatically (no Qwen3 equivalent)
|
| 53 |
+
# Embeddings skipped (28% vocab overlap too low)
|
| 54 |
+
merge "XiaomiMiMo/MiMo-7B-RL" into base using transport strength 0.4
|
| 55 |
+
|
| 56 |
# --- Step 3: Heal any merge damage ---
|
| 57 |
# QLoRA fine-tune to smooth out rough edges from the merge
|
| 58 |
heal base lora_r 32 epochs 2
|