taegyeonglee commited on
Commit
730736b
ยท
verified ยท
1 Parent(s): 07e43ad

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +37 -88
README.md CHANGED
@@ -84,102 +84,51 @@ base_model:
84
  ---
85
  ## How to use the model
86
  ```
87
- import torch
88
- import torch.nn as nn
89
- import numpy as np
90
  from transformers import AutoTokenizer, AutoModel
91
- from huggingface_hub import hf_hub_download
92
-
93
- # ---- ์ƒ์ˆ˜ ์ •์˜ ----
94
- REPO_ID = "langquant/LQ-Kbert-base"
95
- CKPT_RELPATH = "model/lq-kbert-base.pt"
96
-
97
- SENTI_MAP = {'strong_pos':0,'weak_pos':1,'neutral':2,'weak_neg':3,'strong_neg':4}
98
- ACT_MAP = {'buy':0,'hold':1,'sell':2,'avoid':3,'info_only':4,'ask_info':5}
99
- EMO_LIST = ['greed','fear','confidence','doubt','anger','hope','sarcasm']
100
- IDX2SENTI = {v:k for k,v in SENTI_MAP.items()}
101
- IDX2ACT = {v:k for k,v in ACT_MAP.items()}
102
-
103
- def sigmoid(x): return 1/(1+np.exp(-x))
104
-
105
- # ---- ๋ชจ๋ธ ์ •์˜ ----
106
- class KbertMTL(nn.Module):
107
- def __init__(self, base_model, hidden=768):
108
- super().__init__()
109
- self.bert = base_model
110
- self.head_senti = nn.Linear(hidden, 5)
111
- self.head_act = nn.Linear(hidden, 6)
112
- self.head_emo = nn.Linear(hidden, 7)
113
- self.head_reg = nn.Linear(hidden, 3)
114
- self.has_token_type = getattr(self.bert.embeddings, "token_type_embeddings", None) is not None
115
-
116
- def forward(self, input_ids, attention_mask, token_type_ids=None):
117
- kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
118
- if self.has_token_type and token_type_ids is not None:
119
- kwargs["token_type_ids"] = token_type_ids
120
- out = self.bert(**kwargs)
121
- h = out.last_hidden_state[:, 0] # [CLS]
122
- return {
123
- "logits_senti": self.head_senti(h),
124
- "logits_act": self.head_act(h),
125
- "logits_emo": self.head_emo(h),
126
- "pred_reg": self.head_reg(h)
127
- }
128
-
129
- # ---- ์ฒดํฌํฌ์ธํŠธ ๋กœ๋“œ ----
130
- def load_ckpt_from_hub():
131
- path = hf_hub_download(repo_id=REPO_ID, filename=CKPT_RELPATH)
132
- obj = torch.load(path, map_location="cpu")
133
- return obj
134
-
135
- # ---- ๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ๊ตฌ์„ฑ ----
136
- def build_model_and_tokenizer(ckpt_obj, hidden=768):
137
- model_name = ckpt_obj.get("model_name", "klue/bert-base")
138
- tokenizer = AutoTokenizer.from_pretrained(model_name)
139
- base = AutoModel.from_pretrained(model_name)
140
- model = KbertMTL(base_model=base, hidden=hidden)
141
- state_dict = ckpt_obj["state_dict"] if "state_dict" in ckpt_obj else ckpt_obj
142
- model.load_state_dict(state_dict, strict=False)
143
- emo_thr = float(ckpt_obj.get("emo_threshold", 0.5))
144
- return model, tokenizer, emo_thr
145
-
146
- # ---- ์ถ”๋ก  ----
147
- @torch.no_grad()
148
- def predict(text, model, tokenizer, device="cpu", max_len=200, emo_threshold=0.5):
149
- model.to(device).eval()
150
- enc = tokenizer([text], padding=True, truncation=True, max_length=max_len, return_tensors="pt").to(device)
151
  out = model(**enc)
152
 
153
- senti = out["logits_senti"].argmax(-1).item()
154
- act = out["logits_act"].argmax(-1).item()
155
- emo_p = sigmoid(out["logits_emo"].cpu().numpy())[0]
156
- reg = out["pred_reg"].cpu().numpy()[0]
157
 
158
- emos = [EMO_LIST[i] for i,p in enumerate(emo_p) if p >= emo_threshold]
 
 
 
 
 
159
 
160
- return {
161
- "text": text,
162
  "pred_sentiment_strength": IDX2SENTI[senti],
163
- "pred_action_signal": IDX2ACT[act],
164
- "pred_emotions": emos,
165
- "pred_certainty": float(np.clip(reg[0], 0, 1)),
166
- "pred_relevance": float(np.clip(reg[1], 0, 1)),
167
- "pred_toxicity": float(np.clip(reg[2], 0, 1)),
168
  }
 
169
 
170
- # ---- ๋ฉ”์ธ ----
171
- if __name__ == "__main__":
172
- text = input("๋ถ„์„ํ•  ๋ฌธ์žฅ์„ ์ž…๋ ฅํ•˜์„ธ์š”: ").strip()
173
- print("[๋ชจ๋ธ ๋กœ๋“œ ์ค‘...]")
174
- ckpt = load_ckpt_from_hub()
175
- model, tokenizer, emo_thr = build_model_and_tokenizer(ckpt)
176
-
177
- print("[์ถ”๋ก  ์ค‘...]")
178
- result = predict(text, model, tokenizer, device="cuda" if torch.cuda.is_available() else "cpu", emo_threshold=emo_thr)
179
-
180
- print("\n=== ๊ฒฐ๊ณผ ===")
181
- for k,v in result.items():
182
- print(f"{k}: {v}")
183
  ```
184
  ---
185
 
 
84
  ---
85
  ## How to use the model
86
  ```
87
+ import torch, json
 
 
88
  from transformers import AutoTokenizer, AutoModel
89
+
90
+ repo_or_dir = "LangQuant/LQ-Kbert-base"
91
+ texts = [
92
+ "๋น„ํŠธ์ฝ”์ธ ์กฐ์ • ํ›„ ๋ฐ˜๋“ฑ, ํˆฌ์ž์‹ฌ๋ฆฌ ๊ฐœ์„ ",
93
+ "ํ™˜์œจ ๊ธ‰๋“ฑ์— ์ฆ์‹œ ๋ณ€๋™์„ฑ ํ™•๋Œ€",
94
+ "๋น„ํŠธ ๊ทธ๋งŒ ์ข€ ๋‚ด๋ ค๋ผ ์ง„์งœ..",
95
+ "ํญ๋ฝใ… ใ… ใ…œใ… ใ…œ ๋‹ค ํŒ”์•„์•ผํ• ๊นŒ์š”?"
96
+ ]
97
+
98
+
99
+ tokenizer = AutoTokenizer.from_pretrained(repo_or_dir, local_files_only=True)
100
+ model = AutoModel.from_pretrained(repo_or_dir, trust_remote_code=True, local_files_only=True)
101
+ device = "cuda" if torch.cuda.is_available() else "cpu"
102
+ model.to(device).eval()
103
+
104
+
105
+ enc = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=200).to(device)
106
+ with torch.inference_mode():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  out = model(**enc)
108
 
109
+ IDX2SENTI = {0:"strong_pos",1:"weak_pos",2:"neutral",3:"weak_neg",4:"strong_neg"}
110
+ IDX2ACT = {0:"buy",1:"hold",2:"sell",3:"avoid",4:"info_only",5:"ask_info"}
111
+ EMO_LIST = ["greed","fear","confidence","doubt","anger","hope","sarcasm"]
112
+
113
 
114
+ for i, t in enumerate(texts):
115
+ senti = int(out["logits_senti"][i].argmax().item())
116
+ act = int(out["logits_act"][i].argmax().item())
117
+ emo_p = torch.sigmoid(out["logits_emo"][i]).tolist()
118
+ reg = torch.clamp(out["pred_reg"][i], 0, 1).tolist()
119
+ emos = [EMO_LIST[j] for j,p in enumerate(emo_p) if p >= 0.5]
120
 
121
+ result = {
122
+ "text": t,
123
  "pred_sentiment_strength": IDX2SENTI[senti],
124
+ "pred_action_signal": IDX2ACT[act],
125
+ "pred_emotions": emos,
126
+ "pred_certainty": float(reg[0]),
127
+ "pred_relevance": float(reg[1]),
128
+ "pred_toxicity": float(reg[2]),
129
  }
130
+ print(json.dumps(result, ensure_ascii=False))
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  ```
133
  ---
134