td-builder commited on
Commit
bd8f3d3
·
verified ·
1 Parent(s): 1e0b51b

Upload 137 files

Browse files
hugging/td_fuse/merge.py CHANGED
@@ -725,7 +725,7 @@ def run_single_merge(
725
  )
726
  print(f"[merge] Step 5/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
727
 
728
- # --- Step 5.5: RAM RL-weight disentanglement (2601.13572) ---
729
  use_ram = (
730
  cfg.use_ram_disentangle
731
  and source_config.architecture in ("transformer", "transformer+mtp")
@@ -733,6 +733,39 @@ def run_single_merge(
733
  and any(kw in source_config.name.lower() for kw in ["r1", "rl", "rlhf", "grpo"])
734
  )
735
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
736
  # --- Step 6: Pre-merge protection ---
737
  print(f"\n[merge] Step 6/10: Pre-merge protection..."); sys.stdout.flush()
738
  step_t = time.time()
@@ -753,7 +786,6 @@ def run_single_merge(
753
  # RAM path: disentangle RL weights, merge with preservation
754
  print(f"\n[merge] Using RAM RL-preservation for {stage_name}...")
755
  try:
756
- # Try loading the base (pre-RL) model for disentanglement
757
  base_hf_id = source_config.hf_id.replace("-RL", "").replace("-R1-0528", "")
758
  print(f"[merge] Loading base model for RAM: {base_hf_id}")
759
  base_model = AutoModelForCausalLM.from_pretrained(
@@ -763,34 +795,38 @@ def run_single_merge(
763
  trust_remote_code=source_config.trust_remote_code,
764
  )
765
  shared_mask, rl_mask = disentangle_rl_weights(
766
- source_model, base_model, cfg.ram_rl_threshold
767
  )
768
  # Fuse with RL preservation
769
  target_state = merge_with_rl_preservation(
770
  target_model.state_dict(),
771
- source_model.state_dict(),
772
  shared_mask, rl_mask,
773
  shared_alpha=cfg.ram_shared_alpha * (adjusted_alpha / source_config.merge_alpha),
774
  rl_alpha=cfg.ram_rl_alpha,
775
  )
776
  target_model.load_state_dict(target_state)
777
  del base_model
 
 
 
778
  print(f"[merge] RAM merge complete for {stage_name}")
779
  except Exception as e:
780
  print(f"[merge] RAM failed ({e}), falling back to standard T&M merge")
781
  target_model = fuse_weights(
782
- source_model, target_model, transport_plans,
783
  source_config_adjusted, cfg,
784
  )
785
  else:
786
- # Standard T&M path
787
  target_model = fuse_weights(
788
- source_model, target_model, transport_plans,
789
  source_config_adjusted, cfg,
790
  )
791
 
792
  # --- Step 7.5: Theseus fallback check (2602.12952) ---
793
  # If T&M merge produced poor activation alignment, try Theseus
 
794
  if cfg.use_theseus_fallback and source_config.merge_risk == "high":
795
  print(f"\n[merge] Checking if Theseus fallback needed for {stage_name}...")
796
  post_activations = extract_activations(target_model, calibration_data[:50]) # Quick check
@@ -811,6 +847,9 @@ def run_single_merge(
811
  # Restore pre-merge state and try Theseus instead
812
  target_model.load_state_dict(pre_merge_state)
813
  try:
 
 
 
814
  base_model = AutoModelForCausalLM.from_pretrained(
815
  source_config.hf_id.split("/")[0] + "/" + source_config.hf_id.split("/")[1].split("-")[0],
816
  torch_dtype=getattr(torch, cfg.dtype),
@@ -818,17 +857,20 @@ def run_single_merge(
818
  trust_remote_code=source_config.trust_remote_code,
819
  )
820
  target_model = transport_task_vector_theseus(
821
- source_model, base_model, target_model,
822
  source_activations, pre_merge_target_activations,
823
  alpha=cfg.theseus_alpha,
824
  )
825
- del base_model
 
 
 
826
  print(f"[merge] Theseus transport complete for {stage_name}")
827
  except Exception as e:
828
  print(f"[merge] Theseus also failed ({e}). Using original T&M result.")
829
- # Re-apply T&M result
830
  target_model = fuse_weights(
831
- source_model, target_model, transport_plans,
832
  source_config_adjusted, cfg,
833
  )
834
 
@@ -880,15 +922,16 @@ def run_single_merge(
880
  residual_bank.save_residuals(
881
  stage_name=stage_name,
882
  pre_merge_target_state=pre_merge_state,
883
- source_state={k: v.cpu() for k, v in source_model.state_dict().items()},
884
  post_merge_state={k: v.cpu() for k, v in target_model.state_dict().items()},
885
  source_config=source_config,
886
  )
887
 
888
  print(f"[merge] Step 9/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
889
 
890
- # --- Step 9: Free source model memory ---
891
- del source_model, source_activations, pre_merge_target_activations
 
892
  del transport_plans, post_merge_activations
893
  gc.collect()
894
  if torch.cuda.is_available():
 
725
  )
726
  print(f"[merge] Step 5/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
727
 
728
+ # --- Step 5.5: RAM RL-weight disentanglement check (2601.13572) ---
729
  use_ram = (
730
  cfg.use_ram_disentangle
731
  and source_config.architecture in ("transformer", "transformer+mtp")
 
733
  and any(kw in source_config.name.lower() for kw in ["r1", "rl", "rlhf", "grpo"])
734
  )
735
 
736
+ # Validate that the RAM base model actually exists before we try loading it
737
+ if use_ram:
738
+ base_hf_id = source_config.hf_id.replace("-RL", "").replace("-R1-0528", "")
739
+ if base_hf_id == source_config.hf_id:
740
+ # Stripping didn't change anything — no base model to compare against
741
+ print(f"[merge] RAM skipped: no base model ID derivable from {source_config.hf_id}")
742
+ use_ram = False
743
+ else:
744
+ # Check if the base model exists on HuggingFace
745
+ try:
746
+ from huggingface_hub import model_info
747
+ model_info(base_hf_id)
748
+ print(f"[merge] RAM base model verified: {base_hf_id}")
749
+ except Exception:
750
+ print(f"[merge] RAM skipped: base model {base_hf_id} not found on HuggingFace")
751
+ use_ram = False
752
+
753
+ # --- Step 5.7: Free source model from GPU ---
754
+ # After transport plans are computed, we only need the source STATE DICT
755
+ # (not the full model object). Freeing the model saves ~16 GB of GPU memory
756
+ # which prevents OOM during the fusion step.
757
+ print(f"\n[merge] Step 5.7: Freeing source model from GPU..."); sys.stdout.flush()
758
+ step_t = time.time()
759
+ source_state_cpu = {k: v.cpu() for k, v in source_model.state_dict().items()}
760
+ del source_model
761
+ gc.collect()
762
+ if torch.cuda.is_available():
763
+ torch.cuda.empty_cache()
764
+ free_mem = torch.cuda.mem_get_info()[0] / 1e9
765
+ total_mem = torch.cuda.mem_get_info()[1] / 1e9
766
+ print(f"[merge] GPU memory after freeing source: {free_mem:.1f} GB free / {total_mem:.1f} GB total")
767
+ print(f"[merge] Step 5.7 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
768
+
769
  # --- Step 6: Pre-merge protection ---
770
  print(f"\n[merge] Step 6/10: Pre-merge protection..."); sys.stdout.flush()
771
  step_t = time.time()
 
786
  # RAM path: disentangle RL weights, merge with preservation
787
  print(f"\n[merge] Using RAM RL-preservation for {stage_name}...")
788
  try:
 
789
  base_hf_id = source_config.hf_id.replace("-RL", "").replace("-R1-0528", "")
790
  print(f"[merge] Loading base model for RAM: {base_hf_id}")
791
  base_model = AutoModelForCausalLM.from_pretrained(
 
795
  trust_remote_code=source_config.trust_remote_code,
796
  )
797
  shared_mask, rl_mask = disentangle_rl_weights(
798
+ source_state_cpu, base_model, cfg.ram_rl_threshold
799
  )
800
  # Fuse with RL preservation
801
  target_state = merge_with_rl_preservation(
802
  target_model.state_dict(),
803
+ source_state_cpu,
804
  shared_mask, rl_mask,
805
  shared_alpha=cfg.ram_shared_alpha * (adjusted_alpha / source_config.merge_alpha),
806
  rl_alpha=cfg.ram_rl_alpha,
807
  )
808
  target_model.load_state_dict(target_state)
809
  del base_model
810
+ gc.collect()
811
+ if torch.cuda.is_available():
812
+ torch.cuda.empty_cache()
813
  print(f"[merge] RAM merge complete for {stage_name}")
814
  except Exception as e:
815
  print(f"[merge] RAM failed ({e}), falling back to standard T&M merge")
816
  target_model = fuse_weights(
817
+ source_state_cpu, target_model, transport_plans,
818
  source_config_adjusted, cfg,
819
  )
820
  else:
821
+ # Standard T&M path (source_state_cpu is on CPU, fuse_weights moves per-param)
822
  target_model = fuse_weights(
823
+ source_state_cpu, target_model, transport_plans,
824
  source_config_adjusted, cfg,
825
  )
826
 
827
  # --- Step 7.5: Theseus fallback check (2602.12952) ---
828
  # If T&M merge produced poor activation alignment, try Theseus
829
+ # NOTE: source_model was freed in step 5.7 — Theseus needs full model reload
830
  if cfg.use_theseus_fallback and source_config.merge_risk == "high":
831
  print(f"\n[merge] Checking if Theseus fallback needed for {stage_name}...")
832
  post_activations = extract_activations(target_model, calibration_data[:50]) # Quick check
 
847
  # Restore pre-merge state and try Theseus instead
848
  target_model.load_state_dict(pre_merge_state)
849
  try:
850
+ # Reload source model for Theseus (it was freed in step 5.7)
851
+ print(f"[merge] Reloading source model for Theseus fallback...")
852
+ source_model_reload, _ = load_model(source_config, cfg)
853
  base_model = AutoModelForCausalLM.from_pretrained(
854
  source_config.hf_id.split("/")[0] + "/" + source_config.hf_id.split("/")[1].split("-")[0],
855
  torch_dtype=getattr(torch, cfg.dtype),
 
857
  trust_remote_code=source_config.trust_remote_code,
858
  )
859
  target_model = transport_task_vector_theseus(
860
+ source_model_reload, base_model, target_model,
861
  source_activations, pre_merge_target_activations,
862
  alpha=cfg.theseus_alpha,
863
  )
864
+ del base_model, source_model_reload
865
+ gc.collect()
866
+ if torch.cuda.is_available():
867
+ torch.cuda.empty_cache()
868
  print(f"[merge] Theseus transport complete for {stage_name}")
869
  except Exception as e:
870
  print(f"[merge] Theseus also failed ({e}). Using original T&M result.")
871
+ # Re-apply T&M result using CPU state dict
872
  target_model = fuse_weights(
873
+ source_state_cpu, target_model, transport_plans,
874
  source_config_adjusted, cfg,
875
  )
876
 
 
922
  residual_bank.save_residuals(
923
  stage_name=stage_name,
924
  pre_merge_target_state=pre_merge_state,
925
+ source_state=source_state_cpu, # Already on CPU from step 5.7
926
  post_merge_state={k: v.cpu() for k, v in target_model.state_dict().items()},
927
  source_config=source_config,
928
  )
929
 
930
  print(f"[merge] Step 9/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
931
 
932
+ # --- Step 9: Free remaining memory ---
933
+ # source_model was already freed in step 5.7
934
+ del source_state_cpu, source_activations, pre_merge_target_activations
935
  del transport_plans, post_merge_activations
936
  gc.collect()
937
  if torch.cuda.is_available():
hugging/td_fuse/transport.py CHANGED
@@ -513,7 +513,7 @@ def _sinkhorn(
513
 
514
 
515
  def fuse_weights(
516
- source_model: AutoModelForCausalLM,
517
  target_model: AutoModelForCausalLM,
518
  transport_plans: dict,
519
  source_config: ModelConfig,
@@ -527,6 +527,13 @@ def fuse_weights(
527
  2. Transport source weights into target neuron basis: W_fused = Q @ W_source
528
  3. Blend with target: W_final = alpha * W_fused + (1-alpha) * W_target
529
 
 
 
 
 
 
 
 
530
  Special handling per model:
531
  - DeepSeek: Direct merge (same architecture)
532
  - MiMo: Skip MTP heads, skip embeddings
@@ -550,7 +557,7 @@ def fuse_weights(
550
  pass
551
 
552
  # --- Manual fusion using transport plans ---
553
- source_state = source_model.state_dict()
554
  target_state = target_model.state_dict()
555
  P = transport_plans["P"]
556
  Q = transport_plans["Q"]
 
513
 
514
 
515
  def fuse_weights(
516
+ source_state: dict,
517
  target_model: AutoModelForCausalLM,
518
  transport_plans: dict,
519
  source_config: ModelConfig,
 
527
  2. Transport source weights into target neuron basis: W_fused = Q @ W_source
528
  3. Blend with target: W_final = alpha * W_fused + (1-alpha) * W_target
529
 
530
+ Args:
531
+ source_state: Source model state dict (can be on CPU — will be moved per-param)
532
+ target_model: Target model (on GPU)
533
+ transport_plans: Transport plan matrices from compute_transport_plans
534
+ source_config: Source model config
535
+ cfg: Merge configuration
536
+
537
  Special handling per model:
538
  - DeepSeek: Direct merge (same architecture)
539
  - MiMo: Skip MTP heads, skip embeddings
 
557
  pass
558
 
559
  # --- Manual fusion using transport plans ---
560
+ # source_state is passed in (may be on CPU to save GPU memory)
561
  target_state = target_model.state_dict()
562
  P = transport_plans["P"]
563
  Q = transport_plans["Q"]