ArchCoder commited on
Commit
a4f4e25
·
verified ·
1 Parent(s): bccb310

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -71
app.py CHANGED
@@ -226,7 +226,7 @@ def get_random_sample():
226
  return None, None, f"Error loading sample: {e}"
227
 
228
  def preprocess_for_model(image):
229
- """Preprocessing for your model"""
230
  if image.mode != 'L':
231
  image = image.convert('L')
232
 
@@ -238,7 +238,7 @@ def preprocess_for_model(image):
238
  return transform(image).unsqueeze(0)
239
 
240
  def generate_attention_heatmap(attention_maps):
241
- """Generate attention heatmap - Fixed version"""
242
  if not attention_maps:
243
  return np.zeros((256, 256, 3))
244
 
@@ -266,7 +266,7 @@ def generate_attention_heatmap(attention_maps):
266
  return heatmap
267
 
268
  def analyze_image(image, ground_truth, filename):
269
- """Main analysis function - DEBUG VERSION"""
270
  if model is None:
271
  return None, "Model not loaded. Please restart the application."
272
 
@@ -279,7 +279,7 @@ def analyze_image(image, ground_truth, filename):
279
  print(f"Input image mode: {image.mode}")
280
  print(f"Input image size: {image.size}")
281
 
282
- # Preprocess
283
  input_tensor = preprocess_for_model(image).to(device)
284
  print(f"Input tensor shape: {input_tensor.shape}")
285
  print(f"Input tensor min/max: {input_tensor.min():.4f}/{input_tensor.max():.4f}")
@@ -292,128 +292,135 @@ def analyze_image(image, ground_truth, filename):
292
  print(f"Model output shape: {model_output.shape}")
293
  print(f"Model output min/max BEFORE sigmoid: {model_output.min():.4f}/{model_output.max():.4f}")
294
 
295
- # Apply sigmoid
296
  pred_mask = torch.sigmoid(model_output)
297
  print(f"After sigmoid min/max: {pred_mask.min():.4f}/{pred_mask.max():.4f}")
298
 
299
- # Check values before thresholding
300
- unique_vals = torch.unique(pred_mask)
301
- print(f"Unique values in prediction: {unique_vals[:10]}") # Show first 10 unique values
302
-
303
- # Apply threshold
304
  binary_mask = (pred_mask > 0.5).float()
305
- print(f"Binary mask shape: {binary_mask.shape}")
306
  print(f"Binary mask sum (number of 1s): {binary_mask.sum()}")
307
 
308
- # Convert to numpy
309
- binary_mask_np = binary_mask.squeeze().cpu().numpy()
310
- print(f"Numpy binary mask shape: {binary_mask_np.shape}")
311
- print(f"Numpy binary mask unique values: {np.unique(binary_mask_np)}")
312
- print(f"Numpy binary mask sum: {np.sum(binary_mask_np)}")
313
-
314
- # Try different thresholds if 0.5 doesn't work
315
- if np.sum(binary_mask_np) == 0:
316
- print("No pixels detected with threshold 0.5, trying lower thresholds...")
317
- for thresh in [0.3, 0.2, 0.1, 0.05]:
318
- test_mask = (pred_mask > thresh).float().squeeze().cpu().numpy()
319
- pixel_count = np.sum(test_mask)
320
- print(f"Threshold {thresh}: {pixel_count} pixels")
321
- if pixel_count > 0:
322
- print(f"Using threshold {thresh} instead of 0.5")
323
- binary_mask_np = test_mask
324
- break
325
-
326
- # Post-processing (morphological operations)
327
- print("Applying morphological operations...")
328
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
329
- binary_mask_np = cv2.morphologyEx(binary_mask_np.astype(np.uint8), cv2.MORPH_OPEN, kernel)
330
- binary_mask_np = cv2.morphologyEx(binary_mask_np, cv2.MORPH_CLOSE, kernel)
331
- print(f"After morphological ops sum: {np.sum(binary_mask_np)}")
332
 
333
  # Generate attention heatmap
334
  print("Generating attention heatmap...")
335
  att_heatmap = generate_attention_heatmap(attention_maps)
336
  print(f"Attention heatmap shape: {att_heatmap.shape}")
337
 
 
 
 
 
 
 
338
  # Create visualization
339
  if ground_truth is not None:
340
- fig, axes = plt.subplots(2, 3, figsize=(15, 10))
341
  else:
342
- fig, axes = plt.subplots(2, 2, figsize=(12, 10))
343
 
344
  fig.suptitle('Brain Tumor Segmentation Analysis', fontsize=16, weight='bold')
345
 
346
- # Original image
347
- axes[0,0].imshow(image, cmap='gray')
348
  axes[0,0].set_title('Original Image', fontsize=12, weight='bold')
349
  axes[0,0].axis('off')
350
 
351
- # Attention heatmap
352
- axes[0,1].imshow(image, cmap='gray')
353
  axes[0,1].imshow(att_heatmap, alpha=0.4)
354
  axes[0,1].set_title('Attention Heatmap', fontsize=12, weight='bold')
355
  axes[0,1].axis('off')
356
 
357
- # Predicted mask
 
 
 
 
358
  if ground_truth is not None:
359
- axes[0,2].imshow(binary_mask_np, cmap='gray')
360
- axes[0,2].set_title('Predicted Mask', fontsize=12, weight='bold')
361
- axes[0,2].axis('off')
362
-
363
- # Ground truth
364
  gt_array = np.array(ground_truth.resize((256, 256)))
 
 
 
 
 
 
 
365
  print(f"Ground truth array shape: {gt_array.shape}")
366
  print(f"Ground truth unique values: {np.unique(gt_array)}")
367
 
368
- # Normalize ground truth to binary (0 or 1)
369
- gt_binary = (gt_array > 128).astype(np.uint8)
370
- print(f"GT binary sum: {np.sum(gt_binary)}")
 
371
 
372
- axes[1,0].imshow(gt_binary, cmap='gray')
 
373
  axes[1,0].set_title('Ground Truth Mask', fontsize=12, weight='bold')
374
  axes[1,0].axis('off')
375
 
376
- # Overlay comparison
377
  overlay = np.array(image.convert('RGB').resize((256, 256)))
378
- overlay[binary_mask_np > 0] = [0, 255, 0] # Green for prediction
379
- overlay[gt_binary > 0] = [255, 0, 0] # Red for ground truth
380
  axes[1,1].imshow(overlay)
381
  axes[1,1].set_title('Prediction (Green) vs GT (Red)', fontsize=12, weight='bold')
382
  axes[1,1].axis('off')
383
 
384
- # Calculate IoU and Dice
385
- pred_binary = binary_mask_np > 0
386
- gt_binary_bool = gt_binary > 0
387
- intersection = np.sum(pred_binary & gt_binary_bool)
388
- union = np.sum(pred_binary | gt_binary_bool)
389
- iou = intersection / (union + 1e-8)
390
 
391
  # Dice score
392
- dice = (2 * intersection) / (np.sum(pred_binary) + np.sum(gt_binary_bool) + 1e-8)
393
 
394
  print(f"Final IoU: {iou:.4f}")
395
  print(f"Final Dice: {dice:.4f}")
396
  print(f"Intersection: {intersection}")
397
  print(f"Union: {union}")
398
- print(f"Pred pixels: {np.sum(pred_binary)}")
399
- print(f"GT pixels: {np.sum(gt_binary_bool)}")
400
 
401
  axes[1,2].text(0.1, 0.6, f'IoU: {iou:.4f}', fontsize=16, weight='bold')
402
  axes[1,2].text(0.1, 0.4, f'Dice: {dice:.4f}', fontsize=16, weight='bold')
403
  axes[1,2].set_xlim(0, 1)
404
  axes[1,2].set_ylim(0, 1)
405
  axes[1,2].axis('off')
 
 
 
 
 
 
 
406
  else:
407
- axes[1,0].imshow(binary_mask_np, cmap='gray')
 
408
  axes[1,0].set_title('Predicted Mask', fontsize=12, weight='bold')
409
  axes[1,0].axis('off')
410
 
 
 
 
 
 
411
  # Overlay
412
  overlay = np.array(image.convert('RGB').resize((256, 256)))
413
- overlay[binary_mask_np > 0] = [255, 0, 0]
414
- axes[1,1].imshow(overlay)
415
- axes[1,1].set_title('Prediction Overlay', fontsize=12, weight='bold')
416
- axes[1,1].axis('off')
417
 
418
  plt.tight_layout()
419
 
@@ -426,8 +433,8 @@ def analyze_image(image, ground_truth, filename):
426
  result_image = Image.open(buf)
427
 
428
  # Generate analysis text
429
- tumor_pixels = np.sum(binary_mask_np)
430
- total_pixels = binary_mask_np.size
431
  tumor_percentage = (tumor_pixels / total_pixels) * 100
432
 
433
  print(f"Final tumor pixels: {tumor_pixels}")
@@ -445,7 +452,7 @@ def analyze_image(image, ground_truth, filename):
445
 
446
  **Model Features:**
447
  - Attention Visualization: Generated
448
- - Post-processing: Morphological cleanup
449
  """
450
 
451
  if ground_truth is not None:
 
226
  return None, None, f"Error loading sample: {e}"
227
 
228
  def preprocess_for_model(image):
229
+ """Preprocessing for your model - matches the working notebook"""
230
  if image.mode != 'L':
231
  image = image.convert('L')
232
 
 
238
  return transform(image).unsqueeze(0)
239
 
240
  def generate_attention_heatmap(attention_maps):
241
+ """Generate attention heatmap"""
242
  if not attention_maps:
243
  return np.zeros((256, 256, 3))
244
 
 
266
  return heatmap
267
 
268
  def analyze_image(image, ground_truth, filename):
269
+ """Main analysis function - FIXED VERSION matching the working notebook"""
270
  if model is None:
271
  return None, "Model not loaded. Please restart the application."
272
 
 
279
  print(f"Input image mode: {image.mode}")
280
  print(f"Input image size: {image.size}")
281
 
282
+ # Preprocess - exactly like the working notebook
283
  input_tensor = preprocess_for_model(image).to(device)
284
  print(f"Input tensor shape: {input_tensor.shape}")
285
  print(f"Input tensor min/max: {input_tensor.min():.4f}/{input_tensor.max():.4f}")
 
292
  print(f"Model output shape: {model_output.shape}")
293
  print(f"Model output min/max BEFORE sigmoid: {model_output.min():.4f}/{model_output.max():.4f}")
294
 
295
+ # Apply sigmoid and threshold - EXACTLY like the working notebook
296
  pred_mask = torch.sigmoid(model_output)
297
  print(f"After sigmoid min/max: {pred_mask.min():.4f}/{pred_mask.max():.4f}")
298
 
299
+ # Apply threshold to get binary mask
 
 
 
 
300
  binary_mask = (pred_mask > 0.5).float()
 
301
  print(f"Binary mask sum (number of 1s): {binary_mask.sum()}")
302
 
303
+ # Convert to numpy - following notebook approach
304
+ pred_mask_np = binary_mask.cpu().squeeze().numpy()
305
+ print(f"Numpy binary mask shape: {pred_mask_np.shape}")
306
+ print(f"Numpy binary mask unique values: {np.unique(pred_mask_np)}")
307
+ print(f"Numpy binary mask sum: {np.sum(pred_mask_np)}")
308
+
309
+ # Create visualization mask like in the notebook
310
+ # The notebook uses: inv_pred_mask_np = np.where(pred_mask_np == 1, 0, 255)
311
+ # This inverts the mask for better visualization
312
+ inv_pred_mask_np = np.where(pred_mask_np == 1, 0, 255)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
  # Generate attention heatmap
315
  print("Generating attention heatmap...")
316
  att_heatmap = generate_attention_heatmap(attention_maps)
317
  print(f"Attention heatmap shape: {att_heatmap.shape}")
318
 
319
+ # Prepare original image array
320
+ original_np = np.array(image.resize((256, 256)))
321
+
322
+ # Create tumor-only image (like in notebook)
323
+ tumor_only = np.where(pred_mask_np == 1, original_np, 255)
324
+
325
  # Create visualization
326
  if ground_truth is not None:
327
+ fig, axes = plt.subplots(2, 4, figsize=(16, 8))
328
  else:
329
+ fig, axes = plt.subplots(2, 3, figsize=(15, 8))
330
 
331
  fig.suptitle('Brain Tumor Segmentation Analysis', fontsize=16, weight='bold')
332
 
333
+ # Row 1: Original, Attention, Predicted Mask, Tumor Only
334
+ axes[0,0].imshow(original_np, cmap='gray')
335
  axes[0,0].set_title('Original Image', fontsize=12, weight='bold')
336
  axes[0,0].axis('off')
337
 
338
+ # Attention heatmap overlay
339
+ axes[0,1].imshow(original_np, cmap='gray')
340
  axes[0,1].imshow(att_heatmap, alpha=0.4)
341
  axes[0,1].set_title('Attention Heatmap', fontsize=12, weight='bold')
342
  axes[0,1].axis('off')
343
 
344
+ # Predicted mask (inverted for visualization)
345
+ axes[0,2].imshow(inv_pred_mask_np, cmap='gray')
346
+ axes[0,2].set_title('Predicted Mask', fontsize=12, weight='bold')
347
+ axes[0,2].axis('off')
348
+
349
  if ground_truth is not None:
350
+ # Ground truth processing - convert to binary like notebook
 
 
 
 
351
  gt_array = np.array(ground_truth.resize((256, 256)))
352
+ # Apply same preprocessing as notebook
353
+ val_test_transform = transforms.Compose([
354
+ transforms.Resize((256,256)),
355
+ transforms.ToTensor()
356
+ ])
357
+ mask_np = val_test_transform(ground_truth).cpu().squeeze().numpy()
358
+
359
  print(f"Ground truth array shape: {gt_array.shape}")
360
  print(f"Ground truth unique values: {np.unique(gt_array)}")
361
 
362
+ # Tumor only image
363
+ axes[0,3].imshow(tumor_only, cmap='gray')
364
+ axes[0,3].set_title('Tumor Only', fontsize=12, weight='bold')
365
+ axes[0,3].axis('off')
366
 
367
+ # Row 2: Ground truth, overlay comparison, metrics
368
+ axes[1,0].imshow(mask_np, cmap='gray')
369
  axes[1,0].set_title('Ground Truth Mask', fontsize=12, weight='bold')
370
  axes[1,0].axis('off')
371
 
372
+ # Overlay comparison - following notebook style
373
  overlay = np.array(image.convert('RGB').resize((256, 256)))
374
+ overlay[pred_mask_np == 1] = [0, 255, 0] # Green for prediction
375
+ overlay[mask_np > 0.5] = [255, 0, 0] # Red for ground truth
376
  axes[1,1].imshow(overlay)
377
  axes[1,1].set_title('Prediction (Green) vs GT (Red)', fontsize=12, weight='bold')
378
  axes[1,1].axis('off')
379
 
380
+ # Calculate IoU and Dice exactly like notebook
381
+ intersection = np.logical_and(pred_mask_np, mask_np).sum()
382
+ union = np.logical_or(pred_mask_np, mask_np).sum()
383
+ iou = intersection / (union + 1e-7)
 
 
384
 
385
  # Dice score
386
+ dice = (2 * intersection) / (pred_mask_np.sum() + mask_np.sum() + 1e-7)
387
 
388
  print(f"Final IoU: {iou:.4f}")
389
  print(f"Final Dice: {dice:.4f}")
390
  print(f"Intersection: {intersection}")
391
  print(f"Union: {union}")
392
+ print(f"Pred pixels: {np.sum(pred_mask_np)}")
393
+ print(f"GT pixels: {np.sum(mask_np > 0.5)}")
394
 
395
  axes[1,2].text(0.1, 0.6, f'IoU: {iou:.4f}', fontsize=16, weight='bold')
396
  axes[1,2].text(0.1, 0.4, f'Dice: {dice:.4f}', fontsize=16, weight='bold')
397
  axes[1,2].set_xlim(0, 1)
398
  axes[1,2].set_ylim(0, 1)
399
  axes[1,2].axis('off')
400
+ axes[1,2].set_title('Metrics', fontsize=12, weight='bold')
401
+
402
+ # Additional tumor statistics
403
+ axes[1,3].imshow(tumor_only, cmap='gray')
404
+ axes[1,3].set_title('Segmented Tumor', fontsize=12, weight='bold')
405
+ axes[1,3].axis('off')
406
+
407
  else:
408
+ # No ground truth case
409
+ axes[1,0].imshow(inv_pred_mask_np, cmap='gray')
410
  axes[1,0].set_title('Predicted Mask', fontsize=12, weight='bold')
411
  axes[1,0].axis('off')
412
 
413
+ # Tumor only
414
+ axes[1,1].imshow(tumor_only, cmap='gray')
415
+ axes[1,1].set_title('Tumor Only', fontsize=12, weight='bold')
416
+ axes[1,1].axis('off')
417
+
418
  # Overlay
419
  overlay = np.array(image.convert('RGB').resize((256, 256)))
420
+ overlay[pred_mask_np == 1] = [255, 0, 0]
421
+ axes[1,2].imshow(overlay)
422
+ axes[1,2].set_title('Prediction Overlay', fontsize=12, weight='bold')
423
+ axes[1,2].axis('off')
424
 
425
  plt.tight_layout()
426
 
 
433
  result_image = Image.open(buf)
434
 
435
  # Generate analysis text
436
+ tumor_pixels = np.sum(pred_mask_np)
437
+ total_pixels = pred_mask_np.size
438
  tumor_percentage = (tumor_pixels / total_pixels) * 100
439
 
440
  print(f"Final tumor pixels: {tumor_pixels}")
 
452
 
453
  **Model Features:**
454
  - Attention Visualization: Generated
455
+ - Post-processing: Applied
456
  """
457
 
458
  if ground_truth is not None: