# -*- coding: utf-8 -*- """ Inference Module - Model Prediction """ import os import torch # Model path MODEL_SAVE_PATH = '../best_model' # Emotion labels EMOTION_LABELS = [ "Neutral", "Anxiety/Fear", "Anger/Frustration", "Sadness/Helplessness", "Confusion/Doubt", "Gratitude/Relief" ] try: from transformers import AutoTokenizer, AutoModelForSequenceClassification MODEL_LOADED = True except ImportError: MODEL_LOADED = False class EmotionClassifier: """Emotion Classification Inference""" def __init__(self): self.tokenizer = None self.model = None self.device = None self.loaded = False def load_model(self, model_path=None): """Load model""" if model_path is None: model_path = MODEL_SAVE_PATH if not MODEL_LOADED: return {'error': 'transformers library not installed'} try: self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.model = AutoModelForSequenceClassification.from_pretrained(model_path) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) self.model.eval() self.loaded = True return { 'success': True, 'device': str(self.device), 'num_labels': len(EMOTION_LABELS), 'labels': EMOTION_LABELS } except Exception as e: return {'error': f'Failed to load model: {str(e)}'} def predict(self, text, max_length=512): """Predict emotion for single text""" if not self.loaded: result = self.load_model() if 'error' in result: return result try: # Tokenize inputs = self.tokenizer( text, return_tensors="pt", padding=True, truncation=True, max_length=max_length ) # Move to device inputs = {k: v.to(self.device) for k, v in inputs.items()} # Inference with torch.no_grad(): outputs = self.model(**inputs) logits = outputs.logits probabilities = torch.softmax(logits, dim=-1) predicted_class = torch.argmax(logits, dim=-1).item() confidence = probabilities[0][predicted_class].item() # Build result all_probs = probabilities[0].cpu().numpy().tolist() label_probs = [ {'label': EMOTION_LABELS[i], 'probability': round(all_probs[i], 4)} for i in range(len(EMOTION_LABELS)) ] return { 'text': text[:100] + '...' if len(text) > 100 else text, 'predicted_label': EMOTION_LABELS[predicted_class], 'predicted_id': predicted_class, 'confidence': round(confidence, 4), 'all_probabilities': label_probs } except Exception as e: return {'error': f'Prediction failed: {str(e)}'} def predict_batch(self, texts, max_length=512): """Batch prediction""" if not self.loaded: result = self.load_model() if 'error' in result: return result results = [] for text in texts: result = self.predict(text, max_length) results.append(result) return results def is_loaded(self): """Check if model is loaded""" return self.loaded # Global classifier instance _classifier_instance = None def get_classifier(): global _classifier_instance if _classifier_instance is None: _classifier_instance = EmotionClassifier() return _classifier_instance