td-builder commited on
Commit
52a6e10
·
verified ·
1 Parent(s): 5d61448

Vocab mismatch fix for cross-arch merging

Browse files
Files changed (2) hide show
  1. td_fuse/merge.py +19 -4
  2. td_fuse/transport.py +30 -1
td_fuse/merge.py CHANGED
@@ -39,6 +39,7 @@ from .canary import inject_canary, test_all_canaries
39
  from .transport import (
40
  setup_tm_repo,
41
  load_calibration_data,
 
42
  extract_activations,
43
  compute_transport_plans,
44
  fuse_weights,
@@ -662,6 +663,7 @@ def run_single_merge(
662
  protection: MergeProtection,
663
  residual_bank: ResidualBank = None,
664
  calibration_data: list = None,
 
665
  baseline_perplexity: float = None,
666
  merged_sources: list = None,
667
  ) -> dict:
@@ -717,14 +719,26 @@ def run_single_merge(
717
  print(f"\n[merge] Step 3/10: Loading calibration data..."); sys.stdout.flush()
718
  step_t = time.time()
719
  if calibration_data is None:
720
- calibration_data = load_calibration_data(cfg, target_tokenizer)
721
  print(f"[merge] Step 3/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
722
 
723
  # --- Step 4: Extract activations ---
724
  print(f"\n[merge] Step 4/10: Extracting activations (both models)..."); sys.stdout.flush()
725
  step_t = time.time()
726
- print(f"[merge] Extracting source activations...")
727
- source_activations = extract_activations(source_model, calibration_data)
 
 
 
 
 
 
 
 
 
 
 
 
728
 
729
  print(f"[merge] Extracting target activations...")
730
  pre_merge_target_activations = extract_activations(target_model, calibration_data)
@@ -1101,7 +1115,7 @@ def run_pipeline(
1101
  print(f"[pipeline] Baseline perplexity: {baseline_ppl:.2f}")
1102
 
1103
  # --- Load calibration data once ---
1104
- calibration_data = load_calibration_data(cfg, target_tokenizer)
1105
 
1106
  # --- Initialize merge protection + residual bank ---
1107
  protection = MergeProtection(cfg)
@@ -1138,6 +1152,7 @@ def run_pipeline(
1138
  protection,
1139
  residual_bank=residual_bank,
1140
  calibration_data=calibration_data,
 
1141
  baseline_perplexity=baseline_ppl,
1142
  merged_sources=merged_sources,
1143
  )
 
39
  from .transport import (
40
  setup_tm_repo,
41
  load_calibration_data,
42
+ retokenize_calibration,
43
  extract_activations,
44
  compute_transport_plans,
45
  fuse_weights,
 
663
  protection: MergeProtection,
664
  residual_bank: ResidualBank = None,
665
  calibration_data: list = None,
666
+ calibration_raw_texts: list = None,
667
  baseline_perplexity: float = None,
668
  merged_sources: list = None,
669
  ) -> dict:
 
719
  print(f"\n[merge] Step 3/10: Loading calibration data..."); sys.stdout.flush()
720
  step_t = time.time()
721
  if calibration_data is None:
722
+ calibration_data, calibration_raw_texts = load_calibration_data(cfg, target_tokenizer)
723
  print(f"[merge] Step 3/10 done in {time.time()-step_t:.0f}s"); sys.stdout.flush()
724
 
725
  # --- Step 4: Extract activations ---
726
  print(f"\n[merge] Step 4/10: Extracting activations (both models)..."); sys.stdout.flush()
727
  step_t = time.time()
728
+ # Check if source model has a different vocabulary size than target.
729
+ source_vocab_size = source_model.config.vocab_size if hasattr(source_model.config, 'vocab_size') else None
730
+ target_vocab_size = target_model.config.vocab_size if hasattr(target_model.config, 'vocab_size') else None
731
+ print(f"[merge] Vocab sizes -- target: {target_vocab_size}, source: {source_vocab_size}")
732
+
733
+ if source_vocab_size and target_vocab_size and source_vocab_size != target_vocab_size:
734
+ print(f"[merge] VOCAB MISMATCH detected! Re-tokenizing calibration data for {source_config.name}...")
735
+ source_calibration = retokenize_calibration(calibration_raw_texts, source_tokenizer, cfg)
736
+ print(f"[merge] Extracting source activations (with source-tokenized data)...")
737
+ source_activations = extract_activations(source_model, source_calibration)
738
+ del source_calibration
739
+ else:
740
+ print(f"[merge] Extracting source activations...")
741
+ source_activations = extract_activations(source_model, calibration_data)
742
 
743
  print(f"[merge] Extracting target activations...")
744
  pre_merge_target_activations = extract_activations(target_model, calibration_data)
 
1115
  print(f"[pipeline] Baseline perplexity: {baseline_ppl:.2f}")
1116
 
1117
  # --- Load calibration data once ---
1118
+ calibration_data, calibration_raw_texts = load_calibration_data(cfg, target_tokenizer)
1119
 
1120
  # --- Initialize merge protection + residual bank ---
1121
  protection = MergeProtection(cfg)
 
1152
  protection,
1153
  residual_bank=residual_bank,
1154
  calibration_data=calibration_data,
1155
+ calibration_raw_texts=calibration_raw_texts,
1156
  baseline_perplexity=baseline_ppl,
1157
  merged_sources=merged_sources,
1158
  )
td_fuse/transport.py CHANGED
@@ -105,7 +105,7 @@ def setup_tm_repo(cfg: MergeConfig):
105
  print(f"[transport] Added T&M core to path: {core_path}")
106
 
107
 
108
- def load_calibration_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
109
  """
110
  Load calibration data for activation extraction.
111
 
@@ -118,6 +118,7 @@ def load_calibration_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
118
  print(f"[transport] Loading calibration data ({cfg.calibration_samples} samples)...")
119
 
120
  samples = []
 
121
 
122
  # --- Pile: general text (600 samples) ---
123
  try:
@@ -140,6 +141,7 @@ def load_calibration_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
140
  return_tensors="pt",
141
  )
142
  samples.append(tokens)
 
143
  count += 1
144
  if count % 100 == 0:
145
  print(f" Pile: {count}/600 samples loaded...")
@@ -171,6 +173,7 @@ def load_calibration_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
171
  return_tensors="pt",
172
  )
173
  samples.append(tokens)
 
174
  count += 1
175
  if count % 100 == 0:
176
  print(f" neuralmagic: {count}/{remaining} samples loaded...")
@@ -182,6 +185,32 @@ def load_calibration_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> list:
182
  tracker.done()
183
  print(f"[transport] Total calibration samples: {len(samples)}")
184
  sys.stdout.flush()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  return samples
186
 
187
 
 
105
  print(f"[transport] Added T&M core to path: {core_path}")
106
 
107
 
108
+ def load_calibration_data(cfg: MergeConfig, tokenizer: AutoTokenizer) -> tuple:
109
  """
110
  Load calibration data for activation extraction.
111
 
 
118
  print(f"[transport] Loading calibration data ({cfg.calibration_samples} samples)...")
119
 
120
  samples = []
121
+ raw_texts = [] # Store raw text for cross-vocab re-tokenization
122
 
123
  # --- Pile: general text (600 samples) ---
124
  try:
 
141
  return_tensors="pt",
142
  )
143
  samples.append(tokens)
144
+ raw_texts.append(text)
145
  count += 1
146
  if count % 100 == 0:
147
  print(f" Pile: {count}/600 samples loaded...")
 
173
  return_tensors="pt",
174
  )
175
  samples.append(tokens)
176
+ raw_texts.append(str(text))
177
  count += 1
178
  if count % 100 == 0:
179
  print(f" neuralmagic: {count}/{remaining} samples loaded...")
 
185
  tracker.done()
186
  print(f"[transport] Total calibration samples: {len(samples)}")
187
  sys.stdout.flush()
188
+ return samples, raw_texts
189
+
190
+
191
+ def retokenize_calibration(raw_texts: list, tokenizer: AutoTokenizer, cfg: MergeConfig) -> list:
192
+ """
193
+ Re-tokenize calibration texts with a different tokenizer.
194
+
195
+ Used when the source model has a different vocabulary than the target.
196
+ For example, Llama (128K vocab) vs Qwen (152K vocab).
197
+ """
198
+ print(f"[transport] Re-tokenizing {len(raw_texts)} samples for source model vocabulary...")
199
+ sys.stdout.flush()
200
+ samples = []
201
+ for i, text in enumerate(raw_texts):
202
+ tokens = tokenizer(
203
+ text,
204
+ truncation=True,
205
+ max_length=cfg.calibration_seq_len,
206
+ return_tensors="pt",
207
+ )
208
+ samples.append(tokens)
209
+ if (i + 1) % 500 == 0:
210
+ print(f" Re-tokenized {i + 1}/{len(raw_texts)} samples...")
211
+ sys.stdout.flush()
212
+ print(f"[transport] Re-tokenized {len(samples)} samples for source model")
213
+ sys.stdout.flush()
214
  return samples
215
 
216