VaneshDev commited on
Commit
4d6d0f1
·
verified ·
1 Parent(s): 661f3f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -28
app.py CHANGED
@@ -1,10 +1,6 @@
1
  """
2
  RadiologyScan AI – X-ray & Report analyser
3
  Author : <you>
4
- ▶ requirements.txt needs:
5
- torch torchvision torchxrayvision==1.2.0
6
- pillow gradio pymupdf torchcam==0.4.0
7
- transformers>=4.40.0 accelerate
8
  """
9
 
10
  import os, re, logging, tempfile
@@ -13,20 +9,18 @@ from PIL import Image
13
  import torch
14
  import torch.nn.functional as F
15
  from torchvision import transforms
16
- import torchxrayvision as xrv # CheXNet-style models
17
- import fitz # PyMuPDF
18
- from torchcam.methods import SmoothGradCAMpp # visual explainability
19
  from transformers import pipeline
20
 
21
  logging.basicConfig(level=logging.INFO)
22
  log = logging.getLogger(__name__)
23
 
24
- # ------------------------------------------------------------------
25
- # 1. Load model – 18-label denseNet trained on multiple X-ray sets
26
- # ------------------------------------------------------------------
27
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
  MODEL = xrv.models.get_model("densenet121-res224-all").to(DEVICE).eval()
29
- LABELS = MODEL.pathologies # 18 canonical labels
30
 
31
  TRANSFORM = transforms.Compose([
32
  transforms.Resize(224),
@@ -35,15 +29,12 @@ TRANSFORM = transforms.Compose([
35
  transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
36
  ])
37
 
38
- # ------------------------ helper ----------------------------------
39
  def preprocess(pil_img: Image.Image) -> torch.Tensor:
40
  if pil_img.mode != "RGB":
41
  pil_img = pil_img.convert("RGB")
42
  return TRANSFORM(pil_img).unsqueeze(0).to(DEVICE)
43
 
44
- # ------------------------------------------------------------------
45
- # 2. X-ray prediction with Grad-CAM + textual advice
46
- # ------------------------------------------------------------------
47
  cam_extractor = SmoothGradCAMpp(MODEL)
48
 
49
  def analyse_xray(img: Image.Image):
@@ -52,12 +43,12 @@ def analyse_xray(img: Image.Image):
52
  x = preprocess(img)
53
  with torch.no_grad():
54
  logits = MODEL(x)
55
- probs = torch.sigmoid(logits)[0] * 100 # multi-label %
56
- topk = torch.topk(probs, 3) # show best 3
57
 
58
- # Grad-CAM heat-map for the highest score
59
  target = topk.indices[0].item()
60
- activation_map = cam_extractor(target, logits)[0] # H×W
61
  heatmap = cam_extractor.overlay(torch.squeeze(x).cpu(), activation_map)
62
 
63
  # Build HTML summary
@@ -74,7 +65,7 @@ def analyse_xray(img: Image.Image):
74
 
75
  return html, Image.fromarray(heatmap)
76
 
77
- # simple rule-based advice (extend or swap for knowledge-graph)
78
  ADVICE = {
79
  "Pneumonia": "Consult a pulmonologist; antibiotics or antivirals as indicated.",
80
  "Cardiomegaly": "Recommend echocardiography; refer to cardiology.",
@@ -82,10 +73,7 @@ ADVICE = {
82
  }
83
  def medical_advice(label): return ADVICE.get(label, "Discuss with a radiologist for next steps.")
84
 
85
- # ------------------------------------------------------------------
86
- # 3. PDF report summariser (LLM pipeline fallback)
87
- # ------------------------------------------------------------------
88
- # Regex first → else call an LLM summariser (small DistilBART)
89
  summariser = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
90
 
91
  def analyse_report(file):
@@ -100,7 +88,6 @@ def analyse_report(file):
100
 
101
  disease = regex_find_disease(text)
102
  if not disease:
103
- # fallback LLM summary
104
  short = summariser(text[:4000], max_length=120, min_length=30, do_sample=False)[0]["summary_text"]
105
  return f"<h3>Report summary</h3><p>{short}</p>"
106
 
@@ -120,9 +107,7 @@ def regex_find_disease(t:str):
120
  if re.search(v, t, flags=re.I): return k
121
  return None
122
 
123
- # ------------------------------------------------------------------
124
- # 4. Gradio UI
125
- # ------------------------------------------------------------------
126
  with gr.Blocks(title="🩻 RadiologyScan AI") as demo:
127
  gr.Markdown("## 🩻 RadiologyScan AI – Chest X-ray & Report Analyser")
128
 
 
1
  """
2
  RadiologyScan AI – X-ray & Report analyser
3
  Author : <you>
 
 
 
 
4
  """
5
 
6
  import os, re, logging, tempfile
 
9
  import torch
10
  import torch.nn.functional as F
11
  from torchvision import transforms
12
+ import torchxrayvision as xrv
13
+ import fitz # PyMuPDF
14
+ from torchcam.methods import SmoothGradCAMpp
15
  from transformers import pipeline
16
 
17
  logging.basicConfig(level=logging.INFO)
18
  log = logging.getLogger(__name__)
19
 
20
+ # Load model
 
 
21
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
  MODEL = xrv.models.get_model("densenet121-res224-all").to(DEVICE).eval()
23
+ LABELS = MODEL.pathologies
24
 
25
  TRANSFORM = transforms.Compose([
26
  transforms.Resize(224),
 
29
  transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
30
  ])
31
 
 
32
  def preprocess(pil_img: Image.Image) -> torch.Tensor:
33
  if pil_img.mode != "RGB":
34
  pil_img = pil_img.convert("RGB")
35
  return TRANSFORM(pil_img).unsqueeze(0).to(DEVICE)
36
 
37
+ # X-ray prediction with Grad-CAM
 
 
38
  cam_extractor = SmoothGradCAMpp(MODEL)
39
 
40
  def analyse_xray(img: Image.Image):
 
43
  x = preprocess(img)
44
  with torch.no_grad():
45
  logits = MODEL(x)
46
+ probs = torch.sigmoid(logits)[0] * 100
47
+ topk = torch.topk(probs, 3)
48
 
49
+ # Grad-CAM heat-map
50
  target = topk.indices[0].item()
51
+ activation_map = cam_extractor(target, logits)[0]
52
  heatmap = cam_extractor.overlay(torch.squeeze(x).cpu(), activation_map)
53
 
54
  # Build HTML summary
 
65
 
66
  return html, Image.fromarray(heatmap)
67
 
68
+ # Medical advice
69
  ADVICE = {
70
  "Pneumonia": "Consult a pulmonologist; antibiotics or antivirals as indicated.",
71
  "Cardiomegaly": "Recommend echocardiography; refer to cardiology.",
 
73
  }
74
  def medical_advice(label): return ADVICE.get(label, "Discuss with a radiologist for next steps.")
75
 
76
+ # PDF report summariser
 
 
 
77
  summariser = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
78
 
79
  def analyse_report(file):
 
88
 
89
  disease = regex_find_disease(text)
90
  if not disease:
 
91
  short = summariser(text[:4000], max_length=120, min_length=30, do_sample=False)[0]["summary_text"]
92
  return f"<h3>Report summary</h3><p>{short}</p>"
93
 
 
107
  if re.search(v, t, flags=re.I): return k
108
  return None
109
 
110
+ # Gradio UI
 
 
111
  with gr.Blocks(title="🩻 RadiologyScan AI") as demo:
112
  gr.Markdown("## 🩻 RadiologyScan AI – Chest X-ray & Report Analyser")
113