sm89 commited on
Commit
09d4627
·
verified ·
1 Parent(s): f5f1e24

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -74
app.py CHANGED
@@ -2,16 +2,23 @@ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
- import gradio as gr
6
 
7
  # ----------------------------------
8
- # Load Model
 
 
 
 
 
9
  # ----------------------------------
10
  MODEL_NAME = "sm89/Symptom2Disease"
11
 
12
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
13
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
14
- model.eval()
 
 
 
15
 
16
  # ----------------------------------
17
  # Label Mapping
@@ -29,86 +36,57 @@ id_to_label = {
29
  }
30
 
31
  # ----------------------------------
32
- # Core Prediction Logic
33
  # ----------------------------------
34
- def predict_logic(text: str):
 
35
 
36
- if not text.strip():
37
- raise ValueError("Text input cannot be empty")
 
 
 
 
38
 
39
- inputs = tokenizer(
40
- text,
41
- return_tensors="pt",
42
- truncation=True,
43
- padding=True,
44
- max_length=128
45
- )
46
 
47
- with torch.no_grad():
48
- outputs = model(**inputs)
49
- probabilities = torch.softmax(outputs.logits, dim=1)
50
 
51
- top_probs, top_indices = torch.topk(probabilities, 3)
 
 
 
 
 
 
 
52
 
53
- results = []
 
 
54
 
55
- for prob, idx in zip(top_probs[0], top_indices[0]):
56
- label_index = int(idx.item())
57
- results.append({
58
- "department": id_to_label.get(label_index, f"LABEL_{label_index}"),
59
- "confidence": round(float(prob.item()), 4)
60
- })
61
 
62
- return {
63
- "input_text": text,
64
- "top_predictions": results,
65
- "final_prediction": results[0]
66
- }
67
 
68
- # ----------------------------------
69
- # FastAPI App
70
- # ----------------------------------
71
- app = FastAPI(title="Medical Symptom Prediction API")
72
 
73
- class PredictionRequest(BaseModel):
74
- text: str
 
 
75
 
76
- @app.get("/health")
77
- def health():
78
- return {"status": "running"}
 
 
79
 
80
- @app.post("/predict")
81
- def predict(request: PredictionRequest):
82
- try:
83
- return predict_logic(request.text)
84
- except ValueError as e:
85
- raise HTTPException(status_code=400, detail=str(e))
86
  except Exception as e:
87
  raise HTTPException(status_code=500, detail=str(e))
88
-
89
- # ----------------------------------
90
- # Gradio UI
91
- # ----------------------------------
92
- def gradio_predict(text):
93
- try:
94
- result = predict_logic(text)
95
- output = ""
96
-
97
- for item in result["top_predictions"]:
98
- output += f"{item['department']} ({item['confidence']})\n"
99
-
100
- return output
101
-
102
- except Exception as e:
103
- return str(e)
104
-
105
- demo = gr.Interface(
106
- fn=gradio_predict,
107
- inputs=gr.Textbox(lines=3, placeholder="Enter symptoms here"),
108
- outputs=gr.Textbox(label="Prediction"),
109
- title="Medical Symptom Predictor",
110
- description="Enter symptoms to get top predicted medical departments"
111
- )
112
-
113
- # Mount Gradio at /ui (IMPORTANT)
114
- app = gr.mount_gradio_app(app, demo, path="/ui")
 
2
  from pydantic import BaseModel
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
5
 
6
  # ----------------------------------
7
+ # Initialize FastAPI
8
+ # ----------------------------------
9
+ app = FastAPI(title="Medical Symptom Prediction API")
10
+
11
+ # ----------------------------------
12
+ # Load Model from Hugging Face Hub
13
  # ----------------------------------
14
  MODEL_NAME = "sm89/Symptom2Disease"
15
 
16
+ try:
17
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
18
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
19
+ model.eval()
20
+ except Exception as e:
21
+ raise RuntimeError(f"Model loading failed: {e}")
22
 
23
  # ----------------------------------
24
  # Label Mapping
 
36
  }
37
 
38
  # ----------------------------------
39
+ # Request Schema
40
  # ----------------------------------
41
+ class PredictionRequest(BaseModel):
42
+ text: str
43
 
44
+ # ----------------------------------
45
+ # Health Check Endpoint
46
+ # ----------------------------------
47
+ @app.get("/")
48
+ def health_check():
49
+ return {"message": "Medical Symptom API Running"}
50
 
51
+ # ----------------------------------
52
+ # Prediction Endpoint
53
+ # ----------------------------------
54
+ @app.post("/predict")
55
+ def predict(request: PredictionRequest):
 
 
56
 
57
+ if not request.text.strip():
58
+ raise HTTPException(status_code=400, detail="Text input cannot be empty")
 
59
 
60
+ try:
61
+ inputs = tokenizer(
62
+ request.text,
63
+ return_tensors="pt",
64
+ truncation=True,
65
+ padding=True,
66
+ max_length=128
67
+ )
68
 
69
+ with torch.no_grad():
70
+ outputs = model(**inputs)
71
+ probabilities = torch.softmax(outputs.logits, dim=1)
72
 
73
+ top_probs, top_indices = torch.topk(probabilities, 3)
 
 
 
 
 
74
 
75
+ results = []
 
 
 
 
76
 
77
+ for prob, idx in zip(top_probs[0], top_indices[0]):
78
+ label_index = int(idx.item())
 
 
79
 
80
+ results.append({
81
+ "department": id_to_label.get(label_index, f"LABEL_{label_index}"),
82
+ "confidence": round(float(prob.item()), 4)
83
+ })
84
 
85
+ return {
86
+ "input_text": request.text,
87
+ "top_predictions": results,
88
+ "final_prediction": results[0]
89
+ }
90
 
 
 
 
 
 
 
91
  except Exception as e:
92
  raise HTTPException(status_code=500, detail=str(e))