arshenoy commited on
Commit
2fcec1e
·
verified ·
1 Parent(s): 92dc792

Update cerebAI.py

Browse files
Files changed (1) hide show
  1. cerebAI.py +88 -55
cerebAI.py CHANGED
@@ -11,34 +11,40 @@ import albumentations as A
11
  from albumentations.pytorch import ToTensorV2
12
  import os
13
  import requests
14
- import pydicom # REQUIRED FOR DICOM SUPPORT
15
  import io
16
- import gc # For memory management
17
 
18
  # --- CONFIGURATION ---
19
  HF_MODEL_URL = "https://huggingface.co/arshenoy/cerebAI-stroke-model/resolve/main/best_model.pth"
20
  DOWNLOAD_MODEL_PATH = "best_model_cache.pth"
21
  CLASS_LABELS = ['No Stroke', 'Ischemic Stroke', 'Hemorrhagic Stroke']
22
  IMAGE_SIZE = 224
23
- DEVICE = torch.device("cpu") # For Streamlit Cloud stability
24
 
25
- # --- MODEL LOADING ---
 
 
 
 
 
 
 
 
26
  @st.cache_resource
27
  def load_model(model_url, local_path):
28
  """Downloads model from URL if not cached, and loads the weights."""
29
-
30
  if not os.path.exists(local_path):
31
  st.info(f"Model not found locally. Downloading from remote repository...")
32
  try:
33
  response = requests.get(model_url, stream=True)
34
  response.raise_for_status()
35
-
36
  with open(local_path, "wb") as f:
37
  for chunk in response.iter_content(chunk_size=8192):
38
  f.write(chunk)
39
  st.success("Model download complete!")
40
- except Exception as e:
41
- st.error(f"FATAL ERROR: Could not download model. Check the URL. Error: {e}")
42
  return None
43
 
44
  try:
@@ -48,11 +54,11 @@ def load_model(model_url, local_path):
48
  model.to(DEVICE)
49
  model.eval()
50
  return model
51
- except Exception as e:
52
- st.error(f"Failed to load model weights from cache. Error: {e}")
53
  return None
54
 
55
- # --- HELPER FUNCTIONS ---
56
 
57
  def denormalize_image(tensor: torch.Tensor) -> np.ndarray:
58
  """Denormalizes a PyTorch tensor for matplotlib visualization."""
@@ -69,24 +75,19 @@ def denormalize_image(tensor: torch.Tensor) -> np.ndarray:
69
  def preprocess_image(image_bytes: bytes, file_name: str) -> Tuple[Optional[torch.Tensor], Optional[np.ndarray]]:
70
  """Loads, processes, and normalizes image, handling DICOM or JPG/PNG."""
71
 
72
- # 1. READ IMAGE DATA (Handles DICOM vs Standard formats)
73
  if file_name.lower().endswith(('.dcm', '.dicom')):
74
  try:
75
  dcm = pydicom.dcmread(io.BytesIO(image_bytes))
76
-
77
- # FIX: Convert to Hounsfield Units (HU)
78
  pixel_array = dcm.pixel_array.astype(np.int16)
79
  slope = dcm.RescaleSlope
80
  intercept = dcm.RescaleIntercept
81
  pixel_array = pixel_array * slope + intercept
82
 
83
- # Apply Standard Brain Window (-100 HU to 150 HU)
84
  window_center = 40
85
  window_width = 150
86
  min_hu = window_center - (window_width / 2)
87
  max_hu = window_center + (window_width / 2)
88
 
89
- # Apply the windowing transformation and scale to 0-255
90
  pixel_array[pixel_array < min_hu] = min_hu
91
  pixel_array[pixel_array > max_hu] = max_hu
92
  image_grayscale = ((pixel_array - min_hu) / (max_hu - min_hu) * 255).astype(np.uint8)
@@ -94,13 +95,10 @@ def preprocess_image(image_bytes: bytes, file_name: str) -> Tuple[Optional[torch
94
  except Exception:
95
  return None, None
96
  else:
97
- # Read standard image (PNG/JPG)
98
  image_grayscale = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_GRAYSCALE)
99
  if image_grayscale is None: return None, None
100
 
101
- # 2. STANDARD PREPROCESSING
102
  image_rgb = cv2.cvtColor(cv2.resize(image_grayscale, (IMAGE_SIZE, IMAGE_SIZE)), cv2.COLOR_GRAY2RGB)
103
-
104
  image_norm = (image_rgb.astype(np.float32) / 255.0 - 0.5) / 0.5
105
  input_tensor = torch.tensor(image_norm, dtype=torch.float).permute(2, 0, 1).unsqueeze(0)
106
 
@@ -147,6 +145,46 @@ def plot_heatmap_and_original(original_image: np.ndarray, heatmap: np.ndarray, p
147
  plt.tight_layout()
148
  return fig
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  # ==============================================================================
151
  # -------------------- STREAMLIT FRONTEND --------------------
152
  # ==============================================================================
@@ -155,20 +193,20 @@ st.set_page_config(page_title="CerebAI: Stroke Prediction Dashboard", layout="wi
155
  st.title("CerebAI: AI-Powered Stroke Detection")
156
  st.markdown("---")
157
 
158
- # FIX: Load the model using the download mechanism
159
  model = load_model(HF_MODEL_URL, DOWNLOAD_MODEL_PATH)
160
 
161
  if model is not None:
162
- # --- INTERACTIVE CONTROLS (Sidebar or Main Area) ---
163
  st.markdown("### Analysis Controls")
164
 
165
  n_steps_slider = st.slider(
166
  'Integration Steps (Affects Accuracy & Speed)',
167
  min_value=5,
168
  max_value=50,
169
- value=20,
170
  step=5,
171
- help="Higher steps (up to 50) provide a smoother, more accurate heatmap but use more CPU."
 
172
  )
173
  st.markdown("---")
174
 
@@ -177,51 +215,46 @@ if model is not None:
177
  st.markdown("### Upload CT Scan Image")
178
  uploaded_file = st.file_uploader(
179
  "Choose a Dicom, PNG, JPG, or JPEG file",
180
- type=["dcm", "dicom", "png", "jpg", "jpeg"]
 
181
  )
182
-
 
183
  if uploaded_file is not None:
184
- image_bytes = uploaded_file.read()
185
- file_name = uploaded_file.name
186
 
187
- # 1. PROCESS IMAGE FIRST (Defines original_image_rgb)
188
- input_tensor, original_image_rgb = preprocess_image(image_bytes, file_name)
 
 
 
 
 
189
 
190
- # --- DISPLAY AND RESULTS LAYOUT ---
191
  col1, col2 = st.columns(2)
192
 
193
  with col1:
194
  st.subheader("Uploaded Image")
195
- # Display the processed NumPy array
196
- st.image(original_image_rgb, use_container_width=True, caption=file_name)
197
-
198
- # Run Prediction and Attribution
199
- if input_tensor is not None:
200
- # Predict
201
- with torch.no_grad():
202
- output = model(input_tensor)
203
- probabilities = torch.softmax(output, dim=1).squeeze(0).cpu().numpy()
204
- predicted_class_idx = np.argmax(probabilities)
205
-
206
- predicted_label = CLASS_LABELS[predicted_class_idx]
207
- confidence_score = probabilities[predicted_class_idx]
208
 
209
- # Generate Attribution
210
- heatmap = generate_attribution(model, input_tensor, predicted_class_idx, n_steps=n_steps_slider)
 
 
 
 
 
 
 
211
 
212
- # CRITICAL MEMORY MANAGEMENT
213
- del input_tensor
214
- del output
215
- gc.collect()
216
-
217
  with col2:
218
  st.subheader("Prediction Summary")
219
 
220
  st.metric(
221
  label="Diagnosis",
222
- value=predicted_label,
223
- delta=f"{confidence_score*100:.2f}% Confidence",
224
- delta_color='normal'
225
  )
226
 
227
  st.markdown("---")
@@ -229,7 +262,7 @@ if model is not None:
229
 
230
  prob_data = {
231
  'Class': CLASS_LABELS,
232
- 'Confidence': [f"{p:.4f}" for p in probabilities]
233
  }
234
  st.dataframe(prob_data, hide_index=True, use_container_width=True)
235
 
@@ -237,7 +270,7 @@ if model is not None:
237
  st.markdown("---")
238
  st.subheader("Model Interpretation (Integrated Gradients)")
239
 
240
- fig = plot_heatmap_and_original(original_image_rgb, heatmap, predicted_label)
241
  st.pyplot(fig, clear_figure=True, use_container_width=True)
242
 
243
  st.success("Analysis Complete: The heatmap highlights the regions most critical to the diagnosis.")
 
11
  from albumentations.pytorch import ToTensorV2
12
  import os
13
  import requests
14
+ import pydicom
15
  import io
16
+ import gc
17
 
18
  # --- CONFIGURATION ---
19
  HF_MODEL_URL = "https://huggingface.co/arshenoy/cerebAI-stroke-model/resolve/main/best_model.pth"
20
  DOWNLOAD_MODEL_PATH = "best_model_cache.pth"
21
  CLASS_LABELS = ['No Stroke', 'Ischemic Stroke', 'Hemorrhagic Stroke']
22
  IMAGE_SIZE = 224
23
+ DEVICE = torch.device("cpu")
24
 
25
+ # --- INITIALIZE SESSION STATE (CRITICAL FOR STABILITY) ---
26
+ if 'results_ready' not in st.session_state:
27
+ st.session_state.results_ready = False
28
+ st.session_state.results = {}
29
+ st.session_state.input_bytes = None
30
+ st.session_state.file_name = None
31
+
32
+
33
+ # --- MODEL LOADING (UNCHANGED) ---
34
  @st.cache_resource
35
  def load_model(model_url, local_path):
36
  """Downloads model from URL if not cached, and loads the weights."""
 
37
  if not os.path.exists(local_path):
38
  st.info(f"Model not found locally. Downloading from remote repository...")
39
  try:
40
  response = requests.get(model_url, stream=True)
41
  response.raise_for_status()
 
42
  with open(local_path, "wb") as f:
43
  for chunk in response.iter_content(chunk_size=8192):
44
  f.write(chunk)
45
  st.success("Model download complete!")
46
+ except Exception:
47
+ st.error(f"FATAL ERROR: Could not download model.")
48
  return None
49
 
50
  try:
 
54
  model.to(DEVICE)
55
  model.eval()
56
  return model
57
+ except Exception:
58
+ st.error(f"Failed to load model weights.")
59
  return None
60
 
61
+ # --- CORE LOGIC FUNCTIONS ---
62
 
63
  def denormalize_image(tensor: torch.Tensor) -> np.ndarray:
64
  """Denormalizes a PyTorch tensor for matplotlib visualization."""
 
75
  def preprocess_image(image_bytes: bytes, file_name: str) -> Tuple[Optional[torch.Tensor], Optional[np.ndarray]]:
76
  """Loads, processes, and normalizes image, handling DICOM or JPG/PNG."""
77
 
 
78
  if file_name.lower().endswith(('.dcm', '.dicom')):
79
  try:
80
  dcm = pydicom.dcmread(io.BytesIO(image_bytes))
 
 
81
  pixel_array = dcm.pixel_array.astype(np.int16)
82
  slope = dcm.RescaleSlope
83
  intercept = dcm.RescaleIntercept
84
  pixel_array = pixel_array * slope + intercept
85
 
 
86
  window_center = 40
87
  window_width = 150
88
  min_hu = window_center - (window_width / 2)
89
  max_hu = window_center + (window_width / 2)
90
 
 
91
  pixel_array[pixel_array < min_hu] = min_hu
92
  pixel_array[pixel_array > max_hu] = max_hu
93
  image_grayscale = ((pixel_array - min_hu) / (max_hu - min_hu) * 255).astype(np.uint8)
 
95
  except Exception:
96
  return None, None
97
  else:
 
98
  image_grayscale = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_GRAYSCALE)
99
  if image_grayscale is None: return None, None
100
 
 
101
  image_rgb = cv2.cvtColor(cv2.resize(image_grayscale, (IMAGE_SIZE, IMAGE_SIZE)), cv2.COLOR_GRAY2RGB)
 
102
  image_norm = (image_rgb.astype(np.float32) / 255.0 - 0.5) / 0.5
103
  input_tensor = torch.tensor(image_norm, dtype=torch.float).permute(2, 0, 1).unsqueeze(0)
104
 
 
145
  plt.tight_layout()
146
  return fig
147
 
148
+ # -------------------- CORE DIAGNOSIS FUNCTION (TO BE CALLED BY BUTTON) --------------------
149
+
150
+ def perform_full_analysis(model, image_bytes, file_name, n_steps_slider):
151
+ """Function called by the button to perform heavy computation and update state."""
152
+
153
+ # 1. PREPARE INPUTS (The Jitter-Causing step)
154
+ input_tensor, original_image_rgb = preprocess_image(image_bytes, file_name)
155
+
156
+ if input_tensor is None:
157
+ st.error("Could not process file. Ensure it is a valid DICOM/PNG/JPG.")
158
+ st.session_state.results_ready = False
159
+ return
160
+
161
+ with st.spinner(f'Calculating Integrated Gradients ({n_steps_slider} steps)...'):
162
+
163
+ # 2. PREDICT & GENERATE ATTRIBUTION
164
+ with torch.no_grad():
165
+ output = model(input_tensor)
166
+ probabilities = torch.softmax(output, dim=1).squeeze(0).cpu().numpy()
167
+ predicted_class_idx = np.argmax(probabilities)
168
+
169
+ heatmap = generate_attribution(model, input_tensor, predicted_class_idx, n_steps=n_steps_slider)
170
+
171
+ # 3. CRITICAL MEMORY MANAGEMENT
172
+ del input_tensor
173
+ del output
174
+ gc.collect()
175
+
176
+ # 4. STORE FINAL RESULTS TO STATE AND RERUN
177
+ st.session_state.results = {
178
+ 'label': CLASS_LABELS[predicted_class_idx],
179
+ 'confidence': probabilities[predicted_class_idx],
180
+ 'probabilities': probabilities,
181
+ 'image_rgb': original_image_rgb,
182
+ 'heatmap': heatmap
183
+ }
184
+ st.session_state.results_ready = True
185
+ st.rerun()
186
+
187
+
188
  # ==============================================================================
189
  # -------------------- STREAMLIT FRONTEND --------------------
190
  # ==============================================================================
 
193
  st.title("CerebAI: AI-Powered Stroke Detection")
194
  st.markdown("---")
195
 
 
196
  model = load_model(HF_MODEL_URL, DOWNLOAD_MODEL_PATH)
197
 
198
  if model is not None:
199
+ # --- INTERACTIVE CONTROLS ---
200
  st.markdown("### Analysis Controls")
201
 
202
  n_steps_slider = st.slider(
203
  'Integration Steps (Affects Accuracy & Speed)',
204
  min_value=5,
205
  max_value=50,
206
+ value=10,
207
  step=5,
208
+ key="n_steps_slider_key",
209
+ help="Lower steps (5-15) provide a smoother, more accurate heatmap but use more CPU."
210
  )
211
  st.markdown("---")
212
 
 
215
  st.markdown("### Upload CT Scan Image")
216
  uploaded_file = st.file_uploader(
217
  "Choose a Dicom, PNG, JPG, or JPEG file",
218
+ type=["dcm", "dicom", "png", "jpg", "jpeg"],
219
+ key="file_uploader"
220
  )
221
+
222
+ # --- UI FLOW MANAGEMENT ---
223
  if uploaded_file is not None:
 
 
224
 
225
+ # 1. READ AND CACHE INPUT BYTES (Fastest Operation)
226
+ st.session_state.input_bytes = uploaded_file.read()
227
+ st.session_state.file_name = uploaded_file.name
228
+
229
+ # Preprocess image for DISPLAY ONLY (fast operation, no lag)
230
+ # This part runs quickly and stabilizes the top half of the screen
231
+ _, original_image_rgb = preprocess_image(st.session_state.input_bytes, st.session_state.file_name)
232
 
233
+ # --- DISPLAY UPLOADED IMAGE AND BUTTON ---
234
  col1, col2 = st.columns(2)
235
 
236
  with col1:
237
  st.subheader("Uploaded Image")
238
+ st.image(original_image_rgb, use_container_width=True, caption=st.session_state.file_name)
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
+ # --- THE GATED BUTTON (CRITICAL) ---
241
+ st.markdown("---")
242
+ if st.button("▶️ RUN FULL DIAGNOSIS & ATTRIBUTION", type="primary", use_container_width=True, key="run_button"):
243
+ # Call the heavy analysis function
244
+ perform_full_analysis(model, st.session_state.input_bytes, st.session_state.file_name, n_steps_slider)
245
+
246
+ # --- RESULTS DISPLAY (STABLE) ---
247
+ if st.session_state.results_ready:
248
+ data = st.session_state.results
249
 
 
 
 
 
 
250
  with col2:
251
  st.subheader("Prediction Summary")
252
 
253
  st.metric(
254
  label="Diagnosis",
255
+ value=data['label'],
256
+ delta=f"{data['confidence']*100:.2f}% Confidence",
257
+ delta_color='normal'
258
  )
259
 
260
  st.markdown("---")
 
262
 
263
  prob_data = {
264
  'Class': CLASS_LABELS,
265
+ 'Confidence': [f"{p:.4f}" for p in data['probabilities']]
266
  }
267
  st.dataframe(prob_data, hide_index=True, use_container_width=True)
268
 
 
270
  st.markdown("---")
271
  st.subheader("Model Interpretation (Integrated Gradients)")
272
 
273
+ fig = plot_heatmap_and_original(data['image_rgb'], data['heatmap'], data['label'])
274
  st.pyplot(fig, clear_figure=True, use_container_width=True)
275
 
276
  st.success("Analysis Complete: The heatmap highlights the regions most critical to the diagnosis.")