td-builder commited on
Commit
1e0b51b
·
verified ·
1 Parent(s): 1e0fbdd

Upload 137 files

Browse files
hugging/td_fuse/merge.py CHANGED
@@ -134,7 +134,7 @@ class MergeProtection:
134
  4. MagMax — protect top magnitude params (extra safety layer)
135
  """
136
  fused = target_state[key]
137
- original = pre_merge_state[key]
138
  delta = fused - original
139
 
140
  # --- ARM Steering (new, replaces orthogonal projection) ---
@@ -203,7 +203,7 @@ class MergeProtection:
203
 
204
  for key in current_state:
205
  if key in pre_merge_state:
206
- delta = current_state[key].float() - pre_merge_state[key].float()
207
  if delta.abs().max() > 1e-8:
208
  if key not in self.previous_deltas:
209
  self.previous_deltas[key] = []
@@ -268,9 +268,65 @@ def get_source_by_stage(stage_name: str) -> Optional[ModelConfig]:
268
  return None
269
 
270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  def load_model(config: ModelConfig, cfg: MergeConfig) -> tuple:
272
  """Load a model and its tokenizer/processor."""
273
- print(f"\n[merge] Loading {config.name} ({config.hf_id})...")
 
 
 
 
274
 
275
  # Qwen3-VL uses a processor (handles both text + vision), not just a tokenizer
276
  if config.architecture == "transformer+vision":
@@ -298,6 +354,7 @@ def load_model(config: ModelConfig, cfg: MergeConfig) -> tuple:
298
  )
299
  lang_params = sum(p.numel() for p in model.parameters()) - vision_params
300
  print(f"[merge] Language: {lang_params / 1e9:.1f}B | Vision: {vision_params / 1e9:.1f}B")
 
301
 
302
  return model, tokenizer
303
  except ImportError:
@@ -318,6 +375,7 @@ def load_model(config: ModelConfig, cfg: MergeConfig) -> tuple:
318
  )
319
 
320
  print(f"[merge] Loaded {config.name}: {sum(p.numel() for p in model.parameters()) / 1e9:.1f}B params")
 
321
  return model, tokenizer
322
 
323
 
@@ -888,6 +946,7 @@ def run_pipeline(
888
  if cfg is None:
889
  cfg = MergeConfig()
890
 
 
891
  print("\n" + "=" * 70)
892
  print("TD FUSE — Transport and Merge Pipeline")
893
  print(f"Target: {TARGET.name} ({TARGET.hf_id})")
@@ -895,13 +954,18 @@ def run_pipeline(
895
  print(f"Mode: Vision-Language (merging language backbone only, vision encoder untouched)")
896
  print(f"Stages: {', '.join(stages)}")
897
  print(f"Output: {cfg.output_dir}")
 
898
  print("=" * 70)
 
 
 
 
899
 
900
  # Setup
901
  try:
902
  setup_tm_repo(cfg)
903
  except FileNotFoundError as e:
904
- print(f"\n {e}")
905
  print("Continuing with fallback implementation...")
906
 
907
  # Create output directories
@@ -1020,10 +1084,13 @@ def run_pipeline(
1020
  emoji = "✓" if status == "passed" else "✗"
1021
  print(f" {emoji} {stage_name}: {status}")
1022
  print(f"\n Overall: {pipeline_results['overall_status']}")
 
 
1023
  if residual_bank.residual_index:
1024
  print(f"\n Residuals saved for: {', '.join(residual_bank.residual_index.keys())}")
1025
  print(f" To recover lost knowledge later:")
1026
  print(f" python -m td_fuse.run --reinject <stage> --strength 0.2")
1027
  print("=" * 70)
 
1028
 
1029
  return pipeline_results
 
134
  4. MagMax — protect top magnitude params (extra safety layer)
135
  """
136
  fused = target_state[key]
137
+ original = pre_merge_state[key].to(fused.device)
138
  delta = fused - original
139
 
140
  # --- ARM Steering (new, replaces orthogonal projection) ---
 
203
 
204
  for key in current_state:
205
  if key in pre_merge_state:
206
+ delta = current_state[key].cpu().float() - pre_merge_state[key].cpu().float()
207
  if delta.abs().max() > 1e-8:
208
  if key not in self.previous_deltas:
209
  self.previous_deltas[key] = []
 
268
  return None
269
 
270
 
271
+ def check_model_cached(hf_id: str) -> bool:
272
+ """Check if a model is already in the HuggingFace cache."""
273
+ try:
274
+ from huggingface_hub import try_to_load_from_cache, model_info
275
+ # Quick check: see if config.json is cached (every model has one)
276
+ cached = try_to_load_from_cache(hf_id, "config.json")
277
+ if cached is not None and isinstance(cached, str):
278
+ return True
279
+ except Exception:
280
+ pass
281
+ return False
282
+
283
+
284
+ def check_all_models_cached(stages: list) -> dict:
285
+ """
286
+ Pre-flight check: are all needed models already downloaded?
287
+ Prints a clear table so you know what's cached and what will download.
288
+ """
289
+ print("\n" + "=" * 60)
290
+ print("PRE-FLIGHT CHECK: Model cache status")
291
+ print("=" * 60)
292
+ sys.stdout.flush()
293
+
294
+ status = {}
295
+
296
+ # Target model
297
+ cached = check_model_cached(TARGET.hf_id)
298
+ tag = "CACHED" if cached else "WILL DOWNLOAD"
299
+ print(f" {TARGET.name:25s} {tag:15s} ({TARGET.hf_id})")
300
+ status[TARGET.name] = cached
301
+
302
+ # Source models for requested stages
303
+ for stage_name in stages:
304
+ source = get_source_by_stage(stage_name)
305
+ if source:
306
+ cached = check_model_cached(source.hf_id)
307
+ tag = "CACHED" if cached else "WILL DOWNLOAD"
308
+ print(f" {source.name:25s} {tag:15s} ({source.hf_id})")
309
+ status[source.name] = cached
310
+
311
+ not_cached = [name for name, c in status.items() if not c]
312
+ if not_cached:
313
+ print(f"\n {len(not_cached)} model(s) need downloading: {', '.join(not_cached)}")
314
+ print(f" This may take 10-30 min per model depending on connection speed.")
315
+ else:
316
+ print(f"\n All {len(status)} models are cached -- loading will be fast!")
317
+
318
+ print("=" * 60)
319
+ sys.stdout.flush()
320
+ return status
321
+
322
+
323
  def load_model(config: ModelConfig, cfg: MergeConfig) -> tuple:
324
  """Load a model and its tokenizer/processor."""
325
+ load_start = time.time()
326
+ cached = check_model_cached(config.hf_id)
327
+ cache_msg = "(from cache)" if cached else "(downloading -- this may take a while)"
328
+ print(f"\n[merge] Loading {config.name} ({config.hf_id}) {cache_msg}...")
329
+ sys.stdout.flush()
330
 
331
  # Qwen3-VL uses a processor (handles both text + vision), not just a tokenizer
332
  if config.architecture == "transformer+vision":
 
354
  )
355
  lang_params = sum(p.numel() for p in model.parameters()) - vision_params
356
  print(f"[merge] Language: {lang_params / 1e9:.1f}B | Vision: {vision_params / 1e9:.1f}B")
357
+ print(f"[merge] Loaded in {time.time()-load_start:.0f}s"); sys.stdout.flush()
358
 
359
  return model, tokenizer
360
  except ImportError:
 
375
  )
376
 
377
  print(f"[merge] Loaded {config.name}: {sum(p.numel() for p in model.parameters()) / 1e9:.1f}B params")
378
+ print(f"[merge] Loaded in {time.time()-load_start:.0f}s"); sys.stdout.flush()
379
  return model, tokenizer
380
 
381
 
 
946
  if cfg is None:
947
  cfg = MergeConfig()
948
 
949
+ pipeline_start = time.time()
950
  print("\n" + "=" * 70)
951
  print("TD FUSE — Transport and Merge Pipeline")
952
  print(f"Target: {TARGET.name} ({TARGET.hf_id})")
 
954
  print(f"Mode: Vision-Language (merging language backbone only, vision encoder untouched)")
955
  print(f"Stages: {', '.join(stages)}")
956
  print(f"Output: {cfg.output_dir}")
957
+ print(f"Started at: {time.strftime('%H:%M:%S')}")
958
  print("=" * 70)
959
+ sys.stdout.flush()
960
+
961
+ # --- Pre-flight: check which models are cached ---
962
+ check_all_models_cached(stages)
963
 
964
  # Setup
965
  try:
966
  setup_tm_repo(cfg)
967
  except FileNotFoundError as e:
968
+ print(f"\n WARNING: {e}")
969
  print("Continuing with fallback implementation...")
970
 
971
  # Create output directories
 
1084
  emoji = "✓" if status == "passed" else "✗"
1085
  print(f" {emoji} {stage_name}: {status}")
1086
  print(f"\n Overall: {pipeline_results['overall_status']}")
1087
+ total_pipeline_time = time.time() - pipeline_start
1088
+ print(f"\n Total pipeline time: {total_pipeline_time/60:.1f} min ({total_pipeline_time/3600:.1f} hours)")
1089
  if residual_bank.residual_index:
1090
  print(f"\n Residuals saved for: {', '.join(residual_bank.residual_index.keys())}")
1091
  print(f" To recover lost knowledge later:")
1092
  print(f" python -m td_fuse.run --reinject <stage> --strength 0.2")
1093
  print("=" * 70)
1094
+ sys.stdout.flush()
1095
 
1096
  return pipeline_results
hugging/td_fuse/transport.py CHANGED
@@ -572,6 +572,10 @@ def fuse_weights(
572
  source_key = _map_key(target_key, source_config)
573
  if source_key is None or source_key not in source_state:
574
  skipped_count += 1
 
 
 
 
575
  continue
576
 
577
  target_w = target_state[target_key]
@@ -618,6 +622,10 @@ def fuse_weights(
618
  def _should_skip(key: str, source_config: ModelConfig) -> bool:
619
  """Determine if a parameter should be skipped during merge."""
620
 
 
 
 
 
621
  # Always skip if source model says to skip embeddings
622
  if source_config.skip_embeddings and ("embed_tokens" in key or "lm_head" in key):
623
  return True
@@ -640,22 +648,42 @@ def _should_skip(key: str, source_config: ModelConfig) -> bool:
640
  return False
641
 
642
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
643
  def _map_key(target_key: str, source_config: ModelConfig) -> Optional[str]:
644
  """Map a target model parameter name to the corresponding source name."""
645
 
646
- # For same-architecture models (DeepSeek), keys match directly
 
 
 
647
  if source_config.architecture == "transformer" and source_config.layers == 36:
648
- return target_key
649
 
650
  # For Llama (32 layers -> 36 layers), map layer indices
651
  if "layer_mapping_32_to_36" in source_config.special_handling:
652
- if "model.layers." in target_key:
653
  # Extract layer number
654
- parts = target_key.split(".")
655
  try:
656
  layer_idx = int(parts[2])
657
  except (IndexError, ValueError):
658
- return target_key
659
 
660
  # Map 36 target layers to 32 source layers (stride)
661
  source_layer = int(layer_idx * 32 / 36)
@@ -664,17 +692,17 @@ def _map_key(target_key: str, source_config: ModelConfig) -> Optional[str]:
664
 
665
  # For MiMo (same layer count, different extras), keys mostly match
666
  if source_config.architecture == "transformer+mtp":
667
- if "mtp_head" in target_key:
668
  return None # MTP heads don't exist in target
669
- return target_key
670
 
671
  # For Falcon hybrid, only attention and MLP keys map
672
  if source_config.architecture == "hybrid_ssm":
673
- if any(k in target_key for k in ["self_attn", "mlp", "layer_norm"]):
674
- return target_key # These exist in both
675
  return None # Mamba components don't map
676
 
677
- return target_key
678
 
679
 
680
  def _align_dimensions(
 
572
  source_key = _map_key(target_key, source_config)
573
  if source_key is None or source_key not in source_state:
574
  skipped_count += 1
575
+ # Log first few misses to help debug key mapping issues
576
+ if skipped_count <= 5:
577
+ print(f" [skip] No source match for: {target_key} (mapped to: {source_key})")
578
+ sys.stdout.flush()
579
  continue
580
 
581
  target_w = target_state[target_key]
 
622
  def _should_skip(key: str, source_config: ModelConfig) -> bool:
623
  """Determine if a parameter should be skipped during merge."""
624
 
625
+ # Skip vision encoder params (Qwen3-VL) -- these should never be merged
626
+ if key.startswith("visual") or key.startswith("merger") or key.startswith("model.visual") or key.startswith("model.merger"):
627
+ return True
628
+
629
  # Always skip if source model says to skip embeddings
630
  if source_config.skip_embeddings and ("embed_tokens" in key or "lm_head" in key):
631
  return True
 
648
  return False
649
 
650
 
651
+ def _strip_vl_prefix(key: str) -> str:
652
+ """
653
+ Strip the 'language_model.' prefix that Qwen3-VL adds.
654
+
655
+ Qwen3-VL wraps all language params under 'model.language_model.*'
656
+ but source models (DeepSeek, MiMo, Llama, Falcon) use 'model.*' directly.
657
+
658
+ Example:
659
+ target: model.language_model.layers.0.self_attn.q_proj.weight
660
+ source: model.layers.0.self_attn.q_proj.weight
661
+ """
662
+ # model.language_model.X -> model.X
663
+ if "language_model." in key:
664
+ return key.replace("language_model.", "")
665
+ return key
666
+
667
+
668
  def _map_key(target_key: str, source_config: ModelConfig) -> Optional[str]:
669
  """Map a target model parameter name to the corresponding source name."""
670
 
671
+ # Step 1: Strip Qwen3-VL's language_model. prefix so we can match source keys
672
+ source_key = _strip_vl_prefix(target_key)
673
+
674
+ # For same-architecture models (DeepSeek), keys match directly after prefix strip
675
  if source_config.architecture == "transformer" and source_config.layers == 36:
676
+ return source_key
677
 
678
  # For Llama (32 layers -> 36 layers), map layer indices
679
  if "layer_mapping_32_to_36" in source_config.special_handling:
680
+ if "model.layers." in source_key:
681
  # Extract layer number
682
+ parts = source_key.split(".")
683
  try:
684
  layer_idx = int(parts[2])
685
  except (IndexError, ValueError):
686
+ return source_key
687
 
688
  # Map 36 target layers to 32 source layers (stride)
689
  source_layer = int(layer_idx * 32 / 36)
 
692
 
693
  # For MiMo (same layer count, different extras), keys mostly match
694
  if source_config.architecture == "transformer+mtp":
695
+ if "mtp_head" in source_key:
696
  return None # MTP heads don't exist in target
697
+ return source_key
698
 
699
  # For Falcon hybrid, only attention and MLP keys map
700
  if source_config.architecture == "hybrid_ssm":
701
+ if any(k in source_key for k in ["self_attn", "mlp", "layer_norm"]):
702
+ return source_key # These exist in both
703
  return None # Mamba components don't map
704
 
705
+ return source_key
706
 
707
 
708
  def _align_dimensions(