Spaces:
Sleeping
Sleeping
File size: 4,094 Bytes
46795fe 7e3db74 46795fe 7e3db74 46795fe 666b4cd 46795fe 7e3db74 46795fe 7e3db74 666b4cd 46795fe 7e3db74 46795fe 7e3db74 46795fe 666b4cd 46795fe 7e3db74 46795fe | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 | # main.py
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from lime.lime_text import LimeTextExplainer
import numpy as np
import os
app = FastAPI(title="MedGuard API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
MODEL_PATH = "./model"
DEVICE = "cpu"
print(f"🔄 Loading Model from {MODEL_PATH}...")
model = None
tokenizer = None
# --- CRITICAL FIX: MATCH TRAINING LABEL MAP ---
# Training Map: {'Not Relevant': 0, 'Partially Relevant': 1, 'Highly Relevant': 2}
# This list MUST follow the index order: [Index 0, Index 1, Index 2]
LABELS = ["Not Relevant", "Partially Relevant", "Highly Relevant"]
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
model.to(DEVICE)
model.eval()
# Validation check (Optional but good)
if model.config.id2label:
print(f"ℹ️ Model config labels: {model.config.id2label}")
# We enforce our manual list because sometimes configs get messed up during saving
# but you should visually verify if this print matches our LABELS list
print(f"✅ Model Loaded! Label Mapping: {LABELS}")
except Exception as e:
print(f"❌ Error loading local model: {e}")
MODEL_NAME = "csebuetnlp/banglabert"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3)
class QueryRequest(BaseModel):
genre: str = ""
prompt: str = ""
text: str
class PredictionResponse(BaseModel):
label: str
confidence: float
probs: dict
explanation: list = None
def predict_proba_lime(texts):
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=128).to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
return torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy()
@app.get("/")
def health_check():
return {"status": "active", "model": "MedGuard v2.3 (Fixed Labels)"}
@app.post("/predict", response_model=PredictionResponse)
def predict(request: QueryRequest):
if not model or not tokenizer:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
# Use simple space concatenation
parts = [part for part in [request.genre, request.prompt, request.text] if part]
full_input = " ".join(parts)
print(f"📥 Analyzing: {full_input[:50]}...")
inputs = tokenizer(full_input, return_tensors="pt", truncation=True, max_length=128).to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=-1).cpu().numpy()[0]
pred_idx = np.argmax(probs)
# Ensure index is valid
if pred_idx >= len(LABELS):
label_str = "Unknown"
else:
label_str = LABELS[pred_idx]
explainer = LimeTextExplainer(
class_names=LABELS,
split_expression=lambda x: x.split()
)
exp = explainer.explain_instance(
full_input,
predict_proba_lime,
num_features=6,
num_samples=40,
labels=[pred_idx]
)
lime_features = exp.as_list(label=pred_idx)
return {
"label": label_str,
"confidence": round(float(probs[pred_idx]) * 100, 2),
"probs": {l: round(float(p), 4) for l, p in zip(LABELS, probs)},
"explanation": lime_features
}
except Exception as e:
print(f"🔥 Server Error: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |