# infer.py import os, sys, json, torch import torch.nn.functional as F from transformers import AutoTokenizer from safetensors.torch import load_file # ใช้สถาปัตยกรรมร่วม sys.path.append(os.path.join(os.path.dirname(__file__), "common")) from models import create_model_by_name def load_model(model_dir: str): cfg_path = os.path.join(model_dir, "config.json") w_path = os.path.join(model_dir, "model.safetensors") if not (os.path.exists(cfg_path) and os.path.exists(w_path)): raise FileNotFoundError("config.json หรือ model.safetensors ไม่ครบ") with open(cfg_path, "r", encoding="utf-8") as f: cfg = json.load(f) tok = AutoTokenizer.from_pretrained(cfg["base_model"]) model = create_model_by_name(cfg["arch"]) state = load_file(w_path) model.load_state_dict(state) model.eval() return model, tok, cfg def predict(texts, model, tok, cfg): enc = tok(texts, padding=True, truncation=True, max_length=cfg["max_len"], return_tensors="pt") with torch.no_grad(): logits = model(enc["input_ids"], enc["attention_mask"]) prob = F.softmax(logits, dim=1).cpu().numpy() pred = prob.argmax(1) return pred, prob if __name__ == "__main__": # เลือกโฟลเดอร์โมเดล: "baseline" หรือ "cnn_bilstm" MODEL_DIR = sys.argv[1] if len(sys.argv) > 1 else "cnn_bilstm" model, tok, cfg = load_model(MODEL_DIR) xs = ["อาหารอร่อยมาก บริการดี", "ไม่ประทับใจเลย ช้ามาก"] y, p = predict(xs, model, tok, cfg) labels = ["negative", "positive"] for t, yy, pp in zip(xs, y, p): print(f"{t} => {labels[yy]} | prob={pp}")