a1 / src /inference.py
opinder2906's picture
Rename .src/inference.py to src/inference.py
18acb02 verified
raw
history blame
901 Bytes
import torch
import joblib
import torch.nn.functional as F
from src.data_processing import clean_text, encode, pad_sequence
from src.model_def import EmotionTransformer
# Load label encoder and vocab
le = joblib.load("label_encoder.pkl")
vocab = joblib.load("vocab.pkl") # if you saved a vocab dict
# Initialize and load PyTorch model
model = EmotionTransformer(
vocab_size=len(vocab), embed_dim=64, num_heads=4, num_classes=len(le.classes_)
)
model.load_state_dict(torch.load("emotion_transformer_model.pth", map_location="cpu"))
model.eval()
# Predict a single text sample
def predict(text, max_len=128):
clean = clean_text(text)
seq = encode(clean, vocab)
seq = pad_sequence(seq, max_len)
with torch.no_grad():
logits = model(torch.tensor([seq]))
probs = F.softmax(logits, dim=1)
idx = probs.argmax(dim=1).item()
return le.inverse_transform([idx])[0]