calbors commited on
Commit
f2b5c7b
·
verified ·
1 Parent(s): 054035d

Upload model

Browse files
Files changed (2) hide show
  1. huggingface.py +21 -14
  2. model.safetensors +1 -1
huggingface.py CHANGED
@@ -1,5 +1,4 @@
1
- from __future__ import annotations
2
- from pathlib import Path
3
  from torch import nn
4
  from transformers import AutoConfig, AutoModel, AutoTokenizer
5
  from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
@@ -19,23 +18,31 @@ def make_config_class(model_args: dict, model_type: str) -> PretrainedConfig:
19
  return Config
20
 
21
 
22
- def make_model_class(base_class: nn.Module, config_to_args: callable = None) -> PreTrainedModel:
23
- if config_to_args is None:
24
- def config_to_args(config):
25
- args_dict = {}
26
-
27
- for attr_name in dir(config):
28
- if not attr_name.startswith("_") and not callable(getattr(config, attr_name)):
29
- args_dict[attr_name] = getattr(config, attr_name)
30
-
31
- return args_dict
32
-
33
  class Model(PreTrainedModel):
34
  config_class: PretrainedConfig
35
 
36
  def __init__(self, config, **kwargs):
37
  super().__init__(config, **kwargs)
38
- self._model = base_class(**config_to_args(config), **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  def forward(self, *args, **kwargs):
41
  return self._model(*args, **kwargs)
 
1
+ import inspect
 
2
  from torch import nn
3
  from transformers import AutoConfig, AutoModel, AutoTokenizer
4
  from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
 
18
  return Config
19
 
20
 
21
+ def make_model_class(base_class: nn.Module, config_attributes: list[str] = None) -> PreTrainedModel:
22
+ base_init_signature = inspect.signature(base_class.__init__)
23
+ base_params = set(base_init_signature.parameters.keys()) - {"self"}
24
+
 
 
 
 
 
 
 
25
  class Model(PreTrainedModel):
26
  config_class: PretrainedConfig
27
 
28
  def __init__(self, config, **kwargs):
29
  super().__init__(config, **kwargs)
30
+
31
+ if config_attributes is not None:
32
+ model_kwargs = {a: getattr(config, a) for a in config_attributes if hasattr(config, a)}
33
+ else:
34
+ model_kwargs = {}
35
+
36
+ for param_name in base_params:
37
+ if hasattr(config, param_name):
38
+ model_kwargs[param_name] = getattr(config, param_name)
39
+
40
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in base_params}
41
+
42
+ if "config" in base_params:
43
+ self._model = base_class(config, **model_kwargs, **filtered_kwargs)
44
+ else:
45
+ self._model = base_class(**model_kwargs, **filtered_kwargs)
46
 
47
  def forward(self, *args, **kwargs):
48
  return self._model(*args, **kwargs)
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a9bd62368812f60243bbcb3ce4b2999a16bd2726a093923f48c7193995ab6f84
3
  size 228
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5906761113ebf5b48e15bb2a1143f7d02c8fd959c946be09b2b3bcb2ca195a6e
3
  size 228