arshenoy commited on
Commit
92dc792
·
verified ·
1 Parent(s): 6a7e0d5

Update cerebAI.py

Browse files
Files changed (1) hide show
  1. cerebAI.py +60 -88
cerebAI.py CHANGED
@@ -11,40 +11,34 @@ import albumentations as A
11
  from albumentations.pytorch import ToTensorV2
12
  import os
13
  import requests
14
- import pydicom
15
  import io
16
- import gc
17
 
18
  # --- CONFIGURATION ---
19
  HF_MODEL_URL = "https://huggingface.co/arshenoy/cerebAI-stroke-model/resolve/main/best_model.pth"
20
  DOWNLOAD_MODEL_PATH = "best_model_cache.pth"
21
  CLASS_LABELS = ['No Stroke', 'Ischemic Stroke', 'Hemorrhagic Stroke']
22
  IMAGE_SIZE = 224
23
- DEVICE = torch.device("cpu")
24
 
25
- # --- INITIALIZE SESSION STATE (CRITICAL FOR UI STABILITY) ---
26
- if 'results_ready' not in st.session_state:
27
- st.session_state.results_ready = False
28
- st.session_state.results = {}
29
- st.session_state.input_bytes = None
30
- st.session_state.file_name = None
31
-
32
-
33
- # --- MODEL LOADING (UNCHANGED) ---
34
  @st.cache_resource
35
  def load_model(model_url, local_path):
36
- # ... (Model download and loading logic) ...
 
37
  if not os.path.exists(local_path):
38
  st.info(f"Model not found locally. Downloading from remote repository...")
39
  try:
40
  response = requests.get(model_url, stream=True)
41
  response.raise_for_status()
 
42
  with open(local_path, "wb") as f:
43
  for chunk in response.iter_content(chunk_size=8192):
44
  f.write(chunk)
45
  st.success("Model download complete!")
46
- except Exception:
47
- st.error(f"FATAL ERROR: Could not download model.")
48
  return None
49
 
50
  try:
@@ -54,12 +48,14 @@ def load_model(model_url, local_path):
54
  model.to(DEVICE)
55
  model.eval()
56
  return model
57
- except Exception:
58
- st.error(f"Failed to load model weights.")
59
  return None
60
 
61
- # --- CORE LOGIC FUNCTIONS (UNCHANGED) ---
 
62
  def denormalize_image(tensor: torch.Tensor) -> np.ndarray:
 
63
  if tensor.ndim == 4:
64
  tensor = tensor.squeeze(0).detach()
65
  else:
@@ -71,20 +67,26 @@ def denormalize_image(tensor: torch.Tensor) -> np.ndarray:
71
  return np.clip(img, 0, 1)
72
 
73
  def preprocess_image(image_bytes: bytes, file_name: str) -> Tuple[Optional[torch.Tensor], Optional[np.ndarray]]:
74
- # ... (DICOM and preprocessing logic remains the same) ...
 
 
75
  if file_name.lower().endswith(('.dcm', '.dicom')):
76
  try:
77
  dcm = pydicom.dcmread(io.BytesIO(image_bytes))
 
 
78
  pixel_array = dcm.pixel_array.astype(np.int16)
79
  slope = dcm.RescaleSlope
80
  intercept = dcm.RescaleIntercept
81
  pixel_array = pixel_array * slope + intercept
82
 
 
83
  window_center = 40
84
  window_width = 150
85
  min_hu = window_center - (window_width / 2)
86
  max_hu = window_center + (window_width / 2)
87
 
 
88
  pixel_array[pixel_array < min_hu] = min_hu
89
  pixel_array[pixel_array > max_hu] = max_hu
90
  image_grayscale = ((pixel_array - min_hu) / (max_hu - min_hu) * 255).astype(np.uint8)
@@ -92,17 +94,20 @@ def preprocess_image(image_bytes: bytes, file_name: str) -> Tuple[Optional[torch
92
  except Exception:
93
  return None, None
94
  else:
 
95
  image_grayscale = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_GRAYSCALE)
96
  if image_grayscale is None: return None, None
97
 
 
98
  image_rgb = cv2.cvtColor(cv2.resize(image_grayscale, (IMAGE_SIZE, IMAGE_SIZE)), cv2.COLOR_GRAY2RGB)
 
99
  image_norm = (image_rgb.astype(np.float32) / 255.0 - 0.5) / 0.5
100
  input_tensor = torch.tensor(image_norm, dtype=torch.float).permute(2, 0, 1).unsqueeze(0)
101
 
102
  return input_tensor.to(DEVICE), image_rgb
103
 
104
  def generate_attribution(model: nn.Module, input_tensor: torch.Tensor, predicted_class_idx: int, n_steps: int = 20) -> np.ndarray:
105
- # ... (Attribution logic remains the same) ...
106
  target_class_int = int(predicted_class_idx)
107
  input_tensor.requires_grad_(True)
108
 
@@ -124,7 +129,7 @@ def generate_attribution(model: nn.Module, input_tensor: torch.Tensor, predicted
124
  return attributions_ig_vis
125
 
126
  def plot_heatmap_and_original(original_image: np.ndarray, heatmap: np.ndarray, predicted_label: str):
127
- # ... (Plotting logic remains the same) ...
128
  fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
129
  original_image_vis = (original_image.astype(np.float32) / 255.0)
130
 
@@ -142,46 +147,6 @@ def plot_heatmap_and_original(original_image: np.ndarray, heatmap: np.ndarray, p
142
  plt.tight_layout()
143
  return fig
144
 
145
- # -------------------- CORE DIAGNOSIS FUNCTION (TO BE CALLED BY BUTTON) --------------------
146
-
147
- def perform_full_analysis(model, image_bytes, file_name, n_steps_slider):
148
- """Function called by the button to perform heavy computation and update state."""
149
-
150
- # 1. PREPARE INPUTS (The Jitter-Causing step)
151
- input_tensor, original_image_rgb = preprocess_image(image_bytes, file_name)
152
-
153
- if input_tensor is None:
154
- st.error("Could not process file. Ensure it is a valid DICOM/PNG/JPG.")
155
- st.session_state.results_ready = False
156
- return
157
-
158
- with st.spinner(f'Calculating Integrated Gradients ({n_steps_slider} steps)...'):
159
-
160
- # 2. PREDICT & GENERATE ATTRIBUTION
161
- with torch.no_grad():
162
- output = model(input_tensor)
163
- probabilities = torch.softmax(output, dim=1).squeeze(0).cpu().numpy()
164
- predicted_class_idx = np.argmax(probabilities)
165
-
166
- heatmap = generate_attribution(model, input_tensor, predicted_class_idx, n_steps=n_steps_slider)
167
-
168
- # 3. CRITICAL MEMORY MANAGEMENT
169
- del input_tensor
170
- del output
171
- gc.collect()
172
-
173
- # 4. STORE FINAL RESULTS TO STATE AND RERUN
174
- st.session_state.results = {
175
- 'label': CLASS_LABELS[predicted_class_idx],
176
- 'confidence': probabilities[predicted_class_idx],
177
- 'probabilities': probabilities,
178
- 'image_rgb': original_image_rgb,
179
- 'heatmap': heatmap
180
- }
181
- st.session_state.results_ready = True
182
- st.rerun()
183
-
184
-
185
  # ==============================================================================
186
  # -------------------- STREAMLIT FRONTEND --------------------
187
  # ==============================================================================
@@ -190,19 +155,19 @@ st.set_page_config(page_title="CerebAI: Stroke Prediction Dashboard", layout="wi
190
  st.title("CerebAI: AI-Powered Stroke Detection")
191
  st.markdown("---")
192
 
 
193
  model = load_model(HF_MODEL_URL, DOWNLOAD_MODEL_PATH)
194
 
195
  if model is not None:
196
- # --- INTERACTIVE CONTROLS ---
197
  st.markdown("### Analysis Controls")
198
 
199
  n_steps_slider = st.slider(
200
  'Integration Steps (Affects Accuracy & Speed)',
201
  min_value=5,
202
  max_value=50,
203
- value=10,
204
  step=5,
205
- key="n_steps_slider_key",
206
  help="Higher steps (up to 50) provide a smoother, more accurate heatmap but use more CPU."
207
  )
208
  st.markdown("---")
@@ -212,43 +177,50 @@ if model is not None:
212
  st.markdown("### Upload CT Scan Image")
213
  uploaded_file = st.file_uploader(
214
  "Choose a Dicom, PNG, JPG, or JPEG file",
215
- type=["dcm", "dicom", "png", "jpg", "jpeg"],
216
- key="file_uploader"
217
  )
218
-
219
- # --- UI FLOW MANAGEMENT ---
220
  if uploaded_file is not None:
 
 
221
 
222
- # Cache file bytes and name to session state to prevent corruption on rerun
223
- st.session_state.input_bytes = uploaded_file.read()
224
- st.session_state.file_name = uploaded_file.name
225
-
226
- # Preprocess image for DISPLAY ONLY (fast operation, no lag)
227
- _, original_image_rgb = preprocess_image(st.session_state.input_bytes, st.session_state.file_name)
228
 
229
- # --- DISPLAY UPLOADED IMAGE AND BUTTON ---
230
  col1, col2 = st.columns(2)
231
 
232
  with col1:
233
  st.subheader("Uploaded Image")
234
- st.image(original_image_rgb, use_container_width=True, caption=st.session_state.file_name)
 
 
 
 
 
 
 
 
 
235
 
236
- # --- THE GATED BUTTON ---
237
- # This button will trigger the heavy computation
238
- if st.button("▶️ RUN FULL DIAGNOSIS & ATTRIBUTION", type="primary", use_container_width=True, key="run_button"):
239
- perform_full_analysis(model, st.session_state.input_bytes, st.session_state.file_name, n_steps_slider)
240
-
241
- # --- RESULTS DISPLAY (STABLE) ---
242
- if st.session_state.results_ready:
243
- data = st.session_state.results
244
 
 
 
 
 
 
245
  with col2:
246
  st.subheader("Prediction Summary")
247
 
248
  st.metric(
249
  label="Diagnosis",
250
- value=data['label'],
251
- delta=f"{data['confidence']*100:.2f}% Confidence",
252
  delta_color='normal'
253
  )
254
 
@@ -257,7 +229,7 @@ if model is not None:
257
 
258
  prob_data = {
259
  'Class': CLASS_LABELS,
260
- 'Confidence': [f"{p:.4f}" for p in data['probabilities']]
261
  }
262
  st.dataframe(prob_data, hide_index=True, use_container_width=True)
263
 
@@ -265,7 +237,7 @@ if model is not None:
265
  st.markdown("---")
266
  st.subheader("Model Interpretation (Integrated Gradients)")
267
 
268
- fig = plot_heatmap_and_original(data['image_rgb'], data['heatmap'], data['label'])
269
  st.pyplot(fig, clear_figure=True, use_container_width=True)
270
 
271
  st.success("Analysis Complete: The heatmap highlights the regions most critical to the diagnosis.")
 
11
  from albumentations.pytorch import ToTensorV2
12
  import os
13
  import requests
14
+ import pydicom # REQUIRED FOR DICOM SUPPORT
15
  import io
16
+ import gc # For memory management
17
 
18
  # --- CONFIGURATION ---
19
  HF_MODEL_URL = "https://huggingface.co/arshenoy/cerebAI-stroke-model/resolve/main/best_model.pth"
20
  DOWNLOAD_MODEL_PATH = "best_model_cache.pth"
21
  CLASS_LABELS = ['No Stroke', 'Ischemic Stroke', 'Hemorrhagic Stroke']
22
  IMAGE_SIZE = 224
23
+ DEVICE = torch.device("cpu") # For Streamlit Cloud stability
24
 
25
+ # --- MODEL LOADING ---
 
 
 
 
 
 
 
 
26
  @st.cache_resource
27
  def load_model(model_url, local_path):
28
+ """Downloads model from URL if not cached, and loads the weights."""
29
+
30
  if not os.path.exists(local_path):
31
  st.info(f"Model not found locally. Downloading from remote repository...")
32
  try:
33
  response = requests.get(model_url, stream=True)
34
  response.raise_for_status()
35
+
36
  with open(local_path, "wb") as f:
37
  for chunk in response.iter_content(chunk_size=8192):
38
  f.write(chunk)
39
  st.success("Model download complete!")
40
+ except Exception as e:
41
+ st.error(f"FATAL ERROR: Could not download model. Check the URL. Error: {e}")
42
  return None
43
 
44
  try:
 
48
  model.to(DEVICE)
49
  model.eval()
50
  return model
51
+ except Exception as e:
52
+ st.error(f"Failed to load model weights from cache. Error: {e}")
53
  return None
54
 
55
+ # --- HELPER FUNCTIONS ---
56
+
57
  def denormalize_image(tensor: torch.Tensor) -> np.ndarray:
58
+ """Denormalizes a PyTorch tensor for matplotlib visualization."""
59
  if tensor.ndim == 4:
60
  tensor = tensor.squeeze(0).detach()
61
  else:
 
67
  return np.clip(img, 0, 1)
68
 
69
  def preprocess_image(image_bytes: bytes, file_name: str) -> Tuple[Optional[torch.Tensor], Optional[np.ndarray]]:
70
+ """Loads, processes, and normalizes image, handling DICOM or JPG/PNG."""
71
+
72
+ # 1. READ IMAGE DATA (Handles DICOM vs Standard formats)
73
  if file_name.lower().endswith(('.dcm', '.dicom')):
74
  try:
75
  dcm = pydicom.dcmread(io.BytesIO(image_bytes))
76
+
77
+ # FIX: Convert to Hounsfield Units (HU)
78
  pixel_array = dcm.pixel_array.astype(np.int16)
79
  slope = dcm.RescaleSlope
80
  intercept = dcm.RescaleIntercept
81
  pixel_array = pixel_array * slope + intercept
82
 
83
+ # Apply Standard Brain Window (-100 HU to 150 HU)
84
  window_center = 40
85
  window_width = 150
86
  min_hu = window_center - (window_width / 2)
87
  max_hu = window_center + (window_width / 2)
88
 
89
+ # Apply the windowing transformation and scale to 0-255
90
  pixel_array[pixel_array < min_hu] = min_hu
91
  pixel_array[pixel_array > max_hu] = max_hu
92
  image_grayscale = ((pixel_array - min_hu) / (max_hu - min_hu) * 255).astype(np.uint8)
 
94
  except Exception:
95
  return None, None
96
  else:
97
+ # Read standard image (PNG/JPG)
98
  image_grayscale = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_GRAYSCALE)
99
  if image_grayscale is None: return None, None
100
 
101
+ # 2. STANDARD PREPROCESSING
102
  image_rgb = cv2.cvtColor(cv2.resize(image_grayscale, (IMAGE_SIZE, IMAGE_SIZE)), cv2.COLOR_GRAY2RGB)
103
+
104
  image_norm = (image_rgb.astype(np.float32) / 255.0 - 0.5) / 0.5
105
  input_tensor = torch.tensor(image_norm, dtype=torch.float).permute(2, 0, 1).unsqueeze(0)
106
 
107
  return input_tensor.to(DEVICE), image_rgb
108
 
109
  def generate_attribution(model: nn.Module, input_tensor: torch.Tensor, predicted_class_idx: int, n_steps: int = 20) -> np.ndarray:
110
+ """Computes Integrated Gradients for the given input and class."""
111
  target_class_int = int(predicted_class_idx)
112
  input_tensor.requires_grad_(True)
113
 
 
129
  return attributions_ig_vis
130
 
131
  def plot_heatmap_and_original(original_image: np.ndarray, heatmap: np.ndarray, predicted_label: str):
132
+ """Creates a Matplotlib figure for visualization."""
133
  fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
134
  original_image_vis = (original_image.astype(np.float32) / 255.0)
135
 
 
147
  plt.tight_layout()
148
  return fig
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  # ==============================================================================
151
  # -------------------- STREAMLIT FRONTEND --------------------
152
  # ==============================================================================
 
155
  st.title("CerebAI: AI-Powered Stroke Detection")
156
  st.markdown("---")
157
 
158
+ # FIX: Load the model using the download mechanism
159
  model = load_model(HF_MODEL_URL, DOWNLOAD_MODEL_PATH)
160
 
161
  if model is not None:
162
+ # --- INTERACTIVE CONTROLS (Sidebar or Main Area) ---
163
  st.markdown("### Analysis Controls")
164
 
165
  n_steps_slider = st.slider(
166
  'Integration Steps (Affects Accuracy & Speed)',
167
  min_value=5,
168
  max_value=50,
169
+ value=20,
170
  step=5,
 
171
  help="Higher steps (up to 50) provide a smoother, more accurate heatmap but use more CPU."
172
  )
173
  st.markdown("---")
 
177
  st.markdown("### Upload CT Scan Image")
178
  uploaded_file = st.file_uploader(
179
  "Choose a Dicom, PNG, JPG, or JPEG file",
180
+ type=["dcm", "dicom", "png", "jpg", "jpeg"]
 
181
  )
182
+
 
183
  if uploaded_file is not None:
184
+ image_bytes = uploaded_file.read()
185
+ file_name = uploaded_file.name
186
 
187
+ # 1. PROCESS IMAGE FIRST (Defines original_image_rgb)
188
+ input_tensor, original_image_rgb = preprocess_image(image_bytes, file_name)
 
 
 
 
189
 
190
+ # --- DISPLAY AND RESULTS LAYOUT ---
191
  col1, col2 = st.columns(2)
192
 
193
  with col1:
194
  st.subheader("Uploaded Image")
195
+ # Display the processed NumPy array
196
+ st.image(original_image_rgb, use_container_width=True, caption=file_name)
197
+
198
+ # Run Prediction and Attribution
199
+ if input_tensor is not None:
200
+ # Predict
201
+ with torch.no_grad():
202
+ output = model(input_tensor)
203
+ probabilities = torch.softmax(output, dim=1).squeeze(0).cpu().numpy()
204
+ predicted_class_idx = np.argmax(probabilities)
205
 
206
+ predicted_label = CLASS_LABELS[predicted_class_idx]
207
+ confidence_score = probabilities[predicted_class_idx]
208
+
209
+ # Generate Attribution
210
+ heatmap = generate_attribution(model, input_tensor, predicted_class_idx, n_steps=n_steps_slider)
 
 
 
211
 
212
+ # CRITICAL MEMORY MANAGEMENT
213
+ del input_tensor
214
+ del output
215
+ gc.collect()
216
+
217
  with col2:
218
  st.subheader("Prediction Summary")
219
 
220
  st.metric(
221
  label="Diagnosis",
222
+ value=predicted_label,
223
+ delta=f"{confidence_score*100:.2f}% Confidence",
224
  delta_color='normal'
225
  )
226
 
 
229
 
230
  prob_data = {
231
  'Class': CLASS_LABELS,
232
+ 'Confidence': [f"{p:.4f}" for p in probabilities]
233
  }
234
  st.dataframe(prob_data, hide_index=True, use_container_width=True)
235
 
 
237
  st.markdown("---")
238
  st.subheader("Model Interpretation (Integrated Gradients)")
239
 
240
+ fig = plot_heatmap_and_original(original_image_rgb, heatmap, predicted_label)
241
  st.pyplot(fig, clear_figure=True, use_container_width=True)
242
 
243
  st.success("Analysis Complete: The heatmap highlights the regions most critical to the diagnosis.")