Upload model
Browse files- huggingface.py +21 -14
- model.safetensors +1 -1
huggingface.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
-
|
| 2 |
-
from pathlib import Path
|
| 3 |
from torch import nn
|
| 4 |
from transformers import AutoConfig, AutoModel, AutoTokenizer
|
| 5 |
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
|
|
@@ -19,23 +18,31 @@ def make_config_class(model_args: dict, model_type: str) -> PretrainedConfig:
|
|
| 19 |
return Config
|
| 20 |
|
| 21 |
|
| 22 |
-
def make_model_class(base_class: nn.Module,
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
for attr_name in dir(config):
|
| 28 |
-
if not attr_name.startswith("_") and not callable(getattr(config, attr_name)):
|
| 29 |
-
args_dict[attr_name] = getattr(config, attr_name)
|
| 30 |
-
|
| 31 |
-
return args_dict
|
| 32 |
-
|
| 33 |
class Model(PreTrainedModel):
|
| 34 |
config_class: PretrainedConfig
|
| 35 |
|
| 36 |
def __init__(self, config, **kwargs):
|
| 37 |
super().__init__(config, **kwargs)
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
def forward(self, *args, **kwargs):
|
| 41 |
return self._model(*args, **kwargs)
|
|
|
|
| 1 |
+
import inspect
|
|
|
|
| 2 |
from torch import nn
|
| 3 |
from transformers import AutoConfig, AutoModel, AutoTokenizer
|
| 4 |
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
|
|
|
|
| 18 |
return Config
|
| 19 |
|
| 20 |
|
| 21 |
+
def make_model_class(base_class: nn.Module, config_attributes: list[str] = None) -> PreTrainedModel:
|
| 22 |
+
base_init_signature = inspect.signature(base_class.__init__)
|
| 23 |
+
base_params = set(base_init_signature.parameters.keys()) - {"self"}
|
| 24 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
class Model(PreTrainedModel):
|
| 26 |
config_class: PretrainedConfig
|
| 27 |
|
| 28 |
def __init__(self, config, **kwargs):
|
| 29 |
super().__init__(config, **kwargs)
|
| 30 |
+
|
| 31 |
+
if config_attributes is not None:
|
| 32 |
+
model_kwargs = {a: getattr(config, a) for a in config_attributes if hasattr(config, a)}
|
| 33 |
+
else:
|
| 34 |
+
model_kwargs = {}
|
| 35 |
+
|
| 36 |
+
for param_name in base_params:
|
| 37 |
+
if hasattr(config, param_name):
|
| 38 |
+
model_kwargs[param_name] = getattr(config, param_name)
|
| 39 |
+
|
| 40 |
+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in base_params}
|
| 41 |
+
|
| 42 |
+
if "config" in base_params:
|
| 43 |
+
self._model = base_class(config, **model_kwargs, **filtered_kwargs)
|
| 44 |
+
else:
|
| 45 |
+
self._model = base_class(**model_kwargs, **filtered_kwargs)
|
| 46 |
|
| 47 |
def forward(self, *args, **kwargs):
|
| 48 |
return self._model(*args, **kwargs)
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 228
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5906761113ebf5b48e15bb2a1143f7d02c8fd959c946be09b2b3bcb2ca195a6e
|
| 3 |
size 228
|