| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Any, Dict, List, Optional |
|
|
| import torch |
| from pydantic import BaseModel |
|
|
| from mergekit.common import ImmutableMap, ModelReference |
| from mergekit.graph import Task |
| from mergekit.io.tasks import GatherTensors |
| from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod |
| from mergekit.merge_methods.slerp import slerp |
| from mergekit.tokenizer import BuildTokenizer, TokenizerInfo |
|
|
|
|
| class TokenizerPermutationMergeTask(Task[torch.Tensor]): |
| tokenizer_task: BuildTokenizer |
| gather_tensors: GatherTensors |
| base_model: Optional[ModelReference] |
| use_slerp: bool |
| slerp_t: Optional[float] |
| tensor_parameters: ImmutableMap[ModelReference, Any] |
|
|
| def uses_accelerator(self) -> bool: |
| return True |
|
|
| def arguments(self) -> Dict[str, Task]: |
| return {"tokenizer_info": self.tokenizer_task, "tensors": self.gather_tensors} |
|
|
| def execute( |
| self, tokenizer_info: TokenizerInfo, tensors: Dict[ModelReference, torch.Tensor] |
| ) -> torch.Tensor: |
| if not tensors: |
| return None |
| if len(tensors) == 1: |
| return list(tensors.values())[0] |
|
|
| if self.use_slerp and self.slerp_t is None: |
| raise RuntimeError("Must set t to use embed_slerp") |
|
|
| models = [] |
| expanded = [] |
| masks = [] |
| weights = [] |
| for model in tensors: |
| models.append(model) |
|
|
| x = tensors[model] |
| p = tokenizer_info.permutations[model] |
|
|
| xp = torch.zeros((len(p), x.shape[-1]), dtype=x.dtype, device=x.device) |
| mask = torch.zeros((len(p),), dtype=torch.bool, device=x.device) |
| for out_idx in p: |
| in_idx = p[out_idx] |
| if in_idx < 0: |
| continue |
|
|
| xp[out_idx, :] = x[in_idx, :] |
| mask[out_idx] = 1 |
|
|
| expanded.append(xp) |
| masks.append(mask) |
|
|
| is_base = model == self.base_model |
| if self.use_slerp: |
| weight = (1.0 - self.slerp_t) if is_base else self.slerp_t |
| else: |
| weight = self.tensor_parameters[model]["weight"] |
|
|
| weights.append(weight) |
|
|
| expanded = torch.stack(expanded, dim=0) |
| masks = torch.stack(masks, dim=0).unsqueeze(-1) |
| weights = ( |
| torch.tensor(weights, dtype=expanded.dtype, device=expanded.device) |
| .unsqueeze(-1) |
| .unsqueeze(-1) |
| ) |
|
|
| total_weight = (masks * weights).sum(dim=0) |
| scale = 1 / total_weight |
| scale[total_weight.abs() < 1e-8] = 0 |
|
|
| linear_merged = (expanded * weights * masks).sum(dim=0) * scale |
|
|
| if self.use_slerp: |
| if expanded.shape[0] != 2: |
| raise RuntimeError("SLERP takes exactly two models") |
|
|
| if models[0] == self.base_model: |
| v0 = expanded[0, ...] |
| v1 = expanded[1, ...] |
| else: |
| v0 = expanded[1, ...] |
| v1 = expanded[0, ...] |
|
|
| res = slerp(self.slerp_t, v0, v1) |
| need_linear = (masks.sum(dim=0) != 2).squeeze(dim=-1) |
| res[need_linear, :] = linear_merged[need_linear, :].to( |
| device=res.device, dtype=res.dtype |
| ) |
| return res |
|
|
| return linear_merged |
|
|
|
|
| class TokenizerPermutationMerge(MergeMethod, BaseModel): |
| tokenizer_task: BuildTokenizer |
|
|
| def parameters(self) -> List[ConfigParameterDef]: |
| return [ |
| ConfigParameterDef(name="t", required=False), |
| ConfigParameterDef(name="embed_slerp", required=False, default_value=False), |
| ] |
|
|
| def tensor_parameters(self) -> List[ConfigParameterDef]: |
| return [ |
| ConfigParameterDef(name="weight", required=False), |
| ] |
|
|
| def make_task( |
| self, |
| *, |
| tensors: GatherTensors, |
| parameters: Dict[str, Any], |
| tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], |
| base_model: Optional[ModelReference], |
| **_kwargs, |
| ) -> Task: |
| return TokenizerPermutationMergeTask( |
| base_model=base_model, |
| tokenizer_task=self.tokenizer_task, |
| gather_tensors=tensors, |
| use_slerp=parameters["embed_slerp"], |
| slerp_t=parameters["t"], |
| tensor_parameters=tensor_parameters, |
| ) |
|
|