| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import Any, Dict, List |
| |
|
| | import torch |
| |
|
| | from mergekit.architecture import WeightInfo |
| | from mergekit.common import ImmutableMap, ModelReference, rectify_embed_sizes |
| | from mergekit.graph import Task |
| | from mergekit.io.tasks import GatherTensors |
| | from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod |
| |
|
| |
|
| | class LinearMergeTask(Task[torch.Tensor]): |
| | gather_tensors: GatherTensors |
| | tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]] |
| | normalize: bool |
| | parameter_name: str |
| |
|
| | def uses_accelerator(self) -> bool: |
| | return True |
| |
|
| | def arguments(self) -> Dict[str, Task]: |
| | return {"tensors": self.gather_tensors} |
| |
|
| | def execute( |
| | self, tensors: Dict[ModelReference, torch.Tensor], **_kwargs |
| | ) -> torch.Tensor: |
| | keys = list(tensors.keys()) |
| |
|
| | tensors = [tensors[key] for key in keys] |
| | weights = [self.tensor_parameters[key]["weight"] for key in keys] |
| |
|
| | rectify_embed_sizes(self.parameter_name, tensors) |
| |
|
| | unique_shapes = set(t.shape for t in tensors) |
| | if len(unique_shapes) != 1: |
| | raise RuntimeError( |
| | f"Tensor size mismatch for {self.parameter_name}, sizes: {list(unique_shapes)}" |
| | ) |
| |
|
| | tensors = torch.stack(tensors, dim=0) |
| | weights = torch.tensor(weights, dtype=tensors.dtype, device=tensors.device) |
| | while len(weights.shape) < len(tensors.shape): |
| | weights.unsqueeze_(-1) |
| |
|
| | res = (weights * tensors).sum(dim=0) |
| | if self.normalize: |
| | res /= weights.sum(dim=0) |
| |
|
| | return res |
| |
|
| |
|
| | class LinearMerge(MergeMethod): |
| | def parameters(self) -> List[ConfigParameterDef]: |
| | return [ |
| | ConfigParameterDef(name="normalize", required=False, default_value=True), |
| | ] |
| |
|
| | def tensor_parameters(self) -> List[ConfigParameterDef]: |
| | return [ConfigParameterDef(name="weight", required=True)] |
| |
|
| | def make_task( |
| | self, |
| | *, |
| | output_weight: WeightInfo, |
| | tensors: GatherTensors, |
| | parameters: Dict[str, Any], |
| | tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], |
| | **_kwargs, |
| | ) -> Task: |
| | return LinearMergeTask( |
| | gather_tensors=tensors, |
| | tensor_parameters=tensor_parameters, |
| | normalize=parameters["normalize"], |
| | parameter_name=output_weight.name, |
| | ) |
| |
|