Spaces:
Sleeping
Sleeping
| # main.py | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from lime.lime_text import LimeTextExplainer | |
| import numpy as np | |
| import os | |
| app = FastAPI(title="MedGuard API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| MODEL_PATH = "./model" | |
| DEVICE = "cpu" | |
| print(f"π Loading Model from {MODEL_PATH}...") | |
| model = None | |
| tokenizer = None | |
| # --- CRITICAL FIX: MATCH TRAINING LABEL MAP --- | |
| # Training Map: {'Not Relevant': 0, 'Partially Relevant': 1, 'Highly Relevant': 2} | |
| # This list MUST follow the index order: [Index 0, Index 1, Index 2] | |
| LABELS = ["Not Relevant", "Partially Relevant", "Highly Relevant"] | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) | |
| model.to(DEVICE) | |
| model.eval() | |
| # Validation check (Optional but good) | |
| if model.config.id2label: | |
| print(f"βΉοΈ Model config labels: {model.config.id2label}") | |
| # We enforce our manual list because sometimes configs get messed up during saving | |
| # but you should visually verify if this print matches our LABELS list | |
| print(f"β Model Loaded! Label Mapping: {LABELS}") | |
| except Exception as e: | |
| print(f"β Error loading local model: {e}") | |
| MODEL_NAME = "csebuetnlp/banglabert" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3) | |
| class QueryRequest(BaseModel): | |
| genre: str = "" | |
| prompt: str = "" | |
| text: str | |
| class PredictionResponse(BaseModel): | |
| label: str | |
| confidence: float | |
| probs: dict | |
| explanation: list = None | |
| def predict_proba_lime(texts): | |
| inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=128).to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| return torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy() | |
| def health_check(): | |
| return {"status": "active", "model": "MedGuard v2.3 (Fixed Labels)"} | |
| def predict(request: QueryRequest): | |
| if not model or not tokenizer: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| # Use simple space concatenation | |
| parts = [part for part in [request.genre, request.prompt, request.text] if part] | |
| full_input = " ".join(parts) | |
| print(f"π₯ Analyzing: {full_input[:50]}...") | |
| inputs = tokenizer(full_input, return_tensors="pt", truncation=True, max_length=128).to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = F.softmax(outputs.logits, dim=-1).cpu().numpy()[0] | |
| pred_idx = np.argmax(probs) | |
| # Ensure index is valid | |
| if pred_idx >= len(LABELS): | |
| label_str = "Unknown" | |
| else: | |
| label_str = LABELS[pred_idx] | |
| explainer = LimeTextExplainer( | |
| class_names=LABELS, | |
| split_expression=lambda x: x.split() | |
| ) | |
| exp = explainer.explain_instance( | |
| full_input, | |
| predict_proba_lime, | |
| num_features=6, | |
| num_samples=40, | |
| labels=[pred_idx] | |
| ) | |
| lime_features = exp.as_list(label=pred_idx) | |
| return { | |
| "label": label_str, | |
| "confidence": round(float(probs[pred_idx]) * 100, 2), | |
| "probs": {l: round(float(p), 4) for l, p in zip(LABELS, probs)}, | |
| "explanation": lime_features | |
| } | |
| except Exception as e: | |
| print(f"π₯ Server Error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |