puttimej commited on
Commit
a60bfa2
Β·
verified Β·
1 Parent(s): 74a238f

Upload inference/pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference/pipeline.py +265 -0
inference/pipeline.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference/pipeline.py
2
+ import os
3
+ import json
4
+ import torch
5
+ import yaml
6
+ from typing import List, Dict, Any, Optional
7
+
8
+ from tokenizer.thai_tokenizer import ThaiTokenizer
9
+ from model.encoder import ThaiTransformerEncoder, ModelConfig
10
+ from model.heads.ner_head import NERHead
11
+ from model.heads.sentiment_head import SentimentHead
12
+ from model.heads.qa_head import QAHead
13
+
14
+
15
+ # label maps ΰΈͺำหรับ decode output กΰΈ₯ΰΈ±ΰΈšΰΉ€ΰΈ›ΰΉ‡ΰΈ™ string
16
+ NER_ID2LABEL = {
17
+ 0: "O",
18
+ 1: "B-PERSON", 2: "I-PERSON",
19
+ 3: "B-ORGANIZATION", 4: "I-ORGANIZATION",
20
+ 5: "B-LOCATION", 6: "I-LOCATION",
21
+ }
22
+ SENTIMENT_ID2LABEL = {0: "negative", 1: "neutral", 2: "positive"}
23
+
24
+
25
+ class ThaiNLPModel(torch.nn.Module):
26
+ """ΰΈ£ΰΈ§ΰΈ‘ encoder + 3 heads ΰΉ€ΰΈ›ΰΉ‡ΰΈ™ module ΰΉ€ΰΈ”ΰΈ΅ΰΈ’ΰΈ§ ΰΈͺำหรับ load/save"""
27
+ def __init__(self, config: ModelConfig, num_ner_labels: int = 7):
28
+ super().__init__()
29
+ self.encoder = ThaiTransformerEncoder(config)
30
+ self.ner_head = NERHead(config.d_model, num_ner_labels)
31
+ self.sentiment_head = SentimentHead(config.d_model, num_classes=3)
32
+ self.qa_head = QAHead(config.d_model)
33
+
34
+
35
+ class ThaiNLPPipeline:
36
+ """
37
+ High-level inference class
38
+ ΰΉ‚ΰΈ«ΰΈ₯ΰΈ” model ครั้งเดมฒวแΰΈ₯้วเรมฒก predict() ΰΉ„ΰΈ”ΰΉ‰ΰΉ€ΰΈ£ΰΈ·ΰΉˆΰΈ­ΰΈ’ΰΉ†
39
+ """
40
+
41
+ def __init__(self, model_dir: str, device: str = "auto", checkpoint_name: str = "checkpoint_best"):
42
+ # ── Device ───────────────────────────────────────────────────────
43
+ if device == "auto":
44
+ device = "cuda" if torch.cuda.is_available() else "cpu"
45
+ self.device = torch.device(device)
46
+
47
+ # ── Load config ───────────────────────────────────────────────────
48
+ config_path = os.path.join(model_dir, "config.yaml")
49
+ if not os.path.exists(config_path):
50
+ raise FileNotFoundError(f"ΰΉ„ΰΈ‘ΰΉˆΰΈžΰΈš config.yaml ΰΉƒΰΈ™ {model_dir}")
51
+ with open(config_path) as f:
52
+ raw_config = yaml.safe_load(f)
53
+ model_cfg = ModelConfig(**raw_config["model"])
54
+
55
+ # ── Load tokenizer ────────────────────────────────────────────────
56
+ self.tokenizer = ThaiTokenizer.from_pretrained(
57
+ os.path.join(model_dir, "tokenizer")
58
+ )
59
+
60
+ # ── Load model ────────────────────────────────────────────────────
61
+ self.model = ThaiNLPModel(model_cfg)
62
+ ckpt_path = os.path.join(model_dir, checkpoint_name, "checkpoint.pt")
63
+ if not os.path.exists(ckpt_path):
64
+ raise FileNotFoundError(f"ΰΉ„ΰΈ‘ΰΉˆΰΈžΰΈš checkpoint: {ckpt_path}")
65
+
66
+ ckpt = torch.load(ckpt_path, map_location=self.device)
67
+ self.model.load_state_dict(ckpt["model_state"])
68
+ self.model.to(self.device)
69
+ self.model.eval()
70
+
71
+ step = ckpt.get("global_step", "unknown")
72
+ metric = ckpt.get("best_metric", "unknown")
73
+ print(f"pipeline ready on {self.device} (loaded {checkpoint_name} from step {step} with best_val_loss={metric})")
74
+
75
+ # ─────────────────────────────────────────────────────────────────────
76
+ # predict β€” entry point ΰΈ«ΰΈ₯ัก
77
+ # ─────────────────────────────────────────────────────────────────────
78
+
79
+ def predict(
80
+ self,
81
+ text: str,
82
+ tasks: List[str],
83
+ question: Optional[str] = None, # ต้องการΰΈͺำหรับ QA
84
+ context: Optional[str] = None, # ต้องการΰΈͺำหรับ QA
85
+ ) -> Dict[str, Any]:
86
+ """
87
+ Parameters
88
+ ----------
89
+ text : input text ΰΈͺำหรับ NER แΰΈ₯ΰΈ° Sentiment
90
+ tasks : list ΰΈ‚ΰΈ­ΰΈ‡ task ΰΈ—ΰΈ΅ΰΉˆΰΈ•ΰΉ‰ΰΈ­ΰΈ‡ΰΈΰΈ²ΰΈ£ ["ner", "sentiment", "qa"]
91
+ question : question string (ΰΉ€ΰΈ‰ΰΈžΰΈ²ΰΈ° QA)
92
+ context : context string (ΰΉ€ΰΈ‰ΰΈžΰΈ²ΰΈ° QA)
93
+
94
+ Returns
95
+ -------
96
+ dict ΰΈ—ΰΈ΅ΰΉˆΰΈ‘ΰΈ΅ key ΰΈ•ΰΈ²ΰΈ‘ tasks ΰΈ—ΰΈ΅ΰΉˆΰΈ‚ΰΈ­
97
+ """
98
+ results = {}
99
+
100
+ with torch.no_grad():
101
+ # ── NER ──────────────────────────────────────────────────────
102
+ if "ner" in tasks:
103
+ results["ner"] = self._predict_ner(text)
104
+
105
+ # ── Sentiment ────────────────────────────────────────────────
106
+ if "sentiment" in tasks:
107
+ results["sentiment"] = self._predict_sentiment(text)
108
+
109
+ # ── QA ───────────────────────────────────────────────────────
110
+ if "qa" in tasks:
111
+ if question is None or context is None:
112
+ results["qa"] = {"error": "QA ต้องการ question แΰΈ₯ΰΈ° context"}
113
+ else:
114
+ results["qa"] = self._predict_qa(question, context)
115
+
116
+ return results
117
+
118
+ # ─────────────────────────────────────────────────────────────────────
119
+ # Task-specific predict methods
120
+ # ─────────────────────────────────────────────────────────────────────
121
+
122
+ def _encode(self, input_ids, attention_mask):
123
+ """Shared encoder forward"""
124
+ ids = torch.tensor([input_ids], dtype=torch.long).to(self.device)
125
+ mask = torch.tensor([attention_mask], dtype=torch.long).to(self.device)
126
+ hidden, _ = self.model.encoder(ids, mask)
127
+ return hidden, mask
128
+
129
+ def _predict_ner(self, text: str) -> List[Dict[str, str]]:
130
+ """
131
+ ΰΈ„ΰΈ·ΰΈ™ list ΰΈ‚ΰΈ­ΰΈ‡ {"token": str, "label": str}
132
+ กรอง [CLS], [SEP], padding ออก แΰΈ₯ΰΈ° merge subwords กΰΈ₯ΰΈ±ΰΈšΰΉ€ΰΈ›ΰΉ‡ΰΈ™ΰΈ„ΰΈ³
133
+ """
134
+ encoded = self.tokenizer.batch_encode(
135
+ [text], max_length=512, padding=False, return_tensors=False
136
+ )
137
+ input_ids = encoded["input_ids"][0]
138
+ attn_mask = encoded["attention_mask"][0]
139
+
140
+ hidden, _ = self._encode(input_ids, attn_mask)
141
+ logits = self.model.ner_head(hidden) # (1, T, num_labels)
142
+ pred_ids = logits[0].argmax(dim=-1).tolist() # (T,)
143
+
144
+ # decode tokens กΰΈ₯ΰΈ±ΰΈšΰΉ€ΰΈ›ΰΉ‡ΰΈ™ string แΰΈ₯ΰΉ‰ΰΈ§ zip กับ label
145
+ pieces = self.tokenizer.sp.id_to_piece(input_ids)
146
+ special = {
147
+ self.tokenizer.cls_id,
148
+ self.tokenizer.sep_id,
149
+ self.tokenizer.pad_id,
150
+ }
151
+
152
+ entities = []
153
+ current_word = ""
154
+ current_label = "O"
155
+
156
+ for token_id, label_id in zip(input_ids, pred_ids):
157
+ if token_id in special:
158
+ continue
159
+
160
+ piece = self.tokenizer.sp.id_to_piece([token_id])[0]
161
+ label = NER_ID2LABEL.get(label_id, "O")
162
+
163
+ # SentencePiece ΰΉƒΰΈŠΰΉ‰ "▁" ΰΈ™ΰΈ³ΰΈ«ΰΈ™ΰΉ‰ΰΈ² subword แรกของคำ
164
+ if piece.startswith("▁") or not current_word:
165
+ # ΰΈšΰΈ±ΰΈ™ΰΈ—ΰΈΆΰΈΰΈ„ΰΈ³ΰΈΰΉˆΰΈ­ΰΈ™ΰΈ«ΰΈ™ΰΉ‰ΰΈ² (ΰΈ–ΰΉ‰ΰΈ²ΰΈ‘ΰΈ΅)
166
+ if current_word:
167
+ entities.append({
168
+ "token": current_word,
169
+ "label": current_label,
170
+ })
171
+ current_word = piece.lstrip("▁")
172
+ current_label = label
173
+ else:
174
+ # subword ΰΈ•ΰΉˆΰΈ­ΰΉ€ΰΈ™ΰΈ·ΰΉˆΰΈ­ΰΈ‡ β€” merge ΰΉ€ΰΈ‚ΰΉ‰ΰΈ²ΰΈΰΈ±ΰΈšΰΈ„ΰΈ³ΰΈ›ΰΈ±ΰΈˆΰΈˆΰΈΈΰΈšΰΈ±ΰΈ™
175
+ current_word += piece
176
+
177
+ # ΰΈšΰΈ±ΰΈ™ΰΈ—ΰΈΆΰΈΰΈ„ΰΈ³ΰΈͺΰΈΈΰΈ”ΰΈ—ΰΉ‰ΰΈ²ΰΈ’
178
+ if current_word:
179
+ entities.append({"token": current_word, "label": current_label})
180
+
181
+ return entities
182
+
183
+ def _predict_sentiment(self, text: str) -> Dict[str, Any]:
184
+ """
185
+ ΰΈ„ΰΈ·ΰΈ™ {"label": str, "confidence": float, "scores": dict}
186
+ """
187
+ encoded = self.tokenizer.batch_encode(
188
+ [text], max_length=512, padding=False, return_tensors=False
189
+ )
190
+ hidden, mask = self._encode(
191
+ encoded["input_ids"][0],
192
+ encoded["attention_mask"][0],
193
+ )
194
+ mask_tensor = torch.tensor(
195
+ [encoded["attention_mask"][0]], dtype=torch.long
196
+ ).to(self.device)
197
+
198
+ logits = self.model.sentiment_head(hidden, mask_tensor) # (1, 3)
199
+ probs = logits.softmax(dim=-1)[0].tolist()
200
+
201
+ pred_id = int(logits.argmax(dim=-1).item())
202
+ pred_label = SENTIMENT_ID2LABEL[pred_id]
203
+ confidence = round(probs[pred_id], 4)
204
+
205
+ return {
206
+ "label": pred_label,
207
+ "confidence": confidence,
208
+ "scores": {
209
+ SENTIMENT_ID2LABEL[i]: round(p, 4)
210
+ for i, p in enumerate(probs)
211
+ },
212
+ }
213
+
214
+ def _predict_qa(self, question: str, context: str) -> Dict[str, Any]:
215
+ """
216
+ ΰΈ„ΰΈ·ΰΈ™ {"answer": str, "start": int, "end": int, "confidence": float}
217
+ """
218
+ encoded = self.tokenizer.encode_qa(question, context, max_length=512)
219
+ input_ids = encoded["input_ids"]
220
+ attn_mask = encoded["attention_mask"]
221
+ context_start = encoded["context_start"]
222
+
223
+ hidden, _ = self._encode(input_ids, attn_mask)
224
+
225
+ start_logits, end_logits = self.model.qa_head(
226
+ hidden, context_start=context_start
227
+ ) # (1, T) each
228
+
229
+ # ΰΈ«ΰΈ² (start, end) ΰΈ—ΰΈ΅ΰΉˆΰΉƒΰΈ«ΰΉ‰ score ΰΈͺΰΈΉΰΈ‡ΰΈͺΰΈΈΰΈ”ΰΉ‚ΰΈ”ΰΈ’ start ≀ end
230
+ start_logits = start_logits[0] # (T,)
231
+ end_logits = end_logits[0] # (T,)
232
+ seq_len = len(input_ids)
233
+
234
+ best_score = float("-inf")
235
+ best_start = context_start
236
+ best_end = context_start
237
+
238
+ # ΰΈˆΰΈ³ΰΈΰΈ±ΰΈ” span ΰΉ„ΰΈ‘ΰΉˆΰΉ€ΰΈΰΈ΄ΰΈ™ 30 tokens (ΰΈ„ΰΈ³ΰΈ•ΰΈ­ΰΈšΰΈ’ΰΈ²ΰΈ§ΰΉ€ΰΈΰΈ΄ΰΈ™ΰΈ™ΰΈ΅ΰΉ‰ΰΉ„ΰΈ‘ΰΉˆΰΈͺΰΈ‘ΰΉ€ΰΈ«ΰΈ•ΰΈΈΰΈͺฑผΰΈ₯)
239
+ MAX_ANSWER_LEN = 30
240
+
241
+ for s in range(context_start, seq_len):
242
+ for e in range(s, min(s + MAX_ANSWER_LEN, seq_len)):
243
+ score = start_logits[s].item() + end_logits[e].item()
244
+ if score > best_score:
245
+ best_score = score
246
+ best_start = s
247
+ best_end = e
248
+
249
+ # decode answer กΰΈ₯ΰΈ±ΰΈšΰΉ€ΰΈ›ΰΉ‡ΰΈ™ string
250
+ answer_ids = input_ids[best_start:best_end + 1]
251
+ answer = self.tokenizer.decode(answer_ids, skip_special_tokens=True)
252
+
253
+ # confidence = softmax score ΰΈ‚ΰΈ­ΰΈ‡ best span (normalize ΰΈ„ΰΈ£ΰΉˆΰΈ²ΰΈ§ΰΉ†)
254
+ start_probs = start_logits.softmax(dim=-1)
255
+ end_probs = end_logits.softmax(dim=-1)
256
+ confidence = round(
257
+ (start_probs[best_start] * end_probs[best_end]).item(), 4
258
+ )
259
+
260
+ return {
261
+ "answer": answer,
262
+ "start": best_start,
263
+ "end": best_end,
264
+ "confidence": confidence,
265
+ }