| |
|
|
| from copy import deepcopy |
| from dataclasses import dataclass, field |
| from pathlib import Path |
| from typing import Any, Literal, Optional, Type, Union |
|
|
| import torch |
| import yaml |
| from typing_extensions import Self |
|
|
| import litgpt.model |
| from litgpt.utils import find_multiple |
|
|
|
|
| @dataclass |
| class Config: |
| name: str = "" |
| hf_config: dict = field(default_factory=dict) |
| scale_embeddings: bool = False |
| block_size: int = 4096 |
| vocab_size: int = 50254 |
| padding_multiple: int = 512 |
| padded_vocab_size: Optional[int] = None |
| n_layer: int = 16 |
| n_head: int = 32 |
| head_size: Optional[int] = None |
| n_embd: int = 4096 |
| rotary_percentage: float = 0.25 |
| parallel_residual: bool = True |
| bias: bool = True |
| lm_head_bias: bool = False |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| n_query_groups: Optional[int] = None |
| shared_attention_norm: bool = False |
| norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm" |
| norm_eps: float = 1e-5 |
| mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = ( |
| "GptNeoxMLP" |
| ) |
| gelu_approximate: str = "none" |
| intermediate_size: Optional[int] = None |
| rope_condense_ratio: int = 1 |
| rope_base: int = 10000 |
| n_expert: int = 0 |
| n_expert_per_token: int = 0 |
|
|
| add_qkv_bias: Optional[bool] = None |
| prompt_vocab_size: Optional[int] = None |
| attn_dropout: float = 0.0 |
| pos_type: str = "rope" |
| force_align: bool = False |
| use_pretrain_phoneme_emb: bool = False |
| tie_word_embeddings: bool = False |
|
|
| |
| text_vocab_size:int = 152000 |
| cat_audio_vocab_size: int = 29120 |
| audio_vocab_size: int = 4160 |
| whisper_adapter_dim: int = 768 |
| vision_adapter_dim: int = 512 |
|
|
| post_adapter: bool = False |
| post_adapter_layers: int = 6 |
| asr_adapter: str = "llamamlp" |
|
|
| def __post_init__(self): |
| if not self.name: |
| self.name = self.hf_config.get("name", self.name) |
|
|
| if self.head_size is None: |
| assert self.n_embd % self.n_head == 0 |
| self.head_size = self.n_embd // self.n_head |
|
|
| |
| if self.padded_vocab_size is None: |
| self.padded_vocab_size = find_multiple( |
| self.vocab_size, self.padding_multiple |
| ) |
| else: |
| |
| self.vocab_size = min(self.vocab_size, self.padded_vocab_size) |
|
|
| |
| if self.n_query_groups is not None: |
| assert self.n_head % self.n_query_groups == 0 |
| else: |
| self.n_query_groups = self.n_head |
|
|
| |
| if self.intermediate_size is None: |
| if self.mlp_class_name == "LLaMAMLP": |
| raise ValueError( |
| f"The config {self.name!r}, needs to set the `intermediate_size`" |
| ) |
| self.intermediate_size = 4 * self.n_embd |
|
|
| self.rope_n_elem = int(self.rotary_percentage * self.head_size) |
|
|
| if self.add_qkv_bias is None: |
| self.add_qkv_bias = self.bias |
|
|
| @classmethod |
| def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]: |
| if name not in name_to_config: |
| |
| try: |
| conf_dict = next( |
| config |
| for config in configs |
| if name == config["hf_config"]["name"] |
| or config["hf_config"]["org"] + "/" + config["hf_config"]["name"] |
| == name |
| ) |
| except StopIteration: |
| raise ValueError(f"{name!r} is not a supported config name") |
| else: |
| conf_dict = name_to_config[name] |
|
|
| conf_dict = conf_dict.copy() |
| conf_dict.update(kwargs) |
| return cls(**conf_dict) |
|
|
| @classmethod |
| def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self: |
| with open(path, encoding="utf-8") as fp: |
| file_kwargs = yaml.safe_load(fp) |
| if file_kwargs is None: |
| raise ValueError(f"{path} is empty which is likely unexpected.") |
| file_kwargs.update(kwargs) |
| return cls(**file_kwargs) |
|
|
| @classmethod |
| def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self: |
| """Automatically load `model_config.yaml` and if it doesn't exist - a matching config from `litgpt/config.py`.""" |
| if (config_path := path / "model_config.yaml").is_file(): |
| return cls.from_file(config_path, **kwargs) |
| if (model_name := path.name) in name_to_config: |
| return cls.from_name(model_name, **kwargs) |
| raise FileNotFoundError( |
| f"For {str(path)!r} neither 'model_config.yaml' nor matching config exists." |
| ) |
|
|
| @property |
| def mlp_class(self) -> Type: |
| |
| return getattr(litgpt.model, self.mlp_class_name) |
|
|
| @property |
| def norm_class(self) -> Type: |
| |
| if self.norm_class_name == "RMSNorm": |
| from functools import partial |
|
|
| from litgpt.model import RMSNorm |
|
|
| return partial(RMSNorm, add_unit_offset="Gemma" in self.name) |
| return getattr(torch.nn, self.norm_class_name) |
|
|
|
|
| configs = [] |
| name_to_config = {config["name"]: config for config in configs} |
|
|