taegyeonglee commited on
Commit
b33fb90
·
verified ·
1 Parent(s): 28cc311

Add HF-standard offline package (auto_map + modeling_kbert_mtl.py)

Browse files
Files changed (1) hide show
  1. modeling_kbert_mtl.py +6 -17
modeling_kbert_mtl.py CHANGED
@@ -1,43 +1,32 @@
1
  # modeling_kbert_mtl.py
2
  import torch
3
  import torch.nn as nn
4
- from transformers import PreTrainedModel, AutoModel, AutoConfig
5
 
6
  def _config_from_base_dict(base_cfg_dict: dict):
7
  if base_cfg_dict is None:
8
  raise ValueError("config.base_model_config is required for offline load.")
9
- model_type = base_cfg_dict.get("model_type", None)
10
- if model_type is None:
11
- model_type = "bert"
12
  kwargs = {k: v for k, v in base_cfg_dict.items() if k != "model_type"}
13
  return AutoConfig.for_model(model_type, **kwargs)
14
 
15
  class KbertMTL(PreTrainedModel):
16
- """
17
- LangQuant KBERT Multi-Task Head (HF-standard, offline-friendly)
18
 
19
- Outputs (dict):
20
- - logits_senti: (B,5)
21
- - logits_act: (B,6)
22
- - logits_emo: (B,7)
23
- - pred_reg: (B,3) # [certainty, relevance, toxicity]
24
- - last_hidden_state: (B, L, H)
25
- """
26
  def __init__(self, config):
27
  super().__init__(config)
28
 
29
  base_cfg_dict = getattr(config, "base_model_config", None)
30
  base_cfg = _config_from_base_dict(base_cfg_dict)
31
-
32
- self.bert = AutoModel.from_config(base_cfg)
33
 
34
  hidden = self.bert.config.hidden_size
35
  self.head_senti = nn.Linear(hidden, 5)
36
  self.head_act = nn.Linear(hidden, 6)
37
  self.head_emo = nn.Linear(hidden, 7)
38
  self.head_reg = nn.Linear(hidden, 3)
39
-
40
  self.has_token_type = getattr(self.bert.embeddings, "token_type_embeddings", None) is not None
 
41
  self.post_init()
42
 
43
  def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, **kwargs):
@@ -45,7 +34,7 @@ class KbertMTL(PreTrainedModel):
45
  if self.has_token_type and token_type_ids is not None:
46
  kw["token_type_ids"] = token_type_ids
47
  out = self.bert(**kw)
48
- h = out.last_hidden_state[:, 0] # [CLS]
49
  return {
50
  "logits_senti": self.head_senti(h),
51
  "logits_act": self.head_act(h),
 
1
  # modeling_kbert_mtl.py
2
  import torch
3
  import torch.nn as nn
4
+ from transformers import PreTrainedModel, AutoModel, AutoConfig, BertConfig # ← 추가
5
 
6
  def _config_from_base_dict(base_cfg_dict: dict):
7
  if base_cfg_dict is None:
8
  raise ValueError("config.base_model_config is required for offline load.")
9
+ model_type = base_cfg_dict.get("model_type", "bert")
 
 
10
  kwargs = {k: v for k, v in base_cfg_dict.items() if k != "model_type"}
11
  return AutoConfig.for_model(model_type, **kwargs)
12
 
13
  class KbertMTL(PreTrainedModel):
14
+ config_class = BertConfig
 
15
 
 
 
 
 
 
 
 
16
  def __init__(self, config):
17
  super().__init__(config)
18
 
19
  base_cfg_dict = getattr(config, "base_model_config", None)
20
  base_cfg = _config_from_base_dict(base_cfg_dict)
21
+ self.bert = AutoModel.from_config(base_cfg)
 
22
 
23
  hidden = self.bert.config.hidden_size
24
  self.head_senti = nn.Linear(hidden, 5)
25
  self.head_act = nn.Linear(hidden, 6)
26
  self.head_emo = nn.Linear(hidden, 7)
27
  self.head_reg = nn.Linear(hidden, 3)
 
28
  self.has_token_type = getattr(self.bert.embeddings, "token_type_embeddings", None) is not None
29
+
30
  self.post_init()
31
 
32
  def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, **kwargs):
 
34
  if self.has_token_type and token_type_ids is not None:
35
  kw["token_type_ids"] = token_type_ids
36
  out = self.bert(**kw)
37
+ h = out.last_hidden_state[:, 0]
38
  return {
39
  "logits_senti": self.head_senti(h),
40
  "logits_act": self.head_act(h),