Danaasa's picture
Upload folder using huggingface_hub
656b04b verified
# Copyright (C) 2025 Arcee AI
# SPDX-License-Identifier: BUSL-1.1
import logging
from typing import Dict, Optional
import torch
from mergekit.common import ImmutableMap, ModelReference
from mergekit.graph import Task
from mergekit.io.tasks import GatherTensors
from mergekit.tokenizer.build import BuildTokenizer, TokenizerInfo
from mergekit.tokenizer.config import (
ModelTokenEmbedding,
TokenEmbeddingConfig,
ZeroEmbedding,
)
class PermutedEmbeddings(Task[Dict[ModelReference, torch.Tensor]]):
gather_tensors: GatherTensors
tokenizer_task: BuildTokenizer
tokens: Optional[ImmutableMap[str, TokenEmbeddingConfig]]
pad_to_multiple_of: Optional[int]
base_model: Optional[ModelReference]
def arguments(self) -> Dict[str, Task]:
return {"tokenizer_info": self.tokenizer_task, "tensors": self.gather_tensors}
def execute(
self, tokenizer_info: TokenizerInfo, tensors: Dict[ModelReference, torch.Tensor]
) -> Dict[ModelReference, torch.Tensor]:
tokenizer = tokenizer_info.tokenizer
permutations = tokenizer_info.permutations
models = set(tensors.keys())
if self.base_model:
models.add(self.base_model)
models = list(models)
vocab = tokenizer.get_vocab()
vocab_size = len(vocab)
if self.pad_to_multiple_of and vocab_size % self.pad_to_multiple_of:
vocab_size = (
vocab_size // self.pad_to_multiple_of + 1
) * self.pad_to_multiple_of
embed_size = tensors[models[0]].shape[1]
assert all(
t.shape[1] == embed_size for t in tensors.values()
), "Embedding sizes must match"
dtype = tensors[models[0]].dtype
device = tensors[models[0]].device
token_configs = dict(**(self.tokens or {}))
tokens_to_average = self.assign_embedding_sources(
permutations, models, vocab, token_configs
)
default_embeds = {}
for token, token_id in vocab.items():
embed = torch.zeros(embed_size, dtype=dtype, device=device)
if token in tokens_to_average:
count = 0
for model in models:
p = permutations[model]
if p[token_id] < 0:
continue
embed += tensors[model][p[token_id]]
count += 1
embed /= count
elif cfg := token_configs.get(token, None):
cfg: TokenEmbeddingConfig
embed = self.compute_default_embedding(
tokenizer_info, tensors, permutations, token, token_id, cfg
)
else:
continue
default_embeds[token] = embed
result = {}
for model in models:
p = permutations[model]
old_embed = tensors[model]
new_embed = torch.zeros(
(vocab_size, embed_size), dtype=dtype, device=device
)
for token, token_id in vocab.items():
force = False
if token in token_configs:
force = token_configs[token].force
if p[token_id] >= 0 and not force:
new_embed[token_id, :] = old_embed[p[token_id]]
elif token in default_embeds:
new_embed[token_id, :] = default_embeds[token]
else:
logging.error(
f"No embedding for token {repr(token)} in model {model}!"
)
if vocab_size > len(vocab):
# as suggested by https://nlp.stanford.edu/~johnhew/vocab-expansion.html
avg_embed = torch.mean(new_embed[: len(vocab), :], dim=0)
new_embed[len(vocab) :, :] = avg_embed
result[model] = new_embed
return result
def assign_embedding_sources(
self,
permutations: Dict[ModelReference, Dict[int, int]],
models: list[ModelReference],
vocab: Dict[str, int],
token_configs: Dict[str, TokenEmbeddingConfig],
):
permutation_list = [permutations[model] for model in models]
tokens_to_average = set()
# find tokens that are only present in one model
for token, token_id in vocab.items():
if token in token_configs:
continue
has_token = [p[token_id] >= 0 for p in permutation_list]
num_present = sum(int(x) for x in has_token)
if num_present == 1:
donor_model = models[has_token.index(True)]
token_configs[token] = TokenEmbeddingConfig(source=donor_model)
continue
if num_present == 0:
token_configs[token] = TokenEmbeddingConfig(source=ZeroEmbedding())
logging.warning(f"Token {repr(token)} not found in any model")
continue
if num_present > 0 and self.base_model is not None:
if permutations[self.base_model][token_id] >= 0:
token_configs[token] = TokenEmbeddingConfig(source=self.base_model)
continue
tokens_to_average.add(token)
return tokens_to_average
def compute_default_embedding(
self,
tokenizer_info: TokenizerInfo,
tensors: Dict[ModelReference, torch.Tensor],
permutations: Dict[ModelReference, Dict[int, int]],
token: str,
token_id: int,
cfg: TokenEmbeddingConfig,
) -> torch.Tensor:
if isinstance(cfg.source, ZeroEmbedding):
pass
elif isinstance(cfg.source, ModelTokenEmbedding):
model = cfg.source.model
assert (
model in permutations
), f"Model {model} referenced but not part of merge"
p = permutations[model]
src_token_id = cfg.source.token_id
if src_token_id is None:
src_token = cfg.source.token
assert (
src_token in tokenizer_info.original_vocabs[model]
), f"Token {repr(src_token)} not found in model {model}"
src_token_id = tokenizer_info.original_vocabs[model][src_token]
assert (
src_token_id >= 0 and src_token_id < tensors[model].shape[0]
), f"Token ID {src_token_id} out of range for model {model}"
embed = tensors[model][src_token_id]
elif isinstance(cfg.source, ModelReference):
model = cfg.source
p = permutations[model]
assert p[token_id] >= 0, f"Token {repr(token)} not found in model {model}"
embed = tensors[model][p[token_id]]
else:
raise NotImplementedError(cfg)
return embed