StringJammer's picture
Upload folder using huggingface_hub
22bad87 verified
# -*- coding: utf-8 -*-
"""
Inference Module - Model Prediction
"""
import os
import torch
# Model path
MODEL_SAVE_PATH = '../best_model'
# Emotion labels
EMOTION_LABELS = [
"Neutral",
"Anxiety/Fear",
"Anger/Frustration",
"Sadness/Helplessness",
"Confusion/Doubt",
"Gratitude/Relief"
]
try:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
MODEL_LOADED = True
except ImportError:
MODEL_LOADED = False
class EmotionClassifier:
"""Emotion Classification Inference"""
def __init__(self):
self.tokenizer = None
self.model = None
self.device = None
self.loaded = False
def load_model(self, model_path=None):
"""Load model"""
if model_path is None:
model_path = MODEL_SAVE_PATH
if not MODEL_LOADED:
return {'error': 'transformers library not installed'}
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.model.eval()
self.loaded = True
return {
'success': True,
'device': str(self.device),
'num_labels': len(EMOTION_LABELS),
'labels': EMOTION_LABELS
}
except Exception as e:
return {'error': f'Failed to load model: {str(e)}'}
def predict(self, text, max_length=512):
"""Predict emotion for single text"""
if not self.loaded:
result = self.load_model()
if 'error' in result:
return result
try:
# Tokenize
inputs = self.tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length
)
# Move to device
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Inference
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=-1)
predicted_class = torch.argmax(logits, dim=-1).item()
confidence = probabilities[0][predicted_class].item()
# Build result
all_probs = probabilities[0].cpu().numpy().tolist()
label_probs = [
{'label': EMOTION_LABELS[i], 'probability': round(all_probs[i], 4)}
for i in range(len(EMOTION_LABELS))
]
return {
'text': text[:100] + '...' if len(text) > 100 else text,
'predicted_label': EMOTION_LABELS[predicted_class],
'predicted_id': predicted_class,
'confidence': round(confidence, 4),
'all_probabilities': label_probs
}
except Exception as e:
return {'error': f'Prediction failed: {str(e)}'}
def predict_batch(self, texts, max_length=512):
"""Batch prediction"""
if not self.loaded:
result = self.load_model()
if 'error' in result:
return result
results = []
for text in texts:
result = self.predict(text, max_length)
results.append(result)
return results
def is_loaded(self):
"""Check if model is loaded"""
return self.loaded
# Global classifier instance
_classifier_instance = None
def get_classifier():
global _classifier_instance
if _classifier_instance is None:
_classifier_instance = EmotionClassifier()
return _classifier_instance