File size: 4,094 Bytes
46795fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e3db74
 
 
 
 
46795fe
 
 
 
 
7e3db74
 
 
 
 
 
 
 
 
46795fe
 
 
 
 
 
 
666b4cd
 
 
46795fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e3db74
46795fe
 
 
 
 
 
 
7e3db74
666b4cd
 
46795fe
7e3db74
46795fe
 
 
 
 
 
 
 
7e3db74
 
 
 
 
 
46795fe
 
 
 
 
 
 
 
 
666b4cd
46795fe
 
 
 
 
7e3db74
46795fe
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# main.py
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from lime.lime_text import LimeTextExplainer
import numpy as np
import os

app = FastAPI(title="MedGuard API")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], 
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

MODEL_PATH = "./model"
DEVICE = "cpu"

print(f"🔄 Loading Model from {MODEL_PATH}...")
model = None
tokenizer = None

# --- CRITICAL FIX: MATCH TRAINING LABEL MAP ---
# Training Map: {'Not Relevant': 0, 'Partially Relevant': 1, 'Highly Relevant': 2}
# This list MUST follow the index order: [Index 0, Index 1, Index 2]
LABELS = ["Not Relevant", "Partially Relevant", "Highly Relevant"]

try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
    model.to(DEVICE)
    model.eval()
    
    # Validation check (Optional but good)
    if model.config.id2label:
        print(f"ℹ️ Model config labels: {model.config.id2label}")
        # We enforce our manual list because sometimes configs get messed up during saving
        # but you should visually verify if this print matches our LABELS list
        
    print(f"✅ Model Loaded! Label Mapping: {LABELS}")
    
except Exception as e:
    print(f"❌ Error loading local model: {e}")
    MODEL_NAME = "csebuetnlp/banglabert"
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3)

class QueryRequest(BaseModel):
    genre: str = ""
    prompt: str = ""
    text: str 

class PredictionResponse(BaseModel):
    label: str
    confidence: float
    probs: dict
    explanation: list = None

def predict_proba_lime(texts):
    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=128).to(DEVICE)
    with torch.no_grad():
        outputs = model(**inputs)
    return torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy()

@app.get("/")
def health_check():
    return {"status": "active", "model": "MedGuard v2.3 (Fixed Labels)"}

@app.post("/predict", response_model=PredictionResponse)
def predict(request: QueryRequest):
    if not model or not tokenizer:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    try:
        # Use simple space concatenation
        parts = [part for part in [request.genre, request.prompt, request.text] if part]
        full_input = " ".join(parts)
            
        print(f"📥 Analyzing: {full_input[:50]}...") 

        inputs = tokenizer(full_input, return_tensors="pt", truncation=True, max_length=128).to(DEVICE)
        with torch.no_grad():
            outputs = model(**inputs)
            probs = F.softmax(outputs.logits, dim=-1).cpu().numpy()[0]
        
        pred_idx = np.argmax(probs)
        
        # Ensure index is valid
        if pred_idx >= len(LABELS):
            label_str = "Unknown"
        else:
            label_str = LABELS[pred_idx]

        explainer = LimeTextExplainer(
            class_names=LABELS, 
            split_expression=lambda x: x.split() 
        )
        
        exp = explainer.explain_instance(
            full_input, 
            predict_proba_lime, 
            num_features=6, 
            num_samples=40, 
            labels=[pred_idx]
        )
        lime_features = exp.as_list(label=pred_idx) 
        
        return {
            "label": label_str,
            "confidence": round(float(probs[pred_idx]) * 100, 2),
            "probs": {l: round(float(p), 4) for l, p in zip(LABELS, probs)},
            "explanation": lime_features
        }
    except Exception as e:
        print(f"🔥 Server Error: {e}")
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)