Spaces:
Sleeping
Sleeping
| """ | |
| Gradio Phase Reconstruction Viewer | |
| Interactive web interface for viewing zarr microscopy data with T/Z navigation. | |
| Based on: docs/examples/visuals/optimize_phase_recon.py | |
| """ | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| from pathlib import Path | |
| from demo_utils import ( | |
| print_data_summary, | |
| run_optimization_streaming, | |
| get_plate_metadata, | |
| load_fov_from_plate, | |
| extract_2d_slice, | |
| run_reconstruction_single, | |
| ) | |
| # ============================================================================ | |
| # CONFIGURATION | |
| # ============================================================================ | |
| class Config: | |
| """Centralized configuration for the phase reconstruction viewer.""" | |
| # Input data path | |
| INPUT_PATH = Path("data/20x.zarr") | |
| # Default FOV selection | |
| DEFAULT_ROW = "A" | |
| DEFAULT_COLUMN = "1" | |
| DEFAULT_FIELD = "005029" # Center FOV | |
| # Restrict to specific FOVs (center of well A/1 for better quality) | |
| ALLOWED_FOVS = ['005028', '005029', '005030'] | |
| # Channel selection (only BF channel in concatenated data) | |
| CHANNEL = 0 # BF is now channel 0 (GFP was filtered out during concatenation) | |
| # Pixel sizes for 20x objective (override incorrect Zarr metadata) | |
| PIXEL_SIZE_YX = 0.325 # micrometers | |
| PIXEL_SIZE_Z = 2.0 # micrometers | |
| # Reconstruction configuration | |
| RECON_CONFIG = { | |
| "wavelength_illumination": 0.45, | |
| "index_of_refraction_media": 1.3, | |
| "invert_phase_contrast": False, | |
| "num_iterations": 10, | |
| # GPU Configuration (auto-detects GPU for 15-25x speedup) | |
| # - None: Auto-detect (uses CUDA if available, else CPU) | |
| # - "cuda": Force GPU usage (requires CUDA-capable device) | |
| # - "cpu": Force CPU usage (for testing/debugging) | |
| "device": None, | |
| # Tiling (not implemented - using full image) | |
| "use_tiling": False, | |
| } | |
| # Optimizable parameters: (optimize_flag, initial_value, learning_rate) | |
| OPTIMIZABLE_PARAMS = { | |
| "z_offset": (True, 0.0, 0.01), | |
| "numerical_aperture_detection": (True, 0.55, 0.001), | |
| "numerical_aperture_illumination": (True, 0.54, 0.001), | |
| "tilt_angle_zenith": (True, 0.0, 0.005), | |
| "tilt_angle_azimuth": (True, 0.0, 0.001), | |
| } | |
| # UI slider ranges | |
| SLIDER_RANGES = { | |
| "z_offset": (-3.0, 3.0, 0.01), # ±3 µm (1.5x Z-slice spacing for focus correction) | |
| "na_detection": (0.05, 0.65, 0.001), # Max 0.65 to accommodate optimization | |
| "na_illumination": (0.05, 0.65, 0.001), # Max 0.65 (but constrained <= NA_detection) | |
| "tilt_zenith": (0.0, np.pi / 4, 0.005), | |
| "tilt_azimuth": (0.0, np.pi / 4, 0.001), | |
| } | |
| # UI configuration | |
| IMAGE_HEIGHT = 800 | |
| SERVER_PORT = 12124 | |
| # ============================================================================ | |
| # GLOBAL STATE INITIALIZATION | |
| # ============================================================================ | |
| def initialize_plate_metadata(): | |
| """Load and display plate metadata.""" | |
| print("\n" + "=" * 60) | |
| print("Loading HCS Plate Metadata...") | |
| print("=" * 60) | |
| # Pass allowed FOVs to avoid iterating through all positions | |
| plate_metadata = get_plate_metadata(Config.INPUT_PATH, Config.ALLOWED_FOVS) | |
| print(f"Available rows: {plate_metadata['rows']}") | |
| print(f"Available columns: {plate_metadata['columns']}") | |
| print(f"Total wells: {len(plate_metadata['wells'])}") | |
| # Get default well fields (already filtered) | |
| default_well_key = (Config.DEFAULT_ROW, Config.DEFAULT_COLUMN) | |
| default_fields = plate_metadata["wells"].get(default_well_key, []) | |
| print(f"Fields in {Config.DEFAULT_ROW}/{Config.DEFAULT_COLUMN}: {len(default_fields)}") | |
| print(f"Allowed FOVs: {Config.ALLOWED_FOVS}") | |
| print("=" * 60 + "\n") | |
| return plate_metadata, default_fields | |
| def load_default_fov(plate_metadata): | |
| """Load the default field of view and use correct pixel scales.""" | |
| print(f"Loading default FOV: {Config.DEFAULT_ROW}/{Config.DEFAULT_COLUMN}/{Config.DEFAULT_FIELD}") | |
| data_xr = load_fov_from_plate( | |
| plate_metadata["plate"], | |
| Config.DEFAULT_ROW, | |
| Config.DEFAULT_COLUMN, | |
| Config.DEFAULT_FIELD, | |
| resolution=0, | |
| ) | |
| print_data_summary(data_xr) | |
| # Use correct pixel scales from config (20x objective) | |
| # Note: Zarr metadata may have incorrect values from different magnification | |
| pixel_scales = ( | |
| Config.PIXEL_SIZE_Z, # z_scale | |
| Config.PIXEL_SIZE_YX, # y_scale | |
| Config.PIXEL_SIZE_YX, # x_scale | |
| ) | |
| print(f"Using pixel scales (Z, Y, X): {pixel_scales} micrometers (from config, 20x objective)") | |
| return data_xr, pixel_scales | |
| # ============================================================================ | |
| # FOV LOADING CALLBACKS | |
| # ============================================================================ | |
| def load_selected_fov(field: str, current_z: int, plate_metadata): | |
| """Load selected FOV and update UI components.""" | |
| try: | |
| print(f"\nLoading FOV: {Config.DEFAULT_ROW}/{Config.DEFAULT_COLUMN}/{field}") | |
| # Load new data | |
| new_data_xr = load_fov_from_plate( | |
| plate_metadata["plate"], | |
| Config.DEFAULT_ROW, | |
| Config.DEFAULT_COLUMN, | |
| field, | |
| resolution=0, | |
| ) | |
| # Use pixel scales from config (not Zarr metadata) | |
| new_pixel_scales = (Config.PIXEL_SIZE_Z, Config.PIXEL_SIZE_YX, Config.PIXEL_SIZE_YX) | |
| # Update Z slider | |
| z_max = new_data_xr.sizes["Z"] - 1 | |
| new_z = min(current_z, z_max) | |
| print(f"✅ Loaded: {dict(new_data_xr.sizes)}") | |
| # Get preview image | |
| preview_image = extract_2d_slice( | |
| new_data_xr, t=0, c=Config.CHANNEL, z=new_z, normalize=True, verbose=False | |
| ) | |
| return ( | |
| gr.Slider(maximum=z_max, value=new_z), # Updated Z slider | |
| (preview_image, preview_image), # ImageSlider in preview mode | |
| new_data_xr, # Update state | |
| new_pixel_scales, # Update state | |
| ) | |
| except Exception as e: | |
| print(f"❌ Error loading FOV: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return (gr.skip(), gr.skip(), gr.skip(), gr.skip()) | |
| # ============================================================================ | |
| # IMAGE DISPLAY CALLBACKS | |
| # ============================================================================ | |
| def get_slice_for_preview(z: int, data_xr_state): | |
| """Extract slice and show in preview mode (same image twice).""" | |
| slice_img = extract_2d_slice( | |
| data_xr_state, t=0, c=Config.CHANNEL, z=int(z), normalize=True, verbose=False | |
| ) | |
| return (slice_img, slice_img) # Preview mode: both sides show same image | |
| def update_original_slice_only(z: int, data_xr_state, current_reconstructed_state): | |
| """ | |
| Update only the left (original) image when Z changes, keep reconstruction on right. | |
| If no reconstruction exists yet, shows the original on both sides. | |
| """ | |
| slice_img = extract_2d_slice( | |
| data_xr_state, t=0, c=Config.CHANNEL, z=int(z), normalize=True, verbose=False | |
| ) | |
| # If there's a reconstruction, keep it on the right; otherwise show original on both sides | |
| if current_reconstructed_state is not None: | |
| return (slice_img, current_reconstructed_state) | |
| else: | |
| return (slice_img, slice_img) | |
| # ============================================================================ | |
| # RECONSTRUCTION CALLBACKS | |
| # ============================================================================ | |
| def run_reconstruction_ui( | |
| z: int, | |
| z_offset: float, | |
| na_det: float, | |
| na_ill: float, | |
| tilt_zenith: float, | |
| tilt_azimuth: float, | |
| data_xr_state, | |
| pixel_scales_state, | |
| ): | |
| """ | |
| Run reconstruction with CURRENT slider values (no optimization). | |
| Uses slider parameters directly for a single fast reconstruction. | |
| """ | |
| # Extract full Z-stack for timepoint 0 (for reconstruction) | |
| zyx_stack = data_xr_state.isel(T=0, C=Config.CHANNEL).values | |
| # Get current Z-slice for comparison (left side of ImageSlider) | |
| original_normalized = extract_2d_slice( | |
| data_xr_state, t=0, c=Config.CHANNEL, z=int(z), normalize=True, verbose=False | |
| ) | |
| # Build parameter dict from slider values | |
| param_values = { | |
| "z_offset": z_offset, | |
| "numerical_aperture_detection": na_det, | |
| "numerical_aperture_illumination": na_ill, | |
| "tilt_angle_zenith": tilt_zenith, | |
| "tilt_angle_azimuth": tilt_azimuth, | |
| } | |
| # Run single reconstruction with these parameters | |
| reconstructed_image = run_reconstruction_single( | |
| zyx_stack, pixel_scales_state, Config.RECON_CONFIG, param_values | |
| ) | |
| # Return updated image slider AND reconstructed state | |
| return (original_normalized, reconstructed_image), reconstructed_image | |
| def run_optimization_ui( | |
| z: int, | |
| num_iterations: int, | |
| z_offset: float, | |
| na_det: float, | |
| na_ill: float, | |
| tilt_zenith: float, | |
| tilt_azimuth: float, | |
| data_xr_state, | |
| pixel_scales_state, | |
| ): | |
| """ | |
| Run OPTIMIZATION and stream updates to UI with iteration caching. | |
| Uses current slider values as initial guesses, runs full optimization loop. | |
| Yields progressive updates for ImageSlider, loss plot, status, | |
| iteration history, iteration slider, and SLIDER UPDATES. | |
| """ | |
| # Extract full Z-stack for timepoint 0 (for reconstruction) | |
| zyx_stack = data_xr_state.isel(T=0, C=Config.CHANNEL).values | |
| # Get current Z-slice for comparison (left side of ImageSlider) | |
| original_normalized = extract_2d_slice( | |
| data_xr_state, t=0, c=Config.CHANNEL, z=int(z), normalize=True, verbose=False | |
| ) | |
| # Build optimizable params with current slider values as initial values | |
| optimizable_params_with_slider_values = { | |
| "z_offset": ( | |
| Config.OPTIMIZABLE_PARAMS["z_offset"][0], # enabled flag | |
| z_offset, # initial value from slider | |
| Config.OPTIMIZABLE_PARAMS["z_offset"][2], # learning rate | |
| ), | |
| "numerical_aperture_detection": ( | |
| Config.OPTIMIZABLE_PARAMS["numerical_aperture_detection"][0], | |
| na_det, | |
| Config.OPTIMIZABLE_PARAMS["numerical_aperture_detection"][2], | |
| ), | |
| "numerical_aperture_illumination": ( | |
| Config.OPTIMIZABLE_PARAMS["numerical_aperture_illumination"][0], | |
| na_ill, | |
| Config.OPTIMIZABLE_PARAMS["numerical_aperture_illumination"][2], | |
| ), | |
| "tilt_angle_zenith": ( | |
| Config.OPTIMIZABLE_PARAMS["tilt_angle_zenith"][0], | |
| tilt_zenith, | |
| Config.OPTIMIZABLE_PARAMS["tilt_angle_zenith"][2], | |
| ), | |
| "tilt_angle_azimuth": ( | |
| Config.OPTIMIZABLE_PARAMS["tilt_angle_azimuth"][0], | |
| tilt_azimuth, | |
| Config.OPTIMIZABLE_PARAMS["tilt_angle_azimuth"][2], | |
| ), | |
| } | |
| # Initialize tracking | |
| loss_history = [] | |
| iteration_cache = [] | |
| # Set raw image once at the start (pin it) | |
| yield ( | |
| (original_normalized, original_normalized), # Show raw image on both sides initially | |
| pd.DataFrame({"iteration": [], "loss": []}), # Initialize loss plot with empty data | |
| [], # Clear iteration history | |
| gr.skip(), # Don't update slider yet (avoid min=max=1 error) | |
| gr.Markdown(value="Starting optimization...", visible=True), | |
| # Slider updates (5 outputs): | |
| gr.skip(), # z_offset | |
| gr.skip(), # na_det | |
| gr.skip(), # na_ill | |
| gr.skip(), # tilt_zenith | |
| gr.skip(), # tilt_azimuth | |
| None, # No reconstructed image yet | |
| ) | |
| # Run optimization with streaming (using slider values as initial values) | |
| for result in run_optimization_streaming( | |
| zyx_stack, | |
| pixel_scales_state, | |
| Config.RECON_CONFIG, | |
| optimizable_params_with_slider_values, | |
| num_iterations=num_iterations, | |
| ): | |
| # Current iteration number | |
| n = result["iteration"] | |
| # Cache iteration result | |
| iteration_cache.append( | |
| { | |
| "iteration": n, | |
| "reconstructed_image": result["reconstructed_image"], | |
| "loss": result["loss"], | |
| "params": result["params"], | |
| "raw_image": original_normalized, | |
| } | |
| ) | |
| # Accumulate loss history (ensure iteration is int for proper x-axis) | |
| loss_history.append({"iteration": int(n), "loss": result["loss"]}) | |
| # Format iteration info | |
| info_md = f"**Iteration {n}/{num_iterations}** | Loss: `{result['loss']:.2e}`" | |
| # Clip optimized parameters to slider ranges (avoid Gradio validation errors) | |
| # Convert to float to ensure Gradio compatibility | |
| clipped_params = { | |
| "z_offset": float(np.clip( | |
| result["params"].get("z_offset", 0.0), | |
| Config.SLIDER_RANGES["z_offset"][0], | |
| Config.SLIDER_RANGES["z_offset"][1], | |
| )), | |
| "numerical_aperture_detection": float(np.clip( | |
| result["params"].get("numerical_aperture_detection", 0.55), | |
| Config.SLIDER_RANGES["na_detection"][0], | |
| Config.SLIDER_RANGES["na_detection"][1], | |
| )), | |
| "numerical_aperture_illumination": float(np.clip( | |
| result["params"].get("numerical_aperture_illumination", 0.54), | |
| Config.SLIDER_RANGES["na_illumination"][0], | |
| Config.SLIDER_RANGES["na_illumination"][1], | |
| )), | |
| "tilt_angle_zenith": float(np.clip( | |
| result["params"].get("tilt_angle_zenith", 0.0), | |
| Config.SLIDER_RANGES["tilt_zenith"][0], | |
| Config.SLIDER_RANGES["tilt_zenith"][1], | |
| )), | |
| "tilt_angle_azimuth": float(np.clip( | |
| result["params"].get("tilt_angle_azimuth", 0.0), | |
| Config.SLIDER_RANGES["tilt_azimuth"][0], | |
| Config.SLIDER_RANGES["tilt_azimuth"][1], | |
| )), | |
| } | |
| # Yield updates - update ImageSlider AND sliders with clipped params | |
| yield ( | |
| (original_normalized, result["reconstructed_image"]), # Update ImageSlider | |
| pd.DataFrame(loss_history), # Loss plot | |
| iteration_cache, # Update iteration history state | |
| gr.Slider( # Update iteration slider (grows from 1-1 to 1-10) | |
| minimum=1, | |
| maximum=n, | |
| value=n, | |
| step=1, | |
| visible=True, | |
| interactive=True, | |
| ), | |
| gr.Markdown(value=info_md, visible=True), # Show iteration info | |
| # Update parameter sliders with clipped optimized values: | |
| clipped_params["z_offset"], | |
| clipped_params["numerical_aperture_detection"], | |
| clipped_params["numerical_aperture_illumination"], | |
| clipped_params["tilt_angle_zenith"], | |
| clipped_params["tilt_angle_azimuth"], | |
| result["reconstructed_image"], # Update reconstructed image state | |
| ) | |
| # Final yield (keep last state) | |
| yield ( | |
| gr.skip(), # Keep last ImageSlider state | |
| gr.skip(), # Keep last loss plot | |
| gr.skip(), # Keep iteration history | |
| gr.skip(), # Keep iteration slider | |
| gr.Markdown( | |
| value=f"**Optimization Complete!** Final Loss: `{result['loss']:.2e}`", | |
| visible=True, | |
| ), | |
| gr.skip(), # Keep z_offset | |
| gr.skip(), # Keep na_det | |
| gr.skip(), # Keep na_ill | |
| gr.skip(), # Keep tilt_zenith | |
| gr.skip(), # Keep tilt_azimuth | |
| gr.skip(), # Keep reconstructed image state | |
| ) | |
| # ============================================================================ | |
| # ITERATION SCRUBBING CALLBACKS | |
| # ============================================================================ | |
| def scrub_iterations(iteration_idx: int, history: list): | |
| """Update display AND parameter sliders when user scrubs to different iteration.""" | |
| if not history or iteration_idx < 1 or iteration_idx > len(history): | |
| return (gr.skip(),) * 7 # image, info, and 5 parameter values | |
| # Get selected iteration (convert to 0-indexed) | |
| selected = history[iteration_idx - 1] | |
| # Update ImageSlider overlay | |
| comparison = (selected["raw_image"], selected["reconstructed_image"]) | |
| # Update info display | |
| info_md = f"**Iteration {selected['iteration']}/{len(history)}** | Loss: `{selected['loss']:.2e}`" | |
| # Extract parameter values at this iteration and clip to slider ranges | |
| # Convert to float to ensure Gradio compatibility | |
| params = selected["params"] | |
| z_offset = float(np.clip( | |
| params.get("z_offset", 0.0), | |
| Config.SLIDER_RANGES["z_offset"][0], | |
| Config.SLIDER_RANGES["z_offset"][1], | |
| )) | |
| na_det = float(np.clip( | |
| params.get("numerical_aperture_detection", 0.55), | |
| Config.SLIDER_RANGES["na_detection"][0], | |
| Config.SLIDER_RANGES["na_detection"][1], | |
| )) | |
| na_ill = float(np.clip( | |
| params.get("numerical_aperture_illumination", 0.54), | |
| Config.SLIDER_RANGES["na_illumination"][0], | |
| Config.SLIDER_RANGES["na_illumination"][1], | |
| )) | |
| tilt_zenith = float(np.clip( | |
| params.get("tilt_angle_zenith", 0.0), | |
| Config.SLIDER_RANGES["tilt_zenith"][0], | |
| Config.SLIDER_RANGES["tilt_zenith"][1], | |
| )) | |
| tilt_azimuth = float(np.clip( | |
| params.get("tilt_angle_azimuth", 0.0), | |
| Config.SLIDER_RANGES["tilt_azimuth"][0], | |
| Config.SLIDER_RANGES["tilt_azimuth"][1], | |
| )) | |
| return comparison, info_md, z_offset, na_det, na_ill, tilt_zenith, tilt_azimuth | |
| def clear_iteration_state(): | |
| """Reset iteration state when coordinates change.""" | |
| return ( | |
| [], # Clear iteration_history | |
| gr.skip(), # Don't update slider (avoid min=max error) | |
| gr.Markdown(value="", visible=False), # Hide info | |
| ) | |
| # ============================================================================ | |
| # UI CONSTRUCTION | |
| # ============================================================================ | |
| def create_gradio_interface(plate_metadata, default_fields, data_xr, pixel_scales): | |
| """Build the Gradio interface with all components and event wiring.""" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# WaveOrder") | |
| gr.Markdown( | |
| "**Paper:** Chandler T., Ivanov I.E., Hirata-Miyasaki E., et al. \"WaveOrder: Physics-informed ML for auto-tuned multi-contrast computational microscopy from cells to organisms.\" " | |
| "[arXiv:2412.09775](https://arxiv.org/abs/2412.09775) (2025)\n\n" | |
| "**GitHub Repository:** [mehta-lab/waveorder](https://github.com/mehta-lab/waveorder)" | |
| ) | |
| gr.Markdown("---") | |
| # FOV Selection (top of page) | |
| with gr.Row(): | |
| fov_dropdown = gr.Dropdown( | |
| choices=default_fields, | |
| value=Config.DEFAULT_FIELD, | |
| label="Field of View", | |
| info=f"Select FOV from well {Config.DEFAULT_ROW}/{Config.DEFAULT_COLUMN}", | |
| scale=2, | |
| ) | |
| load_fov_btn = gr.Button("🔄 Load FOV", variant="secondary", size="sm", scale=1) | |
| gr.Markdown("---") | |
| # Two-column layout: Image viewer (left) | Controls (right) | |
| with gr.Row(): | |
| # LEFT COLUMN: Large ImageSlider (60% width) | |
| with gr.Column(scale=4): | |
| # Image viewer | |
| initial_preview = extract_2d_slice( | |
| data_xr, | |
| t=0, | |
| c=Config.CHANNEL, | |
| z=data_xr.sizes["Z"] // 2, | |
| normalize=True, | |
| verbose=False, | |
| ) | |
| image_viewer = gr.ImageSlider( | |
| label="Raw (left) vs Reconstructed (right) - Drag slider to compare", | |
| type="numpy", | |
| value=(initial_preview, initial_preview), | |
| height=Config.IMAGE_HEIGHT, | |
| ) | |
| gr.Markdown("---") | |
| # Section 2: Navigation (below image) | |
| gr.Markdown("### 🎛️ Navigation") | |
| z_slider = gr.Slider( | |
| minimum=0, | |
| maximum=data_xr.sizes["Z"] - 1, | |
| value=data_xr.sizes["Z"] // 2, | |
| step=1, | |
| label="Z-slice", | |
| scale=1, | |
| ) | |
| # RIGHT COLUMN: All controls (40% width) | |
| with gr.Column(scale=2): | |
| # Section 3: Reconstruction Parameters | |
| gr.Markdown("### ⚙️ Reconstruction Parameters") | |
| # Sliders for optimizable parameters | |
| z_offset_slider = gr.Slider( | |
| minimum=Config.SLIDER_RANGES["z_offset"][0], | |
| maximum=Config.SLIDER_RANGES["z_offset"][1], | |
| value=Config.OPTIMIZABLE_PARAMS["z_offset"][1], | |
| step=Config.SLIDER_RANGES["z_offset"][2], | |
| label="Z Offset (μm)", | |
| info="Axial focus offset", | |
| ) | |
| na_det_slider = gr.Slider( | |
| minimum=Config.SLIDER_RANGES["na_detection"][0], | |
| maximum=Config.SLIDER_RANGES["na_detection"][1], | |
| value=Config.OPTIMIZABLE_PARAMS["numerical_aperture_detection"][1], | |
| step=Config.SLIDER_RANGES["na_detection"][2], | |
| label="NA Detection", | |
| info="Numerical aperture of detection objective", | |
| ) | |
| na_ill_slider = gr.Slider( | |
| minimum=Config.SLIDER_RANGES["na_illumination"][0], | |
| maximum=Config.SLIDER_RANGES["na_illumination"][1], | |
| value=Config.OPTIMIZABLE_PARAMS["numerical_aperture_illumination"][1], | |
| step=Config.SLIDER_RANGES["na_illumination"][2], | |
| label="NA Illumination", | |
| info="Numerical aperture of illumination", | |
| ) | |
| tilt_zenith_slider = gr.Slider( | |
| minimum=Config.SLIDER_RANGES["tilt_zenith"][0], | |
| maximum=Config.SLIDER_RANGES["tilt_zenith"][1], | |
| value=Config.OPTIMIZABLE_PARAMS["tilt_angle_zenith"][1], | |
| step=Config.SLIDER_RANGES["tilt_zenith"][2], | |
| label="Tilt Zenith (rad)", | |
| info="Zenith angle of illumination tilt", | |
| ) | |
| tilt_azimuth_slider = gr.Slider( | |
| minimum=Config.SLIDER_RANGES["tilt_azimuth"][0], | |
| maximum=Config.SLIDER_RANGES["tilt_azimuth"][1], | |
| value=Config.OPTIMIZABLE_PARAMS["tilt_angle_azimuth"][1], | |
| step=Config.SLIDER_RANGES["tilt_azimuth"][2], | |
| label="Tilt Azimuth (rad)", | |
| info="Azimuthal angle of illumination tilt", | |
| ) | |
| # Reset button | |
| reset_params_btn = gr.Button( | |
| "🔄 Reset Parameters", variant="secondary", size="sm" | |
| ) | |
| gr.Markdown("---") | |
| # Section 4: Reconstruction Actions | |
| gr.Markdown("### 🔬 Phase Reconstruction") | |
| # Number of optimization iterations control | |
| num_iterations_slider = gr.Slider( | |
| minimum=1, | |
| maximum=50, | |
| value=Config.RECON_CONFIG["num_iterations"], | |
| step=1, | |
| label="Optimization Iterations", | |
| info="Number of gradient descent iterations (more = better quality, slower)", | |
| ) | |
| with gr.Row(): | |
| optimize_btn = gr.Button( | |
| "⚡ Optimize Parameters", variant="secondary", size="lg" | |
| ) | |
| reconstruct_btn = gr.Button( | |
| "🔬 Run Reconstruction", variant="primary", size="lg" | |
| ) | |
| gr.Markdown("---") | |
| # Section 5: Optimization Results | |
| gr.Markdown("### 📊 Optimization Results") | |
| loss_plot = gr.LinePlot( | |
| x="iteration", | |
| y="loss", | |
| title="Optimization - Midband Spatial Frequency Loss", | |
| height=200, | |
| scale=2, | |
| value=pd.DataFrame({"iteration": [], "loss": []}), # Initialize with empty DataFrame structure | |
| ) | |
| # Iteration scrubbing controls | |
| iteration_slider = gr.Slider( | |
| minimum=1, | |
| maximum=1, | |
| value=1, | |
| step=1, | |
| label="View Iteration", | |
| info="Scrub through optimization history", | |
| interactive=True, # Always interactive (just hidden until optimization) | |
| visible=False, | |
| ) | |
| iteration_info = gr.Markdown(value="", visible=False) | |
| # State storage | |
| iteration_history = gr.State(value=[]) | |
| current_data_xr = gr.State(value=data_xr) | |
| current_pixel_scales = gr.State(value=pixel_scales) | |
| current_reconstructed = gr.State(value=None) # Stores the current reconstructed image | |
| gr.Markdown("---") | |
| # Wire all event handlers | |
| _wire_event_handlers( | |
| demo, | |
| fov_dropdown, | |
| load_fov_btn, | |
| z_slider, | |
| image_viewer, | |
| z_offset_slider, | |
| na_det_slider, | |
| na_ill_slider, | |
| tilt_zenith_slider, | |
| tilt_azimuth_slider, | |
| reset_params_btn, | |
| num_iterations_slider, | |
| optimize_btn, | |
| reconstruct_btn, | |
| loss_plot, | |
| iteration_slider, | |
| iteration_info, | |
| iteration_history, | |
| current_data_xr, | |
| current_pixel_scales, | |
| current_reconstructed, | |
| plate_metadata, | |
| ) | |
| return demo | |
| def _wire_event_handlers( | |
| demo, | |
| fov_dropdown, | |
| load_fov_btn, | |
| z_slider, | |
| image_viewer, | |
| z_offset_slider, | |
| na_det_slider, | |
| na_ill_slider, | |
| tilt_zenith_slider, | |
| tilt_azimuth_slider, | |
| reset_params_btn, | |
| num_iterations_slider, | |
| optimize_btn, | |
| reconstruct_btn, | |
| loss_plot, | |
| iteration_slider, | |
| iteration_info, | |
| iteration_history, | |
| current_data_xr, | |
| current_pixel_scales, | |
| current_reconstructed, | |
| plate_metadata, | |
| ): | |
| """Wire all Gradio event handlers.""" | |
| # FOV loading | |
| load_fov_btn.click( | |
| fn=lambda field, z: load_selected_fov(field, z, plate_metadata), | |
| inputs=[fov_dropdown, z_slider], | |
| outputs=[z_slider, image_viewer, current_data_xr, current_pixel_scales], | |
| ) | |
| # Reset parameters to initial values | |
| def reset_parameters(): | |
| """Reset all reconstruction parameters to their initial config values.""" | |
| return ( | |
| Config.OPTIMIZABLE_PARAMS["z_offset"][1], | |
| Config.OPTIMIZABLE_PARAMS["numerical_aperture_detection"][1], | |
| Config.OPTIMIZABLE_PARAMS["numerical_aperture_illumination"][1], | |
| Config.OPTIMIZABLE_PARAMS["tilt_angle_zenith"][1], | |
| Config.OPTIMIZABLE_PARAMS["tilt_angle_azimuth"][1], | |
| ) | |
| reset_params_btn.click( | |
| fn=reset_parameters, | |
| inputs=[], | |
| outputs=[ | |
| z_offset_slider, | |
| na_det_slider, | |
| na_ill_slider, | |
| tilt_zenith_slider, | |
| tilt_azimuth_slider, | |
| ], | |
| ) | |
| # NA slider linking: Ensure NA_illumination <= NA_detection (physical constraint) | |
| # Only enforce when NA_detection changes (avoid feedback loop) | |
| def enforce_na_constraint(na_det_value, na_ill_value): | |
| """When NA_detection decreases below NA_illumination, cap NA_illumination.""" | |
| return min(na_ill_value, na_det_value) | |
| na_det_slider.change( | |
| fn=enforce_na_constraint, | |
| inputs=[na_det_slider, na_ill_slider], | |
| outputs=[na_ill_slider], | |
| ) | |
| # Image viewer for Z navigation | |
| # On load: show preview mode (no reconstruction yet) | |
| demo.load( | |
| fn=get_slice_for_preview, | |
| inputs=[z_slider, current_data_xr], | |
| outputs=image_viewer, | |
| ) | |
| # On Z change: update only left (original) image, keep reconstruction on right | |
| z_slider.change( | |
| fn=update_original_slice_only, | |
| inputs=[z_slider, current_data_xr, current_reconstructed], | |
| outputs=image_viewer, | |
| ) | |
| # Reconstruction buttons | |
| optimize_btn.click( | |
| fn=run_optimization_ui, | |
| inputs=[ | |
| z_slider, | |
| num_iterations_slider, | |
| z_offset_slider, | |
| na_det_slider, | |
| na_ill_slider, | |
| tilt_zenith_slider, | |
| tilt_azimuth_slider, | |
| current_data_xr, | |
| current_pixel_scales, | |
| ], | |
| outputs=[ | |
| image_viewer, | |
| loss_plot, | |
| iteration_history, | |
| iteration_slider, | |
| iteration_info, | |
| z_offset_slider, | |
| na_det_slider, | |
| na_ill_slider, | |
| tilt_zenith_slider, | |
| tilt_azimuth_slider, | |
| current_reconstructed, # Update reconstructed state | |
| ], | |
| ) | |
| reconstruct_btn.click( | |
| fn=run_reconstruction_ui, | |
| inputs=[ | |
| z_slider, | |
| z_offset_slider, | |
| na_det_slider, | |
| na_ill_slider, | |
| tilt_zenith_slider, | |
| tilt_azimuth_slider, | |
| current_data_xr, | |
| current_pixel_scales, | |
| ], | |
| outputs=[image_viewer, current_reconstructed], # Update both viewer and state | |
| ) | |
| # Iteration scrubbing - updates image AND all parameter sliders | |
| iteration_slider.change( | |
| fn=scrub_iterations, | |
| inputs=[iteration_slider, iteration_history], | |
| outputs=[ | |
| image_viewer, | |
| iteration_info, | |
| z_offset_slider, | |
| na_det_slider, | |
| na_ill_slider, | |
| tilt_zenith_slider, | |
| tilt_azimuth_slider, | |
| ], | |
| ) | |
| # Clear iteration state when Z changes | |
| z_slider.change( | |
| fn=clear_iteration_state, | |
| inputs=[], | |
| outputs=[iteration_history, iteration_slider, iteration_info], | |
| ) | |