saba2000 commited on
Commit
0ddd5ad
·
verified ·
1 Parent(s): fbf1393

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -1,21 +1,21 @@
1
- from transformers import AutoImageProcessor, AutoModelForImageClassification
2
  import torch
3
  import torch.nn.functional as F
4
  from PIL import Image
5
  import gradio as gr
6
 
7
  # -----------------------------
8
- # 1. Load pretrained chest X-ray model
9
  # -----------------------------
10
- model_name = "yikuan8/resnet50_chestxray14" # real NIH ChestX-ray14 model
11
  model = AutoModelForImageClassification.from_pretrained(model_name)
12
  processor = AutoImageProcessor.from_pretrained(model_name)
13
  model.eval()
14
 
15
- # Get labels directly from the model config
16
  id2label = model.config.id2label
17
 
18
- # Pick only these 3 diseases
19
  target_diseases = ["Pneumonia", "Effusion", "Atelectasis"]
20
 
21
  # -----------------------------
@@ -33,7 +33,8 @@ def predict(image):
33
  results = []
34
  for idx, label in id2label.items():
35
  if label in target_diseases:
36
- results.append(f"{label}: {probs[idx].item():.2f}")
 
37
 
38
  return "\n".join(results)
39
 
@@ -45,7 +46,7 @@ iface = gr.Interface(
45
  inputs=gr.Image(type="pil"),
46
  outputs="text",
47
  title="Chest X-ray: Pneumonia / Effusion / Atelectasis",
48
- description="Upload a chest X-ray. Model predicts probability for Pneumonia, Effusion, and Atelectasis."
49
  )
50
 
51
  iface.launch()
 
1
+ from transformers import AutoModelForImageClassification, AutoImageProcessor
2
  import torch
3
  import torch.nn.functional as F
4
  from PIL import Image
5
  import gradio as gr
6
 
7
  # -----------------------------
8
+ # 1. Load pretrained model
9
  # -----------------------------
10
+ model_name = "microsoft/resnet-50-finetuned-chestxray14"
11
  model = AutoModelForImageClassification.from_pretrained(model_name)
12
  processor = AutoImageProcessor.from_pretrained(model_name)
13
  model.eval()
14
 
15
+ # Get labels from config
16
  id2label = model.config.id2label
17
 
18
+ # Focus only on 3 diseases
19
  target_diseases = ["Pneumonia", "Effusion", "Atelectasis"]
20
 
21
  # -----------------------------
 
33
  results = []
34
  for idx, label in id2label.items():
35
  if label in target_diseases:
36
+ prob = probs[idx].item()
37
+ results.append(f"{label}: {'YES' if prob > 0.5 else 'NO'} ({prob:.2f})")
38
 
39
  return "\n".join(results)
40
 
 
46
  inputs=gr.Image(type="pil"),
47
  outputs="text",
48
  title="Chest X-ray: Pneumonia / Effusion / Atelectasis",
49
+ description="Upload a chest X-ray. Model predicts YES/NO with probabilities for Pneumonia, Effusion, and Atelectasis."
50
  )
51
 
52
  iface.launch()