sm89 commited on
Commit
ce4c783
·
verified ·
1 Parent(s): dd6d826

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+
6
+ app = FastAPI()
7
+
8
+ # ----------------------------------
9
+ # Load Model from Hugging Face Hub
10
+ # ----------------------------------
11
+ MODEL_NAME = "sm89/Symptom2Disease"
12
+
13
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
14
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
15
+ model.eval()
16
+
17
+ # ----------------------------------
18
+ # Request Schema
19
+ # ----------------------------------
20
+ class PredictionRequest(BaseModel):
21
+ text: str
22
+
23
+ # ----------------------------------
24
+ # Health Check
25
+ # ----------------------------------
26
+ @app.get("/")
27
+ def home():
28
+ return {"message": "Medical Symptom API Running"}
29
+
30
+ # ----------------------------------
31
+ # Prediction Endpoint
32
+ # ----------------------------------
33
+ @app.post("/predict")
34
+ def predict(request: PredictionRequest):
35
+
36
+ inputs = tokenizer(
37
+ request.text,
38
+ return_tensors="pt",
39
+ truncation=True,
40
+ padding=True,
41
+ max_length=128
42
+ )
43
+
44
+ with torch.no_grad():
45
+ outputs = model(**inputs)
46
+ probabilities = torch.softmax(outputs.logits, dim=1)
47
+
48
+ top_probs, top_indices = torch.topk(probabilities, 3)
49
+
50
+ results = []
51
+
52
+ for prob, idx in zip(top_probs[0], top_indices[0]):
53
+ results.append({
54
+ "label": int(idx.item()),
55
+ "confidence": round(float(prob.item()), 4)
56
+ })
57
+
58
+ return {
59
+ "input_text": request.text,
60
+ "top_predictions": results,
61
+ "final_prediction": results[0]
62
+ }