sm89's picture
Update app.py
09d4627 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# ----------------------------------
# Initialize FastAPI
# ----------------------------------
app = FastAPI(title="Medical Symptom Prediction API")
# ----------------------------------
# Load Model from Hugging Face Hub
# ----------------------------------
MODEL_NAME = "sm89/Symptom2Disease"
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
model.eval()
except Exception as e:
raise RuntimeError(f"Model loading failed: {e}")
# ----------------------------------
# Label Mapping
# ----------------------------------
id_to_label = {
0: "Cardiology",
1: "Dermatology",
2: "Endocrinology",
3: "Gastroenterology",
4: "Infectious",
5: "Neurology",
6: "Orthopedics",
7: "Pulmonology",
8: "Urology"
}
# ----------------------------------
# Request Schema
# ----------------------------------
class PredictionRequest(BaseModel):
text: str
# ----------------------------------
# Health Check Endpoint
# ----------------------------------
@app.get("/")
def health_check():
return {"message": "Medical Symptom API Running"}
# ----------------------------------
# Prediction Endpoint
# ----------------------------------
@app.post("/predict")
def predict(request: PredictionRequest):
if not request.text.strip():
raise HTTPException(status_code=400, detail="Text input cannot be empty")
try:
inputs = tokenizer(
request.text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=128
)
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.softmax(outputs.logits, dim=1)
top_probs, top_indices = torch.topk(probabilities, 3)
results = []
for prob, idx in zip(top_probs[0], top_indices[0]):
label_index = int(idx.item())
results.append({
"department": id_to_label.get(label_index, f"LABEL_{label_index}"),
"confidence": round(float(prob.item()), 4)
})
return {
"input_text": request.text,
"top_predictions": results,
"final_prediction": results[0]
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))