Upload 137 files
Browse files- hugging/td_fuse/merge.py +71 -4
- hugging/td_fuse/transport.py +38 -10
hugging/td_fuse/merge.py
CHANGED
|
@@ -134,7 +134,7 @@ class MergeProtection:
|
|
| 134 |
4. MagMax — protect top magnitude params (extra safety layer)
|
| 135 |
"""
|
| 136 |
fused = target_state[key]
|
| 137 |
-
original = pre_merge_state[key]
|
| 138 |
delta = fused - original
|
| 139 |
|
| 140 |
# --- ARM Steering (new, replaces orthogonal projection) ---
|
|
@@ -203,7 +203,7 @@ class MergeProtection:
|
|
| 203 |
|
| 204 |
for key in current_state:
|
| 205 |
if key in pre_merge_state:
|
| 206 |
-
delta = current_state[key].float() - pre_merge_state[key].float()
|
| 207 |
if delta.abs().max() > 1e-8:
|
| 208 |
if key not in self.previous_deltas:
|
| 209 |
self.previous_deltas[key] = []
|
|
@@ -268,9 +268,65 @@ def get_source_by_stage(stage_name: str) -> Optional[ModelConfig]:
|
|
| 268 |
return None
|
| 269 |
|
| 270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
def load_model(config: ModelConfig, cfg: MergeConfig) -> tuple:
|
| 272 |
"""Load a model and its tokenizer/processor."""
|
| 273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
# Qwen3-VL uses a processor (handles both text + vision), not just a tokenizer
|
| 276 |
if config.architecture == "transformer+vision":
|
|
@@ -298,6 +354,7 @@ def load_model(config: ModelConfig, cfg: MergeConfig) -> tuple:
|
|
| 298 |
)
|
| 299 |
lang_params = sum(p.numel() for p in model.parameters()) - vision_params
|
| 300 |
print(f"[merge] Language: {lang_params / 1e9:.1f}B | Vision: {vision_params / 1e9:.1f}B")
|
|
|
|
| 301 |
|
| 302 |
return model, tokenizer
|
| 303 |
except ImportError:
|
|
@@ -318,6 +375,7 @@ def load_model(config: ModelConfig, cfg: MergeConfig) -> tuple:
|
|
| 318 |
)
|
| 319 |
|
| 320 |
print(f"[merge] Loaded {config.name}: {sum(p.numel() for p in model.parameters()) / 1e9:.1f}B params")
|
|
|
|
| 321 |
return model, tokenizer
|
| 322 |
|
| 323 |
|
|
@@ -888,6 +946,7 @@ def run_pipeline(
|
|
| 888 |
if cfg is None:
|
| 889 |
cfg = MergeConfig()
|
| 890 |
|
|
|
|
| 891 |
print("\n" + "=" * 70)
|
| 892 |
print("TD FUSE — Transport and Merge Pipeline")
|
| 893 |
print(f"Target: {TARGET.name} ({TARGET.hf_id})")
|
|
@@ -895,13 +954,18 @@ def run_pipeline(
|
|
| 895 |
print(f"Mode: Vision-Language (merging language backbone only, vision encoder untouched)")
|
| 896 |
print(f"Stages: {', '.join(stages)}")
|
| 897 |
print(f"Output: {cfg.output_dir}")
|
|
|
|
| 898 |
print("=" * 70)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 899 |
|
| 900 |
# Setup
|
| 901 |
try:
|
| 902 |
setup_tm_repo(cfg)
|
| 903 |
except FileNotFoundError as e:
|
| 904 |
-
print(f"\n
|
| 905 |
print("Continuing with fallback implementation...")
|
| 906 |
|
| 907 |
# Create output directories
|
|
@@ -1020,10 +1084,13 @@ def run_pipeline(
|
|
| 1020 |
emoji = "✓" if status == "passed" else "✗"
|
| 1021 |
print(f" {emoji} {stage_name}: {status}")
|
| 1022 |
print(f"\n Overall: {pipeline_results['overall_status']}")
|
|
|
|
|
|
|
| 1023 |
if residual_bank.residual_index:
|
| 1024 |
print(f"\n Residuals saved for: {', '.join(residual_bank.residual_index.keys())}")
|
| 1025 |
print(f" To recover lost knowledge later:")
|
| 1026 |
print(f" python -m td_fuse.run --reinject <stage> --strength 0.2")
|
| 1027 |
print("=" * 70)
|
|
|
|
| 1028 |
|
| 1029 |
return pipeline_results
|
|
|
|
| 134 |
4. MagMax — protect top magnitude params (extra safety layer)
|
| 135 |
"""
|
| 136 |
fused = target_state[key]
|
| 137 |
+
original = pre_merge_state[key].to(fused.device)
|
| 138 |
delta = fused - original
|
| 139 |
|
| 140 |
# --- ARM Steering (new, replaces orthogonal projection) ---
|
|
|
|
| 203 |
|
| 204 |
for key in current_state:
|
| 205 |
if key in pre_merge_state:
|
| 206 |
+
delta = current_state[key].cpu().float() - pre_merge_state[key].cpu().float()
|
| 207 |
if delta.abs().max() > 1e-8:
|
| 208 |
if key not in self.previous_deltas:
|
| 209 |
self.previous_deltas[key] = []
|
|
|
|
| 268 |
return None
|
| 269 |
|
| 270 |
|
| 271 |
+
def check_model_cached(hf_id: str) -> bool:
|
| 272 |
+
"""Check if a model is already in the HuggingFace cache."""
|
| 273 |
+
try:
|
| 274 |
+
from huggingface_hub import try_to_load_from_cache, model_info
|
| 275 |
+
# Quick check: see if config.json is cached (every model has one)
|
| 276 |
+
cached = try_to_load_from_cache(hf_id, "config.json")
|
| 277 |
+
if cached is not None and isinstance(cached, str):
|
| 278 |
+
return True
|
| 279 |
+
except Exception:
|
| 280 |
+
pass
|
| 281 |
+
return False
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def check_all_models_cached(stages: list) -> dict:
|
| 285 |
+
"""
|
| 286 |
+
Pre-flight check: are all needed models already downloaded?
|
| 287 |
+
Prints a clear table so you know what's cached and what will download.
|
| 288 |
+
"""
|
| 289 |
+
print("\n" + "=" * 60)
|
| 290 |
+
print("PRE-FLIGHT CHECK: Model cache status")
|
| 291 |
+
print("=" * 60)
|
| 292 |
+
sys.stdout.flush()
|
| 293 |
+
|
| 294 |
+
status = {}
|
| 295 |
+
|
| 296 |
+
# Target model
|
| 297 |
+
cached = check_model_cached(TARGET.hf_id)
|
| 298 |
+
tag = "CACHED" if cached else "WILL DOWNLOAD"
|
| 299 |
+
print(f" {TARGET.name:25s} {tag:15s} ({TARGET.hf_id})")
|
| 300 |
+
status[TARGET.name] = cached
|
| 301 |
+
|
| 302 |
+
# Source models for requested stages
|
| 303 |
+
for stage_name in stages:
|
| 304 |
+
source = get_source_by_stage(stage_name)
|
| 305 |
+
if source:
|
| 306 |
+
cached = check_model_cached(source.hf_id)
|
| 307 |
+
tag = "CACHED" if cached else "WILL DOWNLOAD"
|
| 308 |
+
print(f" {source.name:25s} {tag:15s} ({source.hf_id})")
|
| 309 |
+
status[source.name] = cached
|
| 310 |
+
|
| 311 |
+
not_cached = [name for name, c in status.items() if not c]
|
| 312 |
+
if not_cached:
|
| 313 |
+
print(f"\n {len(not_cached)} model(s) need downloading: {', '.join(not_cached)}")
|
| 314 |
+
print(f" This may take 10-30 min per model depending on connection speed.")
|
| 315 |
+
else:
|
| 316 |
+
print(f"\n All {len(status)} models are cached -- loading will be fast!")
|
| 317 |
+
|
| 318 |
+
print("=" * 60)
|
| 319 |
+
sys.stdout.flush()
|
| 320 |
+
return status
|
| 321 |
+
|
| 322 |
+
|
| 323 |
def load_model(config: ModelConfig, cfg: MergeConfig) -> tuple:
|
| 324 |
"""Load a model and its tokenizer/processor."""
|
| 325 |
+
load_start = time.time()
|
| 326 |
+
cached = check_model_cached(config.hf_id)
|
| 327 |
+
cache_msg = "(from cache)" if cached else "(downloading -- this may take a while)"
|
| 328 |
+
print(f"\n[merge] Loading {config.name} ({config.hf_id}) {cache_msg}...")
|
| 329 |
+
sys.stdout.flush()
|
| 330 |
|
| 331 |
# Qwen3-VL uses a processor (handles both text + vision), not just a tokenizer
|
| 332 |
if config.architecture == "transformer+vision":
|
|
|
|
| 354 |
)
|
| 355 |
lang_params = sum(p.numel() for p in model.parameters()) - vision_params
|
| 356 |
print(f"[merge] Language: {lang_params / 1e9:.1f}B | Vision: {vision_params / 1e9:.1f}B")
|
| 357 |
+
print(f"[merge] Loaded in {time.time()-load_start:.0f}s"); sys.stdout.flush()
|
| 358 |
|
| 359 |
return model, tokenizer
|
| 360 |
except ImportError:
|
|
|
|
| 375 |
)
|
| 376 |
|
| 377 |
print(f"[merge] Loaded {config.name}: {sum(p.numel() for p in model.parameters()) / 1e9:.1f}B params")
|
| 378 |
+
print(f"[merge] Loaded in {time.time()-load_start:.0f}s"); sys.stdout.flush()
|
| 379 |
return model, tokenizer
|
| 380 |
|
| 381 |
|
|
|
|
| 946 |
if cfg is None:
|
| 947 |
cfg = MergeConfig()
|
| 948 |
|
| 949 |
+
pipeline_start = time.time()
|
| 950 |
print("\n" + "=" * 70)
|
| 951 |
print("TD FUSE — Transport and Merge Pipeline")
|
| 952 |
print(f"Target: {TARGET.name} ({TARGET.hf_id})")
|
|
|
|
| 954 |
print(f"Mode: Vision-Language (merging language backbone only, vision encoder untouched)")
|
| 955 |
print(f"Stages: {', '.join(stages)}")
|
| 956 |
print(f"Output: {cfg.output_dir}")
|
| 957 |
+
print(f"Started at: {time.strftime('%H:%M:%S')}")
|
| 958 |
print("=" * 70)
|
| 959 |
+
sys.stdout.flush()
|
| 960 |
+
|
| 961 |
+
# --- Pre-flight: check which models are cached ---
|
| 962 |
+
check_all_models_cached(stages)
|
| 963 |
|
| 964 |
# Setup
|
| 965 |
try:
|
| 966 |
setup_tm_repo(cfg)
|
| 967 |
except FileNotFoundError as e:
|
| 968 |
+
print(f"\n WARNING: {e}")
|
| 969 |
print("Continuing with fallback implementation...")
|
| 970 |
|
| 971 |
# Create output directories
|
|
|
|
| 1084 |
emoji = "✓" if status == "passed" else "✗"
|
| 1085 |
print(f" {emoji} {stage_name}: {status}")
|
| 1086 |
print(f"\n Overall: {pipeline_results['overall_status']}")
|
| 1087 |
+
total_pipeline_time = time.time() - pipeline_start
|
| 1088 |
+
print(f"\n Total pipeline time: {total_pipeline_time/60:.1f} min ({total_pipeline_time/3600:.1f} hours)")
|
| 1089 |
if residual_bank.residual_index:
|
| 1090 |
print(f"\n Residuals saved for: {', '.join(residual_bank.residual_index.keys())}")
|
| 1091 |
print(f" To recover lost knowledge later:")
|
| 1092 |
print(f" python -m td_fuse.run --reinject <stage> --strength 0.2")
|
| 1093 |
print("=" * 70)
|
| 1094 |
+
sys.stdout.flush()
|
| 1095 |
|
| 1096 |
return pipeline_results
|
hugging/td_fuse/transport.py
CHANGED
|
@@ -572,6 +572,10 @@ def fuse_weights(
|
|
| 572 |
source_key = _map_key(target_key, source_config)
|
| 573 |
if source_key is None or source_key not in source_state:
|
| 574 |
skipped_count += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
| 575 |
continue
|
| 576 |
|
| 577 |
target_w = target_state[target_key]
|
|
@@ -618,6 +622,10 @@ def fuse_weights(
|
|
| 618 |
def _should_skip(key: str, source_config: ModelConfig) -> bool:
|
| 619 |
"""Determine if a parameter should be skipped during merge."""
|
| 620 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
# Always skip if source model says to skip embeddings
|
| 622 |
if source_config.skip_embeddings and ("embed_tokens" in key or "lm_head" in key):
|
| 623 |
return True
|
|
@@ -640,22 +648,42 @@ def _should_skip(key: str, source_config: ModelConfig) -> bool:
|
|
| 640 |
return False
|
| 641 |
|
| 642 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 643 |
def _map_key(target_key: str, source_config: ModelConfig) -> Optional[str]:
|
| 644 |
"""Map a target model parameter name to the corresponding source name."""
|
| 645 |
|
| 646 |
-
#
|
|
|
|
|
|
|
|
|
|
| 647 |
if source_config.architecture == "transformer" and source_config.layers == 36:
|
| 648 |
-
return
|
| 649 |
|
| 650 |
# For Llama (32 layers -> 36 layers), map layer indices
|
| 651 |
if "layer_mapping_32_to_36" in source_config.special_handling:
|
| 652 |
-
if "model.layers." in
|
| 653 |
# Extract layer number
|
| 654 |
-
parts =
|
| 655 |
try:
|
| 656 |
layer_idx = int(parts[2])
|
| 657 |
except (IndexError, ValueError):
|
| 658 |
-
return
|
| 659 |
|
| 660 |
# Map 36 target layers to 32 source layers (stride)
|
| 661 |
source_layer = int(layer_idx * 32 / 36)
|
|
@@ -664,17 +692,17 @@ def _map_key(target_key: str, source_config: ModelConfig) -> Optional[str]:
|
|
| 664 |
|
| 665 |
# For MiMo (same layer count, different extras), keys mostly match
|
| 666 |
if source_config.architecture == "transformer+mtp":
|
| 667 |
-
if "mtp_head" in
|
| 668 |
return None # MTP heads don't exist in target
|
| 669 |
-
return
|
| 670 |
|
| 671 |
# For Falcon hybrid, only attention and MLP keys map
|
| 672 |
if source_config.architecture == "hybrid_ssm":
|
| 673 |
-
if any(k in
|
| 674 |
-
return
|
| 675 |
return None # Mamba components don't map
|
| 676 |
|
| 677 |
-
return
|
| 678 |
|
| 679 |
|
| 680 |
def _align_dimensions(
|
|
|
|
| 572 |
source_key = _map_key(target_key, source_config)
|
| 573 |
if source_key is None or source_key not in source_state:
|
| 574 |
skipped_count += 1
|
| 575 |
+
# Log first few misses to help debug key mapping issues
|
| 576 |
+
if skipped_count <= 5:
|
| 577 |
+
print(f" [skip] No source match for: {target_key} (mapped to: {source_key})")
|
| 578 |
+
sys.stdout.flush()
|
| 579 |
continue
|
| 580 |
|
| 581 |
target_w = target_state[target_key]
|
|
|
|
| 622 |
def _should_skip(key: str, source_config: ModelConfig) -> bool:
|
| 623 |
"""Determine if a parameter should be skipped during merge."""
|
| 624 |
|
| 625 |
+
# Skip vision encoder params (Qwen3-VL) -- these should never be merged
|
| 626 |
+
if key.startswith("visual") or key.startswith("merger") or key.startswith("model.visual") or key.startswith("model.merger"):
|
| 627 |
+
return True
|
| 628 |
+
|
| 629 |
# Always skip if source model says to skip embeddings
|
| 630 |
if source_config.skip_embeddings and ("embed_tokens" in key or "lm_head" in key):
|
| 631 |
return True
|
|
|
|
| 648 |
return False
|
| 649 |
|
| 650 |
|
| 651 |
+
def _strip_vl_prefix(key: str) -> str:
|
| 652 |
+
"""
|
| 653 |
+
Strip the 'language_model.' prefix that Qwen3-VL adds.
|
| 654 |
+
|
| 655 |
+
Qwen3-VL wraps all language params under 'model.language_model.*'
|
| 656 |
+
but source models (DeepSeek, MiMo, Llama, Falcon) use 'model.*' directly.
|
| 657 |
+
|
| 658 |
+
Example:
|
| 659 |
+
target: model.language_model.layers.0.self_attn.q_proj.weight
|
| 660 |
+
source: model.layers.0.self_attn.q_proj.weight
|
| 661 |
+
"""
|
| 662 |
+
# model.language_model.X -> model.X
|
| 663 |
+
if "language_model." in key:
|
| 664 |
+
return key.replace("language_model.", "")
|
| 665 |
+
return key
|
| 666 |
+
|
| 667 |
+
|
| 668 |
def _map_key(target_key: str, source_config: ModelConfig) -> Optional[str]:
|
| 669 |
"""Map a target model parameter name to the corresponding source name."""
|
| 670 |
|
| 671 |
+
# Step 1: Strip Qwen3-VL's language_model. prefix so we can match source keys
|
| 672 |
+
source_key = _strip_vl_prefix(target_key)
|
| 673 |
+
|
| 674 |
+
# For same-architecture models (DeepSeek), keys match directly after prefix strip
|
| 675 |
if source_config.architecture == "transformer" and source_config.layers == 36:
|
| 676 |
+
return source_key
|
| 677 |
|
| 678 |
# For Llama (32 layers -> 36 layers), map layer indices
|
| 679 |
if "layer_mapping_32_to_36" in source_config.special_handling:
|
| 680 |
+
if "model.layers." in source_key:
|
| 681 |
# Extract layer number
|
| 682 |
+
parts = source_key.split(".")
|
| 683 |
try:
|
| 684 |
layer_idx = int(parts[2])
|
| 685 |
except (IndexError, ValueError):
|
| 686 |
+
return source_key
|
| 687 |
|
| 688 |
# Map 36 target layers to 32 source layers (stride)
|
| 689 |
source_layer = int(layer_idx * 32 / 36)
|
|
|
|
| 692 |
|
| 693 |
# For MiMo (same layer count, different extras), keys mostly match
|
| 694 |
if source_config.architecture == "transformer+mtp":
|
| 695 |
+
if "mtp_head" in source_key:
|
| 696 |
return None # MTP heads don't exist in target
|
| 697 |
+
return source_key
|
| 698 |
|
| 699 |
# For Falcon hybrid, only attention and MLP keys map
|
| 700 |
if source_config.architecture == "hybrid_ssm":
|
| 701 |
+
if any(k in source_key for k in ["self_attn", "mlp", "layer_norm"]):
|
| 702 |
+
return source_key # These exist in both
|
| 703 |
return None # Mamba components don't map
|
| 704 |
|
| 705 |
+
return source_key
|
| 706 |
|
| 707 |
|
| 708 |
def _align_dimensions(
|