thai-bilstm-sentiment / model /load_model.py
Dusit-P's picture
Update model/load_model.py
e5833e5 verified
raw
history blame contribute delete
725 Bytes
import torch
from transformers import AutoTokenizer, AutoModel
from model.bilstm_model import BiLSTMSentiment
from config import DEVICE, MODEL_PATH, HIDDEN_SIZE, NUM_CLASSES
# โหลด tokenizer และ BERT
tokenizer = AutoTokenizer.from_pretrained(
"airesearch/wangchanberta-base-att-spm-uncased",
use_fast=False,
trust_remote_code=True
)
bert_model = AutoModel.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased").to(DEVICE)
# สร้าง model
INPUT_SIZE = bert_model.config.hidden_size + 1 # BERT + sentiment score
model_lstm = BiLSTMSentiment(INPUT_SIZE, HIDDEN_SIZE, NUM_CLASSES).to(DEVICE)
model_lstm.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model_lstm.eval()