Reyall commited on
Commit
76cd822
·
verified ·
1 Parent(s): 633be3e

Update src/api.py

Browse files
Files changed (1) hide show
  1. src/api.py +48 -36
src/api.py CHANGED
@@ -1,36 +1,48 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
- from transformers import BertTokenizer, BertForSequenceClassification
4
- import torch
5
- import pickle
6
-
7
- app = FastAPI()
8
-
9
- # Label encoder yüklənməsi
10
- with open("label_encoder.pkl", "rb") as f:
11
- label_encoder = pickle.load(f)
12
-
13
- # Model və tokenizer yüklənməsi
14
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
15
- model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(label_encoder.classes_))
16
- model.eval()
17
-
18
- # Request modeli
19
- class TextRequest(BaseModel):
20
- text: str
21
-
22
- @app.get("/")
23
- def home():
24
- return {"message": "Disease prediction API is running!"}
25
-
26
- @app.post("/predict")
27
- async def predict_endpoint(request: TextRequest):
28
- # Tokenize giriş mətni
29
- inputs = tokenizer(request.text, return_tensors="pt", truncation=True, padding=True, max_length=128)
30
- with torch.no_grad():
31
- outputs = model(**inputs)
32
- probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze().tolist()
33
-
34
- # Label-ları geri çevir
35
- labels = label_encoder.classes_ # 'classes_' ilə etiketləri alırıq
36
- return {"predictions": dict(zip(labels, probs))}
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # api.py
2
+ from fastapi import FastAPI, Query
3
+ import uvicorn
4
+ import pickle, torch, random, os
5
+ from transformers import BertTokenizer, BertForSequenceClassification
6
+ from collections import defaultdict
7
+
8
+ app = FastAPI()
9
+
10
+ # Load model
11
+ label_encoder = pickle.load(open("best_model/label_encoder.pkl", "rb"))
12
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
13
+ model = BertForSequenceClassification.from_pretrained(
14
+ "best_model", num_labels=len(label_encoder.classes_)
15
+ )
16
+ model.eval()
17
+
18
+ def predict_disease(symptoms_text):
19
+ symptoms = [s.strip() for s in symptoms_text.split(",") if s.strip()]
20
+ agg_probs = defaultdict(float)
21
+ for _ in range(10):
22
+ random.shuffle(symptoms)
23
+ inputs = tokenizer(", ".join(symptoms), return_tensors="pt",
24
+ truncation=True, padding=True, max_length=128)
25
+ with torch.no_grad():
26
+ outputs = model(**inputs)
27
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze()
28
+ for i, p in enumerate(probs):
29
+ agg_probs[i] += p.item()
30
+ for k in agg_probs:
31
+ agg_probs[k] /= 10
32
+ return sorted(
33
+ [{"disease": label_encoder.classes_[i], "probability": p}
34
+ for i, p in agg_probs.items()],
35
+ key=lambda x: x["probability"], reverse=True
36
+ )[:3]
37
+
38
+ @app.get("/")
39
+ def api(symptoms: str = Query(...)):
40
+ results = predict_disease(symptoms)
41
+ return {
42
+ "status": "success",
43
+ "input": symptoms,
44
+ "predictions": results
45
+ }
46
+
47
+ if __name__ == "__main__":
48
+ uvicorn.run(app, host="0.0.0.0", port=8000)