Spaces:
Running
Running
| # Copyright (C) 2025 Arcee AI | |
| # SPDX-License-Identifier: LGPL-3.0-only | |
| import enum | |
| import logging | |
| from typing import Dict, List, Optional, Tuple | |
| import click | |
| import torch | |
| import torch.distributions.constraints | |
| import tqdm | |
| import transformers | |
| from pydantic import BaseModel | |
| from mergekit.architecture import ( | |
| ConfiguredModelArchitecture, | |
| WeightInfo, | |
| arch_info_for_config, | |
| ) | |
| from mergekit.common import ModelReference, set_config_value | |
| from mergekit.io.tasks import ( | |
| LoaderCache, | |
| ) | |
| from mergekit.io.tensor_writer import TensorWriter | |
| from mergekit.options import MergeOptions, PrettyPrintHelp, add_merge_options | |
| from mergekit.tokenizer.normalization import ( | |
| NormalizedToken, | |
| normalized_vocabulary, | |
| token_prefixes, | |
| ) | |
| from mergekit.tokensurgeon import ( | |
| SubwordMethod, | |
| WeightingScheme, | |
| batch_mp_rope, | |
| batch_omp, | |
| common_interp_approximate, | |
| compute_token_basis, | |
| landmark_pca_approximate, | |
| subword_approximate, | |
| well_trained_tokens, | |
| ) | |
| from mergekit.tokensurgeon.common_interpolation import DistanceMetric | |
| LOG = logging.getLogger(__name__) | |
| class TokenAssignmentStats(BaseModel): | |
| exact_match: int = 0 | |
| byte_match: int = 0 | |
| prefix_match: int = 0 | |
| to_approximate: int = 0 | |
| def pretty_print(self): | |
| chunks = ["Token Breakdown:"] | |
| if self.exact_match: | |
| chunks.append(f" Exact matches: {self.exact_match}") | |
| if self.byte_match: | |
| chunks.append(f" Byte matches: {self.byte_match}") | |
| if self.prefix_match: | |
| chunks.append(f" Prefix matches: {self.prefix_match}") | |
| if self.to_approximate: | |
| chunks.append(f" Tokens to approximate: {self.to_approximate}") | |
| chunks.append( | |
| f" Total: {self.exact_match + self.byte_match + self.prefix_match + self.to_approximate}" | |
| ) | |
| return "\n".join(chunks) | |
| class ApproximationMethod(enum.Enum): | |
| COMMON_INTERPOLATION = "common_interpolation" | |
| SUBWORD = "subword" | |
| MEAN = "mean" | |
| ZERO = "zero" | |
| RANDN = "randn" | |
| JOHN_HEWITT = "john_hewitt" | |
| ORTHOGONAL_MATCHING_PURSUIT = "omp" | |
| LANDMARK_PCA = "landmark_pca" | |
| SPARSE_TOKEN_BASIS = "stb" | |
| MATCHING_PURSUIT_ROPE = "mp_rope" | |
| class TokenSurgeonOptions(BaseModel): | |
| model: ModelReference | |
| donor: ModelReference | |
| out_path: str | |
| method: ApproximationMethod = ApproximationMethod.COMMON_INTERPOLATION | |
| weight_scheme: WeightingScheme = WeightingScheme.DISTANCE_PROPORTIONAL | |
| k: int = 64 | |
| cosine_similarity: bool = False | |
| subword_method: SubwordMethod = SubwordMethod.MEAN | |
| batch_size: Optional[int] = None | |
| new_vocab_noise: Optional[float] = None | |
| new_vocab_scale: Optional[float] = None | |
| def get_arch_info( | |
| model: ModelReference, options: MergeOptions | |
| ) -> ConfiguredModelArchitecture: | |
| cfg = model.config(trust_remote_code=options.trust_remote_code) | |
| arch_info = arch_info_for_config(cfg) | |
| return ConfiguredModelArchitecture(info=arch_info, config=cfg) | |
| def get_embedding_info( | |
| arch_info: ConfiguredModelArchitecture, | |
| ) -> Tuple[WeightInfo, WeightInfo]: | |
| """Get WeightInfo for the input and output embeddings of a model.""" | |
| if len(arch_info.info.modules) != 1: | |
| raise RuntimeError("Model has multiple modules - not supported by tokensurgeon") | |
| name = next(iter(arch_info.info.modules.keys())) | |
| module_def = arch_info.get_module(name) | |
| embed, lm_head = None, None | |
| for weight_info in module_def.pre_weights(): | |
| if weight_info.is_embed: | |
| if embed is not None: | |
| raise RuntimeError("Multiple input embeddings found") | |
| embed = weight_info | |
| for weight_info in module_def.post_weights(): | |
| if weight_info.is_embed: | |
| if lm_head is not None: | |
| raise RuntimeError("Multiple output embeddings found") | |
| lm_head = weight_info | |
| return embed, lm_head | |
| def maybe_aliases(weight_info: WeightInfo, tied: bool) -> Tuple[str, ...]: | |
| return tuple( | |
| list(weight_info.aliases or []) | |
| + list((weight_info.tied_names or []) if tied else []) | |
| ) | |
| def get_stuff( | |
| model: ModelReference, | |
| options: MergeOptions, | |
| arch_info: Optional[ConfiguredModelArchitecture] = None, | |
| get_tied: bool = False, | |
| device: str = "cpu", | |
| ) -> Tuple[Dict[NormalizedToken, int], Optional[torch.Tensor], Optional[torch.Tensor]]: | |
| if arch_info is None: | |
| arch_info = get_arch_info(model, options) | |
| tokenizer = transformers.AutoTokenizer.from_pretrained( | |
| model.model.path, | |
| revision=model.model.revision, | |
| trust_remote_code=options.trust_remote_code, | |
| ) | |
| vocab = normalized_vocabulary(tokenizer) | |
| embed_wi, lm_head_wi = get_embedding_info(arch_info) | |
| loader = LoaderCache().get(model) | |
| embed = loader.get_tensor( | |
| embed_wi.name, | |
| device=device, | |
| aliases=maybe_aliases(embed_wi, get_tied), | |
| raise_on_missing=not embed_wi.optional, | |
| ) | |
| lm_head = loader.get_tensor( | |
| lm_head_wi.name, | |
| device=device, | |
| aliases=maybe_aliases(lm_head_wi, get_tied), | |
| raise_on_missing=not lm_head_wi.optional, | |
| ) | |
| return vocab, embed, lm_head | |
| def match_byte_token( | |
| token: NormalizedToken, original_vocab: Dict[NormalizedToken, int] | |
| ) -> Optional[int]: | |
| if not isinstance(token, str): | |
| return None | |
| if len(token) == 1 and ord(token) < 256: | |
| # check for matching byte tokens | |
| byte_tok = f"<0x{ord(token):02X}>" | |
| if byte_tok in original_vocab: | |
| return original_vocab[byte_tok] | |
| elif token.startswith("<0x") and token.endswith(">") and len(token) == 6: | |
| # check for character tokens matching byte tokens | |
| try: | |
| byte = int(token[3:-1], 16) | |
| except ValueError: | |
| pass | |
| else: | |
| if chr(byte) in original_vocab: | |
| return original_vocab[chr(byte)] | |
| return None | |
| def match_prefix( | |
| token: NormalizedToken, original_vocab: Dict[NormalizedToken, int] | |
| ) -> Optional[int]: | |
| for prefix in token_prefixes(token): | |
| if prefix in original_vocab: | |
| return original_vocab[prefix] | |
| return None | |
| def get_out_arch_info( | |
| model: ModelReference, | |
| donor: ModelReference, | |
| new_vocab_size: int, | |
| common_options: MergeOptions, | |
| ) -> ConfiguredModelArchitecture: | |
| cfg_donor = donor.config(trust_remote_code=common_options.trust_remote_code) | |
| cfg_out = model.config(trust_remote_code=common_options.trust_remote_code) | |
| arch_info_out = arch_info_for_config(cfg_out) | |
| set_config_value( | |
| cfg_out, arch_info_out.vocab_size_config_key or "vocab_size", new_vocab_size | |
| ) | |
| for key in [ | |
| "pad_token_id", | |
| "eos_token_id", | |
| "bos_token_id", | |
| "unk_token_id", | |
| "mask_token_id", | |
| "padding_side", | |
| ]: | |
| if hasattr(cfg_donor, key): | |
| set_config_value(cfg_out, key, getattr(cfg_donor, key)) | |
| return ConfiguredModelArchitecture(info=arch_info_out, config=cfg_out) | |
| def john_hewitt_init(orig_embed: torch.Tensor, num_new_tokens: int) -> torch.Tensor: | |
| orig_embed_f32 = orig_embed.to(torch.float32) | |
| mean = orig_embed_f32.mean(dim=0) | |
| centered = orig_embed_f32 - mean | |
| covariance = centered.T @ centered / orig_embed_f32.shape[0] | |
| is_pd = torch.distributions.constraints.positive_definite.check(covariance).all() | |
| if not is_pd: | |
| LOG.warning( | |
| "Covariance matrix is not positive definite - falling back to small randn" | |
| ) | |
| return ( | |
| torch.randn( | |
| len(num_new_tokens), | |
| orig_embed.shape[1], | |
| device=orig_embed.device, | |
| dtype=orig_embed.dtype, | |
| ) | |
| * 0.02 | |
| ) | |
| dist = torch.distributions.multivariate_normal.MultivariateNormal( | |
| loc=mean, | |
| covariance_matrix=covariance, | |
| ) | |
| new_embeds = dist.sample((num_new_tokens,)) | |
| return new_embeds.to(orig_embed.dtype) | |
| def compute_new_embeddings( | |
| orig_embed: torch.Tensor, | |
| donor_embed: torch.Tensor, | |
| orig_vocab: Dict[NormalizedToken, int], | |
| donor_vocab: Dict[NormalizedToken, int], | |
| target_tokens: List[NormalizedToken], | |
| is_lm_head: bool, | |
| token_basis: Optional[Tuple[torch.Tensor, torch.Tensor]], | |
| orig_tokenizer: transformers.PreTrainedTokenizerBase, | |
| options: TokenSurgeonOptions, | |
| shared_data: Optional[Dict] = None, | |
| compute_device: torch.device = torch.device("cpu"), | |
| ) -> torch.Tensor: | |
| assert all(t in donor_vocab for t in target_tokens) | |
| if options.method == ApproximationMethod.MEAN: | |
| mean = orig_embed.mean(dim=0).to(compute_device) | |
| return mean.unsqueeze(0).expand(len(target_tokens), -1) | |
| elif options.method == ApproximationMethod.ZERO: | |
| return torch.zeros( | |
| len(target_tokens), | |
| orig_embed.shape[1], | |
| device=compute_device, | |
| dtype=orig_embed.dtype, | |
| ) | |
| elif options.method == ApproximationMethod.RANDN: | |
| return torch.randn( | |
| len(target_tokens), | |
| orig_embed.shape[1], | |
| device=compute_device, | |
| dtype=orig_embed.dtype, | |
| ) | |
| elif options.method == ApproximationMethod.JOHN_HEWITT: | |
| return john_hewitt_init(orig_embed.to(compute_device), len(target_tokens)) | |
| elif options.method in ( | |
| ApproximationMethod.COMMON_INTERPOLATION, | |
| ApproximationMethod.ORTHOGONAL_MATCHING_PURSUIT, | |
| ApproximationMethod.LANDMARK_PCA, | |
| ApproximationMethod.MATCHING_PURSUIT_ROPE, | |
| ): | |
| if shared_data is not None: | |
| donor_shared_embeds = shared_data["donor_shared_embeds"].to(compute_device) | |
| orig_shared_embeds = shared_data["orig_shared_embeds"].to(compute_device) | |
| else: | |
| shared_vocab = list( | |
| sorted( | |
| set(orig_vocab.keys()) & set(donor_vocab.keys()), | |
| key=lambda x: donor_vocab[x], | |
| ) | |
| ) | |
| donor_shared_embeds = donor_embed[ | |
| torch.tensor([donor_vocab[t] for t in shared_vocab]) | |
| ].to(compute_device) | |
| orig_shared_embeds = orig_embed[ | |
| torch.tensor([orig_vocab[t] for t in shared_vocab]) | |
| ].to(compute_device) | |
| res = None | |
| in_donor = None | |
| targets = donor_embed[torch.tensor([donor_vocab[t] for t in target_tokens])].to(compute_device) | |
| if options.method == ApproximationMethod.LANDMARK_PCA: | |
| return landmark_pca_approximate( | |
| targets, | |
| donor_shared_embeds, | |
| orig_shared_embeds, | |
| ) | |
| elif options.method == ApproximationMethod.COMMON_INTERPOLATION: | |
| indices, coeffs = common_interp_approximate( | |
| targets, | |
| donor_shared_embeds, | |
| k=options.k, | |
| metric=( | |
| DistanceMetric.COSINE | |
| if options.cosine_similarity | |
| else DistanceMetric.EUCLIDEAN | |
| ), | |
| weight_scheme=options.weight_scheme, | |
| ) | |
| elif options.method == ApproximationMethod.MATCHING_PURSUIT_ROPE: | |
| model_config = options.model.config(trust_remote_code=False) | |
| donor_config = options.donor.config(trust_remote_code=False) | |
| indices, coeffs, res, in_donor = batch_mp_rope( | |
| targets, | |
| donor_shared_embeds, | |
| orig_shared_embeds, | |
| k=options.k, | |
| num_heads_a=donor_config.num_attention_heads, | |
| num_heads_b=model_config.num_attention_heads, | |
| a_rope_base=donor_config.rope_theta, | |
| b_rope_base=model_config.rope_theta, | |
| ) | |
| else: | |
| indices, coeffs = batch_omp(targets, donor_shared_embeds, options.k) | |
| if res is None: | |
| res = ( | |
| torch.bmm( | |
| coeffs.unsqueeze(1), orig_shared_embeds[indices].to(torch.float) | |
| ) | |
| .squeeze(1) | |
| .to(orig_embed.dtype) | |
| ) | |
| return res | |
| elif options.method == ApproximationMethod.SUBWORD: | |
| return subword_approximate( | |
| orig_embed.to(compute_device), | |
| target_tokens, | |
| is_lm_head, | |
| orig_tokenizer, | |
| options.subword_method, | |
| ) | |
| elif options.method == ApproximationMethod.SPARSE_TOKEN_BASIS: | |
| assert token_basis is not None, "Token basis must be provided for STB" | |
| donor_basis, orig_basis = token_basis | |
| donor_basis = donor_basis.to(compute_device).to(torch.float32) | |
| orig_basis = orig_basis.to(compute_device).to(torch.float32) | |
| target_donor_embeds = donor_embed[ | |
| torch.tensor([donor_vocab[t] for t in target_tokens]) | |
| ].to(compute_device).to(torch.float32) - donor_embed.mean(dim=0).to(compute_device) | |
| coeffs = torch.linalg.lstsq( | |
| donor_basis.T, | |
| target_donor_embeds.T, | |
| ).solution.T | |
| if LOG.isEnabledFor(logging.DEBUG): | |
| donor_rt = coeffs @ donor_basis | |
| err = (donor_rt - target_donor_embeds).norm(dim=1) | |
| err_rel = err / target_donor_embeds.norm(dim=1).clamp_min(1e-6) | |
| sim = torch.nn.functional.cosine_similarity( | |
| donor_rt, target_donor_embeds, dim=1 | |
| ) | |
| LOG.debug(f"Reconstruction error: {err.mean().item():.4f}") | |
| LOG.debug(f"Relative reconstruction error: {err_rel.mean().item():.4f}") | |
| LOG.debug(f"Cosine similarity: {sim.mean().item():.4f}") | |
| return coeffs @ orig_basis + orig_embed.mean(dim=0).to(compute_device) | |
| else: | |
| raise ValueError(f"Unknown approximation method: {options.method}") | |
| def build_embedding_matrix( | |
| weight_info: WeightInfo, | |
| orig_embed: torch.Tensor, | |
| donor_embed: torch.Tensor, | |
| orig_vocab: Dict[NormalizedToken, int], | |
| donor_vocab: Dict[NormalizedToken, int], | |
| junk_tokens: List[int], | |
| allow_prefix: bool, | |
| allow_byte: bool, | |
| is_lm_head: bool, | |
| options: TokenSurgeonOptions, | |
| compute_device: torch.device, | |
| ) -> torch.Tensor: | |
| LOG.info(f"Building new tensor for {weight_info.name}") | |
| stats = TokenAssignmentStats() | |
| out_vocab_size = max(len(donor_vocab), max(donor_vocab.values()) + 1) | |
| if options.method == ApproximationMethod.SPARSE_TOKEN_BASIS: | |
| token_basis = compute_token_basis( | |
| orig_embed, | |
| donor_embed, | |
| orig_vocab, | |
| donor_vocab, | |
| junk_tokens, | |
| options, | |
| ) | |
| else: | |
| token_basis = None | |
| res = torch.zeros( | |
| out_vocab_size, | |
| orig_embed.shape[1], | |
| device=orig_embed.device, | |
| dtype=orig_embed.dtype, | |
| ) | |
| new_tokens = [] | |
| for token, donor_idx in donor_vocab.items(): | |
| if token in orig_vocab: | |
| orig_idx = orig_vocab[token] | |
| res[donor_idx] = orig_embed[orig_idx] | |
| stats.exact_match += 1 | |
| elif ( | |
| allow_byte and (orig_idx := match_byte_token(token, orig_vocab)) is not None | |
| ): | |
| res[donor_idx] = orig_embed[orig_idx] | |
| stats.byte_match += 1 | |
| elif allow_prefix and (orig_idx := match_prefix(token, orig_vocab)) is not None: | |
| res[donor_idx] = orig_embed[orig_idx] | |
| stats.prefix_match += 1 | |
| else: | |
| new_tokens.append(token) | |
| stats.to_approximate += 1 | |
| donor_tokenizer = transformers.AutoTokenizer.from_pretrained( | |
| options.donor.model.path, | |
| revision=options.donor.model.revision, | |
| trust_remote_code=True, | |
| ) | |
| orig_tokenizer = transformers.AutoTokenizer.from_pretrained( | |
| options.model.model.path, | |
| revision=options.model.model.revision, | |
| trust_remote_code=True, | |
| ) | |
| LOG.info(stats.pretty_print()) | |
| if new_tokens: | |
| LOG.info(f"Approximating {len(new_tokens)} tokens") | |
| # Precompute shared embeds to avoid doing it in every batch | |
| shared_vocab = list( | |
| sorted( | |
| set(orig_vocab.keys()) & set(donor_vocab.keys()), | |
| key=lambda x: donor_vocab[x], | |
| ) | |
| ) | |
| donor_shared_embeds = donor_embed[ | |
| torch.tensor([donor_vocab[t] for t in shared_vocab]) | |
| ] | |
| orig_shared_embeds = orig_embed[ | |
| torch.tensor([orig_vocab[t] for t in shared_vocab]) | |
| ] | |
| shared_data = { | |
| "donor_shared_embeds": donor_shared_embeds, | |
| "orig_shared_embeds": orig_shared_embeds, | |
| } | |
| batch_size = options.batch_size | |
| if batch_size is None or batch_size <= 0: | |
| batch_size = 512 | |
| # Adaptive batching logic | |
| i = 0 | |
| total_tokens = len(new_tokens) | |
| oom_count = 0 | |
| pbar = tqdm.tqdm(total=total_tokens, desc="Approximating tokens") | |
| while i < total_tokens: | |
| end = min(i + batch_size, total_tokens) | |
| current_batch = new_tokens[i:end] | |
| try: | |
| new_embeds = compute_new_embeddings( | |
| orig_embed, | |
| donor_embed, | |
| orig_vocab, | |
| donor_vocab, | |
| target_tokens=current_batch, | |
| is_lm_head=is_lm_head, | |
| token_basis=token_basis, | |
| orig_tokenizer=orig_tokenizer, | |
| options=options, | |
| shared_data=shared_data, | |
| compute_device=compute_device, | |
| ) | |
| if options.new_vocab_noise: | |
| new_embeds += torch.randn_like(new_embeds) * options.new_vocab_noise | |
| if options.new_vocab_scale: | |
| new_embeds *= options.new_vocab_scale | |
| for ne_idx, token in enumerate(current_batch): | |
| res[donor_vocab[token]] = new_embeds[ne_idx].to(res.device) | |
| # Success, move to next batch | |
| pbar.update(end - i) | |
| i = end | |
| oom_count = 0 | |
| # Optional cleanup | |
| if compute_device.type == "cuda": | |
| torch.cuda.empty_cache() | |
| except torch.OutOfMemoryError: | |
| oom_count += 1 | |
| if compute_device.type == "cuda": | |
| torch.cuda.empty_cache() | |
| import gc | |
| gc.collect() | |
| old_batch = batch_size | |
| batch_size = max(1, int(batch_size * 0.75)) | |
| if batch_size == old_batch and batch_size == 1: | |
| LOG.error("OOM even with batch size 1. Cannot continue.") | |
| raise | |
| LOG.warning(f"OOM error. Reducing batch size from {old_batch} to {batch_size} (attempt {oom_count})") | |
| if oom_count > 10: | |
| LOG.error("Too many OOM errors, giving up.") | |
| raise | |
| pbar.close() | |
| if junk_tokens: | |
| LOG.info(f"Zero-initializing {len(junk_tokens)} junk tokens") | |
| for token_id in junk_tokens: | |
| res[token_id] = torch.zeros( | |
| orig_embed.shape[1], | |
| device=orig_embed.device, | |
| dtype=orig_embed.dtype, | |
| ) | |
| return res | |
| class AllowMatch(enum.Enum): | |
| LM_HEAD_ONLY = "lm_head" | |
| EMBED_ONLY = "embed" | |
| YES = "yes" | |
| NO = "no" | |
| def main( | |
| model: str, | |
| donor: str, | |
| out_path: str, | |
| k: int, | |
| cosine_similarity: bool, | |
| approximation_method: str, | |
| weight_scheme: str, | |
| subword_method: str, | |
| batch_size: Optional[int], | |
| prefix_match: str, | |
| byte_match: str, | |
| magikarp: bool, | |
| new_vocab_noise: Optional[float], | |
| new_vocab_scale: Optional[float], | |
| merge_options: MergeOptions, | |
| ): | |
| merge_options.apply_global_options() | |
| logging.warning("This script is experimental and may produce unexpected results.") | |
| options = TokenSurgeonOptions( | |
| model=ModelReference.model_validate(model), | |
| donor=ModelReference.model_validate(donor), | |
| out_path=out_path, | |
| k=k, | |
| cosine_similarity=cosine_similarity, | |
| method=ApproximationMethod(approximation_method), | |
| weight_scheme=WeightingScheme(weight_scheme), | |
| subword_method=SubwordMethod(subword_method), | |
| batch_size=batch_size, | |
| new_vocab_noise=new_vocab_noise, | |
| new_vocab_scale=new_vocab_scale, | |
| ) | |
| prefix_match = AllowMatch(prefix_match) | |
| byte_match = AllowMatch(byte_match) | |
| cache = LoaderCache() | |
| cache.setup(options=merge_options) | |
| compute_device = torch.device(merge_options.device if merge_options.device else "cuda" if torch.cuda.is_available() else "cpu") | |
| storage_device = "cpu" | |
| arch_info = get_arch_info(options.model, merge_options) | |
| embed_wi, lm_head_wi = get_embedding_info(arch_info) | |
| orig_vocab, orig_embed, orig_lm_head = get_stuff( | |
| options.model, merge_options, arch_info=arch_info, device=storage_device | |
| ) | |
| donor_vocab, donor_embed, donor_lm_head = get_stuff( | |
| options.donor, merge_options, arch_info=None, get_tied=True, device=storage_device | |
| ) | |
| if magikarp: | |
| LOG.debug("Finding well-trained tokens in original model") | |
| well_trained_orig_tokens = set( | |
| well_trained_tokens( | |
| orig_vocab, | |
| orig_embed, | |
| orig_lm_head, | |
| ) | |
| ) | |
| LOG.debug("Finding well-trained tokens in donor model") | |
| well_trained_donor_tokens = set( | |
| well_trained_tokens( | |
| donor_vocab, | |
| donor_embed, | |
| donor_lm_head, | |
| ) | |
| ) | |
| common_well_trained_tokens = ( | |
| well_trained_orig_tokens & well_trained_donor_tokens | |
| ) | |
| LOG.info(f"Found {len(common_well_trained_tokens)} common well-trained tokens") | |
| orig_vocab = { | |
| tok: idx | |
| for tok, idx in orig_vocab.items() | |
| if tok in common_well_trained_tokens | |
| } | |
| junk_tokens = [ | |
| idx | |
| for tok, idx in donor_vocab.items() | |
| if (tok not in well_trained_donor_tokens) | |
| and (tok not in well_trained_orig_tokens) | |
| ] | |
| else: | |
| junk_tokens = [] | |
| if orig_embed is not None: | |
| if donor_embed is None: | |
| raise RuntimeError( | |
| f"Missing tensor {embed_wi.name} in model {options.donor}" | |
| ) | |
| new_embed = build_embedding_matrix( | |
| embed_wi, | |
| orig_embed, | |
| donor_embed, | |
| orig_vocab=orig_vocab, | |
| donor_vocab=donor_vocab, | |
| junk_tokens=junk_tokens, | |
| allow_prefix=prefix_match in (AllowMatch.YES, AllowMatch.LM_HEAD_ONLY), | |
| allow_byte=byte_match in (AllowMatch.YES, AllowMatch.LM_HEAD_ONLY), | |
| is_lm_head=False, | |
| options=options, | |
| compute_device=compute_device, | |
| ) | |
| else: | |
| if not embed_wi.optional: | |
| raise RuntimeError( | |
| f"Missing tensor {embed_wi.name} in model {options.model}" | |
| ) | |
| new_embed = None | |
| if orig_lm_head is not None: | |
| if donor_lm_head is None: | |
| raise RuntimeError( | |
| f"Missing tensor {lm_head_wi.name} in model {options.donor}" | |
| ) | |
| new_lm_head = build_embedding_matrix( | |
| lm_head_wi, | |
| orig_lm_head, | |
| donor_lm_head, | |
| orig_vocab=orig_vocab, | |
| donor_vocab=donor_vocab, | |
| junk_tokens=junk_tokens, | |
| allow_prefix=prefix_match in (AllowMatch.YES, AllowMatch.EMBED_ONLY), | |
| allow_byte=byte_match in (AllowMatch.YES, AllowMatch.EMBED_ONLY), | |
| is_lm_head=True, | |
| options=options, | |
| compute_device=compute_device, | |
| ) | |
| else: | |
| if not lm_head_wi.optional: | |
| raise RuntimeError( | |
| f"Missing tensor {lm_head_wi.name} in model {options.model}" | |
| ) | |
| new_lm_head = None | |
| new_vocab_size = None | |
| if new_embed is not None: | |
| new_vocab_size = new_embed.shape[0] | |
| elif new_lm_head is not None: | |
| new_vocab_size = new_lm_head.shape[0] | |
| LOG.info(f"Saving new model to {out_path}") | |
| out_arch_info = get_out_arch_info( | |
| options.model, options.donor, new_vocab_size, merge_options | |
| ) | |
| writer = TensorWriter( | |
| out_path, | |
| max_shard_size=merge_options.out_shard_size, | |
| safe_serialization=merge_options.safe_serialization, | |
| use_async=merge_options.async_write, | |
| max_write_threads=merge_options.write_threads, | |
| ) | |
| for weight_info in tqdm.tqdm(out_arch_info.all_weights(), desc="Saving weights"): | |
| if weight_info.name == embed_wi.name: | |
| tensor = new_embed | |
| elif lm_head_wi is not None and weight_info.name == lm_head_wi.name: | |
| tensor = new_lm_head | |
| else: | |
| tensor = cache.get(options.model).get_tensor( | |
| weight_info.name, aliases=weight_info.aliases, raise_on_missing=False | |
| ) | |
| if tensor is None: | |
| if weight_info.optional: | |
| continue | |
| raise RuntimeError( | |
| f"Missing tensor {weight_info.name} in model {options.model}" | |
| ) | |
| writer.save_tensor(weight_info.name, tensor, clone=merge_options.clone_tensors) | |
| # Force close lazy loader file handles so Windows allows deletion/renaming | |
| cache.flush_all() | |
| import gc | |
| gc.collect() | |
| # Delete original safetensors files to prevent FileExistsError during rename | |
| import os | |
| import re | |
| temp_pattern = re.compile(r"^.*-\d+\.safetensors$") | |
| for fname in os.listdir(out_path): | |
| if fname.endswith(".safetensors") and not temp_pattern.match(fname): | |
| try: | |
| os.remove(os.path.join(out_path, fname)) | |
| except Exception as e: | |
| LOG.warning(f"Could not remove old file {fname}: {e}") | |
| elif fname == "model.safetensors.index.json": | |
| try: | |
| os.remove(os.path.join(out_path, fname)) | |
| except Exception: | |
| pass | |
| writer.finalize() | |
| out_arch_info.config.save_pretrained(out_path) | |
| tokenizer_out = transformers.AutoTokenizer.from_pretrained( | |
| options.donor.model.path, | |
| revision=options.donor.model.revision, | |
| trust_remote_code=merge_options.trust_remote_code, | |
| ) | |
| tokenizer_out.save_pretrained(out_path) | |
| # Also copy generation_config.json if it exists in the donor | |
| donor_gen_config = os.path.join(options.donor.model.path, "generation_config.json") | |
| if os.path.exists(donor_gen_config): | |
| import shutil | |
| shutil.copy(donor_gen_config, os.path.join(out_path, "generation_config.json")) | |
| LOG.info("Done!") | |
| if __name__ == "__main__": | |
| main() | |