ArchCoder commited on
Commit
fa07b35
·
verified ·
1 Parent(s): cd271b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +228 -177
app.py CHANGED
@@ -6,15 +6,18 @@ import cv2
6
  from PIL import Image
7
  import matplotlib.pyplot as plt
8
  import io
9
- from torchvision import transforms
10
  import torchvision.transforms.functional as TF
11
- import urllib.request
12
  import os
 
 
 
13
 
14
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
  model = None
16
 
17
- # Define your Attention U-Net architecture (from your training code)
18
  class DoubleConv(nn.Module):
19
  def __init__(self, in_channels, out_channels):
20
  super(DoubleConv, self).__init__()
@@ -56,7 +59,7 @@ class AttentionBlock(nn.Module):
56
  x1 = self.W_x(x)
57
  psi = self.relu(g1 + x1)
58
  psi = self.psi(psi)
59
- return x * psi
60
 
61
  class AttentionUNET(nn.Module):
62
  def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]):
@@ -67,7 +70,7 @@ class AttentionUNET(nn.Module):
67
  self.attentions = nn.ModuleList()
68
  self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
69
 
70
- # Down part of UNET
71
  for feature in features:
72
  self.downs.append(DoubleConv(in_channels, feature))
73
  in_channels = feature
@@ -75,7 +78,7 @@ class AttentionUNET(nn.Module):
75
  # Bottleneck
76
  self.bottleneck = DoubleConv(features[-1], features[-1]*2)
77
 
78
- # Up part of UNET
79
  for feature in reversed(features):
80
  self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
81
  self.attentions.append(AttentionBlock(F_g=feature, F_l=feature, F_int=feature // 2))
@@ -85,6 +88,7 @@ class AttentionUNET(nn.Module):
85
 
86
  def forward(self, x):
87
  skip_connections = []
 
88
 
89
  for down in self.downs:
90
  x = down(x)
@@ -92,20 +96,21 @@ class AttentionUNET(nn.Module):
92
  x = self.pool(x)
93
 
94
  x = self.bottleneck(x)
95
- skip_connections = skip_connections[::-1] #reverse list
96
 
97
- for idx in range(0, len(self.ups), 2): #do up and double_conv
98
  x = self.ups[idx](x)
99
  skip_connection = skip_connections[idx//2]
100
 
101
  if x.shape != skip_connection.shape:
102
  x = TF.resize(x, size=skip_connection.shape[2:])
103
 
104
- skip_connection = self.attentions[idx // 2](skip_connection, x)
105
- concat_skip = torch.cat((skip_connection, x), dim=1)
 
106
  x = self.ups[idx+1](concat_skip)
107
 
108
- return self.final_conv(x)
109
 
110
  def download_model():
111
  """Download your trained model from HuggingFace"""
@@ -120,9 +125,6 @@ def download_model():
120
  except Exception as e:
121
  print(f"❌ Failed to download model: {e}")
122
  return None
123
- else:
124
- print("✅ Model already exists!")
125
-
126
  return model_path
127
 
128
  def load_your_attention_model():
@@ -151,66 +153,174 @@ def load_your_attention_model():
151
  model = None
152
  return model
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  def preprocess_for_your_model(image):
155
  """Preprocessing exactly like your Colab code"""
156
- # Convert to grayscale (like your Colab code)
157
  if image.mode != 'L':
158
  image = image.convert('L')
159
 
160
- # Use the exact same transform as your Colab code
161
  val_test_transform = transforms.Compose([
162
  transforms.Resize((256,256)),
163
  transforms.ToTensor()
164
  ])
165
 
166
- return val_test_transform(image).unsqueeze(0) # Add batch dimension
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
- def predict_tumor(image):
169
  current_model = load_your_attention_model()
170
 
171
  if current_model is None:
172
- return None, "Failed to load your trained model."
173
 
174
  if image is None:
175
- return None, "⚠️ Please upload an image first."
176
 
177
  try:
178
- print("🧠 Processing with YOUR trained Attention U-Net...")
179
-
180
- # Use the exact preprocessing from your Colab code
181
  input_tensor = preprocess_for_your_model(image).to(device)
182
 
183
- # Predict using your model (exactly like your Colab code)
184
- with torch.no_grad():
185
- pred_mask = torch.sigmoid(current_model(input_tensor))
186
- pred_mask_binary = (pred_mask > 0.5).float()
 
187
 
188
- # Convert to numpy (like your Colab code)
189
- pred_mask_np = pred_mask_binary.cpu().squeeze().numpy()
190
- original_np = np.array(image.convert('L').resize((256, 256)))
 
191
 
192
- # Create inverted mask for visualization (like your Colab code)
193
- inv_pred_mask_np = np.where(pred_mask_np == 1, 0, 255)
 
194
 
195
- # Create tumor-only image (like your Colab code)
196
- tumor_only = np.where(pred_mask_np == 1, original_np, 255)
 
197
 
198
- # Create visualization (matching your Colab 4-panel layout)
199
- fig, axes = plt.subplots(1, 4, figsize=(20, 5))
200
- fig.suptitle('🧠 Your Attention U-Net Results', fontsize=16, fontweight='bold')
 
201
 
202
- titles = ["Original Image", "Tumor Segmentation", "Inverted Mask", "Tumor Only"]
203
- images = [original_np, pred_mask_np * 255, inv_pred_mask_np, tumor_only]
204
- cmaps = ['gray', 'hot', 'gray', 'gray']
 
 
205
 
206
- for i, ax in enumerate(axes):
207
- ax.imshow(images[i], cmap=cmaps[i])
208
- ax.set_title(titles[i], fontsize=12, fontweight='bold')
209
- ax.axis('off')
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  plt.tight_layout()
212
 
213
- # Save result
214
  buf = io.BytesIO()
215
  plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
216
  buf.seek(0)
@@ -218,187 +328,128 @@ def predict_tumor(image):
218
 
219
  result_image = Image.open(buf)
220
 
221
- # Calculate statistics (like your Colab code)
222
- tumor_pixels = np.sum(pred_mask_np)
223
- total_pixels = pred_mask_np.size
224
  tumor_percentage = (tumor_pixels / total_pixels) * 100
225
 
226
- # Calculate confidence metrics
227
- max_confidence = torch.max(pred_mask).item()
228
- mean_confidence = torch.mean(pred_mask).item()
229
-
230
  analysis_text = f"""
231
- ## 🧠 Your Attention U-Net Analysis Results
232
-
233
- ### 📊 Detection Summary:
234
- - **Status**: {'🔴 TUMOR DETECTED' if tumor_pixels > 50 else '🟢 NO SIGNIFICANT TUMOR'}
235
- - **Tumor Area**: {tumor_percentage:.2f}% of brain region
236
- - **Tumor Pixels**: {tumor_pixels:,} pixels
237
- - **Max Confidence**: {max_confidence:.4f}
238
- - **Mean Confidence**: {mean_confidence:.4f}
239
-
240
- ### 🔬 Your Model Information:
241
- - **Architecture**: YOUR trained Attention U-Net
242
- - **Training Performance**: Dice: 0.8420, IoU: 0.7297
243
- - **Input**: Grayscale (single channel)
244
- - **Output**: Binary segmentation mask
245
- - **Device**: {device.type.upper()}
246
-
247
- ### 🎯 Model Performance:
248
- - **Training Accuracy**: 98.90%
249
- - **Best Dice Score**: 0.8420
250
- - **Best IoU Score**: 0.7297
251
- - **Training Dataset**: Brain tumor segmentation dataset
252
-
253
- ### 📈 Processing Details:
254
- - **Preprocessing**: Resize(256×256) + ToTensor (your exact method)
255
- - **Threshold**: 0.5 (sigmoid > 0.5)
256
- - **Architecture**: Attention gates + Skip connections
257
- - **Features**: [32, 64, 128, 256] channels
258
-
259
- ### ⚠️ Medical Disclaimer:
260
- This is YOUR trained AI model for **research and educational purposes only**.
261
- Results should be validated by medical professionals. Not for clinical diagnosis.
262
-
263
- ### 🏆 Model Quality:
264
- ✅ This is your own trained model with proven {tumor_percentage:.2f}% detection capability!
265
  """
266
 
267
- print(f"✅ Your model analysis completed! Tumor area: {tumor_percentage:.2f}%")
 
 
268
  return result_image, analysis_text
269
 
270
  except Exception as e:
271
- error_msg = f"Error with your model: {str(e)}"
272
- print(error_msg)
273
- return None, error_msg
274
 
275
  def clear_all():
276
- return None, None, "Upload a brain MRI image to test YOUR trained Attention U-Net model"
277
 
278
- # Enhanced CSS for your model
279
  css = """
280
  .gradio-container {
281
  max-width: 1400px !important;
282
  margin: auto !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  }
284
- #title {
285
- text-align: center;
286
- background: linear-gradient(135deg, #8B5CF6 0%, #7C3AED 100%);
287
- color: white;
288
- padding: 30px;
289
- border-radius: 15px;
290
- margin-bottom: 25px;
291
- box-shadow: 0 8px 16px rgba(139, 92, 246, 0.3);
292
  }
293
  """
294
 
295
- # Create Gradio interface for your model
296
- with gr.Blocks(css=css, title="🧠 Your Attention U-Net Model", theme=gr.themes.Soft()) as app:
297
 
298
- gr.HTML("""
299
- <div id="title">
300
- <h1>🧠 YOUR Attention U-Net Model</h1>
301
- <p style="font-size: 18px; margin-top: 15px;">
302
- Using Your Own Trained Model • Dice: 0.8420 • IoU: 0.7297
303
- </p>
304
- <p style="font-size: 14px; margin-top: 10px; opacity: 0.9;">
305
- Loaded from: ArchCoder/the-op-segmenter HuggingFace Space
306
- </p>
307
- </div>
308
  """)
309
 
310
  with gr.Row():
311
  with gr.Column(scale=1):
312
- gr.Markdown("### 📤 Upload Brain MRI")
313
 
314
  image_input = gr.Image(
315
- label="Brain MRI Scan",
316
  type="pil",
317
  sources=["upload", "webcam"],
318
- height=350
319
  )
320
 
321
- with gr.Row():
322
- analyze_btn = gr.Button("🔍 Analyze with YOUR Model", variant="primary", scale=2, size="lg")
323
- clear_btn = gr.Button("🗑️ Clear", variant="secondary", scale=1)
324
 
325
- gr.HTML("""
326
- <div style="margin-top: 20px; padding: 20px; background: linear-gradient(135deg, #F3E8FF 0%, #EDE9FE 100%); border-radius: 10px; border-left: 4px solid #8B5CF6;">
327
- <h4 style="color: #8B5CF6; margin-bottom: 15px;">🏆 Your Model Features:</h4>
328
- <ul style="margin: 10px 0; padding-left: 20px; line-height: 1.6;">
329
- <li><strong>Personal Model:</strong> Your own trained Attention U-Net</li>
330
- <li><strong>Proven Performance:</strong> 84.2% Dice Score, 72.97% IoU</li>
331
- <li><strong>Attention Gates:</strong> Advanced feature selection</li>
332
- <li><strong>Clean Output:</strong> Binary segmentation masks</li>
333
- <li><strong>4-Panel View:</strong> Complete analysis like your Colab</li>
334
- </ul>
335
- </div>
336
- """)
337
 
338
  with gr.Column(scale=2):
339
- gr.Markdown("### 📊 Your Model Results")
340
 
341
  output_image = gr.Image(
342
- label="Your Attention U-Net Analysis",
343
  type="pil",
344
- height=500
345
  )
346
 
347
  analysis_output = gr.Markdown(
348
- value="Upload a brain MRI image to test YOUR trained Attention U-Net model.",
349
- elem_id="analysis"
350
  )
351
-
352
- # Footer highlighting your model
353
- gr.HTML("""
354
- <div style="margin-top: 30px; padding: 25px; background-color: #F8FAFC; border-radius: 15px; border: 2px solid #8B5CF6;">
355
- <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 30px;">
356
- <div>
357
- <h4 style="color: #8B5CF6; margin-bottom: 15px;">🏆 Your Personal AI Model</h4>
358
- <p><strong>Architecture:</strong> Attention U-Net with skip connections</p>
359
- <p><strong>Performance:</strong> Dice: 0.8420, IoU: 0.7297, Accuracy: 98.90%</p>
360
- <p><strong>Training:</strong> Your own dataset-specific training</p>
361
- <p><strong>Features:</strong> [32, 64, 128, 256] channel progression</p>
362
- </div>
363
- <div>
364
- <h4 style="color: #DC2626; margin-bottom: 15px;">⚠️ Your Model Disclaimer</h4>
365
- <p style="color: #DC2626; font-weight: 600; line-height: 1.4;">
366
- This is YOUR personally trained AI model for <strong>research purposes only</strong>.<br>
367
- Results reflect your model's training performance.<br>
368
- Always validate with medical professionals for any clinical application.
369
- </p>
370
- </div>
371
- </div>
372
- <hr style="margin: 20px 0; border: none; border-top: 2px solid #E5E7EB;">
373
- <p style="text-align: center; color: #6B7280; margin: 10px 0; font-weight: 600;">
374
- 🚀 Your Personal Attention U-Net • Downloaded from HuggingFace • Research-Grade Performance
375
- </p>
376
- </div>
377
- """)
378
 
 
 
 
 
379
  # Event handlers
380
  analyze_btn.click(
381
  fn=predict_tumor,
382
- inputs=[image_input],
383
- outputs=[output_image, analysis_output],
384
- show_progress=True
 
 
 
 
 
385
  )
386
 
387
  clear_btn.click(
388
  fn=clear_all,
389
  inputs=[],
390
- outputs=[image_input, output_image, analysis_output]
391
  )
392
 
393
  if __name__ == "__main__":
394
- print("🚀 Starting YOUR Attention U-Net Model System...")
395
- print("🏆 Using your personally trained model")
396
- print("📥 Auto-downloading from HuggingFace...")
397
- print("🎯 Expected performance: Dice 0.8420, IoU 0.7297")
398
-
399
- app.launch(
400
- server_name="0.0.0.0",
401
- server_port=7860,
402
- show_error=True,
403
- share=False
404
- )
 
6
  from PIL import Image
7
  import matplotlib.pyplot as plt
8
  import io
9
+ import torchvision.transforms as transforms
10
  import torchvision.transforms.functional as TF
11
+ import random
12
  import os
13
+ import zipfile
14
+ import urllib.request
15
+ import kagglehub
16
 
17
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
  model = None
19
 
20
+ # Your Attention U-Net classes (from your code)
21
  class DoubleConv(nn.Module):
22
  def __init__(self, in_channels, out_channels):
23
  super(DoubleConv, self).__init__()
 
59
  x1 = self.W_x(x)
60
  psi = self.relu(g1 + x1)
61
  psi = self.psi(psi)
62
+ return x * psi, psi # Return attention map as well
63
 
64
  class AttentionUNET(nn.Module):
65
  def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]):
 
70
  self.attentions = nn.ModuleList()
71
  self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
72
 
73
+ # Down part
74
  for feature in features:
75
  self.downs.append(DoubleConv(in_channels, feature))
76
  in_channels = feature
 
78
  # Bottleneck
79
  self.bottleneck = DoubleConv(features[-1], features[-1]*2)
80
 
81
+ # Up part
82
  for feature in reversed(features):
83
  self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
84
  self.attentions.append(AttentionBlock(F_g=feature, F_l=feature, F_int=feature // 2))
 
88
 
89
  def forward(self, x):
90
  skip_connections = []
91
+ attention_maps = [] # To store attention maps
92
 
93
  for down in self.downs:
94
  x = down(x)
 
96
  x = self.pool(x)
97
 
98
  x = self.bottleneck(x)
99
+ skip_connections = skip_connections[::-1]
100
 
101
+ for idx in range(0, len(self.ups), 2):
102
  x = self.ups[idx](x)
103
  skip_connection = skip_connections[idx//2]
104
 
105
  if x.shape != skip_connection.shape:
106
  x = TF.resize(x, size=skip_connection.shape[2:])
107
 
108
+ attended_skip, att_map = self.attentions[idx // 2](x, skip_connection) # Get attention map
109
+ attention_maps.append(att_map) # Store attention map
110
+ concat_skip = torch.cat((attended_skip, x), dim=1)
111
  x = self.ups[idx+1](concat_skip)
112
 
113
+ return self.final_conv(x), attention_maps
114
 
115
  def download_model():
116
  """Download your trained model from HuggingFace"""
 
125
  except Exception as e:
126
  print(f"❌ Failed to download model: {e}")
127
  return None
 
 
 
128
  return model_path
129
 
130
  def load_your_attention_model():
 
153
  model = None
154
  return model
155
 
156
+ def download_dataset():
157
+ """Download and extract the dataset using kagglehub"""
158
+ dataset_path = kagglehub.dataset_download('nikhilroxtomar/brain-tumor-segmentation')
159
+
160
+ # Extract if it's a zip
161
+ extracted_path = "brain_tumor_dataset"
162
+ if not os.path.exists(extracted_path):
163
+ with zipfile.ZipFile(dataset_path, 'r') as zip_ref:
164
+ zip_ref.extractall(extracted_path)
165
+
166
+ images_path = os.path.join(extracted_path, 'images')
167
+ masks_path = os.path.join(extracted_path, 'masks')
168
+
169
+ return images_path, masks_path
170
+
171
+ def load_random_sample():
172
+ """Load a random image and mask from the dataset"""
173
+ images_path, masks_path = download_dataset()
174
+
175
+ image_files = [f for f in os.listdir(images_path) if f.endswith(('.png', '.jpg'))]
176
+ if not image_files:
177
+ return None, None, "No images found in dataset"
178
+
179
+ random_file = random.choice(image_files)
180
+ img_path = os.path.join(images_path, random_file)
181
+ mask_path = os.path.join(masks_path, random_file)
182
+
183
+ image = Image.open(img_path).convert("L")
184
+ mask = Image.open(mask_path).convert("L") if os.path.exists(mask_path) else None
185
+
186
+ return image, mask, random_file
187
+
188
  def preprocess_for_your_model(image):
189
  """Preprocessing exactly like your Colab code"""
 
190
  if image.mode != 'L':
191
  image = image.convert('L')
192
 
 
193
  val_test_transform = transforms.Compose([
194
  transforms.Resize((256,256)),
195
  transforms.ToTensor()
196
  ])
197
 
198
+ return val_test_transform(image).unsqueeze(0)
199
+
200
+ def apply_tta(model, input_tensor):
201
+ """Test-Time Augmentation: Apply augmentations and average predictions"""
202
+ augmentations = [
203
+ lambda x: x, # Original
204
+ lambda x: TF.rotate(x, 90), # 90 deg rotation
205
+ lambda x: TF.rotate(x, -90), # -90 deg rotation
206
+ lambda x: TF.hflip(x), # Horizontal flip
207
+ lambda x: TF.vflip(x) # Vertical flip
208
+ ]
209
+
210
+ predictions = []
211
+ for aug in augmentations:
212
+ aug_input = aug(input_tensor)
213
+ pred = torch.sigmoid(model(aug_input)[0]) # Get prediction
214
+ # Reverse the augmentation for averaging
215
+ if aug == augmentations[1]: # Reverse 90 deg
216
+ pred = TF.rotate(pred, -90)
217
+ elif aug == augmentations[2]: # Reverse -90 deg
218
+ pred = TF.rotate(pred, 90)
219
+ elif aug == augmentations[3]: # Reverse hflip
220
+ pred = TF.hflip(pred)
221
+ elif aug == augmentations[4]: # Reverse vflip
222
+ pred = TF.vflip(pred)
223
+ predictions.append(pred)
224
+
225
+ # Average predictions
226
+ avg_pred = torch.mean(torch.stack(predictions), dim=0)
227
+ return avg_pred
228
+
229
+ def generate_attention_heatmap(attention_maps):
230
+ """Generate combined attention heatmap"""
231
+ if not attention_maps:
232
+ return np.zeros((256, 256))
233
+
234
+ # Average attention maps from different levels
235
+ combined_att = torch.mean(torch.stack(attention_maps), dim=0).squeeze().cpu().numpy()
236
+ combined_att = cv2.resize(combined_att, (256, 256))
237
+ combined_att = (combined_att - combined_att.min()) / (combined_att.max() - combined_att.min() + 1e-8)
238
+ heatmap = cv2.applyColorMap((combined_att * 255).astype(np.uint8), cv2.COLORMAP_JET)
239
+ return heatmap
240
 
241
+ def predict_tumor(image, ground_truth=None, filename=None):
242
  current_model = load_your_attention_model()
243
 
244
  if current_model is None:
245
+ return None, "Failed to load your trained model."
246
 
247
  if image is None:
248
+ return None, "Please upload or load an image first."
249
 
250
  try:
251
+ # Preprocess
 
 
252
  input_tensor = preprocess_for_your_model(image).to(device)
253
 
254
+ # Apply TTA
255
+ avg_pred = apply_tta(current_model, input_tensor)
256
+
257
+ # Get binary mask
258
+ binary_mask = (avg_pred > 0.5).float().squeeze().cpu().numpy()
259
 
260
+ # Post-processing
261
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5,5))
262
+ binary_mask = cv2.morphologyEx(binary_mask.astype(np.uint8), cv2.MORPH_OPEN, kernel)
263
+ binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
264
 
265
+ # Extract attention maps
266
+ _, attention_maps = current_model(input_tensor)
267
+ att_heatmap = generate_attention_heatmap(attention_maps)
268
 
269
+ # Create visualization
270
+ fig, axes = plt.subplots(2, 3, figsize=(18, 12))
271
+ fig.suptitle('Brain Tumor Segmentation Analysis', fontsize=20)
272
 
273
+ # Original
274
+ axes[0,0].imshow(image, cmap='gray')
275
+ axes[0,0].set_title('Original Image')
276
+ axes[0,0].axis('off')
277
 
278
+ # Attention Heatmap
279
+ axes[0,1].imshow(np.array(image), cmap='gray')
280
+ axes[0,1].imshow(att_heatmap, alpha=0.5)
281
+ axes[0,1].set_title('Attention Heatmap')
282
+ axes[0,1].axis('off')
283
 
284
+ # Predicted Mask
285
+ axes[0,2].imshow(binary_mask, cmap='gray')
286
+ axes[0,2].set_title('Predicted Mask')
287
+ axes[0,2].axis('off')
288
 
289
+ # Ground Truth (if available)
290
+ if ground_truth is not None:
291
+ gt_np = np.array(ground_truth.resize((256, 256)))
292
+ axes[1,0].imshow(gt_np, cmap='gray')
293
+ axes[1,0].set_title('Ground Truth Mask')
294
+ axes[1,0].axis('off')
295
+
296
+ # Comparison Overlay
297
+ overlay = np.array(image.convert('RGB'))
298
+ overlay[binary_mask > 0] = [0, 255, 0] # Green for prediction
299
+ overlay[gt_np > 0] = [255, 0, 0] # Red for ground truth
300
+ axes[1,1].imshow(overlay)
301
+ axes[1,1].set_title('Prediction (Green) vs GT (Red)')
302
+ axes[1,1].axis('off')
303
+
304
+ # IoU Calculation
305
+ intersection = np.sum(binary_mask * (gt_np > 0))
306
+ union = np.sum(binary_mask) + np.sum(gt_np > 0) - intersection
307
+ iou = intersection / (union + 1e-8)
308
+
309
+ axes[1,2].text(0.1, 0.5, f'IoU Score: {iou:.4f}', fontsize=20)
310
+ axes[1,2].axis('off')
311
+ else:
312
+ # Overlay for prediction only
313
+ overlay = np.array(image.convert('RGB'))
314
+ overlay[binary_mask > 0] = [255, 0, 0]
315
+ axes[1,0].imshow(overlay)
316
+ axes[1,0].set_title('Prediction Overlay')
317
+ axes[1,0].axis('off')
318
+
319
+ axes[1,1].axis('off')
320
+ axes[1,2].axis('off')
321
+
322
  plt.tight_layout()
323
 
 
324
  buf = io.BytesIO()
325
  plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
326
  buf.seek(0)
 
328
 
329
  result_image = Image.open(buf)
330
 
331
+ # Statistics
332
+ tumor_pixels = np.sum(binary_mask)
333
+ total_pixels = binary_mask.size
334
  tumor_percentage = (tumor_pixels / total_pixels) * 100
335
 
 
 
 
 
336
  analysis_text = f"""
337
+ ## Brain Tumor Segmentation Results
338
+
339
+ ### Detection Summary
340
+ - Tumor Percentage: {tumor_percentage:.2f}%
341
+ - Tumor Pixels: {tumor_pixels}
342
+ - File: {filename if filename else 'Uploaded Image'}
343
+
344
+ ### Model Information
345
+ - Your Attention U-Net Model
346
+ - Test-Time Augmentation: Applied
347
+ - Attention Visualization: Included
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  """
349
 
350
+ if ground_truth is not None:
351
+ analysis_text += f"\n- IoU with Ground Truth: {iou:.4f}"
352
+
353
  return result_image, analysis_text
354
 
355
  except Exception as e:
356
+ return None, f"Error: {str(e)}"
 
 
357
 
358
  def clear_all():
359
+ return None, None, None, "Upload or load an image for analysis"
360
 
361
+ # Professional CSS (white, clean, professional)
362
  css = """
363
  .gradio-container {
364
  max-width: 1400px !important;
365
  margin: auto !important;
366
+ background-color: white !important;
367
+ font-family: 'Arial', sans-serif !important;
368
+ }
369
+ h1, h2, h3, h4 {
370
+ color: #333333 !important;
371
+ }
372
+ button {
373
+ background-color: #f0f0f0 !important;
374
+ color: #333333 !important;
375
+ border: 1px solid #dddddd !important;
376
+ border-radius: 4px !important;
377
+ }
378
+ button.primary {
379
+ background-color: #007bff !important;
380
+ color: white !important;
381
+ }
382
+ .output-image {
383
+ border: 1px solid #dddddd !important;
384
+ border-radius: 4px !important;
385
  }
386
+ .markdown {
387
+ line-height: 1.6 !important;
388
+ color: #555555 !important;
 
 
 
 
 
389
  }
390
  """
391
 
392
+ # Create professional Gradio interface
393
+ with gr.Blocks(css=css, title="Brain Tumor Segmentation Application") as app:
394
 
395
+ gr.Markdown("""
396
+ # Brain Tumor Segmentation Using Attention U-Net
397
+ A professional tool for medical image analysis
 
 
 
 
 
 
 
398
  """)
399
 
400
  with gr.Row():
401
  with gr.Column(scale=1):
402
+ gr.Markdown("### Input Selection")
403
 
404
  image_input = gr.Image(
405
+ label="Upload Brain MRI",
406
  type="pil",
407
  sources=["upload", "webcam"],
408
+ height=300
409
  )
410
 
411
+ load_random_btn = gr.Button("Load Random Sample from Dataset", variant="primary")
 
 
412
 
413
+ with gr.Row():
414
+ analyze_btn = gr.Button("Analyze Image", variant="primary", scale=2)
415
+ clear_btn = gr.Button("Clear", scale=1)
 
 
 
 
 
 
 
 
 
416
 
417
  with gr.Column(scale=2):
418
+ gr.Markdown("### Analysis Results")
419
 
420
  output_image = gr.Image(
421
+ label="Segmentation Results",
422
  type="pil",
423
+ height=400
424
  )
425
 
426
  analysis_output = gr.Markdown(
427
+ value="Select an input method to begin analysis."
 
428
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
 
430
+ # Hidden state for ground truth and filename
431
+ ground_truth_state = gr.State()
432
+ filename_state = gr.State()
433
+
434
  # Event handlers
435
  analyze_btn.click(
436
  fn=predict_tumor,
437
+ inputs=[image_input, ground_truth_state, filename_state],
438
+ outputs=[output_image, analysis_output]
439
+ )
440
+
441
+ load_random_btn.click(
442
+ fn=load_random_sample,
443
+ inputs=[],
444
+ outputs=[image_input, ground_truth_state, filename_state, analysis_output]
445
  )
446
 
447
  clear_btn.click(
448
  fn=clear_all,
449
  inputs=[],
450
+ outputs=[image_input, output_image, ground_truth_state, analysis_output]
451
  )
452
 
453
  if __name__ == "__main__":
454
+ print("Starting Brain Tumor Segmentation Application...")
455
+ app.launch()