import torch import joblib import torch.nn.functional as F from src.data_processing import clean_text, encode, pad_sequence from src.model_def import EmotionTransformer # Load label encoder and vocab le = joblib.load("label_encoder.pkl") vocab = joblib.load("vocab.pkl") # if you saved a vocab dict # Initialize and load PyTorch model model = EmotionTransformer( vocab_size=len(vocab), embed_dim=64, num_heads=4, num_classes=len(le.classes_) ) model.load_state_dict(torch.load("emotion_transformer_model.pth", map_location="cpu")) model.eval() # Predict a single text sample def predict(text, max_len=128): clean = clean_text(text) seq = encode(clean, vocab) seq = pad_sequence(seq, max_len) with torch.no_grad(): logits = model(torch.tensor([seq])) probs = F.softmax(logits, dim=1) idx = probs.argmax(dim=1).item() return le.inverse_transform([idx])[0]