td-builder commited on
Commit
bc446a5
·
verified ·
1 Parent(s): 52a6e10

Current td_fuse code with all fixes

Browse files
td_fuse/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.28 kB). View file
 
td_fuse/__pycache__/canary.cpython-312.pyc ADDED
Binary file (8.27 kB). View file
 
td_fuse/__pycache__/config.cpython-312.pyc ADDED
Binary file (8.59 kB). View file
 
td_fuse/__pycache__/heal.cpython-312.pyc ADDED
Binary file (16.2 kB). View file
 
td_fuse/__pycache__/merge.cpython-312.pyc ADDED
Binary file (58.7 kB). View file
 
td_fuse/__pycache__/techniques.cpython-312.pyc ADDED
Binary file (25.1 kB). View file
 
td_fuse/__pycache__/transport.cpython-312.pyc ADDED
Binary file (45.4 kB). View file
 
td_fuse/__pycache__/validate.cpython-312.pyc ADDED
Binary file (11.8 kB). View file
 
td_fuse/config.py CHANGED
@@ -129,7 +129,7 @@ SOURCES = [
129
  skip_embeddings=True, # Must skip — vocab too different
130
  trust_remote_code=False,
131
  merge_risk="medium",
132
- merge_alpha=0.35, # Lower alpha — layer mismatch risk
133
  special_handling=["skip_embeddings", "drop_qkv_bias", "layer_mapping_32_to_36"],
134
  notes=(
135
  "32 layers vs 36 — T&M's P matrix handles layer mapping. "
@@ -152,7 +152,7 @@ SOURCES = [
152
  skip_embeddings=True, # Must skip — vocab too different
153
  trust_remote_code=True, # Likely custom hybrid code
154
  merge_risk="high",
155
- merge_alpha=0.3, # Conservative — highest risk model
156
  special_handling=[
157
  "skip_embeddings",
158
  "drop_mamba_state_params", # A, D matrices have no Qwen3 equivalent
 
129
  skip_embeddings=True, # Must skip — vocab too different
130
  trust_remote_code=False,
131
  merge_risk="medium",
132
+ merge_alpha=0.08, # Lower alpha — layer mismatch risk
133
  special_handling=["skip_embeddings", "drop_qkv_bias", "layer_mapping_32_to_36"],
134
  notes=(
135
  "32 layers vs 36 — T&M's P matrix handles layer mapping. "
 
152
  skip_embeddings=True, # Must skip — vocab too different
153
  trust_remote_code=True, # Likely custom hybrid code
154
  merge_risk="high",
155
+ merge_alpha=0.08, # Conservative — highest risk model
156
  special_handling=[
157
  "skip_embeddings",
158
  "drop_mamba_state_params", # A, D matrices have no Qwen3 equivalent
td_fuse/heal.py CHANGED
@@ -69,11 +69,11 @@ def load_healing_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
69
  # Each entry: (dataset_id, config_name_or_None, split, count, text_field)
70
  datasets_to_load = [
71
  # General language — same calibration data source that works reliably
72
- ("neuralmagic/LLM_compression_calibration", None, "train", 500, "text"),
73
  # Math reasoning (exercises DeepSeek/MiMo contributions)
74
- ("openai/gsm8k", "main", "train", 300, "question"),
75
  # Code — bigcode/starcoderdata is a modern alternative
76
- ("bigcode/starcoderdata", "python", "train", 200, "content"),
77
  ]
78
 
79
  all_texts = []
@@ -193,7 +193,9 @@ def apply_qlora_unsloth(
193
  learning_rate=cfg.heal_learning_rate,
194
  bf16=True,
195
  logging_steps=10,
196
- save_strategy="no", max_steps=50, # Don't save intermediate checkpoints — saves ~17GB disk
 
 
197
  warmup_ratio=0.05,
198
  lr_scheduler_type="cosine",
199
  optim="adamw_8bit", # Memory-efficient optimiser
@@ -249,24 +251,15 @@ def apply_qlora_standard(
249
  return 'td_fuse_outputs/healed'
250
  import torch
251
  from peft import LoraConfig, get_peft_model, TaskType
252
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
253
 
254
  print("\n[heal] Loading model with standard PEFT...")
255
 
256
- # 4-bit quantisation config
257
- bnb_config = BitsAndBytesConfig(
258
- load_in_4bit=True,
259
- bnb_4bit_quant_type="nf4",
260
- bnb_4bit_compute_dtype=getattr(torch, cfg.dtype),
261
- bnb_4bit_use_double_quant=True,
262
- )
263
-
264
  tokenizer = AutoTokenizer.from_pretrained(model_path)
265
  model = _load_model_smart(
266
  model_path,
267
- quantization_config=bnb_config,
268
  device_map="auto",
269
- torch_dtype=getattr(torch, cfg.dtype),
270
  )
271
 
272
  # LoRA config
@@ -328,7 +321,9 @@ def apply_qlora_standard(
328
  learning_rate=cfg.heal_learning_rate,
329
  bf16=True,
330
  logging_steps=10,
331
- save_strategy="no", max_steps=50, # Don't save intermediate checkpoints — saves ~17GB disk
 
 
332
  warmup_ratio=0.05,
333
  lr_scheduler_type="cosine",
334
  optim="adamw_torch",
@@ -365,36 +360,10 @@ def apply_qlora_standard(
365
 
366
  gc.collect()
367
 
368
- # SAVE FIRSTnever delete anything until save is confirmed
369
- # save_pretrained can fail on 4-bit merged models (NotImplementedError)
370
- # So we go straight to the safe manual method
371
- print(f"[heal] Saving healed model to {healed_dir}...")
372
- try:
373
- from safetensors.torch import save_file
374
- import torch as _torch
375
- # Fixed: use named_parameters for proper dequantization
376
- clean_state = {}
377
- for k, v in merged_model.named_parameters():
378
- if hasattr(v, 'dequantize'):
379
- clean_state[k] = v.dequantize().to(_torch.bfloat16)
380
- elif v.data.dtype in (_torch.float32, _torch.float16, _torch.bfloat16):
381
- clean_state[k] = v.data.to(_torch.bfloat16)
382
- else:
383
- clean_state[k] = v.data.float().to(_torch.bfloat16)
384
- save_file(clean_state, str(healed_dir / "model.safetensors"))
385
- if hasattr(merged_model, 'config'):
386
- if hasattr(merged_model.config, "quantization_config"):
387
- merged_model.config.quantization_config = None
388
- print("[heal] Removed quantization_config from saved config (weights are bf16 now)")
389
- merged_model.config.save_pretrained(str(healed_dir))
390
- tokenizer.save_pretrained(str(healed_dir))
391
- print(f"[heal] SAVED OK: {healed_dir / 'model.safetensors'}")
392
- except Exception as e:
393
- # Emergency fallback: try save_pretrained as last resort
394
- print(f"[heal] Manual save failed ({e}), trying save_pretrained...")
395
- merged_model.save_pretrained(str(healed_dir))
396
- tokenizer.save_pretrained(str(healed_dir))
397
- print(f"[heal] SAVED OK via save_pretrained: {healed_dir}")
398
 
399
  # Verify the save actually worked before cleaning up ANYTHING
400
  saved_model = healed_dir / "model.safetensors"
 
69
  # Each entry: (dataset_id, config_name_or_None, split, count, text_field)
70
  datasets_to_load = [
71
  # General language — same calibration data source that works reliably
72
+ ("neuralmagic/LLM_compression_calibration", None, "train", 1500, "text"),
73
  # Math reasoning (exercises DeepSeek/MiMo contributions)
74
+ ("openai/gsm8k", "main", "train", 1000, "question"),
75
  # Code — bigcode/starcoderdata is a modern alternative
76
+ ("sahil2801/CodeAlpaca-20k", None, "train", 500, "output"),
77
  ]
78
 
79
  all_texts = []
 
193
  learning_rate=cfg.heal_learning_rate,
194
  bf16=True,
195
  logging_steps=10,
196
+ save_strategy="steps",
197
+ save_steps=50,
198
+ save_total_limit=2, max_steps=50, # Don't save intermediate checkpoints — saves ~17GB disk
199
  warmup_ratio=0.05,
200
  lr_scheduler_type="cosine",
201
  optim="adamw_8bit", # Memory-efficient optimiser
 
251
  return 'td_fuse_outputs/healed'
252
  import torch
253
  from peft import LoraConfig, get_peft_model, TaskType
254
+ from transformers import AutoModelForCausalLM, AutoTokenizer
255
 
256
  print("\n[heal] Loading model with standard PEFT...")
257
 
 
 
 
 
 
 
 
 
258
  tokenizer = AutoTokenizer.from_pretrained(model_path)
259
  model = _load_model_smart(
260
  model_path,
 
261
  device_map="auto",
262
+ torch_dtype=torch.bfloat16,
263
  )
264
 
265
  # LoRA config
 
321
  learning_rate=cfg.heal_learning_rate,
322
  bf16=True,
323
  logging_steps=10,
324
+ save_strategy="steps",
325
+ save_steps=50,
326
+ save_total_limit=2, max_steps=50, # Don't save intermediate checkpoints — saves ~17GB disk
327
  warmup_ratio=0.05,
328
  lr_scheduler_type="cosine",
329
  optim="adamw_torch",
 
360
 
361
  gc.collect()
362
 
363
+ # bf16 modelsave_pretrained works correctly, no dequantize needed
364
+ merged_model.save_pretrained(str(healed_dir), safe_serialization=True)
365
+ tokenizer.save_pretrained(str(healed_dir))
366
+ print(f"[heal] SAVED OK: {healed_dir}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
 
368
  # Verify the save actually worked before cleaning up ANYTHING
369
  saved_model = healed_dir / "model.safetensors"
td_fuse/merge.py CHANGED
@@ -726,11 +726,11 @@ def run_single_merge(
726
  print(f"\n[merge] Step 4/10: Extracting activations (both models)..."); sys.stdout.flush()
727
  step_t = time.time()
728
  # Check if source model has a different vocabulary size than target.
729
- source_vocab_size = source_model.config.vocab_size if hasattr(source_model.config, 'vocab_size') else None
730
- target_vocab_size = target_model.config.vocab_size if hasattr(target_model.config, 'vocab_size') else None
731
  print(f"[merge] Vocab sizes -- target: {target_vocab_size}, source: {source_vocab_size}")
732
 
733
- if source_vocab_size and target_vocab_size and source_vocab_size != target_vocab_size:
734
  print(f"[merge] VOCAB MISMATCH detected! Re-tokenizing calibration data for {source_config.name}...")
735
  source_calibration = retokenize_calibration(calibration_raw_texts, source_tokenizer, cfg)
736
  print(f"[merge] Extracting source activations (with source-tokenized data)...")
 
726
  print(f"\n[merge] Step 4/10: Extracting activations (both models)..."); sys.stdout.flush()
727
  step_t = time.time()
728
  # Check if source model has a different vocabulary size than target.
729
+ source_vocab_size = len(source_tokenizer)
730
+ target_vocab_size = len(target_tokenizer)
731
  print(f"[merge] Vocab sizes -- target: {target_vocab_size}, source: {source_vocab_size}")
732
 
733
+ if source_vocab_size != target_vocab_size:
734
  print(f"[merge] VOCAB MISMATCH detected! Re-tokenizing calibration data for {source_config.name}...")
735
  source_calibration = retokenize_calibration(calibration_raw_texts, source_tokenizer, cfg)
736
  print(f"[merge] Extracting source activations (with source-tokenized data)...")
td_fuse/transport.py CHANGED
@@ -520,7 +520,7 @@ def _compute_plans_fallback(
520
  sys.stdout.flush()
521
 
522
  # Timeout: 90 min (Sinkhorn on 4096x4096 is slow on CPU)
523
- tracker.check_timeout(timeout_seconds=5400)
524
 
525
  if permutations:
526
  print(f"[transport] Computed {len(permutations)} neuron permutations")
@@ -569,8 +569,8 @@ def _compute_plans_fallback(
569
  layer_costs[i, j] = 1.0 - sim
570
  tracker.tick(f"layer sim {i},{j}")
571
 
572
- # Timeout: 30 min for cross-arch
573
- tracker.check_timeout(timeout_seconds=1800)
574
 
575
  print(f"[transport] Step 1/3 done: {n_source}x{n_target} similarities computed")
576
  sys.stdout.flush()
@@ -579,10 +579,24 @@ def _compute_plans_fallback(
579
  print("[transport] Step 2/3: Computing neuron-level transport (top-3 per target)...")
580
  sys.stdout.flush()
581
  Q_matrices = {}
 
 
 
 
 
582
  for j, tl in enumerate(target_layers):
583
  top3 = np.argsort(layer_costs[:, j])[:3]
584
  for i in top3:
585
  sl = source_layers[i]
 
 
 
 
 
 
 
 
 
586
  S = source_act[sl].numpy()
587
  T = target_act[tl].numpy()
588
 
@@ -595,14 +609,15 @@ def _compute_plans_fallback(
595
  corr = S_norm.T @ T_norm / S.shape[0]
596
  cost = 1.0 - corr
597
  Q_matrices[(sl, tl)] = _sinkhorn(cost, reg=0.1, max_iter=50)
 
598
  tracker.tick(f"Q({sl},{tl})")
599
 
600
  if (j + 1) % 5 == 0 or j == 0:
601
  print(f" Target layer {j + 1}/{n_target}: matched to top-3 sources")
602
  sys.stdout.flush()
603
 
604
- # Timeout: 30 min for cross-arch
605
- tracker.check_timeout(timeout_seconds=1800)
606
 
607
  print(f"[transport] Step 2/3 done: {len(Q_matrices)} Q matrices computed")
608
  sys.stdout.flush()
 
520
  sys.stdout.flush()
521
 
522
  # Timeout: 90 min (Sinkhorn on 4096x4096 is slow on CPU)
523
+ tracker.check_timeout(timeout_seconds=10800)
524
 
525
  if permutations:
526
  print(f"[transport] Computed {len(permutations)} neuron permutations")
 
569
  layer_costs[i, j] = 1.0 - sim
570
  tracker.tick(f"layer sim {i},{j}")
571
 
572
+ # Timeout: 90 min for cross-arch
573
+ tracker.check_timeout(timeout_seconds=10800)
574
 
575
  print(f"[transport] Step 1/3 done: {n_source}x{n_target} similarities computed")
576
  sys.stdout.flush()
 
579
  print("[transport] Step 2/3: Computing neuron-level transport (top-3 per target)...")
580
  sys.stdout.flush()
581
  Q_matrices = {}
582
+
583
+ # Incremental cache: save each Q as we go so crashes don't lose progress
584
+ q_cache_dir = Path("td_fuse_checkpoints") / "q_cache_crossarch"
585
+ q_cache_dir.mkdir(parents=True, exist_ok=True)
586
+
587
  for j, tl in enumerate(target_layers):
588
  top3 = np.argsort(layer_costs[:, j])[:3]
589
  for i in top3:
590
  sl = source_layers[i]
591
+ cache_key = f"{sl}__{tl}".replace("/", "_").replace(".", "_")
592
+ cache_path = q_cache_dir / f"{cache_key}.npy"
593
+
594
+ # Skip if already computed in a previous run
595
+ if cache_path.exists():
596
+ Q_matrices[(sl, tl)] = np.load(str(cache_path))
597
+ tracker.tick(f"Q({sl},{tl})")
598
+ continue
599
+
600
  S = source_act[sl].numpy()
601
  T = target_act[tl].numpy()
602
 
 
609
  corr = S_norm.T @ T_norm / S.shape[0]
610
  cost = 1.0 - corr
611
  Q_matrices[(sl, tl)] = _sinkhorn(cost, reg=0.1, max_iter=50)
612
+ np.save(str(cache_path), Q_matrices[(sl, tl)])
613
  tracker.tick(f"Q({sl},{tl})")
614
 
615
  if (j + 1) % 5 == 0 or j == 0:
616
  print(f" Target layer {j + 1}/{n_target}: matched to top-3 sources")
617
  sys.stdout.flush()
618
 
619
+ # Timeout: 90 min for cross-arch (was 30, too short for 72 layers)
620
+ tracker.check_timeout(timeout_seconds=10800)
621
 
622
  print(f"[transport] Step 2/3 done: {len(Q_matrices)} Q matrices computed")
623
  sys.stdout.flush()