arshenoy commited on
Commit
c9190ef
·
verified ·
1 Parent(s): 2fcec1e

Update cerebAI.py

Browse files
Files changed (1) hide show
  1. cerebAI.py +26 -27
cerebAI.py CHANGED
@@ -23,17 +23,17 @@ IMAGE_SIZE = 224
23
  DEVICE = torch.device("cpu")
24
 
25
  # --- INITIALIZE SESSION STATE (CRITICAL FOR STABILITY) ---
26
- if 'results_ready' not in st.session_state:
27
- st.session_state.results_ready = False
28
  st.session_state.results = {}
29
- st.session_state.input_bytes = None
30
  st.session_state.file_name = None
 
31
 
32
 
33
- # --- MODEL LOADING (UNCHANGED) ---
34
  @st.cache_resource
35
  def load_model(model_url, local_path):
36
- """Downloads model from URL if not cached, and loads the weights."""
37
  if not os.path.exists(local_path):
38
  st.info(f"Model not found locally. Downloading from remote repository...")
39
  try:
@@ -55,13 +55,12 @@ def load_model(model_url, local_path):
55
  model.eval()
56
  return model
57
  except Exception:
58
- st.error(f"Failed to load model weights.")
59
  return None
60
 
61
- # --- CORE LOGIC FUNCTIONS ---
62
-
63
  def denormalize_image(tensor: torch.Tensor) -> np.ndarray:
64
- """Denormalizes a PyTorch tensor for matplotlib visualization."""
65
  if tensor.ndim == 4:
66
  tensor = tensor.squeeze(0).detach()
67
  else:
@@ -73,7 +72,6 @@ def denormalize_image(tensor: torch.Tensor) -> np.ndarray:
73
  return np.clip(img, 0, 1)
74
 
75
  def preprocess_image(image_bytes: bytes, file_name: str) -> Tuple[Optional[torch.Tensor], Optional[np.ndarray]]:
76
- """Loads, processes, and normalizes image, handling DICOM or JPG/PNG."""
77
 
78
  if file_name.lower().endswith(('.dcm', '.dicom')):
79
  try:
@@ -105,7 +103,6 @@ def preprocess_image(image_bytes: bytes, file_name: str) -> Tuple[Optional[torch
105
  return input_tensor.to(DEVICE), image_rgb
106
 
107
  def generate_attribution(model: nn.Module, input_tensor: torch.Tensor, predicted_class_idx: int, n_steps: int = 20) -> np.ndarray:
108
- """Computes Integrated Gradients for the given input and class."""
109
  target_class_int = int(predicted_class_idx)
110
  input_tensor.requires_grad_(True)
111
 
@@ -127,7 +124,6 @@ def generate_attribution(model: nn.Module, input_tensor: torch.Tensor, predicted
127
  return attributions_ig_vis
128
 
129
  def plot_heatmap_and_original(original_image: np.ndarray, heatmap: np.ndarray, predicted_label: str):
130
- """Creates a Matplotlib figure for visualization."""
131
  fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
132
  original_image_vis = (original_image.astype(np.float32) / 255.0)
133
 
@@ -145,7 +141,8 @@ def plot_heatmap_and_original(original_image: np.ndarray, heatmap: np.ndarray, p
145
  plt.tight_layout()
146
  return fig
147
 
148
- # -------------------- CORE DIAGNOSIS FUNCTION (TO BE CALLED BY BUTTON) --------------------
 
149
 
150
  def perform_full_analysis(model, image_bytes, file_name, n_steps_slider):
151
  """Function called by the button to perform heavy computation and update state."""
@@ -196,7 +193,7 @@ st.markdown("---")
196
  model = load_model(HF_MODEL_URL, DOWNLOAD_MODEL_PATH)
197
 
198
  if model is not None:
199
- # --- INTERACTIVE CONTROLS ---
200
  st.markdown("### Analysis Controls")
201
 
202
  n_steps_slider = st.slider(
@@ -206,7 +203,7 @@ if model is not None:
206
  value=10,
207
  step=5,
208
  key="n_steps_slider_key",
209
- help="Lower steps (5-15) provide a smoother, more accurate heatmap but use more CPU."
210
  )
211
  st.markdown("---")
212
 
@@ -216,32 +213,33 @@ if model is not None:
216
  uploaded_file = st.file_uploader(
217
  "Choose a Dicom, PNG, JPG, or JPEG file",
218
  type=["dcm", "dicom", "png", "jpg", "jpeg"],
219
- key="file_uploader"
 
220
  )
221
-
222
- # --- UI FLOW MANAGEMENT ---
223
  if uploaded_file is not None:
224
 
225
  # 1. READ AND CACHE INPUT BYTES (Fastest Operation)
226
- st.session_state.input_bytes = uploaded_file.read()
227
- st.session_state.file_name = uploaded_file.name
228
 
229
  # Preprocess image for DISPLAY ONLY (fast operation, no lag)
230
- # This part runs quickly and stabilizes the top half of the screen
231
- _, original_image_rgb = preprocess_image(st.session_state.input_bytes, st.session_state.file_name)
232
 
233
  # --- DISPLAY UPLOADED IMAGE AND BUTTON ---
234
  col1, col2 = st.columns(2)
235
 
236
  with col1:
237
  st.subheader("Uploaded Image")
238
- st.image(original_image_rgb, use_container_width=True, caption=st.session_state.file_name)
 
239
 
240
- # --- THE GATED BUTTON (CRITICAL) ---
241
  st.markdown("---")
242
  if st.button("▶️ RUN FULL DIAGNOSIS & ATTRIBUTION", type="primary", use_container_width=True, key="run_button"):
243
- # Call the heavy analysis function
244
- perform_full_analysis(model, st.session_state.input_bytes, st.session_state.file_name, n_steps_slider)
245
 
246
  # --- RESULTS DISPLAY (STABLE) ---
247
  if st.session_state.results_ready:
@@ -273,4 +271,5 @@ if model is not None:
273
  fig = plot_heatmap_and_original(data['image_rgb'], data['heatmap'], data['label'])
274
  st.pyplot(fig, clear_figure=True, use_container_width=True)
275
 
276
- st.success("Analysis Complete: The heatmap highlights the regions most critical to the diagnosis.")
 
 
23
  DEVICE = torch.device("cpu")
24
 
25
  # --- INITIALIZE SESSION STATE (CRITICAL FOR STABILITY) ---
26
+ if 'output_ready' not in st.session_state:
27
+ st.session_state.output_ready = False
28
  st.session_state.results = {}
 
29
  st.session_state.file_name = None
30
+ st.session_state.input_bytes = None
31
 
32
 
33
+ # --- MODEL LOADING ---
34
  @st.cache_resource
35
  def load_model(model_url, local_path):
36
+ # ... (Model loading logic remains the same) ...
37
  if not os.path.exists(local_path):
38
  st.info(f"Model not found locally. Downloading from remote repository...")
39
  try:
 
55
  model.eval()
56
  return model
57
  except Exception:
58
+ st.error(f"Failed to load model weights from cache.")
59
  return None
60
 
61
+ # --- HELPER FUNCTIONS ---
62
+ # ... (denormalize_image, preprocess_image, generate_attribution, plot_heatmap_and_original functions remain the same) ...
63
  def denormalize_image(tensor: torch.Tensor) -> np.ndarray:
 
64
  if tensor.ndim == 4:
65
  tensor = tensor.squeeze(0).detach()
66
  else:
 
72
  return np.clip(img, 0, 1)
73
 
74
  def preprocess_image(image_bytes: bytes, file_name: str) -> Tuple[Optional[torch.Tensor], Optional[np.ndarray]]:
 
75
 
76
  if file_name.lower().endswith(('.dcm', '.dicom')):
77
  try:
 
103
  return input_tensor.to(DEVICE), image_rgb
104
 
105
  def generate_attribution(model: nn.Module, input_tensor: torch.Tensor, predicted_class_idx: int, n_steps: int = 20) -> np.ndarray:
 
106
  target_class_int = int(predicted_class_idx)
107
  input_tensor.requires_grad_(True)
108
 
 
124
  return attributions_ig_vis
125
 
126
  def plot_heatmap_and_original(original_image: np.ndarray, heatmap: np.ndarray, predicted_label: str):
 
127
  fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
128
  original_image_vis = (original_image.astype(np.float32) / 255.0)
129
 
 
141
  plt.tight_layout()
142
  return fig
143
 
144
+
145
+ # -------------------- CORE DIAGNOSIS FUNCTION (GATED LOGIC) --------------------
146
 
147
  def perform_full_analysis(model, image_bytes, file_name, n_steps_slider):
148
  """Function called by the button to perform heavy computation and update state."""
 
193
  model = load_model(HF_MODEL_URL, DOWNLOAD_MODEL_PATH)
194
 
195
  if model is not None:
196
+ # --- INPUT CONTROLS ---
197
  st.markdown("### Analysis Controls")
198
 
199
  n_steps_slider = st.slider(
 
203
  value=10,
204
  step=5,
205
  key="n_steps_slider_key",
206
+ help="Higher steps (up to 50) provide a smoother, more accurate heatmap but use more CPU."
207
  )
208
  st.markdown("---")
209
 
 
213
  uploaded_file = st.file_uploader(
214
  "Choose a Dicom, PNG, JPG, or JPEG file",
215
  type=["dcm", "dicom", "png", "jpg", "jpeg"],
216
+ key="file_uploader",
217
+ on_change=lambda: setattr(st.session_state, 'results_ready', False) # Reset results on new file upload
218
  )
219
+
 
220
  if uploaded_file is not None:
221
 
222
  # 1. READ AND CACHE INPUT BYTES (Fastest Operation)
223
+ image_bytes = uploaded_file.read()
224
+ file_name = uploaded_file.name
225
 
226
  # Preprocess image for DISPLAY ONLY (fast operation, no lag)
227
+ # This part ensures the image appears instantly without waiting for the model
228
+ _, original_image_rgb = preprocess_image(image_bytes, file_name)
229
 
230
  # --- DISPLAY UPLOADED IMAGE AND BUTTON ---
231
  col1, col2 = st.columns(2)
232
 
233
  with col1:
234
  st.subheader("Uploaded Image")
235
+ # Display the processed NumPy array
236
+ st.image(original_image_rgb, use_container_width=True, caption=file_name)
237
 
238
+ # --- THE GATED BUTTON (CRITICAL FOR STABILITY) ---
239
  st.markdown("---")
240
  if st.button("▶️ RUN FULL DIAGNOSIS & ATTRIBUTION", type="primary", use_container_width=True, key="run_button"):
241
+ # Execute the heavy logic and save to session state
242
+ perform_full_analysis(model, image_bytes, file_name, n_steps_slider)
243
 
244
  # --- RESULTS DISPLAY (STABLE) ---
245
  if st.session_state.results_ready:
 
271
  fig = plot_heatmap_and_original(data['image_rgb'], data['heatmap'], data['label'])
272
  st.pyplot(fig, clear_figure=True, use_container_width=True)
273
 
274
+ st.success("Analysis Complete: The heatmap highlights the regions most critical to the diagnosis.")
275
+