File size: 956 Bytes
8c1cff0
70dd4f9
03d3ea1
 
8c1cff0
db008b2
03d3ea1
 
 
8c1cff0
70dd4f9
03d3ea1
 
 
 
80c2210
03d3ea1
80c2210
03d3ea1
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import joblib
import torch
import torch.nn.functional as F
from src.data_processing import clean_text
from src.model_def import EmotionTransformer

# Load artifacts
vocab = joblib.load('vocab.pkl')
le    = joblib.load('label_encoder.pkl')
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Recreate model
model = EmotionTransformer(len(vocab), num_classes=len(le.classes_)).to(DEVICE)
model.load_state_dict(torch.load('emotion_transformer_model.pth', map_location=DEVICE))
model.eval()

MAX_LEN = 32

def predict(text):
    toks = clean_text(text).split()
    idxs = [vocab.get(tok,1) for tok in toks]
    pad  = (idxs + [0]*MAX_LEN)[:MAX_LEN]
    x = torch.tensor([pad], dtype=torch.long).to(DEVICE)

    # MC-dropout inference
    model.train()
    with torch.no_grad():
        probs = torch.stack([F.softmax(model(x), dim=1) for _ in range(5)])
    avg = probs.mean(dim=0)
    return le.inverse_transform([avg.argmax().item()])[0]