# Copyright (C) 2025 Arcee AI # SPDX-License-Identifier: LGPL-3.0-only import torch from abc import ABC, abstractmethod from typing import Dict, List, Optional, Tuple from pydantic import BaseModel, Field from transformers import PretrainedConfig from mergekit.common import get_config_value class WeightInfo(BaseModel, frozen=True): """Information about an individual weight tensor in a model. Attributes: name (str): The name of the tensor representing the weight. is_embed (bool): Indicates whether the weight is for an embedding or language model head. optional (bool): Indicates whether the weight can be omitted from a model. aliases (Optional[List[str]]): List of alternative names for the weight, if applicable. force_dtype (Optional[str]): Mandatory dtype for the weight, if applicable. """ name: str is_embed: bool = False optional: bool = False aliases: Optional[Tuple[str, ...]] = None force_dtype: Optional[str] = None tied_names: Optional[Tuple[str, ...]] = None def _prefix_weight(weight: WeightInfo, prefix: Optional[str] = None) -> WeightInfo: if prefix is None: return weight return WeightInfo( name=prefix + weight.name, aliases=tuple(prefix + alias for alias in weight.aliases or ()) or None, tied_names=tuple(prefix + tied_name for tied_name in weight.tied_names or ()) or None, **weight.model_dump(exclude={"name", "aliases", "tied_names"}), ) class ModuleArchitecture(ABC): @abstractmethod def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: """Return a list of all weights preceding the first layer.""" ... @abstractmethod def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: """Return a list of all weights following the final layer.""" ... @abstractmethod def layer_weights( self, index: int, config: PretrainedConfig ) -> Optional[List[WeightInfo]]: """Return a list of all weights associated with a given layer.""" ... def num_layers_config_key(self) -> str: """Key in config that represents number of layers""" return "num_hidden_layers" def num_layers(self, config: PretrainedConfig) -> int: """Return the number of layers in a model.""" return get_config_value(config, self.num_layers_config_key()) def all_weights(self, config: PretrainedConfig) -> List[WeightInfo]: """Return all weights associated with a model.""" num_layers = self.num_layers(config) res = list(self.pre_weights(config)) for layer_idx in range(num_layers): res.extend(self.layer_weights(layer_idx, config)) res.extend(self.post_weights(config)) return res class ConfiguredModuleArchitecture( BaseModel, frozen=True, arbitrary_types_allowed=True ): info: ModuleArchitecture config: PretrainedConfig weight_prefix: Optional[str] = None def num_layers(self) -> int: return self.info.num_layers(self.config) def pre_weights(self) -> List[WeightInfo]: return [ _prefix_weight(w, self.weight_prefix) for w in self.info.pre_weights(self.config) ] def post_weights(self) -> List[WeightInfo]: return [ _prefix_weight(w, self.weight_prefix) for w in self.info.post_weights(self.config) ] def layer_weights(self, index: int) -> List[WeightInfo]: return [ _prefix_weight(w, self.weight_prefix) for w in self.info.layer_weights(index, self.config) ] def all_weights(self) -> List[WeightInfo]: return [ _prefix_weight(w, self.weight_prefix) for w in self.info.all_weights(self.config) ] class ModuleDefinition(BaseModel, frozen=True, arbitrary_types_allowed=True): architecture: ModuleArchitecture weight_prefix: Optional[str] = None subfolder: Optional[str] = None class ModelArchitecture(BaseModel, frozen=True): modules: Dict[str, ModuleDefinition] architectures: List[str] expected_model_type: str = Field(alias="model_type") tagalong_files: Optional[List[str]] = None vocab_size_config_key: Optional[str] = None def all_weights(self, config: PretrainedConfig) -> List[WeightInfo]: res = [] for module in self.modules.values(): for weight_info in module.architecture.all_weights(config=config): res.append(_prefix_weight(weight_info, module.weight_prefix)) return res class ConfiguredModelArchitecture(BaseModel, frozen=True, arbitrary_types_allowed=True): info: ModelArchitecture config: PretrainedConfig def all_weights(self) -> List[WeightInfo]: return self.info.all_weights(self.config) def get_module(self, module_name: str) -> ConfiguredModuleArchitecture: return ConfiguredModuleArchitecture( info=self.info.modules[module_name].architecture, config=self.config, weight_prefix=self.info.modules[module_name].weight_prefix, ) # Runpod Fix # Manually rebuild Pydantic models to resolve forward references # This fixes the "not fully defined" error with Pydantic v2 ConfiguredModuleArchitecture.model_rebuild() ConfiguredModelArchitecture.model_rebuild()