File size: 2,538 Bytes
f5f1e24
 
ce4c783
 
 
d167b71
09d4627
 
 
 
 
 
ce4c783
 
 
09d4627
 
 
 
 
 
d167b71
 
 
 
 
f193ab8
 
 
d167b71
f193ab8
 
 
 
 
d167b71
ce4c783
aa68efd
09d4627
aa68efd
09d4627
 
aa68efd
09d4627
 
 
 
 
 
aa68efd
09d4627
 
 
 
 
aa68efd
09d4627
 
aa68efd
09d4627
 
 
 
 
 
 
 
aa68efd
09d4627
 
 
aa68efd
09d4627
aa68efd
09d4627
aa68efd
09d4627
 
f5f1e24
09d4627
 
 
 
f5f1e24
09d4627
 
 
 
 
ce4c783
f5f1e24
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# ----------------------------------
# Initialize FastAPI
# ----------------------------------
app = FastAPI(title="Medical Symptom Prediction API")

# ----------------------------------
# Load Model from Hugging Face Hub
# ----------------------------------
MODEL_NAME = "sm89/Symptom2Disease"

try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
    model.eval()
except Exception as e:
    raise RuntimeError(f"Model loading failed: {e}")

# ----------------------------------
# Label Mapping
# ----------------------------------
id_to_label = {
    0: "Cardiology",
    1: "Dermatology",
    2: "Endocrinology",
    3: "Gastroenterology",
    4: "Infectious",
    5: "Neurology",
    6: "Orthopedics",
    7: "Pulmonology",
    8: "Urology"
}

# ----------------------------------
# Request Schema
# ----------------------------------
class PredictionRequest(BaseModel):
    text: str

# ----------------------------------
# Health Check Endpoint
# ----------------------------------
@app.get("/")
def health_check():
    return {"message": "Medical Symptom API Running"}

# ----------------------------------
# Prediction Endpoint
# ----------------------------------
@app.post("/predict")
def predict(request: PredictionRequest):

    if not request.text.strip():
        raise HTTPException(status_code=400, detail="Text input cannot be empty")

    try:
        inputs = tokenizer(
            request.text,
            return_tensors="pt",
            truncation=True,
            padding=True,
            max_length=128
        )

        with torch.no_grad():
            outputs = model(**inputs)
            probabilities = torch.softmax(outputs.logits, dim=1)

        top_probs, top_indices = torch.topk(probabilities, 3)

        results = []

        for prob, idx in zip(top_probs[0], top_indices[0]):
            label_index = int(idx.item())

            results.append({
                "department": id_to_label.get(label_index, f"LABEL_{label_index}"),
                "confidence": round(float(prob.item()), 4)
            })

        return {
            "input_text": request.text,
            "top_predictions": results,
            "final_prediction": results[0]
        }

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))