File size: 22,081 Bytes
6477883
 
c5d3869
6477883
 
 
 
 
fa07b35
5ad7df6
fa07b35
5ad7df6
fa07b35
 
c530021
71461a8
c530021
71461a8
fb04e74
c530021
 
 
6477883
c530021
 
 
 
 
c5d3869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d57f983
c5d3869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d57f983
c5d3869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd271b0
c5d3869
c530021
c5d3869
 
 
 
 
 
 
fa07b35
c5d3869
fa07b35
c5d3869
 
 
 
 
 
c530021
 
fa07b35
c5d3869
 
fa07b35
c5d3869
c530021
 
 
 
 
c5d3869
 
 
c530021
c5d3869
c530021
c5d3869
 
e1f56cb
c530021
 
fa07b35
c530021
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa07b35
c530021
 
 
 
fa07b35
c530021
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa07b35
c530021
 
 
 
fa07b35
c530021
fa07b35
 
c530021
 
 
fa07b35
c530021
 
 
 
 
 
 
fa07b35
c530021
 
 
 
 
 
fa07b35
c530021
a4f4e25
5ad7df6
 
16f55d5
c530021
cd271b0
5ad7df6
 
4f4b98a
c530021
fa07b35
 
a4f4e25
fa07b35
c530021
fa07b35
d57f983
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa07b35
d57f983
 
fa07b35
d57f983
fa07b35
5ad7df6
c530021
a4f4e25
c530021
 
90efbfd
6477883
c530021
fb04e74
6477883
bccb310
 
 
 
 
a4f4e25
c530021
bccb310
 
b70ec11
82c9eec
c530021
bccb310
82c9eec
 
bccb310
 
 
a4f4e25
82c9eec
bccb310
 
a4f4e25
82c9eec
bccb310
82c9eec
a4f4e25
 
 
 
 
 
 
 
 
 
5ad7df6
c530021
bccb310
fa07b35
bccb310
5ad7df6
a4f4e25
 
 
 
 
 
fa07b35
c530021
a4f4e25
c530021
a4f4e25
c530021
 
5ad7df6
a4f4e25
 
c530021
fa07b35
5ad7df6
a4f4e25
 
c530021
 
fa07b35
cd271b0
a4f4e25
 
 
 
 
fa07b35
a4f4e25
c530021
a4f4e25
 
 
 
 
 
 
bccb310
 
 
a4f4e25
 
 
 
82c9eec
a4f4e25
 
c530021
fa07b35
 
a4f4e25
c530021
a4f4e25
 
fa07b35
c530021
fa07b35
 
a4f4e25
 
 
 
fa07b35
c530021
a4f4e25
c530021
bccb310
 
 
 
a4f4e25
 
bccb310
c530021
 
 
 
fa07b35
a4f4e25
 
 
 
 
 
 
fa07b35
a4f4e25
 
c530021
fa07b35
 
a4f4e25
 
 
 
 
c530021
 
a4f4e25
 
 
 
fa07b35
6477883
fb04e74
c530021
6477883
5ad7df6
6477883
 
fb04e74
6477883
cad0957
c530021
a4f4e25
 
fb04e74
90efbfd
bccb310
 
 
 
fb04e74
c530021
 
 
fa07b35
c530021
 
 
fa07b35
c530021
 
a4f4e25
c530021
fb04e74
fa07b35
c530021
 
 
 
 
fa07b35
6477883
fb04e74
6477883
82c9eec
 
 
 
c530021
bccb310
c530021
 
 
 
 
 
 
 
 
6477883
c530021
5ad7df6
c530021
c5d3869
5ad7df6
c530021
5ad7df6
c530021
fa07b35
c530021
 
 
fa07b35
c530021
 
 
fa07b35
c530021
 
 
fa07b35
c530021
 
5ad7df6
c530021
 
 
5ad7df6
c5d3869
 
c530021
 
5ad7df6
fa07b35
 
c530021
 
 
82c9eec
5ad7df6
fb04e74
c530021
 
 
 
 
 
 
 
6477883
 
fa07b35
5ad7df6
c530021
 
 
cd271b0
fa07b35
5ad7df6
fb04e74
c530021
fa07b35
c530021
 
 
 
 
 
 
 
 
 
 
 
fb04e74
6477883
fa07b35
5ad7df6
c530021
 
5ad7df6
c530021
5ad7df6
 
c530021
 
cd271b0
fb04e74
c530021
 
 
 
6477883
c530021
 
 
 
 
 
 
 
 
 
 
 
 
fa07b35
 
c530021
 
 
 
6477883
fb04e74
c530021
 
 
 
6477883
 
 
c530021
 
 
 
 
 
 
 
 
82c9eec
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import io
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import random
import os
import urllib.request
import kagglehub
from glob import glob

# Global variables - loaded once at startup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = None
dataset_images = []
dataset_masks = []
dataset_loaded = False

print("="*50)
print("BRAIN TUMOR SEGMENTATION APPLICATION")
print("="*50)

# Your Attention U-Net classes (unchanged)
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionBlock, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi, psi  # Return both attended features AND attention map

class AttentionUNET(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]):
        super(AttentionUNET, self).__init__()
        self.out_channels = out_channels
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.attentions = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        
        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
            self.attentions.append(AttentionBlock(F_g=feature, F_l=feature, F_int=feature // 2))
            self.ups.append(DoubleConv(feature*2, feature))

        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []
        attention_maps = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])

            attended_skip, att_map = self.attentions[idx // 2](x, skip_connection)
            attention_maps.append(att_map)
            concat_skip = torch.cat((attended_skip, x), dim=1)
            x = self.ups[idx+1](concat_skip)

        return self.final_conv(x), attention_maps

def download_and_load_model():
    """Download and load model once at startup"""
    global model
    print("Loading Attention U-Net model...")
    
    model_url = "https://huggingface.co/spaces/ArchCoder/the-op-segmenter/resolve/main/best_attention_model.pth.tar"
    model_path = "best_attention_model.pth.tar"
    
    # Download model if needed
    if not os.path.exists(model_path):
        print("Downloading model weights...")
        try:
            urllib.request.urlretrieve(model_url, model_path)
        except Exception as e:
            print(f"Failed to download model: {e}")
            return False
    
    # Load model
    try:
        model = AttentionUNET(in_channels=1, out_channels=1).to(device)
        checkpoint = torch.load(model_path, map_location=device, weights_only=True)
        model.load_state_dict(checkpoint["state_dict"])
        model.eval()
        print("βœ“ Model loaded successfully!")
        return True
    except Exception as e:
        print(f"Failed to load model: {e}")
        return False

def download_and_load_dataset():
    """Download and load entire dataset once at startup"""
    global dataset_images, dataset_masks, dataset_loaded
    
    if dataset_loaded:
        return True
        
    print("Loading brain tumor dataset...")
    
    try:
        # Download dataset using kagglehub - returns directory path
        dataset_path = kagglehub.dataset_download('nikhilroxtomar/brain-tumor-segmentation')
        print(f"Dataset downloaded to: {dataset_path}")
        
        # Find images and masks directories
        images_dir = os.path.join(dataset_path, 'images')
        masks_dir = os.path.join(dataset_path, 'masks')
        
        # If direct path doesn't exist, search subdirectories
        if not os.path.exists(images_dir):
            # Search for images and masks directories
            for root, dirs, files in os.walk(dataset_path):
                if 'images' in dirs:
                    images_dir = os.path.join(root, 'images')
                if 'masks' in dirs:
                    masks_dir = os.path.join(root, 'masks')
        
        if not os.path.exists(images_dir) or not os.path.exists(masks_dir):
            print("Could not find images/masks directories. Searching all files...")
            # Fallback: find all image files
            all_files = glob(os.path.join(dataset_path, "**/*.png"), recursive=True) + \
                       glob(os.path.join(dataset_path, "**/*.jpg"), recursive=True)
            
            dataset_images = [f for f in all_files if '/images/' in f or 'image' in f.lower()]
            dataset_masks = [f for f in all_files if '/masks/' in f or 'mask' in f.lower()]
        else:
            # Load image and mask file paths
            dataset_images = glob(os.path.join(images_dir, "*.*"))
            dataset_masks = glob(os.path.join(masks_dir, "*.*"))
        
        dataset_images = sorted(dataset_images)
        dataset_masks = sorted(dataset_masks)
        
        print(f"βœ“ Found {len(dataset_images)} images and {len(dataset_masks)} masks")
        dataset_loaded = True
        return True
        
    except Exception as e:
        print(f"Failed to load dataset: {e}")
        return False

def get_random_sample():
    """Get a random image and corresponding mask from dataset"""
    if not dataset_loaded:
        return None, None, "Dataset not loaded"
    
    if not dataset_images:
        return None, None, "No images found in dataset"
    
    # Get random index
    idx = random.randint(0, len(dataset_images) - 1)
    img_path = dataset_images[idx]
    
    # Find corresponding mask
    img_name = os.path.basename(img_path)
    mask_path = None
    for mask in dataset_masks:
        if os.path.basename(mask) == img_name:
            mask_path = mask
            break
    
    try:
        image = Image.open(img_path).convert("L")
        mask = Image.open(mask_path).convert("L") if mask_path else None
        return image, mask, img_name
    except Exception as e:
        return None, None, f"Error loading sample: {e}"

def preprocess_for_model(image):
    """Preprocessing for your model - matches the working notebook"""
    if image.mode != 'L':
        image = image.convert('L')
    
    transform = transforms.Compose([
        transforms.Resize((256,256)),
        transforms.ToTensor()
    ])
    
    return transform(image).unsqueeze(0)

def generate_attention_heatmap(attention_maps):
    """Generate attention heatmap"""
    if not attention_maps:
        return np.zeros((256, 256, 3))
    
    # Resize all attention maps to the same size (256x256) before combining
    resized_maps = []
    target_size = (256, 256)
    
    for att_map in attention_maps:
        # Convert to numpy and squeeze
        att_np = att_map.squeeze().cpu().numpy()
        
        # Resize to target size
        att_resized = cv2.resize(att_np, target_size)
        resized_maps.append(att_resized)
    
    # Now we can safely average the maps since they're all the same size
    combined_att = np.mean(resized_maps, axis=0)
    
    # Normalize to [0, 1]
    combined_att = (combined_att - combined_att.min()) / (combined_att.max() - combined_att.min() + 1e-8)
    
    # Apply colormap
    heatmap = cv2.applyColorMap((combined_att * 255).astype(np.uint8), cv2.COLORMAP_JET)
    
    return heatmap

def analyze_image(image, ground_truth, filename):
    """Main analysis function - FIXED VERSION matching the working notebook"""
    if model is None:
        return None, "Model not loaded. Please restart the application."
    
    if image is None:
        return None, "Please select an image first."
    
    try:
        print("="*50)
        print("DEBUG: Starting analysis...")
        print(f"Input image mode: {image.mode}")
        print(f"Input image size: {image.size}")
        
        # Preprocess - exactly like the working notebook
        input_tensor = preprocess_for_model(image).to(device)
        print(f"Input tensor shape: {input_tensor.shape}")
        print(f"Input tensor min/max: {input_tensor.min():.4f}/{input_tensor.max():.4f}")
        
        # Get prediction and attention maps
        with torch.no_grad():
            print("Getting model output...")
            model_output, attention_maps = model(input_tensor)
            
            print(f"Model output shape: {model_output.shape}")
            print(f"Model output min/max BEFORE sigmoid: {model_output.min():.4f}/{model_output.max():.4f}")
            
            # Apply sigmoid and threshold - EXACTLY like the working notebook
            pred_mask = torch.sigmoid(model_output)
            print(f"After sigmoid min/max: {pred_mask.min():.4f}/{pred_mask.max():.4f}")
            
            # Apply threshold to get binary mask
            binary_mask = (pred_mask > 0.5).float()
            print(f"Binary mask sum (number of 1s): {binary_mask.sum()}")
            
            # Convert to numpy - following notebook approach
            pred_mask_np = binary_mask.cpu().squeeze().numpy()
            print(f"Numpy binary mask shape: {pred_mask_np.shape}")
            print(f"Numpy binary mask unique values: {np.unique(pred_mask_np)}")
            print(f"Numpy binary mask sum: {np.sum(pred_mask_np)}")
        
        # Create visualization mask like in the notebook
        # The notebook uses: inv_pred_mask_np = np.where(pred_mask_np == 1, 0, 255)
        # This inverts the mask for better visualization
        inv_pred_mask_np = np.where(pred_mask_np == 1, 0, 255)
        
        # Generate attention heatmap
        print("Generating attention heatmap...")
        att_heatmap = generate_attention_heatmap(attention_maps)
        print(f"Attention heatmap shape: {att_heatmap.shape}")
        
        # Prepare original image array
        original_np = np.array(image.resize((256, 256)))
        
        # Create tumor-only image (like in notebook)
        tumor_only = np.where(pred_mask_np == 1, original_np, 255)
        
        # Create visualization
        if ground_truth is not None:
            fig, axes = plt.subplots(2, 4, figsize=(16, 8))
        else:
            fig, axes = plt.subplots(2, 3, figsize=(15, 8))
            
        fig.suptitle('Brain Tumor Segmentation Analysis', fontsize=16, weight='bold')
        
        # Row 1: Original, Attention, Predicted Mask, Tumor Only
        axes[0,0].imshow(original_np, cmap='gray')
        axes[0,0].set_title('Original Image', fontsize=12, weight='bold')
        axes[0,0].axis('off')
        
        # Attention heatmap overlay
        axes[0,1].imshow(original_np, cmap='gray')
        axes[0,1].imshow(att_heatmap, alpha=0.4)
        axes[0,1].set_title('Attention Heatmap', fontsize=12, weight='bold')
        axes[0,1].axis('off')
        
        # Predicted mask (inverted for visualization)
        axes[0,2].imshow(inv_pred_mask_np, cmap='gray')
        axes[0,2].set_title('Predicted Mask', fontsize=12, weight='bold')
        axes[0,2].axis('off')
        
        if ground_truth is not None:
            # Ground truth processing - convert to binary like notebook
            gt_array = np.array(ground_truth.resize((256, 256)))
            # Apply same preprocessing as notebook
            val_test_transform = transforms.Compose([
                transforms.Resize((256,256)),
                transforms.ToTensor()
            ])
            mask_np = val_test_transform(ground_truth).cpu().squeeze().numpy()
            
            print(f"Ground truth array shape: {gt_array.shape}")
            print(f"Ground truth unique values: {np.unique(gt_array)}")
            
            # Tumor only image
            axes[0,3].imshow(tumor_only, cmap='gray')
            axes[0,3].set_title('Tumor Only', fontsize=12, weight='bold')
            axes[0,3].axis('off')
            
            # Row 2: Ground truth, overlay comparison, metrics
            axes[1,0].imshow(mask_np, cmap='gray')
            axes[1,0].set_title('Ground Truth Mask', fontsize=12, weight='bold')
            axes[1,0].axis('off')
            
            # Overlay comparison - following notebook style
            overlay = np.array(image.convert('RGB').resize((256, 256)))
            overlay[pred_mask_np == 1] = [0, 255, 0]  # Green for prediction
            overlay[mask_np > 0.5] = [255, 0, 0]      # Red for ground truth
            axes[1,1].imshow(overlay)
            axes[1,1].set_title('Prediction (Green) vs GT (Red)', fontsize=12, weight='bold')
            axes[1,1].axis('off')
            
            # Calculate IoU and Dice exactly like notebook
            intersection = np.logical_and(pred_mask_np, mask_np).sum()
            union = np.logical_or(pred_mask_np, mask_np).sum()
            iou = intersection / (union + 1e-7)
            
            # Dice score
            dice = (2 * intersection) / (pred_mask_np.sum() + mask_np.sum() + 1e-7)
            
            print(f"Final IoU: {iou:.4f}")
            print(f"Final Dice: {dice:.4f}")
            print(f"Intersection: {intersection}")
            print(f"Union: {union}")
            print(f"Pred pixels: {np.sum(pred_mask_np)}")
            print(f"GT pixels: {np.sum(mask_np > 0.5)}")
            
            axes[1,2].text(0.1, 0.6, f'IoU: {iou:.4f}', fontsize=16, weight='bold')
            axes[1,2].text(0.1, 0.4, f'Dice: {dice:.4f}', fontsize=16, weight='bold')
            axes[1,2].set_xlim(0, 1)
            axes[1,2].set_ylim(0, 1)
            axes[1,2].axis('off')
            axes[1,2].set_title('Metrics', fontsize=12, weight='bold')
            
            # Additional tumor statistics
            axes[1,3].imshow(tumor_only, cmap='gray')
            axes[1,3].set_title('Segmented Tumor', fontsize=12, weight='bold')
            axes[1,3].axis('off')
            
        else:
            # No ground truth case
            axes[1,0].imshow(inv_pred_mask_np, cmap='gray')
            axes[1,0].set_title('Predicted Mask', fontsize=12, weight='bold')
            axes[1,0].axis('off')
            
            # Tumor only
            axes[1,1].imshow(tumor_only, cmap='gray')
            axes[1,1].set_title('Tumor Only', fontsize=12, weight='bold')
            axes[1,1].axis('off')
            
            # Overlay
            overlay = np.array(image.convert('RGB').resize((256, 256)))
            overlay[pred_mask_np == 1] = [255, 0, 0]
            axes[1,2].imshow(overlay)
            axes[1,2].set_title('Prediction Overlay', fontsize=12, weight='bold')
            axes[1,2].axis('off')

        plt.tight_layout()
        
        # Save plot
        buf = io.BytesIO()
        plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
        buf.seek(0)
        plt.close()
        
        result_image = Image.open(buf)
        
        # Generate analysis text
        tumor_pixels = np.sum(pred_mask_np)
        total_pixels = pred_mask_np.size
        tumor_percentage = (tumor_pixels / total_pixels) * 100
        
        print(f"Final tumor pixels: {tumor_pixels}")
        print(f"Final tumor percentage: {tumor_percentage:.2f}%")
        print("="*50)
        
        analysis_text = f"""
# Analysis Results

**File:** {filename if filename else 'Uploaded Image'}

**Tumor Detection:**
- Tumor Area: {tumor_percentage:.2f}%
- Tumor Pixels: {tumor_pixels:,}

**Model Features:**
- Attention Visualization: Generated
- Post-processing: Applied
"""
        
        if ground_truth is not None:
            analysis_text += f"""
**Performance Metrics:**
- IoU Score: {iou:.4f}
- Dice Score: {dice:.4f}
"""
        
        return result_image, analysis_text
        
    except Exception as e:
        import traceback
        error_msg = f"Analysis failed: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
        print(error_msg)  # For debugging
        return None, error_msg

        
# Initialize model and dataset at startup
print("Initializing application components...")
model_loaded = download_and_load_model()
dataset_loaded_success = download_and_load_dataset()

if not model_loaded:
    print("WARNING: Model failed to load!")
if not dataset_loaded_success:
    print("WARNING: Dataset failed to load!")

print("Application ready!")

# Professional CSS
css = """
.gradio-container {
    max-width: 1600px !important;
    margin: auto !important;
    font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif !important;
}
.gr-button {
    border-radius: 6px !important;
    font-weight: 500 !important;
}
.gr-button-primary {
    background: #2563eb !important;
    border-color: #2563eb !important;
}
.gr-button-secondary {
    background: #6b7280 !important;
    border-color: #6b7280 !important;
}
h1, h2, h3 {
    color: #1f2937 !important;
}
.gr-form {
    border: 1px solid #e5e7eb !important;
    border-radius: 8px !important;
}
"""

# Create Gradio interface
with gr.Blocks(css=css, title="Brain Tumor Segmentation Analysis") as app:
    
    gr.Markdown("""
    # Brain Tumor Segmentation Using Attention U-Net
    
    **Advanced Medical Image Analysis Tool**
    
    Features: Attention Visualization, Dataset Integration, Morphological Post-processing
    """)
    
    # Status display
    with gr.Row():
        with gr.Column():
            status_text = f"Model Status: {'βœ“ Loaded' if model_loaded else 'βœ— Failed'} | Dataset Status: {'βœ“ Loaded' if dataset_loaded_success else 'βœ— Failed'}"
            if dataset_loaded_success:
                status_text += f" | Images: {len(dataset_images)} | Masks: {len(dataset_masks)}"
            gr.Markdown(f"**{status_text}**")
    
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### Input Selection")
            
            # Image display
            image_display = gr.Image(
                label="Selected Image",
                type="pil",
                height=300
            )
            
            # Control buttons
            with gr.Row():
                load_sample_btn = gr.Button("Load Random Sample", variant="primary", scale=1)
                upload_btn = gr.UploadButton("Upload Image", file_types=["image"], scale=1)
            
            analyze_btn = gr.Button("Analyze Image", variant="primary", size="lg")
            
            # Dataset info
            gr.Markdown(f"""
            **Dataset Information:**
            - Total Images: {len(dataset_images) if dataset_loaded_success else 'N/A'}
            - Total Masks: {len(dataset_masks) if dataset_loaded_success else 'N/A'}
            - Source: nikhilroxtomar/brain-tumor-segmentation
            """)
        
        with gr.Column(scale=2):
            gr.Markdown("### Analysis Results")
            
            result_display = gr.Image(
                label="Segmentation Analysis",
                type="pil",
                height=500
            )
            
            analysis_text = gr.Markdown(
                value="Load an image and click 'Analyze Image' to begin."
            )
    
    # Hidden states
    current_ground_truth = gr.State()
    current_filename = gr.State()
    
    # Event handlers
    def handle_sample_load():
        image, mask, filename = get_random_sample()
        return image, mask, filename
    
    def handle_upload(file):
        if file is not None:
            image = Image.open(file.name).convert("L")
            return image, None, os.path.basename(file.name)
        return None, None, ""
    
    load_sample_btn.click(
        fn=handle_sample_load,
        outputs=[image_display, current_ground_truth, current_filename]
    )
    
    upload_btn.upload(
        fn=handle_upload,
        inputs=[upload_btn],
        outputs=[image_display, current_ground_truth, current_filename]
    )
    
    analyze_btn.click(
        fn=analyze_image,
        inputs=[image_display, current_ground_truth, current_filename],
        outputs=[result_display, analysis_text]
    )

if __name__ == "__main__":
    print("\n" + "="*50)
    print("LAUNCHING BRAIN TUMOR SEGMENTATION APPLICATION")
    print("="*50)
    
    app.launch(
        server_name="0.0.0.0",
        server_port=7860,
        show_error=True,
        share=False
    )