arshenoy commited on
Commit
0ea30c3
·
0 Parent(s):

Deploying cerebAI to streamlit

Browse files
Files changed (3) hide show
  1. README.md +0 -0
  2. cerebAI.py +200 -0
  3. requirements.txt +9 -0
README.md ADDED
File without changes
cerebAI.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ import cv2
6
+ import timm
7
+ import matplotlib.pyplot as plt
8
+ from captum.attr import IntegratedGradients
9
+ import albumentations as A
10
+ from albumentations.pytorch import ToTensorV2
11
+ from typing import Tuple, Optional
12
+
13
+ # --- CONFIGURATION ---
14
+ MODEL_PATH = "best_model.pth"
15
+ CLASS_LABELS = ['No Stroke', 'Ischemic Stroke', 'Hemorrhagic Stroke']
16
+ IMAGE_SIZE = 224
17
+ # Use CPU by default for stability in free deployment, but change this locally to 'cuda' for speed!
18
+ DEVICE = torch.device("cpu")
19
+
20
+ # --- MODEL LOADING ---
21
+ @st.cache_resource
22
+ def load_model(model_path):
23
+ """Loads the model architecture and saved weights."""
24
+ try:
25
+ model = timm.create_model('convnext_base', pretrained=False)
26
+ model.reset_classifier(num_classes=len(CLASS_LABELS))
27
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
28
+ model.to(DEVICE)
29
+ model.eval()
30
+ return model
31
+ except Exception as e:
32
+ st.error(f"Failed to load model. Check model file and path. Error: {e}")
33
+ return None
34
+
35
+ # --- HELPER FUNCTIONS ---
36
+
37
+ def denormalize_image(tensor: torch.Tensor) -> np.ndarray:
38
+ """Denormalizes a PyTorch tensor for matplotlib visualization."""
39
+ if tensor.ndim == 4:
40
+ tensor = tensor.squeeze(0)
41
+
42
+ mean, std = np.array([0.5, 0.5, 0.5]), np.array([0.5, 0.5, 0.5])
43
+ img = tensor.cpu().permute(1, 2, 0).numpy()
44
+ img = (img * std) + mean
45
+ return np.clip(img, 0, 1)
46
+
47
+ def preprocess_image(image_bytes: bytes) -> Tuple[Optional[torch.Tensor], Optional[np.ndarray]]:
48
+ """Loads, resizes, and normalizes the image for model input."""
49
+ image = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_GRAYSCALE)
50
+ if image is None: return None, None
51
+ image_rgb = cv2.cvtColor(cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE)), cv2.COLOR_GRAY2RGB)
52
+
53
+ image_norm = (image_rgb.astype(np.float32) / 255.0 - 0.5) / 0.5
54
+ input_tensor = torch.tensor(image_norm, dtype=torch.float).permute(2, 0, 1).unsqueeze(0)
55
+
56
+ return input_tensor.to(DEVICE), image_rgb
57
+
58
+ def generate_attribution(model: nn.Module, input_tensor: torch.Tensor, predicted_class_idx: int, n_steps: int = 20) -> np.ndarray:
59
+ """Computes Integrated Gradients for the given input and class."""
60
+
61
+ # CRITICAL FIX: Captum requires standard Python int, not numpy.int64
62
+ target_class_int = int(predicted_class_idx)
63
+
64
+ # CRITICAL: Enables gradient tracking for Captum
65
+ input_tensor.requires_grad_(True)
66
+
67
+ ig = IntegratedGradients(model)
68
+ baseline = torch.zeros_like(input_tensor).to(DEVICE)
69
+
70
+ attributions_ig = ig.attribute(
71
+ inputs=input_tensor,
72
+ baselines=baseline,
73
+ target=target_class_int,
74
+ n_steps=n_steps # Using dynamic or default steps
75
+ )
76
+
77
+ # Process Attributions: Sum across color channels and normalize the heatmap
78
+ attributions_ig_vis = attributions_ig.squeeze(0).sum(dim=0).abs().cpu().detach().numpy()
79
+
80
+ if attributions_ig_vis.max() > 0:
81
+ attributions_ig_vis = attributions_ig_vis / attributions_ig_vis.max()
82
+
83
+ return attributions_ig_vis
84
+
85
+ def plot_heatmap_and_original(original_image: np.ndarray, heatmap: np.ndarray, predicted_label: str):
86
+ """Creates a Matplotlib figure for visualization."""
87
+
88
+ # Use dynamic sizing for better responsiveness
89
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
90
+
91
+ # Convert image to 0-1 range for plotting
92
+ original_image_vis = (original_image.astype(np.float32) / 255.0)
93
+
94
+ # --- Plot 1: Original Image ---
95
+ ax1.imshow(original_image_vis)
96
+ ax1.set_title("Original CT Scan", fontsize=14)
97
+ ax1.axis('off')
98
+
99
+ # --- Plot 2: Integrated Gradients ---
100
+ ax2.imshow(original_image_vis)
101
+
102
+ # Dynamic alpha mask: fades out non-contributing regions
103
+ alpha_mask = heatmap * 0.7 + 0.3
104
+
105
+ # Aesthetic Fix: Use 'jet' colormap for clinical highlight (red/yellow)
106
+ ax2.imshow(heatmap, cmap='jet', alpha=alpha_mask, vmin=0, vmax=1)
107
+ ax2.set_title(f"Interpretation: {predicted_label}", fontsize=14)
108
+ ax2.axis('off')
109
+
110
+ plt.tight_layout()
111
+ return fig
112
+
113
+ # ==============================================================================
114
+ # -------------------- STREAMLIT FRONTEND --------------------
115
+ # ==============================================================================
116
+
117
+ st.set_page_config(page_title="CerebAI: Stroke Prediction Dashboard", layout="wide")
118
+ st.title("CerebAI: AI-Powered Stroke Detection")
119
+ st.markdown("---")
120
+
121
+ # Load the model
122
+ model = load_model(MODEL_PATH)
123
+
124
+ if model is not None:
125
+ # --- INTERACTIVE CONTROLS (Sidebar or Main Area) ---
126
+ st.markdown("### Analysis Controls")
127
+
128
+ n_steps_slider = st.slider(
129
+ 'Integration Steps (Affects Accuracy & Speed)',
130
+ min_value=5,
131
+ max_value=50,
132
+ value=20, # Default to a safe, medium-speed value
133
+ step=5,
134
+ help="Higher steps (up to 50) provide a smoother, more accurate heatmap but use more CPU."
135
+ )
136
+ st.markdown("---")
137
+
138
+
139
+ # --- FILE UPLOAD ---
140
+ st.markdown("### Upload CT Scan Image")
141
+ uploaded_file = st.file_uploader(
142
+ "Choose a PNG, JPG, or JPEG file",
143
+ type=["png", "jpg", "jpeg"]
144
+ )
145
+
146
+ if uploaded_file is not None:
147
+ image_bytes = uploaded_file.read()
148
+
149
+ # --- DISPLAY AND RESULTS LAYOUT ---
150
+ col1, col2 = st.columns(2) # Retaining old columns structure for familiar look
151
+
152
+ with col1:
153
+ st.subheader("Uploaded Image")
154
+ st.image(image_bytes, use_container_width=True) # Responsive fix
155
+
156
+ # Run Prediction and Attribution
157
+ input_tensor, original_image_rgb = preprocess_image(image_bytes)
158
+
159
+ if input_tensor is not None:
160
+ # Predict
161
+ with torch.no_grad():
162
+ output = model(input_tensor)
163
+ probabilities = torch.softmax(output, dim=1).squeeze(0).cpu().numpy()
164
+ predicted_class_idx = np.argmax(probabilities)
165
+
166
+ predicted_label = CLASS_LABELS[predicted_class_idx]
167
+ confidence_score = probabilities[predicted_class_idx]
168
+
169
+ # Generate Attribution
170
+ heatmap = generate_attribution(model, input_tensor, predicted_class_idx, n_steps=n_steps_slider)
171
+
172
+ with col2:
173
+ st.subheader("Prediction Summary")
174
+
175
+ # Metric based on prediction
176
+ st.metric(
177
+ label="Diagnosis",
178
+ value=predicted_label,
179
+ delta=f"{confidence_score*100:.2f}% Confidence",
180
+ delta_color='normal' # Let Streamlit choose color
181
+ )
182
+
183
+ st.markdown("---")
184
+ st.subheader("Confidence Breakdown")
185
+
186
+ # Display probabilities in a clean, professional table
187
+ prob_data = {
188
+ 'Class': CLASS_LABELS,
189
+ 'Confidence': [f"{p:.4f}" for p in probabilities]
190
+ }
191
+ st.dataframe(prob_data, hide_index=True, use_container_width=True)
192
+
193
+ # --- PLOT INTERPRETATION ---
194
+ st.markdown("---")
195
+ st.subheader("Model Interpretation (Integrated Gradients)")
196
+
197
+ fig = plot_heatmap_and_original(original_image_rgb, heatmap, predicted_label)
198
+ st.pyplot(fig, clear_figure=True, use_container_width=True) # Responsive Plot
199
+
200
+ st.success("Analysis Complete: The heatmap highlights the regions most critical to the diagnosis.")
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ torch==2.3.0
3
+ timm==1.0.20
4
+ numpy==1.24.4
5
+ opencv-python-headless
6
+ albumentations==1.3.1
7
+ captum
8
+ scikit-learn
9
+ matplotlib