saba2000 commited on
Commit
8ad1547
·
verified ·
1 Parent(s): 14b8261

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -24
app.py CHANGED
@@ -1,39 +1,48 @@
1
  from transformers import AutoModelForImageClassification, AutoImageProcessor
2
  import torch
3
- import torch.nn.functional as F
4
  from PIL import Image
5
  import gradio as gr
6
 
7
- # 1️⃣ Load fine-tuned vit-chest-xray model
8
  model_name = "codewithdark/vit-chest-xray"
9
  model = AutoModelForImageClassification.from_pretrained(model_name)
10
  processor = AutoImageProcessor.from_pretrained(model_name)
11
  model.eval()
12
 
13
- # 2️⃣ Define disease indices based on CheXpert labels
14
- # The model expects: ['Cardiomegaly', 'Edema', 'Consolidation', 'No Finding', 'Pneumonia']
15
- label_list = ['Cardiomegaly', 'Edema', 'Consolidation', 'No Finding', 'Pneumonia']
16
- # We only care about Pneumonia, Consolidation, Edema
17
- target_labels = ['Pneumonia', 'Consolidation', 'Cardiomegaly', 'No Finding', 'Edema']
18
- target_idxs = [label_list.index(lbl) for lbl in target_labels]
19
 
20
  def predict(image):
21
- img = image.convert("RGB").resize((224, 224))
22
- inputs = processor(images=img, return_tensors="pt")
23
- with torch.no_grad():
24
- logits = model(**inputs).logits
25
- probs = torch.sigmoid(logits).squeeze() # multi-label => sigmoid
26
- results = []
27
- for idx, lbl in zip(target_idxs, target_labels):
28
- prob = probs[idx].item()
29
- results.append(f"{lbl}: {'YES' if prob > 0.5 else 'NO'} ({prob:.2f})")
30
- return "\n".join(results)
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  iface = gr.Interface(
33
- fn=predict,
34
- inputs=gr.Image(type="pil"),
35
- outputs="text",
36
- title="Chest X-ray Multi-Disease Detector",
37
- description="Upload a chest X-ray. Predicts Pneumonia, Consolidation, and Edema."
38
  )
39
- iface.launch()
 
 
1
  from transformers import AutoModelForImageClassification, AutoImageProcessor
2
  import torch
 
3
  from PIL import Image
4
  import gradio as gr
5
 
 
6
  model_name = "codewithdark/vit-chest-xray"
7
  model = AutoModelForImageClassification.from_pretrained(model_name)
8
  processor = AutoImageProcessor.from_pretrained(model_name)
9
  model.eval()
10
 
11
+ labels = ['Cardiomegaly', 'Edema', 'Consolidation', 'No Finding', 'Pneumonia']
12
+ target_labels = ['Pneumonia', 'Consolidation', 'Edema']
13
+ target_idxs = [labels.index(lbl) for lbl in target_labels]
 
 
 
14
 
15
  def predict(image):
16
+ image = image.convert("RGB").resize((224, 224))
17
+ inputs = processor(images=image, return_tensors="pt")
18
+ with torch.no_grad():
19
+ logits = model(**inputs).logits
20
+ probs = torch.sigmoid(logits).squeeze()
21
+
22
+ ```
23
+ detected = []
24
+ results = []
25
+ for idx, lbl in zip(target_idxs, target_labels):
26
+ prob = probs[idx].item()
27
+ status = "YES" if prob > 0.5 else "NO"
28
+ results.append(f"{lbl}: {status} ({prob:.2f})")
29
+ if status == "YES":
30
+ detected.append(lbl)
31
+
32
+ if detected:
33
+ summary = f"⚠️ Patient shows signs of: {', '.join(detected)}."
34
+ else:
35
+ summary = "✅ Patient appears healthy — no major lung issues detected."
36
+
37
+ return "\n".join(results + ["\n" + summary])
38
+ ```
39
 
40
  iface = gr.Interface(
41
+ fn=predict,
42
+ inputs=gr.Image(type="pil"),
43
+ outputs="text",
44
+ title="Chest X-ray Disease Detector",
45
+ description="Upload a chest X-ray to detect Pneumonia, Consolidation, and Edema. Gives clear patient health summary."
46
  )
47
+
48
+ iface.launch()