sm89 commited on
Commit
f5f1e24
·
verified ·
1 Parent(s): 0fea038

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -44
app.py CHANGED
@@ -1,24 +1,17 @@
1
- from fastapi import FastAPI, HTTPException, Request
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import gradio as gr
5
 
6
  # ----------------------------------
7
- # Initialize FastAPI
8
- # ----------------------------------
9
- app = FastAPI(title="Medical Symptom Prediction API")
10
-
11
- # ----------------------------------
12
- # Load Model from Hugging Face Hub
13
  # ----------------------------------
14
  MODEL_NAME = "sm89/Symptom2Disease"
15
 
16
- try:
17
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
18
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
19
- model.eval()
20
- except Exception as e:
21
- raise RuntimeError(f"Model loading failed: {e}")
22
 
23
  # ----------------------------------
24
  # Label Mapping
@@ -61,7 +54,6 @@ def predict_logic(text: str):
61
 
62
  for prob, idx in zip(top_probs[0], top_indices[0]):
63
  label_index = int(idx.item())
64
-
65
  results.append({
66
  "department": id_to_label.get(label_index, f"LABEL_{label_index}"),
67
  "confidence": round(float(prob.item()), 4)
@@ -74,58 +66,49 @@ def predict_logic(text: str):
74
  }
75
 
76
  # ----------------------------------
77
- # Health Endpoint
78
  # ----------------------------------
 
 
 
 
 
79
  @app.get("/health")
80
- def health_check():
81
- return {"message": "Medical Symptom API Running"}
82
 
83
- # ----------------------------------
84
- # JSON Prediction Endpoint (No 422 Issue)
85
- # ----------------------------------
86
  @app.post("/predict")
87
- async def predict_api(request: Request):
88
  try:
89
- body = await request.json()
90
- text = body.get("text", "")
91
-
92
- if not text.strip():
93
- raise HTTPException(status_code=400, detail="Text input cannot be empty")
94
-
95
- return predict_logic(text)
96
-
97
- except Exception as e:
98
  raise HTTPException(status_code=400, detail=str(e))
 
 
99
 
100
  # ----------------------------------
101
- # Gradio UI Function
102
  # ----------------------------------
103
  def gradio_predict(text):
104
  try:
105
  result = predict_logic(text)
106
-
107
- output = "Top Predictions:\n\n"
108
 
109
  for item in result["top_predictions"]:
110
- output += f"{item['department']} {item['confidence']}\n"
111
-
112
- output += f"\nFinal Prediction: {result['final_prediction']['department']}"
113
 
114
  return output
115
 
116
  except Exception as e:
117
  return str(e)
118
 
119
- # ----------------------------------
120
- # Create Gradio Interface
121
- # ----------------------------------
122
  demo = gr.Interface(
123
  fn=gradio_predict,
124
- inputs=gr.Textbox(lines=3, placeholder="Enter symptoms here..."),
125
- outputs=gr.Textbox(label="Prediction Result"),
126
- title="Medical Symptom to Department Predictor",
127
- description="Enter symptoms and get top 3 predicted departments."
128
  )
129
 
130
- # Mount Gradio at /ui (safe way)
131
  app = gr.mount_gradio_app(app, demo, path="/ui")
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
  import gradio as gr
6
 
7
  # ----------------------------------
8
+ # Load Model
 
 
 
 
 
9
  # ----------------------------------
10
  MODEL_NAME = "sm89/Symptom2Disease"
11
 
12
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
13
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
14
+ model.eval()
 
 
 
15
 
16
  # ----------------------------------
17
  # Label Mapping
 
54
 
55
  for prob, idx in zip(top_probs[0], top_indices[0]):
56
  label_index = int(idx.item())
 
57
  results.append({
58
  "department": id_to_label.get(label_index, f"LABEL_{label_index}"),
59
  "confidence": round(float(prob.item()), 4)
 
66
  }
67
 
68
  # ----------------------------------
69
+ # FastAPI App
70
  # ----------------------------------
71
+ app = FastAPI(title="Medical Symptom Prediction API")
72
+
73
+ class PredictionRequest(BaseModel):
74
+ text: str
75
+
76
  @app.get("/health")
77
+ def health():
78
+ return {"status": "running"}
79
 
 
 
 
80
  @app.post("/predict")
81
+ def predict(request: PredictionRequest):
82
  try:
83
+ return predict_logic(request.text)
84
+ except ValueError as e:
 
 
 
 
 
 
 
85
  raise HTTPException(status_code=400, detail=str(e))
86
+ except Exception as e:
87
+ raise HTTPException(status_code=500, detail=str(e))
88
 
89
  # ----------------------------------
90
+ # Gradio UI
91
  # ----------------------------------
92
  def gradio_predict(text):
93
  try:
94
  result = predict_logic(text)
95
+ output = ""
 
96
 
97
  for item in result["top_predictions"]:
98
+ output += f"{item['department']} ({item['confidence']})\n"
 
 
99
 
100
  return output
101
 
102
  except Exception as e:
103
  return str(e)
104
 
 
 
 
105
  demo = gr.Interface(
106
  fn=gradio_predict,
107
+ inputs=gr.Textbox(lines=3, placeholder="Enter symptoms here"),
108
+ outputs=gr.Textbox(label="Prediction"),
109
+ title="Medical Symptom Predictor",
110
+ description="Enter symptoms to get top predicted medical departments"
111
  )
112
 
113
+ # Mount Gradio at /ui (IMPORTANT)
114
  app = gr.mount_gradio_app(app, demo, path="/ui")