Vocab mismatch fix for cross-arch merging
Browse files- td_fuse/merge.py +19 -4
- td_fuse/transport.py +30 -1
td_fuse/merge.py
CHANGED
|
@@ -39,6 +39,7 @@ from .canary import inject_canary, test_all_canaries
|
|
| 39 |
from .transport import (
|
| 40 |
setup_tm_repo,
|
| 41 |
load_calibration_data,
|
|
|
|
| 42 |
extract_activations,
|
| 43 |
compute_transport_plans,
|
| 44 |
fuse_weights,
|
|
@@ -662,6 +663,7 @@ def run_single_merge(
|
|
| 662 |
protection: MergeProtection,
|
| 663 |
residual_bank: ResidualBank = None,
|
| 664 |
calibration_data: list = None,
|
|
|
|
| 665 |
baseline_perplexity: float = None,
|
| 666 |
merged_sources: list = None,
|
| 667 |
) -> dict:
|
|
@@ -717,14 +719,26 @@ def run_single_merge(
|
|
| 717 |
print(f"\n[merge] Step 3/10: Loading calibration data..."); sys.stdout.flush()
|
| 718 |
step_t = time.time()
|
| 719 |
if calibration_data is None:
|
| 720 |
-
calibration_data = load_calibration_data(cfg, target_tokenizer)
|
| 721 |
print(f"[merge] Step 3/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
|
| 722 |
|
| 723 |
# --- Step 4: Extract activations ---
|
| 724 |
print(f"\n[merge] Step 4/10: Extracting activations (both models)..."); sys.stdout.flush()
|
| 725 |
step_t = time.time()
|
| 726 |
-
|
| 727 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 728 |
|
| 729 |
print(f"[merge] Extracting target activations...")
|
| 730 |
pre_merge_target_activations = extract_activations(target_model, calibration_data)
|
|
@@ -1101,7 +1115,7 @@ def run_pipeline(
|
|
| 1101 |
print(f"[pipeline] Baseline perplexity: {baseline_ppl:.2f}")
|
| 1102 |
|
| 1103 |
# --- Load calibration data once ---
|
| 1104 |
-
calibration_data = load_calibration_data(cfg, target_tokenizer)
|
| 1105 |
|
| 1106 |
# --- Initialize merge protection + residual bank ---
|
| 1107 |
protection = MergeProtection(cfg)
|
|
@@ -1138,6 +1152,7 @@ def run_pipeline(
|
|
| 1138 |
protection,
|
| 1139 |
residual_bank=residual_bank,
|
| 1140 |
calibration_data=calibration_data,
|
|
|
|
| 1141 |
baseline_perplexity=baseline_ppl,
|
| 1142 |
merged_sources=merged_sources,
|
| 1143 |
)
|
|
|
|
| 39 |
from .transport import (
|
| 40 |
setup_tm_repo,
|
| 41 |
load_calibration_data,
|
| 42 |
+
retokenize_calibration,
|
| 43 |
extract_activations,
|
| 44 |
compute_transport_plans,
|
| 45 |
fuse_weights,
|
|
|
|
| 663 |
protection: MergeProtection,
|
| 664 |
residual_bank: ResidualBank = None,
|
| 665 |
calibration_data: list = None,
|
| 666 |
+
calibration_raw_texts: list = None,
|
| 667 |
baseline_perplexity: float = None,
|
| 668 |
merged_sources: list = None,
|
| 669 |
) -> dict:
|
|
|
|
| 719 |
print(f"\n[merge] Step 3/10: Loading calibration data..."); sys.stdout.flush()
|
| 720 |
step_t = time.time()
|
| 721 |
if calibration_data is None:
|
| 722 |
+
calibration_data, calibration_raw_texts = load_calibration_data(cfg, target_tokenizer)
|
| 723 |
print(f"[merge] Step 3/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
|
| 724 |
|
| 725 |
# --- Step 4: Extract activations ---
|
| 726 |
print(f"\n[merge] Step 4/10: Extracting activations (both models)..."); sys.stdout.flush()
|
| 727 |
step_t = time.time()
|
| 728 |
+
# Check if source model has a different vocabulary size than target.
|
| 729 |
+
source_vocab_size = source_model.config.vocab_size if hasattr(source_model.config, 'vocab_size') else None
|
| 730 |
+
target_vocab_size = target_model.config.vocab_size if hasattr(target_model.config, 'vocab_size') else None
|
| 731 |
+
print(f"[merge] Vocab sizes -- target: {target_vocab_size}, source: {source_vocab_size}")
|
| 732 |
+
|
| 733 |
+
if source_vocab_size and target_vocab_size and source_vocab_size != target_vocab_size:
|
| 734 |
+
print(f"[merge] VOCAB MISMATCH detected! Re-tokenizing calibration data for {source_config.name}...")
|
| 735 |
+
source_calibration = retokenize_calibration(calibration_raw_texts, source_tokenizer, cfg)
|
| 736 |
+
print(f"[merge] Extracting source activations (with source-tokenized data)...")
|
| 737 |
+
source_activations = extract_activations(source_model, source_calibration)
|
| 738 |
+
del source_calibration
|
| 739 |
+
else:
|
| 740 |
+
print(f"[merge] Extracting source activations...")
|
| 741 |
+
source_activations = extract_activations(source_model, calibration_data)
|
| 742 |
|
| 743 |
print(f"[merge] Extracting target activations...")
|
| 744 |
pre_merge_target_activations = extract_activations(target_model, calibration_data)
|
|
|
|
| 1115 |
print(f"[pipeline] Baseline perplexity: {baseline_ppl:.2f}")
|
| 1116 |
|
| 1117 |
# --- Load calibration data once ---
|
| 1118 |
+
calibration_data, calibration_raw_texts = load_calibration_data(cfg, target_tokenizer)
|
| 1119 |
|
| 1120 |
# --- Initialize merge protection + residual bank ---
|
| 1121 |
protection = MergeProtection(cfg)
|
|
|
|
| 1152 |
protection,
|
| 1153 |
residual_bank=residual_bank,
|
| 1154 |
calibration_data=calibration_data,
|
| 1155 |
+
calibration_raw_texts=calibration_raw_texts,
|
| 1156 |
baseline_perplexity=baseline_ppl,
|
| 1157 |
merged_sources=merged_sources,
|
| 1158 |
)
|
td_fuse/transport.py
CHANGED
|
@@ -105,7 +105,7 @@ def setup_tm_repo(cfg: MergeConfig):
|
|
| 105 |
print(f"[transport] Added T&M core to path: {core_path}")
|
| 106 |
|
| 107 |
|
| 108 |
-
def load_calibration_data(cfg: MergeConfig, tokenizer: AutoTokenizer) ->
|
| 109 |
"""
|
| 110 |
Load calibration data for activation extraction.
|
| 111 |
|
|
@@ -118,6 +118,7 @@ def load_calibration_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
|
|
| 118 |
print(f"[transport] Loading calibration data ({cfg.calibration_samples} samples)...")
|
| 119 |
|
| 120 |
samples = []
|
|
|
|
| 121 |
|
| 122 |
# --- Pile: general text (600 samples) ---
|
| 123 |
try:
|
|
@@ -140,6 +141,7 @@ def load_calibration_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
|
|
| 140 |
return_tensors="pt",
|
| 141 |
)
|
| 142 |
samples.append(tokens)
|
|
|
|
| 143 |
count += 1
|
| 144 |
if count % 100 == 0:
|
| 145 |
print(f" Pile: {count}/600 samples loaded...")
|
|
@@ -171,6 +173,7 @@ def load_calibration_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
|
|
| 171 |
return_tensors="pt",
|
| 172 |
)
|
| 173 |
samples.append(tokens)
|
|
|
|
| 174 |
count += 1
|
| 175 |
if count % 100 == 0:
|
| 176 |
print(f" neuralmagic: {count}/{remaining} samples loaded...")
|
|
@@ -182,6 +185,32 @@ def load_calibration_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
|
|
| 182 |
tracker.done()
|
| 183 |
print(f"[transport] Total calibration samples: {len(samples)}")
|
| 184 |
sys.stdout.flush()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
return samples
|
| 186 |
|
| 187 |
|
|
|
|
| 105 |
print(f"[transport] Added T&M core to path: {core_path}")
|
| 106 |
|
| 107 |
|
| 108 |
+
def load_calibration_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> tuple:
|
| 109 |
"""
|
| 110 |
Load calibration data for activation extraction.
|
| 111 |
|
|
|
|
| 118 |
print(f"[transport] Loading calibration data ({cfg.calibration_samples} samples)...")
|
| 119 |
|
| 120 |
samples = []
|
| 121 |
+
raw_texts = [] # Store raw text for cross-vocab re-tokenization
|
| 122 |
|
| 123 |
# --- Pile: general text (600 samples) ---
|
| 124 |
try:
|
|
|
|
| 141 |
return_tensors="pt",
|
| 142 |
)
|
| 143 |
samples.append(tokens)
|
| 144 |
+
raw_texts.append(text)
|
| 145 |
count += 1
|
| 146 |
if count % 100 == 0:
|
| 147 |
print(f" Pile: {count}/600 samples loaded...")
|
|
|
|
| 173 |
return_tensors="pt",
|
| 174 |
)
|
| 175 |
samples.append(tokens)
|
| 176 |
+
raw_texts.append(str(text))
|
| 177 |
count += 1
|
| 178 |
if count % 100 == 0:
|
| 179 |
print(f" neuralmagic: {count}/{remaining} samples loaded...")
|
|
|
|
| 185 |
tracker.done()
|
| 186 |
print(f"[transport] Total calibration samples: {len(samples)}")
|
| 187 |
sys.stdout.flush()
|
| 188 |
+
return samples, raw_texts
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def retokenize_calibration(raw_texts: list, tokenizer: AutoTokenizer, cfg: MergeConfig) -> list:
|
| 192 |
+
"""
|
| 193 |
+
Re-tokenize calibration texts with a different tokenizer.
|
| 194 |
+
|
| 195 |
+
Used when the source model has a different vocabulary than the target.
|
| 196 |
+
For example, Llama (128K vocab) vs Qwen (152K vocab).
|
| 197 |
+
"""
|
| 198 |
+
print(f"[transport] Re-tokenizing {len(raw_texts)} samples for source model vocabulary...")
|
| 199 |
+
sys.stdout.flush()
|
| 200 |
+
samples = []
|
| 201 |
+
for i, text in enumerate(raw_texts):
|
| 202 |
+
tokens = tokenizer(
|
| 203 |
+
text,
|
| 204 |
+
truncation=True,
|
| 205 |
+
max_length=cfg.calibration_seq_len,
|
| 206 |
+
return_tensors="pt",
|
| 207 |
+
)
|
| 208 |
+
samples.append(tokens)
|
| 209 |
+
if (i + 1) % 500 == 0:
|
| 210 |
+
print(f" Re-tokenized {i + 1}/{len(raw_texts)} samples...")
|
| 211 |
+
sys.stdout.flush()
|
| 212 |
+
print(f"[transport] Re-tokenized {len(samples)} samples for source model")
|
| 213 |
+
sys.stdout.flush()
|
| 214 |
return samples
|
| 215 |
|
| 216 |
|