VaneshDev commited on
Commit
2a74625
·
verified ·
1 Parent(s): 35a8d6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -40
app.py CHANGED
@@ -21,19 +21,18 @@ import numpy as np
21
  logging.basicConfig(level=logging.INFO)
22
  log = logging.getLogger(__name__)
23
 
24
- # Load model
25
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
- MODEL = xrv.models.get_model("densenet121-res224-all").to(DEVICE).eval()
 
27
  LABELS = MODEL.pathologies
28
 
29
- # Correct transform for TorchXRayVision (grayscale, normalized to [-1024, 1024])
 
 
30
  def preprocess_xray(pil_img: Image.Image) -> torch.Tensor:
31
  """
32
  Preprocess PIL image for TorchXRayVision model
33
- TorchXRayVision expects:
34
- - Single channel grayscale image
35
- - Values normalized to [-1024, 1024] range
36
- - Resolution of 224x224
37
  """
38
  # Convert to grayscale if needed
39
  if pil_img.mode != "L":
@@ -43,11 +42,10 @@ def preprocess_xray(pil_img: Image.Image) -> torch.Tensor:
43
  img_array = np.array(pil_img, dtype=np.float32)
44
 
45
  # Normalize to [-1024, 1024] range (TorchXRayVision standard)
46
- # Assume input is 8-bit (0-255), scale to [-1024, 1024]
47
  img_array = xrv.datasets.normalize(img_array, 255)
48
 
49
- # Add channel dimension and resize
50
- img_array = img_array[None, ...] # Add channel dimension
51
 
52
  # Use TorchXRayVision transforms
53
  transform = transforms.Compose([
@@ -57,25 +55,26 @@ def preprocess_xray(pil_img: Image.Image) -> torch.Tensor:
57
 
58
  img_array = transform(img_array)
59
 
60
- # Convert to tensor
61
  img_tensor = torch.from_numpy(img_array).unsqueeze(0).to(DEVICE)
 
62
 
63
  return img_tensor
64
 
65
- # Initialize CAM extractor with correct input shape for grayscale
66
- cam_extractor = SmoothGradCAMpp(MODEL, input_shape=(1, 224, 224)) # Single channel
67
-
68
  def analyse_xray(img: Image.Image):
69
  if img is None:
70
  return "Please upload an image.", None
71
 
72
  try:
73
- # Preprocess image for TorchXRayVision
74
  x = preprocess_xray(img)
75
 
76
- with torch.no_grad():
77
- logits = MODEL(x)
78
- probs = torch.sigmoid(logits)[0] * 100 # Convert to percentages
 
 
 
79
 
80
  # Get top 5 predictions
81
  topk = torch.topk(probs, 5)
@@ -86,13 +85,16 @@ def analyse_xray(img: Image.Image):
86
  # Generate activation map
87
  activation_map = cam_extractor(target, logits)[0]
88
 
89
- # Overlay heatmap on original image
90
  # Convert single channel to 3-channel for overlay
91
- input_for_overlay = x.squeeze(0).cpu()
92
- input_for_overlay = input_for_overlay.repeat(3, 1, 1) # Repeat single channel 3 times
93
 
 
94
  heatmap = cam_extractor.overlay(input_for_overlay, activation_map)
95
 
 
 
 
96
  # Build HTML summary table
97
  table_rows = ""
98
  for i in range(len(topk.indices)):
@@ -148,7 +150,6 @@ MEDICAL_ADVICE = {
148
  def get_medical_advice(condition: str) -> str:
149
  return MEDICAL_ADVICE.get(condition, "Consult with a radiologist or pulmonologist for proper interpretation.")
150
 
151
- # PDF report analysis (simplified - focusing on the main issue)
152
  def analyse_report(file):
153
  if file is None:
154
  return "Please upload a PDF file."
@@ -252,24 +253,6 @@ with gr.Blocks(title="🩻 RadiologyScan AI", theme=gr.themes.Soft()) as demo:
252
  outputs=[pdf_input, report_output]
253
  )
254
 
255
- gr.Markdown("""
256
- ### 📖 How to Use
257
- 1. **X-ray Analysis**: Upload a chest X-ray image (JPEG, PNG) and click "Analyze X-ray"
258
- 2. **Report Analysis**: Upload a medical report PDF and click "Analyze Report"
259
-
260
- ### 🔬 Technical Details
261
- - Uses TorchXRayVision pre-trained DenseNet-121 model
262
- - Trained on multiple chest X-ray datasets
263
- - Provides attention heatmaps for interpretability
264
- - Supports 18 different pathological conditions
265
-
266
- ### ⚠️ Limitations
267
- - For educational use only
268
- - Not a substitute for professional medical diagnosis
269
- - Results may vary based on image quality
270
- - Always consult healthcare professionals
271
- """)
272
-
273
  if __name__ == "__main__":
274
  demo.launch(
275
  server_name="0.0.0.0",
 
21
  logging.basicConfig(level=logging.INFO)
22
  log = logging.getLogger(__name__)
23
 
24
+ # Load model - IMPORTANT: Don't call .eval() here for CAM to work
25
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ MODEL = xrv.models.get_model("densenet121-res224-all").to(DEVICE)
27
+ # Note: We'll switch between train/eval modes as needed
28
  LABELS = MODEL.pathologies
29
 
30
+ # Initialize CAM extractor with correct input shape for grayscale
31
+ cam_extractor = SmoothGradCAMpp(MODEL, input_shape=(1, 224, 224))
32
+
33
  def preprocess_xray(pil_img: Image.Image) -> torch.Tensor:
34
  """
35
  Preprocess PIL image for TorchXRayVision model
 
 
 
 
36
  """
37
  # Convert to grayscale if needed
38
  if pil_img.mode != "L":
 
42
  img_array = np.array(pil_img, dtype=np.float32)
43
 
44
  # Normalize to [-1024, 1024] range (TorchXRayVision standard)
 
45
  img_array = xrv.datasets.normalize(img_array, 255)
46
 
47
+ # Add channel dimension
48
+ img_array = img_array[None, ...]
49
 
50
  # Use TorchXRayVision transforms
51
  transform = transforms.Compose([
 
55
 
56
  img_array = transform(img_array)
57
 
58
+ # Convert to tensor with gradient enabled
59
  img_tensor = torch.from_numpy(img_array).unsqueeze(0).to(DEVICE)
60
+ img_tensor.requires_grad_(True) # Enable gradients for CAM
61
 
62
  return img_tensor
63
 
 
 
 
64
  def analyse_xray(img: Image.Image):
65
  if img is None:
66
  return "Please upload an image.", None
67
 
68
  try:
69
+ # Preprocess image
70
  x = preprocess_xray(img)
71
 
72
+ # Set model to train mode temporarily for gradient computation
73
+ MODEL.train()
74
+
75
+ # Forward pass with gradient tracking
76
+ logits = MODEL(x)
77
+ probs = torch.sigmoid(logits)[0] * 100 # Convert to percentages
78
 
79
  # Get top 5 predictions
80
  topk = torch.topk(probs, 5)
 
85
  # Generate activation map
86
  activation_map = cam_extractor(target, logits)[0]
87
 
 
88
  # Convert single channel to 3-channel for overlay
89
+ input_for_overlay = x.squeeze(0).cpu().detach()
90
+ input_for_overlay = input_for_overlay.repeat(3, 1, 1)
91
 
92
+ # Generate heatmap
93
  heatmap = cam_extractor.overlay(input_for_overlay, activation_map)
94
 
95
+ # Set model back to eval mode
96
+ MODEL.eval()
97
+
98
  # Build HTML summary table
99
  table_rows = ""
100
  for i in range(len(topk.indices)):
 
150
  def get_medical_advice(condition: str) -> str:
151
  return MEDICAL_ADVICE.get(condition, "Consult with a radiologist or pulmonologist for proper interpretation.")
152
 
 
153
  def analyse_report(file):
154
  if file is None:
155
  return "Please upload a PDF file."
 
253
  outputs=[pdf_input, report_output]
254
  )
255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  if __name__ == "__main__":
257
  demo.launch(
258
  server_name="0.0.0.0",