Spaces:
Runtime error
Runtime error
feat(model): support using JSON config
Browse files- litgpt/config.py +8 -49
litgpt/config.py
CHANGED
|
@@ -1,12 +1,10 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
from copy import deepcopy
|
| 4 |
from dataclasses import dataclass, field
|
| 5 |
from pathlib import Path
|
| 6 |
from typing import Any, Literal, Optional, Type, Union
|
| 7 |
|
| 8 |
import torch
|
| 9 |
-
import yaml
|
| 10 |
from typing_extensions import Self
|
| 11 |
|
| 12 |
import litgpt.model
|
|
@@ -30,33 +28,11 @@ class Config:
|
|
| 30 |
parallel_residual: bool = True
|
| 31 |
bias: bool = True
|
| 32 |
lm_head_bias: bool = False
|
| 33 |
-
# to use multi-head attention (MHA), set this to `n_head` (default)
|
| 34 |
-
# to use multi-query attention (MQA), set this to 1
|
| 35 |
-
# to use grouped-query attention (GQA), set this to a value in between
|
| 36 |
-
# Example with `n_head=4`
|
| 37 |
-
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
|
| 38 |
-
# │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │
|
| 39 |
-
# └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
|
| 40 |
-
# │ │ │ │ │ │ │
|
| 41 |
-
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
|
| 42 |
-
# │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │
|
| 43 |
-
# └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
|
| 44 |
-
# │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐
|
| 45 |
-
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐
|
| 46 |
-
# │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │
|
| 47 |
-
# └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘
|
| 48 |
-
# ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶
|
| 49 |
-
# MHA GQA MQA
|
| 50 |
-
# n_query_groups=4 n_query_groups=2 n_query_groups=1
|
| 51 |
-
#
|
| 52 |
-
# credit https://arxiv.org/pdf/2305.13245.pdf
|
| 53 |
n_query_groups: Optional[int] = None
|
| 54 |
shared_attention_norm: bool = False
|
| 55 |
norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
|
| 56 |
norm_eps: float = 1e-5
|
| 57 |
-
mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] =
|
| 58 |
-
"GptNeoxMLP"
|
| 59 |
-
)
|
| 60 |
gelu_approximate: str = "none"
|
| 61 |
intermediate_size: Optional[int] = None
|
| 62 |
rope_condense_ratio: int = 1
|
|
@@ -90,27 +66,19 @@ class Config:
|
|
| 90 |
assert self.n_embd % self.n_head == 0
|
| 91 |
self.head_size = self.n_embd // self.n_head
|
| 92 |
|
| 93 |
-
# vocab size should be a power of 2 to be optimal on hardware. compute the closest value
|
| 94 |
if self.padded_vocab_size is None:
|
| 95 |
-
self.padded_vocab_size = find_multiple(
|
| 96 |
-
self.vocab_size, self.padding_multiple
|
| 97 |
-
)
|
| 98 |
else:
|
| 99 |
-
# vocab size shouldn't be larger than padded vocab size
|
| 100 |
self.vocab_size = min(self.vocab_size, self.padded_vocab_size)
|
| 101 |
|
| 102 |
-
# compute the number of query groups
|
| 103 |
if self.n_query_groups is not None:
|
| 104 |
assert self.n_head % self.n_query_groups == 0
|
| 105 |
else:
|
| 106 |
self.n_query_groups = self.n_head
|
| 107 |
|
| 108 |
-
# compute the intermediate size for MLP if not set
|
| 109 |
if self.intermediate_size is None:
|
| 110 |
if self.mlp_class_name == "LLaMAMLP":
|
| 111 |
-
raise ValueError(
|
| 112 |
-
f"The config {self.name!r}, needs to set the `intermediate_size`"
|
| 113 |
-
)
|
| 114 |
self.intermediate_size = 4 * self.n_embd
|
| 115 |
|
| 116 |
self.rope_n_elem = int(self.rotary_percentage * self.head_size)
|
|
@@ -121,14 +89,12 @@ class Config:
|
|
| 121 |
@classmethod
|
| 122 |
def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]:
|
| 123 |
if name not in name_to_config:
|
| 124 |
-
# search through all `config['hf_config']['name']`
|
| 125 |
try:
|
| 126 |
conf_dict = next(
|
| 127 |
config
|
| 128 |
for config in configs
|
| 129 |
if name == config["hf_config"]["name"]
|
| 130 |
-
or config["hf_config"]["org"] + "/" + config["hf_config"]["name"]
|
| 131 |
-
== name
|
| 132 |
)
|
| 133 |
except StopIteration:
|
| 134 |
raise ValueError(f"{name!r} is not a supported config name")
|
|
@@ -142,7 +108,7 @@ class Config:
|
|
| 142 |
@classmethod
|
| 143 |
def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self:
|
| 144 |
with open(path, encoding="utf-8") as fp:
|
| 145 |
-
file_kwargs =
|
| 146 |
if file_kwargs is None:
|
| 147 |
raise ValueError(f"{path} is empty which is likely unexpected.")
|
| 148 |
file_kwargs.update(kwargs)
|
|
@@ -150,28 +116,21 @@ class Config:
|
|
| 150 |
|
| 151 |
@classmethod
|
| 152 |
def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self:
|
| 153 |
-
|
| 154 |
-
if (config_path := path / "model_config.yaml").is_file():
|
| 155 |
return cls.from_file(config_path, **kwargs)
|
| 156 |
if (model_name := path.name) in name_to_config:
|
| 157 |
return cls.from_name(model_name, **kwargs)
|
| 158 |
-
raise FileNotFoundError(
|
| 159 |
-
f"For {str(path)!r} neither 'model_config.yaml' nor matching config exists."
|
| 160 |
-
)
|
| 161 |
|
| 162 |
@property
|
| 163 |
def mlp_class(self) -> Type:
|
| 164 |
-
# `self.mlp_class_name` cannot be the type to keep the config serializable
|
| 165 |
return getattr(litgpt.model, self.mlp_class_name)
|
| 166 |
|
| 167 |
@property
|
| 168 |
def norm_class(self) -> Type:
|
| 169 |
-
# `self.norm_class_name` cannot be the type to keep the config serializable
|
| 170 |
if self.norm_class_name == "RMSNorm":
|
| 171 |
from functools import partial
|
| 172 |
-
|
| 173 |
from litgpt.model import RMSNorm
|
| 174 |
-
|
| 175 |
return partial(RMSNorm, add_unit_offset="Gemma" in self.name)
|
| 176 |
return getattr(torch.nn, self.norm_class_name)
|
| 177 |
|
|
|
|
| 1 |
+
import json
|
|
|
|
| 2 |
from copy import deepcopy
|
| 3 |
from dataclasses import dataclass, field
|
| 4 |
from pathlib import Path
|
| 5 |
from typing import Any, Literal, Optional, Type, Union
|
| 6 |
|
| 7 |
import torch
|
|
|
|
| 8 |
from typing_extensions import Self
|
| 9 |
|
| 10 |
import litgpt.model
|
|
|
|
| 28 |
parallel_residual: bool = True
|
| 29 |
bias: bool = True
|
| 30 |
lm_head_bias: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
n_query_groups: Optional[int] = None
|
| 32 |
shared_attention_norm: bool = False
|
| 33 |
norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
|
| 34 |
norm_eps: float = 1e-5
|
| 35 |
+
mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP"
|
|
|
|
|
|
|
| 36 |
gelu_approximate: str = "none"
|
| 37 |
intermediate_size: Optional[int] = None
|
| 38 |
rope_condense_ratio: int = 1
|
|
|
|
| 66 |
assert self.n_embd % self.n_head == 0
|
| 67 |
self.head_size = self.n_embd // self.n_head
|
| 68 |
|
|
|
|
| 69 |
if self.padded_vocab_size is None:
|
| 70 |
+
self.padded_vocab_size = find_multiple(self.vocab_size, self.padding_multiple)
|
|
|
|
|
|
|
| 71 |
else:
|
|
|
|
| 72 |
self.vocab_size = min(self.vocab_size, self.padded_vocab_size)
|
| 73 |
|
|
|
|
| 74 |
if self.n_query_groups is not None:
|
| 75 |
assert self.n_head % self.n_query_groups == 0
|
| 76 |
else:
|
| 77 |
self.n_query_groups = self.n_head
|
| 78 |
|
|
|
|
| 79 |
if self.intermediate_size is None:
|
| 80 |
if self.mlp_class_name == "LLaMAMLP":
|
| 81 |
+
raise ValueError(f"The config {self.name!r}, needs to set the `intermediate_size`")
|
|
|
|
|
|
|
| 82 |
self.intermediate_size = 4 * self.n_embd
|
| 83 |
|
| 84 |
self.rope_n_elem = int(self.rotary_percentage * self.head_size)
|
|
|
|
| 89 |
@classmethod
|
| 90 |
def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]:
|
| 91 |
if name not in name_to_config:
|
|
|
|
| 92 |
try:
|
| 93 |
conf_dict = next(
|
| 94 |
config
|
| 95 |
for config in configs
|
| 96 |
if name == config["hf_config"]["name"]
|
| 97 |
+
or config["hf_config"]["org"] + "/" + config["hf_config"]["name"] == name
|
|
|
|
| 98 |
)
|
| 99 |
except StopIteration:
|
| 100 |
raise ValueError(f"{name!r} is not a supported config name")
|
|
|
|
| 108 |
@classmethod
|
| 109 |
def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self:
|
| 110 |
with open(path, encoding="utf-8") as fp:
|
| 111 |
+
file_kwargs = json.load(fp)
|
| 112 |
if file_kwargs is None:
|
| 113 |
raise ValueError(f"{path} is empty which is likely unexpected.")
|
| 114 |
file_kwargs.update(kwargs)
|
|
|
|
| 116 |
|
| 117 |
@classmethod
|
| 118 |
def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self:
|
| 119 |
+
if (config_path := path / "config.json").is_file():
|
|
|
|
| 120 |
return cls.from_file(config_path, **kwargs)
|
| 121 |
if (model_name := path.name) in name_to_config:
|
| 122 |
return cls.from_name(model_name, **kwargs)
|
| 123 |
+
raise FileNotFoundError(f"For {str(path)!r} neither 'config.json' nor matching config exists.")
|
|
|
|
|
|
|
| 124 |
|
| 125 |
@property
|
| 126 |
def mlp_class(self) -> Type:
|
|
|
|
| 127 |
return getattr(litgpt.model, self.mlp_class_name)
|
| 128 |
|
| 129 |
@property
|
| 130 |
def norm_class(self) -> Type:
|
|
|
|
| 131 |
if self.norm_class_name == "RMSNorm":
|
| 132 |
from functools import partial
|
|
|
|
| 133 |
from litgpt.model import RMSNorm
|
|
|
|
| 134 |
return partial(RMSNorm, add_unit_offset="Gemma" in self.name)
|
| 135 |
return getattr(torch.nn, self.norm_class_name)
|
| 136 |
|