# main.py from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import json import os import sys from datetime import datetime import torch from transformers import DistilBertTokenizer, DistilBertForSequenceClassification import re import shap import numpy as np from pathlib import Path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) os.chdir(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # Config MODEL_DIR = "models" BEST_METRICS_PATH = "models/best_metrics.json" DRIFT_LOG_PATH = "models/drift_log.json" RETRAIN_LOG_PATH = "models/retrain_log.json" app = FastAPI( title="Sentiment ML System", description="Production ML system with DistilBERT", version="2.0.0" ) FRONTEND_URL = os.environ.get("FRONTEND_URL") app.add_middleware( CORSMiddleware, allow_origins=[ FRONTEND_URL ], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load model print("Loading DistilBERT model...") tokenizer = DistilBertTokenizer.from_pretrained(MODEL_DIR) model = DistilBertForSequenceClassification.from_pretrained(MODEL_DIR) model.eval() device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) print(f"✓ DistilBERT loaded on {device}") class ReviewRequest(BaseModel): review: str class PredictionResponse(BaseModel): sentiment: str confidence: float label: int timestamp: str class ExplanationResponse(BaseModel): sentiment: str confidence: float label: int explanation: list timestamp: str def preprocess_text(text): text = text.lower() text = re.sub(r"<.*?>", "", text) text = re.sub(r"[^a-z0-9\s]", "", text) return text.strip() @app.get("/") def root(): return {"status": "running", "message": "Sentiment ML System - DistilBERT"} @app.post("/predict", response_model=PredictionResponse) def predict(request: ReviewRequest): if not request.review.strip(): raise HTTPException(status_code=400, detail="Review text cannot be empty") try: review = preprocess_text(request.review) inputs = tokenizer( review, return_tensors="pt", truncation=True, max_length=256, padding="max_length" ) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probabilities = torch.softmax(logits, dim=-1) label = int(torch.argmax(probabilities, dim=-1).item()) confidence = float(probabilities[0][label].item()) sentiment = "Positive" if label == 1 else "Negative" return PredictionResponse( sentiment=sentiment, confidence=round(confidence, 4), label=label, timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S") ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/metrics") def get_metrics(): response = {} if os.path.exists(BEST_METRICS_PATH): with open(BEST_METRICS_PATH, "r") as f: response["best_model"] = json.load(f) else: response["best_model"] = None if os.path.exists(DRIFT_LOG_PATH): with open(DRIFT_LOG_PATH, "r") as f: response["drift_log"] = json.load(f) else: response["drift_log"] = [] if os.path.exists(RETRAIN_LOG_PATH): with open(RETRAIN_LOG_PATH, "r") as f: response["retrain_log"] = json.load(f) else: response["retrain_log"] = [] return response @app.get("/health") def health(): return { "status": "healthy", "model": "DistilBERT", "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") } @app.post("/explain", response_model=ExplanationResponse) def explain(request: ReviewRequest): if not request.review.strip(): raise HTTPException(status_code=400, detail="Review text cannot be empty") try: review = preprocess_text(request.review) # Get prediction first inputs = tokenizer( review, return_tensors="pt", truncation=True, max_length=256, padding="max_length", return_offsets_mapping=True ) offset_mapping = inputs.pop("offset_mapping")[0] inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probabilities = torch.softmax(logits, dim=-1) label = int(torch.argmax(probabilities, dim=-1).item()) confidence = float(probabilities[0][label].item()) sentiment = "Positive" if label == 1 else "Negative" # SHAP explanation def model_predict(texts): """Wrapper for SHAP""" all_probs = [] for text in texts: text_clean = preprocess_text(text) inputs = tokenizer( text_clean, return_tensors="pt", truncation=True, max_length=256, padding="max_length" ) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) probs = torch.softmax(outputs.logits, dim=-1).cpu().numpy()[0] all_probs.append(probs) return np.array(all_probs) # Create explainer explainer = shap.Explainer(model_predict, tokenizer) # Get SHAP values shap_values = explainer([review]) # Extract word impacts for the predicted class tokens = tokenizer.tokenize(review) token_impacts = shap_values.values[0, :, label] # Map tokens back to words word_impacts = [] current_word = "" current_impact = 0.0 for i, (token, impact) in enumerate(zip(tokens, token_impacts)): if token.startswith("##"): # Continuation of previous word current_word += token[2:] current_impact += impact else: # New word if current_word: word_impacts.append({ "word": current_word, "impact": round(float(current_impact), 4) }) current_word = token current_impact = impact # Add last word if current_word: word_impacts.append({ "word": current_word, "impact": round(float(current_impact), 4) }) # Filter out special tokens and very low impacts word_impacts = [ w for w in word_impacts if w["word"] not in ["[CLS]", "[SEP]", "[PAD]"] and abs(w["impact"]) > 0.01 ] return ExplanationResponse( sentiment=sentiment, confidence=round(confidence, 4), label=label, explanation=word_impacts, timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S") ) except Exception as e: raise HTTPException(status_code=500, detail=str(e))