File size: 5,469 Bytes
4868b25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c38f701
4868b25
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# Copyright (C) 2025 Arcee AI
# SPDX-License-Identifier: LGPL-3.0-only

import torch
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple

from pydantic import BaseModel, Field
from transformers import PretrainedConfig

from mergekit.common import get_config_value


class WeightInfo(BaseModel, frozen=True):
    """Information about an individual weight tensor in a model.

    Attributes:
        name (str):
            The name of the tensor representing the weight.
        is_embed (bool):
            Indicates whether the weight is for an embedding or language model head.
        optional (bool):
            Indicates whether the weight can be omitted from a model.
        aliases (Optional[List[str]]):
            List of alternative names for the weight, if applicable.
        force_dtype (Optional[str]):
            Mandatory dtype for the weight, if applicable.
    """

    name: str
    is_embed: bool = False
    optional: bool = False
    aliases: Optional[Tuple[str, ...]] = None
    force_dtype: Optional[str] = None
    tied_names: Optional[Tuple[str, ...]] = None


def _prefix_weight(weight: WeightInfo, prefix: Optional[str] = None) -> WeightInfo:
    if prefix is None:
        return weight
    return WeightInfo(
        name=prefix + weight.name,
        aliases=tuple(prefix + alias for alias in weight.aliases or ()) or None,
        tied_names=tuple(prefix + tied_name for tied_name in weight.tied_names or ())
        or None,
        **weight.model_dump(exclude={"name", "aliases", "tied_names"}),
    )


class ModuleArchitecture(ABC):
    @abstractmethod
    def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
        """Return a list of all weights preceding the first layer."""
        ...

    @abstractmethod
    def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
        """Return a list of all weights following the final layer."""
        ...

    @abstractmethod
    def layer_weights(
        self, index: int, config: PretrainedConfig
    ) -> Optional[List[WeightInfo]]:
        """Return a list of all weights associated with a given layer."""
        ...

    def num_layers_config_key(self) -> str:
        """Key in config that represents number of layers"""
        return "num_hidden_layers"

    def num_layers(self, config: PretrainedConfig) -> int:
        """Return the number of layers in a model."""
        return get_config_value(config, self.num_layers_config_key())

    def all_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
        """Return all weights associated with a model."""
        num_layers = self.num_layers(config)
        res = list(self.pre_weights(config))
        for layer_idx in range(num_layers):
            res.extend(self.layer_weights(layer_idx, config))
        res.extend(self.post_weights(config))
        return res


class ConfiguredModuleArchitecture(
    BaseModel, frozen=True, arbitrary_types_allowed=True
):
    info: ModuleArchitecture
    config: PretrainedConfig
    weight_prefix: Optional[str] = None

    def num_layers(self) -> int:
        return self.info.num_layers(self.config)

    def pre_weights(self) -> List[WeightInfo]:
        return [
            _prefix_weight(w, self.weight_prefix)
            for w in self.info.pre_weights(self.config)
        ]

    def post_weights(self) -> List[WeightInfo]:
        return [
            _prefix_weight(w, self.weight_prefix)
            for w in self.info.post_weights(self.config)
        ]

    def layer_weights(self, index: int) -> List[WeightInfo]:
        return [
            _prefix_weight(w, self.weight_prefix)
            for w in self.info.layer_weights(index, self.config)
        ]

    def all_weights(self) -> List[WeightInfo]:
        return [
            _prefix_weight(w, self.weight_prefix)
            for w in self.info.all_weights(self.config)
        ]


class ModuleDefinition(BaseModel, frozen=True, arbitrary_types_allowed=True):
    architecture: ModuleArchitecture
    weight_prefix: Optional[str] = None
    subfolder: Optional[str] = None


class ModelArchitecture(BaseModel, frozen=True):
    modules: Dict[str, ModuleDefinition]
    architectures: List[str]
    expected_model_type: str = Field(alias="model_type")
    tagalong_files: Optional[List[str]] = None
    vocab_size_config_key: Optional[str] = None

    def all_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
        res = []
        for module in self.modules.values():
            for weight_info in module.architecture.all_weights(config=config):
                res.append(_prefix_weight(weight_info, module.weight_prefix))
        return res


class ConfiguredModelArchitecture(BaseModel, frozen=True, arbitrary_types_allowed=True):
    info: ModelArchitecture
    config: PretrainedConfig

    def all_weights(self) -> List[WeightInfo]:
        return self.info.all_weights(self.config)

    def get_module(self, module_name: str) -> ConfiguredModuleArchitecture:
        return ConfiguredModuleArchitecture(
            info=self.info.modules[module_name].architecture,
            config=self.config,
            weight_prefix=self.info.modules[module_name].weight_prefix,
        )

# Runpod Fix
# Manually rebuild Pydantic models to resolve forward references
# This fixes the "not fully defined" error with Pydantic v2
ConfiguredModuleArchitecture.model_rebuild()
ConfiguredModelArchitecture.model_rebuild()