Spaces:
Running
Running
File size: 5,469 Bytes
4868b25 c38f701 4868b25 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | # Copyright (C) 2025 Arcee AI
# SPDX-License-Identifier: LGPL-3.0-only
import torch
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple
from pydantic import BaseModel, Field
from transformers import PretrainedConfig
from mergekit.common import get_config_value
class WeightInfo(BaseModel, frozen=True):
"""Information about an individual weight tensor in a model.
Attributes:
name (str):
The name of the tensor representing the weight.
is_embed (bool):
Indicates whether the weight is for an embedding or language model head.
optional (bool):
Indicates whether the weight can be omitted from a model.
aliases (Optional[List[str]]):
List of alternative names for the weight, if applicable.
force_dtype (Optional[str]):
Mandatory dtype for the weight, if applicable.
"""
name: str
is_embed: bool = False
optional: bool = False
aliases: Optional[Tuple[str, ...]] = None
force_dtype: Optional[str] = None
tied_names: Optional[Tuple[str, ...]] = None
def _prefix_weight(weight: WeightInfo, prefix: Optional[str] = None) -> WeightInfo:
if prefix is None:
return weight
return WeightInfo(
name=prefix + weight.name,
aliases=tuple(prefix + alias for alias in weight.aliases or ()) or None,
tied_names=tuple(prefix + tied_name for tied_name in weight.tied_names or ())
or None,
**weight.model_dump(exclude={"name", "aliases", "tied_names"}),
)
class ModuleArchitecture(ABC):
@abstractmethod
def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
"""Return a list of all weights preceding the first layer."""
...
@abstractmethod
def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
"""Return a list of all weights following the final layer."""
...
@abstractmethod
def layer_weights(
self, index: int, config: PretrainedConfig
) -> Optional[List[WeightInfo]]:
"""Return a list of all weights associated with a given layer."""
...
def num_layers_config_key(self) -> str:
"""Key in config that represents number of layers"""
return "num_hidden_layers"
def num_layers(self, config: PretrainedConfig) -> int:
"""Return the number of layers in a model."""
return get_config_value(config, self.num_layers_config_key())
def all_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
"""Return all weights associated with a model."""
num_layers = self.num_layers(config)
res = list(self.pre_weights(config))
for layer_idx in range(num_layers):
res.extend(self.layer_weights(layer_idx, config))
res.extend(self.post_weights(config))
return res
class ConfiguredModuleArchitecture(
BaseModel, frozen=True, arbitrary_types_allowed=True
):
info: ModuleArchitecture
config: PretrainedConfig
weight_prefix: Optional[str] = None
def num_layers(self) -> int:
return self.info.num_layers(self.config)
def pre_weights(self) -> List[WeightInfo]:
return [
_prefix_weight(w, self.weight_prefix)
for w in self.info.pre_weights(self.config)
]
def post_weights(self) -> List[WeightInfo]:
return [
_prefix_weight(w, self.weight_prefix)
for w in self.info.post_weights(self.config)
]
def layer_weights(self, index: int) -> List[WeightInfo]:
return [
_prefix_weight(w, self.weight_prefix)
for w in self.info.layer_weights(index, self.config)
]
def all_weights(self) -> List[WeightInfo]:
return [
_prefix_weight(w, self.weight_prefix)
for w in self.info.all_weights(self.config)
]
class ModuleDefinition(BaseModel, frozen=True, arbitrary_types_allowed=True):
architecture: ModuleArchitecture
weight_prefix: Optional[str] = None
subfolder: Optional[str] = None
class ModelArchitecture(BaseModel, frozen=True):
modules: Dict[str, ModuleDefinition]
architectures: List[str]
expected_model_type: str = Field(alias="model_type")
tagalong_files: Optional[List[str]] = None
vocab_size_config_key: Optional[str] = None
def all_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
res = []
for module in self.modules.values():
for weight_info in module.architecture.all_weights(config=config):
res.append(_prefix_weight(weight_info, module.weight_prefix))
return res
class ConfiguredModelArchitecture(BaseModel, frozen=True, arbitrary_types_allowed=True):
info: ModelArchitecture
config: PretrainedConfig
def all_weights(self) -> List[WeightInfo]:
return self.info.all_weights(self.config)
def get_module(self, module_name: str) -> ConfiguredModuleArchitecture:
return ConfiguredModuleArchitecture(
info=self.info.modules[module_name].architecture,
config=self.config,
weight_prefix=self.info.modules[module_name].weight_prefix,
)
# Runpod Fix
# Manually rebuild Pydantic models to resolve forward references
# This fixes the "not fully defined" error with Pydantic v2
ConfiguredModuleArchitecture.model_rebuild()
ConfiguredModelArchitecture.model_rebuild() |