MohamedTry commited on
Commit
596a54e
·
verified ·
1 Parent(s): a5c9478

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -28
app.py CHANGED
@@ -3,48 +3,42 @@ from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import torch
5
 
6
- app = FastAPI()
 
 
 
 
7
 
8
- MODEL_NAME = "monologg/distilbiobert"
9
 
10
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
11
- model = AutoModelForSequenceClassification.from_pretrained(
12
- MODEL_NAME,
13
- num_labels=12 # عدد أنواع السرطان (أنت تتحكم به)
14
- )
15
 
 
16
  LABELS = [
17
- "breast_cancer",
18
- "lung_cancer",
19
- "prostate_cancer",
20
- "colon_cancer",
21
- "lymphoma",
22
- "melanoma",
23
- "thyroid_cancer",
24
- "kidney_cancer",
25
- "pancreatic_cancer",
26
- "ovarian_cancer",
27
- "cervical_cancer",
28
- "brain_tumor"
29
  ]
30
 
31
  class Input(BaseModel):
32
  text: str
33
 
 
 
 
 
34
  @app.post("/predict")
35
  def predict(data: Input):
36
  inputs = tokenizer(data.text, return_tensors="pt", truncation=True)
37
  outputs = model(**inputs)
38
-
39
- probs = torch.nn.functional.softmax(outputs.logits, dim=1)
40
- label_id = torch.argmax(probs).item()
41
- confidence = float(torch.max(probs))
42
 
43
  return {
44
- "prediction": LABELS[label_id],
45
- "confidence": confidence
 
46
  }
47
-
48
- @app.get("/")
49
- def home():
50
- return {"status": "Model is running"}
 
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import torch
5
 
6
+ app = FastAPI(
7
+ title="Medical Text Classifier",
8
+ description="Simple medical text classifier using a lightweight BioBERT model.",
9
+ version="1.0"
10
+ )
11
 
12
+ MODEL_NAME = "d4data/biobert-v1.1-finetuned-MedICAL"
13
 
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
15
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
 
 
 
16
 
17
+ # Example labels (you can change them to cancer types)
18
  LABELS = [
19
+ "disease_related",
20
+ "treatment_related",
21
+ "test_related",
22
+ "symptom_related"
 
 
 
 
 
 
 
 
23
  ]
24
 
25
  class Input(BaseModel):
26
  text: str
27
 
28
+ @app.get("/")
29
+ def home():
30
+ return {"status": "Medical classifier running successfully"}
31
+
32
  @app.post("/predict")
33
  def predict(data: Input):
34
  inputs = tokenizer(data.text, return_tensors="pt", truncation=True)
35
  outputs = model(**inputs)
36
+
37
+ probs = torch.softmax(outputs.logits, dim=1)
38
+ label_id = probs.argmax().item()
 
39
 
40
  return {
41
+ "input": data.text,
42
+ "predicted_label": LABELS[label_id] if label_id < len(LABELS) else label_id,
43
+ "confidence": float(probs.max())
44
  }