taegyeonglee commited on
Commit
6940c8b
·
verified ·
1 Parent(s): b33fb90

Add HF-standard offline package (auto_map + modeling_kbert_mtl.py)

Browse files
Files changed (1) hide show
  1. modeling_kbert_mtl.py +11 -7
modeling_kbert_mtl.py CHANGED
@@ -1,24 +1,28 @@
1
  # modeling_kbert_mtl.py
2
  import torch
3
  import torch.nn as nn
4
- from transformers import PreTrainedModel, AutoModel, AutoConfig, BertConfig # ← 추가
5
 
6
  def _config_from_base_dict(base_cfg_dict: dict):
7
  if base_cfg_dict is None:
8
  raise ValueError("config.base_model_config is required for offline load.")
9
- model_type = base_cfg_dict.get("model_type", "bert")
10
- kwargs = {k: v for k, v in base_cfg_dict.items() if k != "model_type"}
11
- return AutoConfig.for_model(model_type, **kwargs)
 
 
 
 
12
 
13
  class KbertMTL(PreTrainedModel):
14
- config_class = BertConfig
15
 
16
  def __init__(self, config):
17
  super().__init__(config)
18
-
19
  base_cfg_dict = getattr(config, "base_model_config", None)
20
  base_cfg = _config_from_base_dict(base_cfg_dict)
21
- self.bert = AutoModel.from_config(base_cfg)
 
22
 
23
  hidden = self.bert.config.hidden_size
24
  self.head_senti = nn.Linear(hidden, 5)
 
1
  # modeling_kbert_mtl.py
2
  import torch
3
  import torch.nn as nn
4
+ from transformers import PreTrainedModel, AutoModel, AutoConfig, BertConfig
5
 
6
  def _config_from_base_dict(base_cfg_dict: dict):
7
  if base_cfg_dict is None:
8
  raise ValueError("config.base_model_config is required for offline load.")
9
+ model_type = "bert"
10
+ try:
11
+ kwargs = {k: v for k, v in base_cfg_dict.items() if k != "model_type"}
12
+ cfg = AutoConfig.for_model(model_type, **kwargs)
13
+ except Exception:
14
+ cfg = BertConfig(**{k: v for k, v in base_cfg_dict.items() if k != "model_type"})
15
+ return cfg
16
 
17
  class KbertMTL(PreTrainedModel):
18
+ config_class = BertConfig
19
 
20
  def __init__(self, config):
21
  super().__init__(config)
 
22
  base_cfg_dict = getattr(config, "base_model_config", None)
23
  base_cfg = _config_from_base_dict(base_cfg_dict)
24
+
25
+ self.bert = AutoModel.from_config(base_cfg)
26
 
27
  hidden = self.bert.config.hidden_size
28
  self.head_senti = nn.Linear(hidden, 5)