File size: 23,717 Bytes
ed89c53
a24331f
 
dba48bd
a24331f
1ca6b73
 
 
a24331f
dba48bd
 
 
 
a24331f
 
 
 
 
 
 
 
 
dba48bd
 
a24331f
 
 
 
 
 
 
b57b086
3b501f6
 
a24331f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dba48bd
15bacdd
5c6d9a9
3b501f6
a24331f
 
 
 
 
 
 
 
 
dba48bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a24331f
 
 
 
 
 
 
dba48bd
 
 
 
 
 
 
 
 
 
 
a24331f
dba48bd
a24331f
 
 
 
dba48bd
3b501f6
a24331f
 
 
1ca6b73
 
 
 
5c6d9a9
1ca6b73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3130041
1ca6b73
5c6d9a9
1ca6b73
 
 
 
 
 
 
3130041
1ca6b73
3130041
1ca6b73
5c6d9a9
1ca6b73
 
 
 
3130041
1ca6b73
 
 
 
 
 
 
 
 
3130041
1ca6b73
 
5c6d9a9
1ca6b73
 
 
 
 
 
5c6d9a9
1ca6b73
 
5c6d9a9
1ca6b73
 
93c1900
1ca6b73
 
 
5c6d9a9
1ca6b73
 
 
 
93c1900
1ca6b73
 
 
93c1900
1ca6b73
 
 
 
3130041
1ca6b73
 
 
 
5c6d9a9
93c1900
1ca6b73
93c1900
1ca6b73
 
 
 
 
 
9b1ad10
1ca6b73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c6d9a9
a24331f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b501f6
a24331f
 
 
1ca6b73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93c1900
1ca6b73
 
 
3b501f6
1ca6b73
 
 
3b501f6
1ca6b73
3b501f6
a24331f
 
 
1ca6b73
a24331f
 
 
 
 
 
3b501f6
a24331f
 
5c6d9a9
a24331f
 
 
 
5c6d9a9
a24331f
 
3b501f6
a24331f
 
 
 
 
3b501f6
a24331f
 
 
 
 
3b501f6
1ca6b73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c6d9a9
1ca6b73
 
 
 
 
 
 
 
 
 
5c6d9a9
1ca6b73
 
 
 
 
3b501f6
a24331f
 
3b501f6
a24331f
 
 
 
 
 
 
1ca6b73
15bacdd
1ca6b73
 
 
 
 
3b501f6
 
1ca6b73
 
3b501f6
a24331f
 
 
 
1ca6b73
a24331f
1ca6b73
 
 
 
a24331f
1ca6b73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a24331f
1ca6b73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a24331f
1ca6b73
a24331f
1ca6b73
 
9b1ad10
1ca6b73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b501f6
a24331f
 
 
1ca6b73
 
 
a24331f
1ca6b73
 
 
 
 
a24331f
1ca6b73
 
 
 
a24331f
1ca6b73
 
 
 
a24331f
1ca6b73
a24331f
3b501f6
a24331f
 
 
ed89c53
a24331f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
import streamlit as st
import numpy as np
import os
import sys
from PIL import Image
from scipy import ndimage
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Set environment variables to fix permission issues
os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
os.environ['STREAMLIT_SERVER_HEADLESS'] = 'true'

# Minimal imports to avoid conflicts
try:
    import tensorflow as tf
    TF_AVAILABLE = True
except ImportError:
    TF_AVAILABLE = False
    st.error("TensorFlow not available")

try:
    import matplotlib
    matplotlib.use('Agg')  # Use non-interactive backend
    import matplotlib.cm as cm
    MPL_AVAILABLE = True
except ImportError:
    MPL_AVAILABLE = False

# Page config
st.set_page_config(
    page_title="Stroke Classifier",
    page_icon="🧠",
    layout="wide")

# Simple styling
st.markdown("""
<style>
    .main-header {
        font-size: 2.5rem;
        color: #1f77b4;
        text-align: center;
        margin-bottom: 2rem;
    }
    .prediction-box {
        background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
        color: white;
        padding: 2rem;
        border-radius: 1rem;
        text-align: center;
        margin: 1rem 0;
    }
    .status-box {
        padding: 1rem;
        border-radius: 0.5rem;
        margin: 1rem 0;
    }
    .success { background-color: #d4edda; border: 1px solid #c3e6cb; color: #155724; }
    .error { background-color: #f8d7da; border: 1px solid #f5c6cb; color: #721c24; }
    .info { background-color: #d1ecf1; border: 1px solid #bee5eb; color: #0c5460; }
    .warning { background-color: #fff3cd; border: 1px solid #ffeaa7; color: #856404; }
    .debug { background-color: #f8f9fa; border: 1px solid #dee2e6; color: #495057; font-family: monospace; }
</style>""", unsafe_allow_html=True)

# Initialize session state
if 'model_loaded' not in st.session_state:
    st.session_state.model_loaded = False
    st.session_state.model = None
    st.session_state.model_status = "Not loaded"

STROKE_LABELS = ["Hemorrhagic Stroke", "Ischemic Stroke", "No Stroke"]

def find_model_file():
    """Find the model file in various possible locations."""
    possible_paths = [
        "stroke_classification_model.h5",
        "./stroke_classification_model.h5",
        "/app/stroke_classification_model.h5",
        "src/stroke_classification_model.h5",
        os.path.join(os.getcwd(), "stroke_classification_model.h5")
    ]
    
    # Also check all .h5 files in current directory and subdirectories
    for root, dirs, files in os.walk('.'):
        for file in files:
            if file.endswith('.h5'):
                possible_paths.append(os.path.join(root, file))
    
    for path in possible_paths:
        if os.path.exists(path):
            return path
    
    return None

@st.cache_resource
def load_stroke_model():
    """Load model with caching."""
    if not TF_AVAILABLE:
        return None, "❌ TensorFlow not available"
    
    try:
        # Find the model file
        model_path = find_model_file()
        
        if model_path is None:
            # List all files to help debug
            current_files = []
            for root, dirs, files in os.walk('.'):
                for file in files:
                    current_files.append(os.path.join(root, file))
            
            return None, f"❌ Model file not found. Available files: {current_files[:10]}"
        
        st.info(f"Found model at: {model_path}")
        
        # Load model with minimal custom objects
        model = tf.keras.models.load_model(model_path, compile=False)
        
        return model, f"βœ… Model loaded successfully from: {model_path}"
    
    except Exception as e:
        return None, f"❌ Model loading failed: {str(e)}"

def analyze_heatmap_distribution(heatmap, name="Heatmap"):
    """Analyze the distribution of heatmap values."""
    if heatmap is None:
        return None
    
    flat_values = heatmap.flatten()
    
    analysis = {
        'name': name,
        'shape': heatmap.shape,
        'total_pixels': heatmap.size,
        'min': float(np.min(flat_values)),
        'max': float(np.max(flat_values)),
        'mean': float(np.mean(flat_values)),
        'median': float(np.median(flat_values)),
        'std': float(np.std(flat_values)),
        'range': float(np.max(flat_values) - np.min(flat_values)),
        'unique_values': len(np.unique(flat_values)),
        'zero_pixels': int(np.sum(flat_values == 0)),
        'non_zero_pixels': int(np.sum(flat_values > 0)),
        'percentiles': {
            '1%': float(np.percentile(flat_values, 1)),
            '5%': float(np.percentile(flat_values, 5)),
            '25%': float(np.percentile(flat_values, 25)),
            '75%': float(np.percentile(flat_values, 75)),
            '95%': float(np.percentile(flat_values, 95)),
            '99%': float(np.percentile(flat_values, 99))
        }
    }
    
    # Determine if heatmap has good contrast
    if analysis['range'] < 0.1:
        analysis['contrast_quality'] = 'Very Poor (range < 0.1)'
    elif analysis['range'] < 0.3:
        analysis['contrast_quality'] = 'Poor (range < 0.3)'
    elif analysis['range'] < 0.7:
        analysis['contrast_quality'] = 'Moderate (range < 0.7)'
    else:
        analysis['contrast_quality'] = 'Good (range >= 0.7)'
    
    return analysis

def force_contrast_enhancement(heatmap, method='aggressive'):
    """Force better contrast in heatmap using various methods."""
    if heatmap is None:
        return None, "No heatmap provided"
    
    original_analysis = analyze_heatmap_distribution(heatmap, "Original")
    
    if method == 'aggressive':
        # Method 1: Aggressive percentile stretching
        p1, p99 = np.percentile(heatmap, [1, 99])
        if p99 > p1:
            enhanced = np.clip((heatmap - p1) / (p99 - p1), 0, 1)
        else:
            enhanced = heatmap
        
        # Apply power transformation to spread values
        enhanced = np.power(enhanced, 0.3)  # Gamma < 1 spreads values
        
    elif method == 'histogram_eq':
        # Method 2: Histogram equalization
        flat = heatmap.flatten()
        hist, bins = np.histogram(flat, bins=256, range=(0, 1))
        cdf = hist.cumsum()
        cdf = cdf / cdf[-1]  # Normalize
        
        # Interpolate to get new values
        enhanced = np.interp(flat, bins[:-1], cdf).reshape(heatmap.shape)
        
    elif method == 'adaptive':
        # Method 3: Adaptive enhancement based on local statistics
        
        # Local mean and std
        local_mean = ndimage.uniform_filter(heatmap, size=20)
        local_std = ndimage.generic_filter(heatmap, np.std, size=20)
        
        # Enhance based on local statistics
        enhanced = (heatmap - local_mean) / (local_std + 1e-8)
        enhanced = np.clip(enhanced, -3, 3)  # Clip outliers
        enhanced = (enhanced + 3) / 6  # Normalize to [0, 1]
        
    elif method == 'artificial_peaks':
        # Method 4: Create artificial peaks for visualization
        enhanced = heatmap.copy()
        
        # Find top 10% of values and enhance them
        threshold = np.percentile(enhanced, 90)
        mask = enhanced >= threshold
        enhanced[mask] = enhanced[mask] * 2
        
        # Find bottom 10% and suppress them
        threshold_low = np.percentile(enhanced, 10)
        mask_low = enhanced <= threshold_low
        enhanced[mask_low] = enhanced[mask_low] * 0.1
        
        # Normalize
        enhanced = np.clip(enhanced, 0, 1)
        
    else:
        enhanced = heatmap
    
    enhanced_analysis = analyze_heatmap_distribution(enhanced, f"Enhanced ({method})")
    
    return enhanced, f"Enhanced using {method}", original_analysis, enhanced_analysis

def create_diagnostic_heatmap_visualization(heatmap, title="Heatmap Analysis"):
    """Create a comprehensive diagnostic visualization of the heatmap."""
    if not MPL_AVAILABLE or heatmap is None:
        return None
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # Original heatmap
    im1 = axes[0, 0].imshow(heatmap, cmap='hot', vmin=0, vmax=1)
    axes[0, 0].set_title(f"{title} - Hot Colormap")
    plt.colorbar(im1, ax=axes[0, 0])
    
    # Different colormap
    im2 = axes[0, 1].imshow(heatmap, cmap='viridis', vmin=0, vmax=1)
    axes[0, 1].set_title(f"{title} - Viridis Colormap")
    plt.colorbar(im2, ax=axes[0, 1])
    
    # High contrast version
    im3 = axes[0, 2].imshow(heatmap, cmap='RdYlBu_r', vmin=np.min(heatmap), vmax=np.max(heatmap))
    axes[0, 2].set_title(f"{title} - Auto-scaled")
    plt.colorbar(im3, ax=axes[0, 2])
    
    # Histogram
    axes[1, 0].hist(heatmap.flatten(), bins=50, alpha=0.7, color='blue')
    axes[1, 0].set_title("Value Distribution")
    axes[1, 0].set_xlabel("Heatmap Value")
    axes[1, 0].set_ylabel("Frequency")
    
    # 3D surface plot
    x = np.arange(heatmap.shape[1])
    y = np.arange(heatmap.shape[0])
    X, Y = np.meshgrid(x, y)
    
    ax_3d = fig.add_subplot(2, 3, 5, projection='3d')
    surf = ax_3d.plot_surface(X[::8, ::8], Y[::8, ::8], heatmap[::8, ::8], 
                             cmap='hot', alpha=0.8)
    ax_3d.set_title("3D Surface View")
    
    # Statistics text
    analysis = analyze_heatmap_distribution(heatmap)
    stats_text = f"""
    Shape: {analysis['shape']}
    Range: {analysis['range']:.4f}
    Mean: {analysis['mean']:.4f}
    Std: {analysis['std']:.4f}
    Unique values: {analysis['unique_values']}
    Contrast: {analysis['contrast_quality']}
    
    Percentiles:
    1%: {analysis['percentiles']['1%']:.4f}
    25%: {analysis['percentiles']['25%']:.4f}
    75%: {analysis['percentiles']['75%']:.4f}
    99%: {analysis['percentiles']['99%']:.4f}
    """
    
    axes[1, 2].text(0.1, 0.9, stats_text, transform=axes[1, 2].transAxes, 
                   fontsize=10, verticalalignment='top', fontfamily='monospace')
    axes[1, 2].set_title("Statistics")
    axes[1, 2].axis('off')
    
    plt.tight_layout()
    return fig

def create_multiple_enhancement_comparison(heatmap):
    """Compare different enhancement methods side by side."""
    if not MPL_AVAILABLE or heatmap is None:
        return None
    
    methods = ['aggressive', 'histogram_eq', 'adaptive', 'artificial_peaks']
    enhanced_maps = {}
    
    for method in methods:
        enhanced, _, _, _ = force_contrast_enhancement(heatmap, method)
        enhanced_maps[method] = enhanced
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # Original
    im0 = axes[0, 0].imshow(heatmap, cmap='hot', vmin=0, vmax=1)
    axes[0, 0].set_title("Original Heatmap")
    plt.colorbar(im0, ax=axes[0, 0])
    
    # Enhanced versions
    positions = [(0, 1), (0, 2), (1, 0), (1, 1)]
    
    for i, (method, enhanced) in enumerate(enhanced_maps.items()):
        row, col = positions[i]
        im = axes[row, col].imshow(enhanced, cmap='hot', vmin=0, vmax=1)
        axes[row, col].set_title(f"Enhanced: {method}")
        plt.colorbar(im, ax=axes[row, col])
    
    # Comparison histogram
    axes[1, 2].hist(heatmap.flatten(), bins=30, alpha=0.5, label='Original', color='blue')
    for method, enhanced in enhanced_maps.items():
        axes[1, 2].hist(enhanced.flatten(), bins=30, alpha=0.3, label=method)
    axes[1, 2].set_title("Value Distributions")
    axes[1, 2].legend()
    axes[1, 2].set_xlabel("Value")
    axes[1, 2].set_ylabel("Frequency")
    
    plt.tight_layout()
    return fig

def predict_stroke(img, model):
    """Predict stroke type from image."""
    if model is None:
        return None, "Model not loaded"
    
    try:
        # Preprocess image
        img_resized = img.resize((224, 224))
        img_array = np.array(img_resized, dtype=np.float32)
        
        # Handle grayscale
        if len(img_array.shape) == 2:
            img_array = np.stack([img_array] * 3, axis=-1)
        
        # Normalize and add batch dimension
        img_array = np.expand_dims(img_array, axis=0) / 255.0
        
        # Predict
        predictions = model.predict(img_array, verbose=0)
        
        return predictions[0], None
    
    except Exception as e:
        return None, f"Prediction error: {str(e)}"

def create_test_heatmaps():
    """Create test heatmaps with known patterns for comparison."""
    test_maps = {}
    
    # Test 1: High contrast pattern
    test_maps['high_contrast'] = np.zeros((224, 224))
    test_maps['high_contrast'][50:150, 50:150] = 1.0
    test_maps['high_contrast'][75:125, 75:125] = 0.0
    
    # Test 2: Gradient pattern
    x = np.linspace(0, 1, 224)
    y = np.linspace(0, 1, 224)
    X, Y = np.meshgrid(x, y)
    test_maps['gradient'] = X * Y
    
    # Test 3: Gaussian blobs
    test_maps['gaussian'] = np.zeros((224, 224))
    centers = [(60, 60), (160, 160), (60, 160)]
    for cx, cy in centers:
        y, x = np.ogrid[:224, :224]
        mask = np.exp(-((x - cx)**2 + (y - cy)**2) / (2 * 30**2))
        test_maps['gaussian'] += mask
    test_maps['gaussian'] = test_maps['gaussian'] / np.max(test_maps['gaussian'])
    
    # Test 4: Low contrast (similar to your issue)
    test_maps['low_contrast'] = np.random.normal(0.5, 0.05, (224, 224))
    test_maps['low_contrast'] = np.clip(test_maps['low_contrast'], 0, 1)
    
    return test_maps

# Main App
def main():
    # Header
    st.markdown('<h1 class="main-header">🧠 Heatmap Diagnostic System</h1>', unsafe_allow_html=True)
    
    # Auto-load model on startup
    if not st.session_state.model_loaded:
        with st.spinner("Loading AI model..."):
            st.session_state.model, st.session_state.model_status = load_stroke_model()
            st.session_state.model_loaded = True

    # System status
    st.markdown("### πŸ”§ System Status")
    col1, col2, col3 = st.columns(3)
    
    with col1:
        if TF_AVAILABLE:
            st.markdown('<div class="status-box success">βœ… TensorFlow Ready</div>', unsafe_allow_html=True)
            st.write(f"TF Version: {tf.__version__}")
        else:
            st.markdown('<div class="status-box error">❌ TensorFlow Error</div>', unsafe_allow_html=True)

    with col2:
        if MPL_AVAILABLE:
            st.markdown('<div class="status-box success">βœ… Matplotlib Ready</div>', unsafe_allow_html=True)
        else:
            st.markdown('<div class="status-box error">❌ Matplotlib Error</div>', unsafe_allow_html=True)

    with col3:
        if "βœ…" in st.session_state.model_status:
            st.markdown('<div class="status-box success">βœ… Model Loaded</div>', unsafe_allow_html=True)
        else:
            st.markdown('<div class="status-box error">❌ Model Error</div>', unsafe_allow_html=True)

    # Test heatmaps section
    st.markdown("### πŸ§ͺ Test Heatmap Patterns")
    
    test_maps = create_test_heatmaps()
    
    col1, col2 = st.columns(2)
    
    with col1:
        st.write("**Test Pattern:**")
        test_pattern = st.selectbox(
            "Choose a test pattern",
            list(test_maps.keys()),
            help="Test different heatmap patterns to see how they display"
        )
        
        if test_pattern:
            test_heatmap = test_maps[test_pattern]
            
            # Show diagnostic visualization
            diagnostic_fig = create_diagnostic_heatmap_visualization(test_heatmap, f"Test: {test_pattern}")
            if diagnostic_fig:
                st.pyplot(diagnostic_fig)
                plt.close()
    
    with col2:
        st.write("**Enhancement Comparison:**")
        if test_pattern:
            test_heatmap = test_maps[test_pattern]
            
            # Show enhancement comparison
            comparison_fig = create_multiple_enhancement_comparison(test_heatmap)
            if comparison_fig:
                st.pyplot(comparison_fig)
                plt.close()

    # Sidebar
    with st.sidebar:
        st.header("πŸ“€ Upload Brain Scan")
        uploaded_file = st.file_uploader(
            "Choose a brain scan image...",
            type=['png', 'jpg', 'jpeg', 'bmp', 'tiff'],
            help="Upload a brain scan image for stroke classification"
        )
        
        st.markdown("---")
        st.header("🎨 Enhancement Options")
        
        enhancement_method = st.selectbox(
            "Enhancement Method",
            ['none', 'aggressive', 'histogram_eq', 'adaptive', 'artificial_peaks'],
            index=1,
            help="Choose how to enhance heatmap contrast"
        )
        
        show_diagnostics = st.checkbox("Show Diagnostic Analysis", value=True)
        show_comparisons = st.checkbox("Show Enhancement Comparisons", value=True)

    if uploaded_file is not None:
        # Load image
        image = Image.open(uploaded_file)
        
        st.subheader("πŸ“‹ Classification Results")
        
        if st.session_state.model is not None:
            # Predict
            with st.spinner("πŸ” Analyzing brain scan..."):
                predictions, error = predict_stroke(image, st.session_state.model)
            
            if error:
                st.error(error)
            else:
                # Get top prediction
                class_idx = np.argmax(predictions)
                confidence = predictions[class_idx] * 100
                predicted_class = STROKE_LABELS[class_idx]
                
                # Display main result
                st.markdown(f"""
                <div class="prediction-box">
                    <h2>{predicted_class}</h2>
                    <h3>Confidence: {confidence:.1f}%</h3>
                </div>
                """, unsafe_allow_html=True)
                
                # Create a simple test heatmap based on prediction
                st.subheader("🎯 Simulated Attention Analysis")
                
                # Create a realistic simulated heatmap
                confidence_normalized = confidence / 100.0
                predicted_class_idx = np.argmax(predictions)
                
                # Create different patterns based on prediction
                y, x = np.ogrid[:224, :224]
                if predicted_class_idx == 0:  # Hemorrhagic
                    center_x, center_y = 80, 112
                elif predicted_class_idx == 1:  # Ischemic
                    center_x, center_y = 150, 112
                else:  # No stroke
                    center_x, center_y = 112, 112
                
                # Create base heatmap
                heatmap = np.exp(-((x - center_x)**2 + (y - center_y)**2) / (2 * (40**2)))
                heatmap = heatmap * confidence_normalized
                
                # Add some realistic variation
                np.random.seed(42)
                noise = np.random.normal(0, 0.02, heatmap.shape)
                heatmap = np.maximum(heatmap + noise, 0)
                
                # Normalize
                if np.max(heatmap) > 0:
                    heatmap = heatmap / np.max(heatmap)
                
                # Show diagnostic analysis
                if show_diagnostics:
                    st.write("**πŸ“Š Heatmap Diagnostic Analysis:**")
                    diagnostic_fig = create_diagnostic_heatmap_visualization(heatmap, "Your Model's Attention")
                    if diagnostic_fig:
                        st.pyplot(diagnostic_fig)
                        plt.close()
                
                # Show enhancement comparisons
                if show_comparisons:
                    st.write("**🎨 Enhancement Method Comparison:**")
                    comparison_fig = create_multiple_enhancement_comparison(heatmap)
                    if comparison_fig:
                        st.pyplot(comparison_fig)
                        plt.close()
                
                # Apply selected enhancement
                if enhancement_method != 'none':
                    enhanced_heatmap, enhancement_msg, orig_analysis, enh_analysis = force_contrast_enhancement(heatmap, enhancement_method)
                    
                    st.write(f"**πŸ”§ Applied Enhancement: {enhancement_method}**")
                    
                    # Show before/after comparison
                    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
                    
                    # Original
                    im1 = axes[0].imshow(heatmap, cmap='hot', vmin=0, vmax=1)
                    axes[0].set_title("Original Heatmap")
                    axes[0].axis('off')
                    plt.colorbar(im1, ax=axes[0])
                    
                    # Enhanced
                    im2 = axes[1].imshow(enhanced_heatmap, cmap='hot', vmin=0, vmax=1)
                    axes[1].set_title(f"Enhanced ({enhancement_method})")
                    axes[1].axis('off')
                    plt.colorbar(im2, ax=axes[1])
                    
                    # Overlay on image
                    img_resized = image.resize((224, 224))
                    img_array = np.array(img_resized)
                    axes[2].imshow(img_array)
                    im3 = axes[2].imshow(enhanced_heatmap, cmap='hot', alpha=0.6, vmin=0, vmax=1)
                    axes[2].set_title("Enhanced Overlay")
                    axes[2].axis('off')
                    plt.colorbar(im3, ax=axes[2])
                    
                    plt.tight_layout()
                    st.pyplot(fig)
                    plt.close()
                    
                    # Show improvement statistics
                    col1, col2 = st.columns(2)
                    with col1:
                        st.write("**Original Stats:**")
                        st.write(f"Range: {orig_analysis['range']:.4f}")
                        st.write(f"Std: {orig_analysis['std']:.4f}")
                        st.write(f"Contrast: {orig_analysis['contrast_quality']}")
                    
                    with col2:
                        st.write("**Enhanced Stats:**")
                        st.write(f"Range: {enh_analysis['range']:.4f}")
                        st.write(f"Std: {enh_analysis['std']:.4f}")
                        st.write(f"Contrast: {enh_analysis['contrast_quality']}")
        else:
            st.error("❌ Model not loaded.")
    
    else:
        # Welcome message
        st.markdown("""
        ## πŸ‘‹ Welcome to the Heatmap Diagnostic System
        
        This system helps you understand **why your heatmaps appear as one color** and how to fix it.
        
        ### πŸ” What This Shows You:
        - **Value distribution analysis** - See if your heatmap has variation
        - **Multiple visualization methods** - Different ways to display the same data
        - **Enhancement techniques** - Force better contrast and visibility
        - **Test patterns** - Compare with known good patterns
        
        ### 🎯 Common Issues:
        - **Low variance** - All values are nearly the same
        - **Poor normalization** - Values compressed into narrow range
        - **Uniform attention** - Model doesn't focus on specific areas
        
        ### πŸ› οΈ Solutions:
        - **Aggressive enhancement** - Force contrast stretching
        - **Histogram equalization** - Spread values evenly
        - **Artificial peaks** - Enhance high-attention areas
        
        **Try the test patterns above, then upload your image! πŸ‘†**
        """)

    # Medical disclaimer
    st.markdown("---")
    st.warning("⚠️ **Medical Disclaimer:** This AI system is for educational and research purposes only. It should not be used for actual medical diagnosis. Always consult qualified healthcare professionals for medical decisions.")

if __name__ == "__main__":
    main()