File size: 4,591 Bytes
59d1475 54c2a57 5ed3596 54c2a57 025ec54 54c2a57 5ed3596 54c2a57 5ed3596 54c2a57 5ed3596 54c2a57 5ed3596 59d66b9 54c2a57 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | import os
from typing import List, Dict, Any
from app.core.config import settings
import torch
import numpy as np
from scipy.special import softmax
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
class SentimentService:
"""
A service for loading the sentiment analysis model and performing predictions.
"""
def __init__(self) -> None:
"""
Initialize the service by loading the sentiment analysis model and tokenizer.
"""
# Select device (GPU if available, otherwise CPU)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if self.device.type == "cuda":
print(
f"GPU found: {torch.cuda.get_device_name(0)}. Loading model onto GPU."
)
else:
print("GPU not found. Loading model onto CPU.")
# Load model, tokenizer, and config (for id2label mapping)
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
MODEL_DIR = os.path.join(
BASE_DIR, "models", "twitter-roberta-base-sentiment-latest"
)
print(MODEL_DIR)
if not os.path.exists(MODEL_DIR):
raise FileNotFoundError(f"Model folder not found: {MODEL_DIR}")
# Load tokenizer, config, model
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
self.config = AutoConfig.from_pretrained(MODEL_DIR)
self.model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR).to(
self.device
)
self.model.eval() # set model to inference mode
print("Sentiment model loaded successfully.")
def _preprocess_text(self, text: str) -> str:
"""
Replace @user mentions and http links with placeholders.
"""
if not isinstance(text, str):
return ""
new_text = []
for t in text.split(" "):
t = "@user" if t.startswith("@") and len(t) > 1 else t
t = "http" if t.startswith("http") else t
new_text.append(t)
return " ".join(new_text)
def predict(self, texts: List[str]) -> List[Dict[str, Any]]:
"""
Predict sentiment for a batch of texts, splitting into sub-batches
for efficiency on CPU.
"""
# Preprocess all texts
preprocessed_texts = [self._preprocess_text(text) for text in texts]
# Keep only non-empty texts and remember their original indices
non_empty_texts_with_indices = [
(i, text) for i, text in enumerate(preprocessed_texts) if text.strip()
]
if not non_empty_texts_with_indices:
return []
indices, texts_to_predict = zip(*non_empty_texts_with_indices)
# --- Define batch size for CPU ---
batch_size = settings.INFERENCE_BATCH_SIZE
predictions = []
# --- Process in chunks ---
for start in range(0, len(texts_to_predict), batch_size):
sub_texts = texts_to_predict[start : start + batch_size]
# Tokenize
encoded_inputs = self.tokenizer(
list(sub_texts),
return_tensors="pt",
padding=True,
truncation=True,
max_length=512,
).to(self.device)
# Inference
with torch.no_grad():
outputs = self.model(**encoded_inputs)
logits = outputs.logits.detach().cpu().numpy()
# Clear memory
del encoded_inputs, outputs
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Softmax + map to labels
probs = softmax(logits, axis=1)
for prob in probs:
max_idx = int(np.argmax(prob))
predictions.append(
{
"label": self.config.id2label[max_idx],
"score": float(prob[max_idx]),
}
)
# print(f" - Processed batch {start // batch_size + 1}...")
# Map predictions back to their original positions
final_results: List[Dict[str, Any] | None] = [None] * len(texts)
for original_index, prediction in zip(indices, predictions):
final_results[original_index] = prediction
# Replace None results with a default neutral prediction
default_prediction = {"label": "neutral", "score": 1.0}
return [res if res is not None else default_prediction for res in final_results]
|