File size: 901 Bytes
70dd4f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
import joblib
import torch.nn.functional as F
from src.data_processing import clean_text, encode, pad_sequence
from src.model_def import EmotionTransformer

# Load label encoder and vocab
le = joblib.load("label_encoder.pkl")
vocab = joblib.load("vocab.pkl")  # if you saved a vocab dict

# Initialize and load PyTorch model
model = EmotionTransformer(
    vocab_size=len(vocab), embed_dim=64, num_heads=4, num_classes=len(le.classes_)
)
model.load_state_dict(torch.load("emotion_transformer_model.pth", map_location="cpu"))
model.eval()

# Predict a single text sample
def predict(text, max_len=128):
    clean = clean_text(text)
    seq = encode(clean, vocab)
    seq = pad_sequence(seq, max_len)
    with torch.no_grad():
        logits = model(torch.tensor([seq]))
        probs = F.softmax(logits, dim=1)
    idx = probs.argmax(dim=1).item()
    return le.inverse_transform([idx])[0]