Spaces:
Running
Running
File size: 2,538 Bytes
f5f1e24 ce4c783 d167b71 09d4627 ce4c783 09d4627 d167b71 f193ab8 d167b71 f193ab8 d167b71 ce4c783 aa68efd 09d4627 aa68efd 09d4627 aa68efd 09d4627 aa68efd 09d4627 aa68efd 09d4627 aa68efd 09d4627 aa68efd 09d4627 aa68efd 09d4627 aa68efd 09d4627 aa68efd 09d4627 f5f1e24 09d4627 f5f1e24 09d4627 ce4c783 f5f1e24 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 | from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# ----------------------------------
# Initialize FastAPI
# ----------------------------------
app = FastAPI(title="Medical Symptom Prediction API")
# ----------------------------------
# Load Model from Hugging Face Hub
# ----------------------------------
MODEL_NAME = "sm89/Symptom2Disease"
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
model.eval()
except Exception as e:
raise RuntimeError(f"Model loading failed: {e}")
# ----------------------------------
# Label Mapping
# ----------------------------------
id_to_label = {
0: "Cardiology",
1: "Dermatology",
2: "Endocrinology",
3: "Gastroenterology",
4: "Infectious",
5: "Neurology",
6: "Orthopedics",
7: "Pulmonology",
8: "Urology"
}
# ----------------------------------
# Request Schema
# ----------------------------------
class PredictionRequest(BaseModel):
text: str
# ----------------------------------
# Health Check Endpoint
# ----------------------------------
@app.get("/")
def health_check():
return {"message": "Medical Symptom API Running"}
# ----------------------------------
# Prediction Endpoint
# ----------------------------------
@app.post("/predict")
def predict(request: PredictionRequest):
if not request.text.strip():
raise HTTPException(status_code=400, detail="Text input cannot be empty")
try:
inputs = tokenizer(
request.text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=128
)
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.softmax(outputs.logits, dim=1)
top_probs, top_indices = torch.topk(probabilities, 3)
results = []
for prob, idx in zip(top_probs[0], top_indices[0]):
label_index = int(idx.item())
results.append({
"department": id_to_label.get(label_index, f"LABEL_{label_index}"),
"confidence": round(float(prob.item()), 4)
})
return {
"input_text": request.text,
"top_predictions": results,
"final_prediction": results[0]
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
|