Halemo commited on
Commit
367d7c1
·
verified ·
1 Parent(s): c5e5019

Cache analyze + PDF (fix vibrate-on-resize)

Browse files
Files changed (1) hide show
  1. app.py +27 -13
app.py CHANGED
@@ -312,13 +312,33 @@ def render_probability_bars(probs, top_idx: int) -> None:
312
  st.markdown("\n".join(rows), unsafe_allow_html=True)
313
 
314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  def render_results(image: Image.Image) -> None:
316
- with st.spinner("Analyzing scan…"):
317
- model = get_model()
318
- tensor = preprocess(image).to(DEVICE)
319
- top_idx, probs = predict_with_probs(model, tensor, DEVICE)
320
- target_layer = get_gradcam_target_layer(model)
321
- heatmap = compute_gradcam(model, tensor, target_layer, top_idx, image)
322
 
323
  label = LABELS[top_idx]
324
  confidence = float(probs[top_idx])
@@ -371,13 +391,7 @@ def render_results(image: Image.Image) -> None:
371
  )
372
 
373
  render_step_title(3, "Download report")
374
- pdf_bytes = make_pdf_report(
375
- original_image=image,
376
- heatmap_image=heatmap,
377
- predicted_label=label,
378
- probabilities=probs,
379
- model_version=MODEL_VERSION,
380
- )
381
  cdl, cdr = st.columns([1, 3])
382
  with cdl:
383
  st.download_button(
 
312
  st.markdown("\n".join(rows), unsafe_allow_html=True)
313
 
314
 
315
+ @st.cache_data(show_spinner="Analyzing scan…", max_entries=8)
316
+ def _analyze(mri_bytes: bytes):
317
+ image = Image.open(BytesIO(mri_bytes)).convert("RGB")
318
+ model = get_model()
319
+ tensor = preprocess(image).to(DEVICE)
320
+ top_idx, probs = predict_with_probs(model, tensor, DEVICE)
321
+ target_layer = get_gradcam_target_layer(model)
322
+ heatmap = compute_gradcam(model, tensor, target_layer, top_idx, image)
323
+ return image, top_idx, probs, heatmap
324
+
325
+
326
+ @st.cache_data(show_spinner=False, max_entries=8)
327
+ def _build_pdf(mri_bytes: bytes, predicted_label: str, probabilities_tuple: tuple) -> bytes:
328
+ import numpy as np
329
+ image, _, _, heatmap = _analyze(mri_bytes)
330
+ return make_pdf_report(
331
+ original_image=image,
332
+ heatmap_image=heatmap,
333
+ predicted_label=predicted_label,
334
+ probabilities=np.asarray(probabilities_tuple),
335
+ model_version=MODEL_VERSION,
336
+ )
337
+
338
+
339
  def render_results(image: Image.Image) -> None:
340
+ mri_bytes = st.session_state["mri_bytes"]
341
+ image, top_idx, probs, heatmap = _analyze(mri_bytes)
 
 
 
 
342
 
343
  label = LABELS[top_idx]
344
  confidence = float(probs[top_idx])
 
391
  )
392
 
393
  render_step_title(3, "Download report")
394
+ pdf_bytes = _build_pdf(mri_bytes, label, tuple(float(p) for p in probs))
 
 
 
 
 
 
395
  cdl, cdr = st.columns([1, 3])
396
  with cdl:
397
  st.download_button(