Spaces:
Running
Running
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| # ---------------------------------- | |
| # Initialize FastAPI | |
| # ---------------------------------- | |
| app = FastAPI(title="Medical Symptom Prediction API") | |
| # ---------------------------------- | |
| # Load Model from Hugging Face Hub | |
| # ---------------------------------- | |
| MODEL_NAME = "sm89/Symptom2Disease" | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) | |
| model.eval() | |
| except Exception as e: | |
| raise RuntimeError(f"Model loading failed: {e}") | |
| # ---------------------------------- | |
| # Label Mapping | |
| # ---------------------------------- | |
| id_to_label = { | |
| 0: "Cardiology", | |
| 1: "Dermatology", | |
| 2: "Endocrinology", | |
| 3: "Gastroenterology", | |
| 4: "Infectious", | |
| 5: "Neurology", | |
| 6: "Orthopedics", | |
| 7: "Pulmonology", | |
| 8: "Urology" | |
| } | |
| # ---------------------------------- | |
| # Request Schema | |
| # ---------------------------------- | |
| class PredictionRequest(BaseModel): | |
| text: str | |
| # ---------------------------------- | |
| # Health Check Endpoint | |
| # ---------------------------------- | |
| def health_check(): | |
| return {"message": "Medical Symptom API Running"} | |
| # ---------------------------------- | |
| # Prediction Endpoint | |
| # ---------------------------------- | |
| def predict(request: PredictionRequest): | |
| if not request.text.strip(): | |
| raise HTTPException(status_code=400, detail="Text input cannot be empty") | |
| try: | |
| inputs = tokenizer( | |
| request.text, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True, | |
| max_length=128 | |
| ) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probabilities = torch.softmax(outputs.logits, dim=1) | |
| top_probs, top_indices = torch.topk(probabilities, 3) | |
| results = [] | |
| for prob, idx in zip(top_probs[0], top_indices[0]): | |
| label_index = int(idx.item()) | |
| results.append({ | |
| "department": id_to_label.get(label_index, f"LABEL_{label_index}"), | |
| "confidence": round(float(prob.item()), 4) | |
| }) | |
| return { | |
| "input_text": request.text, | |
| "top_predictions": results, | |
| "final_prediction": results[0] | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |