model_tools / tokensurgeon.py
Naphula's picture
Upload 8 files
5f463e1 verified
# Copyright (C) 2025 Arcee AI
# SPDX-License-Identifier: LGPL-3.0-only
import enum
import logging
from typing import Dict, List, Optional, Tuple
import click
import torch
import torch.distributions.constraints
import tqdm
import transformers
from pydantic import BaseModel
from mergekit.architecture import (
ConfiguredModelArchitecture,
WeightInfo,
arch_info_for_config,
)
from mergekit.common import ModelReference, set_config_value
from mergekit.io.tasks import (
LoaderCache,
)
from mergekit.io.tensor_writer import TensorWriter
from mergekit.options import MergeOptions, PrettyPrintHelp, add_merge_options
from mergekit.tokenizer.normalization import (
NormalizedToken,
normalized_vocabulary,
token_prefixes,
)
from mergekit.tokensurgeon import (
SubwordMethod,
WeightingScheme,
batch_mp_rope,
batch_omp,
common_interp_approximate,
compute_token_basis,
landmark_pca_approximate,
subword_approximate,
well_trained_tokens,
)
from mergekit.tokensurgeon.common_interpolation import DistanceMetric
LOG = logging.getLogger(__name__)
class TokenAssignmentStats(BaseModel):
exact_match: int = 0
byte_match: int = 0
prefix_match: int = 0
to_approximate: int = 0
def pretty_print(self):
chunks = ["Token Breakdown:"]
if self.exact_match:
chunks.append(f" Exact matches: {self.exact_match}")
if self.byte_match:
chunks.append(f" Byte matches: {self.byte_match}")
if self.prefix_match:
chunks.append(f" Prefix matches: {self.prefix_match}")
if self.to_approximate:
chunks.append(f" Tokens to approximate: {self.to_approximate}")
chunks.append(
f" Total: {self.exact_match + self.byte_match + self.prefix_match + self.to_approximate}"
)
return "\n".join(chunks)
class ApproximationMethod(enum.Enum):
COMMON_INTERPOLATION = "common_interpolation"
SUBWORD = "subword"
MEAN = "mean"
ZERO = "zero"
RANDN = "randn"
JOHN_HEWITT = "john_hewitt"
ORTHOGONAL_MATCHING_PURSUIT = "omp"
LANDMARK_PCA = "landmark_pca"
SPARSE_TOKEN_BASIS = "stb"
MATCHING_PURSUIT_ROPE = "mp_rope"
class TokenSurgeonOptions(BaseModel):
model: ModelReference
donor: ModelReference
out_path: str
method: ApproximationMethod = ApproximationMethod.COMMON_INTERPOLATION
weight_scheme: WeightingScheme = WeightingScheme.DISTANCE_PROPORTIONAL
k: int = 64
cosine_similarity: bool = False
subword_method: SubwordMethod = SubwordMethod.MEAN
batch_size: Optional[int] = None
new_vocab_noise: Optional[float] = None
new_vocab_scale: Optional[float] = None
def get_arch_info(
model: ModelReference, options: MergeOptions
) -> ConfiguredModelArchitecture:
cfg = model.config(trust_remote_code=options.trust_remote_code)
arch_info = arch_info_for_config(cfg)
return ConfiguredModelArchitecture(info=arch_info, config=cfg)
def get_embedding_info(
arch_info: ConfiguredModelArchitecture,
) -> Tuple[WeightInfo, WeightInfo]:
"""Get WeightInfo for the input and output embeddings of a model."""
if len(arch_info.info.modules) != 1:
raise RuntimeError("Model has multiple modules - not supported by tokensurgeon")
name = next(iter(arch_info.info.modules.keys()))
module_def = arch_info.get_module(name)
embed, lm_head = None, None
for weight_info in module_def.pre_weights():
if weight_info.is_embed:
if embed is not None:
raise RuntimeError("Multiple input embeddings found")
embed = weight_info
for weight_info in module_def.post_weights():
if weight_info.is_embed:
if lm_head is not None:
raise RuntimeError("Multiple output embeddings found")
lm_head = weight_info
return embed, lm_head
def maybe_aliases(weight_info: WeightInfo, tied: bool) -> Tuple[str, ...]:
return tuple(
list(weight_info.aliases or [])
+ list((weight_info.tied_names or []) if tied else [])
)
def get_stuff(
model: ModelReference,
options: MergeOptions,
arch_info: Optional[ConfiguredModelArchitecture] = None,
get_tied: bool = False,
device: str = "cpu",
) -> Tuple[Dict[NormalizedToken, int], Optional[torch.Tensor], Optional[torch.Tensor]]:
if arch_info is None:
arch_info = get_arch_info(model, options)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model.model.path,
revision=model.model.revision,
trust_remote_code=options.trust_remote_code,
)
vocab = normalized_vocabulary(tokenizer)
embed_wi, lm_head_wi = get_embedding_info(arch_info)
loader = LoaderCache().get(model)
embed = loader.get_tensor(
embed_wi.name,
device=device,
aliases=maybe_aliases(embed_wi, get_tied),
raise_on_missing=not embed_wi.optional,
)
lm_head = loader.get_tensor(
lm_head_wi.name,
device=device,
aliases=maybe_aliases(lm_head_wi, get_tied),
raise_on_missing=not lm_head_wi.optional,
)
return vocab, embed, lm_head
def match_byte_token(
token: NormalizedToken, original_vocab: Dict[NormalizedToken, int]
) -> Optional[int]:
if not isinstance(token, str):
return None
if len(token) == 1 and ord(token) < 256:
# check for matching byte tokens
byte_tok = f"<0x{ord(token):02X}>"
if byte_tok in original_vocab:
return original_vocab[byte_tok]
elif token.startswith("<0x") and token.endswith(">") and len(token) == 6:
# check for character tokens matching byte tokens
try:
byte = int(token[3:-1], 16)
except ValueError:
pass
else:
if chr(byte) in original_vocab:
return original_vocab[chr(byte)]
return None
def match_prefix(
token: NormalizedToken, original_vocab: Dict[NormalizedToken, int]
) -> Optional[int]:
for prefix in token_prefixes(token):
if prefix in original_vocab:
return original_vocab[prefix]
return None
def get_out_arch_info(
model: ModelReference,
donor: ModelReference,
new_vocab_size: int,
common_options: MergeOptions,
) -> ConfiguredModelArchitecture:
cfg_donor = donor.config(trust_remote_code=common_options.trust_remote_code)
cfg_out = model.config(trust_remote_code=common_options.trust_remote_code)
arch_info_out = arch_info_for_config(cfg_out)
set_config_value(
cfg_out, arch_info_out.vocab_size_config_key or "vocab_size", new_vocab_size
)
for key in [
"pad_token_id",
"eos_token_id",
"bos_token_id",
"unk_token_id",
"mask_token_id",
"padding_side",
]:
if hasattr(cfg_donor, key):
set_config_value(cfg_out, key, getattr(cfg_donor, key))
return ConfiguredModelArchitecture(info=arch_info_out, config=cfg_out)
def john_hewitt_init(orig_embed: torch.Tensor, num_new_tokens: int) -> torch.Tensor:
orig_embed_f32 = orig_embed.to(torch.float32)
mean = orig_embed_f32.mean(dim=0)
centered = orig_embed_f32 - mean
covariance = centered.T @ centered / orig_embed_f32.shape[0]
is_pd = torch.distributions.constraints.positive_definite.check(covariance).all()
if not is_pd:
LOG.warning(
"Covariance matrix is not positive definite - falling back to small randn"
)
return (
torch.randn(
len(num_new_tokens),
orig_embed.shape[1],
device=orig_embed.device,
dtype=orig_embed.dtype,
)
* 0.02
)
dist = torch.distributions.multivariate_normal.MultivariateNormal(
loc=mean,
covariance_matrix=covariance,
)
new_embeds = dist.sample((num_new_tokens,))
return new_embeds.to(orig_embed.dtype)
def compute_new_embeddings(
orig_embed: torch.Tensor,
donor_embed: torch.Tensor,
orig_vocab: Dict[NormalizedToken, int],
donor_vocab: Dict[NormalizedToken, int],
target_tokens: List[NormalizedToken],
is_lm_head: bool,
token_basis: Optional[Tuple[torch.Tensor, torch.Tensor]],
orig_tokenizer: transformers.PreTrainedTokenizerBase,
options: TokenSurgeonOptions,
shared_data: Optional[Dict] = None,
compute_device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
assert all(t in donor_vocab for t in target_tokens)
if options.method == ApproximationMethod.MEAN:
mean = orig_embed.mean(dim=0).to(compute_device)
return mean.unsqueeze(0).expand(len(target_tokens), -1)
elif options.method == ApproximationMethod.ZERO:
return torch.zeros(
len(target_tokens),
orig_embed.shape[1],
device=compute_device,
dtype=orig_embed.dtype,
)
elif options.method == ApproximationMethod.RANDN:
return torch.randn(
len(target_tokens),
orig_embed.shape[1],
device=compute_device,
dtype=orig_embed.dtype,
)
elif options.method == ApproximationMethod.JOHN_HEWITT:
return john_hewitt_init(orig_embed.to(compute_device), len(target_tokens))
elif options.method in (
ApproximationMethod.COMMON_INTERPOLATION,
ApproximationMethod.ORTHOGONAL_MATCHING_PURSUIT,
ApproximationMethod.LANDMARK_PCA,
ApproximationMethod.MATCHING_PURSUIT_ROPE,
):
if shared_data is not None:
donor_shared_embeds = shared_data["donor_shared_embeds"].to(compute_device)
orig_shared_embeds = shared_data["orig_shared_embeds"].to(compute_device)
else:
shared_vocab = list(
sorted(
set(orig_vocab.keys()) & set(donor_vocab.keys()),
key=lambda x: donor_vocab[x],
)
)
donor_shared_embeds = donor_embed[
torch.tensor([donor_vocab[t] for t in shared_vocab])
].to(compute_device)
orig_shared_embeds = orig_embed[
torch.tensor([orig_vocab[t] for t in shared_vocab])
].to(compute_device)
res = None
in_donor = None
targets = donor_embed[torch.tensor([donor_vocab[t] for t in target_tokens])].to(compute_device)
if options.method == ApproximationMethod.LANDMARK_PCA:
return landmark_pca_approximate(
targets,
donor_shared_embeds,
orig_shared_embeds,
)
elif options.method == ApproximationMethod.COMMON_INTERPOLATION:
indices, coeffs = common_interp_approximate(
targets,
donor_shared_embeds,
k=options.k,
metric=(
DistanceMetric.COSINE
if options.cosine_similarity
else DistanceMetric.EUCLIDEAN
),
weight_scheme=options.weight_scheme,
)
elif options.method == ApproximationMethod.MATCHING_PURSUIT_ROPE:
model_config = options.model.config(trust_remote_code=False)
donor_config = options.donor.config(trust_remote_code=False)
indices, coeffs, res, in_donor = batch_mp_rope(
targets,
donor_shared_embeds,
orig_shared_embeds,
k=options.k,
num_heads_a=donor_config.num_attention_heads,
num_heads_b=model_config.num_attention_heads,
a_rope_base=donor_config.rope_theta,
b_rope_base=model_config.rope_theta,
)
else:
indices, coeffs = batch_omp(targets, donor_shared_embeds, options.k)
if res is None:
res = (
torch.bmm(
coeffs.unsqueeze(1), orig_shared_embeds[indices].to(torch.float)
)
.squeeze(1)
.to(orig_embed.dtype)
)
return res
elif options.method == ApproximationMethod.SUBWORD:
return subword_approximate(
orig_embed.to(compute_device),
target_tokens,
is_lm_head,
orig_tokenizer,
options.subword_method,
)
elif options.method == ApproximationMethod.SPARSE_TOKEN_BASIS:
assert token_basis is not None, "Token basis must be provided for STB"
donor_basis, orig_basis = token_basis
donor_basis = donor_basis.to(compute_device).to(torch.float32)
orig_basis = orig_basis.to(compute_device).to(torch.float32)
target_donor_embeds = donor_embed[
torch.tensor([donor_vocab[t] for t in target_tokens])
].to(compute_device).to(torch.float32) - donor_embed.mean(dim=0).to(compute_device)
coeffs = torch.linalg.lstsq(
donor_basis.T,
target_donor_embeds.T,
).solution.T
if LOG.isEnabledFor(logging.DEBUG):
donor_rt = coeffs @ donor_basis
err = (donor_rt - target_donor_embeds).norm(dim=1)
err_rel = err / target_donor_embeds.norm(dim=1).clamp_min(1e-6)
sim = torch.nn.functional.cosine_similarity(
donor_rt, target_donor_embeds, dim=1
)
LOG.debug(f"Reconstruction error: {err.mean().item():.4f}")
LOG.debug(f"Relative reconstruction error: {err_rel.mean().item():.4f}")
LOG.debug(f"Cosine similarity: {sim.mean().item():.4f}")
return coeffs @ orig_basis + orig_embed.mean(dim=0).to(compute_device)
else:
raise ValueError(f"Unknown approximation method: {options.method}")
def build_embedding_matrix(
weight_info: WeightInfo,
orig_embed: torch.Tensor,
donor_embed: torch.Tensor,
orig_vocab: Dict[NormalizedToken, int],
donor_vocab: Dict[NormalizedToken, int],
junk_tokens: List[int],
allow_prefix: bool,
allow_byte: bool,
is_lm_head: bool,
options: TokenSurgeonOptions,
compute_device: torch.device,
) -> torch.Tensor:
LOG.info(f"Building new tensor for {weight_info.name}")
stats = TokenAssignmentStats()
out_vocab_size = max(len(donor_vocab), max(donor_vocab.values()) + 1)
if options.method == ApproximationMethod.SPARSE_TOKEN_BASIS:
token_basis = compute_token_basis(
orig_embed,
donor_embed,
orig_vocab,
donor_vocab,
junk_tokens,
options,
)
else:
token_basis = None
res = torch.zeros(
out_vocab_size,
orig_embed.shape[1],
device=orig_embed.device,
dtype=orig_embed.dtype,
)
new_tokens = []
for token, donor_idx in donor_vocab.items():
if token in orig_vocab:
orig_idx = orig_vocab[token]
res[donor_idx] = orig_embed[orig_idx]
stats.exact_match += 1
elif (
allow_byte and (orig_idx := match_byte_token(token, orig_vocab)) is not None
):
res[donor_idx] = orig_embed[orig_idx]
stats.byte_match += 1
elif allow_prefix and (orig_idx := match_prefix(token, orig_vocab)) is not None:
res[donor_idx] = orig_embed[orig_idx]
stats.prefix_match += 1
else:
new_tokens.append(token)
stats.to_approximate += 1
donor_tokenizer = transformers.AutoTokenizer.from_pretrained(
options.donor.model.path,
revision=options.donor.model.revision,
trust_remote_code=True,
)
orig_tokenizer = transformers.AutoTokenizer.from_pretrained(
options.model.model.path,
revision=options.model.model.revision,
trust_remote_code=True,
)
LOG.info(stats.pretty_print())
if new_tokens:
LOG.info(f"Approximating {len(new_tokens)} tokens")
# Precompute shared embeds to avoid doing it in every batch
shared_vocab = list(
sorted(
set(orig_vocab.keys()) & set(donor_vocab.keys()),
key=lambda x: donor_vocab[x],
)
)
donor_shared_embeds = donor_embed[
torch.tensor([donor_vocab[t] for t in shared_vocab])
]
orig_shared_embeds = orig_embed[
torch.tensor([orig_vocab[t] for t in shared_vocab])
]
shared_data = {
"donor_shared_embeds": donor_shared_embeds,
"orig_shared_embeds": orig_shared_embeds,
}
batch_size = options.batch_size
if batch_size is None or batch_size <= 0:
batch_size = 512
# Adaptive batching logic
i = 0
total_tokens = len(new_tokens)
oom_count = 0
pbar = tqdm.tqdm(total=total_tokens, desc="Approximating tokens")
while i < total_tokens:
end = min(i + batch_size, total_tokens)
current_batch = new_tokens[i:end]
try:
new_embeds = compute_new_embeddings(
orig_embed,
donor_embed,
orig_vocab,
donor_vocab,
target_tokens=current_batch,
is_lm_head=is_lm_head,
token_basis=token_basis,
orig_tokenizer=orig_tokenizer,
options=options,
shared_data=shared_data,
compute_device=compute_device,
)
if options.new_vocab_noise:
new_embeds += torch.randn_like(new_embeds) * options.new_vocab_noise
if options.new_vocab_scale:
new_embeds *= options.new_vocab_scale
for ne_idx, token in enumerate(current_batch):
res[donor_vocab[token]] = new_embeds[ne_idx].to(res.device)
# Success, move to next batch
pbar.update(end - i)
i = end
oom_count = 0
# Optional cleanup
if compute_device.type == "cuda":
torch.cuda.empty_cache()
except torch.OutOfMemoryError:
oom_count += 1
if compute_device.type == "cuda":
torch.cuda.empty_cache()
import gc
gc.collect()
old_batch = batch_size
batch_size = max(1, int(batch_size * 0.75))
if batch_size == old_batch and batch_size == 1:
LOG.error("OOM even with batch size 1. Cannot continue.")
raise
LOG.warning(f"OOM error. Reducing batch size from {old_batch} to {batch_size} (attempt {oom_count})")
if oom_count > 10:
LOG.error("Too many OOM errors, giving up.")
raise
pbar.close()
if junk_tokens:
LOG.info(f"Zero-initializing {len(junk_tokens)} junk tokens")
for token_id in junk_tokens:
res[token_id] = torch.zeros(
orig_embed.shape[1],
device=orig_embed.device,
dtype=orig_embed.dtype,
)
return res
class AllowMatch(enum.Enum):
LM_HEAD_ONLY = "lm_head"
EMBED_ONLY = "embed"
YES = "yes"
NO = "no"
@click.command("mergekit-tokensurgeon", cls=PrettyPrintHelp)
@click.argument("model", type=str)
@click.argument("donor", type=str)
@click.argument("out_path", type=str)
@click.option(
"--k",
"-k",
type=int,
default=64,
help="Number of nearest neighbours to use for embedding interpolation",
show_default=True,
)
@click.option(
"--cosine-similarity/--no-cosine-similarity",
"-c/-nc",
is_flag=True,
default=False,
help="Use cosine similarity for nearest neighbour search",
show_default=True,
)
@click.option(
"--approximation-method",
"-a",
type=click.Choice([m.value for m in ApproximationMethod]),
default=ApproximationMethod.ORTHOGONAL_MATCHING_PURSUIT.value,
help="Method for approximating missing tokens",
show_default=True,
)
@click.option(
"--weight-scheme",
"-w",
type=click.Choice([w.value for w in WeightingScheme]),
default=WeightingScheme.DISTANCE_PROPORTIONAL.value,
help="Weighting scheme for common-vocabulary interpolation",
show_default=True,
)
@click.option(
"--subword-method",
"-s",
type=click.Choice([m.value for m in SubwordMethod]),
default=SubwordMethod.MEAN.value,
help="Method for approximating embeddings with subword tokens",
show_default=True,
)
@click.option(
"--batch-size",
type=int,
default=512,
help="Number of tokens to process in each batch. -1 for no batching.",
show_default=True,
)
@click.option(
"--prefix-match",
"-pm",
type=click.Choice([m.value for m in AllowMatch]),
default=AllowMatch.NO.value,
help="Allow prefix match for tokens",
show_default=True,
)
@click.option(
"--byte-match",
"-bm",
type=click.Choice([m.value for m in AllowMatch]),
default=AllowMatch.NO.value,
help="Allow byte match for tokens",
show_default=True,
)
@click.option(
"--magikarp/--no-magikarp",
is_flag=True,
default=False,
help="Filter out poorly trained tokens",
show_default=True,
)
@click.option(
"--new-vocab-noise",
"-nvn",
type=float,
default=None,
help="Add gaussian noise to new vocab embeddings",
show_default=True,
)
@click.option(
"--new-vocab-scale",
"-nvs",
type=float,
default=None,
help="Scale computed new vocab embeddings by this factor",
show_default=True,
)
@add_merge_options
def main(
model: str,
donor: str,
out_path: str,
k: int,
cosine_similarity: bool,
approximation_method: str,
weight_scheme: str,
subword_method: str,
batch_size: Optional[int],
prefix_match: str,
byte_match: str,
magikarp: bool,
new_vocab_noise: Optional[float],
new_vocab_scale: Optional[float],
merge_options: MergeOptions,
):
merge_options.apply_global_options()
logging.warning("This script is experimental and may produce unexpected results.")
options = TokenSurgeonOptions(
model=ModelReference.model_validate(model),
donor=ModelReference.model_validate(donor),
out_path=out_path,
k=k,
cosine_similarity=cosine_similarity,
method=ApproximationMethod(approximation_method),
weight_scheme=WeightingScheme(weight_scheme),
subword_method=SubwordMethod(subword_method),
batch_size=batch_size,
new_vocab_noise=new_vocab_noise,
new_vocab_scale=new_vocab_scale,
)
prefix_match = AllowMatch(prefix_match)
byte_match = AllowMatch(byte_match)
cache = LoaderCache()
cache.setup(options=merge_options)
compute_device = torch.device(merge_options.device if merge_options.device else "cuda" if torch.cuda.is_available() else "cpu")
storage_device = "cpu"
arch_info = get_arch_info(options.model, merge_options)
embed_wi, lm_head_wi = get_embedding_info(arch_info)
orig_vocab, orig_embed, orig_lm_head = get_stuff(
options.model, merge_options, arch_info=arch_info, device=storage_device
)
donor_vocab, donor_embed, donor_lm_head = get_stuff(
options.donor, merge_options, arch_info=None, get_tied=True, device=storage_device
)
if magikarp:
LOG.debug("Finding well-trained tokens in original model")
well_trained_orig_tokens = set(
well_trained_tokens(
orig_vocab,
orig_embed,
orig_lm_head,
)
)
LOG.debug("Finding well-trained tokens in donor model")
well_trained_donor_tokens = set(
well_trained_tokens(
donor_vocab,
donor_embed,
donor_lm_head,
)
)
common_well_trained_tokens = (
well_trained_orig_tokens & well_trained_donor_tokens
)
LOG.info(f"Found {len(common_well_trained_tokens)} common well-trained tokens")
orig_vocab = {
tok: idx
for tok, idx in orig_vocab.items()
if tok in common_well_trained_tokens
}
junk_tokens = [
idx
for tok, idx in donor_vocab.items()
if (tok not in well_trained_donor_tokens)
and (tok not in well_trained_orig_tokens)
]
else:
junk_tokens = []
if orig_embed is not None:
if donor_embed is None:
raise RuntimeError(
f"Missing tensor {embed_wi.name} in model {options.donor}"
)
new_embed = build_embedding_matrix(
embed_wi,
orig_embed,
donor_embed,
orig_vocab=orig_vocab,
donor_vocab=donor_vocab,
junk_tokens=junk_tokens,
allow_prefix=prefix_match in (AllowMatch.YES, AllowMatch.LM_HEAD_ONLY),
allow_byte=byte_match in (AllowMatch.YES, AllowMatch.LM_HEAD_ONLY),
is_lm_head=False,
options=options,
compute_device=compute_device,
)
else:
if not embed_wi.optional:
raise RuntimeError(
f"Missing tensor {embed_wi.name} in model {options.model}"
)
new_embed = None
if orig_lm_head is not None:
if donor_lm_head is None:
raise RuntimeError(
f"Missing tensor {lm_head_wi.name} in model {options.donor}"
)
new_lm_head = build_embedding_matrix(
lm_head_wi,
orig_lm_head,
donor_lm_head,
orig_vocab=orig_vocab,
donor_vocab=donor_vocab,
junk_tokens=junk_tokens,
allow_prefix=prefix_match in (AllowMatch.YES, AllowMatch.EMBED_ONLY),
allow_byte=byte_match in (AllowMatch.YES, AllowMatch.EMBED_ONLY),
is_lm_head=True,
options=options,
compute_device=compute_device,
)
else:
if not lm_head_wi.optional:
raise RuntimeError(
f"Missing tensor {lm_head_wi.name} in model {options.model}"
)
new_lm_head = None
new_vocab_size = None
if new_embed is not None:
new_vocab_size = new_embed.shape[0]
elif new_lm_head is not None:
new_vocab_size = new_lm_head.shape[0]
LOG.info(f"Saving new model to {out_path}")
out_arch_info = get_out_arch_info(
options.model, options.donor, new_vocab_size, merge_options
)
writer = TensorWriter(
out_path,
max_shard_size=merge_options.out_shard_size,
safe_serialization=merge_options.safe_serialization,
use_async=merge_options.async_write,
max_write_threads=merge_options.write_threads,
)
for weight_info in tqdm.tqdm(out_arch_info.all_weights(), desc="Saving weights"):
if weight_info.name == embed_wi.name:
tensor = new_embed
elif lm_head_wi is not None and weight_info.name == lm_head_wi.name:
tensor = new_lm_head
else:
tensor = cache.get(options.model).get_tensor(
weight_info.name, aliases=weight_info.aliases, raise_on_missing=False
)
if tensor is None:
if weight_info.optional:
continue
raise RuntimeError(
f"Missing tensor {weight_info.name} in model {options.model}"
)
writer.save_tensor(weight_info.name, tensor, clone=merge_options.clone_tensors)
# Force close lazy loader file handles so Windows allows deletion/renaming
cache.flush_all()
import gc
gc.collect()
# Delete original safetensors files to prevent FileExistsError during rename
import os
import re
temp_pattern = re.compile(r"^.*-\d+\.safetensors$")
for fname in os.listdir(out_path):
if fname.endswith(".safetensors") and not temp_pattern.match(fname):
try:
os.remove(os.path.join(out_path, fname))
except Exception as e:
LOG.warning(f"Could not remove old file {fname}: {e}")
elif fname == "model.safetensors.index.json":
try:
os.remove(os.path.join(out_path, fname))
except Exception:
pass
writer.finalize()
out_arch_info.config.save_pretrained(out_path)
tokenizer_out = transformers.AutoTokenizer.from_pretrained(
options.donor.model.path,
revision=options.donor.model.revision,
trust_remote_code=merge_options.trust_remote_code,
)
tokenizer_out.save_pretrained(out_path)
# Also copy generation_config.json if it exists in the donor
donor_gen_config = os.path.join(options.donor.model.path, "generation_config.json")
if os.path.exists(donor_gen_config):
import shutil
shutil.copy(donor_gen_config, os.path.join(out_path, "generation_config.json"))
LOG.info("Done!")
if __name__ == "__main__":
main()