arshenoy commited on
Commit
3322876
·
verified ·
1 Parent(s): a44d9db

rollback to benchmark

Browse files
Files changed (1) hide show
  1. cerebAI.py +99 -141
cerebAI.py CHANGED
@@ -9,43 +9,36 @@ from captum.attr import IntegratedGradients
9
  from typing import Tuple, Optional
10
  import albumentations as A
11
  from albumentations.pytorch import ToTensorV2
12
- import os
13
- import requests
14
- import pydicom
15
- import io
16
- import gc
17
 
18
  # --- CONFIGURATION ---
19
- HF_MODEL_URL = "https://huggingface.co/arshenoy/cerebAI-stroke-model/resolve/main/best_model.pth"
20
  DOWNLOAD_MODEL_PATH = "best_model_cache.pth"
21
  CLASS_LABELS = ['No Stroke', 'Ischemic Stroke', 'Hemorrhagic Stroke']
22
  IMAGE_SIZE = 224
23
- DEVICE = torch.device("cpu")
24
-
25
- # --- INITIALIZE SESSION STATE (NEW: for stable UI) ---
26
- if 'results_ready' not in st.session_state:
27
- st.session_state.results_ready = False
28
- st.session_state.results = {}
29
- st.session_state.file_bytes = None
30
-
31
 
32
  # --- MODEL LOADING ---
33
  @st.cache_resource
34
  def load_model(model_url, local_path):
35
- # ... (Model loading logic remains the same) ...
36
  if not os.path.exists(local_path):
37
  st.info(f"Model not found locally. Downloading from remote repository...")
38
  try:
39
  response = requests.get(model_url, stream=True)
40
- response.raise_for_status()
 
41
  with open(local_path, "wb") as f:
42
  for chunk in response.iter_content(chunk_size=8192):
43
  f.write(chunk)
44
  st.success("Model download complete!")
45
- except Exception:
46
- st.error(f"FATAL ERROR: Could not download model.")
47
  return None
48
-
49
  try:
50
  model = timm.create_model('convnext_base', pretrained=False)
51
  model.reset_classifier(num_classes=len(CLASS_LABELS))
@@ -53,19 +46,18 @@ def load_model(model_url, local_path):
53
  model.to(DEVICE)
54
  model.eval()
55
  return model
56
- except Exception:
57
- st.error(f"Failed to load model weights from cache.")
58
  return None
59
 
60
- # --- CORE HELPER FUNCTIONS (UNCHANGED) ---
61
-
62
  def denormalize_image(tensor: torch.Tensor) -> np.ndarray:
63
  """Denormalizes a PyTorch tensor for matplotlib visualization."""
64
  if tensor.ndim == 4:
65
- tensor = tensor.squeeze(0).detach()
66
  else:
67
- tensor = tensor.detach()
68
-
69
  mean, std = np.array([0.5, 0.5, 0.5]), np.array([0.5, 0.5, 0.5])
70
  img = tensor.cpu().permute(1, 2, 0).numpy()
71
  img = (img * std) + mean
@@ -73,117 +65,81 @@ def denormalize_image(tensor: torch.Tensor) -> np.ndarray:
73
 
74
  def preprocess_image(image_bytes: bytes, file_name: str) -> Tuple[Optional[torch.Tensor], Optional[np.ndarray]]:
75
  """Loads, processes, and normalizes image, handling DICOM or JPG/PNG."""
76
-
77
  if file_name.lower().endswith(('.dcm', '.dicom')):
78
  try:
79
  dcm = pydicom.dcmread(io.BytesIO(image_bytes))
 
 
80
  pixel_array = dcm.pixel_array.astype(np.int16)
81
  slope = dcm.RescaleSlope
82
  intercept = dcm.RescaleIntercept
83
  pixel_array = pixel_array * slope + intercept
84
-
85
- window_center = 40
 
86
  window_width = 150
87
  min_hu = window_center - (window_width / 2)
88
  max_hu = window_center + (window_width / 2)
89
-
 
90
  pixel_array[pixel_array < min_hu] = min_hu
91
  pixel_array[pixel_array > max_hu] = max_hu
92
  image_grayscale = ((pixel_array - min_hu) / (max_hu - min_hu) * 255).astype(np.uint8)
93
-
94
  except Exception:
95
  return None, None
96
  else:
 
97
  image_grayscale = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_GRAYSCALE)
98
- if image_grayscale is None: return None, None
99
-
 
 
100
  image_rgb = cv2.cvtColor(cv2.resize(image_grayscale, (IMAGE_SIZE, IMAGE_SIZE)), cv2.COLOR_GRAY2RGB)
101
  image_norm = (image_rgb.astype(np.float32) / 255.0 - 0.5) / 0.5
102
  input_tensor = torch.tensor(image_norm, dtype=torch.float).permute(2, 0, 1).unsqueeze(0)
103
-
104
  return input_tensor.to(DEVICE), image_rgb
105
 
106
  def generate_attribution(model: nn.Module, input_tensor: torch.Tensor, predicted_class_idx: int, n_steps: int = 20) -> np.ndarray:
107
  """Computes Integrated Gradients for the given input and class."""
108
- target_class_int = int(predicted_class_idx)
109
- input_tensor.requires_grad_(True)
110
-
 
111
  ig = IntegratedGradients(model)
112
  baseline = torch.zeros_like(input_tensor).to(DEVICE)
113
-
114
  attributions_ig = ig.attribute(
115
  inputs=input_tensor,
116
  baselines=baseline,
117
  target=target_class_int,
118
  n_steps=n_steps
119
  )
120
-
121
  attributions_ig_vis = attributions_ig.squeeze(0).sum(dim=0).abs().cpu().detach().numpy()
122
-
123
  if attributions_ig_vis.max() > 0:
124
  attributions_ig_vis = attributions_ig_vis / attributions_ig_vis.max()
125
-
126
  return attributions_ig_vis
127
 
128
  def plot_heatmap_and_original(original_image: np.ndarray, heatmap: np.ndarray, predicted_label: str):
129
  """Creates a Matplotlib figure for visualization."""
130
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
131
- original_image_vis = (original_image.astype(np.float32) / 255.0)
132
 
 
133
  ax1.imshow(original_image_vis)
134
  ax1.set_title("Original CT Scan", fontsize=14)
135
  ax1.axis('off')
136
 
137
  ax2.imshow(original_image_vis)
138
- alpha_mask = heatmap * 0.7 + 0.3
139
-
140
  ax2.imshow(heatmap, cmap='jet', alpha=alpha_mask, vmin=0, vmax=1)
141
  ax2.set_title(f"Interpretation: {predicted_label}", fontsize=14)
142
  ax2.axis('off')
143
-
144
  plt.tight_layout()
145
  return fig
146
 
147
- # -------------------- CORE DIAGNOSIS FUNCTION (TO BE CALLED BY BUTTON) --------------------
148
-
149
- def perform_full_analysis(model, image_bytes, file_name, n_steps_slider):
150
- """Function called by the button to perform heavy computation and update state."""
151
-
152
- # 1. PREPARE INPUTS
153
- input_tensor, original_image_rgb = preprocess_image(image_bytes, file_name)
154
-
155
- if input_tensor is None:
156
- st.error("Could not process file. Ensure it is a valid DICOM/PNG/JPG.")
157
- st.session_state.results_ready = False
158
- return
159
-
160
- # 2. PREDICT
161
- with torch.no_grad():
162
- output = model(input_tensor)
163
- probabilities = torch.softmax(output, dim=1).squeeze(0).cpu().numpy()
164
- predicted_class_idx = np.argmax(probabilities)
165
-
166
- # 3. GENERATE ATTRIBUTION (HEAVY PART)
167
- heatmap = generate_attribution(model, input_tensor, predicted_class_idx, n_steps=n_steps_slider)
168
-
169
- # 4. CRITICAL MEMORY MANAGEMENT
170
- del input_tensor
171
- del output
172
- gc.collect()
173
-
174
- # 5. SAVE FINAL RESULTS TO STATE
175
- st.session_state.results = {
176
- 'label': CLASS_LABELS[predicted_class_idx],
177
- 'confidence': probabilities[predicted_class_idx],
178
- 'probabilities': probabilities,
179
- 'image_rgb': original_image_rgb,
180
- 'heatmap': heatmap
181
- }
182
- st.session_state.results_ready = True
183
- st.rerun() # Force a single, clean render of the results
184
-
185
-
186
  # ==============================================================================
 
187
  # -------------------- STREAMLIT FRONTEND --------------------
188
  # ==============================================================================
189
 
@@ -191,76 +147,78 @@ st.set_page_config(page_title="CerebAI: Stroke Prediction Dashboard", layout="wi
191
  st.title("CerebAI: AI-Powered Stroke Detection")
192
  st.markdown("---")
193
 
 
194
  model = load_model(HF_MODEL_URL, DOWNLOAD_MODEL_PATH)
195
 
196
  if model is not None:
197
- # --- INPUT CONTROLS ---
198
  st.markdown("### Analysis Controls")
199
-
200
  n_steps_slider = st.slider(
201
  'Integration Steps (Affects Accuracy & Speed)',
202
- min_value=5,
203
- max_value=50,
204
- value=10,
205
  step=5,
206
- key="n_steps_slider_key", # Added key to prevent slider reset on rerun
207
  help="Higher steps (up to 50) provide a smoother, more accurate heatmap but use more CPU."
208
  )
209
  st.markdown("---")
210
 
211
-
212
  # --- FILE UPLOAD ---
213
  st.markdown("### Upload CT Scan Image")
214
  uploaded_file = st.file_uploader(
215
- "Choose a Dicom, PNG, JPG, or JPEG file",
216
- type=["dcm", "dicom", "png", "jpg", "jpeg"],
217
- key="file_uploader" # Added key to manage file state
218
  )
219
-
220
-
221
- # --- UI FLOW MANAGEMENT ---
222
  if uploaded_file is not None:
223
  image_bytes = uploaded_file.read()
224
- file_name = uploaded_file.name
225
-
226
- # Display the image immediately (fast operation, no lag)
227
- st.subheader("Input Image")
228
- st.image(image_bytes, use_container_width=True, caption=file_name)
229
-
230
- # --- GATED BUTTON ---
231
- if st.button("▶️ RUN DIAGNOSIS & ATTRIBUTION", type="primary", use_container_width=True, key="run_button"):
232
- with st.spinner(f'Calculating Integrated Gradients ({n_steps_slider} steps)...'):
233
- # Execute the heavy logic and save to session state
234
- perform_full_analysis(model, image_bytes, file_name, n_steps_slider)
235
-
236
- # --- RESULTS DISPLAY (ONLY RUNS AFTER BUTTON CLICK) ---
237
- if st.session_state.results_ready:
238
- data = st.session_state.results
239
-
240
- st.markdown("---")
241
- st.markdown("### 2. Diagnosis and Interpretation")
242
-
243
- col_metric, col_plot = st.columns([1, 2])
244
-
245
- with col_metric:
246
- st.subheader("Prediction Summary")
247
- st.metric(
248
- label="Diagnosis",
249
- value=data['label'],
250
- delta=f"{data['confidence']*100:.2f}% Confidence",
251
- delta_color='normal'
252
- )
253
- st.markdown("---")
254
- st.subheader("Confidence Breakdown")
255
- prob_data = {
256
- 'Class': CLASS_LABELS,
257
- 'Confidence': [f"{p:.4f}" for p in data['probabilities']]
258
- }
259
- st.dataframe(prob_data, hide_index=True, use_container_width=True)
260
-
261
- with col_plot:
262
- st.subheader("Visual Explanation")
263
- fig = plot_heatmap_and_original(data['image_rgb'], data['heatmap'], data['label'])
264
- st.pyplot(fig, clear_figure=True, use_container_width=True)
265
-
266
- st.success("Analysis Complete: The heatmap highlights the regions most critical to the diagnosis.")
 
 
 
 
 
 
 
 
9
  from typing import Tuple, Optional
10
  import albumentations as A
11
  from albumentations.pytorch import ToTensorV2
12
+ import os
13
+ import requests
14
+ import pydicom # REQUIRED FOR DICOM SUPPORT
15
+ import io
16
+ import gc # For memory management
17
 
18
  # --- CONFIGURATION ---
19
+ HF_MODEL_URL = "https://huggingface.co/arshenoy/cerebAI-stroke-model/resolve/main/best_model.pth"
20
  DOWNLOAD_MODEL_PATH = "best_model_cache.pth"
21
  CLASS_LABELS = ['No Stroke', 'Ischemic Stroke', 'Hemorrhagic Stroke']
22
  IMAGE_SIZE = 224
23
+ DEVICE = torch.device("cpu") # For Streamlit Cloud stability
 
 
 
 
 
 
 
24
 
25
  # --- MODEL LOADING ---
26
  @st.cache_resource
27
  def load_model(model_url, local_path):
28
+ """Downloads model from URL if not cached, and loads the weights."""
29
  if not os.path.exists(local_path):
30
  st.info(f"Model not found locally. Downloading from remote repository...")
31
  try:
32
  response = requests.get(model_url, stream=True)
33
+ response.raise_for_status()
34
+
35
  with open(local_path, "wb") as f:
36
  for chunk in response.iter_content(chunk_size=8192):
37
  f.write(chunk)
38
  st.success("Model download complete!")
39
+ except Exception as e:
40
+ st.error(f"FATAL ERROR: Could not download model. Check the URL. Error: {e}")
41
  return None
 
42
  try:
43
  model = timm.create_model('convnext_base', pretrained=False)
44
  model.reset_classifier(num_classes=len(CLASS_LABELS))
 
46
  model.to(DEVICE)
47
  model.eval()
48
  return model
49
+ except Exception as e:
50
+ st.error(f"Failed to load model weights from cache. Error: {e}")
51
  return None
52
 
53
+ # --- HELPER FUNCTIONS ---
 
54
  def denormalize_image(tensor: torch.Tensor) -> np.ndarray:
55
  """Denormalizes a PyTorch tensor for matplotlib visualization."""
56
  if tensor.ndim == 4:
57
+ tensor = tensor.squeeze(0).detach()
58
  else:
59
+ tensor = tensor.detach()
60
+
61
  mean, std = np.array([0.5, 0.5, 0.5]), np.array([0.5, 0.5, 0.5])
62
  img = tensor.cpu().permute(1, 2, 0).numpy()
63
  img = (img * std) + mean
 
65
 
66
  def preprocess_image(image_bytes: bytes, file_name: str) -> Tuple[Optional[torch.Tensor], Optional[np.ndarray]]:
67
  """Loads, processes, and normalizes image, handling DICOM or JPG/PNG."""
68
+ # 1. READ IMAGE DATA (Handles DICOM vs Standard formats)
69
  if file_name.lower().endswith(('.dcm', '.dicom')):
70
  try:
71
  dcm = pydicom.dcmread(io.BytesIO(image_bytes))
72
+
73
+ # FIX: Convert to Hounsfield Units (HU)
74
  pixel_array = dcm.pixel_array.astype(np.int16)
75
  slope = dcm.RescaleSlope
76
  intercept = dcm.RescaleIntercept
77
  pixel_array = pixel_array * slope + intercept
78
+
79
+ # Apply Standard Brain Window (-100 HU to 150 HU)
80
+ window_center = 40
81
  window_width = 150
82
  min_hu = window_center - (window_width / 2)
83
  max_hu = window_center + (window_width / 2)
84
+
85
+ # Apply the windowing transformation and scale to 0-255
86
  pixel_array[pixel_array < min_hu] = min_hu
87
  pixel_array[pixel_array > max_hu] = max_hu
88
  image_grayscale = ((pixel_array - min_hu) / (max_hu - min_hu) * 255).astype(np.uint8)
 
89
  except Exception:
90
  return None, None
91
  else:
92
+ # Read standard image (PNG/JPG)
93
  image_grayscale = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_GRAYSCALE)
94
+ if image_grayscale is None:
95
+ return None, None
96
+
97
+ # 2. STANDARD PREPROCESSING
98
  image_rgb = cv2.cvtColor(cv2.resize(image_grayscale, (IMAGE_SIZE, IMAGE_SIZE)), cv2.COLOR_GRAY2RGB)
99
  image_norm = (image_rgb.astype(np.float32) / 255.0 - 0.5) / 0.5
100
  input_tensor = torch.tensor(image_norm, dtype=torch.float).permute(2, 0, 1).unsqueeze(0)
 
101
  return input_tensor.to(DEVICE), image_rgb
102
 
103
  def generate_attribution(model: nn.Module, input_tensor: torch.Tensor, predicted_class_idx: int, n_steps: int = 20) -> np.ndarray:
104
  """Computes Integrated Gradients for the given input and class."""
105
+ target_class_int = int(predicted_class_idx)
106
+
107
+ input_tensor.requires_grad_(True)
108
+
109
  ig = IntegratedGradients(model)
110
  baseline = torch.zeros_like(input_tensor).to(DEVICE)
 
111
  attributions_ig = ig.attribute(
112
  inputs=input_tensor,
113
  baselines=baseline,
114
  target=target_class_int,
115
  n_steps=n_steps
116
  )
117
+
118
  attributions_ig_vis = attributions_ig.squeeze(0).sum(dim=0).abs().cpu().detach().numpy()
 
119
  if attributions_ig_vis.max() > 0:
120
  attributions_ig_vis = attributions_ig_vis / attributions_ig_vis.max()
 
121
  return attributions_ig_vis
122
 
123
  def plot_heatmap_and_original(original_image: np.ndarray, heatmap: np.ndarray, predicted_label: str):
124
  """Creates a Matplotlib figure for visualization."""
125
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
 
126
 
127
+ original_image_vis = (original_image.astype(np.float32) / 255.0)
128
  ax1.imshow(original_image_vis)
129
  ax1.set_title("Original CT Scan", fontsize=14)
130
  ax1.axis('off')
131
 
132
  ax2.imshow(original_image_vis)
133
+ alpha_mask = heatmap * 0.7 + 0.3
 
134
  ax2.imshow(heatmap, cmap='jet', alpha=alpha_mask, vmin=0, vmax=1)
135
  ax2.set_title(f"Interpretation: {predicted_label}", fontsize=14)
136
  ax2.axis('off')
137
+
138
  plt.tight_layout()
139
  return fig
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  # ==============================================================================
142
+
143
  # -------------------- STREAMLIT FRONTEND --------------------
144
  # ==============================================================================
145
 
 
147
  st.title("CerebAI: AI-Powered Stroke Detection")
148
  st.markdown("---")
149
 
150
+ # FIX: Load the model using the download mechanism
151
  model = load_model(HF_MODEL_URL, DOWNLOAD_MODEL_PATH)
152
 
153
  if model is not None:
154
+ # --- INTERACTIVE CONTROLS (Sidebar or Main Area) ---
155
  st.markdown("### Analysis Controls")
 
156
  n_steps_slider = st.slider(
157
  'Integration Steps (Affects Accuracy & Speed)',
158
+ min_value=5,
159
+ max_value=50,
160
+ value=20,
161
  step=5,
 
162
  help="Higher steps (up to 50) provide a smoother, more accurate heatmap but use more CPU."
163
  )
164
  st.markdown("---")
165
 
 
166
  # --- FILE UPLOAD ---
167
  st.markdown("### Upload CT Scan Image")
168
  uploaded_file = st.file_uploader(
169
+ "Choose a Dicom, PNG, JPG, or JPEG file",
170
+ type=["dcm", "dicom", "png", "jpg", "jpeg"]
 
171
  )
172
+
 
 
173
  if uploaded_file is not None:
174
  image_bytes = uploaded_file.read()
175
+ file_name = uploaded_file.name
176
+
177
+ # 1. PROCESS IMAGE FIRST (Defines original_image_rgb)
178
+ input_tensor, original_image_rgb = preprocess_image(image_bytes, file_name)
179
+
180
+ # --- DISPLAY AND RESULTS LAYOUT ---
181
+ col1, col2 = st.columns(2)
182
+
183
+ with col1:
184
+ st.subheader("Uploaded Image")
185
+ # Display the processed NumPy array
186
+ st.image(original_image_rgb, use_container_width=True, caption=file_name)
187
+
188
+ # Run Prediction and Attribution
189
+ if input_tensor is not None:
190
+ # Predict
191
+ with torch.no_grad():
192
+ output = model(input_tensor)
193
+ probabilities = torch.softmax(output, dim=1).squeeze(0).cpu().numpy()
194
+ predicted_class_idx = np.argmax(probabilities)
195
+
196
+ predicted_label = CLASS_LABELS[predicted_class_idx]
197
+ confidence_score = probabilities[predicted_class_idx]
198
+
199
+ # Generate Attribution
200
+ heatmap = generate_attribution(model, input_tensor, predicted_class_idx, n_steps=n_steps_slider)
201
+
202
+ # CRITICAL MEMORY MANAGEMENT
203
+ del input_tensor
204
+ del output
205
+ gc.collect()
206
+
207
+ with col2:
208
+ st.subheader("Prediction Summary")
209
+
210
+ st.metric(
211
+ label="Diagnosis",
212
+ value=predicted_label,
213
+ delta=f"{confidence_score*100:.2f}% Confidence",
214
+ delta_color='normal'
215
+ )
216
+
217
+ st.markdown("---")
218
+ st.subheader("Confidence Breakdown")
219
+
220
+ prob_data = {
221
+ 'Class': CLASS_LABELS,
222
+ 'Confidence': [f"{p:.4f}" for p in probabilities]
223
+ }
224
+ st.dataframe(prob_data, hide_index=True, use_container_width