arshenoy commited on
Commit
4a43619
·
verified ·
1 Parent(s): 70c5c31

Update cerebAI.py

Browse files
Files changed (1) hide show
  1. cerebAI.py +56 -82
cerebAI.py CHANGED
@@ -11,37 +11,34 @@ import albumentations as A
11
  from albumentations.pytorch import ToTensorV2
12
  import os
13
  import requests
14
- import pydicom
15
  import io
16
- import gc
17
 
18
- # --- CONFIGURATION & SESSION STATE ---
19
  HF_MODEL_URL = "https://huggingface.co/arshenoy/cerebAI-stroke-model/resolve/main/best_model.pth"
20
  DOWNLOAD_MODEL_PATH = "best_model_cache.pth"
21
  CLASS_LABELS = ['No Stroke', 'Ischemic Stroke', 'Hemorrhagic Stroke']
22
  IMAGE_SIZE = 224
23
- DEVICE = torch.device("cpu")
24
-
25
- # CRITICAL: Initialize Session State for UI stability
26
- if 'results_ready' not in st.session_state:
27
- st.session_state.results_ready = False
28
- st.session_state.results = {}
29
 
30
  # --- MODEL LOADING ---
31
  @st.cache_resource
32
  def load_model(model_url, local_path):
33
- # ... (Model loading logic remains the same) ...
 
34
  if not os.path.exists(local_path):
35
  st.info(f"Model not found locally. Downloading from remote repository...")
36
  try:
37
  response = requests.get(model_url, stream=True)
38
  response.raise_for_status()
 
39
  with open(local_path, "wb") as f:
40
  for chunk in response.iter_content(chunk_size=8192):
41
  f.write(chunk)
42
  st.success("Model download complete!")
43
- except Exception:
44
- st.error(f"FATAL ERROR: Could not download model.")
45
  return None
46
 
47
  try:
@@ -51,13 +48,14 @@ def load_model(model_url, local_path):
51
  model.to(DEVICE)
52
  model.eval()
53
  return model
54
- except Exception:
55
- st.error(f"Failed to load model weights.")
56
  return None
57
 
58
- # --- CORE HELPER FUNCTIONS (PREPROCESSING/PLOTTING) ---
59
 
60
  def denormalize_image(tensor: torch.Tensor) -> np.ndarray:
 
61
  if tensor.ndim == 4:
62
  tensor = tensor.squeeze(0).detach()
63
  else:
@@ -69,20 +67,26 @@ def denormalize_image(tensor: torch.Tensor) -> np.ndarray:
69
  return np.clip(img, 0, 1)
70
 
71
  def preprocess_image(image_bytes: bytes, file_name: str) -> Tuple[Optional[torch.Tensor], Optional[np.ndarray]]:
 
72
 
 
73
  if file_name.lower().endswith(('.dcm', '.dicom')):
74
  try:
75
  dcm = pydicom.dcmread(io.BytesIO(image_bytes))
 
 
76
  pixel_array = dcm.pixel_array.astype(np.int16)
77
  slope = dcm.RescaleSlope
78
  intercept = dcm.RescaleIntercept
79
  pixel_array = pixel_array * slope + intercept
80
 
 
81
  window_center = 40
82
  window_width = 150
83
  min_hu = window_center - (window_width / 2)
84
  max_hu = window_center + (window_width / 2)
85
 
 
86
  pixel_array[pixel_array < min_hu] = min_hu
87
  pixel_array[pixel_array > max_hu] = max_hu
88
  image_grayscale = ((pixel_array - min_hu) / (max_hu - min_hu) * 255).astype(np.uint8)
@@ -90,16 +94,20 @@ def preprocess_image(image_bytes: bytes, file_name: str) -> Tuple[Optional[torch
90
  except Exception:
91
  return None, None
92
  else:
 
93
  image_grayscale = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_GRAYSCALE)
94
  if image_grayscale is None: return None, None
95
 
 
96
  image_rgb = cv2.cvtColor(cv2.resize(image_grayscale, (IMAGE_SIZE, IMAGE_SIZE)), cv2.COLOR_GRAY2RGB)
 
97
  image_norm = (image_rgb.astype(np.float32) / 255.0 - 0.5) / 0.5
98
  input_tensor = torch.tensor(image_norm, dtype=torch.float).permute(2, 0, 1).unsqueeze(0)
99
 
100
  return input_tensor.to(DEVICE), image_rgb
101
 
102
  def generate_attribution(model: nn.Module, input_tensor: torch.Tensor, predicted_class_idx: int, n_steps: int = 20) -> np.ndarray:
 
103
  target_class_int = int(predicted_class_idx)
104
  input_tensor.requires_grad_(True)
105
 
@@ -121,6 +129,7 @@ def generate_attribution(model: nn.Module, input_tensor: torch.Tensor, predicted
121
  return attributions_ig_vis
122
 
123
  def plot_heatmap_and_original(original_image: np.ndarray, heatmap: np.ndarray, predicted_label: str):
 
124
  fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
125
  original_image_vis = (original_image.astype(np.float32) / 255.0)
126
 
@@ -138,46 +147,6 @@ def plot_heatmap_and_original(original_image: np.ndarray, heatmap: np.ndarray, p
138
  plt.tight_layout()
139
  return fig
140
 
141
- # -------------------- CORE DIAGNOSIS FUNCTION (GATED LOGIC) --------------------
142
-
143
- def perform_full_analysis(model, image_bytes, file_name, n_steps_slider):
144
- """Function called by the button to perform heavy computation and update state."""
145
-
146
- # 1. PREPARE INPUTS (Run preprocessing inside the spinner)
147
- input_tensor, original_image_rgb = preprocess_image(image_bytes, file_name)
148
-
149
- if input_tensor is None:
150
- st.error("Could not process file. Ensure it is a valid DICOM/PNG/JPG.")
151
- st.session_state.results_ready = False
152
- return
153
-
154
- with st.spinner(f'Calculating Integrated Gradients ({n_steps_slider} steps)...'):
155
-
156
- # 2. PREDICT & GENERATE ATTRIBUTION
157
- with torch.no_grad():
158
- output = model(input_tensor)
159
- probabilities = torch.softmax(output, dim=1).squeeze(0).cpu().numpy()
160
- predicted_class_idx = np.argmax(probabilities)
161
-
162
- heatmap = generate_attribution(model, input_tensor, predicted_class_idx, n_steps=n_steps_slider)
163
-
164
- # 3. CRITICAL MEMORY MANAGEMENT
165
- del input_tensor
166
- del output
167
- gc.collect()
168
-
169
- # 4. STORE FINAL RESULTS TO STATE AND RERUN
170
- st.session_state.results = {
171
- 'label': CLASS_LABELS[predicted_class_idx],
172
- 'confidence': probabilities[predicted_class_idx],
173
- 'probabilities': probabilities,
174
- 'image_rgb': original_image_rgb,
175
- 'heatmap': heatmap
176
- }
177
- st.session_state.results_ready = True
178
- st.rerun()
179
-
180
-
181
  # ==============================================================================
182
  # -------------------- STREAMLIT FRONTEND --------------------
183
  # ==============================================================================
@@ -186,20 +155,20 @@ st.set_page_config(page_title="CerebAI: Stroke Prediction Dashboard", layout="wi
186
  st.title("CerebAI: AI-Powered Stroke Detection")
187
  st.markdown("---")
188
 
 
189
  model = load_model(HF_MODEL_URL, DOWNLOAD_MODEL_PATH)
190
 
191
  if model is not None:
192
- # --- INPUT CONTROLS ---
193
  st.markdown("### Analysis Controls")
194
 
195
  n_steps_slider = st.slider(
196
  'Integration Steps (Affects Accuracy & Speed)',
197
  min_value=5,
198
  max_value=50,
199
- value=10,
200
  step=5,
201
- key="n_steps_slider_key",
202
- help="Lower steps (5-15) provide a smoother, more accurate heatmap but use more CPU."
203
  )
204
  st.markdown("---")
205
 
@@ -208,46 +177,51 @@ if model is not None:
208
  st.markdown("### Upload CT Scan Image")
209
  uploaded_file = st.file_uploader(
210
  "Choose a Dicom, PNG, JPG, or JPEG file",
211
- type=["dcm", "dicom", "png", "jpg", "jpeg"],
212
- key="file_uploader",
213
- on_change=lambda: setattr(st.session_state, 'results_ready', False) # Reset results on new file upload
214
  )
215
 
216
  if uploaded_file is not None:
217
-
218
- # 1. READ AND CACHE INPUT BYTES (Fastest Operation)
219
  image_bytes = uploaded_file.read()
220
  file_name = uploaded_file.name
221
 
222
- # Preprocess image for DISPLAY ONLY (fast operation, no lag)
223
- # This part ensures the image appears instantly without waiting for the model
224
- _, original_image_rgb = preprocess_image(image_bytes, file_name)
225
 
226
- # --- DISPLAY UPLOADED IMAGE AND BUTTON ---
227
  col1, col2 = st.columns(2)
228
 
229
  with col1:
230
  st.subheader("Uploaded Image")
 
231
  st.image(original_image_rgb, use_container_width=True, caption=file_name)
 
 
 
 
 
 
 
 
232
 
233
- # --- THE GATED BUTTON (CRITICAL FOR STABILITY) ---
234
- st.markdown("---")
235
- if st.button("▶️ RUN FULL DIAGNOSIS & ATTRIBUTION", type="primary", use_container_width=True, key="run_button"):
236
- # Execute the heavy logic and save to session state
237
- perform_full_analysis(model, image_bytes, file_name, n_steps_slider)
238
-
239
- # --- RESULTS DISPLAY (STABLE) ---
240
- if st.session_state.results_ready:
241
- data = st.session_state.results
242
 
 
 
 
 
 
 
 
 
243
  with col2:
244
  st.subheader("Prediction Summary")
245
 
246
  st.metric(
247
  label="Diagnosis",
248
- value=data['label'],
249
- delta=f"{data['confidence']*100:.2f}% Confidence",
250
- delta_color='normal'
251
  )
252
 
253
  st.markdown("---")
@@ -255,7 +229,7 @@ if model is not None:
255
 
256
  prob_data = {
257
  'Class': CLASS_LABELS,
258
- 'Confidence': [f"{p:.4f}" for p in data['probabilities']]
259
  }
260
  st.dataframe(prob_data, hide_index=True, use_container_width=True)
261
 
@@ -263,7 +237,7 @@ if model is not None:
263
  st.markdown("---")
264
  st.subheader("Model Interpretation (Integrated Gradients)")
265
 
266
- fig = plot_heatmap_and_original(data['image_rgb'], data['heatmap'], data['label'])
267
  st.pyplot(fig, clear_figure=True, use_container_width=True)
268
 
269
  st.success("Analysis Complete: The heatmap highlights the regions most critical to the diagnosis.")
 
11
  from albumentations.pytorch import ToTensorV2
12
  import os
13
  import requests
14
+ import pydicom # REQUIRED FOR DICOM SUPPORT
15
  import io
16
+ import gc # For memory management
17
 
18
+ # --- CONFIGURATION ---
19
  HF_MODEL_URL = "https://huggingface.co/arshenoy/cerebAI-stroke-model/resolve/main/best_model.pth"
20
  DOWNLOAD_MODEL_PATH = "best_model_cache.pth"
21
  CLASS_LABELS = ['No Stroke', 'Ischemic Stroke', 'Hemorrhagic Stroke']
22
  IMAGE_SIZE = 224
23
+ DEVICE = torch.device("cpu") # For Streamlit Cloud stability
 
 
 
 
 
24
 
25
  # --- MODEL LOADING ---
26
  @st.cache_resource
27
  def load_model(model_url, local_path):
28
+ """Downloads model from URL if not cached, and loads the weights."""
29
+
30
  if not os.path.exists(local_path):
31
  st.info(f"Model not found locally. Downloading from remote repository...")
32
  try:
33
  response = requests.get(model_url, stream=True)
34
  response.raise_for_status()
35
+
36
  with open(local_path, "wb") as f:
37
  for chunk in response.iter_content(chunk_size=8192):
38
  f.write(chunk)
39
  st.success("Model download complete!")
40
+ except Exception as e:
41
+ st.error(f"FATAL ERROR: Could not download model. Check the URL. Error: {e}")
42
  return None
43
 
44
  try:
 
48
  model.to(DEVICE)
49
  model.eval()
50
  return model
51
+ except Exception as e:
52
+ st.error(f"Failed to load model weights from cache. Error: {e}")
53
  return None
54
 
55
+ # --- HELPER FUNCTIONS ---
56
 
57
  def denormalize_image(tensor: torch.Tensor) -> np.ndarray:
58
+ """Denormalizes a PyTorch tensor for matplotlib visualization."""
59
  if tensor.ndim == 4:
60
  tensor = tensor.squeeze(0).detach()
61
  else:
 
67
  return np.clip(img, 0, 1)
68
 
69
  def preprocess_image(image_bytes: bytes, file_name: str) -> Tuple[Optional[torch.Tensor], Optional[np.ndarray]]:
70
+ """Loads, processes, and normalizes image, handling DICOM or JPG/PNG."""
71
 
72
+ # 1. READ IMAGE DATA (Handles DICOM vs Standard formats)
73
  if file_name.lower().endswith(('.dcm', '.dicom')):
74
  try:
75
  dcm = pydicom.dcmread(io.BytesIO(image_bytes))
76
+
77
+ # FIX: Convert to Hounsfield Units (HU)
78
  pixel_array = dcm.pixel_array.astype(np.int16)
79
  slope = dcm.RescaleSlope
80
  intercept = dcm.RescaleIntercept
81
  pixel_array = pixel_array * slope + intercept
82
 
83
+ # Apply Standard Brain Window (-100 HU to 150 HU)
84
  window_center = 40
85
  window_width = 150
86
  min_hu = window_center - (window_width / 2)
87
  max_hu = window_center + (window_width / 2)
88
 
89
+ # Apply the windowing transformation and scale to 0-255
90
  pixel_array[pixel_array < min_hu] = min_hu
91
  pixel_array[pixel_array > max_hu] = max_hu
92
  image_grayscale = ((pixel_array - min_hu) / (max_hu - min_hu) * 255).astype(np.uint8)
 
94
  except Exception:
95
  return None, None
96
  else:
97
+ # Read standard image (PNG/JPG)
98
  image_grayscale = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_GRAYSCALE)
99
  if image_grayscale is None: return None, None
100
 
101
+ # 2. STANDARD PREPROCESSING
102
  image_rgb = cv2.cvtColor(cv2.resize(image_grayscale, (IMAGE_SIZE, IMAGE_SIZE)), cv2.COLOR_GRAY2RGB)
103
+
104
  image_norm = (image_rgb.astype(np.float32) / 255.0 - 0.5) / 0.5
105
  input_tensor = torch.tensor(image_norm, dtype=torch.float).permute(2, 0, 1).unsqueeze(0)
106
 
107
  return input_tensor.to(DEVICE), image_rgb
108
 
109
  def generate_attribution(model: nn.Module, input_tensor: torch.Tensor, predicted_class_idx: int, n_steps: int = 20) -> np.ndarray:
110
+ """Computes Integrated Gradients for the given input and class."""
111
  target_class_int = int(predicted_class_idx)
112
  input_tensor.requires_grad_(True)
113
 
 
129
  return attributions_ig_vis
130
 
131
  def plot_heatmap_and_original(original_image: np.ndarray, heatmap: np.ndarray, predicted_label: str):
132
+ """Creates a Matplotlib figure for visualization."""
133
  fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
134
  original_image_vis = (original_image.astype(np.float32) / 255.0)
135
 
 
147
  plt.tight_layout()
148
  return fig
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  # ==============================================================================
151
  # -------------------- STREAMLIT FRONTEND --------------------
152
  # ==============================================================================
 
155
  st.title("CerebAI: AI-Powered Stroke Detection")
156
  st.markdown("---")
157
 
158
+ # FIX: Load the model using the download mechanism
159
  model = load_model(HF_MODEL_URL, DOWNLOAD_MODEL_PATH)
160
 
161
  if model is not None:
162
+ # --- INTERACTIVE CONTROLS (Sidebar or Main Area) ---
163
  st.markdown("### Analysis Controls")
164
 
165
  n_steps_slider = st.slider(
166
  'Integration Steps (Affects Accuracy & Speed)',
167
  min_value=5,
168
  max_value=50,
169
+ value=20,
170
  step=5,
171
+ help="Higher steps (up to 50) provide a smoother, more accurate heatmap but use more CPU."
 
172
  )
173
  st.markdown("---")
174
 
 
177
  st.markdown("### Upload CT Scan Image")
178
  uploaded_file = st.file_uploader(
179
  "Choose a Dicom, PNG, JPG, or JPEG file",
180
+ type=["dcm", "dicom", "png", "jpg", "jpeg"]
 
 
181
  )
182
 
183
  if uploaded_file is not None:
 
 
184
  image_bytes = uploaded_file.read()
185
  file_name = uploaded_file.name
186
 
187
+ # 1. PROCESS IMAGE FIRST (Defines original_image_rgb)
188
+ input_tensor, original_image_rgb = preprocess_image(image_bytes, file_name)
 
189
 
190
+ # --- DISPLAY AND RESULTS LAYOUT ---
191
  col1, col2 = st.columns(2)
192
 
193
  with col1:
194
  st.subheader("Uploaded Image")
195
+ # Display the processed NumPy array
196
  st.image(original_image_rgb, use_container_width=True, caption=file_name)
197
+
198
+ # Run Prediction and Attribution
199
+ if input_tensor is not None:
200
+ # Predict
201
+ with torch.no_grad():
202
+ output = model(input_tensor)
203
+ probabilities = torch.softmax(output, dim=1).squeeze(0).cpu().numpy()
204
+ predicted_class_idx = np.argmax(probabilities)
205
 
206
+ predicted_label = CLASS_LABELS[predicted_class_idx]
207
+ confidence_score = probabilities[predicted_class_idx]
 
 
 
 
 
 
 
208
 
209
+ # Generate Attribution
210
+ heatmap = generate_attribution(model, input_tensor, predicted_class_idx, n_steps=n_steps_slider)
211
+
212
+ # CRITICAL MEMORY MANAGEMENT
213
+ del input_tensor
214
+ del output
215
+ gc.collect()
216
+
217
  with col2:
218
  st.subheader("Prediction Summary")
219
 
220
  st.metric(
221
  label="Diagnosis",
222
+ value=predicted_label,
223
+ delta=f"{confidence_score*100:.2f}% Confidence",
224
+ delta_color='normal'
225
  )
226
 
227
  st.markdown("---")
 
229
 
230
  prob_data = {
231
  'Class': CLASS_LABELS,
232
+ 'Confidence': [f"{p:.4f}" for p in probabilities]
233
  }
234
  st.dataframe(prob_data, hide_index=True, use_container_width=True)
235
 
 
237
  st.markdown("---")
238
  st.subheader("Model Interpretation (Integrated Gradients)")
239
 
240
+ fig = plot_heatmap_and_original(original_image_rgb, heatmap, predicted_label)
241
  st.pyplot(fig, clear_figure=True, use_container_width=True)
242
 
243
  st.success("Analysis Complete: The heatmap highlights the regions most critical to the diagnosis.")