saba2000 commited on
Commit
5322bcf
·
verified ·
1 Parent(s): 0ddd5ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -27
app.py CHANGED
@@ -4,49 +4,36 @@ import torch.nn.functional as F
4
  from PIL import Image
5
  import gradio as gr
6
 
7
- # -----------------------------
8
- # 1. Load pretrained model
9
- # -----------------------------
10
- model_name = "microsoft/resnet-50-finetuned-chestxray14"
11
  model = AutoModelForImageClassification.from_pretrained(model_name)
12
  processor = AutoImageProcessor.from_pretrained(model_name)
13
  model.eval()
14
 
15
- # Get labels from config
16
- id2label = model.config.id2label
 
 
 
 
17
 
18
- # Focus only on 3 diseases
19
- target_diseases = ["Pneumonia", "Effusion", "Atelectasis"]
20
-
21
- # -----------------------------
22
- # 2. Prediction function
23
- # -----------------------------
24
  def predict(image):
25
  img = image.convert("RGB").resize((224, 224))
26
  inputs = processor(images=img, return_tensors="pt")
27
-
28
  with torch.no_grad():
29
  logits = model(**inputs).logits
30
-
31
- probs = F.softmax(logits, dim=1).squeeze()
32
-
33
  results = []
34
- for idx, label in id2label.items():
35
- if label in target_diseases:
36
- prob = probs[idx].item()
37
- results.append(f"{label}: {'YES' if prob > 0.5 else 'NO'} ({prob:.2f})")
38
-
39
  return "\n".join(results)
40
 
41
- # -----------------------------
42
- # 3. Gradio interface
43
- # -----------------------------
44
  iface = gr.Interface(
45
  fn=predict,
46
  inputs=gr.Image(type="pil"),
47
  outputs="text",
48
- title="Chest X-ray: Pneumonia / Effusion / Atelectasis",
49
- description="Upload a chest X-ray. Model predicts YES/NO with probabilities for Pneumonia, Effusion, and Atelectasis."
50
  )
51
-
52
  iface.launch()
 
4
  from PIL import Image
5
  import gradio as gr
6
 
7
+ # 1️⃣ Load fine-tuned vit-chest-xray model
8
+ model_name = "codewithdark/vit-chest-xray"
 
 
9
  model = AutoModelForImageClassification.from_pretrained(model_name)
10
  processor = AutoImageProcessor.from_pretrained(model_name)
11
  model.eval()
12
 
13
+ # 2️⃣ Define disease indices based on CheXpert labels
14
+ # The model expects: ['Cardiomegaly', 'Edema', 'Consolidation', 'No Finding', 'Pneumonia']
15
+ label_list = ['Cardiomegaly', 'Edema', 'Consolidation', 'No Finding', 'Pneumonia']
16
+ # We only care about Pneumonia, Consolidation, Edema
17
+ target_labels = ['Pneumonia', 'Consolidation', 'Edema']
18
+ target_idxs = [label_list.index(lbl) for lbl in target_labels]
19
 
 
 
 
 
 
 
20
  def predict(image):
21
  img = image.convert("RGB").resize((224, 224))
22
  inputs = processor(images=img, return_tensors="pt")
 
23
  with torch.no_grad():
24
  logits = model(**inputs).logits
25
+ probs = torch.sigmoid(logits).squeeze() # multi-label => sigmoid
 
 
26
  results = []
27
+ for idx, lbl in zip(target_idxs, target_labels):
28
+ prob = probs[idx].item()
29
+ results.append(f"{lbl}: {'YES' if prob > 0.5 else 'NO'} ({prob:.2f})")
 
 
30
  return "\n".join(results)
31
 
 
 
 
32
  iface = gr.Interface(
33
  fn=predict,
34
  inputs=gr.Image(type="pil"),
35
  outputs="text",
36
+ title="Chest X-ray Multi-Disease Detector",
37
+ description="Upload a chest X-ray. Predicts Pneumonia, Consolidation, and Edema."
38
  )
 
39
  iface.launch()