File size: 6,846 Bytes
656b04b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# Copyright (C) 2025 Arcee AI
# SPDX-License-Identifier: BUSL-1.1

import logging
from typing import Dict, Optional

import torch

from mergekit.common import ImmutableMap, ModelReference
from mergekit.graph import Task
from mergekit.io.tasks import GatherTensors
from mergekit.tokenizer.build import BuildTokenizer, TokenizerInfo
from mergekit.tokenizer.config import (
    ModelTokenEmbedding,
    TokenEmbeddingConfig,
    ZeroEmbedding,
)


class PermutedEmbeddings(Task[Dict[ModelReference, torch.Tensor]]):
    gather_tensors: GatherTensors
    tokenizer_task: BuildTokenizer
    tokens: Optional[ImmutableMap[str, TokenEmbeddingConfig]]
    pad_to_multiple_of: Optional[int]
    base_model: Optional[ModelReference]

    def arguments(self) -> Dict[str, Task]:
        return {"tokenizer_info": self.tokenizer_task, "tensors": self.gather_tensors}

    def execute(
        self, tokenizer_info: TokenizerInfo, tensors: Dict[ModelReference, torch.Tensor]
    ) -> Dict[ModelReference, torch.Tensor]:
        tokenizer = tokenizer_info.tokenizer
        permutations = tokenizer_info.permutations

        models = set(tensors.keys())
        if self.base_model:
            models.add(self.base_model)
        models = list(models)

        vocab = tokenizer.get_vocab()
        vocab_size = len(vocab)
        if self.pad_to_multiple_of and vocab_size % self.pad_to_multiple_of:
            vocab_size = (
                vocab_size // self.pad_to_multiple_of + 1
            ) * self.pad_to_multiple_of
        embed_size = tensors[models[0]].shape[1]
        assert all(
            t.shape[1] == embed_size for t in tensors.values()
        ), "Embedding sizes must match"

        dtype = tensors[models[0]].dtype
        device = tensors[models[0]].device

        token_configs = dict(**(self.tokens or {}))
        tokens_to_average = self.assign_embedding_sources(
            permutations, models, vocab, token_configs
        )

        default_embeds = {}
        for token, token_id in vocab.items():
            embed = torch.zeros(embed_size, dtype=dtype, device=device)
            if token in tokens_to_average:
                count = 0
                for model in models:
                    p = permutations[model]
                    if p[token_id] < 0:
                        continue
                    embed += tensors[model][p[token_id]]
                    count += 1
                embed /= count
            elif cfg := token_configs.get(token, None):
                cfg: TokenEmbeddingConfig
                embed = self.compute_default_embedding(
                    tokenizer_info, tensors, permutations, token, token_id, cfg
                )
            else:
                continue
            default_embeds[token] = embed

        result = {}
        for model in models:
            p = permutations[model]
            old_embed = tensors[model]
            new_embed = torch.zeros(
                (vocab_size, embed_size), dtype=dtype, device=device
            )
            for token, token_id in vocab.items():
                force = False
                if token in token_configs:
                    force = token_configs[token].force

                if p[token_id] >= 0 and not force:
                    new_embed[token_id, :] = old_embed[p[token_id]]
                elif token in default_embeds:
                    new_embed[token_id, :] = default_embeds[token]
                else:
                    logging.error(
                        f"No embedding for token {repr(token)} in model {model}!"
                    )

            if vocab_size > len(vocab):
                # as suggested by https://nlp.stanford.edu/~johnhew/vocab-expansion.html
                avg_embed = torch.mean(new_embed[: len(vocab), :], dim=0)
                new_embed[len(vocab) :, :] = avg_embed
            result[model] = new_embed

        return result

    def assign_embedding_sources(
        self,
        permutations: Dict[ModelReference, Dict[int, int]],
        models: list[ModelReference],
        vocab: Dict[str, int],
        token_configs: Dict[str, TokenEmbeddingConfig],
    ):
        permutation_list = [permutations[model] for model in models]

        tokens_to_average = set()
        # find tokens that are only present in one model
        for token, token_id in vocab.items():
            if token in token_configs:
                continue

            has_token = [p[token_id] >= 0 for p in permutation_list]
            num_present = sum(int(x) for x in has_token)
            if num_present == 1:
                donor_model = models[has_token.index(True)]
                token_configs[token] = TokenEmbeddingConfig(source=donor_model)
                continue

            if num_present == 0:
                token_configs[token] = TokenEmbeddingConfig(source=ZeroEmbedding())
                logging.warning(f"Token {repr(token)} not found in any model")
                continue

            if num_present > 0 and self.base_model is not None:
                if permutations[self.base_model][token_id] >= 0:
                    token_configs[token] = TokenEmbeddingConfig(source=self.base_model)
                    continue

            tokens_to_average.add(token)
        return tokens_to_average

    def compute_default_embedding(
        self,
        tokenizer_info: TokenizerInfo,
        tensors: Dict[ModelReference, torch.Tensor],
        permutations: Dict[ModelReference, Dict[int, int]],
        token: str,
        token_id: int,
        cfg: TokenEmbeddingConfig,
    ) -> torch.Tensor:
        if isinstance(cfg.source, ZeroEmbedding):
            pass
        elif isinstance(cfg.source, ModelTokenEmbedding):
            model = cfg.source.model
            assert (
                model in permutations
            ), f"Model {model} referenced but not part of merge"
            p = permutations[model]
            src_token_id = cfg.source.token_id
            if src_token_id is None:
                src_token = cfg.source.token
                assert (
                    src_token in tokenizer_info.original_vocabs[model]
                ), f"Token {repr(src_token)} not found in model {model}"
                src_token_id = tokenizer_info.original_vocabs[model][src_token]
            assert (
                src_token_id >= 0 and src_token_id < tensors[model].shape[0]
            ), f"Token ID {src_token_id} out of range for model {model}"
            embed = tensors[model][src_token_id]
        elif isinstance(cfg.source, ModelReference):
            model = cfg.source
            p = permutations[model]
            assert p[token_id] >= 0, f"Token {repr(token)} not found in model {model}"
            embed = tensors[model][p[token_id]]
        else:
            raise NotImplementedError(cfg)
        return embed