a1 / src /inference.py
opinder2906's picture
Update src/inference.py
03d3ea1 verified
import joblib
import torch
import torch.nn.functional as F
from src.data_processing import clean_text
from src.model_def import EmotionTransformer
# Load artifacts
vocab = joblib.load('vocab.pkl')
le = joblib.load('label_encoder.pkl')
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Recreate model
model = EmotionTransformer(len(vocab), num_classes=len(le.classes_)).to(DEVICE)
model.load_state_dict(torch.load('emotion_transformer_model.pth', map_location=DEVICE))
model.eval()
MAX_LEN = 32
def predict(text):
toks = clean_text(text).split()
idxs = [vocab.get(tok,1) for tok in toks]
pad = (idxs + [0]*MAX_LEN)[:MAX_LEN]
x = torch.tensor([pad], dtype=torch.long).to(DEVICE)
# MC-dropout inference
model.train()
with torch.no_grad():
probs = torch.stack([F.softmax(model(x), dim=1) for _ in range(5)])
avg = probs.mean(dim=0)
return le.inverse_transform([avg.argmax().item()])[0]