from fastapi import FastAPI, HTTPException from pydantic import BaseModel import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification # ---------------------------------- # Initialize FastAPI # ---------------------------------- app = FastAPI(title="Medical Symptom Prediction API") # ---------------------------------- # Load Model from Hugging Face Hub # ---------------------------------- MODEL_NAME = "sm89/Symptom2Disease" try: tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) model.eval() except Exception as e: raise RuntimeError(f"Model loading failed: {e}") # ---------------------------------- # Label Mapping # ---------------------------------- id_to_label = { 0: "Cardiology", 1: "Dermatology", 2: "Endocrinology", 3: "Gastroenterology", 4: "Infectious", 5: "Neurology", 6: "Orthopedics", 7: "Pulmonology", 8: "Urology" } # ---------------------------------- # Request Schema # ---------------------------------- class PredictionRequest(BaseModel): text: str # ---------------------------------- # Health Check Endpoint # ---------------------------------- @app.get("/") def health_check(): return {"message": "Medical Symptom API Running"} # ---------------------------------- # Prediction Endpoint # ---------------------------------- @app.post("/predict") def predict(request: PredictionRequest): if not request.text.strip(): raise HTTPException(status_code=400, detail="Text input cannot be empty") try: inputs = tokenizer( request.text, return_tensors="pt", truncation=True, padding=True, max_length=128 ) with torch.no_grad(): outputs = model(**inputs) probabilities = torch.softmax(outputs.logits, dim=1) top_probs, top_indices = torch.topk(probabilities, 3) results = [] for prob, idx in zip(top_probs[0], top_indices[0]): label_index = int(idx.item()) results.append({ "department": id_to_label.get(label_index, f"LABEL_{label_index}"), "confidence": round(float(prob.item()), 4) }) return { "input_text": request.text, "top_predictions": results, "final_prediction": results[0] } except Exception as e: raise HTTPException(status_code=500, detail=str(e))