File size: 1,843 Bytes
48e0979
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# infer.py
import os, sys, json, torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from safetensors.torch import load_file

# ใช้สถาปัตยกรรมร่วม
sys.path.append(os.path.join(os.path.dirname(__file__), "common"))
from models import create_model_by_name

def load_model(model_dir: str):
    cfg_path = os.path.join(model_dir, "config.json")
    w_path   = os.path.join(model_dir, "model.safetensors")
    if not (os.path.exists(cfg_path) and os.path.exists(w_path)):
        raise FileNotFoundError("config.json หรือ model.safetensors ไม่ครบ")

    with open(cfg_path, "r", encoding="utf-8") as f:
        cfg = json.load(f)

    tok = AutoTokenizer.from_pretrained(cfg["base_model"])
    model = create_model_by_name(cfg["arch"])
    state = load_file(w_path)
    model.load_state_dict(state)
    model.eval()
    return model, tok, cfg

def predict(texts, model, tok, cfg):
    enc = tok(texts, padding=True, truncation=True, max_length=cfg["max_len"], return_tensors="pt")
    with torch.no_grad():
        logits = model(enc["input_ids"], enc["attention_mask"])
        prob = F.softmax(logits, dim=1).cpu().numpy()
        pred = prob.argmax(1)
    return pred, prob

if __name__ == "__main__":
    # เลือกโฟลเดอร์โมเดล: "baseline" หรือ "cnn_bilstm"
    MODEL_DIR = sys.argv[1] if len(sys.argv) > 1 else "cnn_bilstm"

    model, tok, cfg = load_model(MODEL_DIR)
    xs = ["อาหารอร่อยมาก บริการดี", "ไม่ประทับใจเลย ช้ามาก"]
    y, p = predict(xs, model, tok, cfg)
    labels = ["negative", "positive"]
    for t, yy, pp in zip(xs, y, p):
        print(f"{t} => {labels[yy]} | prob={pp}")