Upload 137 files
Browse files- hugging/td_fuse/merge.py +57 -14
- hugging/td_fuse/transport.py +9 -2
hugging/td_fuse/merge.py
CHANGED
|
@@ -725,7 +725,7 @@ def run_single_merge(
|
|
| 725 |
)
|
| 726 |
print(f"[merge] Step 5/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
|
| 727 |
|
| 728 |
-
# --- Step 5.5: RAM RL-weight disentanglement (2601.13572) ---
|
| 729 |
use_ram = (
|
| 730 |
cfg.use_ram_disentangle
|
| 731 |
and source_config.architecture in ("transformer", "transformer+mtp")
|
|
@@ -733,6 +733,39 @@ def run_single_merge(
|
|
| 733 |
and any(kw in source_config.name.lower() for kw in ["r1", "rl", "rlhf", "grpo"])
|
| 734 |
)
|
| 735 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 736 |
# --- Step 6: Pre-merge protection ---
|
| 737 |
print(f"\n[merge] Step 6/10: Pre-merge protection..."); sys.stdout.flush()
|
| 738 |
step_t = time.time()
|
|
@@ -753,7 +786,6 @@ def run_single_merge(
|
|
| 753 |
# RAM path: disentangle RL weights, merge with preservation
|
| 754 |
print(f"\n[merge] Using RAM RL-preservation for {stage_name}...")
|
| 755 |
try:
|
| 756 |
-
# Try loading the base (pre-RL) model for disentanglement
|
| 757 |
base_hf_id = source_config.hf_id.replace("-RL", "").replace("-R1-0528", "")
|
| 758 |
print(f"[merge] Loading base model for RAM: {base_hf_id}")
|
| 759 |
base_model = AutoModelForCausalLM.from_pretrained(
|
|
@@ -763,34 +795,38 @@ def run_single_merge(
|
|
| 763 |
trust_remote_code=source_config.trust_remote_code,
|
| 764 |
)
|
| 765 |
shared_mask, rl_mask = disentangle_rl_weights(
|
| 766 |
-
|
| 767 |
)
|
| 768 |
# Fuse with RL preservation
|
| 769 |
target_state = merge_with_rl_preservation(
|
| 770 |
target_model.state_dict(),
|
| 771 |
-
|
| 772 |
shared_mask, rl_mask,
|
| 773 |
shared_alpha=cfg.ram_shared_alpha * (adjusted_alpha / source_config.merge_alpha),
|
| 774 |
rl_alpha=cfg.ram_rl_alpha,
|
| 775 |
)
|
| 776 |
target_model.load_state_dict(target_state)
|
| 777 |
del base_model
|
|
|
|
|
|
|
|
|
|
| 778 |
print(f"[merge] RAM merge complete for {stage_name}")
|
| 779 |
except Exception as e:
|
| 780 |
print(f"[merge] RAM failed ({e}), falling back to standard T&M merge")
|
| 781 |
target_model = fuse_weights(
|
| 782 |
-
|
| 783 |
source_config_adjusted, cfg,
|
| 784 |
)
|
| 785 |
else:
|
| 786 |
-
# Standard T&M path
|
| 787 |
target_model = fuse_weights(
|
| 788 |
-
|
| 789 |
source_config_adjusted, cfg,
|
| 790 |
)
|
| 791 |
|
| 792 |
# --- Step 7.5: Theseus fallback check (2602.12952) ---
|
| 793 |
# If T&M merge produced poor activation alignment, try Theseus
|
|
|
|
| 794 |
if cfg.use_theseus_fallback and source_config.merge_risk == "high":
|
| 795 |
print(f"\n[merge] Checking if Theseus fallback needed for {stage_name}...")
|
| 796 |
post_activations = extract_activations(target_model, calibration_data[:50]) # Quick check
|
|
@@ -811,6 +847,9 @@ def run_single_merge(
|
|
| 811 |
# Restore pre-merge state and try Theseus instead
|
| 812 |
target_model.load_state_dict(pre_merge_state)
|
| 813 |
try:
|
|
|
|
|
|
|
|
|
|
| 814 |
base_model = AutoModelForCausalLM.from_pretrained(
|
| 815 |
source_config.hf_id.split("/")[0] + "/" + source_config.hf_id.split("/")[1].split("-")[0],
|
| 816 |
torch_dtype=getattr(torch, cfg.dtype),
|
|
@@ -818,17 +857,20 @@ def run_single_merge(
|
|
| 818 |
trust_remote_code=source_config.trust_remote_code,
|
| 819 |
)
|
| 820 |
target_model = transport_task_vector_theseus(
|
| 821 |
-
|
| 822 |
source_activations, pre_merge_target_activations,
|
| 823 |
alpha=cfg.theseus_alpha,
|
| 824 |
)
|
| 825 |
-
del base_model
|
|
|
|
|
|
|
|
|
|
| 826 |
print(f"[merge] Theseus transport complete for {stage_name}")
|
| 827 |
except Exception as e:
|
| 828 |
print(f"[merge] Theseus also failed ({e}). Using original T&M result.")
|
| 829 |
-
# Re-apply T&M result
|
| 830 |
target_model = fuse_weights(
|
| 831 |
-
|
| 832 |
source_config_adjusted, cfg,
|
| 833 |
)
|
| 834 |
|
|
@@ -880,15 +922,16 @@ def run_single_merge(
|
|
| 880 |
residual_bank.save_residuals(
|
| 881 |
stage_name=stage_name,
|
| 882 |
pre_merge_target_state=pre_merge_state,
|
| 883 |
-
source_state=
|
| 884 |
post_merge_state={k: v.cpu() for k, v in target_model.state_dict().items()},
|
| 885 |
source_config=source_config,
|
| 886 |
)
|
| 887 |
|
| 888 |
print(f"[merge] Step 9/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
|
| 889 |
|
| 890 |
-
# --- Step 9: Free
|
| 891 |
-
|
|
|
|
| 892 |
del transport_plans, post_merge_activations
|
| 893 |
gc.collect()
|
| 894 |
if torch.cuda.is_available():
|
|
|
|
| 725 |
)
|
| 726 |
print(f"[merge] Step 5/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
|
| 727 |
|
| 728 |
+
# --- Step 5.5: RAM RL-weight disentanglement check (2601.13572) ---
|
| 729 |
use_ram = (
|
| 730 |
cfg.use_ram_disentangle
|
| 731 |
and source_config.architecture in ("transformer", "transformer+mtp")
|
|
|
|
| 733 |
and any(kw in source_config.name.lower() for kw in ["r1", "rl", "rlhf", "grpo"])
|
| 734 |
)
|
| 735 |
|
| 736 |
+
# Validate that the RAM base model actually exists before we try loading it
|
| 737 |
+
if use_ram:
|
| 738 |
+
base_hf_id = source_config.hf_id.replace("-RL", "").replace("-R1-0528", "")
|
| 739 |
+
if base_hf_id == source_config.hf_id:
|
| 740 |
+
# Stripping didn't change anything — no base model to compare against
|
| 741 |
+
print(f"[merge] RAM skipped: no base model ID derivable from {source_config.hf_id}")
|
| 742 |
+
use_ram = False
|
| 743 |
+
else:
|
| 744 |
+
# Check if the base model exists on HuggingFace
|
| 745 |
+
try:
|
| 746 |
+
from huggingface_hub import model_info
|
| 747 |
+
model_info(base_hf_id)
|
| 748 |
+
print(f"[merge] RAM base model verified: {base_hf_id}")
|
| 749 |
+
except Exception:
|
| 750 |
+
print(f"[merge] RAM skipped: base model {base_hf_id} not found on HuggingFace")
|
| 751 |
+
use_ram = False
|
| 752 |
+
|
| 753 |
+
# --- Step 5.7: Free source model from GPU ---
|
| 754 |
+
# After transport plans are computed, we only need the source STATE DICT
|
| 755 |
+
# (not the full model object). Freeing the model saves ~16 GB of GPU memory
|
| 756 |
+
# which prevents OOM during the fusion step.
|
| 757 |
+
print(f"\n[merge] Step 5.7: Freeing source model from GPU..."); sys.stdout.flush()
|
| 758 |
+
step_t = time.time()
|
| 759 |
+
source_state_cpu = {k: v.cpu() for k, v in source_model.state_dict().items()}
|
| 760 |
+
del source_model
|
| 761 |
+
gc.collect()
|
| 762 |
+
if torch.cuda.is_available():
|
| 763 |
+
torch.cuda.empty_cache()
|
| 764 |
+
free_mem = torch.cuda.mem_get_info()[0] / 1e9
|
| 765 |
+
total_mem = torch.cuda.mem_get_info()[1] / 1e9
|
| 766 |
+
print(f"[merge] GPU memory after freeing source: {free_mem:.1f} GB free / {total_mem:.1f} GB total")
|
| 767 |
+
print(f"[merge] Step 5.7 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
|
| 768 |
+
|
| 769 |
# --- Step 6: Pre-merge protection ---
|
| 770 |
print(f"\n[merge] Step 6/10: Pre-merge protection..."); sys.stdout.flush()
|
| 771 |
step_t = time.time()
|
|
|
|
| 786 |
# RAM path: disentangle RL weights, merge with preservation
|
| 787 |
print(f"\n[merge] Using RAM RL-preservation for {stage_name}...")
|
| 788 |
try:
|
|
|
|
| 789 |
base_hf_id = source_config.hf_id.replace("-RL", "").replace("-R1-0528", "")
|
| 790 |
print(f"[merge] Loading base model for RAM: {base_hf_id}")
|
| 791 |
base_model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
| 795 |
trust_remote_code=source_config.trust_remote_code,
|
| 796 |
)
|
| 797 |
shared_mask, rl_mask = disentangle_rl_weights(
|
| 798 |
+
source_state_cpu, base_model, cfg.ram_rl_threshold
|
| 799 |
)
|
| 800 |
# Fuse with RL preservation
|
| 801 |
target_state = merge_with_rl_preservation(
|
| 802 |
target_model.state_dict(),
|
| 803 |
+
source_state_cpu,
|
| 804 |
shared_mask, rl_mask,
|
| 805 |
shared_alpha=cfg.ram_shared_alpha * (adjusted_alpha / source_config.merge_alpha),
|
| 806 |
rl_alpha=cfg.ram_rl_alpha,
|
| 807 |
)
|
| 808 |
target_model.load_state_dict(target_state)
|
| 809 |
del base_model
|
| 810 |
+
gc.collect()
|
| 811 |
+
if torch.cuda.is_available():
|
| 812 |
+
torch.cuda.empty_cache()
|
| 813 |
print(f"[merge] RAM merge complete for {stage_name}")
|
| 814 |
except Exception as e:
|
| 815 |
print(f"[merge] RAM failed ({e}), falling back to standard T&M merge")
|
| 816 |
target_model = fuse_weights(
|
| 817 |
+
source_state_cpu, target_model, transport_plans,
|
| 818 |
source_config_adjusted, cfg,
|
| 819 |
)
|
| 820 |
else:
|
| 821 |
+
# Standard T&M path (source_state_cpu is on CPU, fuse_weights moves per-param)
|
| 822 |
target_model = fuse_weights(
|
| 823 |
+
source_state_cpu, target_model, transport_plans,
|
| 824 |
source_config_adjusted, cfg,
|
| 825 |
)
|
| 826 |
|
| 827 |
# --- Step 7.5: Theseus fallback check (2602.12952) ---
|
| 828 |
# If T&M merge produced poor activation alignment, try Theseus
|
| 829 |
+
# NOTE: source_model was freed in step 5.7 — Theseus needs full model reload
|
| 830 |
if cfg.use_theseus_fallback and source_config.merge_risk == "high":
|
| 831 |
print(f"\n[merge] Checking if Theseus fallback needed for {stage_name}...")
|
| 832 |
post_activations = extract_activations(target_model, calibration_data[:50]) # Quick check
|
|
|
|
| 847 |
# Restore pre-merge state and try Theseus instead
|
| 848 |
target_model.load_state_dict(pre_merge_state)
|
| 849 |
try:
|
| 850 |
+
# Reload source model for Theseus (it was freed in step 5.7)
|
| 851 |
+
print(f"[merge] Reloading source model for Theseus fallback...")
|
| 852 |
+
source_model_reload, _ = load_model(source_config, cfg)
|
| 853 |
base_model = AutoModelForCausalLM.from_pretrained(
|
| 854 |
source_config.hf_id.split("/")[0] + "/" + source_config.hf_id.split("/")[1].split("-")[0],
|
| 855 |
torch_dtype=getattr(torch, cfg.dtype),
|
|
|
|
| 857 |
trust_remote_code=source_config.trust_remote_code,
|
| 858 |
)
|
| 859 |
target_model = transport_task_vector_theseus(
|
| 860 |
+
source_model_reload, base_model, target_model,
|
| 861 |
source_activations, pre_merge_target_activations,
|
| 862 |
alpha=cfg.theseus_alpha,
|
| 863 |
)
|
| 864 |
+
del base_model, source_model_reload
|
| 865 |
+
gc.collect()
|
| 866 |
+
if torch.cuda.is_available():
|
| 867 |
+
torch.cuda.empty_cache()
|
| 868 |
print(f"[merge] Theseus transport complete for {stage_name}")
|
| 869 |
except Exception as e:
|
| 870 |
print(f"[merge] Theseus also failed ({e}). Using original T&M result.")
|
| 871 |
+
# Re-apply T&M result using CPU state dict
|
| 872 |
target_model = fuse_weights(
|
| 873 |
+
source_state_cpu, target_model, transport_plans,
|
| 874 |
source_config_adjusted, cfg,
|
| 875 |
)
|
| 876 |
|
|
|
|
| 922 |
residual_bank.save_residuals(
|
| 923 |
stage_name=stage_name,
|
| 924 |
pre_merge_target_state=pre_merge_state,
|
| 925 |
+
source_state=source_state_cpu, # Already on CPU from step 5.7
|
| 926 |
post_merge_state={k: v.cpu() for k, v in target_model.state_dict().items()},
|
| 927 |
source_config=source_config,
|
| 928 |
)
|
| 929 |
|
| 930 |
print(f"[merge] Step 9/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
|
| 931 |
|
| 932 |
+
# --- Step 9: Free remaining memory ---
|
| 933 |
+
# source_model was already freed in step 5.7
|
| 934 |
+
del source_state_cpu, source_activations, pre_merge_target_activations
|
| 935 |
del transport_plans, post_merge_activations
|
| 936 |
gc.collect()
|
| 937 |
if torch.cuda.is_available():
|
hugging/td_fuse/transport.py
CHANGED
|
@@ -513,7 +513,7 @@ def _sinkhorn(
|
|
| 513 |
|
| 514 |
|
| 515 |
def fuse_weights(
|
| 516 |
-
|
| 517 |
target_model: AutoModelForCausalLM,
|
| 518 |
transport_plans: dict,
|
| 519 |
source_config: ModelConfig,
|
|
@@ -527,6 +527,13 @@ def fuse_weights(
|
|
| 527 |
2. Transport source weights into target neuron basis: W_fused = Q @ W_source
|
| 528 |
3. Blend with target: W_final = alpha * W_fused + (1-alpha) * W_target
|
| 529 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 530 |
Special handling per model:
|
| 531 |
- DeepSeek: Direct merge (same architecture)
|
| 532 |
- MiMo: Skip MTP heads, skip embeddings
|
|
@@ -550,7 +557,7 @@ def fuse_weights(
|
|
| 550 |
pass
|
| 551 |
|
| 552 |
# --- Manual fusion using transport plans ---
|
| 553 |
-
source_state
|
| 554 |
target_state = target_model.state_dict()
|
| 555 |
P = transport_plans["P"]
|
| 556 |
Q = transport_plans["Q"]
|
|
|
|
| 513 |
|
| 514 |
|
| 515 |
def fuse_weights(
|
| 516 |
+
source_state: dict,
|
| 517 |
target_model: AutoModelForCausalLM,
|
| 518 |
transport_plans: dict,
|
| 519 |
source_config: ModelConfig,
|
|
|
|
| 527 |
2. Transport source weights into target neuron basis: W_fused = Q @ W_source
|
| 528 |
3. Blend with target: W_final = alpha * W_fused + (1-alpha) * W_target
|
| 529 |
|
| 530 |
+
Args:
|
| 531 |
+
source_state: Source model state dict (can be on CPU — will be moved per-param)
|
| 532 |
+
target_model: Target model (on GPU)
|
| 533 |
+
transport_plans: Transport plan matrices from compute_transport_plans
|
| 534 |
+
source_config: Source model config
|
| 535 |
+
cfg: Merge configuration
|
| 536 |
+
|
| 537 |
Special handling per model:
|
| 538 |
- DeepSeek: Direct merge (same architecture)
|
| 539 |
- MiMo: Skip MTP heads, skip embeddings
|
|
|
|
| 557 |
pass
|
| 558 |
|
| 559 |
# --- Manual fusion using transport plans ---
|
| 560 |
+
# source_state is passed in (may be on CPU to save GPU memory)
|
| 561 |
target_state = target_model.state_dict()
|
| 562 |
P = transport_plans["P"]
|
| 563 |
Q = transport_plans["Q"]
|