File size: 28,707 Bytes
ed89c53
a24331f
 
dba48bd
a24331f
 
dba48bd
 
 
 
a24331f
 
 
 
 
 
 
 
 
dba48bd
 
a24331f
 
 
 
 
 
 
 
b57b086
3b501f6
 
a24331f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dba48bd
15bacdd
5c6d9a9
3b501f6
a24331f
 
 
 
 
 
 
 
 
dba48bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a24331f
 
 
 
 
 
 
dba48bd
 
 
 
 
 
 
 
 
 
 
a24331f
dba48bd
a24331f
 
 
 
dba48bd
3b501f6
a24331f
 
 
3130041
 
5c6d9a9
3130041
5c6d9a9
3130041
 
 
 
 
93c1900
3130041
 
5c6d9a9
 
 
93c1900
 
3130041
 
 
 
93c1900
 
 
3130041
 
93c1900
 
 
 
 
 
 
 
 
 
3130041
 
 
93c1900
3130041
 
 
 
 
93c1900
3130041
5c6d9a9
3130041
 
 
 
 
 
 
 
 
5c6d9a9
93c1900
 
 
 
 
 
 
 
 
 
3130041
5c6d9a9
93c1900
3130041
93c1900
3130041
93c1900
5c6d9a9
 
3130041
5c6d9a9
 
93c1900
5c6d9a9
3130041
93c1900
 
 
5c6d9a9
 
93c1900
 
 
5c6d9a9
93c1900
5c6d9a9
93c1900
3130041
5c6d9a9
3130041
93c1900
 
 
 
 
 
 
 
 
 
5c6d9a9
93c1900
 
 
 
3130041
 
 
 
93c1900
 
 
 
 
 
 
 
3130041
93c1900
 
 
 
 
 
 
 
 
 
3130041
93c1900
3130041
5c6d9a9
93c1900
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c6d9a9
 
93c1900
 
9b1ad10
93c1900
 
5c6d9a9
 
 
 
 
 
 
 
 
 
 
 
93c1900
 
5c6d9a9
93c1900
 
5c6d9a9
93c1900
 
 
5c6d9a9
93c1900
 
 
 
 
 
5c6d9a9
93c1900
 
 
 
 
 
 
 
 
 
 
9b1ad10
93c1900
 
 
5c6d9a9
93c1900
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c6d9a9
93c1900
5c6d9a9
a24331f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b501f6
a24331f
 
 
93c1900
 
a24331f
dba48bd
93c1900
 
 
 
 
 
 
 
 
 
 
3b501f6
 
93c1900
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b1ad10
 
93c1900
 
 
 
9b1ad10
 
93c1900
a24331f
9b1ad10
a24331f
93c1900
 
3b501f6
5c6d9a9
3b501f6
 
5c6d9a9
3b501f6
 
 
15bacdd
5c6d9a9
9b1ad10
93c1900
3b501f6
5c6d9a9
 
93c1900
 
 
 
 
9b1ad10
3b501f6
93c1900
15bacdd
93c1900
9b1ad10
 
 
 
 
 
3b501f6
15bacdd
93c1900
15bacdd
93c1900
9b1ad10
3b501f6
9b1ad10
 
 
 
3b501f6
93c1900
9b1ad10
 
 
 
 
93c1900
9b1ad10
 
3b501f6
5c6d9a9
 
9b1ad10
5c6d9a9
 
9b1ad10
5c6d9a9
 
9b1ad10
 
 
3b501f6
 
9b1ad10
93c1900
3b501f6
 
93c1900
3b501f6
a24331f
 
 
3b501f6
a24331f
 
 
 
 
 
3b501f6
a24331f
 
5c6d9a9
a24331f
 
 
 
5c6d9a9
a24331f
 
3b501f6
a24331f
 
 
 
 
3b501f6
a24331f
 
 
 
 
3b501f6
a24331f
dba48bd
3b501f6
3130041
5c6d9a9
93c1900
3130041
5c6d9a9
 
3130041
 
 
 
 
5c6d9a9
93c1900
 
 
 
 
15bacdd
dba48bd
3b501f6
dba48bd
 
3b501f6
a24331f
 
3b501f6
a24331f
 
 
 
 
 
 
3b501f6
15bacdd
5c6d9a9
3130041
5c6d9a9
93c1900
3b501f6
 
9b1ad10
 
 
 
 
 
 
a24331f
5c6d9a9
9b1ad10
93c1900
3b501f6
a24331f
 
 
 
 
9b1ad10
a24331f
 
3b501f6
a24331f
 
 
3b501f6
a24331f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b501f6
a24331f
 
 
dba48bd
3b501f6
 
93c1900
3b501f6
 
93c1900
 
 
15bacdd
 
 
9b1ad10
 
15bacdd
 
9b1ad10
 
 
93c1900
9b1ad10
3b501f6
 
 
 
5c6d9a9
 
 
 
3130041
5c6d9a9
3130041
93c1900
9b1ad10
 
 
 
93c1900
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15bacdd
5c6d9a9
93c1900
 
15bacdd
 
3b501f6
 
 
a24331f
 
 
93c1900
a24331f
93c1900
a24331f
93c1900
 
 
 
 
a24331f
93c1900
 
 
 
 
a24331f
93c1900
a24331f
3b501f6
a24331f
 
 
ed89c53
a24331f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
import streamlit as st
import numpy as np
import os
import sys
from PIL import Image

# Set environment variables to fix permission issues
os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
os.environ['STREAMLIT_SERVER_HEADLESS'] = 'true'

# Minimal imports to avoid conflicts
try:
    import tensorflow as tf
    TF_AVAILABLE = True
except ImportError:
    TF_AVAILABLE = False
    st.error("TensorFlow not available")

try:
    import matplotlib
    matplotlib.use('Agg')  # Use non-interactive backend
    import matplotlib.pyplot as plt
    import matplotlib.cm as cm
    MPL_AVAILABLE = True
except ImportError:
    MPL_AVAILABLE = False

# Page config
st.set_page_config(
    page_title="Stroke Classifier",
    page_icon="🧠",
    layout="wide")

# Simple styling
st.markdown("""
<style>
    .main-header {
        font-size: 2.5rem;
        color: #1f77b4;
        text-align: center;
        margin-bottom: 2rem;
    }
    .prediction-box {
        background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
        color: white;
        padding: 2rem;
        border-radius: 1rem;
        text-align: center;
        margin: 1rem 0;
    }
    .status-box {
        padding: 1rem;
        border-radius: 0.5rem;
        margin: 1rem 0;
    }
    .success { background-color: #d4edda; border: 1px solid #c3e6cb; color: #155724; }
    .error { background-color: #f8d7da; border: 1px solid #f5c6cb; color: #721c24; }
    .info { background-color: #d1ecf1; border: 1px solid #bee5eb; color: #0c5460; }
    .warning { background-color: #fff3cd; border: 1px solid #ffeaa7; color: #856404; }
    .debug { background-color: #f8f9fa; border: 1px solid #dee2e6; color: #495057; font-family: monospace; }
</style>""", unsafe_allow_html=True)

# Initialize session state
if 'model_loaded' not in st.session_state:
    st.session_state.model_loaded = False
    st.session_state.model = None
    st.session_state.model_status = "Not loaded"

STROKE_LABELS = ["Hemorrhagic Stroke", "Ischemic Stroke", "No Stroke"]

def find_model_file():
    """Find the model file in various possible locations."""
    possible_paths = [
        "stroke_classification_model.h5",
        "./stroke_classification_model.h5",
        "/app/stroke_classification_model.h5",
        "src/stroke_classification_model.h5",
        os.path.join(os.getcwd(), "stroke_classification_model.h5")
    ]
    
    # Also check all .h5 files in current directory and subdirectories
    for root, dirs, files in os.walk('.'):
        for file in files:
            if file.endswith('.h5'):
                possible_paths.append(os.path.join(root, file))
    
    for path in possible_paths:
        if os.path.exists(path):
            return path
    
    return None

@st.cache_resource
def load_stroke_model():
    """Load model with caching."""
    if not TF_AVAILABLE:
        return None, "❌ TensorFlow not available"
    
    try:
        # Find the model file
        model_path = find_model_file()
        
        if model_path is None:
            # List all files to help debug
            current_files = []
            for root, dirs, files in os.walk('.'):
                for file in files:
                    current_files.append(os.path.join(root, file))
            
            return None, f"❌ Model file not found. Available files: {current_files[:10]}"
        
        st.info(f"Found model at: {model_path}")
        
        # Load model with minimal custom objects
        model = tf.keras.models.load_model(model_path, compile=False)
        
        return model, f"βœ… Model loaded successfully from: {model_path}"
    
    except Exception as e:
        return None, f"❌ Model loading failed: {str(e)}"

def analyze_model_architecture(model):
    """Comprehensive analysis of model architecture."""
    if model is None:
        return {"error": "No model loaded"}
    
    layer_analysis = {
        'total_layers': len(model.layers),
        'conv_layers': [],
        'dense_layers': [],
        'other_layers': [],
        'all_layers_detailed': [],
        'model_type': 'Unknown'
    }
    
    for i, layer in enumerate(model.layers):
        layer_type = type(layer).__name__
        
        # Get more detailed layer information
        layer_info = {
            'index': i,
            'name': layer.name,
            'type': layer_type,
            'output_shape': getattr(layer, 'output_shape', 'Unknown'),
            'trainable': getattr(layer, 'trainable', 'Unknown'),
            'activation': getattr(layer, 'activation', None)
        }
        
        # Try to get activation function name
        if hasattr(layer, 'activation') and layer.activation:
            try:
                layer_info['activation'] = layer.activation.__name__
            except:
                layer_info['activation'] = str(layer.activation)
        
        layer_analysis['all_layers_detailed'].append(layer_info)
        
        # Categorize layers with more comprehensive detection
        if any(conv_type in layer_type for conv_type in [
            'Conv1D', 'Conv2D', 'Conv3D', 'SeparableConv2D', 'DepthwiseConv2D',
            'Convolution1D', 'Convolution2D', 'Convolution3D'
        ]) or 'conv' in layer.name.lower():
            layer_analysis['conv_layers'].append(layer_info)
        
        elif 'Dense' in layer_type or 'Linear' in layer_type:
            layer_analysis['dense_layers'].append(layer_info)
        
        else:
            layer_analysis['other_layers'].append(layer_info)
    
    # Determine model type
    if layer_analysis['conv_layers']:
        layer_analysis['model_type'] = 'CNN (Convolutional Neural Network)'
    elif layer_analysis['dense_layers']:
        layer_analysis['model_type'] = 'MLP (Multi-Layer Perceptron)'
    else:
        layer_analysis['model_type'] = 'Custom Architecture'
    
    return layer_analysis

def debug_gradcam_step_by_step(img_array, model, layer_name, pred_index):
    """Debug Grad-CAM computation step by step."""
    debug_info = {
        'step': 'Starting',
        'error': None,
        'layer_output_shape': None,
        'gradients_shape': None,
        'gradients_stats': None,
        'heatmap_stats': None
    }
    
    try:
        debug_info['step'] = 'Getting target layer'
        target_layer = model.get_layer(layer_name)
        debug_info['target_layer_type'] = type(target_layer).__name__
        
        debug_info['step'] = 'Creating grad model'
        grad_model = tf.keras.Model(
            inputs=[model.inputs],
            outputs=[target_layer.output, model.output]
        )
        
        debug_info['step'] = 'Computing forward pass'
        with tf.GradientTape() as tape:
            layer_output, preds = grad_model(img_array)
            debug_info['layer_output_shape'] = layer_output.shape.as_list()
            debug_info['predictions_shape'] = preds.shape.as_list()
            
            if pred_index is None:
                pred_index = tf.argmax(preds[0])
            debug_info['pred_index'] = int(pred_index)
            debug_info['pred_confidence'] = float(preds[0][pred_index])
            
            class_channel = preds[:, pred_index]
            debug_info['class_channel_shape'] = class_channel.shape.as_list()
        
        debug_info['step'] = 'Computing gradients'
        grads = tape.gradient(class_channel, layer_output)
        
        if grads is None:
            debug_info['error'] = "Gradients are None - no backpropagation path"
            return None, debug_info
        
        debug_info['gradients_shape'] = grads.shape.as_list()
        debug_info['gradients_stats'] = {
            'min': float(tf.reduce_min(grads)),
            'max': float(tf.reduce_max(grads)),
            'mean': float(tf.reduce_mean(grads)),
            'std': float(tf.math.reduce_std(grads))
        }
        
        debug_info['step'] = 'Processing gradients based on layer type'
        
        if len(layer_output.shape) == 4:  # Conv layer
            debug_info['processing_type'] = 'Convolutional layer (4D)'
            pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
            layer_output = layer_output[0]
            heatmap = layer_output @ pooled_grads[..., tf.newaxis]
            heatmap = tf.squeeze(heatmap)
            
        elif len(layer_output.shape) == 2:  # Dense layer
            debug_info['processing_type'] = 'Dense layer (2D)'
            # For dense layers, create spatial heatmap from gradient magnitude
            grads_magnitude = tf.reduce_mean(tf.abs(grads))
            # Create a simple spatial pattern
            heatmap = tf.ones((14, 14)) * grads_magnitude
            
        else:
            debug_info['error'] = f"Unsupported layer shape: {layer_output.shape}"
            return None, debug_info
        
        debug_info['step'] = 'Normalizing heatmap'
        debug_info['raw_heatmap_stats'] = {
            'min': float(tf.reduce_min(heatmap)),
            'max': float(tf.reduce_max(heatmap)),
            'mean': float(tf.reduce_mean(heatmap)),
            'std': float(tf.math.reduce_std(heatmap))
        }
        
        # Apply ReLU (remove negative values)
        heatmap = tf.maximum(heatmap, 0)
        
        # Normalize
        heatmap_max = tf.reduce_max(heatmap)
        if heatmap_max > 0:
            heatmap = heatmap / heatmap_max
        else:
            debug_info['error'] = "All heatmap values are zero or negative"
            return None, debug_info
        
        debug_info['final_heatmap_stats'] = {
            'min': float(tf.reduce_min(heatmap)),
            'max': float(tf.reduce_max(heatmap)),
            'mean': float(tf.reduce_mean(heatmap)),
            'std': float(tf.math.reduce_std(heatmap))
        }
        
        debug_info['step'] = 'Complete'
        return heatmap.numpy(), debug_info
        
    except Exception as e:
        debug_info['error'] = f"Exception in step '{debug_info['step']}': {str(e)}"
        return None, debug_info

def create_robust_gradcam_heatmap(img, model, predictions):
    """Create Grad-CAM with comprehensive debugging."""
    try:
        # Preprocess image
        img_resized = img.resize((224, 224))
        img_array = np.array(img_resized, dtype=np.float32)
        
        # Handle grayscale
        if len(img_array.shape) == 2:
            img_array = np.stack([img_array] * 3, axis=-1)
        
        # Normalize and add batch dimension
        img_array = np.expand_dims(img_array, axis=0) / 255.0
        
        # Get model analysis
        analysis = analyze_model_architecture(model)
        
        # Try different layers in order of preference
        layer_candidates = []
        
        # Add conv layers first
        for layer in analysis['conv_layers']:
            layer_candidates.append((layer['name'], f"Conv layer: {layer['name']}"))
        
        # Add other potentially suitable layers
        for layer in analysis['all_layers_detailed']:
            if (layer['type'] in ['Activation', 'BatchNormalization'] and 
                isinstance(layer['output_shape'], (list, tuple)) and 
                len(layer['output_shape']) == 4):
                layer_candidates.append((layer['name'], f"4D layer: {layer['name']} ({layer['type']})"))
        
        # Try dense layers as last resort
        if not layer_candidates:
            for layer in analysis['dense_layers']:
                layer_candidates.append((layer['name'], f"Dense layer: {layer['name']} (experimental)"))
        
        if not layer_candidates:
            return None, "❌ No suitable layers found", None
        
        # Try each candidate layer
        for layer_name, layer_desc in layer_candidates:
            pred_index = np.argmax(predictions)
            
            heatmap, debug_info = debug_gradcam_step_by_step(
                img_array, model, layer_name, pred_index
            )
            
            if heatmap is not None:
                # Resize heatmap to match input image size
                if heatmap.shape[0] != 224 or heatmap.shape[1] != 224:
                    heatmap_resized = tf.image.resize(
                        heatmap[..., tf.newaxis], 
                        (224, 224)
                    ).numpy()[:, :, 0]
                else:
                    heatmap_resized = heatmap
                
                # Final statistics
                stats = {
                    'min': float(np.min(heatmap_resized)),
                    'max': float(np.max(heatmap_resized)),
                    'mean': float(np.mean(heatmap_resized)),
                    'std': float(np.std(heatmap_resized))
                }
                
                return heatmap_resized, f"βœ… Grad-CAM successful using {layer_desc}", stats, debug_info
            else:
                # Continue to next layer if this one failed
                continue
        
        # If all layers failed, return debug info from the last attempt
        return None, f"❌ All layers failed. Last error: {debug_info.get('error', 'Unknown')}", None, debug_info
        
    except Exception as e:
        return None, f"❌ Grad-CAM error: {str(e)}", None, {'error': str(e)}

def predict_stroke(img, model):
    """Predict stroke type from image."""
    if model is None:
        return None, "Model not loaded"
    
    try:
        # Preprocess image
        img_resized = img.resize((224, 224))
        img_array = np.array(img_resized, dtype=np.float32)
        
        # Handle grayscale
        if len(img_array.shape) == 2:
            img_array = np.stack([img_array] * 3, axis=-1)
        
        # Normalize and add batch dimension
        img_array = np.expand_dims(img_array, axis=0) / 255.0
        
        # Predict
        predictions = model.predict(img_array, verbose=0)
        
        return predictions[0], None
    
    except Exception as e:
        return None, f"Prediction error: {str(e)}"

def create_enhanced_simulated_heatmap(img, predictions):
    """Create a more realistic simulated heatmap."""
    try:
        confidence = np.max(predictions)
        predicted_class = np.argmax(predictions)
        
        # Create different patterns based on predicted class
        if predicted_class == 0:  # Hemorrhagic
            # Focus on center-left region
            center_x, center_y = 80, 112
        elif predicted_class == 1:  # Ischemic
            # Focus on right side
            center_x, center_y = 150, 112
        else:  # No stroke
            # Diffuse, low-intensity pattern
            center_x, center_y = 112, 112
        
        # Create base pattern
        y, x = np.ogrid[:224, :224]
        
        # Primary focus area
        mask1 = np.exp(-((x - center_x)**2 + (y - center_y)**2) / (2 * (40**2)))
        
        # Secondary areas
        mask2 = np.exp(-((x - center_x + 30)**2 + (y - center_y + 20)**2) / (2 * (25**2)))
        mask3 = np.exp(-((x - center_x - 20)**2 + (y - center_y - 30)**2) / (2 * (30**2)))
        
        # Combine patterns
        heatmap = (mask1 * 0.8 + mask2 * 0.4 + mask3 * 0.3) * confidence
        
        # Add some noise for realism
        np.random.seed(42)
        noise = np.random.normal(0, 0.05, heatmap.shape)
        heatmap = np.maximum(heatmap + noise, 0)
        
        # Normalize
        if np.max(heatmap) > 0:
            heatmap = heatmap / np.max(heatmap)
        
        stats = {
            'min': float(np.min(heatmap)),
            'max': float(np.max(heatmap)),
            'mean': float(np.mean(heatmap)),
            'std': float(np.std(heatmap))
        }
        
        return heatmap, "⚠️ Using enhanced simulated heatmap", stats
    except Exception as e:
        return None, f"❌ Simulated heatmap error: {str(e)}", None

def create_comprehensive_visualization(img, predictions, model, force_gradcam=True, colormap='hot'):
    """Create comprehensive visualization with debugging."""
    if not MPL_AVAILABLE:
        return None, "❌ Matplotlib not available"
    
    try:
        # Resize image to 224x224
        img_resized = img.resize((224, 224))
        img_array = np.array(img_resized)
        
        heatmap = None
        status_message = ""
        stats = None
        debug_info = None
        
        # Try Grad-CAM first
        if force_gradcam and model is not None:
            result = create_robust_gradcam_heatmap(img, model, predictions)
            if result and len(result) >= 3:
                heatmap, gradcam_status, stats = result[0], result[1], result[2]
                if len(result) > 3:
                    debug_info = result[3]
                status_message = gradcam_status
        
        # Fallback to enhanced simulated if Grad-CAM failed
        if heatmap is None:
            result = create_enhanced_simulated_heatmap(img, predictions)
            if result and len(result) == 3:
                heatmap, sim_status, stats = result
                if status_message:
                    status_message += f" | {sim_status}"
                else:
                    status_message = sim_status
        
        if heatmap is None:
            return None, "❌ Could not generate any heatmap", None, None
        
        # Create visualization
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        # 1. Original image
        axes[0].imshow(img_array)
        axes[0].set_title("Original Image", fontsize=12, fontweight='bold')
        axes[0].axis('off')
        
        # 2. Heatmap only
        im1 = axes[1].imshow(heatmap, cmap=colormap, vmin=0, vmax=1)
        axes[1].set_title(f"Attention Heatmap ({colormap})", fontsize=12, fontweight='bold')
        axes[1].axis('off')
        plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)
        
        # 3. Overlay
        axes[2].imshow(img_array)
        im2 = axes[2].imshow(heatmap, cmap=colormap, alpha=0.6, vmin=0, vmax=1, interpolation='bilinear')
        
        # Determine title based on success
        if "βœ… Grad-CAM successful" in status_message:
            title = "🎯 Real AI Attention Overlay"
            title_color = 'green'
        else:
            title = "🎨 Simulated Attention Overlay"
            title_color = 'orange'
        
        axes[2].set_title(title, fontsize=12, fontweight='bold', color=title_color)
        axes[2].axis('off')
        plt.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04)
        
        plt.tight_layout()
        
        return fig, status_message, stats, debug_info
    
    except Exception as e:
        return None, f"❌ Visualization error: {str(e)}", None, None

# Main App
def main():
    # Header
    st.markdown('<h1 class="main-header">🧠 AI-Powered Stroke Classification System</h1>', unsafe_allow_html=True)
    
    # Auto-load model on startup
    if not st.session_state.model_loaded:
        with st.spinner("Loading AI model..."):
            st.session_state.model, st.session_state.model_status = load_stroke_model()
            st.session_state.model_loaded = True

    # System status
    st.markdown("### πŸ”§ System Status")
    col1, col2, col3 = st.columns(3)
    
    with col1:
        if TF_AVAILABLE:
            st.markdown('<div class="status-box success">βœ… TensorFlow Ready</div>', unsafe_allow_html=True)
            st.write(f"TF Version: {tf.__version__}")
        else:
            st.markdown('<div class="status-box error">❌ TensorFlow Error</div>', unsafe_allow_html=True)

    with col2:
        if MPL_AVAILABLE:
            st.markdown('<div class="status-box success">βœ… Matplotlib Ready</div>', unsafe_allow_html=True)
        else:
            st.markdown('<div class="status-box error">❌ Matplotlib Error</div>', unsafe_allow_html=True)

    with col3:
        if "βœ…" in st.session_state.model_status:
            st.markdown('<div class="status-box success">βœ… Model Loaded</div>', unsafe_allow_html=True)
        else:
            st.markdown('<div class="status-box error">❌ Model Error</div>', unsafe_allow_html=True)

    # Model status details
    st.markdown(f'<div class="status-box info"><strong>Model Status:</strong> {st.session_state.model_status}</div>', unsafe_allow_html=True)

    # Enhanced model architecture analysis
    if st.session_state.model is not None:
        with st.expander("πŸ” Detailed Model Architecture Analysis"):
            analysis = analyze_model_architecture(st.session_state.model)
            
            st.write("**πŸ“Š Model Summary:**")
            st.write(f"- **Model Type:** {analysis['model_type']}")
            st.write(f"- **Total Layers:** {analysis['total_layers']}")
            st.write(f"- **Convolutional Layers:** {len(analysis['conv_layers'])}")
            st.write(f"- **Dense Layers:** {len(analysis['dense_layers'])}")
            st.write(f"- **Other Layers:** {len(analysis['other_layers'])}")
            
            # Show detailed layer information
            st.write("**πŸ” All Layers (Detailed):**")
            for layer in analysis['all_layers_detailed']:
                activation_info = f" | Activation: {layer['activation']}" if layer['activation'] else ""
                st.code(f"{layer['index']:2d}: {layer['name']} ({layer['type']}) | Shape: {layer['output_shape']}{activation_info}")

    # Manual reload button
    if st.button("πŸ”„ Reload Model", help="Try to reload the model"):
        st.session_state.model_loaded = False
        st.rerun()

    # Sidebar
    with st.sidebar:
        st.header("πŸ“€ Upload Brain Scan")
        uploaded_file = st.file_uploader(
            "Choose a brain scan image...",
            type=['png', 'jpg', 'jpeg', 'bmp', 'tiff'],
            help="Upload a brain scan image for stroke classification"
        )
        
        st.markdown("---")
        st.header("🎨 Visualization Options")
        
        force_gradcam = st.checkbox(
            "Attempt Grad-CAM", 
            value=True,
            help="Try Grad-CAM with comprehensive debugging"
        )
        
        colormap = st.selectbox(
            "Color Scheme",
            ['hot', 'jet', 'viridis', 'plasma', 'inferno', 'magma', 'coolwarm'],
            index=0,
            help="Choose color scheme for heatmap visualization"
        )
        
        show_probabilities = st.checkbox("Show All Probabilities", value=True)
        show_debug = st.checkbox("Show Debug Info", value=True)
        show_stats = st.checkbox("Show Heatmap Statistics", value=True)
        show_detailed_debug = st.checkbox("Show Detailed Debug Info", value=False)

    if uploaded_file is not None:
        # Load image
        image = Image.open(uploaded_file)
        
        # Main content area
        col1, col2 = st.columns([1, 2])
        
        with col1:
            st.subheader("πŸ“‹ Classification Results")
            
            if st.session_state.model is not None:
                # Predict
                with st.spinner("πŸ” Analyzing brain scan..."):
                    predictions, error = predict_stroke(image, st.session_state.model)
                
                if error:
                    st.error(error)
                else:
                    # Get top prediction
                    class_idx = np.argmax(predictions)
                    confidence = predictions[class_idx] * 100
                    predicted_class = STROKE_LABELS[class_idx]
                    
                    # Display main result
                    st.markdown(f"""
                    <div class="prediction-box">
                        <h2>{predicted_class}</h2>
                        <h3>Confidence: {confidence:.1f}%</h3>
                    </div>
                    """, unsafe_allow_html=True)
                    
                    # Show all probabilities
                    if show_probabilities:
                        st.write("**πŸ“Š All Probabilities:**")
                        for i, (label, prob) in enumerate(zip(STROKE_LABELS, predictions)):
                            st.write(f"β€’ {label}: {prob*100:.1f}%")
            else:
                st.error("❌ Model not loaded. Check the debug information above to see available files.")

        with col2:
            st.subheader("🎯 Comprehensive AI Attention Visualization")
            
            if st.session_state.model is not None and 'predictions' in locals() and predictions is not None:
                # Create comprehensive visualization
                with st.spinner("🎨 Generating comprehensive attention visualization..."):
                    result = create_comprehensive_visualization(
                        image, 
                        predictions, 
                        st.session_state.model, 
                        force_gradcam,
                        colormap
                    )
                
                if result and len(result) >= 2:
                    overlay_fig, status_message = result[0], result[1]
                    stats = result[2] if len(result) > 2 else None
                    debug_info = result[3] if len(result) > 3 else None
                    
                    if overlay_fig is not None:
                        st.pyplot(overlay_fig)
                        plt.close()
                        
                        # Show detailed status
                        if show_debug:
                            if "βœ… Grad-CAM successful" in status_message:
                                st.success(f"βœ… {status_message}")
                            elif "⚠️" in status_message:
                                st.warning(f"⚠️ {status_message}")
                            else:
                                st.error(f"❌ {status_message}")
                        
                        # Show heatmap statistics
                        if show_stats and stats:
                            st.write("**πŸ“ˆ Heatmap Statistics:**")
                            if any(np.isnan([stats['min'], stats['max'], stats['mean'], stats['std']])):
                                st.error("⚠️ NaN values detected in heatmap - this indicates a computation error")
                            else:
                                col_stats1, col_stats2 = st.columns(2)
                                with col_stats1:
                                    st.write(f"β€’ Min: {stats['min']:.3f}")
                                    st.write(f"β€’ Max: {stats['max']:.3f}")
                                with col_stats2:
                                    st.write(f"β€’ Mean: {stats['mean']:.3f}")
                                    st.write(f"β€’ Std: {stats['std']:.3f}")
                        
                        # Show detailed debug information
                        if show_detailed_debug and debug_info:
                            with st.expander("πŸ”§ Detailed Debug Information"):
                                st.json(debug_info)
                    else:
                        st.error(f"Could not generate visualization: {status_message}")
                        if debug_info:
                            st.error(f"Debug info: {debug_info.get('error', 'No additional info')}")
                else:
                    st.error("Could not generate attention visualization")
            else:
                st.info("Upload an image and run classification to see AI attention visualization")
    
    else:
        # Welcome message
        st.markdown("""
        ## πŸ‘‹ Welcome to the Comprehensive Stroke Classification System
        
        This system now includes **step-by-step debugging** to identify why Grad-CAM might be failing.
        
        ### πŸ”§ New Debugging Features:
        - **Step-by-step Grad-CAM debugging** - See exactly where it fails
        - **Multiple layer attempts** - Tries different layers automatically
        - **Enhanced error messages** - Clear explanations of what went wrong
        - **NaN detection** - Identifies computation errors
        
        ### 🎯 What to Look For:
        - **Green success messages** - Grad-CAM is working
        - **Orange warnings** - Using fallback methods
        - **Red errors** - Something is broken
        - **NaN statistics** - Computation failure
        
        **Upload an image to see detailed debugging! πŸ‘ˆ**
        """)

    # Medical disclaimer
    st.markdown("---")
    st.warning("⚠️ **Medical Disclaimer:** This AI system is for educational and research purposes only. It should not be used for actual medical diagnosis. Always consult qualified healthcare professionals for medical decisions.")

if __name__ == "__main__":
    main()