# main.py from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForSequenceClassification from lime.lime_text import LimeTextExplainer import numpy as np import os app = FastAPI(title="MedGuard API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) MODEL_PATH = "./model" DEVICE = "cpu" print(f"🔄 Loading Model from {MODEL_PATH}...") model = None tokenizer = None # --- CRITICAL FIX: MATCH TRAINING LABEL MAP --- # Training Map: {'Not Relevant': 0, 'Partially Relevant': 1, 'Highly Relevant': 2} # This list MUST follow the index order: [Index 0, Index 1, Index 2] LABELS = ["Not Relevant", "Partially Relevant", "Highly Relevant"] try: tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) model.to(DEVICE) model.eval() # Validation check (Optional but good) if model.config.id2label: print(f"â„šī¸ Model config labels: {model.config.id2label}") # We enforce our manual list because sometimes configs get messed up during saving # but you should visually verify if this print matches our LABELS list print(f"✅ Model Loaded! Label Mapping: {LABELS}") except Exception as e: print(f"❌ Error loading local model: {e}") MODEL_NAME = "csebuetnlp/banglabert" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3) class QueryRequest(BaseModel): genre: str = "" prompt: str = "" text: str class PredictionResponse(BaseModel): label: str confidence: float probs: dict explanation: list = None def predict_proba_lime(texts): inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=128).to(DEVICE) with torch.no_grad(): outputs = model(**inputs) return torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy() @app.get("/") def health_check(): return {"status": "active", "model": "MedGuard v2.3 (Fixed Labels)"} @app.post("/predict", response_model=PredictionResponse) def predict(request: QueryRequest): if not model or not tokenizer: raise HTTPException(status_code=503, detail="Model not loaded") try: # Use simple space concatenation parts = [part for part in [request.genre, request.prompt, request.text] if part] full_input = " ".join(parts) print(f"đŸ“Ĩ Analyzing: {full_input[:50]}...") inputs = tokenizer(full_input, return_tensors="pt", truncation=True, max_length=128).to(DEVICE) with torch.no_grad(): outputs = model(**inputs) probs = F.softmax(outputs.logits, dim=-1).cpu().numpy()[0] pred_idx = np.argmax(probs) # Ensure index is valid if pred_idx >= len(LABELS): label_str = "Unknown" else: label_str = LABELS[pred_idx] explainer = LimeTextExplainer( class_names=LABELS, split_expression=lambda x: x.split() ) exp = explainer.explain_instance( full_input, predict_proba_lime, num_features=6, num_samples=40, labels=[pred_idx] ) lime_features = exp.as_list(label=pred_idx) return { "label": label_str, "confidence": round(float(probs[pred_idx]) * 100, 2), "probs": {l: round(float(p), 4) for l, p in zip(LABELS, probs)}, "explanation": lime_features } except Exception as e: print(f"đŸ”Ĩ Server Error: {e}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)