srivarra commited on
Commit
ec6b668
Β·
1 Parent(s): 0b5ee75

updated datasets

Browse files
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: waveOrder Phase Reconstruction Viewer
3
  emoji: πŸ”¬
4
  python_version: 3.13
5
  colorFrom: blue
@@ -17,7 +17,7 @@ tags:
17
  - scientific-visualization
18
  ---
19
 
20
- # waveOrder: Phase Reconstruction Interactive Viewer
21
 
22
  <div align="center">
23
 
@@ -29,13 +29,13 @@ tags:
29
 
30
  ## πŸ“„ Paper
31
 
32
- **waveOrder: generalist framework for label-agnostic computational microscopy**
33
  Chandler T., Ivanov I.E., Hirata-Miyasaki E., et al. "WaveOrder: Physics-informed ML for auto-tuned multi-contrast computational microscopy from cells to
34
  organisms." [arXiv:2412.09775](https://arxiv.org/abs/2412.09775) (2025)
35
 
36
  ## πŸ”¬ About
37
 
38
- Interactive web interface for exploring phase reconstruction from quantitative label-free microscopy data. This demo showcases the waveOrder framework's capabilities for reconstructing phase contrast images with interactive parameter optimization.
39
 
40
  ### Features
41
 
@@ -67,7 +67,7 @@ This demo uses concatenated 20x objective microscopy data from high-content scre
67
 
68
  - **Paper**: [arXiv:2412.09775](https://arxiv.org/abs/2412.09775)
69
  - **GitHub Repository**: [mehta-lab/waveorder](https://github.com/mehta-lab/waveorder)
70
- - **Documentation**: [waveOrder Docs](https://github.com/mehta-lab/waveorder/tree/main/docs)
71
 
72
  ## πŸ“ Citation
73
 
 
1
  ---
2
+ title: WaveOrder
3
  emoji: πŸ”¬
4
  python_version: 3.13
5
  colorFrom: blue
 
17
  - scientific-visualization
18
  ---
19
 
20
+ # WaveOrder
21
 
22
  <div align="center">
23
 
 
29
 
30
  ## πŸ“„ Paper
31
 
32
+ **WaveOrder: generalist framework for label-agnostic computational microscopy**
33
  Chandler T., Ivanov I.E., Hirata-Miyasaki E., et al. "WaveOrder: Physics-informed ML for auto-tuned multi-contrast computational microscopy from cells to
34
  organisms." [arXiv:2412.09775](https://arxiv.org/abs/2412.09775) (2025)
35
 
36
  ## πŸ”¬ About
37
 
38
+ Interactive web interface for exploring phase reconstruction from quantitative label-free microscopy data. This demo showcases the WaveOrder framework's capabilities for reconstructing phase contrast images with interactive parameter optimization.
39
 
40
  ### Features
41
 
 
67
 
68
  - **Paper**: [arXiv:2412.09775](https://arxiv.org/abs/2412.09775)
69
  - **GitHub Repository**: [mehta-lab/waveorder](https://github.com/mehta-lab/waveorder)
70
+ - **Documentation**: [WaveOrder Docs](https://github.com/mehta-lab/waveorder/tree/main/docs)
71
 
72
  ## πŸ“ Citation
73
 
data/20x.zarr/A/1/{002026 β†’ 005028}/0/c/0/0/0/0/0 RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d195a11058e39269e8fe901a940ceac4dfdf2da2b529c7567fc5be994e4739e1
3
- size 10762964
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59b45f07dbdd34d19cb7a5be46263cc4a0c2e24724edd9b3de9003055a0308ba
3
+ size 2767879
data/20x.zarr/A/1/{002027 β†’ 005028}/0/zarr.json RENAMED
@@ -3,8 +3,8 @@
3
  1,
4
  1,
5
  7,
6
- 1024,
7
- 1024
8
  ],
9
  "data_type": "uint16",
10
  "chunk_grid": {
@@ -14,8 +14,8 @@
14
  1,
15
  1,
16
  7,
17
- 1024,
18
- 1024
19
  ]
20
  }
21
  },
 
3
  1,
4
  1,
5
  7,
6
+ 512,
7
+ 512
8
  ],
9
  "data_type": "uint16",
10
  "chunk_grid": {
 
14
  1,
15
  1,
16
  7,
17
+ 512,
18
+ 512
19
  ]
20
  }
21
  },
data/20x.zarr/A/1/{002028 β†’ 005028}/zarr.json RENAMED
@@ -53,7 +53,7 @@
53
  "omero": {
54
  "version": "0.5",
55
  "id": 0,
56
- "name": "002028",
57
  "channels": [
58
  {
59
  "active": true,
 
53
  "omero": {
54
  "version": "0.5",
55
  "id": 0,
56
+ "name": "005028",
57
  "channels": [
58
  {
59
  "active": true,
data/20x.zarr/A/1/{002028 β†’ 005029}/0/c/0/0/0/0/0 RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:cd651a8f2de408b6c9f0827627a0428a34e9b79c0ec49d491f9c91e3ae45e4a7
3
- size 10796944
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae859d743adcc5ca961730062f51e186b2b81c6a476b03cb026c91484773c48e
3
+ size 2782178
data/20x.zarr/A/1/{002026 β†’ 005029}/0/zarr.json RENAMED
@@ -3,8 +3,8 @@
3
  1,
4
  1,
5
  7,
6
- 1024,
7
- 1024
8
  ],
9
  "data_type": "uint16",
10
  "chunk_grid": {
@@ -14,8 +14,8 @@
14
  1,
15
  1,
16
  7,
17
- 1024,
18
- 1024
19
  ]
20
  }
21
  },
 
3
  1,
4
  1,
5
  7,
6
+ 512,
7
+ 512
8
  ],
9
  "data_type": "uint16",
10
  "chunk_grid": {
 
14
  1,
15
  1,
16
  7,
17
+ 512,
18
+ 512
19
  ]
20
  }
21
  },
data/20x.zarr/A/1/{002026 β†’ 005029}/zarr.json RENAMED
@@ -53,7 +53,7 @@
53
  "omero": {
54
  "version": "0.5",
55
  "id": 0,
56
- "name": "002026",
57
  "channels": [
58
  {
59
  "active": true,
 
53
  "omero": {
54
  "version": "0.5",
55
  "id": 0,
56
+ "name": "005029",
57
  "channels": [
58
  {
59
  "active": true,
data/20x.zarr/A/1/{002027 β†’ 005030}/0/c/0/0/0/0/0 RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b7193f55c6f1e53fef903483bbcb9b38c51624b8807b7179997147f5d4b1d7f3
3
- size 10768295
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7090092f581a783fcc89c31a2da439d4ce06e5f4c0ef9caa8ce1657ff12edf5d
3
+ size 2786497
data/20x.zarr/A/1/{002028 β†’ 005030}/0/zarr.json RENAMED
@@ -3,8 +3,8 @@
3
  1,
4
  1,
5
  7,
6
- 1024,
7
- 1024
8
  ],
9
  "data_type": "uint16",
10
  "chunk_grid": {
@@ -14,8 +14,8 @@
14
  1,
15
  1,
16
  7,
17
- 1024,
18
- 1024
19
  ]
20
  }
21
  },
 
3
  1,
4
  1,
5
  7,
6
+ 512,
7
+ 512
8
  ],
9
  "data_type": "uint16",
10
  "chunk_grid": {
 
14
  1,
15
  1,
16
  7,
17
+ 512,
18
+ 512
19
  ]
20
  }
21
  },
data/20x.zarr/A/1/{002027 β†’ 005030}/zarr.json RENAMED
@@ -53,7 +53,7 @@
53
  "omero": {
54
  "version": "0.5",
55
  "id": 0,
56
- "name": "002027",
57
  "channels": [
58
  {
59
  "active": true,
 
53
  "omero": {
54
  "version": "0.5",
55
  "id": 0,
56
+ "name": "005030",
57
  "channels": [
58
  {
59
  "active": true,
data/20x.zarr/A/1/zarr.json CHANGED
@@ -6,15 +6,15 @@
6
  "images": [
7
  {
8
  "acquisition": 0,
9
- "path": "002026"
10
  },
11
  {
12
  "acquisition": 0,
13
- "path": "002027"
14
  },
15
  {
16
  "acquisition": 0,
17
- "path": "002028"
18
  }
19
  ]
20
  },
 
6
  "images": [
7
  {
8
  "acquisition": 0,
9
+ "path": "005028"
10
  },
11
  {
12
  "acquisition": 0,
13
+ "path": "005029"
14
  },
15
  {
16
  "acquisition": 0,
17
+ "path": "005030"
18
  }
19
  ]
20
  },
demo_utils.py CHANGED
@@ -25,6 +25,9 @@ from xarray_ome import open_ome_dataset
25
 
26
  from waveorder import util
27
  from waveorder.models import isotropic_thin_3d
 
 
 
28
 
29
  # Type alias for device specification
30
  Device = torch.device | str | None
@@ -467,8 +470,8 @@ def run_reconstruction(zyx_tile: torch.Tensor, recon_args: dict) -> torch.Tensor
467
  """
468
  Run phase reconstruction on a Z-stack.
469
 
470
- Takes a 3D stack (Z, Y, X) and produces a 2D phase reconstruction (Y, X).
471
- Device is inferred from the input tensor.
472
 
473
  Parameters
474
  ----------
@@ -491,18 +494,37 @@ def run_reconstruction(zyx_tile: torch.Tensor, recon_args: dict) -> torch.Tensor
491
  # Infer device from input tensor
492
  device = zyx_tile.device
493
 
494
- # Prepare transfer function arguments - ensure all tensors are on the same device
495
- tf_args = {}
496
- for key, value in recon_args.items():
497
- if isinstance(value, torch.Tensor):
498
- tf_args[key] = value.to(device)
499
- else:
500
- tf_args[key] = value
501
-
502
  Z, _, _ = zyx_tile.shape
503
- tf_args["z_position_list"] = (
504
- torch.arange(Z, device=device) - (Z // 2) + tf_args["z_offset"]
505
- ) * tf_args["z_scale"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
  tf_args.pop("z_offset")
507
  tf_args.pop("z_scale")
508
 
 
25
 
26
  from waveorder import util
27
  from waveorder.models import isotropic_thin_3d
28
+ from waveorder.cli.compute_transfer_function import (
29
+ _position_list_from_shape_scale_offset,
30
+ )
31
 
32
  # Type alias for device specification
33
  Device = torch.device | str | None
 
470
  """
471
  Run phase reconstruction on a Z-stack.
472
 
473
+ Uses waveorder's official _position_list_from_shape_scale_offset
474
+ to ensure proper z-position calculation and correct phase sign.
475
 
476
  Parameters
477
  ----------
 
494
  # Infer device from input tensor
495
  device = zyx_tile.device
496
 
497
+ # Prepare transfer function arguments
498
+ tf_args = recon_args.copy()
 
 
 
 
 
 
499
  Z, _, _ = zyx_tile.shape
500
+
501
+ # Extract z_offset value (keep as tensor if it is one, for gradient flow)
502
+ z_offset_value = recon_args["z_offset"]
503
+ if torch.is_tensor(z_offset_value):
504
+ # For optimization: extract scalar value for _position_list function
505
+ z_offset_scalar = z_offset_value.item()
506
+ else:
507
+ z_offset_scalar = z_offset_value
508
+
509
+ # Use waveorder's official function (returns torch.Tensor on CPU)
510
+ z_position_list_cpu = _position_list_from_shape_scale_offset(
511
+ shape=Z,
512
+ scale=recon_args["z_scale"],
513
+ offset=z_offset_scalar,
514
+ )
515
+
516
+ # Move to device and ensure gradient connection if z_offset is a parameter
517
+ if torch.is_tensor(z_offset_value) and z_offset_value.requires_grad:
518
+ # Recompute on device to maintain gradient connection
519
+ # Uses same formula as waveorder: -arange(Z) + (Z // 2) + offset
520
+ z_position_list = (
521
+ -torch.arange(Z, dtype=torch.float32, device=device) + (Z // 2) + z_offset_value
522
+ ) * recon_args["z_scale"]
523
+ else:
524
+ # No gradient needed, just move to device
525
+ z_position_list = z_position_list_cpu.to(device)
526
+
527
+ tf_args["z_position_list"] = z_position_list
528
  tf_args.pop("z_offset")
529
  tf_args.pop("z_scale")
530
 
optimize_demo.py CHANGED
@@ -33,10 +33,10 @@ class Config:
33
  # Default FOV selection
34
  DEFAULT_ROW = "A"
35
  DEFAULT_COLUMN = "1"
36
- DEFAULT_FIELD = "002026"
37
 
38
- # Restrict to specific FOVs (filter large plate)
39
- ALLOWED_FOVS = ['002026', '002027', '002028']
40
 
41
  # Channel selection (only BF channel in concatenated data)
42
  CHANNEL = 0 # BF is now channel 0 (GFP was filtered out during concatenation)
@@ -71,11 +71,11 @@ class Config:
71
 
72
  # UI slider ranges
73
  SLIDER_RANGES = {
74
- "z_offset": (-0.5, 0.5, 0.01),
75
  "na_detection": (0.05, 0.65, 0.001), # Max 0.65 to accommodate optimization
76
  "na_illumination": (0.05, 0.65, 0.001), # Max 0.65 (but constrained <= NA_detection)
77
- "tilt_zenith": (0.0, np.pi / 2, 0.005),
78
- "tilt_azimuth": (0.0, 2 * np.pi, 0.001),
79
  }
80
 
81
  # UI configuration
@@ -195,6 +195,23 @@ def get_slice_for_preview(z: int, data_xr_state):
195
  return (slice_img, slice_img) # Preview mode: both sides show same image
196
 
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  # ============================================================================
199
  # RECONSTRUCTION CALLBACKS
200
  # ============================================================================
@@ -236,15 +253,25 @@ def run_reconstruction_ui(
236
  zyx_stack, pixel_scales_state, Config.RECON_CONFIG, param_values
237
  )
238
 
239
- # Return updated image slider (no optimization results)
240
- return (original_normalized, reconstructed_image)
241
 
242
 
243
- def run_optimization_ui(z: int, data_xr_state, pixel_scales_state):
 
 
 
 
 
 
 
 
 
 
244
  """
245
  Run OPTIMIZATION and stream updates to UI with iteration caching.
246
 
247
- Uses OPTIMIZABLE_PARAMS as initial guesses, runs full optimization loop.
248
  Yields progressive updates for ImageSlider, loss plot, status,
249
  iteration history, iteration slider, and SLIDER UPDATES.
250
  """
@@ -256,6 +283,35 @@ def run_optimization_ui(z: int, data_xr_state, pixel_scales_state):
256
  data_xr_state, t=0, c=Config.CHANNEL, z=int(z), normalize=True, verbose=False
257
  )
258
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  # Initialize tracking
260
  loss_history = []
261
  iteration_cache = []
@@ -273,15 +329,16 @@ def run_optimization_ui(z: int, data_xr_state, pixel_scales_state):
273
  gr.skip(), # na_ill
274
  gr.skip(), # tilt_zenith
275
  gr.skip(), # tilt_azimuth
 
276
  )
277
 
278
- # Run optimization with streaming
279
  for result in run_optimization_streaming(
280
  zyx_stack,
281
  pixel_scales_state,
282
  Config.RECON_CONFIG,
283
- Config.OPTIMIZABLE_PARAMS,
284
- num_iterations=Config.RECON_CONFIG["num_iterations"],
285
  ):
286
  # Current iteration number
287
  n = result["iteration"]
@@ -301,9 +358,39 @@ def run_optimization_ui(z: int, data_xr_state, pixel_scales_state):
301
  loss_history.append({"iteration": int(n), "loss": result["loss"]})
302
 
303
  # Format iteration info
304
- info_md = f"**Iteration {n}/{Config.RECON_CONFIG['num_iterations']}** | Loss: `{result['loss']:.2e}`"
305
-
306
- # Yield updates - update ImageSlider AND sliders with latest params
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  yield (
308
  (original_normalized, result["reconstructed_image"]), # Update ImageSlider
309
  pd.DataFrame(loss_history), # Loss plot
@@ -317,12 +404,13 @@ def run_optimization_ui(z: int, data_xr_state, pixel_scales_state):
317
  interactive=True,
318
  ),
319
  gr.Markdown(value=info_md, visible=True), # Show iteration info
320
- # Update parameter sliders with optimized values:
321
- result["params"].get("z_offset", gr.skip()),
322
- result["params"].get("numerical_aperture_detection", gr.skip()),
323
- result["params"].get("numerical_aperture_illumination", gr.skip()),
324
- result["params"].get("tilt_angle_zenith", gr.skip()),
325
- result["params"].get("tilt_angle_azimuth", gr.skip()),
 
326
  )
327
 
328
  # Final yield (keep last state)
@@ -340,6 +428,7 @@ def run_optimization_ui(z: int, data_xr_state, pixel_scales_state):
340
  gr.skip(), # Keep na_ill
341
  gr.skip(), # Keep tilt_zenith
342
  gr.skip(), # Keep tilt_azimuth
 
343
  )
344
 
345
 
@@ -361,13 +450,34 @@ def scrub_iterations(iteration_idx: int, history: list):
361
  # Update info display
362
  info_md = f"**Iteration {selected['iteration']}/{len(history)}** | Loss: `{selected['loss']:.2e}`"
363
 
364
- # Extract parameter values at this iteration
 
365
  params = selected["params"]
366
- z_offset = params.get("z_offset", 0.0)
367
- na_det = params.get("numerical_aperture_detection", 0.55)
368
- na_ill = params.get("numerical_aperture_illumination", 0.54)
369
- tilt_zenith = params.get("tilt_angle_zenith", 0.0)
370
- tilt_azimuth = params.get("tilt_angle_azimuth", 0.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
 
372
  return comparison, info_md, z_offset, na_det, na_ill, tilt_zenith, tilt_azimuth
373
 
@@ -389,7 +499,7 @@ def create_gradio_interface(plate_metadata, default_fields, data_xr, pixel_scale
389
  """Build the Gradio interface with all components and event wiring."""
390
 
391
  with gr.Blocks() as demo:
392
- gr.Markdown("# waveOrder Phase Reconstruction Demo")
393
  gr.Markdown(
394
  "**Paper:** Chandler T., Ivanov I.E., Hirata-Miyasaki E., et al. \"WaveOrder: Physics-informed ML for auto-tuned multi-contrast computational microscopy from cells to organisms.\" "
395
  "[arXiv:2412.09775](https://arxiv.org/abs/2412.09775) (2025)\n\n"
@@ -500,6 +610,16 @@ def create_gradio_interface(plate_metadata, default_fields, data_xr, pixel_scale
500
  # Section 4: Reconstruction Actions
501
  gr.Markdown("### πŸ”¬ Phase Reconstruction")
502
 
 
 
 
 
 
 
 
 
 
 
503
  with gr.Row():
504
  optimize_btn = gr.Button(
505
  "⚑ Optimize Parameters", variant="secondary", size="lg"
@@ -539,6 +659,7 @@ def create_gradio_interface(plate_metadata, default_fields, data_xr, pixel_scale
539
  iteration_history = gr.State(value=[])
540
  current_data_xr = gr.State(value=data_xr)
541
  current_pixel_scales = gr.State(value=pixel_scales)
 
542
 
543
  gr.Markdown("---")
544
 
@@ -554,6 +675,7 @@ def create_gradio_interface(plate_metadata, default_fields, data_xr, pixel_scale
554
  na_ill_slider,
555
  tilt_zenith_slider,
556
  tilt_azimuth_slider,
 
557
  optimize_btn,
558
  reconstruct_btn,
559
  loss_plot,
@@ -562,6 +684,7 @@ def create_gradio_interface(plate_metadata, default_fields, data_xr, pixel_scale
562
  iteration_history,
563
  current_data_xr,
564
  current_pixel_scales,
 
565
  plate_metadata,
566
  )
567
 
@@ -579,6 +702,7 @@ def _wire_event_handlers(
579
  na_ill_slider,
580
  tilt_zenith_slider,
581
  tilt_azimuth_slider,
 
582
  optimize_btn,
583
  reconstruct_btn,
584
  loss_plot,
@@ -587,6 +711,7 @@ def _wire_event_handlers(
587
  iteration_history,
588
  current_data_xr,
589
  current_pixel_scales,
 
590
  plate_metadata,
591
  ):
592
  """Wire all Gradio event handlers."""
@@ -610,22 +735,34 @@ def _wire_event_handlers(
610
  outputs=[na_ill_slider],
611
  )
612
 
613
- # Image viewer for Z navigation (preview mode: same image twice)
 
614
  demo.load(
615
  fn=get_slice_for_preview,
616
  inputs=[z_slider, current_data_xr],
617
  outputs=image_viewer,
618
  )
 
619
  z_slider.change(
620
- fn=get_slice_for_preview,
621
- inputs=[z_slider, current_data_xr],
622
  outputs=image_viewer,
623
  )
624
 
625
  # Reconstruction buttons
626
  optimize_btn.click(
627
  fn=run_optimization_ui,
628
- inputs=[z_slider, current_data_xr, current_pixel_scales],
 
 
 
 
 
 
 
 
 
 
629
  outputs=[
630
  image_viewer,
631
  loss_plot,
@@ -637,6 +774,7 @@ def _wire_event_handlers(
637
  na_ill_slider,
638
  tilt_zenith_slider,
639
  tilt_azimuth_slider,
 
640
  ],
641
  )
642
 
@@ -652,7 +790,7 @@ def _wire_event_handlers(
652
  current_data_xr,
653
  current_pixel_scales,
654
  ],
655
- outputs=[image_viewer],
656
  )
657
 
658
  # Iteration scrubbing - updates image AND all parameter sliders
 
33
  # Default FOV selection
34
  DEFAULT_ROW = "A"
35
  DEFAULT_COLUMN = "1"
36
+ DEFAULT_FIELD = "005029" # Center FOV
37
 
38
+ # Restrict to specific FOVs (center of well A/1 for better quality)
39
+ ALLOWED_FOVS = ['005028', '005029', '005030']
40
 
41
  # Channel selection (only BF channel in concatenated data)
42
  CHANNEL = 0 # BF is now channel 0 (GFP was filtered out during concatenation)
 
71
 
72
  # UI slider ranges
73
  SLIDER_RANGES = {
74
+ "z_offset": (-3.0, 3.0, 0.01), # Β±3 Β΅m (1.5x Z-slice spacing for focus correction)
75
  "na_detection": (0.05, 0.65, 0.001), # Max 0.65 to accommodate optimization
76
  "na_illumination": (0.05, 0.65, 0.001), # Max 0.65 (but constrained <= NA_detection)
77
+ "tilt_zenith": (0.0, np.pi / 4, 0.005),
78
+ "tilt_azimuth": (0.0, np.pi / 4, 0.001),
79
  }
80
 
81
  # UI configuration
 
195
  return (slice_img, slice_img) # Preview mode: both sides show same image
196
 
197
 
198
+ def update_original_slice_only(z: int, data_xr_state, current_reconstructed_state):
199
+ """
200
+ Update only the left (original) image when Z changes, keep reconstruction on right.
201
+
202
+ If no reconstruction exists yet, shows the original on both sides.
203
+ """
204
+ slice_img = extract_2d_slice(
205
+ data_xr_state, t=0, c=Config.CHANNEL, z=int(z), normalize=True, verbose=False
206
+ )
207
+
208
+ # If there's a reconstruction, keep it on the right; otherwise show original on both sides
209
+ if current_reconstructed_state is not None:
210
+ return (slice_img, current_reconstructed_state)
211
+ else:
212
+ return (slice_img, slice_img)
213
+
214
+
215
  # ============================================================================
216
  # RECONSTRUCTION CALLBACKS
217
  # ============================================================================
 
253
  zyx_stack, pixel_scales_state, Config.RECON_CONFIG, param_values
254
  )
255
 
256
+ # Return updated image slider AND reconstructed state
257
+ return (original_normalized, reconstructed_image), reconstructed_image
258
 
259
 
260
+ def run_optimization_ui(
261
+ z: int,
262
+ num_iterations: int,
263
+ z_offset: float,
264
+ na_det: float,
265
+ na_ill: float,
266
+ tilt_zenith: float,
267
+ tilt_azimuth: float,
268
+ data_xr_state,
269
+ pixel_scales_state,
270
+ ):
271
  """
272
  Run OPTIMIZATION and stream updates to UI with iteration caching.
273
 
274
+ Uses current slider values as initial guesses, runs full optimization loop.
275
  Yields progressive updates for ImageSlider, loss plot, status,
276
  iteration history, iteration slider, and SLIDER UPDATES.
277
  """
 
283
  data_xr_state, t=0, c=Config.CHANNEL, z=int(z), normalize=True, verbose=False
284
  )
285
 
286
+ # Build optimizable params with current slider values as initial values
287
+ optimizable_params_with_slider_values = {
288
+ "z_offset": (
289
+ Config.OPTIMIZABLE_PARAMS["z_offset"][0], # enabled flag
290
+ z_offset, # initial value from slider
291
+ Config.OPTIMIZABLE_PARAMS["z_offset"][2], # learning rate
292
+ ),
293
+ "numerical_aperture_detection": (
294
+ Config.OPTIMIZABLE_PARAMS["numerical_aperture_detection"][0],
295
+ na_det,
296
+ Config.OPTIMIZABLE_PARAMS["numerical_aperture_detection"][2],
297
+ ),
298
+ "numerical_aperture_illumination": (
299
+ Config.OPTIMIZABLE_PARAMS["numerical_aperture_illumination"][0],
300
+ na_ill,
301
+ Config.OPTIMIZABLE_PARAMS["numerical_aperture_illumination"][2],
302
+ ),
303
+ "tilt_angle_zenith": (
304
+ Config.OPTIMIZABLE_PARAMS["tilt_angle_zenith"][0],
305
+ tilt_zenith,
306
+ Config.OPTIMIZABLE_PARAMS["tilt_angle_zenith"][2],
307
+ ),
308
+ "tilt_angle_azimuth": (
309
+ Config.OPTIMIZABLE_PARAMS["tilt_angle_azimuth"][0],
310
+ tilt_azimuth,
311
+ Config.OPTIMIZABLE_PARAMS["tilt_angle_azimuth"][2],
312
+ ),
313
+ }
314
+
315
  # Initialize tracking
316
  loss_history = []
317
  iteration_cache = []
 
329
  gr.skip(), # na_ill
330
  gr.skip(), # tilt_zenith
331
  gr.skip(), # tilt_azimuth
332
+ None, # No reconstructed image yet
333
  )
334
 
335
+ # Run optimization with streaming (using slider values as initial values)
336
  for result in run_optimization_streaming(
337
  zyx_stack,
338
  pixel_scales_state,
339
  Config.RECON_CONFIG,
340
+ optimizable_params_with_slider_values,
341
+ num_iterations=num_iterations,
342
  ):
343
  # Current iteration number
344
  n = result["iteration"]
 
358
  loss_history.append({"iteration": int(n), "loss": result["loss"]})
359
 
360
  # Format iteration info
361
+ info_md = f"**Iteration {n}/{num_iterations}** | Loss: `{result['loss']:.2e}`"
362
+
363
+ # Clip optimized parameters to slider ranges (avoid Gradio validation errors)
364
+ # Convert to float to ensure Gradio compatibility
365
+ clipped_params = {
366
+ "z_offset": float(np.clip(
367
+ result["params"].get("z_offset", 0.0),
368
+ Config.SLIDER_RANGES["z_offset"][0],
369
+ Config.SLIDER_RANGES["z_offset"][1],
370
+ )),
371
+ "numerical_aperture_detection": float(np.clip(
372
+ result["params"].get("numerical_aperture_detection", 0.55),
373
+ Config.SLIDER_RANGES["na_detection"][0],
374
+ Config.SLIDER_RANGES["na_detection"][1],
375
+ )),
376
+ "numerical_aperture_illumination": float(np.clip(
377
+ result["params"].get("numerical_aperture_illumination", 0.54),
378
+ Config.SLIDER_RANGES["na_illumination"][0],
379
+ Config.SLIDER_RANGES["na_illumination"][1],
380
+ )),
381
+ "tilt_angle_zenith": float(np.clip(
382
+ result["params"].get("tilt_angle_zenith", 0.0),
383
+ Config.SLIDER_RANGES["tilt_zenith"][0],
384
+ Config.SLIDER_RANGES["tilt_zenith"][1],
385
+ )),
386
+ "tilt_angle_azimuth": float(np.clip(
387
+ result["params"].get("tilt_angle_azimuth", 0.0),
388
+ Config.SLIDER_RANGES["tilt_azimuth"][0],
389
+ Config.SLIDER_RANGES["tilt_azimuth"][1],
390
+ )),
391
+ }
392
+
393
+ # Yield updates - update ImageSlider AND sliders with clipped params
394
  yield (
395
  (original_normalized, result["reconstructed_image"]), # Update ImageSlider
396
  pd.DataFrame(loss_history), # Loss plot
 
404
  interactive=True,
405
  ),
406
  gr.Markdown(value=info_md, visible=True), # Show iteration info
407
+ # Update parameter sliders with clipped optimized values:
408
+ clipped_params["z_offset"],
409
+ clipped_params["numerical_aperture_detection"],
410
+ clipped_params["numerical_aperture_illumination"],
411
+ clipped_params["tilt_angle_zenith"],
412
+ clipped_params["tilt_angle_azimuth"],
413
+ result["reconstructed_image"], # Update reconstructed image state
414
  )
415
 
416
  # Final yield (keep last state)
 
428
  gr.skip(), # Keep na_ill
429
  gr.skip(), # Keep tilt_zenith
430
  gr.skip(), # Keep tilt_azimuth
431
+ gr.skip(), # Keep reconstructed image state
432
  )
433
 
434
 
 
450
  # Update info display
451
  info_md = f"**Iteration {selected['iteration']}/{len(history)}** | Loss: `{selected['loss']:.2e}`"
452
 
453
+ # Extract parameter values at this iteration and clip to slider ranges
454
+ # Convert to float to ensure Gradio compatibility
455
  params = selected["params"]
456
+ z_offset = float(np.clip(
457
+ params.get("z_offset", 0.0),
458
+ Config.SLIDER_RANGES["z_offset"][0],
459
+ Config.SLIDER_RANGES["z_offset"][1],
460
+ ))
461
+ na_det = float(np.clip(
462
+ params.get("numerical_aperture_detection", 0.55),
463
+ Config.SLIDER_RANGES["na_detection"][0],
464
+ Config.SLIDER_RANGES["na_detection"][1],
465
+ ))
466
+ na_ill = float(np.clip(
467
+ params.get("numerical_aperture_illumination", 0.54),
468
+ Config.SLIDER_RANGES["na_illumination"][0],
469
+ Config.SLIDER_RANGES["na_illumination"][1],
470
+ ))
471
+ tilt_zenith = float(np.clip(
472
+ params.get("tilt_angle_zenith", 0.0),
473
+ Config.SLIDER_RANGES["tilt_zenith"][0],
474
+ Config.SLIDER_RANGES["tilt_zenith"][1],
475
+ ))
476
+ tilt_azimuth = float(np.clip(
477
+ params.get("tilt_angle_azimuth", 0.0),
478
+ Config.SLIDER_RANGES["tilt_azimuth"][0],
479
+ Config.SLIDER_RANGES["tilt_azimuth"][1],
480
+ ))
481
 
482
  return comparison, info_md, z_offset, na_det, na_ill, tilt_zenith, tilt_azimuth
483
 
 
499
  """Build the Gradio interface with all components and event wiring."""
500
 
501
  with gr.Blocks() as demo:
502
+ gr.Markdown("# WaveOrder")
503
  gr.Markdown(
504
  "**Paper:** Chandler T., Ivanov I.E., Hirata-Miyasaki E., et al. \"WaveOrder: Physics-informed ML for auto-tuned multi-contrast computational microscopy from cells to organisms.\" "
505
  "[arXiv:2412.09775](https://arxiv.org/abs/2412.09775) (2025)\n\n"
 
610
  # Section 4: Reconstruction Actions
611
  gr.Markdown("### πŸ”¬ Phase Reconstruction")
612
 
613
+ # Number of optimization iterations control
614
+ num_iterations_slider = gr.Slider(
615
+ minimum=1,
616
+ maximum=50,
617
+ value=Config.RECON_CONFIG["num_iterations"],
618
+ step=1,
619
+ label="Optimization Iterations",
620
+ info="Number of gradient descent iterations (more = better quality, slower)",
621
+ )
622
+
623
  with gr.Row():
624
  optimize_btn = gr.Button(
625
  "⚑ Optimize Parameters", variant="secondary", size="lg"
 
659
  iteration_history = gr.State(value=[])
660
  current_data_xr = gr.State(value=data_xr)
661
  current_pixel_scales = gr.State(value=pixel_scales)
662
+ current_reconstructed = gr.State(value=None) # Stores the current reconstructed image
663
 
664
  gr.Markdown("---")
665
 
 
675
  na_ill_slider,
676
  tilt_zenith_slider,
677
  tilt_azimuth_slider,
678
+ num_iterations_slider,
679
  optimize_btn,
680
  reconstruct_btn,
681
  loss_plot,
 
684
  iteration_history,
685
  current_data_xr,
686
  current_pixel_scales,
687
+ current_reconstructed,
688
  plate_metadata,
689
  )
690
 
 
702
  na_ill_slider,
703
  tilt_zenith_slider,
704
  tilt_azimuth_slider,
705
+ num_iterations_slider,
706
  optimize_btn,
707
  reconstruct_btn,
708
  loss_plot,
 
711
  iteration_history,
712
  current_data_xr,
713
  current_pixel_scales,
714
+ current_reconstructed,
715
  plate_metadata,
716
  ):
717
  """Wire all Gradio event handlers."""
 
735
  outputs=[na_ill_slider],
736
  )
737
 
738
+ # Image viewer for Z navigation
739
+ # On load: show preview mode (no reconstruction yet)
740
  demo.load(
741
  fn=get_slice_for_preview,
742
  inputs=[z_slider, current_data_xr],
743
  outputs=image_viewer,
744
  )
745
+ # On Z change: update only left (original) image, keep reconstruction on right
746
  z_slider.change(
747
+ fn=update_original_slice_only,
748
+ inputs=[z_slider, current_data_xr, current_reconstructed],
749
  outputs=image_viewer,
750
  )
751
 
752
  # Reconstruction buttons
753
  optimize_btn.click(
754
  fn=run_optimization_ui,
755
+ inputs=[
756
+ z_slider,
757
+ num_iterations_slider,
758
+ z_offset_slider,
759
+ na_det_slider,
760
+ na_ill_slider,
761
+ tilt_zenith_slider,
762
+ tilt_azimuth_slider,
763
+ current_data_xr,
764
+ current_pixel_scales,
765
+ ],
766
  outputs=[
767
  image_viewer,
768
  loss_plot,
 
774
  na_ill_slider,
775
  tilt_zenith_slider,
776
  tilt_azimuth_slider,
777
+ current_reconstructed, # Update reconstructed state
778
  ],
779
  )
780
 
 
790
  current_data_xr,
791
  current_pixel_scales,
792
  ],
793
+ outputs=[image_viewer, current_reconstructed], # Update both viewer and state
794
  )
795
 
796
  # Iteration scrubbing - updates image AND all parameter sliders