Current td_fuse code with all fixes
Browse files- td_fuse/__pycache__/__init__.cpython-312.pyc +0 -0
- td_fuse/__pycache__/canary.cpython-312.pyc +0 -0
- td_fuse/__pycache__/config.cpython-312.pyc +0 -0
- td_fuse/__pycache__/heal.cpython-312.pyc +0 -0
- td_fuse/__pycache__/merge.cpython-312.pyc +0 -0
- td_fuse/__pycache__/techniques.cpython-312.pyc +0 -0
- td_fuse/__pycache__/transport.cpython-312.pyc +0 -0
- td_fuse/__pycache__/validate.cpython-312.pyc +0 -0
- td_fuse/config.py +2 -2
- td_fuse/heal.py +15 -46
- td_fuse/merge.py +3 -3
- td_fuse/transport.py +20 -5
td_fuse/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (1.28 kB). View file
|
|
|
td_fuse/__pycache__/canary.cpython-312.pyc
ADDED
|
Binary file (8.27 kB). View file
|
|
|
td_fuse/__pycache__/config.cpython-312.pyc
ADDED
|
Binary file (8.59 kB). View file
|
|
|
td_fuse/__pycache__/heal.cpython-312.pyc
ADDED
|
Binary file (16.2 kB). View file
|
|
|
td_fuse/__pycache__/merge.cpython-312.pyc
ADDED
|
Binary file (58.7 kB). View file
|
|
|
td_fuse/__pycache__/techniques.cpython-312.pyc
ADDED
|
Binary file (25.1 kB). View file
|
|
|
td_fuse/__pycache__/transport.cpython-312.pyc
ADDED
|
Binary file (45.4 kB). View file
|
|
|
td_fuse/__pycache__/validate.cpython-312.pyc
ADDED
|
Binary file (11.8 kB). View file
|
|
|
td_fuse/config.py
CHANGED
|
@@ -129,7 +129,7 @@ SOURCES = [
|
|
| 129 |
skip_embeddings=True, # Must skip — vocab too different
|
| 130 |
trust_remote_code=False,
|
| 131 |
merge_risk="medium",
|
| 132 |
-
merge_alpha=0.
|
| 133 |
special_handling=["skip_embeddings", "drop_qkv_bias", "layer_mapping_32_to_36"],
|
| 134 |
notes=(
|
| 135 |
"32 layers vs 36 — T&M's P matrix handles layer mapping. "
|
|
@@ -152,7 +152,7 @@ SOURCES = [
|
|
| 152 |
skip_embeddings=True, # Must skip — vocab too different
|
| 153 |
trust_remote_code=True, # Likely custom hybrid code
|
| 154 |
merge_risk="high",
|
| 155 |
-
merge_alpha=0.
|
| 156 |
special_handling=[
|
| 157 |
"skip_embeddings",
|
| 158 |
"drop_mamba_state_params", # A, D matrices have no Qwen3 equivalent
|
|
|
|
| 129 |
skip_embeddings=True, # Must skip — vocab too different
|
| 130 |
trust_remote_code=False,
|
| 131 |
merge_risk="medium",
|
| 132 |
+
merge_alpha=0.08, # Lower alpha — layer mismatch risk
|
| 133 |
special_handling=["skip_embeddings", "drop_qkv_bias", "layer_mapping_32_to_36"],
|
| 134 |
notes=(
|
| 135 |
"32 layers vs 36 — T&M's P matrix handles layer mapping. "
|
|
|
|
| 152 |
skip_embeddings=True, # Must skip — vocab too different
|
| 153 |
trust_remote_code=True, # Likely custom hybrid code
|
| 154 |
merge_risk="high",
|
| 155 |
+
merge_alpha=0.08, # Conservative — highest risk model
|
| 156 |
special_handling=[
|
| 157 |
"skip_embeddings",
|
| 158 |
"drop_mamba_state_params", # A, D matrices have no Qwen3 equivalent
|
td_fuse/heal.py
CHANGED
|
@@ -69,11 +69,11 @@ def load_healing_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
|
|
| 69 |
# Each entry: (dataset_id, config_name_or_None, split, count, text_field)
|
| 70 |
datasets_to_load = [
|
| 71 |
# General language — same calibration data source that works reliably
|
| 72 |
-
("neuralmagic/LLM_compression_calibration", None, "train",
|
| 73 |
# Math reasoning (exercises DeepSeek/MiMo contributions)
|
| 74 |
-
("openai/gsm8k", "main", "train",
|
| 75 |
# Code — bigcode/starcoderdata is a modern alternative
|
| 76 |
-
("
|
| 77 |
]
|
| 78 |
|
| 79 |
all_texts = []
|
|
@@ -193,7 +193,9 @@ def apply_qlora_unsloth(
|
|
| 193 |
learning_rate=cfg.heal_learning_rate,
|
| 194 |
bf16=True,
|
| 195 |
logging_steps=10,
|
| 196 |
-
save_strategy="
|
|
|
|
|
|
|
| 197 |
warmup_ratio=0.05,
|
| 198 |
lr_scheduler_type="cosine",
|
| 199 |
optim="adamw_8bit", # Memory-efficient optimiser
|
|
@@ -249,24 +251,15 @@ def apply_qlora_standard(
|
|
| 249 |
return 'td_fuse_outputs/healed'
|
| 250 |
import torch
|
| 251 |
from peft import LoraConfig, get_peft_model, TaskType
|
| 252 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 253 |
|
| 254 |
print("\n[heal] Loading model with standard PEFT...")
|
| 255 |
|
| 256 |
-
# 4-bit quantisation config
|
| 257 |
-
bnb_config = BitsAndBytesConfig(
|
| 258 |
-
load_in_4bit=True,
|
| 259 |
-
bnb_4bit_quant_type="nf4",
|
| 260 |
-
bnb_4bit_compute_dtype=getattr(torch, cfg.dtype),
|
| 261 |
-
bnb_4bit_use_double_quant=True,
|
| 262 |
-
)
|
| 263 |
-
|
| 264 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 265 |
model = _load_model_smart(
|
| 266 |
model_path,
|
| 267 |
-
quantization_config=bnb_config,
|
| 268 |
device_map="auto",
|
| 269 |
-
torch_dtype=
|
| 270 |
)
|
| 271 |
|
| 272 |
# LoRA config
|
|
@@ -328,7 +321,9 @@ def apply_qlora_standard(
|
|
| 328 |
learning_rate=cfg.heal_learning_rate,
|
| 329 |
bf16=True,
|
| 330 |
logging_steps=10,
|
| 331 |
-
save_strategy="
|
|
|
|
|
|
|
| 332 |
warmup_ratio=0.05,
|
| 333 |
lr_scheduler_type="cosine",
|
| 334 |
optim="adamw_torch",
|
|
@@ -365,36 +360,10 @@ def apply_qlora_standard(
|
|
| 365 |
|
| 366 |
gc.collect()
|
| 367 |
|
| 368 |
-
#
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
print(f"[heal]
|
| 372 |
-
try:
|
| 373 |
-
from safetensors.torch import save_file
|
| 374 |
-
import torch as _torch
|
| 375 |
-
# Fixed: use named_parameters for proper dequantization
|
| 376 |
-
clean_state = {}
|
| 377 |
-
for k, v in merged_model.named_parameters():
|
| 378 |
-
if hasattr(v, 'dequantize'):
|
| 379 |
-
clean_state[k] = v.dequantize().to(_torch.bfloat16)
|
| 380 |
-
elif v.data.dtype in (_torch.float32, _torch.float16, _torch.bfloat16):
|
| 381 |
-
clean_state[k] = v.data.to(_torch.bfloat16)
|
| 382 |
-
else:
|
| 383 |
-
clean_state[k] = v.data.float().to(_torch.bfloat16)
|
| 384 |
-
save_file(clean_state, str(healed_dir / "model.safetensors"))
|
| 385 |
-
if hasattr(merged_model, 'config'):
|
| 386 |
-
if hasattr(merged_model.config, "quantization_config"):
|
| 387 |
-
merged_model.config.quantization_config = None
|
| 388 |
-
print("[heal] Removed quantization_config from saved config (weights are bf16 now)")
|
| 389 |
-
merged_model.config.save_pretrained(str(healed_dir))
|
| 390 |
-
tokenizer.save_pretrained(str(healed_dir))
|
| 391 |
-
print(f"[heal] SAVED OK: {healed_dir / 'model.safetensors'}")
|
| 392 |
-
except Exception as e:
|
| 393 |
-
# Emergency fallback: try save_pretrained as last resort
|
| 394 |
-
print(f"[heal] Manual save failed ({e}), trying save_pretrained...")
|
| 395 |
-
merged_model.save_pretrained(str(healed_dir))
|
| 396 |
-
tokenizer.save_pretrained(str(healed_dir))
|
| 397 |
-
print(f"[heal] SAVED OK via save_pretrained: {healed_dir}")
|
| 398 |
|
| 399 |
# Verify the save actually worked before cleaning up ANYTHING
|
| 400 |
saved_model = healed_dir / "model.safetensors"
|
|
|
|
| 69 |
# Each entry: (dataset_id, config_name_or_None, split, count, text_field)
|
| 70 |
datasets_to_load = [
|
| 71 |
# General language — same calibration data source that works reliably
|
| 72 |
+
("neuralmagic/LLM_compression_calibration", None, "train", 1500, "text"),
|
| 73 |
# Math reasoning (exercises DeepSeek/MiMo contributions)
|
| 74 |
+
("openai/gsm8k", "main", "train", 1000, "question"),
|
| 75 |
# Code — bigcode/starcoderdata is a modern alternative
|
| 76 |
+
("sahil2801/CodeAlpaca-20k", None, "train", 500, "output"),
|
| 77 |
]
|
| 78 |
|
| 79 |
all_texts = []
|
|
|
|
| 193 |
learning_rate=cfg.heal_learning_rate,
|
| 194 |
bf16=True,
|
| 195 |
logging_steps=10,
|
| 196 |
+
save_strategy="steps",
|
| 197 |
+
save_steps=50,
|
| 198 |
+
save_total_limit=2, max_steps=50, # Don't save intermediate checkpoints — saves ~17GB disk
|
| 199 |
warmup_ratio=0.05,
|
| 200 |
lr_scheduler_type="cosine",
|
| 201 |
optim="adamw_8bit", # Memory-efficient optimiser
|
|
|
|
| 251 |
return 'td_fuse_outputs/healed'
|
| 252 |
import torch
|
| 253 |
from peft import LoraConfig, get_peft_model, TaskType
|
| 254 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 255 |
|
| 256 |
print("\n[heal] Loading model with standard PEFT...")
|
| 257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 259 |
model = _load_model_smart(
|
| 260 |
model_path,
|
|
|
|
| 261 |
device_map="auto",
|
| 262 |
+
torch_dtype=torch.bfloat16,
|
| 263 |
)
|
| 264 |
|
| 265 |
# LoRA config
|
|
|
|
| 321 |
learning_rate=cfg.heal_learning_rate,
|
| 322 |
bf16=True,
|
| 323 |
logging_steps=10,
|
| 324 |
+
save_strategy="steps",
|
| 325 |
+
save_steps=50,
|
| 326 |
+
save_total_limit=2, max_steps=50, # Don't save intermediate checkpoints — saves ~17GB disk
|
| 327 |
warmup_ratio=0.05,
|
| 328 |
lr_scheduler_type="cosine",
|
| 329 |
optim="adamw_torch",
|
|
|
|
| 360 |
|
| 361 |
gc.collect()
|
| 362 |
|
| 363 |
+
# bf16 model — save_pretrained works correctly, no dequantize needed
|
| 364 |
+
merged_model.save_pretrained(str(healed_dir), safe_serialization=True)
|
| 365 |
+
tokenizer.save_pretrained(str(healed_dir))
|
| 366 |
+
print(f"[heal] SAVED OK: {healed_dir}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
|
| 368 |
# Verify the save actually worked before cleaning up ANYTHING
|
| 369 |
saved_model = healed_dir / "model.safetensors"
|
td_fuse/merge.py
CHANGED
|
@@ -726,11 +726,11 @@ def run_single_merge(
|
|
| 726 |
print(f"\n[merge] Step 4/10: Extracting activations (both models)..."); sys.stdout.flush()
|
| 727 |
step_t = time.time()
|
| 728 |
# Check if source model has a different vocabulary size than target.
|
| 729 |
-
source_vocab_size =
|
| 730 |
-
target_vocab_size =
|
| 731 |
print(f"[merge] Vocab sizes -- target: {target_vocab_size}, source: {source_vocab_size}")
|
| 732 |
|
| 733 |
-
if source_vocab_size
|
| 734 |
print(f"[merge] VOCAB MISMATCH detected! Re-tokenizing calibration data for {source_config.name}...")
|
| 735 |
source_calibration = retokenize_calibration(calibration_raw_texts, source_tokenizer, cfg)
|
| 736 |
print(f"[merge] Extracting source activations (with source-tokenized data)...")
|
|
|
|
| 726 |
print(f"\n[merge] Step 4/10: Extracting activations (both models)..."); sys.stdout.flush()
|
| 727 |
step_t = time.time()
|
| 728 |
# Check if source model has a different vocabulary size than target.
|
| 729 |
+
source_vocab_size = len(source_tokenizer)
|
| 730 |
+
target_vocab_size = len(target_tokenizer)
|
| 731 |
print(f"[merge] Vocab sizes -- target: {target_vocab_size}, source: {source_vocab_size}")
|
| 732 |
|
| 733 |
+
if source_vocab_size != target_vocab_size:
|
| 734 |
print(f"[merge] VOCAB MISMATCH detected! Re-tokenizing calibration data for {source_config.name}...")
|
| 735 |
source_calibration = retokenize_calibration(calibration_raw_texts, source_tokenizer, cfg)
|
| 736 |
print(f"[merge] Extracting source activations (with source-tokenized data)...")
|
td_fuse/transport.py
CHANGED
|
@@ -520,7 +520,7 @@ def _compute_plans_fallback(
|
|
| 520 |
sys.stdout.flush()
|
| 521 |
|
| 522 |
# Timeout: 90 min (Sinkhorn on 4096x4096 is slow on CPU)
|
| 523 |
-
tracker.check_timeout(timeout_seconds=
|
| 524 |
|
| 525 |
if permutations:
|
| 526 |
print(f"[transport] Computed {len(permutations)} neuron permutations")
|
|
@@ -569,8 +569,8 @@ def _compute_plans_fallback(
|
|
| 569 |
layer_costs[i, j] = 1.0 - sim
|
| 570 |
tracker.tick(f"layer sim {i},{j}")
|
| 571 |
|
| 572 |
-
# Timeout:
|
| 573 |
-
tracker.check_timeout(timeout_seconds=
|
| 574 |
|
| 575 |
print(f"[transport] Step 1/3 done: {n_source}x{n_target} similarities computed")
|
| 576 |
sys.stdout.flush()
|
|
@@ -579,10 +579,24 @@ def _compute_plans_fallback(
|
|
| 579 |
print("[transport] Step 2/3: Computing neuron-level transport (top-3 per target)...")
|
| 580 |
sys.stdout.flush()
|
| 581 |
Q_matrices = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 582 |
for j, tl in enumerate(target_layers):
|
| 583 |
top3 = np.argsort(layer_costs[:, j])[:3]
|
| 584 |
for i in top3:
|
| 585 |
sl = source_layers[i]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 586 |
S = source_act[sl].numpy()
|
| 587 |
T = target_act[tl].numpy()
|
| 588 |
|
|
@@ -595,14 +609,15 @@ def _compute_plans_fallback(
|
|
| 595 |
corr = S_norm.T @ T_norm / S.shape[0]
|
| 596 |
cost = 1.0 - corr
|
| 597 |
Q_matrices[(sl, tl)] = _sinkhorn(cost, reg=0.1, max_iter=50)
|
|
|
|
| 598 |
tracker.tick(f"Q({sl},{tl})")
|
| 599 |
|
| 600 |
if (j + 1) % 5 == 0 or j == 0:
|
| 601 |
print(f" Target layer {j + 1}/{n_target}: matched to top-3 sources")
|
| 602 |
sys.stdout.flush()
|
| 603 |
|
| 604 |
-
# Timeout:
|
| 605 |
-
tracker.check_timeout(timeout_seconds=
|
| 606 |
|
| 607 |
print(f"[transport] Step 2/3 done: {len(Q_matrices)} Q matrices computed")
|
| 608 |
sys.stdout.flush()
|
|
|
|
| 520 |
sys.stdout.flush()
|
| 521 |
|
| 522 |
# Timeout: 90 min (Sinkhorn on 4096x4096 is slow on CPU)
|
| 523 |
+
tracker.check_timeout(timeout_seconds=10800)
|
| 524 |
|
| 525 |
if permutations:
|
| 526 |
print(f"[transport] Computed {len(permutations)} neuron permutations")
|
|
|
|
| 569 |
layer_costs[i, j] = 1.0 - sim
|
| 570 |
tracker.tick(f"layer sim {i},{j}")
|
| 571 |
|
| 572 |
+
# Timeout: 90 min for cross-arch
|
| 573 |
+
tracker.check_timeout(timeout_seconds=10800)
|
| 574 |
|
| 575 |
print(f"[transport] Step 1/3 done: {n_source}x{n_target} similarities computed")
|
| 576 |
sys.stdout.flush()
|
|
|
|
| 579 |
print("[transport] Step 2/3: Computing neuron-level transport (top-3 per target)...")
|
| 580 |
sys.stdout.flush()
|
| 581 |
Q_matrices = {}
|
| 582 |
+
|
| 583 |
+
# Incremental cache: save each Q as we go so crashes don't lose progress
|
| 584 |
+
q_cache_dir = Path("td_fuse_checkpoints") / "q_cache_crossarch"
|
| 585 |
+
q_cache_dir.mkdir(parents=True, exist_ok=True)
|
| 586 |
+
|
| 587 |
for j, tl in enumerate(target_layers):
|
| 588 |
top3 = np.argsort(layer_costs[:, j])[:3]
|
| 589 |
for i in top3:
|
| 590 |
sl = source_layers[i]
|
| 591 |
+
cache_key = f"{sl}__{tl}".replace("/", "_").replace(".", "_")
|
| 592 |
+
cache_path = q_cache_dir / f"{cache_key}.npy"
|
| 593 |
+
|
| 594 |
+
# Skip if already computed in a previous run
|
| 595 |
+
if cache_path.exists():
|
| 596 |
+
Q_matrices[(sl, tl)] = np.load(str(cache_path))
|
| 597 |
+
tracker.tick(f"Q({sl},{tl})")
|
| 598 |
+
continue
|
| 599 |
+
|
| 600 |
S = source_act[sl].numpy()
|
| 601 |
T = target_act[tl].numpy()
|
| 602 |
|
|
|
|
| 609 |
corr = S_norm.T @ T_norm / S.shape[0]
|
| 610 |
cost = 1.0 - corr
|
| 611 |
Q_matrices[(sl, tl)] = _sinkhorn(cost, reg=0.1, max_iter=50)
|
| 612 |
+
np.save(str(cache_path), Q_matrices[(sl, tl)])
|
| 613 |
tracker.tick(f"Q({sl},{tl})")
|
| 614 |
|
| 615 |
if (j + 1) % 5 == 0 or j == 0:
|
| 616 |
print(f" Target layer {j + 1}/{n_target}: matched to top-3 sources")
|
| 617 |
sys.stdout.flush()
|
| 618 |
|
| 619 |
+
# Timeout: 90 min for cross-arch (was 30, too short for 72 layers)
|
| 620 |
+
tracker.check_timeout(timeout_seconds=10800)
|
| 621 |
|
| 622 |
print(f"[transport] Step 2/3 done: {len(Q_matrices)} Q matrices computed")
|
| 623 |
sys.stdout.flush()
|