arshenoy commited on
Commit
a44d9db
·
verified ·
1 Parent(s): 81cf5a9

flicker fix rollback

Browse files
Files changed (1) hide show
  1. cerebAI.py +165 -98
cerebAI.py CHANGED
@@ -15,31 +15,27 @@ import pydicom
15
  import io
16
  import gc
17
 
18
- # --- CONFIGURATION & CACHING ---
19
  HF_MODEL_URL = "https://huggingface.co/arshenoy/cerebAI-stroke-model/resolve/main/best_model.pth"
20
  DOWNLOAD_MODEL_PATH = "best_model_cache.pth"
21
  CLASS_LABELS = ['No Stroke', 'Ischemic Stroke', 'Hemorrhagic Stroke']
22
  IMAGE_SIZE = 224
23
  DEVICE = torch.device("cpu")
24
 
25
- # --- INITIALIZE SESSION STATE (NEW) ---
26
- # This ensures variables persist between reruns and controls the UI flow.
27
- if 'output_ready' not in st.session_state:
28
- st.session_state.output_ready = False
29
  st.session_state.results = {}
30
- st.session_state.original_image_rgb = None
31
- st.session_state.input_tensor = None
32
 
33
 
34
- # --- MODEL LOADING (UNCHANGED) ---
35
  @st.cache_resource
36
  def load_model(model_url, local_path):
37
- """Downloads model from URL if not cached, and loads the weights."""
38
  # ... (Model loading logic remains the same) ...
39
  if not os.path.exists(local_path):
40
  st.info(f"Model not found locally. Downloading from remote repository...")
41
  try:
42
- # ... (Download code) ...
43
  response = requests.get(model_url, stream=True)
44
  response.raise_for_status()
45
  with open(local_path, "wb") as f:
@@ -58,60 +54,137 @@ def load_model(model_url, local_path):
58
  model.eval()
59
  return model
60
  except Exception:
61
- st.error(f"Failed to load model weights.")
62
  return None
63
 
64
- # --- HELPER FUNCTIONS (UNCHANGED) ---
65
- # ... (denormalize_image, preprocess_image, generate_attribution, plot_heatmap_and_original functions here) ...
66
 
67
- # -------------------- CORE COMPUTATION FUNCTION (NEW) --------------------
68
-
69
- def run_diagnosis(model, input_bytes, file_name, n_steps):
70
- """
71
- Handles all computation, run inside st.spinner.
72
- Uses st.session_state to pass results out of the function.
73
- """
74
- try:
75
- # Preprocess the uploaded image (The jittering part runs here, ONCE)
76
- input_tensor, original_image_rgb = preprocess_image(input_bytes, file_name)
 
77
 
78
- if input_tensor is None:
79
- st.error("Could not process file. Ensure it is a valid DICOM/PNG/JPG.")
80
- return
 
 
 
 
 
 
 
81
 
82
- # Predict
83
- with torch.no_grad():
84
- output = model(input_tensor)
85
- probabilities = torch.softmax(output, dim=1).squeeze(0).cpu().numpy()
86
- predicted_class_idx = np.argmax(probabilities)
87
-
88
- # Generate Attribution (Heavy Calculation)
89
- heatmap = generate_attribution(model, input_tensor, predicted_class_idx, n_steps=n_steps)
90
-
91
- # --- CLEANUP (CRITICAL FOR RAM) ---
92
- del input_tensor
93
- del output
94
- gc.collect()
95
- torch.cuda.empty_cache()
96
-
97
- # --- STORE FINAL RESULTS ---
98
- st.session_state.results = {
99
- 'label': CLASS_LABELS[predicted_class_idx],
100
- 'confidence': probabilities[predicted_class_idx],
101
- 'probabilities': probabilities,
102
- 'image_rgb': original_image_rgb,
103
- 'heatmap': heatmap,
104
- 'filename': file_name
105
- }
106
- st.session_state.output_ready = True
107
-
108
- except Exception as e:
109
- st.error(f"An unexpected error occurred during analysis: {e}")
110
- st.session_state.output_ready = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
 
113
  # ==============================================================================
114
- # -------------------- STREAMLIT FRONTEND (Final Clean Flow) --------------------
115
  # ==============================================================================
116
 
117
  st.set_page_config(page_title="CerebAI: Stroke Prediction Dashboard", layout="wide")
@@ -124,62 +197,58 @@ if model is not None:
124
  # --- INPUT CONTROLS ---
125
  st.markdown("### Analysis Controls")
126
 
127
- col_upload, col_steps = st.columns([2, 1])
128
-
129
- with col_steps:
130
- n_steps_slider = st.slider(
131
- 'Integration Steps (CPU Load)',
132
- min_value=5, max_value=50, value=10, step=5,
133
- help="Lower steps (5-15) for fast diagnosis; higher steps (30-50) for smoother visualization."
134
- )
135
-
136
- with col_upload:
137
- uploaded_file = st.file_uploader(
138
- "Upload CT Scan Image",
139
- type=["dcm", "dicom", "png", "jpg", "jpeg"],
140
- key="file_uploader" # Use key to manage state changes
141
- )
142
-
143
  st.markdown("---")
 
 
 
 
 
 
 
 
 
 
144
 
145
- # --- DISPLAY INPUT & RUN BUTTON ---
146
  if uploaded_file is not None:
147
-
148
- # NOTE: Read file bytes and display immediately (fast operation)
149
  image_bytes = uploaded_file.read()
150
  file_name = uploaded_file.name
151
-
152
- # Display the file (use a generic path/bytes for fast display)
153
- with col_upload:
154
- st.subheader("Input Image")
155
- st.image(image_bytes, use_container_width=True, caption=file_name)
156
-
157
- # --- THE GATING BUTTON ---
158
- # This button calls the function that does the slow processing
159
- if st.button("▶️ RUN FULL DIAGNOSIS", type="primary", use_container_width=True):
160
- # We need to re-read the file and run the full logic inside the spinner
161
- run_diagnosis(model, image_bytes, file_name, n_steps_slider)
162
-
163
-
164
- # --- RESULTS DISPLAY ---
165
- # This entire block only runs AFTER the button is clicked and the state is updated
166
- if st.session_state.output_ready:
167
  data = st.session_state.results
168
 
169
  st.markdown("---")
170
  st.markdown("### 2. Diagnosis and Interpretation")
171
 
172
- # Stable layout for results
173
  col_metric, col_plot = st.columns([1, 2])
174
 
175
- # --- METRICS ---
176
  with col_metric:
177
  st.subheader("Prediction Summary")
178
  st.metric(
179
- label="Predicted Diagnosis",
180
  value=data['label'],
181
  delta=f"{data['confidence']*100:.2f}% Confidence",
182
- delta_color='normal'
183
  )
184
  st.markdown("---")
185
  st.subheader("Confidence Breakdown")
@@ -189,10 +258,8 @@ if model is not None:
189
  }
190
  st.dataframe(prob_data, hide_index=True, use_container_width=True)
191
 
192
- # --- PLOT ---
193
  with col_plot:
194
- st.subheader("Visual Explanation (Integrated Gradients)")
195
-
196
  fig = plot_heatmap_and_original(data['image_rgb'], data['heatmap'], data['label'])
197
  st.pyplot(fig, clear_figure=True, use_container_width=True)
198
 
 
15
  import io
16
  import gc
17
 
18
+ # --- CONFIGURATION ---
19
  HF_MODEL_URL = "https://huggingface.co/arshenoy/cerebAI-stroke-model/resolve/main/best_model.pth"
20
  DOWNLOAD_MODEL_PATH = "best_model_cache.pth"
21
  CLASS_LABELS = ['No Stroke', 'Ischemic Stroke', 'Hemorrhagic Stroke']
22
  IMAGE_SIZE = 224
23
  DEVICE = torch.device("cpu")
24
 
25
+ # --- INITIALIZE SESSION STATE (NEW: for stable UI) ---
26
+ if 'results_ready' not in st.session_state:
27
+ st.session_state.results_ready = False
 
28
  st.session_state.results = {}
29
+ st.session_state.file_bytes = None
 
30
 
31
 
32
+ # --- MODEL LOADING ---
33
  @st.cache_resource
34
  def load_model(model_url, local_path):
 
35
  # ... (Model loading logic remains the same) ...
36
  if not os.path.exists(local_path):
37
  st.info(f"Model not found locally. Downloading from remote repository...")
38
  try:
 
39
  response = requests.get(model_url, stream=True)
40
  response.raise_for_status()
41
  with open(local_path, "wb") as f:
 
54
  model.eval()
55
  return model
56
  except Exception:
57
+ st.error(f"Failed to load model weights from cache.")
58
  return None
59
 
60
+ # --- CORE HELPER FUNCTIONS (UNCHANGED) ---
 
61
 
62
+ def denormalize_image(tensor: torch.Tensor) -> np.ndarray:
63
+ """Denormalizes a PyTorch tensor for matplotlib visualization."""
64
+ if tensor.ndim == 4:
65
+ tensor = tensor.squeeze(0).detach()
66
+ else:
67
+ tensor = tensor.detach()
68
+
69
+ mean, std = np.array([0.5, 0.5, 0.5]), np.array([0.5, 0.5, 0.5])
70
+ img = tensor.cpu().permute(1, 2, 0).numpy()
71
+ img = (img * std) + mean
72
+ return np.clip(img, 0, 1)
73
 
74
+ def preprocess_image(image_bytes: bytes, file_name: str) -> Tuple[Optional[torch.Tensor], Optional[np.ndarray]]:
75
+ """Loads, processes, and normalizes image, handling DICOM or JPG/PNG."""
76
+
77
+ if file_name.lower().endswith(('.dcm', '.dicom')):
78
+ try:
79
+ dcm = pydicom.dcmread(io.BytesIO(image_bytes))
80
+ pixel_array = dcm.pixel_array.astype(np.int16)
81
+ slope = dcm.RescaleSlope
82
+ intercept = dcm.RescaleIntercept
83
+ pixel_array = pixel_array * slope + intercept
84
 
85
+ window_center = 40
86
+ window_width = 150
87
+ min_hu = window_center - (window_width / 2)
88
+ max_hu = window_center + (window_width / 2)
89
+
90
+ pixel_array[pixel_array < min_hu] = min_hu
91
+ pixel_array[pixel_array > max_hu] = max_hu
92
+ image_grayscale = ((pixel_array - min_hu) / (max_hu - min_hu) * 255).astype(np.uint8)
93
+
94
+ except Exception:
95
+ return None, None
96
+ else:
97
+ image_grayscale = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_GRAYSCALE)
98
+ if image_grayscale is None: return None, None
99
+
100
+ image_rgb = cv2.cvtColor(cv2.resize(image_grayscale, (IMAGE_SIZE, IMAGE_SIZE)), cv2.COLOR_GRAY2RGB)
101
+ image_norm = (image_rgb.astype(np.float32) / 255.0 - 0.5) / 0.5
102
+ input_tensor = torch.tensor(image_norm, dtype=torch.float).permute(2, 0, 1).unsqueeze(0)
103
+
104
+ return input_tensor.to(DEVICE), image_rgb
105
+
106
+ def generate_attribution(model: nn.Module, input_tensor: torch.Tensor, predicted_class_idx: int, n_steps: int = 20) -> np.ndarray:
107
+ """Computes Integrated Gradients for the given input and class."""
108
+ target_class_int = int(predicted_class_idx)
109
+ input_tensor.requires_grad_(True)
110
+
111
+ ig = IntegratedGradients(model)
112
+ baseline = torch.zeros_like(input_tensor).to(DEVICE)
113
+
114
+ attributions_ig = ig.attribute(
115
+ inputs=input_tensor,
116
+ baselines=baseline,
117
+ target=target_class_int,
118
+ n_steps=n_steps
119
+ )
120
+
121
+ attributions_ig_vis = attributions_ig.squeeze(0).sum(dim=0).abs().cpu().detach().numpy()
122
+
123
+ if attributions_ig_vis.max() > 0:
124
+ attributions_ig_vis = attributions_ig_vis / attributions_ig_vis.max()
125
+
126
+ return attributions_ig_vis
127
+
128
+ def plot_heatmap_and_original(original_image: np.ndarray, heatmap: np.ndarray, predicted_label: str):
129
+ """Creates a Matplotlib figure for visualization."""
130
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
131
+ original_image_vis = (original_image.astype(np.float32) / 255.0)
132
+
133
+ ax1.imshow(original_image_vis)
134
+ ax1.set_title("Original CT Scan", fontsize=14)
135
+ ax1.axis('off')
136
+
137
+ ax2.imshow(original_image_vis)
138
+ alpha_mask = heatmap * 0.7 + 0.3
139
+
140
+ ax2.imshow(heatmap, cmap='jet', alpha=alpha_mask, vmin=0, vmax=1)
141
+ ax2.set_title(f"Interpretation: {predicted_label}", fontsize=14)
142
+ ax2.axis('off')
143
+
144
+ plt.tight_layout()
145
+ return fig
146
+
147
+ # -------------------- CORE DIAGNOSIS FUNCTION (TO BE CALLED BY BUTTON) --------------------
148
+
149
+ def perform_full_analysis(model, image_bytes, file_name, n_steps_slider):
150
+ """Function called by the button to perform heavy computation and update state."""
151
+
152
+ # 1. PREPARE INPUTS
153
+ input_tensor, original_image_rgb = preprocess_image(image_bytes, file_name)
154
+
155
+ if input_tensor is None:
156
+ st.error("Could not process file. Ensure it is a valid DICOM/PNG/JPG.")
157
+ st.session_state.results_ready = False
158
+ return
159
+
160
+ # 2. PREDICT
161
+ with torch.no_grad():
162
+ output = model(input_tensor)
163
+ probabilities = torch.softmax(output, dim=1).squeeze(0).cpu().numpy()
164
+ predicted_class_idx = np.argmax(probabilities)
165
+
166
+ # 3. GENERATE ATTRIBUTION (HEAVY PART)
167
+ heatmap = generate_attribution(model, input_tensor, predicted_class_idx, n_steps=n_steps_slider)
168
+
169
+ # 4. CRITICAL MEMORY MANAGEMENT
170
+ del input_tensor
171
+ del output
172
+ gc.collect()
173
+
174
+ # 5. SAVE FINAL RESULTS TO STATE
175
+ st.session_state.results = {
176
+ 'label': CLASS_LABELS[predicted_class_idx],
177
+ 'confidence': probabilities[predicted_class_idx],
178
+ 'probabilities': probabilities,
179
+ 'image_rgb': original_image_rgb,
180
+ 'heatmap': heatmap
181
+ }
182
+ st.session_state.results_ready = True
183
+ st.rerun() # Force a single, clean render of the results
184
 
185
 
186
  # ==============================================================================
187
+ # -------------------- STREAMLIT FRONTEND --------------------
188
  # ==============================================================================
189
 
190
  st.set_page_config(page_title="CerebAI: Stroke Prediction Dashboard", layout="wide")
 
197
  # --- INPUT CONTROLS ---
198
  st.markdown("### Analysis Controls")
199
 
200
+ n_steps_slider = st.slider(
201
+ 'Integration Steps (Affects Accuracy & Speed)',
202
+ min_value=5,
203
+ max_value=50,
204
+ value=10,
205
+ step=5,
206
+ key="n_steps_slider_key", # Added key to prevent slider reset on rerun
207
+ help="Higher steps (up to 50) provide a smoother, more accurate heatmap but use more CPU."
208
+ )
 
 
 
 
 
 
 
209
  st.markdown("---")
210
+
211
+
212
+ # --- FILE UPLOAD ---
213
+ st.markdown("### Upload CT Scan Image")
214
+ uploaded_file = st.file_uploader(
215
+ "Choose a Dicom, PNG, JPG, or JPEG file",
216
+ type=["dcm", "dicom", "png", "jpg", "jpeg"],
217
+ key="file_uploader" # Added key to manage file state
218
+ )
219
+
220
 
221
+ # --- UI FLOW MANAGEMENT ---
222
  if uploaded_file is not None:
 
 
223
  image_bytes = uploaded_file.read()
224
  file_name = uploaded_file.name
225
+
226
+ # Display the image immediately (fast operation, no lag)
227
+ st.subheader("Input Image")
228
+ st.image(image_bytes, use_container_width=True, caption=file_name)
229
+
230
+ # --- GATED BUTTON ---
231
+ if st.button("▶️ RUN DIAGNOSIS & ATTRIBUTION", type="primary", use_container_width=True, key="run_button"):
232
+ with st.spinner(f'Calculating Integrated Gradients ({n_steps_slider} steps)...'):
233
+ # Execute the heavy logic and save to session state
234
+ perform_full_analysis(model, image_bytes, file_name, n_steps_slider)
235
+
236
+ # --- RESULTS DISPLAY (ONLY RUNS AFTER BUTTON CLICK) ---
237
+ if st.session_state.results_ready:
 
 
 
238
  data = st.session_state.results
239
 
240
  st.markdown("---")
241
  st.markdown("### 2. Diagnosis and Interpretation")
242
 
 
243
  col_metric, col_plot = st.columns([1, 2])
244
 
 
245
  with col_metric:
246
  st.subheader("Prediction Summary")
247
  st.metric(
248
+ label="Diagnosis",
249
  value=data['label'],
250
  delta=f"{data['confidence']*100:.2f}% Confidence",
251
+ delta_color='normal'
252
  )
253
  st.markdown("---")
254
  st.subheader("Confidence Breakdown")
 
258
  }
259
  st.dataframe(prob_data, hide_index=True, use_container_width=True)
260
 
 
261
  with col_plot:
262
+ st.subheader("Visual Explanation")
 
263
  fig = plot_heatmap_and_original(data['image_rgb'], data['heatmap'], data['label'])
264
  st.pyplot(fig, clear_figure=True, use_container_width=True)
265