| import torch |
| import os |
| import joblib |
| import logging |
| import sys |
| from typing import List, Union |
|
|
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
| |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| from src.data_loader import clean_text |
|
|
|
|
| |
| |
| |
|
|
| MODEL_DIR = "model_output" |
| MAX_LEN = 192 |
|
|
| THRESHOLDS = { |
| "severe_toxic": 0.85, |
| "toxic": 0.70 |
| } |
|
|
| |
| |
| |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s - %(levelname)s - %(message)s" |
| ) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| |
| |
|
|
| class ToxicityPredictor: |
|
|
| def __init__(self, model_dir: str = MODEL_DIR): |
|
|
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| logger.info(f"Inference device: {self.device}") |
|
|
| try: |
| self.tokenizer = AutoTokenizer.from_pretrained(model_dir) |
| self.model = AutoModelForSequenceClassification.from_pretrained(model_dir) |
| self.model.to(self.device) |
| self.model.eval() |
|
|
| le_path = os.path.join(model_dir, "label_encoder.joblib") |
| self.label_encoder = joblib.load(le_path) |
| self.labels = list(self.label_encoder.classes_) |
|
|
| logger.info("Model, tokenizer, and label encoder loaded successfully.") |
|
|
| except Exception as e: |
| logger.error(f"Failed to load model resources: {e}") |
| raise |
|
|
|
|
| |
| |
| |
|
|
| def _apply_thresholds(self, scores): |
|
|
| top_idx = scores.argmax() |
| top_label = self.labels[top_idx] |
| confidence = float(scores[top_idx]) |
|
|
| final_label = "review" |
|
|
| if top_label == "severe_toxic" and confidence >= THRESHOLDS["severe_toxic"]: |
| final_label = "severe_toxic" |
|
|
| elif top_label == "toxic" and confidence >= THRESHOLDS["toxic"]: |
| final_label = "toxic" |
|
|
| elif top_label == "non_toxic": |
| final_label = "non_toxic" |
|
|
| return final_label, top_label, confidence |
|
|
|
|
| def predict(self, text: str) -> dict: |
|
|
| if not isinstance(text, str) or not text.strip(): |
| raise ValueError("Input text must be a non-empty string.") |
|
|
| cleaned_text = clean_text(text) |
|
|
| inputs = self.tokenizer( |
| cleaned_text, |
| return_tensors="pt", |
| truncation=True, |
| padding=True, |
| max_length=MAX_LEN |
| ) |
|
|
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| outputs = self.model(**inputs) |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
|
| scores = probs.cpu().numpy()[0] |
|
|
| final_label, raw_label, confidence = self._apply_thresholds(scores) |
|
|
| score_dict = { |
| label: float(score) |
| for label, score in zip(self.labels, scores) |
| } |
|
|
| return { |
| "text": text, |
| "cleaned_text": cleaned_text, |
| "prediction": final_label, |
| "raw_prediction": raw_label, |
| "confidence": confidence, |
| "scores": score_dict |
| } |
|
|
|
|
| |
| |
| |
|
|
| def predict_batch(self, texts: List[str]) -> List[dict]: |
|
|
| if not isinstance(texts, list) or not texts: |
| raise ValueError("Input must be a non-empty list of strings.") |
|
|
| cleaned_texts = [clean_text(t) for t in texts] |
|
|
| inputs = self.tokenizer( |
| cleaned_texts, |
| return_tensors="pt", |
| truncation=True, |
| padding=True, |
| max_length=MAX_LEN |
| ) |
|
|
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| outputs = self.model(**inputs) |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
|
| results = [] |
|
|
| for i, scores in enumerate(probs.cpu().numpy()): |
|
|
| final_label, raw_label, confidence = self._apply_thresholds(scores) |
|
|
| score_dict = { |
| label: float(score) |
| for label, score in zip(self.labels, scores) |
| } |
|
|
| results.append({ |
| "text": texts[i], |
| "cleaned_text": cleaned_texts[i], |
| "prediction": final_label, |
| "raw_prediction": raw_label, |
| "confidence": confidence, |
| "scores": score_dict |
| }) |
|
|
| return results |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
|
|
| predictor = ToxicityPredictor() |
|
|
| test_texts = [ |
| "Unakku arive illaya?", |
| "You are a very good person.", |
| "Nee oru loosu", |
| "Enda ipdi panra?", |
| "This is a wonderful place.", |
| "k1ll urself", |
| "nee saavu da 😂", |
| "People like you ruin everything.", |
| "I strongly disagree with your opinion." |
| ] |
|
|
| print("\n" + "=" * 60) |
| print("Production Inference Test") |
| print("=" * 60) |
|
|
| for result in predictor.predict_batch(test_texts): |
| print(f"\nInput: {result['text']}") |
| print(f"Final Prediction: {result['prediction']}") |
| print(f"Raw Prediction: {result['raw_prediction']}") |
| print(f"Confidence: {result['confidence']:.4f}") |
| print(f"Scores: {result['scores']}") |
|
|
| print("\n" + "=" * 60) |