ArchCoder commited on
Commit
c322805
·
verified ·
1 Parent(s): a4f4e25

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -131
app.py CHANGED
@@ -266,181 +266,191 @@ def generate_attention_heatmap(attention_maps):
266
  return heatmap
267
 
268
  def analyze_image(image, ground_truth, filename):
269
- """Main analysis function - FIXED VERSION matching the working notebook"""
 
 
 
 
 
 
270
  if model is None:
271
  return None, "Model not loaded. Please restart the application."
272
-
273
  if image is None:
274
  return None, "Please select an image first."
275
-
276
  try:
277
- print("="*50)
278
  print("DEBUG: Starting analysis...")
279
  print(f"Input image mode: {image.mode}")
280
  print(f"Input image size: {image.size}")
281
-
282
- # Preprocess - exactly like the working notebook
283
  input_tensor = preprocess_for_model(image).to(device)
284
  print(f"Input tensor shape: {input_tensor.shape}")
285
  print(f"Input tensor min/max: {input_tensor.min():.4f}/{input_tensor.max():.4f}")
286
-
287
  # Get prediction and attention maps
288
  with torch.no_grad():
289
  print("Getting model output...")
290
  model_output, attention_maps = model(input_tensor)
291
-
 
292
  print(f"Model output shape: {model_output.shape}")
293
  print(f"Model output min/max BEFORE sigmoid: {model_output.min():.4f}/{model_output.max():.4f}")
294
-
295
- # Apply sigmoid and threshold - EXACTLY like the working notebook
296
- pred_mask = torch.sigmoid(model_output)
297
- print(f"After sigmoid min/max: {pred_mask.min():.4f}/{pred_mask.max():.4f}")
298
-
299
- # Apply threshold to get binary mask
300
- binary_mask = (pred_mask > 0.5).float()
301
- print(f"Binary mask sum (number of 1s): {binary_mask.sum()}")
302
-
303
- # Convert to numpy - following notebook approach
304
- pred_mask_np = binary_mask.cpu().squeeze().numpy()
305
  print(f"Numpy binary mask shape: {pred_mask_np.shape}")
306
  print(f"Numpy binary mask unique values: {np.unique(pred_mask_np)}")
307
  print(f"Numpy binary mask sum: {np.sum(pred_mask_np)}")
308
-
309
- # Create visualization mask like in the notebook
310
- # The notebook uses: inv_pred_mask_np = np.where(pred_mask_np == 1, 0, 255)
311
- # This inverts the mask for better visualization
312
- inv_pred_mask_np = np.where(pred_mask_np == 1, 0, 255)
313
-
314
- # Generate attention heatmap
315
  print("Generating attention heatmap...")
316
- att_heatmap = generate_attention_heatmap(attention_maps)
317
- print(f"Attention heatmap shape: {att_heatmap.shape}")
318
-
319
- # Prepare original image array
320
- original_np = np.array(image.resize((256, 256)))
321
-
322
- # Create tumor-only image (like in notebook)
323
- tumor_only = np.where(pred_mask_np == 1, original_np, 255)
324
-
325
- # Create visualization
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  if ground_truth is not None:
327
  fig, axes = plt.subplots(2, 4, figsize=(16, 8))
328
  else:
329
  fig, axes = plt.subplots(2, 3, figsize=(15, 8))
330
-
331
  fig.suptitle('Brain Tumor Segmentation Analysis', fontsize=16, weight='bold')
332
-
333
- # Row 1: Original, Attention, Predicted Mask, Tumor Only
334
- axes[0,0].imshow(original_np, cmap='gray')
335
- axes[0,0].set_title('Original Image', fontsize=12, weight='bold')
336
- axes[0,0].axis('off')
337
-
338
- # Attention heatmap overlay
339
- axes[0,1].imshow(original_np, cmap='gray')
340
- axes[0,1].imshow(att_heatmap, alpha=0.4)
341
- axes[0,1].set_title('Attention Heatmap', fontsize=12, weight='bold')
342
- axes[0,1].axis('off')
343
-
 
344
  # Predicted mask (inverted for visualization)
345
- axes[0,2].imshow(inv_pred_mask_np, cmap='gray')
346
- axes[0,2].set_title('Predicted Mask', fontsize=12, weight='bold')
347
- axes[0,2].axis('off')
348
-
349
  if ground_truth is not None:
 
 
 
 
350
  # Ground truth processing - convert to binary like notebook
351
- gt_array = np.array(ground_truth.resize((256, 256)))
352
- # Apply same preprocessing as notebook
353
  val_test_transform = transforms.Compose([
354
- transforms.Resize((256,256)),
355
  transforms.ToTensor()
356
  ])
357
  mask_np = val_test_transform(ground_truth).cpu().squeeze().numpy()
358
-
359
- print(f"Ground truth array shape: {gt_array.shape}")
360
- print(f"Ground truth unique values: {np.unique(gt_array)}")
361
-
362
- # Tumor only image
363
- axes[0,3].imshow(tumor_only, cmap='gray')
364
- axes[0,3].set_title('Tumor Only', fontsize=12, weight='bold')
365
- axes[0,3].axis('off')
366
-
367
- # Row 2: Ground truth, overlay comparison, metrics
368
- axes[1,0].imshow(mask_np, cmap='gray')
369
- axes[1,0].set_title('Ground Truth Mask', fontsize=12, weight='bold')
370
- axes[1,0].axis('off')
371
-
372
- # Overlay comparison - following notebook style
373
- overlay = np.array(image.convert('RGB').resize((256, 256)))
374
- overlay[pred_mask_np == 1] = [0, 255, 0] # Green for prediction
375
- overlay[mask_np > 0.5] = [255, 0, 0] # Red for ground truth
376
- axes[1,1].imshow(overlay)
377
- axes[1,1].set_title('Prediction (Green) vs GT (Red)', fontsize=12, weight='bold')
378
- axes[1,1].axis('off')
379
-
380
- # Calculate IoU and Dice exactly like notebook
381
- intersection = np.logical_and(pred_mask_np, mask_np).sum()
382
- union = np.logical_or(pred_mask_np, mask_np).sum()
383
  iou = intersection / (union + 1e-7)
384
-
385
- # Dice score
386
- dice = (2 * intersection) / (pred_mask_np.sum() + mask_np.sum() + 1e-7)
387
-
388
  print(f"Final IoU: {iou:.4f}")
389
  print(f"Final Dice: {dice:.4f}")
390
  print(f"Intersection: {intersection}")
391
  print(f"Union: {union}")
392
- print(f"Pred pixels: {np.sum(pred_mask_np)}")
393
- print(f"GT pixels: {np.sum(mask_np > 0.5)}")
394
-
395
- axes[1,2].text(0.1, 0.6, f'IoU: {iou:.4f}', fontsize=16, weight='bold')
396
- axes[1,2].text(0.1, 0.4, f'Dice: {dice:.4f}', fontsize=16, weight='bold')
397
- axes[1,2].set_xlim(0, 1)
398
- axes[1,2].set_ylim(0, 1)
399
- axes[1,2].axis('off')
400
- axes[1,2].set_title('Metrics', fontsize=12, weight='bold')
401
-
402
- # Additional tumor statistics
403
- axes[1,3].imshow(tumor_only, cmap='gray')
404
- axes[1,3].set_title('Segmented Tumor', fontsize=12, weight='bold')
405
- axes[1,3].axis('off')
406
-
407
  else:
408
  # No ground truth case
409
- axes[1,0].imshow(inv_pred_mask_np, cmap='gray')
410
- axes[1,0].set_title('Predicted Mask', fontsize=12, weight='bold')
411
- axes[1,0].axis('off')
412
-
413
- # Tumor only
414
- axes[1,1].imshow(tumor_only, cmap='gray')
415
- axes[1,1].set_title('Tumor Only', fontsize=12, weight='bold')
416
- axes[1,1].axis('off')
417
-
418
- # Overlay
419
- overlay = np.array(image.convert('RGB').resize((256, 256)))
420
- overlay[pred_mask_np == 1] = [255, 0, 0]
421
- axes[1,2].imshow(overlay)
422
- axes[1,2].set_title('Prediction Overlay', fontsize=12, weight='bold')
423
- axes[1,2].axis('off')
424
 
425
  plt.tight_layout()
426
-
427
- # Save plot
428
  buf = io.BytesIO()
429
  plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
430
  buf.seek(0)
431
  plt.close()
432
-
433
- result_image = Image.open(buf)
434
-
435
- # Generate analysis text
436
- tumor_pixels = np.sum(pred_mask_np)
437
- total_pixels = pred_mask_np.size
438
- tumor_percentage = (tumor_pixels / total_pixels) * 100
439
-
440
  print(f"Final tumor pixels: {tumor_pixels}")
441
  print(f"Final tumor percentage: {tumor_percentage:.2f}%")
442
- print("="*50)
443
-
444
  analysis_text = f"""
445
  # Analysis Results
446
 
@@ -454,20 +464,20 @@ def analyze_image(image, ground_truth, filename):
454
  - Attention Visualization: Generated
455
  - Post-processing: Applied
456
  """
457
-
458
  if ground_truth is not None:
459
  analysis_text += f"""
460
  **Performance Metrics:**
461
  - IoU Score: {iou:.4f}
462
  - Dice Score: {dice:.4f}
463
  """
464
-
465
  return result_image, analysis_text
466
-
467
  except Exception as e:
468
  import traceback
469
  error_msg = f"Analysis failed: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
470
- print(error_msg) # For debugging
471
  return None, error_msg
472
 
473
 
 
266
  return heatmap
267
 
268
  def analyze_image(image, ground_truth, filename):
269
+ """
270
+ Robust replacement for the original analyze_image.
271
+ - Fixes broadcasting issues between 2D masks and 3-channel images.
272
+ - Converts attention heatmap (BGR from OpenCV) to RGB for correct plotting.
273
+ - Ensures masks are strict binary uint8 arrays.
274
+ - Returns (PIL.Image result_plot, markdown_text).
275
+ """
276
  if model is None:
277
  return None, "Model not loaded. Please restart the application."
278
+
279
  if image is None:
280
  return None, "Please select an image first."
281
+
282
  try:
283
+ print("=" * 50)
284
  print("DEBUG: Starting analysis...")
285
  print(f"Input image mode: {image.mode}")
286
  print(f"Input image size: {image.size}")
287
+
288
+ # Preprocess - keeps same behavior as notebook
289
  input_tensor = preprocess_for_model(image).to(device)
290
  print(f"Input tensor shape: {input_tensor.shape}")
291
  print(f"Input tensor min/max: {input_tensor.min():.4f}/{input_tensor.max():.4f}")
292
+
293
  # Get prediction and attention maps
294
  with torch.no_grad():
295
  print("Getting model output...")
296
  model_output, attention_maps = model(input_tensor)
297
+
298
+ # model_output shape expected: [1, 1, 256, 256]
299
  print(f"Model output shape: {model_output.shape}")
300
  print(f"Model output min/max BEFORE sigmoid: {model_output.min():.4f}/{model_output.max():.4f}")
301
+
302
+ pred_prob = torch.sigmoid(model_output) # probabilities in [0,1]
303
+ print(f"After sigmoid min/max: {pred_prob.min():.4f}/{pred_prob.max():.4f}")
304
+
305
+ # DEFAULT THRESHOLD: 0.5 (same as your notebook). Change if debugging low-confidence.
306
+ pred_mask = (pred_prob > 0.5).float()
307
+ print(f"Binary mask sum (number of 1s): {pred_mask.sum():.4f}")
308
+
309
+ # Convert prediction to numpy
310
+ pred_mask_np = pred_mask.cpu().squeeze().numpy() # shape: (H, W)
 
311
  print(f"Numpy binary mask shape: {pred_mask_np.shape}")
312
  print(f"Numpy binary mask unique values: {np.unique(pred_mask_np)}")
313
  print(f"Numpy binary mask sum: {np.sum(pred_mask_np)}")
314
+
315
+ # Create attention heatmap (the helper resizes & returns a 3-channel BGR heatmap)
 
 
 
 
 
316
  print("Generating attention heatmap...")
317
+ att_heatmap = generate_attention_heatmap(attention_maps) # likely BGR (cv2)
318
+ print(f"Raw attention heatmap shape: {att_heatmap.shape}")
319
+
320
+ # Convert heatmap to RGB (OpenCV returns BGR)
321
+ if att_heatmap is not None and att_heatmap.size != 0:
322
+ try:
323
+ att_heatmap = cv2.cvtColor(att_heatmap, cv2.COLOR_BGR2RGB)
324
+ except Exception:
325
+ # if conversion fails, proceed with what we have
326
+ pass
327
+
328
+ # Prepare original image arrays:
329
+ original_gray = np.array(image.convert('L').resize((256, 256))).astype(np.uint8) # 2D
330
+ original_rgb = np.array(image.convert('RGB').resize((256, 256))).astype(np.uint8) # 3D
331
+
332
+ # Ensure pred_mask_np is strict binary 0/1 uint8
333
+ pred_mask_bin = (pred_mask_np > 0.5).astype(np.uint8) # shape: (256,256), dtype: uint8
334
+
335
+ # Inverted predicted mask for visualization (white background, tumor black)
336
+ inv_pred_mask_np = np.where(pred_mask_bin == 1, 0, 255).astype(np.uint8)
337
+
338
+ # Tumor-only images:
339
+ tumor_only_gray = np.where(pred_mask_bin == 1, original_gray, 255).astype(np.uint8)
340
+ tumor_only_rgb = original_rgb.copy()
341
+ tumor_only_rgb[pred_mask_bin == 0] = 255
342
+
343
+ # Begin plotting (match existing layout: 2x4 with GT or 2x3 without)
344
  if ground_truth is not None:
345
  fig, axes = plt.subplots(2, 4, figsize=(16, 8))
346
  else:
347
  fig, axes = plt.subplots(2, 3, figsize=(15, 8))
348
+
349
  fig.suptitle('Brain Tumor Segmentation Analysis', fontsize=16, weight='bold')
350
+
351
+ # Row 1: Original, Attention, Predicted Mask, Tumor Only (if GT exists show 4th)
352
+ axes[0, 0].imshow(original_gray, cmap='gray')
353
+ axes[0, 0].set_title('Original Image', fontsize=12, weight='bold')
354
+ axes[0, 0].axis('off')
355
+
356
+ # Attention overlay on RGB original (blend)
357
+ axes[0, 1].imshow(original_rgb)
358
+ if att_heatmap is not None and att_heatmap.size != 0:
359
+ axes[0, 1].imshow(att_heatmap, alpha=0.4)
360
+ axes[0, 1].set_title('Attention Heatmap', fontsize=12, weight='bold')
361
+ axes[0, 1].axis('off')
362
+
363
  # Predicted mask (inverted for visualization)
364
+ axes[0, 2].imshow(inv_pred_mask_np, cmap='gray')
365
+ axes[0, 2].set_title('Predicted Mask', fontsize=12, weight='bold')
366
+ axes[0, 2].axis('off')
367
+
368
  if ground_truth is not None:
369
+ axes[0, 3].imshow(tumor_only_rgb)
370
+ axes[0, 3].set_title('Tumor Only', fontsize=12, weight='bold')
371
+ axes[0, 3].axis('off')
372
+
373
  # Ground truth processing - convert to binary like notebook
 
 
374
  val_test_transform = transforms.Compose([
375
+ transforms.Resize((256, 256)),
376
  transforms.ToTensor()
377
  ])
378
  mask_np = val_test_transform(ground_truth).cpu().squeeze().numpy()
379
+ mask_bin = (mask_np > 0.5).astype(np.uint8)
380
+
381
+ print(f"Ground truth array shape: {np.array(ground_truth.resize((256,256))).shape}")
382
+ print(f"Ground truth unique values: {np.unique(np.array(ground_truth.resize((256,256))))}")
383
+
384
+ # Row 2: Ground truth, overlay comparison, metrics, segmented tumor
385
+ axes[1, 0].imshow(mask_bin, cmap='gray')
386
+ axes[1, 0].set_title('Ground Truth Mask', fontsize=12, weight='bold')
387
+ axes[1, 0].axis('off')
388
+
389
+ overlay = original_rgb.copy()
390
+ overlay[pred_mask_bin == 1] = [0, 255, 0] # predicted green
391
+ overlay[mask_bin == 1] = [255, 0, 0] # ground truth red
392
+ axes[1, 1].imshow(overlay)
393
+ axes[1, 1].set_title('Prediction (Green) vs GT (Red)', fontsize=12, weight='bold')
394
+ axes[1, 1].axis('off')
395
+
396
+ # Metrics calculation (IoU and Dice)
397
+ intersection = np.logical_and(pred_mask_bin, mask_bin).sum()
398
+ union = np.logical_or(pred_mask_bin, mask_bin).sum()
 
 
 
 
 
399
  iou = intersection / (union + 1e-7)
400
+ dice = (2 * intersection) / (pred_mask_bin.sum() + mask_bin.sum() + 1e-7)
401
+
 
 
402
  print(f"Final IoU: {iou:.4f}")
403
  print(f"Final Dice: {dice:.4f}")
404
  print(f"Intersection: {intersection}")
405
  print(f"Union: {union}")
406
+ print(f"Pred pixels: {np.sum(pred_mask_bin)}")
407
+ print(f"GT pixels: {np.sum(mask_bin)}")
408
+
409
+ axes[1, 2].text(0.1, 0.6, f'IoU: {iou:.4f}', fontsize=16, weight='bold')
410
+ axes[1, 2].text(0.1, 0.4, f'Dice: {dice:.4f}', fontsize=16, weight='bold')
411
+ axes[1, 2].set_xlim(0, 1)
412
+ axes[1, 2].set_ylim(0, 1)
413
+ axes[1, 2].axis('off')
414
+ axes[1, 2].set_title('Metrics', fontsize=12, weight='bold')
415
+
416
+ axes[1, 3].imshow(tumor_only_gray, cmap='gray')
417
+ axes[1, 3].set_title('Segmented Tumor', fontsize=12, weight='bold')
418
+ axes[1, 3].axis('off')
419
+
 
420
  else:
421
  # No ground truth case
422
+ axes[1, 0].imshow(inv_pred_mask_np, cmap='gray')
423
+ axes[1, 0].set_title('Predicted Mask', fontsize=12, weight='bold')
424
+ axes[1, 0].axis('off')
425
+
426
+ axes[1, 1].imshow(tumor_only_gray, cmap='gray')
427
+ axes[1, 1].set_title('Tumor Only', fontsize=12, weight='bold')
428
+ axes[1, 1].axis('off')
429
+
430
+ overlay = original_rgb.copy()
431
+ overlay[pred_mask_bin == 1] = [255, 0, 0] # red for prediction overlay
432
+ axes[1, 2].imshow(overlay)
433
+ axes[1, 2].set_title('Prediction Overlay', fontsize=12, weight='bold')
434
+ axes[1, 2].axis('off')
 
 
435
 
436
  plt.tight_layout()
437
+
438
+ # Save plot to buffer and return as PIL image
439
  buf = io.BytesIO()
440
  plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
441
  buf.seek(0)
442
  plt.close()
443
+ result_image = Image.open(buf).convert("RGB")
444
+
445
+ # Analysis text: tumor area
446
+ tumor_pixels = int(np.sum(pred_mask_bin))
447
+ total_pixels = int(pred_mask_bin.size)
448
+ tumor_percentage = (tumor_pixels / total_pixels) * 100 if total_pixels > 0 else 0.0
449
+
 
450
  print(f"Final tumor pixels: {tumor_pixels}")
451
  print(f"Final tumor percentage: {tumor_percentage:.2f}%")
452
+ print("=" * 50)
453
+
454
  analysis_text = f"""
455
  # Analysis Results
456
 
 
464
  - Attention Visualization: Generated
465
  - Post-processing: Applied
466
  """
467
+
468
  if ground_truth is not None:
469
  analysis_text += f"""
470
  **Performance Metrics:**
471
  - IoU Score: {iou:.4f}
472
  - Dice Score: {dice:.4f}
473
  """
474
+
475
  return result_image, analysis_text
476
+
477
  except Exception as e:
478
  import traceback
479
  error_msg = f"Analysis failed: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
480
+ print(error_msg)
481
  return None, error_msg
482
 
483