Gokulavanan
Deploy toxic comment classifier
030f249
Raw
History Blame Contribute Delete
5.68 kB
import torch
import os
import joblib
import logging
import sys
from typing import List, Union
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Add project root
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.data_loader import clean_text
# =============================
# CONFIG
# =============================
MODEL_DIR = "model_output"
MAX_LEN = 192
THRESHOLDS = {
"severe_toxic": 0.85,
"toxic": 0.70
}
# =============================
# LOGGING
# =============================
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# =============================
# PREDICTOR CLASS
# =============================
class ToxicityPredictor:
def __init__(self, model_dir: str = MODEL_DIR):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Inference device: {self.device}")
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir)
self.model.to(self.device)
self.model.eval()
le_path = os.path.join(model_dir, "label_encoder.joblib")
self.label_encoder = joblib.load(le_path)
self.labels = list(self.label_encoder.classes_)
logger.info("Model, tokenizer, and label encoder loaded successfully.")
except Exception as e:
logger.error(f"Failed to load model resources: {e}")
raise
# =============================
# CORE PREDICTION LOGIC
# =============================
def _apply_thresholds(self, scores):
top_idx = scores.argmax()
top_label = self.labels[top_idx]
confidence = float(scores[top_idx])
final_label = "review"
if top_label == "severe_toxic" and confidence >= THRESHOLDS["severe_toxic"]:
final_label = "severe_toxic"
elif top_label == "toxic" and confidence >= THRESHOLDS["toxic"]:
final_label = "toxic"
elif top_label == "non_toxic":
final_label = "non_toxic"
return final_label, top_label, confidence
def predict(self, text: str) -> dict:
if not isinstance(text, str) or not text.strip():
raise ValueError("Input text must be a non-empty string.")
cleaned_text = clean_text(text)
inputs = self.tokenizer(
cleaned_text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=MAX_LEN
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
scores = probs.cpu().numpy()[0]
final_label, raw_label, confidence = self._apply_thresholds(scores)
score_dict = {
label: float(score)
for label, score in zip(self.labels, scores)
}
return {
"text": text,
"cleaned_text": cleaned_text,
"prediction": final_label,
"raw_prediction": raw_label,
"confidence": confidence,
"scores": score_dict
}
# =============================
# BATCH PREDICTION
# =============================
def predict_batch(self, texts: List[str]) -> List[dict]:
if not isinstance(texts, list) or not texts:
raise ValueError("Input must be a non-empty list of strings.")
cleaned_texts = [clean_text(t) for t in texts]
inputs = self.tokenizer(
cleaned_texts,
return_tensors="pt",
truncation=True,
padding=True,
max_length=MAX_LEN
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
results = []
for i, scores in enumerate(probs.cpu().numpy()):
final_label, raw_label, confidence = self._apply_thresholds(scores)
score_dict = {
label: float(score)
for label, score in zip(self.labels, scores)
}
results.append({
"text": texts[i],
"cleaned_text": cleaned_texts[i],
"prediction": final_label,
"raw_prediction": raw_label,
"confidence": confidence,
"scores": score_dict
})
return results
# =============================
# TEST BLOCK
# =============================
if __name__ == "__main__":
predictor = ToxicityPredictor()
test_texts = [
"Unakku arive illaya?",
"You are a very good person.",
"Nee oru loosu",
"Enda ipdi panra?",
"This is a wonderful place.",
"k1ll urself",
"nee saavu da 😂",
"People like you ruin everything.",
"I strongly disagree with your opinion."
]
print("\n" + "=" * 60)
print("Production Inference Test")
print("=" * 60)
for result in predictor.predict_batch(test_texts):
print(f"\nInput: {result['text']}")
print(f"Final Prediction: {result['prediction']}")
print(f"Raw Prediction: {result['raw_prediction']}")
print(f"Confidence: {result['confidence']:.4f}")
print(f"Scores: {result['scores']}")
print("\n" + "=" * 60)