|
|
| """Inference cho Vietnamese QA Stacking Ensemble v2. Load từ Hugging Face Hub.""" |
| import json |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import AutoModelForQuestionAnswering, AutoTokenizer |
|
|
| try: |
| from huggingface_hub import hf_hub_download |
| except ImportError: |
| hf_hub_download = None |
|
|
|
|
| class MetaCNN(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = nn.Sequential( |
| nn.Conv1d(4, 32, 3, padding=1), |
| nn.ReLU(), |
| nn.Conv1d(32, 32, 3, padding=1), |
| nn.ReLU(), |
| ) |
| self.start_fc = nn.Linear(32, 1) |
| self.end_fc = nn.Linear(32, 1) |
|
|
| def forward(self, x): |
| x = self.conv(x) |
| start = self.start_fc(x.transpose(1, 2)).squeeze(-1) |
| end = self.end_fc(x.transpose(1, 2)).squeeze(-1) |
| return start, end |
|
|
|
|
| def load_ensemble(repo_id: str = None, local_dir: str = None): |
| """ |
| Load ensemble từ Hugging Face hoặc thư mục local. |
| - repo_id: "username/vi-qa-stacking-ensemble-v2" để tải từ Hub |
| - local_dir: đường dẫn thư mục chứa meta_cnn.pth, config.json |
| """ |
| if local_dir: |
| path = Path(local_dir) |
| config_path = path / "config.json" |
| meta_path = path / "meta_cnn.pth" |
| elif repo_id and hf_hub_download: |
| config_path = hf_hub_download(repo_id=repo_id, filename="config.json") |
| meta_path = hf_hub_download(repo_id=repo_id, filename="meta_cnn.pth") |
| path = Path(meta_path).parent |
| else: |
| raise ValueError("Cần repo_id hoặc local_dir") |
|
|
| with open(config_path, encoding="utf-8") as f: |
| config = json.load(f) |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| def _load_tok(mid, use_fast=True): |
| try: |
| return AutoTokenizer.from_pretrained(mid, use_fast=use_fast) |
| except Exception as e: |
| if "sentencepiece" in str(e).lower() and use_fast: |
| return AutoTokenizer.from_pretrained(mid, use_fast=False) |
| raise |
| tokenizer1 = _load_tok(config["base_models"][0]) |
| tokenizer2 = AutoTokenizer.from_pretrained(config["base_models"][1], use_fast=False) |
| model1 = AutoModelForQuestionAnswering.from_pretrained(config["base_models"][0]).to(device) |
| model2 = AutoModelForQuestionAnswering.from_pretrained(config["base_models"][1]).to(device) |
|
|
| meta_model = MetaCNN().to(device) |
| meta_model.load_state_dict(torch.load(meta_path, map_location=device)) |
| meta_model.eval() |
| model1.eval() |
| model2.eval() |
|
|
| return { |
| "tokenizer1": tokenizer1, |
| "tokenizer2": tokenizer2, |
| "model1": model1, |
| "model2": model2, |
| "meta_model": meta_model, |
| "device": device, |
| "max_len_1": config.get("max_length", 512), |
| "max_len_2": config.get("max_len_2", 256), |
| } |
|
|
|
|
| def _pad_to_512(x): |
| if x.size(0) < 512: |
| pad = torch.zeros(512 - x.size(0), dtype=x.dtype, device=x.device) |
| x = torch.cat([x, pad], dim=0) |
| return x[:512] |
|
|
|
|
| def predict(question: str, context: str, ensemble: dict, max_answer_len: int = 30): |
| """Trả về (answer, no_answer_probability).""" |
| t1, t2 = ensemble["tokenizer1"], ensemble["tokenizer2"] |
| m1, m2 = ensemble["model1"], ensemble["model2"] |
| meta = ensemble["meta_model"] |
| dev = ensemble["device"] |
| max1, max2 = ensemble["max_len_1"], ensemble["max_len_2"] |
|
|
| enc1 = t1( |
| question, |
| context, |
| return_tensors="pt", |
| truncation="only_second", |
| max_length=max1, |
| padding="max_length", |
| ) |
| enc2 = t2( |
| question, |
| context, |
| return_tensors="pt", |
| truncation="only_second", |
| max_length=max2, |
| padding="max_length", |
| ) |
|
|
| inp1 = {k: v.to(dev) for k, v in enc1.items()} |
| inp2 = {k: v.to(dev) for k, v in enc2.items()} |
|
|
| try: |
| seq_ids = enc1.sequence_ids(0) |
| except Exception: |
| |
| sep_id = t1.convert_tokens_to_ids(t1.sep_token or "</s>") |
| ids = enc1["input_ids"][0].tolist() |
| sep_pos = [i for i, x in enumerate(ids) if x == sep_id] |
| if len(sep_pos) < 2: |
| return "", 1.0 |
| ctx_idx = list(range(sep_pos[0] + 1, sep_pos[1])) |
| else: |
| ctx_idx = [i for i, s in enumerate(seq_ids) if s == 1] |
| if not ctx_idx: |
| return "", 1.0 |
|
|
| ctx_start, ctx_end = ctx_idx[0], ctx_idx[-1] |
|
|
| with torch.no_grad(): |
| o1, o2 = m1(**inp1), m2(**inp2) |
|
|
| s1 = o1.start_logits[0][:512] |
| e1 = o1.end_logits[0][:512] |
| s2 = _pad_to_512(o2.start_logits[0]) |
| e2 = _pad_to_512(o2.end_logits[0]) |
|
|
| combined = torch.stack([s1, e1, s2, e2], dim=0).unsqueeze(0) |
| with torch.no_grad(): |
| fs, fe = meta(combined) |
|
|
| sp = F.softmax(fs[0], dim=-1) |
| ep = F.softmax(fe[0], dim=-1) |
|
|
| best = -1e9 |
| bs, be = ctx_start, ctx_start |
| for s in range(ctx_start, ctx_end + 1): |
| for e in range(s, min(s + max_answer_len, ctx_end) + 1): |
| sc = torch.log(sp[s] + 1e-12) + torch.log(ep[e] + 1e-12) |
| if sc > best: |
| best, bs, be = sc, s, e |
|
|
| null = torch.log(sp[0] + 1e-12) + torch.log(ep[0] + 1e-12) |
| no_ans = torch.sigmoid(null - best).item() |
| if null > best: |
| return "", no_ans |
|
|
| ans = t1.decode(enc1["input_ids"][0][bs : be + 1], skip_special_tokens=True).strip() |
| return ans, no_ans |
|
|
|
|
| if __name__ == "__main__": |
| import sys |
| if len(sys.argv) >= 2: |
| repo = sys.argv[1] |
| ens = load_ensemble(repo_id=repo) |
| else: |
| ens = load_ensemble(local_dir=".") |
| a, p = predict( |
| "Thủ đô Việt Nam là gì?", |
| "Việt Nam nằm ở Đông Nam Á. Thủ đô là Hà Nội.", |
| ens, |
| ) |
| print(f"Đáp án: {a}, no_answer_prob: {p:.4f}") |
|
|