|
|
| import os, sys, json, torch
|
| import torch.nn.functional as F
|
| from transformers import AutoTokenizer
|
| from safetensors.torch import load_file
|
|
|
|
|
| sys.path.append(os.path.join(os.path.dirname(__file__), "common"))
|
| from models import create_model_by_name
|
|
|
| def load_model(model_dir: str):
|
| cfg_path = os.path.join(model_dir, "config.json")
|
| w_path = os.path.join(model_dir, "model.safetensors")
|
| if not (os.path.exists(cfg_path) and os.path.exists(w_path)):
|
| raise FileNotFoundError("config.json หรือ model.safetensors ไม่ครบ")
|
|
|
| with open(cfg_path, "r", encoding="utf-8") as f:
|
| cfg = json.load(f)
|
|
|
| tok = AutoTokenizer.from_pretrained(cfg["base_model"])
|
| model = create_model_by_name(cfg["arch"])
|
| state = load_file(w_path)
|
| model.load_state_dict(state)
|
| model.eval()
|
| return model, tok, cfg
|
|
|
| def predict(texts, model, tok, cfg):
|
| enc = tok(texts, padding=True, truncation=True, max_length=cfg["max_len"], return_tensors="pt")
|
| with torch.no_grad():
|
| logits = model(enc["input_ids"], enc["attention_mask"])
|
| prob = F.softmax(logits, dim=1).cpu().numpy()
|
| pred = prob.argmax(1)
|
| return pred, prob
|
|
|
| if __name__ == "__main__":
|
|
|
| MODEL_DIR = sys.argv[1] if len(sys.argv) > 1 else "cnn_bilstm"
|
|
|
| model, tok, cfg = load_model(MODEL_DIR)
|
| xs = ["อาหารอร่อยมาก บริการดี", "ไม่ประทับใจเลย ช้ามาก"]
|
| y, p = predict(xs, model, tok, cfg)
|
| labels = ["negative", "positive"]
|
| for t, yy, pp in zip(xs, y, p):
|
| print(f"{t} => {labels[yy]} | prob={pp}")
|
|
|