emotion-classifier / app /ml /predictor.py
AfroLogicInsect's picture
initial migration from Render: FastAPI + Gradio, port 7860
811fc57
Raw
History Blame Contribute Delete
4.63 kB
import os
import json
import torch
import numpy as np
from typing import Dict, List, Optional
from transformers import AutoModelForSequenceClassification, AutoTokenizer
class EmotionPredictor:
def __init__(self, model_path: Optional[str] = None, model_id: Optional[str] = None):
"""
Initialize the emotion predictor with a pre-trained model
Args:
model_path: Path to the saved model directory. If None, will use the
environment variable MODEL_PATH or the default "./model_export"
model_id: HuggingFace Hub model ID (e.g., "your-username/emotion-classifier")
If provided, will load from HuggingFace instead of local path
"""
self.model_path = model_path or os.environ.get("MODEL_PATH", "./model_export")
self.model_id = model_id or os.environ.get("MODEL_ID")
self.model = None
self.tokenizer = None
self.label2id = None
self.id2label = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the model
self._load_model()
def _load_model(self):
"""Load the model and tokenizer from local directory or HuggingFace Hub"""
try:
# If model_id is provided, load from HuggingFace Hub
if self.model_id:
print(f"Loading model from HuggingFace Hub: {self.model_id}")
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
# Get label mappings from model config
self.id2label = self.model.config.id2label
self.label2id = self.model.config.label2id
else:
# Load from local path
print(f"Loading model from local path: {self.model_path}")
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_path)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
# Load label mappings
mappings_path = os.path.join(self.model_path, 'label_mappings.json')
if os.path.exists(mappings_path):
with open(mappings_path, 'r') as f:
mappings = json.load(f)
self.label2id = mappings['label2id']
self.id2label = {int(k): v for k, v in mappings['id2label'].items()}
else:
# If no mappings file, use the model config
self.id2label = self.model.config.id2label
self.label2id = self.model.config.label2id
# Move model to appropriate device
self.model.to(self.device)
self.model.eval()
except Exception as e:
print(f"Error loading model: {e}")
def is_model_loaded(self) -> bool:
"""Check if the model is loaded"""
return self.model is not None and self.tokenizer is not None
def get_labels(self) -> List[str]:
"""Get the list of emotion labels"""
if not self.is_model_loaded():
raise ValueError("Model is not loaded")
return list(self.label2id.keys())
def predict(self, text: str) -> Dict:
"""
Predict the emotion of the input text
Args:
text: Input text to classify
Returns:
Dictionary with prediction results
"""
if not self.is_model_loaded():
raise ValueError("Model is not loaded")
# Tokenize the input text
inputs = self.tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to(self.device)
# Make prediction
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=1).cpu().numpy()[0]
# Get the predicted class and confidence
predicted_class_id = int(np.argmax(probabilities))
predicted_emotion = self.id2label[predicted_class_id]
confidence = float(probabilities[predicted_class_id])
# Create the response with all emotion probabilities
all_emotions = {
self.id2label[i]: float(prob)
for i, prob in enumerate(probabilities)
}
return {
"emotion": predicted_emotion,
"confidence": confidence,
"all_emotions": all_emotions
}