Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
|
@@ -11,7 +11,6 @@ import os
|
|
| 11 |
|
| 12 |
app = FastAPI(title="MedGuard API")
|
| 13 |
|
| 14 |
-
# --- CORS CONFIGURATION ---
|
| 15 |
app.add_middleware(
|
| 16 |
CORSMiddleware,
|
| 17 |
allow_origins=["*"],
|
|
@@ -20,7 +19,6 @@ app.add_middleware(
|
|
| 20 |
allow_headers=["*"],
|
| 21 |
)
|
| 22 |
|
| 23 |
-
# --- CONFIGURATION ---
|
| 24 |
MODEL_PATH = "./model"
|
| 25 |
DEVICE = "cpu"
|
| 26 |
|
|
@@ -28,20 +26,31 @@ print(f"🔄 Loading Model from {MODEL_PATH}...")
|
|
| 28 |
model = None
|
| 29 |
tokenizer = None
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
try:
|
| 32 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
| 33 |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
|
| 34 |
model.to(DEVICE)
|
| 35 |
model.eval()
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
except Exception as e:
|
| 38 |
print(f"❌ Error loading local model: {e}")
|
| 39 |
-
# Fallback
|
| 40 |
MODEL_NAME = "csebuetnlp/banglabert"
|
| 41 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 42 |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3)
|
| 43 |
|
| 44 |
-
# --- DATA MODELS ---
|
| 45 |
class QueryRequest(BaseModel):
|
| 46 |
genre: str = ""
|
| 47 |
prompt: str = ""
|
|
@@ -53,8 +62,6 @@ class PredictionResponse(BaseModel):
|
|
| 53 |
probs: dict
|
| 54 |
explanation: list = None
|
| 55 |
|
| 56 |
-
LABELS = ["Highly Relevant", "Partially Relevant", "Not Relevant"]
|
| 57 |
-
|
| 58 |
def predict_proba_lime(texts):
|
| 59 |
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=128).to(DEVICE)
|
| 60 |
with torch.no_grad():
|
|
@@ -63,7 +70,7 @@ def predict_proba_lime(texts):
|
|
| 63 |
|
| 64 |
@app.get("/")
|
| 65 |
def health_check():
|
| 66 |
-
return {"status": "active", "model": "MedGuard v2.
|
| 67 |
|
| 68 |
@app.post("/predict", response_model=PredictionResponse)
|
| 69 |
def predict(request: QueryRequest):
|
|
@@ -71,15 +78,12 @@ def predict(request: QueryRequest):
|
|
| 71 |
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 72 |
|
| 73 |
try:
|
| 74 |
-
#
|
| 75 |
-
# We use simple SPACE concatenation to match standard training dataframe practices.
|
| 76 |
-
# No [SEP] tokens, just "Genre Prompt Response"
|
| 77 |
parts = [part for part in [request.genre, request.prompt, request.text] if part]
|
| 78 |
full_input = " ".join(parts)
|
| 79 |
|
| 80 |
-
print(f"📥 Analyzing: {full_input[:
|
| 81 |
|
| 82 |
-
# 1. PREDICT
|
| 83 |
inputs = tokenizer(full_input, return_tensors="pt", truncation=True, max_length=128).to(DEVICE)
|
| 84 |
with torch.no_grad():
|
| 85 |
outputs = model(**inputs)
|
|
@@ -87,7 +91,12 @@ def predict(request: QueryRequest):
|
|
| 87 |
|
| 88 |
pred_idx = np.argmax(probs)
|
| 89 |
|
| 90 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
explainer = LimeTextExplainer(
|
| 92 |
class_names=LABELS,
|
| 93 |
split_expression=lambda x: x.split()
|
|
@@ -103,7 +112,7 @@ def predict(request: QueryRequest):
|
|
| 103 |
lime_features = exp.as_list(label=pred_idx)
|
| 104 |
|
| 105 |
return {
|
| 106 |
-
"label":
|
| 107 |
"confidence": round(float(probs[pred_idx]) * 100, 2),
|
| 108 |
"probs": {l: round(float(p), 4) for l, p in zip(LABELS, probs)},
|
| 109 |
"explanation": lime_features
|
|
|
|
| 11 |
|
| 12 |
app = FastAPI(title="MedGuard API")
|
| 13 |
|
|
|
|
| 14 |
app.add_middleware(
|
| 15 |
CORSMiddleware,
|
| 16 |
allow_origins=["*"],
|
|
|
|
| 19 |
allow_headers=["*"],
|
| 20 |
)
|
| 21 |
|
|
|
|
| 22 |
MODEL_PATH = "./model"
|
| 23 |
DEVICE = "cpu"
|
| 24 |
|
|
|
|
| 26 |
model = None
|
| 27 |
tokenizer = None
|
| 28 |
|
| 29 |
+
# --- CRITICAL FIX: MATCH TRAINING LABEL MAP ---
|
| 30 |
+
# Training Map: {'Not Relevant': 0, 'Partially Relevant': 1, 'Highly Relevant': 2}
|
| 31 |
+
# This list MUST follow the index order: [Index 0, Index 1, Index 2]
|
| 32 |
+
LABELS = ["Not Relevant", "Partially Relevant", "Highly Relevant"]
|
| 33 |
+
|
| 34 |
try:
|
| 35 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
| 36 |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
|
| 37 |
model.to(DEVICE)
|
| 38 |
model.eval()
|
| 39 |
+
|
| 40 |
+
# Validation check (Optional but good)
|
| 41 |
+
if model.config.id2label:
|
| 42 |
+
print(f"ℹ️ Model config labels: {model.config.id2label}")
|
| 43 |
+
# We enforce our manual list because sometimes configs get messed up during saving
|
| 44 |
+
# but you should visually verify if this print matches our LABELS list
|
| 45 |
+
|
| 46 |
+
print(f"✅ Model Loaded! Label Mapping: {LABELS}")
|
| 47 |
+
|
| 48 |
except Exception as e:
|
| 49 |
print(f"❌ Error loading local model: {e}")
|
|
|
|
| 50 |
MODEL_NAME = "csebuetnlp/banglabert"
|
| 51 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 52 |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3)
|
| 53 |
|
|
|
|
| 54 |
class QueryRequest(BaseModel):
|
| 55 |
genre: str = ""
|
| 56 |
prompt: str = ""
|
|
|
|
| 62 |
probs: dict
|
| 63 |
explanation: list = None
|
| 64 |
|
|
|
|
|
|
|
| 65 |
def predict_proba_lime(texts):
|
| 66 |
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=128).to(DEVICE)
|
| 67 |
with torch.no_grad():
|
|
|
|
| 70 |
|
| 71 |
@app.get("/")
|
| 72 |
def health_check():
|
| 73 |
+
return {"status": "active", "model": "MedGuard v2.3 (Fixed Labels)"}
|
| 74 |
|
| 75 |
@app.post("/predict", response_model=PredictionResponse)
|
| 76 |
def predict(request: QueryRequest):
|
|
|
|
| 78 |
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 79 |
|
| 80 |
try:
|
| 81 |
+
# Use simple space concatenation
|
|
|
|
|
|
|
| 82 |
parts = [part for part in [request.genre, request.prompt, request.text] if part]
|
| 83 |
full_input = " ".join(parts)
|
| 84 |
|
| 85 |
+
print(f"📥 Analyzing: {full_input[:50]}...")
|
| 86 |
|
|
|
|
| 87 |
inputs = tokenizer(full_input, return_tensors="pt", truncation=True, max_length=128).to(DEVICE)
|
| 88 |
with torch.no_grad():
|
| 89 |
outputs = model(**inputs)
|
|
|
|
| 91 |
|
| 92 |
pred_idx = np.argmax(probs)
|
| 93 |
|
| 94 |
+
# Ensure index is valid
|
| 95 |
+
if pred_idx >= len(LABELS):
|
| 96 |
+
label_str = "Unknown"
|
| 97 |
+
else:
|
| 98 |
+
label_str = LABELS[pred_idx]
|
| 99 |
+
|
| 100 |
explainer = LimeTextExplainer(
|
| 101 |
class_names=LABELS,
|
| 102 |
split_expression=lambda x: x.split()
|
|
|
|
| 112 |
lime_features = exp.as_list(label=pred_idx)
|
| 113 |
|
| 114 |
return {
|
| 115 |
+
"label": label_str,
|
| 116 |
"confidence": round(float(probs[pred_idx]) * 100, 2),
|
| 117 |
"probs": {l: round(float(p), 4) for l, p in zip(LABELS, probs)},
|
| 118 |
"explanation": lime_features
|