MohamedTry commited on
Commit
f828c3b
·
verified ·
1 Parent(s): 9c9b1be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -14
app.py CHANGED
@@ -5,40 +5,37 @@ import torch
5
 
6
  app = FastAPI(
7
  title="Medical Text Classifier",
8
- description="Simple medical text classifier using a lightweight BioBERT model.",
9
  version="1.0"
10
  )
11
 
12
- MODEL_NAME = "d4data/biobert-v1.1-finetuned-MedICAL"
13
 
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
15
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
 
 
 
16
 
17
- # Example labels (you can change them to cancer types)
18
- LABELS = [
19
- "disease_related",
20
- "treatment_related",
21
- "test_related",
22
- "symptom_related"
23
- ]
24
 
25
  class Input(BaseModel):
26
  text: str
27
 
28
  @app.get("/")
29
  def home():
30
- return {"status": "Medical classifier running successfully"}
31
 
32
  @app.post("/predict")
33
  def predict(data: Input):
34
  inputs = tokenizer(data.text, return_tensors="pt", truncation=True)
35
  outputs = model(**inputs)
36
-
37
  probs = torch.softmax(outputs.logits, dim=1)
38
- label_id = probs.argmax().item()
39
 
40
  return {
41
  "input": data.text,
42
- "predicted_label": LABELS[label_id] if label_id < len(LABELS) else label_id,
43
  "confidence": float(probs.max())
44
  }
 
5
 
6
  app = FastAPI(
7
  title="Medical Text Classifier",
8
+ description="Classifier using Bio_ClinicalBERT",
9
  version="1.0"
10
  )
11
 
12
+ MODEL_NAME = "emilyalsentzer/Bio_ClinicalBERT"
13
 
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
15
+ model = AutoModelForSequenceClassification.from_pretrained(
16
+ MODEL_NAME,
17
+ num_labels=4 # تقدر تغيّر عدد اللابلز
18
+ )
19
 
20
+ LABELS = ["disease", "treatment", "test", "symptom"]
 
 
 
 
 
 
21
 
22
  class Input(BaseModel):
23
  text: str
24
 
25
  @app.get("/")
26
  def home():
27
+ return {"status": "Bio_ClinicalBERT is running"}
28
 
29
  @app.post("/predict")
30
  def predict(data: Input):
31
  inputs = tokenizer(data.text, return_tensors="pt", truncation=True)
32
  outputs = model(**inputs)
33
+
34
  probs = torch.softmax(outputs.logits, dim=1)
35
+ label_id = torch.argmax(probs).item()
36
 
37
  return {
38
  "input": data.text,
39
+ "predicted_label": LABELS[label_id],
40
  "confidence": float(probs.max())
41
  }