medguard-api / main.py
sumoy47's picture
Update main.py
7e3db74 verified
# main.py
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from lime.lime_text import LimeTextExplainer
import numpy as np
import os
app = FastAPI(title="MedGuard API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
MODEL_PATH = "./model"
DEVICE = "cpu"
print(f"πŸ”„ Loading Model from {MODEL_PATH}...")
model = None
tokenizer = None
# --- CRITICAL FIX: MATCH TRAINING LABEL MAP ---
# Training Map: {'Not Relevant': 0, 'Partially Relevant': 1, 'Highly Relevant': 2}
# This list MUST follow the index order: [Index 0, Index 1, Index 2]
LABELS = ["Not Relevant", "Partially Relevant", "Highly Relevant"]
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
model.to(DEVICE)
model.eval()
# Validation check (Optional but good)
if model.config.id2label:
print(f"ℹ️ Model config labels: {model.config.id2label}")
# We enforce our manual list because sometimes configs get messed up during saving
# but you should visually verify if this print matches our LABELS list
print(f"βœ… Model Loaded! Label Mapping: {LABELS}")
except Exception as e:
print(f"❌ Error loading local model: {e}")
MODEL_NAME = "csebuetnlp/banglabert"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3)
class QueryRequest(BaseModel):
genre: str = ""
prompt: str = ""
text: str
class PredictionResponse(BaseModel):
label: str
confidence: float
probs: dict
explanation: list = None
def predict_proba_lime(texts):
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=128).to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
return torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy()
@app.get("/")
def health_check():
return {"status": "active", "model": "MedGuard v2.3 (Fixed Labels)"}
@app.post("/predict", response_model=PredictionResponse)
def predict(request: QueryRequest):
if not model or not tokenizer:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
# Use simple space concatenation
parts = [part for part in [request.genre, request.prompt, request.text] if part]
full_input = " ".join(parts)
print(f"πŸ“₯ Analyzing: {full_input[:50]}...")
inputs = tokenizer(full_input, return_tensors="pt", truncation=True, max_length=128).to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=-1).cpu().numpy()[0]
pred_idx = np.argmax(probs)
# Ensure index is valid
if pred_idx >= len(LABELS):
label_str = "Unknown"
else:
label_str = LABELS[pred_idx]
explainer = LimeTextExplainer(
class_names=LABELS,
split_expression=lambda x: x.split()
)
exp = explainer.explain_instance(
full_input,
predict_proba_lime,
num_features=6,
num_samples=40,
labels=[pred_idx]
)
lime_features = exp.as_list(label=pred_idx)
return {
"label": label_str,
"confidence": round(float(probs[pred_idx]) * 100, 2),
"probs": {l: round(float(p), 4) for l, p in zip(LABELS, probs)},
"explanation": lime_features
}
except Exception as e:
print(f"πŸ”₯ Server Error: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)