eumora-api / backend /src /model.py
VivDubs's picture
refactor: move backend files into backend/ directory
9eb5faa
Raw
History Blame Contribute Delete
2.35 kB
"""Lyrical emotion classification model."""
import torch
from transformers import AutoModelForSequenceClassification, AutoConfig
from pathlib import Path
from .config import config
class LyricEmotionClassifier:
"""RoBERTa-based emotion classifier for song lyrics."""
def __init__(self, model_name: str = None, num_labels: int = None):
self.model_name = model_name or config.model_name
self.num_labels = num_labels or config.num_labels
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = None
def load_pretrained(self) -> "LyricEmotionClassifier":
"""Load pretrained model for fine-tuning."""
model_config = AutoConfig.from_pretrained(
self.model_name,
num_labels=self.num_labels,
id2label={i: label for i, label in enumerate(config.label_names)},
label2id={label: i for i, label in enumerate(config.label_names)},
)
self.model = AutoModelForSequenceClassification.from_pretrained(
self.model_name,
config=model_config
)
self.model.to(self.device)
return self
def load_finetuned(self, path: Path) -> "LyricEmotionClassifier":
"""Load a fine-tuned model from disk."""
self.model = AutoModelForSequenceClassification.from_pretrained(path)
self.model.to(self.device)
return self
def save(self, path: Path):
"""Save model to disk."""
path.mkdir(parents=True, exist_ok=True)
self.model.save_pretrained(path)
def predict(self, inputs: dict) -> dict:
"""Run inference on tokenized inputs."""
self.model.eval()
# Move inputs to device
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)
predictions = torch.argmax(probs, dim=-1)
return {
"predictions": predictions.cpu().numpy(),
"probabilities": probs.cpu().numpy(),
"labels": [config.label_names[p] for p in predictions.cpu().numpy()]
}
def get_model() -> LyricEmotionClassifier:
"""Factory function to get classifier instance."""
return LyricEmotionClassifier().load_pretrained()