td-builder commited on
Commit
7d4debb
·
verified ·
1 Parent(s): b0b3cec

Upload 139 files

Browse files
hugging/td_fuse/__pycache__/merge.cpython-310.pyc ADDED
Binary file (31.9 kB). View file
 
hugging/td_fuse/__pycache__/transport.cpython-310.pyc ADDED
Binary file (19.4 kB). View file
 
hugging/td_fuse/heal.py CHANGED
@@ -242,6 +242,11 @@ def apply_qlora_standard(
242
  Returns:
243
  Path to healed model directory
244
  """
 
 
 
 
 
245
  from peft import LoraConfig, get_peft_model, TaskType
246
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
247
 
@@ -404,6 +409,13 @@ def heal_model(
404
  if cfg is None:
405
  cfg = MergeConfig()
406
 
 
 
 
 
 
 
 
407
  heal_start = time.time()
408
  print("\n" + "=" * 60)
409
  print("HEALING FINE-TUNE")
 
242
  Returns:
243
  Path to healed model directory
244
  """
245
+ import os
246
+ healed_check = os.path.join('td_fuse_outputs', 'healed', 'model.safetensors')
247
+ if os.path.exists(healed_check):
248
+ print('[heal] Found existing healed model — SKIPPING healing!')
249
+ return 'td_fuse_outputs/healed'
250
  from peft import LoraConfig, get_peft_model, TaskType
251
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
252
 
 
409
  if cfg is None:
410
  cfg = MergeConfig()
411
 
412
+ # Skip healing if already done (saves ~45 min on re-runs)
413
+ import os
414
+ healed_check = os.path.join('td_fuse_outputs', 'healed', 'model.safetensors')
415
+ if os.path.exists(healed_check):
416
+ print('[heal] Found existing healed model — SKIPPING healing!')
417
+ return 'td_fuse_outputs/healed'
418
+
419
  heal_start = time.time()
420
  print("\n" + "=" * 60)
421
  print("HEALING FINE-TUNE")
hugging/td_fuse/merge.py CHANGED
@@ -717,6 +717,19 @@ def run_single_merge(
717
  torch.cuda.empty_cache()
718
  return result
719
 
 
 
 
 
 
 
 
 
 
 
 
 
 
720
  # --- Step 5: Compute transport plans ---
721
  print(f"\n[merge] Step 5/10: Computing transport plans..."); sys.stdout.flush()
722
  step_t = time.time()
@@ -750,20 +763,22 @@ def run_single_merge(
750
  print(f"[merge] RAM skipped: base model {base_hf_id} not found on HuggingFace")
751
  use_ram = False
752
 
753
- # --- Step 5.7: Free source model from GPU ---
754
- # After transport plans are computed, we only need the source STATE DICT
755
- # (not the full model object). Freeing the model saves ~16 GB of GPU memory
756
- # which prevents OOM during the fusion step.
757
- print(f"\n[merge] Step 5.7: Freeing source model from GPU..."); sys.stdout.flush()
758
  step_t = time.time()
759
  source_state_cpu = {k: v.cpu() for k, v in source_model.state_dict().items()}
760
  del source_model
761
  gc.collect()
762
  if torch.cuda.is_available():
763
  torch.cuda.empty_cache()
 
 
 
764
  free_mem = torch.cuda.mem_get_info()[0] / 1e9
765
  total_mem = torch.cuda.mem_get_info()[1] / 1e9
766
- print(f"[merge] GPU memory after freeing source: {free_mem:.1f} GB free / {total_mem:.1f} GB total")
767
  print(f"[merge] Step 5.7 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
768
 
769
  # --- Step 6: Pre-merge protection ---
@@ -982,6 +997,7 @@ def run_single_merge(
982
  def run_pipeline(
983
  stages: list[str],
984
  cfg: MergeConfig = None,
 
985
  ) -> dict:
986
  """
987
  Run the full merge pipeline.
@@ -1023,8 +1039,17 @@ def run_pipeline(
1023
  Path(cfg.output_dir).mkdir(parents=True, exist_ok=True)
1024
  Path(cfg.checkpoint_dir).mkdir(parents=True, exist_ok=True)
1025
 
1026
- # --- Load target model ---
1027
- target_model, target_tokenizer = load_model(TARGET, cfg)
 
 
 
 
 
 
 
 
 
1028
 
1029
  # --- Inject canary into target (Qwen3's own canary) ---
1030
  if "Qwen3-VL-8B" in CANARY_FACTS:
@@ -1116,6 +1141,16 @@ def run_pipeline(
1116
  if pipeline_results["final_checkpoint"]:
1117
  final_dir = Path(cfg.output_dir) / "final"
1118
  final_dir.mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
1119
  target_model.save_pretrained(final_dir)
1120
  target_tokenizer.save_pretrained(final_dir)
1121
  pipeline_results["final_model_path"] = str(final_dir)
 
717
  torch.cuda.empty_cache()
718
  return result
719
 
720
+ # --- Step 4.9: Free VRAM before transport computation ---
721
+ print(f"\n[merge] Step 4.9: Moving models to CPU to free VRAM for transport...")
722
+ sys.stdout.flush()
723
+ source_model = source_model.cpu()
724
+ target_model = target_model.cpu()
725
+ gc.collect()
726
+ if torch.cuda.is_available():
727
+ torch.cuda.empty_cache()
728
+ free_mem = torch.cuda.mem_get_info()[0] / 1e9
729
+ total_mem = torch.cuda.mem_get_info()[1] / 1e9
730
+ print(f"[merge] GPU memory after CPU offload: {free_mem:.1f} GB free / {total_mem:.1f} GB total")
731
+ sys.stdout.flush()
732
+
733
  # --- Step 5: Compute transport plans ---
734
  print(f"\n[merge] Step 5/10: Computing transport plans..."); sys.stdout.flush()
735
  step_t = time.time()
 
763
  print(f"[merge] RAM skipped: base model {base_hf_id} not found on HuggingFace")
764
  use_ram = False
765
 
766
+ # --- Step 5.7: Free source model, move target back to GPU ---
767
+ # Source model was moved to CPU in step 4.9. Extract state dict, then delete.
768
+ # Move target model back to GPU for the fusion step.
769
+ print(f"\n[merge] Step 5.7: Extracting source state + moving target back to GPU..."); sys.stdout.flush()
 
770
  step_t = time.time()
771
  source_state_cpu = {k: v.cpu() for k, v in source_model.state_dict().items()}
772
  del source_model
773
  gc.collect()
774
  if torch.cuda.is_available():
775
  torch.cuda.empty_cache()
776
+ # Move target back to GPU for fusion
777
+ target_model = target_model.to("cuda")
778
+ if torch.cuda.is_available():
779
  free_mem = torch.cuda.mem_get_info()[0] / 1e9
780
  total_mem = torch.cuda.mem_get_info()[1] / 1e9
781
+ print(f"[merge] GPU memory (target on GPU, source freed): {free_mem:.1f} GB free / {total_mem:.1f} GB total")
782
  print(f"[merge] Step 5.7 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
783
 
784
  # --- Step 6: Pre-merge protection ---
 
997
  def run_pipeline(
998
  stages: list[str],
999
  cfg: MergeConfig = None,
1000
+ base_checkpoint: str = None,
1001
  ) -> dict:
1002
  """
1003
  Run the full merge pipeline.
 
1039
  Path(cfg.output_dir).mkdir(parents=True, exist_ok=True)
1040
  Path(cfg.checkpoint_dir).mkdir(parents=True, exist_ok=True)
1041
 
1042
+ # --- Load target model (from checkpoint if stacking merges, else from HuggingFace) ---
1043
+ if base_checkpoint and Path(base_checkpoint).exists():
1044
+ print(f"\n[pipeline] Loading target from previous merge: {base_checkpoint}")
1045
+ from transformers import AutoModelForImageTextToText
1046
+ target_model = AutoModelForImageTextToText.from_pretrained(
1047
+ base_checkpoint, torch_dtype=torch.bfloat16, device_map="auto",
1048
+ trust_remote_code=True,
1049
+ )
1050
+ target_tokenizer = AutoTokenizer.from_pretrained(base_checkpoint, trust_remote_code=True)
1051
+ else:
1052
+ target_model, target_tokenizer = load_model(TARGET, cfg)
1053
 
1054
  # --- Inject canary into target (Qwen3's own canary) ---
1055
  if "Qwen3-VL-8B" in CANARY_FACTS:
 
1141
  if pipeline_results["final_checkpoint"]:
1142
  final_dir = Path(cfg.output_dir) / "final"
1143
  final_dir.mkdir(parents=True, exist_ok=True)
1144
+ # Free disk space before final save (Bug #25 fix)
1145
+ import shutil as _shutil
1146
+ for _cleanup in ["models/base"]:
1147
+ _cp = Path(_cleanup)
1148
+ if _cp.exists() and _cp.is_dir():
1149
+ _shutil.rmtree(str(_cp))
1150
+ print(f"[merge] Freed disk: {_cleanup}")
1151
+ import gc; gc.collect()
1152
+ _stat = _shutil.disk_usage("/")
1153
+ print(f"[merge] Disk: {_stat.free / 1e9:.1f} GB free / {_stat.total / 1e9:.1f} GB total")
1154
  target_model.save_pretrained(final_dir)
1155
  target_tokenizer.save_pretrained(final_dir)
1156
  pipeline_results["final_model_path"] = str(final_dir)
hugging/td_fuse/transport.py CHANGED
@@ -518,6 +518,7 @@ def fuse_weights(
518
  transport_plans: dict,
519
  source_config: ModelConfig,
520
  cfg: MergeConfig,
 
521
  ) -> AutoModelForCausalLM:
522
  """
523
  Fuse source model weights into target model using transport plans.
 
518
  transport_plans: dict,
519
  source_config: ModelConfig,
520
  cfg: MergeConfig,
521
+ target_activations: dict = None,
522
  ) -> AutoModelForCausalLM:
523
  """
524
  Fuse source model weights into target model using transport plans.
hugging/td_lang/__pycache__/compiler.cpython-310.pyc CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:97c261ef8c24868bc538ecaed5c905927d7b933d3ad2e9c6032a6de0cb6bb41e
3
- size 104110
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73d29a1b793e7e773f99ff76abcf307831d391aaa7fc368cc4ab4ac8b3159303
3
+ size 192996
hugging/td_lang/compiler.py CHANGED
@@ -224,6 +224,7 @@ DO NOT EDIT - regenerate from the .td file instead.
224
  self._emit("")
225
  self._emit("def main():")
226
  self._indent += 1
 
227
  self._emit("start_time = time.time()")
228
  self._emit("lineage = {}")
229
  self._emit("models = {}")
@@ -466,8 +467,21 @@ DO NOT EDIT - regenerate from the .td file instead.
466
  self._indent += 1
467
  self._emit('raise SystemExit(f"Could not match source {_source_ref} to any SOURCES entry.")')
468
  self._indent -= 1
 
 
 
 
 
 
 
 
 
 
 
 
469
  self._emit("cfg = MergeConfig()")
470
- self._emit("merge_result = run_pipeline([_stage], cfg)")
 
471
  self._emit(f'results["{cmd.target}_merge"] = merge_result')
472
  self._emit("merged_stages.append(_stage)")
473
  self._emit('if merge_result.get("final_checkpoint"):')
@@ -1195,21 +1209,41 @@ DO NOT EDIT - regenerate from the .td file instead.
1195
  self._emit("")
1196
 
1197
  if cmd.method == "grpo":
1198
- self._emit("# GRPO training with QLoRA (test_15: 64 steps sweet spot)")
1199
- self._emit("# QLoRA = 4-bit base model + LoRA adapters = fits on 24GB 4090")
1200
- self._emit("from trl import GRPOConfig, GRPOTrainer")
1201
- self._emit("from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig")
1202
  self._emit("from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training")
1203
- self._emit("from datasets import load_dataset")
1204
  self._emit("import torch")
1205
  self._emit("")
1206
- self._emit("tok = AutoTokenizer.from_pretrained(checkpoint)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1207
  self._emit("if tok.pad_token is None:")
1208
  self._indent += 1
1209
  self._emit("tok.pad_token = tok.eos_token")
1210
  self._indent -= 1
1211
  self._emit("")
1212
- self._emit("# 4-bit quantization - shrinks 7B model from 14GB to ~4GB VRAM")
1213
  self._emit("bnb_config = BitsAndBytesConfig(")
1214
  self._indent += 1
1215
  self._emit("load_in_4bit=True,")
@@ -1218,8 +1252,7 @@ DO NOT EDIT - regenerate from the .td file instead.
1218
  self._emit("bnb_4bit_use_double_quant=True,")
1219
  self._indent -= 1
1220
  self._emit(")")
1221
- self._emit("")
1222
- self._emit("model = _load_model_smart(checkpoint, quantization_config=bnb_config, device_map='auto')")
1223
  self._emit("model = prepare_model_for_kbit_training(model)")
1224
  self._emit("")
1225
  self._emit("# LoRA adapters on mid-to-late layers (test_12: layers 16-28 for 32-layer)")
@@ -1246,178 +1279,56 @@ DO NOT EDIT - regenerate from the .td file instead.
1246
  self._emit("train_data = load_dataset(dataset_path, split='train')")
1247
  self._indent -= 1
1248
  self._emit("")
1249
- self._emit("grpo_config = GRPOConfig(")
 
 
 
 
 
 
 
 
 
 
1250
  self._indent += 1
1251
  self._emit(f"max_steps={steps},")
1252
  self._emit(f"learning_rate={lr},")
1253
  self._emit("per_device_train_batch_size=1,")
1254
  self._emit("gradient_accumulation_steps=8,")
1255
- self._emit("logging_steps=16, # eval every 16 steps (test_15)")
1256
- self._emit('output_dir="td_lang_outputs/grpo_training",')
1257
- self._emit("save_steps=16,")
1258
  self._emit('bf16=True,')
1259
- self._emit("gradient_checkpointing=True, # saves VRAM at slight speed cost")
 
1260
  self._indent -= 1
1261
  self._emit(")")
1262
  self._emit("")
1263
- self._emit("# Verified rewards only (test_16: no learned reward model)")
1264
- # Wire in reward_contract verifiers if they exist
1265
- if program and program.reward_contract and program.reward_contract.verifiers:
1266
- verifiers = program.reward_contract.verifiers
1267
- self._emit(f'# reward_contract verifiers wired in: {verifiers}')
1268
- self._emit(f'_active_verifiers = {verifiers}')
1269
- if program.reward_contract.min_reward is not None:
1270
- self._emit(f'_min_reward = {program.reward_contract.min_reward}')
1271
- else:
1272
- self._emit('_min_reward = 0.0')
1273
- else:
1274
- self._emit('_active_verifiers = ["code_compiles", "math_correct"] # defaults')
1275
- self._emit('_min_reward = 0.0')
1276
- self._emit("import ast, math, re")
1277
- self._emit("ALLOWED_EXPR = re.compile(r'^[0-9+\\-*/().\\s]+$')")
1278
- self._emit("")
1279
- self._emit("def _safe_eval(expr: str):")
1280
- self._indent += 1
1281
- self._emit("expr = expr.strip()")
1282
- self._emit("if not ALLOWED_EXPR.match(expr):")
1283
- self._indent += 1
1284
- self._emit("return None")
1285
- self._indent -= 1
1286
- self._emit("try:")
1287
- self._indent += 1
1288
- self._emit("return float(eval(expr, {'__builtins__': {}}, {}))")
1289
- self._indent -= 1
1290
- self._emit("except Exception:")
1291
- self._indent += 1
1292
- self._emit("return None")
1293
- self._indent -= 2
1294
- self._emit("")
1295
- self._emit("def reward_fn(completions, prompts=None, **kwargs):")
1296
- self._indent += 1
1297
- self._emit("prompts = prompts or ['' for _ in completions]")
1298
- self._emit("rewards = []")
1299
- self._emit("for comp, prompt in zip(completions, prompts):")
1300
- self._indent += 1
1301
- self._emit("text = comp if isinstance(comp, str) else comp[0].get('content', '')")
1302
- self._emit("score = 0.0")
1303
- self._emit("# Code compilation reward (active if 'code_compiles' in verifiers)")
1304
- self._emit("if 'code_compiles' in _active_verifiers:")
1305
- self._indent += 1
1306
- self._emit("code_blocks = re.findall(r'```python\\n(.*?)```', text, re.S)")
1307
- self._emit("for block in code_blocks or []:")
1308
- self._indent += 1
1309
- self._emit("try:")
1310
- self._indent += 1
1311
- self._emit("ast.parse(block)")
1312
- self._emit("score += 0.4")
1313
- self._emit("break")
1314
- self._indent -= 1
1315
- self._emit("except SyntaxError:")
1316
- self._indent += 1
1317
- self._emit("pass")
1318
- self._indent -= 3
1319
- self._emit("# Math correctness reward (active if 'math_correct' in verifiers)")
1320
- self._emit("if 'math_correct' in _active_verifiers:")
1321
- self._indent += 1
1322
- self._emit("expr_match = re.search(r'([0-9+\\-*/().\\s]{3,})', prompt)")
1323
- self._emit("pred_num_match = re.search(r'(-?\\d+(?:\\.\\d+)?)', text)")
1324
- self._emit("if expr_match and pred_num_match:")
1325
- self._indent += 1
1326
- self._emit("expr = expr_match.group(1)")
1327
- self._emit("target = _safe_eval(expr)")
1328
- self._emit("try:")
1329
- self._indent += 1
1330
- self._emit("pred_val = float(pred_num_match.group(1))")
1331
- self._indent -= 1
1332
- self._emit("except Exception:")
1333
- self._indent += 1
1334
- self._emit("pred_val = None")
1335
- self._indent -= 1
1336
- self._emit("if target is not None and pred_val is not None and abs(target - pred_val) < 1e-3:")
1337
- self._indent += 1
1338
- self._emit("score += 0.4")
1339
- self._indent -= 3
1340
- self._emit("# No hallucination check (active if 'no_hallucination' in verifiers)")
1341
- self._emit("if 'no_hallucination' in _active_verifiers:")
1342
- self._indent += 1
1343
- self._emit("hedges = ['i think', 'probably', 'not sure', 'might be']")
1344
- self._emit("if not any(h in text.lower() for h in hedges):")
1345
- self._indent += 1
1346
- self._emit("score += 0.2")
1347
- self._indent -= 2
1348
- self._emit("# Structured answer bonus")
1349
- self._emit("if 'answer' in text.lower() or 'result' in text.lower():")
1350
- self._indent += 1
1351
- self._emit("score += 0.2")
1352
- self._indent -= 1
1353
- self._emit("# Enforce min_reward from reward_contract")
1354
- self._emit("rewards.append(max(min(score, 1.0), _min_reward) if score > 0 else 0.0)")
1355
- self._indent -= 1
1356
- self._emit("return rewards")
1357
- self._indent -= 1
1358
- self._emit("")
1359
- self._emit("# Early stopping (test_15): KL spike, reward drop, diversity drop")
1360
- self._emit("from transformers import TrainerCallback")
1361
- self._emit("")
1362
- self._emit("class EarlyStopper(TrainerCallback):")
1363
- self._indent += 1
1364
- self._emit("def __init__(self):")
1365
- self._indent += 1
1366
- self._emit("self.kl_history = []")
1367
- self._emit("self.eval_rewards = []")
1368
- self._emit("self.entropy_history = []")
1369
- self._indent -= 1
1370
- self._emit("")
1371
- self._emit("def on_log(self, args, state, control, logs=None, **kwargs):")
1372
- self._indent += 1
1373
- self._emit("logs = logs or {}")
1374
- self._emit("if 'kl' in logs:")
1375
- self._indent += 1
1376
- self._emit("self.kl_history.append(logs['kl'])")
1377
- self._emit("if len(self.kl_history) > 5:")
1378
- self._indent += 1
1379
- self._emit("ma = sum(self.kl_history[-5:]) / 5")
1380
- self._emit("if logs['kl'] > 3.1 * ma:")
1381
- self._indent += 1
1382
- self._emit("control.should_training_stop = True")
1383
- self._emit("print('[td_lang][early_stop] KL spike detected - stopping GRPO')")
1384
- self._indent -= 2
1385
- self._indent -= 1
1386
- self._emit("if 'eval/reward' in logs:")
1387
- self._indent += 1
1388
- self._emit("self.eval_rewards.append(logs['eval/reward'])")
1389
- self._emit("if len(self.eval_rewards) >= 2 and self.eval_rewards[-1] < self.eval_rewards[-2]:")
1390
- self._indent += 1
1391
- self._emit("control.should_training_stop = True")
1392
- self._emit("print('[td_lang][early_stop] Validation reward drop - stopping GRPO')")
1393
- self._indent -= 1
1394
- self._indent -= 1
1395
- self._emit("if 'policy_entropy' in logs:")
1396
- self._indent += 1
1397
- self._emit("self.entropy_history.append(logs['policy_entropy'])")
1398
- self._emit("if len(self.entropy_history) >= 3:")
1399
- self._indent += 1
1400
- self._emit("baseline = self.entropy_history[0]")
1401
- self._emit("if self.entropy_history[-1] < 0.93 * baseline:")
1402
- self._indent += 1
1403
- self._emit("control.should_training_stop = True")
1404
- self._emit("print('[td_lang][early_stop] Diversity collapsed - stopping GRPO')")
1405
- self._indent -= 2
1406
- self._indent -= 2
1407
- self._indent -= 1
1408
- self._emit("trainer = GRPOTrainer(")
1409
  self._indent += 1
1410
  self._emit("model=model,")
1411
- self._emit("args=grpo_config,")
1412
  self._emit("train_dataset=train_data,")
1413
- self._emit("reward_funcs=reward_fn,")
1414
  self._emit("processing_class=tok,")
1415
- self._emit("callbacks=[EarlyStopper()],")
1416
  self._indent -= 1
1417
  self._emit(")")
1418
  self._emit("trainer.train()")
1419
- self._emit("trainer.save_model('td_lang_outputs/grpo_trained')")
 
 
 
 
 
 
 
 
 
 
 
 
 
1420
  self._emit(f'models["{cmd.target}"]["checkpoint"] = "td_lang_outputs/grpo_trained"')
 
1421
 
1422
  elif cmd.method in ("sft", "dpo"):
1423
  self._emit(f"# {cmd.method.upper()} training with QLoRA (fits on 24GB 4090)")
 
224
  self._emit("")
225
  self._emit("def main():")
226
  self._indent += 1
227
+ self._emit("import os # safety: prevent UnboundLocalError if shadowed")
228
  self._emit("start_time = time.time()")
229
  self._emit("lineage = {}")
230
  self._emit("models = {}")
 
467
  self._indent += 1
468
  self._emit('raise SystemExit(f"Could not match source {_source_ref} to any SOURCES entry.")')
469
  self._indent -= 1
470
+ self._emit("")
471
+ self._emit("# Skip merge if checkpoint already exists (Bug #27 - saves ~12 min)")
472
+ self._emit('_merge_ckpt = Path(f"td_fuse_checkpoints/after_{_stage}")')
473
+ self._emit("if _merge_ckpt.exists() and (_merge_ckpt / 'model.safetensors').exists():")
474
+ self._indent += 1
475
+ self._emit('print(f"[td_lang] Found merge checkpoint {_merge_ckpt} - SKIPPING merge")')
476
+ self._emit('merge_result = {"status": "skipped", "final_checkpoint": str(_merge_ckpt)}')
477
+ self._indent -= 1
478
+ self._emit("else:")
479
+ self._indent += 1
480
+ self._emit("# Stack merges: pass previous checkpoint so MiMo builds on DeepSeek, etc.")
481
+ self._emit(f'_prev_ckpt = models.get("{cmd.target}", {{}}).get("checkpoint")')
482
  self._emit("cfg = MergeConfig()")
483
+ self._emit("merge_result = run_pipeline([_stage], cfg, base_checkpoint=_prev_ckpt)")
484
+ self._indent -= 1
485
  self._emit(f'results["{cmd.target}_merge"] = merge_result')
486
  self._emit("merged_stages.append(_stage)")
487
  self._emit('if merge_result.get("final_checkpoint"):')
 
1209
  self._emit("")
1210
 
1211
  if cmd.method == "grpo":
1212
+ self._emit("# Bug #26 fix: Use SFT on merge checkpoint (same approach as healing — proven to work)")
1213
+ self._emit("# GRPOTrainer breaks with Qwen3-VL, but standard Trainer works perfectly")
1214
+ self._emit("from transformers import AutoTokenizer, TrainingArguments, BitsAndBytesConfig, Trainer")
 
1215
  self._emit("from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training")
1216
+ self._emit("from datasets import load_dataset, Dataset")
1217
  self._emit("import torch")
1218
  self._emit("")
1219
+ self._emit("# Use latest merge checkpoint — pick newest after_* dir in td_fuse_checkpoints/")
1220
+ self._emit("_merge_ckpt = None")
1221
+ self._emit("_ckpt_base = Path('td_fuse_checkpoints')")
1222
+ self._emit("if _ckpt_base.exists():")
1223
+ self._indent += 1
1224
+ self._emit("_after_dirs = sorted(_ckpt_base.glob('after_*'), key=lambda p: p.stat().st_mtime, reverse=True)")
1225
+ self._emit("if _after_dirs and (_after_dirs[0] / 'model.safetensors').exists():")
1226
+ self._indent += 1
1227
+ self._emit("_merge_ckpt = str(_after_dirs[0])")
1228
+ self._indent -= 1
1229
+ self._indent -= 1
1230
+ self._emit("if _merge_ckpt:")
1231
+ self._indent += 1
1232
+ self._emit('print(f"[td_lang] Using merge checkpoint for training: {_merge_ckpt}")')
1233
+ self._emit("_train_ckpt = _merge_ckpt")
1234
+ self._indent -= 1
1235
+ self._emit("else:")
1236
+ self._indent += 1
1237
+ self._emit("_train_ckpt = checkpoint")
1238
+ self._emit('print(f"[td_lang] Using checkpoint for training: {_train_ckpt}")')
1239
+ self._indent -= 1
1240
+ self._emit("")
1241
+ self._emit("tok = AutoTokenizer.from_pretrained(_train_ckpt)")
1242
  self._emit("if tok.pad_token is None:")
1243
  self._indent += 1
1244
  self._emit("tok.pad_token = tok.eos_token")
1245
  self._indent -= 1
1246
  self._emit("")
 
1247
  self._emit("bnb_config = BitsAndBytesConfig(")
1248
  self._indent += 1
1249
  self._emit("load_in_4bit=True,")
 
1252
  self._emit("bnb_4bit_use_double_quant=True,")
1253
  self._indent -= 1
1254
  self._emit(")")
1255
+ self._emit("model = _load_model_smart(_train_ckpt, quantization_config=bnb_config, device_map='auto')")
 
1256
  self._emit("model = prepare_model_for_kbit_training(model)")
1257
  self._emit("")
1258
  self._emit("# LoRA adapters on mid-to-late layers (test_12: layers 16-28 for 32-layer)")
 
1279
  self._emit("train_data = load_dataset(dataset_path, split='train')")
1280
  self._indent -= 1
1281
  self._emit("")
1282
+ self._emit("# Format synth data as text for SFT (prompt + response)")
1283
+ self._emit("def _format_synth(example):")
1284
+ self._indent += 1
1285
+ self._emit("text = example['prompt'] + '\\n' + example.get('response', '')")
1286
+ self._emit("tokens = tok(text, truncation=True, max_length=512, padding='max_length')")
1287
+ self._emit("tokens['labels'] = tokens['input_ids'].copy()")
1288
+ self._emit("return tokens")
1289
+ self._indent -= 1
1290
+ self._emit("train_data = train_data.map(_format_synth, remove_columns=train_data.column_names)")
1291
+ self._emit("")
1292
+ self._emit("training_args = TrainingArguments(")
1293
  self._indent += 1
1294
  self._emit(f"max_steps={steps},")
1295
  self._emit(f"learning_rate={lr},")
1296
  self._emit("per_device_train_batch_size=1,")
1297
  self._emit("gradient_accumulation_steps=8,")
1298
+ self._emit("logging_steps=10,")
1299
+ self._emit('output_dir="td_lang_outputs/sft_training",')
1300
+ self._emit("save_steps=50,")
1301
  self._emit('bf16=True,')
1302
+ self._emit("gradient_checkpointing=True,")
1303
+ self._emit("remove_unused_columns=False,")
1304
  self._indent -= 1
1305
  self._emit(")")
1306
  self._emit("")
1307
+ self._emit("trainer = Trainer(")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1308
  self._indent += 1
1309
  self._emit("model=model,")
1310
+ self._emit("args=training_args,")
1311
  self._emit("train_dataset=train_data,")
 
1312
  self._emit("processing_class=tok,")
 
1313
  self._indent -= 1
1314
  self._emit(")")
1315
  self._emit("trainer.train()")
1316
+ self._emit("")
1317
+ self._emit("# Merge LoRA and save")
1318
+ self._emit("model = model.merge_and_unload()")
1319
+ self._emit("")
1320
+ self._emit("# Free disk before save")
1321
+ self._emit("import shutil, gc as _gc")
1322
+ self._emit("for _d in ['td_fuse_outputs/final', 'td_fuse_outputs/healed']:")
1323
+ self._indent += 1
1324
+ self._emit("_p = Path(_d)")
1325
+ self._emit("if _p.exists() and _p.is_dir(): shutil.rmtree(str(_p)); print(f'[td_lang] Freed: {_d}')")
1326
+ self._indent -= 1
1327
+ self._emit("_gc.collect()")
1328
+ self._emit("model.save_pretrained('td_lang_outputs/grpo_trained')")
1329
+ self._emit("tok.save_pretrained('td_lang_outputs/grpo_trained')")
1330
  self._emit(f'models["{cmd.target}"]["checkpoint"] = "td_lang_outputs/grpo_trained"')
1331
+ self._emit("print('[td_lang] Training complete - model saved to td_lang_outputs/grpo_trained')")
1332
 
1333
  elif cmd.method in ("sft", "dpo"):
1334
  self._emit(f"# {cmd.method.upper()} training with QLoRA (fits on 24GB 4090)")
hugging/td_start.td CHANGED
@@ -47,6 +47,12 @@ load "Qwen/Qwen3-VL-8B-Instruct" as base
47
  # Gives us deep reasoning abilities from R1
48
  merge "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" into base using transport strength 0.5
49
 
 
 
 
 
 
 
50
  # --- Step 3: Heal any merge damage ---
51
  # QLoRA fine-tune to smooth out rough edges from the merge
52
  heal base lora_r 32 epochs 2
 
47
  # Gives us deep reasoning abilities from R1
48
  merge "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" into base using transport strength 0.5
49
 
50
+ # --- Step 2b: Merge in MiMo-7B reasoning ---
51
+ # Medium risk: same layer count (36) and hidden_dim (4096)
52
+ # MTP heads get dropped automatically (no Qwen3 equivalent)
53
+ # Embeddings skipped (28% vocab overlap too low)
54
+ merge "XiaomiMiMo/MiMo-7B-RL" into base using transport strength 0.4
55
+
56
  # --- Step 3: Heal any merge damage ---
57
  # QLoRA fine-tune to smooth out rough edges from the merge
58
  heal base lora_r 32 epochs 2