stroke-classification / src /streamlit_app.py
bakhili's picture
Update src/streamlit_app.py
1ca6b73 verified
import streamlit as st
import numpy as np
import os
import sys
from PIL import Image
from scipy import ndimage
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# Set environment variables to fix permission issues
os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
os.environ['STREAMLIT_SERVER_HEADLESS'] = 'true'
# Minimal imports to avoid conflicts
try:
import tensorflow as tf
TF_AVAILABLE = True
except ImportError:
TF_AVAILABLE = False
st.error("TensorFlow not available")
try:
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend
import matplotlib.cm as cm
MPL_AVAILABLE = True
except ImportError:
MPL_AVAILABLE = False
# Page config
st.set_page_config(
page_title="Stroke Classifier",
page_icon="🧠",
layout="wide")
# Simple styling
st.markdown("""
<style>
.main-header {
font-size: 2.5rem;
color: #1f77b4;
text-align: center;
margin-bottom: 2rem;
}
.prediction-box {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 2rem;
border-radius: 1rem;
text-align: center;
margin: 1rem 0;
}
.status-box {
padding: 1rem;
border-radius: 0.5rem;
margin: 1rem 0;
}
.success { background-color: #d4edda; border: 1px solid #c3e6cb; color: #155724; }
.error { background-color: #f8d7da; border: 1px solid #f5c6cb; color: #721c24; }
.info { background-color: #d1ecf1; border: 1px solid #bee5eb; color: #0c5460; }
.warning { background-color: #fff3cd; border: 1px solid #ffeaa7; color: #856404; }
.debug { background-color: #f8f9fa; border: 1px solid #dee2e6; color: #495057; font-family: monospace; }
</style>""", unsafe_allow_html=True)
# Initialize session state
if 'model_loaded' not in st.session_state:
st.session_state.model_loaded = False
st.session_state.model = None
st.session_state.model_status = "Not loaded"
STROKE_LABELS = ["Hemorrhagic Stroke", "Ischemic Stroke", "No Stroke"]
def find_model_file():
"""Find the model file in various possible locations."""
possible_paths = [
"stroke_classification_model.h5",
"./stroke_classification_model.h5",
"/app/stroke_classification_model.h5",
"src/stroke_classification_model.h5",
os.path.join(os.getcwd(), "stroke_classification_model.h5")
]
# Also check all .h5 files in current directory and subdirectories
for root, dirs, files in os.walk('.'):
for file in files:
if file.endswith('.h5'):
possible_paths.append(os.path.join(root, file))
for path in possible_paths:
if os.path.exists(path):
return path
return None
@st.cache_resource
def load_stroke_model():
"""Load model with caching."""
if not TF_AVAILABLE:
return None, "❌ TensorFlow not available"
try:
# Find the model file
model_path = find_model_file()
if model_path is None:
# List all files to help debug
current_files = []
for root, dirs, files in os.walk('.'):
for file in files:
current_files.append(os.path.join(root, file))
return None, f"❌ Model file not found. Available files: {current_files[:10]}"
st.info(f"Found model at: {model_path}")
# Load model with minimal custom objects
model = tf.keras.models.load_model(model_path, compile=False)
return model, f"βœ… Model loaded successfully from: {model_path}"
except Exception as e:
return None, f"❌ Model loading failed: {str(e)}"
def analyze_heatmap_distribution(heatmap, name="Heatmap"):
"""Analyze the distribution of heatmap values."""
if heatmap is None:
return None
flat_values = heatmap.flatten()
analysis = {
'name': name,
'shape': heatmap.shape,
'total_pixels': heatmap.size,
'min': float(np.min(flat_values)),
'max': float(np.max(flat_values)),
'mean': float(np.mean(flat_values)),
'median': float(np.median(flat_values)),
'std': float(np.std(flat_values)),
'range': float(np.max(flat_values) - np.min(flat_values)),
'unique_values': len(np.unique(flat_values)),
'zero_pixels': int(np.sum(flat_values == 0)),
'non_zero_pixels': int(np.sum(flat_values > 0)),
'percentiles': {
'1%': float(np.percentile(flat_values, 1)),
'5%': float(np.percentile(flat_values, 5)),
'25%': float(np.percentile(flat_values, 25)),
'75%': float(np.percentile(flat_values, 75)),
'95%': float(np.percentile(flat_values, 95)),
'99%': float(np.percentile(flat_values, 99))
}
}
# Determine if heatmap has good contrast
if analysis['range'] < 0.1:
analysis['contrast_quality'] = 'Very Poor (range < 0.1)'
elif analysis['range'] < 0.3:
analysis['contrast_quality'] = 'Poor (range < 0.3)'
elif analysis['range'] < 0.7:
analysis['contrast_quality'] = 'Moderate (range < 0.7)'
else:
analysis['contrast_quality'] = 'Good (range >= 0.7)'
return analysis
def force_contrast_enhancement(heatmap, method='aggressive'):
"""Force better contrast in heatmap using various methods."""
if heatmap is None:
return None, "No heatmap provided"
original_analysis = analyze_heatmap_distribution(heatmap, "Original")
if method == 'aggressive':
# Method 1: Aggressive percentile stretching
p1, p99 = np.percentile(heatmap, [1, 99])
if p99 > p1:
enhanced = np.clip((heatmap - p1) / (p99 - p1), 0, 1)
else:
enhanced = heatmap
# Apply power transformation to spread values
enhanced = np.power(enhanced, 0.3) # Gamma < 1 spreads values
elif method == 'histogram_eq':
# Method 2: Histogram equalization
flat = heatmap.flatten()
hist, bins = np.histogram(flat, bins=256, range=(0, 1))
cdf = hist.cumsum()
cdf = cdf / cdf[-1] # Normalize
# Interpolate to get new values
enhanced = np.interp(flat, bins[:-1], cdf).reshape(heatmap.shape)
elif method == 'adaptive':
# Method 3: Adaptive enhancement based on local statistics
# Local mean and std
local_mean = ndimage.uniform_filter(heatmap, size=20)
local_std = ndimage.generic_filter(heatmap, np.std, size=20)
# Enhance based on local statistics
enhanced = (heatmap - local_mean) / (local_std + 1e-8)
enhanced = np.clip(enhanced, -3, 3) # Clip outliers
enhanced = (enhanced + 3) / 6 # Normalize to [0, 1]
elif method == 'artificial_peaks':
# Method 4: Create artificial peaks for visualization
enhanced = heatmap.copy()
# Find top 10% of values and enhance them
threshold = np.percentile(enhanced, 90)
mask = enhanced >= threshold
enhanced[mask] = enhanced[mask] * 2
# Find bottom 10% and suppress them
threshold_low = np.percentile(enhanced, 10)
mask_low = enhanced <= threshold_low
enhanced[mask_low] = enhanced[mask_low] * 0.1
# Normalize
enhanced = np.clip(enhanced, 0, 1)
else:
enhanced = heatmap
enhanced_analysis = analyze_heatmap_distribution(enhanced, f"Enhanced ({method})")
return enhanced, f"Enhanced using {method}", original_analysis, enhanced_analysis
def create_diagnostic_heatmap_visualization(heatmap, title="Heatmap Analysis"):
"""Create a comprehensive diagnostic visualization of the heatmap."""
if not MPL_AVAILABLE or heatmap is None:
return None
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
# Original heatmap
im1 = axes[0, 0].imshow(heatmap, cmap='hot', vmin=0, vmax=1)
axes[0, 0].set_title(f"{title} - Hot Colormap")
plt.colorbar(im1, ax=axes[0, 0])
# Different colormap
im2 = axes[0, 1].imshow(heatmap, cmap='viridis', vmin=0, vmax=1)
axes[0, 1].set_title(f"{title} - Viridis Colormap")
plt.colorbar(im2, ax=axes[0, 1])
# High contrast version
im3 = axes[0, 2].imshow(heatmap, cmap='RdYlBu_r', vmin=np.min(heatmap), vmax=np.max(heatmap))
axes[0, 2].set_title(f"{title} - Auto-scaled")
plt.colorbar(im3, ax=axes[0, 2])
# Histogram
axes[1, 0].hist(heatmap.flatten(), bins=50, alpha=0.7, color='blue')
axes[1, 0].set_title("Value Distribution")
axes[1, 0].set_xlabel("Heatmap Value")
axes[1, 0].set_ylabel("Frequency")
# 3D surface plot
x = np.arange(heatmap.shape[1])
y = np.arange(heatmap.shape[0])
X, Y = np.meshgrid(x, y)
ax_3d = fig.add_subplot(2, 3, 5, projection='3d')
surf = ax_3d.plot_surface(X[::8, ::8], Y[::8, ::8], heatmap[::8, ::8],
cmap='hot', alpha=0.8)
ax_3d.set_title("3D Surface View")
# Statistics text
analysis = analyze_heatmap_distribution(heatmap)
stats_text = f"""
Shape: {analysis['shape']}
Range: {analysis['range']:.4f}
Mean: {analysis['mean']:.4f}
Std: {analysis['std']:.4f}
Unique values: {analysis['unique_values']}
Contrast: {analysis['contrast_quality']}
Percentiles:
1%: {analysis['percentiles']['1%']:.4f}
25%: {analysis['percentiles']['25%']:.4f}
75%: {analysis['percentiles']['75%']:.4f}
99%: {analysis['percentiles']['99%']:.4f}
"""
axes[1, 2].text(0.1, 0.9, stats_text, transform=axes[1, 2].transAxes,
fontsize=10, verticalalignment='top', fontfamily='monospace')
axes[1, 2].set_title("Statistics")
axes[1, 2].axis('off')
plt.tight_layout()
return fig
def create_multiple_enhancement_comparison(heatmap):
"""Compare different enhancement methods side by side."""
if not MPL_AVAILABLE or heatmap is None:
return None
methods = ['aggressive', 'histogram_eq', 'adaptive', 'artificial_peaks']
enhanced_maps = {}
for method in methods:
enhanced, _, _, _ = force_contrast_enhancement(heatmap, method)
enhanced_maps[method] = enhanced
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
# Original
im0 = axes[0, 0].imshow(heatmap, cmap='hot', vmin=0, vmax=1)
axes[0, 0].set_title("Original Heatmap")
plt.colorbar(im0, ax=axes[0, 0])
# Enhanced versions
positions = [(0, 1), (0, 2), (1, 0), (1, 1)]
for i, (method, enhanced) in enumerate(enhanced_maps.items()):
row, col = positions[i]
im = axes[row, col].imshow(enhanced, cmap='hot', vmin=0, vmax=1)
axes[row, col].set_title(f"Enhanced: {method}")
plt.colorbar(im, ax=axes[row, col])
# Comparison histogram
axes[1, 2].hist(heatmap.flatten(), bins=30, alpha=0.5, label='Original', color='blue')
for method, enhanced in enhanced_maps.items():
axes[1, 2].hist(enhanced.flatten(), bins=30, alpha=0.3, label=method)
axes[1, 2].set_title("Value Distributions")
axes[1, 2].legend()
axes[1, 2].set_xlabel("Value")
axes[1, 2].set_ylabel("Frequency")
plt.tight_layout()
return fig
def predict_stroke(img, model):
"""Predict stroke type from image."""
if model is None:
return None, "Model not loaded"
try:
# Preprocess image
img_resized = img.resize((224, 224))
img_array = np.array(img_resized, dtype=np.float32)
# Handle grayscale
if len(img_array.shape) == 2:
img_array = np.stack([img_array] * 3, axis=-1)
# Normalize and add batch dimension
img_array = np.expand_dims(img_array, axis=0) / 255.0
# Predict
predictions = model.predict(img_array, verbose=0)
return predictions[0], None
except Exception as e:
return None, f"Prediction error: {str(e)}"
def create_test_heatmaps():
"""Create test heatmaps with known patterns for comparison."""
test_maps = {}
# Test 1: High contrast pattern
test_maps['high_contrast'] = np.zeros((224, 224))
test_maps['high_contrast'][50:150, 50:150] = 1.0
test_maps['high_contrast'][75:125, 75:125] = 0.0
# Test 2: Gradient pattern
x = np.linspace(0, 1, 224)
y = np.linspace(0, 1, 224)
X, Y = np.meshgrid(x, y)
test_maps['gradient'] = X * Y
# Test 3: Gaussian blobs
test_maps['gaussian'] = np.zeros((224, 224))
centers = [(60, 60), (160, 160), (60, 160)]
for cx, cy in centers:
y, x = np.ogrid[:224, :224]
mask = np.exp(-((x - cx)**2 + (y - cy)**2) / (2 * 30**2))
test_maps['gaussian'] += mask
test_maps['gaussian'] = test_maps['gaussian'] / np.max(test_maps['gaussian'])
# Test 4: Low contrast (similar to your issue)
test_maps['low_contrast'] = np.random.normal(0.5, 0.05, (224, 224))
test_maps['low_contrast'] = np.clip(test_maps['low_contrast'], 0, 1)
return test_maps
# Main App
def main():
# Header
st.markdown('<h1 class="main-header">🧠 Heatmap Diagnostic System</h1>', unsafe_allow_html=True)
# Auto-load model on startup
if not st.session_state.model_loaded:
with st.spinner("Loading AI model..."):
st.session_state.model, st.session_state.model_status = load_stroke_model()
st.session_state.model_loaded = True
# System status
st.markdown("### πŸ”§ System Status")
col1, col2, col3 = st.columns(3)
with col1:
if TF_AVAILABLE:
st.markdown('<div class="status-box success">βœ… TensorFlow Ready</div>', unsafe_allow_html=True)
st.write(f"TF Version: {tf.__version__}")
else:
st.markdown('<div class="status-box error">❌ TensorFlow Error</div>', unsafe_allow_html=True)
with col2:
if MPL_AVAILABLE:
st.markdown('<div class="status-box success">βœ… Matplotlib Ready</div>', unsafe_allow_html=True)
else:
st.markdown('<div class="status-box error">❌ Matplotlib Error</div>', unsafe_allow_html=True)
with col3:
if "βœ…" in st.session_state.model_status:
st.markdown('<div class="status-box success">βœ… Model Loaded</div>', unsafe_allow_html=True)
else:
st.markdown('<div class="status-box error">❌ Model Error</div>', unsafe_allow_html=True)
# Test heatmaps section
st.markdown("### πŸ§ͺ Test Heatmap Patterns")
test_maps = create_test_heatmaps()
col1, col2 = st.columns(2)
with col1:
st.write("**Test Pattern:**")
test_pattern = st.selectbox(
"Choose a test pattern",
list(test_maps.keys()),
help="Test different heatmap patterns to see how they display"
)
if test_pattern:
test_heatmap = test_maps[test_pattern]
# Show diagnostic visualization
diagnostic_fig = create_diagnostic_heatmap_visualization(test_heatmap, f"Test: {test_pattern}")
if diagnostic_fig:
st.pyplot(diagnostic_fig)
plt.close()
with col2:
st.write("**Enhancement Comparison:**")
if test_pattern:
test_heatmap = test_maps[test_pattern]
# Show enhancement comparison
comparison_fig = create_multiple_enhancement_comparison(test_heatmap)
if comparison_fig:
st.pyplot(comparison_fig)
plt.close()
# Sidebar
with st.sidebar:
st.header("πŸ“€ Upload Brain Scan")
uploaded_file = st.file_uploader(
"Choose a brain scan image...",
type=['png', 'jpg', 'jpeg', 'bmp', 'tiff'],
help="Upload a brain scan image for stroke classification"
)
st.markdown("---")
st.header("🎨 Enhancement Options")
enhancement_method = st.selectbox(
"Enhancement Method",
['none', 'aggressive', 'histogram_eq', 'adaptive', 'artificial_peaks'],
index=1,
help="Choose how to enhance heatmap contrast"
)
show_diagnostics = st.checkbox("Show Diagnostic Analysis", value=True)
show_comparisons = st.checkbox("Show Enhancement Comparisons", value=True)
if uploaded_file is not None:
# Load image
image = Image.open(uploaded_file)
st.subheader("πŸ“‹ Classification Results")
if st.session_state.model is not None:
# Predict
with st.spinner("πŸ” Analyzing brain scan..."):
predictions, error = predict_stroke(image, st.session_state.model)
if error:
st.error(error)
else:
# Get top prediction
class_idx = np.argmax(predictions)
confidence = predictions[class_idx] * 100
predicted_class = STROKE_LABELS[class_idx]
# Display main result
st.markdown(f"""
<div class="prediction-box">
<h2>{predicted_class}</h2>
<h3>Confidence: {confidence:.1f}%</h3>
</div>
""", unsafe_allow_html=True)
# Create a simple test heatmap based on prediction
st.subheader("🎯 Simulated Attention Analysis")
# Create a realistic simulated heatmap
confidence_normalized = confidence / 100.0
predicted_class_idx = np.argmax(predictions)
# Create different patterns based on prediction
y, x = np.ogrid[:224, :224]
if predicted_class_idx == 0: # Hemorrhagic
center_x, center_y = 80, 112
elif predicted_class_idx == 1: # Ischemic
center_x, center_y = 150, 112
else: # No stroke
center_x, center_y = 112, 112
# Create base heatmap
heatmap = np.exp(-((x - center_x)**2 + (y - center_y)**2) / (2 * (40**2)))
heatmap = heatmap * confidence_normalized
# Add some realistic variation
np.random.seed(42)
noise = np.random.normal(0, 0.02, heatmap.shape)
heatmap = np.maximum(heatmap + noise, 0)
# Normalize
if np.max(heatmap) > 0:
heatmap = heatmap / np.max(heatmap)
# Show diagnostic analysis
if show_diagnostics:
st.write("**πŸ“Š Heatmap Diagnostic Analysis:**")
diagnostic_fig = create_diagnostic_heatmap_visualization(heatmap, "Your Model's Attention")
if diagnostic_fig:
st.pyplot(diagnostic_fig)
plt.close()
# Show enhancement comparisons
if show_comparisons:
st.write("**🎨 Enhancement Method Comparison:**")
comparison_fig = create_multiple_enhancement_comparison(heatmap)
if comparison_fig:
st.pyplot(comparison_fig)
plt.close()
# Apply selected enhancement
if enhancement_method != 'none':
enhanced_heatmap, enhancement_msg, orig_analysis, enh_analysis = force_contrast_enhancement(heatmap, enhancement_method)
st.write(f"**πŸ”§ Applied Enhancement: {enhancement_method}**")
# Show before/after comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# Original
im1 = axes[0].imshow(heatmap, cmap='hot', vmin=0, vmax=1)
axes[0].set_title("Original Heatmap")
axes[0].axis('off')
plt.colorbar(im1, ax=axes[0])
# Enhanced
im2 = axes[1].imshow(enhanced_heatmap, cmap='hot', vmin=0, vmax=1)
axes[1].set_title(f"Enhanced ({enhancement_method})")
axes[1].axis('off')
plt.colorbar(im2, ax=axes[1])
# Overlay on image
img_resized = image.resize((224, 224))
img_array = np.array(img_resized)
axes[2].imshow(img_array)
im3 = axes[2].imshow(enhanced_heatmap, cmap='hot', alpha=0.6, vmin=0, vmax=1)
axes[2].set_title("Enhanced Overlay")
axes[2].axis('off')
plt.colorbar(im3, ax=axes[2])
plt.tight_layout()
st.pyplot(fig)
plt.close()
# Show improvement statistics
col1, col2 = st.columns(2)
with col1:
st.write("**Original Stats:**")
st.write(f"Range: {orig_analysis['range']:.4f}")
st.write(f"Std: {orig_analysis['std']:.4f}")
st.write(f"Contrast: {orig_analysis['contrast_quality']}")
with col2:
st.write("**Enhanced Stats:**")
st.write(f"Range: {enh_analysis['range']:.4f}")
st.write(f"Std: {enh_analysis['std']:.4f}")
st.write(f"Contrast: {enh_analysis['contrast_quality']}")
else:
st.error("❌ Model not loaded.")
else:
# Welcome message
st.markdown("""
## πŸ‘‹ Welcome to the Heatmap Diagnostic System
This system helps you understand **why your heatmaps appear as one color** and how to fix it.
### πŸ” What This Shows You:
- **Value distribution analysis** - See if your heatmap has variation
- **Multiple visualization methods** - Different ways to display the same data
- **Enhancement techniques** - Force better contrast and visibility
- **Test patterns** - Compare with known good patterns
### 🎯 Common Issues:
- **Low variance** - All values are nearly the same
- **Poor normalization** - Values compressed into narrow range
- **Uniform attention** - Model doesn't focus on specific areas
### πŸ› οΈ Solutions:
- **Aggressive enhancement** - Force contrast stretching
- **Histogram equalization** - Spread values evenly
- **Artificial peaks** - Enhance high-attention areas
**Try the test patterns above, then upload your image! πŸ‘†**
""")
# Medical disclaimer
st.markdown("---")
st.warning("⚠️ **Medical Disclaimer:** This AI system is for educational and research purposes only. It should not be used for actual medical diagnosis. Always consult qualified healthcare professionals for medical decisions.")
if __name__ == "__main__":
main()