saba2000 commited on
Commit
ccfcbd4
·
verified ·
1 Parent(s): 929016b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -21
app.py CHANGED
@@ -4,23 +4,25 @@ import torch.nn.functional as F
4
  from PIL import Image
5
  import gradio as gr
6
 
7
- # -----------------------------
8
- # 1. Load the pretrained model
9
- # -----------------------------
10
- model_name = "microsoft/resnet-50" # fine-tuned for chest x-ray multi-disease
11
  model = AutoModelForImageClassification.from_pretrained(model_name)
12
  processor = AutoImageProcessor.from_pretrained(model_name)
13
  model.eval()
14
 
15
- # Example disease list (adjust depending on model config)
16
- diseases = ["Pneumonia", "Effusion", "Atelectasis"]
 
 
 
 
 
 
 
 
17
 
18
- # -----------------------------
19
- # 2. Prediction function
20
- # -----------------------------
21
  def predict(image):
22
  img = image.convert("RGB").resize((224, 224))
23
-
24
  inputs = processor(images=img, return_tensors="pt")
25
 
26
  with torch.no_grad():
@@ -28,25 +30,18 @@ def predict(image):
28
 
29
  probs = F.softmax(logits, dim=1).squeeze()
30
 
31
- # Get top-3 predictions
32
- top_probs, top_idxs = torch.topk(probs, k=3)
33
-
34
  results = []
35
- for idx, prob in zip(top_idxs, top_probs):
36
- disease_name = diseases[idx] if idx < len(diseases) else f"Class {idx.item()}"
37
- results.append(f"{disease_name}: {prob.item():.2f}")
38
 
39
  return "\n".join(results)
40
 
41
- # -----------------------------
42
- # 3. Gradio interface
43
- # -----------------------------
44
  iface = gr.Interface(
45
  fn=predict,
46
  inputs=gr.Image(type="pil"),
47
  outputs="text",
48
- title="Chest X-ray Detector",
49
- description="Upload a chest X-ray. The model predicts Pneumonia, Effusion, or Atelectasis with probability."
50
  )
51
 
52
  iface.launch()
 
4
  from PIL import Image
5
  import gradio as gr
6
 
7
+ # Load the properly fine-tuned chest X-ray model
8
+ model_name = "Lucario-K17/biomedclip_radiology_diagnosis"
 
 
9
  model = AutoModelForImageClassification.from_pretrained(model_name)
10
  processor = AutoImageProcessor.from_pretrained(model_name)
11
  model.eval()
12
 
13
+ # All 14 disease labels
14
+ all_diseases = [
15
+ "Atelectasis", "Cardiomegaly", "Effusion", "Infiltration", "Mass",
16
+ "Nodule", "Pneumonia", "Pneumothorax", "Consolidation", "Edema",
17
+ "Emphysema", "Fibrosis", "Pleural_Thickening", "Hernia"
18
+ ]
19
+
20
+ # Lock to desired diseases
21
+ target_diseases = ["Pneumonia", "Effusion", "Atelectasis"]
22
+ target_idxs = [all_diseases.index(d) for d in target_diseases]
23
 
 
 
 
24
  def predict(image):
25
  img = image.convert("RGB").resize((224, 224))
 
26
  inputs = processor(images=img, return_tensors="pt")
27
 
28
  with torch.no_grad():
 
30
 
31
  probs = F.softmax(logits, dim=1).squeeze()
32
 
 
 
 
33
  results = []
34
+ for i, d in zip(target_idxs, target_diseases):
35
+ results.append(f"{d}: {probs[i].item():.2f}")
 
36
 
37
  return "\n".join(results)
38
 
 
 
 
39
  iface = gr.Interface(
40
  fn=predict,
41
  inputs=gr.Image(type="pil"),
42
  outputs="text",
43
+ title="Chest X-ray: Pneumonia / Effusion / Atelectasis",
44
+ description="Upload a chest X-ray. Model predicts probability for Pneumonia, Effusion, and Atelectasis."
45
  )
46
 
47
  iface.launch()