taegyeonglee commited on
Commit
07e43ad
·
verified ·
1 Parent(s): c24c1e1

Add HF-standard offline package (auto_map + modeling_kbert_mtl.py)

Browse files
Files changed (1) hide show
  1. modeling_kbert_mtl.py +14 -11
modeling_kbert_mtl.py CHANGED
@@ -3,6 +3,15 @@ import torch
3
  import torch.nn as nn
4
  from transformers import PreTrainedModel, AutoModel, AutoConfig
5
 
 
 
 
 
 
 
 
 
 
6
  class KbertMTL(PreTrainedModel):
7
  """
8
  LangQuant KBERT Multi-Task Head (HF-standard, offline-friendly)
@@ -11,19 +20,15 @@ class KbertMTL(PreTrainedModel):
11
  - logits_senti: (B,5)
12
  - logits_act: (B,6)
13
  - logits_emo: (B,7)
14
- - pred_reg: (B,3) # [certainty, relevance, toxicity] in 0~1 (권장)
15
- - last_hidden_state: (B, L, H) from base encoder
16
  """
17
  def __init__(self, config):
18
  super().__init__(config)
19
 
20
- if not hasattr(config, "base_model_config") or config.base_model_config is None:
21
- raise ValueError(
22
- "config.base_model_config is required for offline load. "
23
- "Make sure your config.json contains a serialized base model config."
24
- )
25
 
26
- base_cfg = AutoConfig.from_dict(config.base_model_config)
27
  self.bert = AutoModel.from_config(base_cfg)
28
 
29
  hidden = self.bert.config.hidden_size
@@ -33,16 +38,14 @@ class KbertMTL(PreTrainedModel):
33
  self.head_reg = nn.Linear(hidden, 3)
34
 
35
  self.has_token_type = getattr(self.bert.embeddings, "token_type_embeddings", None) is not None
36
- self.post_init()
37
 
38
  def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, **kwargs):
39
  kw = dict(input_ids=input_ids, attention_mask=attention_mask)
40
  if self.has_token_type and token_type_ids is not None:
41
  kw["token_type_ids"] = token_type_ids
42
-
43
  out = self.bert(**kw)
44
  h = out.last_hidden_state[:, 0] # [CLS]
45
-
46
  return {
47
  "logits_senti": self.head_senti(h),
48
  "logits_act": self.head_act(h),
 
3
  import torch.nn as nn
4
  from transformers import PreTrainedModel, AutoModel, AutoConfig
5
 
6
+ def _config_from_base_dict(base_cfg_dict: dict):
7
+ if base_cfg_dict is None:
8
+ raise ValueError("config.base_model_config is required for offline load.")
9
+ model_type = base_cfg_dict.get("model_type", None)
10
+ if model_type is None:
11
+ model_type = "bert"
12
+ kwargs = {k: v for k, v in base_cfg_dict.items() if k != "model_type"}
13
+ return AutoConfig.for_model(model_type, **kwargs)
14
+
15
  class KbertMTL(PreTrainedModel):
16
  """
17
  LangQuant KBERT Multi-Task Head (HF-standard, offline-friendly)
 
20
  - logits_senti: (B,5)
21
  - logits_act: (B,6)
22
  - logits_emo: (B,7)
23
+ - pred_reg: (B,3) # [certainty, relevance, toxicity]
24
+ - last_hidden_state: (B, L, H)
25
  """
26
  def __init__(self, config):
27
  super().__init__(config)
28
 
29
+ base_cfg_dict = getattr(config, "base_model_config", None)
30
+ base_cfg = _config_from_base_dict(base_cfg_dict)
 
 
 
31
 
 
32
  self.bert = AutoModel.from_config(base_cfg)
33
 
34
  hidden = self.bert.config.hidden_size
 
38
  self.head_reg = nn.Linear(hidden, 3)
39
 
40
  self.has_token_type = getattr(self.bert.embeddings, "token_type_embeddings", None) is not None
41
+ self.post_init()
42
 
43
  def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, **kwargs):
44
  kw = dict(input_ids=input_ids, attention_mask=attention_mask)
45
  if self.has_token_type and token_type_ids is not None:
46
  kw["token_type_ids"] = token_type_ids
 
47
  out = self.bert(**kw)
48
  h = out.last_hidden_state[:, 0] # [CLS]
 
49
  return {
50
  "logits_senti": self.head_senti(h),
51
  "logits_act": self.head_act(h),