saba2000 commited on
Commit
fbf1393
·
verified ·
1 Parent(s): ccfcbd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -17
app.py CHANGED
@@ -1,41 +1,45 @@
1
- from transformers import AutoModelForImageClassification, AutoImageProcessor
2
  import torch
3
  import torch.nn.functional as F
4
  from PIL import Image
5
  import gradio as gr
6
 
7
- # Load the properly fine-tuned chest X-ray model
8
- model_name = "Lucario-K17/biomedclip_radiology_diagnosis"
 
 
9
  model = AutoModelForImageClassification.from_pretrained(model_name)
10
  processor = AutoImageProcessor.from_pretrained(model_name)
11
  model.eval()
12
 
13
- # All 14 disease labels
14
- all_diseases = [
15
- "Atelectasis", "Cardiomegaly", "Effusion", "Infiltration", "Mass",
16
- "Nodule", "Pneumonia", "Pneumothorax", "Consolidation", "Edema",
17
- "Emphysema", "Fibrosis", "Pleural_Thickening", "Hernia"
18
- ]
19
 
20
- # Lock to desired diseases
21
  target_diseases = ["Pneumonia", "Effusion", "Atelectasis"]
22
- target_idxs = [all_diseases.index(d) for d in target_diseases]
23
 
 
 
 
24
  def predict(image):
25
  img = image.convert("RGB").resize((224, 224))
26
  inputs = processor(images=img, return_tensors="pt")
27
-
28
  with torch.no_grad():
29
  logits = model(**inputs).logits
30
-
31
  probs = F.softmax(logits, dim=1).squeeze()
32
-
33
  results = []
34
- for i, d in zip(target_idxs, target_diseases):
35
- results.append(f"{d}: {probs[i].item():.2f}")
36
-
 
37
  return "\n".join(results)
38
 
 
 
 
39
  iface = gr.Interface(
40
  fn=predict,
41
  inputs=gr.Image(type="pil"),
 
1
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
2
  import torch
3
  import torch.nn.functional as F
4
  from PIL import Image
5
  import gradio as gr
6
 
7
+ # -----------------------------
8
+ # 1. Load pretrained chest X-ray model
9
+ # -----------------------------
10
+ model_name = "yikuan8/resnet50_chestxray14" # real NIH ChestX-ray14 model
11
  model = AutoModelForImageClassification.from_pretrained(model_name)
12
  processor = AutoImageProcessor.from_pretrained(model_name)
13
  model.eval()
14
 
15
+ # Get labels directly from the model config
16
+ id2label = model.config.id2label
 
 
 
 
17
 
18
+ # Pick only these 3 diseases
19
  target_diseases = ["Pneumonia", "Effusion", "Atelectasis"]
 
20
 
21
+ # -----------------------------
22
+ # 2. Prediction function
23
+ # -----------------------------
24
  def predict(image):
25
  img = image.convert("RGB").resize((224, 224))
26
  inputs = processor(images=img, return_tensors="pt")
27
+
28
  with torch.no_grad():
29
  logits = model(**inputs).logits
30
+
31
  probs = F.softmax(logits, dim=1).squeeze()
32
+
33
  results = []
34
+ for idx, label in id2label.items():
35
+ if label in target_diseases:
36
+ results.append(f"{label}: {probs[idx].item():.2f}")
37
+
38
  return "\n".join(results)
39
 
40
+ # -----------------------------
41
+ # 3. Gradio interface
42
+ # -----------------------------
43
  iface = gr.Interface(
44
  fn=predict,
45
  inputs=gr.Image(type="pil"),