Spaces:
Sleeping
Sleeping
| """ | |
| Utility functions for Gradio demos | |
| Provides reusable components for: | |
| - Data loading from OME-Zarr stores | |
| - Image normalization and processing | |
| - Slice extraction from xarray DataArrays | |
| - Phase reconstruction and optimization | |
| Design Notes | |
| ------------ | |
| All image processing functions work with xarray.DataArray to maintain | |
| labeled dimensions and coordinate information as long as possible. | |
| Only convert to numpy arrays at the final display step. | |
| """ | |
| from pathlib import Path | |
| from typing import Generator | |
| import numpy as np | |
| import torch | |
| import xarray as xr | |
| from numpy.typing import NDArray | |
| from xarray_ome import open_ome_dataset | |
| from waveorder import util | |
| from waveorder.models import isotropic_thin_3d | |
| from waveorder.cli.compute_transfer_function import ( | |
| _position_list_from_shape_scale_offset, | |
| ) | |
| # Type alias for device specification | |
| Device = torch.device | str | None | |
| def get_device(device: Device = None) -> torch.device: | |
| """ | |
| Get torch device with smart defaults. | |
| Parameters | |
| ---------- | |
| device : torch.device | str | None | |
| If None, auto-selects cuda if available, else cpu. | |
| If str, converts to torch.device. | |
| If torch.device, returns as-is. | |
| Returns | |
| ------- | |
| torch.device | |
| Validated device ready for use | |
| Examples | |
| -------- | |
| >>> get_device() # Auto-detect | |
| device(type='cuda', index=0) # if GPU available | |
| >>> get_device("cpu") # Force CPU | |
| device(type='cpu') | |
| >>> get_device(torch.device("cuda:1")) # Specific GPU | |
| device(type='cuda', index=1) | |
| """ | |
| if device is None: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if device.type == "cuda": | |
| print(f"🚀 Using GPU: {torch.cuda.get_device_name(device)}") | |
| gpu_mem_gb = torch.cuda.get_device_properties(device).total_memory / 1e9 | |
| print(f" GPU Memory: {gpu_mem_gb:.2f} GB") | |
| else: | |
| print("💻 Using CPU (GPU not available)") | |
| return device | |
| if isinstance(device, str): | |
| return torch.device(device) | |
| return device | |
| # === HCS Plate Loading with iohub === | |
| def get_plate_metadata(zarr_path: Path | str, allowed_fovs: list[str]) -> dict: | |
| """ | |
| Extract HCS plate metadata for FOV selection using iohub. | |
| Optimized to only load metadata for specified FOVs. | |
| Parameters | |
| ---------- | |
| zarr_path : Path | str | |
| Path to the HCS plate zarr store | |
| allowed_fovs : list[str] | |
| List of allowed FOV names (e.g., ['002026', '002027', '002028']) | |
| Returns | |
| ------- | |
| dict | |
| Metadata with keys: | |
| - 'rows': list of row names (e.g., ['A']) | |
| - 'columns': list of column names (e.g., ['1', '2', '3']) | |
| - 'wells': dict mapping (row, col) to list of field names | |
| - 'plate': iohub Plate object for later access | |
| - 'zarr_path': stored path for data loading | |
| """ | |
| from iohub import open_ome_zarr | |
| # Open HCS plate with iohub (fast - doesn't load data) | |
| plate = open_ome_zarr(str(zarr_path), mode="r", layout="hcs") | |
| # Hardcoded metadata for known structure (avoids iterating 1000s of positions) | |
| rows = ["A"] | |
| columns = ["1", "2", "3"] | |
| # Only return the allowed FOVs for each well | |
| wells = { | |
| ("A", "1"): allowed_fovs, | |
| ("A", "2"): allowed_fovs, | |
| ("A", "3"): allowed_fovs, | |
| } | |
| return { | |
| "rows": rows, | |
| "columns": columns, | |
| "wells": wells, | |
| "plate": plate, | |
| "zarr_path": str(zarr_path), | |
| } | |
| def load_fov_from_plate( | |
| plate, row: str, column: str, field: str, resolution: int = 0 | |
| ) -> xr.DataArray: | |
| """ | |
| Load a specific FOV from HCS plate using hybrid iohub + xarray-ome approach. | |
| Uses iohub for navigation, then xarray-ome for fast data loading. | |
| Parameters | |
| ---------- | |
| plate : iohub.Plate | |
| Plate loaded with open_ome_zarr(..., layout="hcs") | |
| row : str | |
| Row name (e.g., 'A') | |
| column : str | |
| Column name (e.g., '1') | |
| field : str | |
| Field/position name (e.g., '002026') | |
| resolution : int, optional | |
| Resolution level to load, by default 0 | |
| Returns | |
| ------- | |
| xr.DataArray | |
| Image data with labeled dimensions (T, C, Z, Y, X) | |
| """ | |
| # Navigate to position using iohub (fast) | |
| position_key = f"{row}/{column}/{field}" | |
| position = plate[position_key] | |
| # Get full zarr path from position (handle both Zarr V2 and V3) | |
| store = position.zgroup.store | |
| if hasattr(store, 'path'): | |
| base_path = Path(store.path) # Zarr V2 (DirectoryStore) | |
| elif hasattr(store, 'root'): | |
| base_path = Path(store.root) # Zarr V3 (LocalStore) | |
| else: | |
| raise RuntimeError(f"Unknown store type: {type(store)}") | |
| position_path = base_path / position.zgroup.path | |
| # Load with xarray-ome (fast and reliable) | |
| fov_dataset = open_ome_dataset(position_path, resolution=resolution, validate=False) | |
| data_xr = fov_dataset["image"] | |
| return data_xr | |
| # === Data Loading === | |
| def load_ome_zarr_fov( | |
| zarr_path: Path | str, fov_path: Path | str, resolution: int = 0 | |
| ) -> xr.DataArray: | |
| """ | |
| Load a field of view from an OME-Zarr store as an xarray DataArray. | |
| Parameters | |
| ---------- | |
| zarr_path : Path | str | |
| Path to the root OME-Zarr store | |
| fov_path : Path | str | |
| Relative path to the FOV (e.g., "A/1/001007") | |
| resolution : int, optional | |
| Resolution level to load (0 is full resolution), by default 0 | |
| Returns | |
| ------- | |
| xr.DataArray | |
| Image data with labeled dimensions (T, C, Z, Y, X) | |
| """ | |
| zarr_path = Path(zarr_path) | |
| fov_path = Path(fov_path) | |
| print(f"Loading zarr store from: {zarr_path}") | |
| print(f"Accessing FOV: {fov_path}") | |
| # Load as xarray Dataset | |
| fov_dataset: xr.Dataset = open_ome_dataset( | |
| zarr_path / fov_path, resolution=resolution, validate=False | |
| ) | |
| # Extract the image DataArray | |
| data_xr = fov_dataset["image"] | |
| print(f"Loaded data shape: {dict(data_xr.sizes)}") | |
| print(f"Dimensions: {list(data_xr.dims)}") | |
| print(f"Data type: {data_xr.dtype}") | |
| return data_xr | |
| # === Image Processing === | |
| def normalize_for_display( | |
| img_2d: xr.DataArray, | |
| percentiles: tuple[float, float] = (1, 99), | |
| clip_to_uint8: bool = True, | |
| ) -> np.ndarray: | |
| """ | |
| Normalize a 2D microscopy image using percentile clipping. | |
| Uses robust percentile-based normalization to handle outliers | |
| common in microscopy data. Works with xarray DataArrays to maintain | |
| labeled dimensions through the processing pipeline. | |
| Parameters | |
| ---------- | |
| img_2d : xr.DataArray | |
| 2D image DataArray to normalize | |
| percentiles : tuple[float, float], optional | |
| Lower and upper percentiles for clipping, by default (1, 99) | |
| clip_to_uint8 : bool, optional | |
| If True, convert to uint8 (0-255), otherwise keep as float (0-1), | |
| by default True | |
| Returns | |
| ------- | |
| np.ndarray | |
| Normalized numpy array (uint8 if clip_to_uint8=True, else float32) | |
| Notes | |
| ----- | |
| Expects xarray.DataArray input. For raw numpy arrays, | |
| wrap in xarray first: xr.DataArray(array, dims=["Y", "X"]) | |
| """ | |
| # Calculate percentiles using xarray | |
| p_low = float(img_2d.quantile(percentiles[0] / 100.0).values) | |
| p_high = float(img_2d.quantile(percentiles[1] / 100.0).values) | |
| # Handle edge case: no intensity variation | |
| if p_high - p_low < 1e-10: | |
| return np.zeros(img_2d.shape, dtype=np.uint8 if clip_to_uint8 else np.float32) | |
| # Clip and normalize using xarray operations | |
| img_clipped = img_2d.clip(min=p_low, max=p_high) | |
| img_normalized = (img_clipped - p_low) / (p_high - p_low) | |
| # Convert to numpy array | |
| result = img_normalized.values | |
| # Convert to requested output format | |
| if clip_to_uint8: | |
| result = (result * 255).astype(np.uint8) | |
| else: | |
| result = result.astype(np.float32) | |
| return result | |
| # === Slice Extraction === | |
| def extract_2d_slice( | |
| data_xr: xr.DataArray, | |
| t: int | None = None, | |
| c: int | None = None, | |
| z: int | None = None, | |
| normalize: bool = True, | |
| verbose: bool = True, | |
| ) -> np.ndarray: | |
| """ | |
| Extract and optionally normalize a 2D slice from xarray data. | |
| Flexibly handles different dimension specifications. If a dimension | |
| index is None, it will be squeezed out if size=1 or raise an error | |
| if size>1. | |
| Parameters | |
| ---------- | |
| data_xr : xr.DataArray | |
| Image data with dimensions (T, C, Z, Y, X) | |
| t : int | None, optional | |
| Timepoint index, by default None | |
| c : int | None, optional | |
| Channel index, by default None | |
| z : int | None, optional | |
| Z-slice index, by default None | |
| normalize : bool, optional | |
| Whether to normalize for display, by default True | |
| verbose : bool, optional | |
| Whether to print slice information, by default True | |
| Returns | |
| ------- | |
| np.ndarray | |
| 2D numpy array (normalized uint8 if normalize=True, else raw values) | |
| Raises | |
| ------ | |
| ValueError | |
| If result is empty or not 2D after slicing and squeezing | |
| """ | |
| # Build selection dictionary for indexed dimensions | |
| sel_dict = {} | |
| if t is not None: | |
| sel_dict["T"] = int(t) | |
| if c is not None: | |
| sel_dict["C"] = int(c) | |
| if z is not None: | |
| sel_dict["Z"] = int(z) | |
| # Extract slice using xarray's labeled indexing | |
| slice_xr = data_xr.isel(**sel_dict) if sel_dict else data_xr | |
| # Compute if Dask-backed (load from disk) | |
| if hasattr(slice_xr.data, "compute"): | |
| slice_xr = slice_xr.compute() | |
| # Squeeze singleton dimensions (e.g., single channel, single Z) | |
| slice_xr = slice_xr.squeeze() | |
| # Validation: ensure non-empty result | |
| if slice_xr.size == 0: | |
| raise ValueError( | |
| f"Empty array after slicing. Selection: {sel_dict}, " | |
| f"Original shape: {data_xr.shape}" | |
| ) | |
| # Validation: ensure 2D result | |
| if slice_xr.ndim != 2: | |
| raise ValueError( | |
| f"Expected 2D array after slicing, got shape {slice_xr.shape}. " | |
| f"Selection: {sel_dict}" | |
| ) | |
| # Verbose output: print slice information | |
| if verbose: | |
| sel_str = ( | |
| ", ".join(f"{k}={v}" for k, v in sel_dict.items()) | |
| if sel_dict | |
| else "full array" | |
| ) | |
| print( | |
| f"Extracted slice: {sel_str}, Shape={slice_xr.shape}, " | |
| f"Range=[{float(slice_xr.min()):.1f}, {float(slice_xr.max()):.1f}]" | |
| ) | |
| # Normalize or convert to numpy | |
| if normalize: | |
| slice_2d = normalize_for_display(slice_xr) | |
| else: | |
| slice_2d = slice_xr.values | |
| return slice_2d | |
| # === Slice Extraction Factory === | |
| def create_slice_extractor( | |
| data_xr: xr.DataArray, | |
| normalize: bool = True, | |
| channel: int = 0, | |
| ): | |
| """ | |
| Create a closure function for extracting slices from a specific dataset. | |
| This factory function is useful for Gradio callbacks where the data | |
| is loaded once and the same extraction function is called multiple times. | |
| Parameters | |
| ---------- | |
| data_xr : xr.DataArray | |
| Image data to extract slices from | |
| normalize : bool, optional | |
| Whether to normalize for display, by default True | |
| channel : int, optional | |
| Default channel to use, by default 0 | |
| Returns | |
| ------- | |
| callable | |
| Function with signature (t: int, z: int) -> np.ndarray that extracts | |
| and normalizes 2D slices | |
| """ | |
| def get_slice(t: int, z: int) -> np.ndarray: | |
| """Extract and normalize a 2D slice at timepoint t and z-slice z.""" | |
| return extract_2d_slice( | |
| data_xr, | |
| t=int(t), | |
| c=channel, | |
| z=int(z), | |
| normalize=normalize, | |
| verbose=True, | |
| ) | |
| return get_slice | |
| # === Metadata Helpers === | |
| def get_dimension_info(data_xr: xr.DataArray) -> dict: | |
| """ | |
| Extract dimension information from xarray DataArray. | |
| Parameters | |
| ---------- | |
| data_xr : xr.DataArray | |
| Image data with dimensions | |
| Returns | |
| ------- | |
| dict | |
| Dictionary with keys: 'sizes', 'dims', 'coords', 'dtype' | |
| """ | |
| return { | |
| "sizes": dict(data_xr.sizes), | |
| "dims": list(data_xr.dims), | |
| "coords": {dim: data_xr.coords[dim].values.tolist() for dim in data_xr.dims}, | |
| "dtype": str(data_xr.dtype), | |
| } | |
| def print_data_summary(data_xr: xr.DataArray) -> None: | |
| """ | |
| Print a formatted summary of xarray DataArray. | |
| Parameters | |
| ---------- | |
| data_xr : xr.DataArray | |
| Image data to summarize | |
| """ | |
| info = get_dimension_info(data_xr) | |
| print("\n" + "=" * 60) | |
| print("DATA SUMMARY") | |
| print("=" * 60) | |
| print(f"Shape: {info['sizes']}") | |
| print(f"Dimensions: {info['dims']}") | |
| print(f"Data type: {info['dtype']}") | |
| # Print coordinate ranges | |
| print("\nCoordinate Ranges:") | |
| for dim in info["dims"]: | |
| coords = info["coords"][dim] | |
| if len(coords) > 0: | |
| print(f" {dim}: [{coords[0]:.2f} ... {coords[-1]:.2f}] (n={len(coords)})") | |
| # Print memory size estimate | |
| total_elements = np.prod(list(info["sizes"].values())) | |
| dtype_size = np.dtype(data_xr.dtype).itemsize | |
| size_mb = (total_elements * dtype_size) / (1024**2) | |
| print(f"\nEstimated size: {size_mb:.1f} MB") | |
| print("=" * 60 + "\n") | |
| # === Phase Reconstruction Functions === | |
| def run_reconstruction(zyx_tile: torch.Tensor, recon_args: dict) -> torch.Tensor: | |
| """ | |
| Run phase reconstruction on a Z-stack. | |
| Uses waveorder's official _position_list_from_shape_scale_offset | |
| to ensure proper z-position calculation and correct phase sign. | |
| Parameters | |
| ---------- | |
| zyx_tile : torch.Tensor | |
| Input Z-stack data with shape (Z, Y, X). Can be on CPU or GPU. | |
| recon_args : dict | |
| Reconstruction arguments including wavelength, NA, pixel sizes, etc. | |
| All tensor values should be on the same device as zyx_tile. | |
| Returns | |
| ------- | |
| torch.Tensor | |
| Reconstructed 2D phase image with shape (Y, X), on same device as input. | |
| Notes | |
| ----- | |
| All intermediate tensors are created on the same device as the input | |
| to ensure efficient computation without device transfers. | |
| """ | |
| # Infer device from input tensor | |
| device = zyx_tile.device | |
| # Prepare transfer function arguments | |
| tf_args = recon_args.copy() | |
| Z, _, _ = zyx_tile.shape | |
| # Extract z_offset value (keep as tensor if it is one, for gradient flow) | |
| z_offset_value = recon_args["z_offset"] | |
| if torch.is_tensor(z_offset_value): | |
| # For optimization: extract scalar value for _position_list function | |
| z_offset_scalar = z_offset_value.item() | |
| else: | |
| z_offset_scalar = z_offset_value | |
| # Use waveorder's official function (returns torch.Tensor on CPU) | |
| z_position_list_cpu = _position_list_from_shape_scale_offset( | |
| shape=Z, | |
| scale=recon_args["z_scale"], | |
| offset=z_offset_scalar, | |
| ) | |
| # Move to device and ensure gradient connection if z_offset is a parameter | |
| if torch.is_tensor(z_offset_value) and z_offset_value.requires_grad: | |
| # Recompute on device to maintain gradient connection | |
| # Uses same formula as waveorder: -arange(Z) + (Z // 2) + offset | |
| z_position_list = ( | |
| -torch.arange(Z, dtype=torch.float32, device=device) + (Z // 2) + z_offset_value | |
| ) * recon_args["z_scale"] | |
| else: | |
| # No gradient needed, just move to device | |
| z_position_list = z_position_list_cpu.to(device) | |
| tf_args["z_position_list"] = z_position_list | |
| tf_args.pop("z_offset") | |
| tf_args.pop("z_scale") | |
| # Core reconstruction calls (all on same device) | |
| tf_abs, tf_phase = isotropic_thin_3d.calculate_transfer_function(**tf_args) | |
| system = isotropic_thin_3d.calculate_singular_system(tf_abs, tf_phase) | |
| _, yx_phase_recon = isotropic_thin_3d.apply_inverse_transfer_function( | |
| zyx_tile, system, regularization_strength=1e-2 | |
| ) | |
| return yx_phase_recon | |
| def compute_midband_power( | |
| yx_array: torch.Tensor, | |
| NA_det: float, | |
| lambda_ill: float, | |
| pixel_size: float, | |
| band: tuple[float, float] = (0.125, 0.25), | |
| ) -> torch.Tensor: | |
| """ | |
| Compute midband power metric for optimization loss. | |
| Parameters | |
| ---------- | |
| yx_array : torch.Tensor | |
| 2D reconstructed image (on CPU or GPU) | |
| NA_det : float | |
| Numerical aperture of detection | |
| lambda_ill : float | |
| Illumination wavelength | |
| pixel_size : float | |
| Pixel size in same units as wavelength | |
| band : tuple[float, float], optional | |
| Frequency band as fraction of cutoff, by default (0.125, 0.25) | |
| Returns | |
| ------- | |
| torch.Tensor | |
| Scalar power value in the specified frequency band, on same device as input. | |
| Notes | |
| ----- | |
| All operations are performed on the same device as the input tensor | |
| for efficient GPU computation. | |
| """ | |
| device = yx_array.device | |
| # Generate frequency coordinates (returns numpy arrays) | |
| _, _, fxx, fyy = util.gen_coordinate(yx_array.shape, pixel_size) | |
| # Convert to torch tensor on same device | |
| frr = torch.tensor(np.sqrt(fxx**2 + fyy**2), dtype=torch.float32, device=device) | |
| # FFT and frequency masking (all on device) | |
| xy_abs_fft = torch.abs(torch.fft.fftn(yx_array)) | |
| cutoff = 2 * NA_det / lambda_ill | |
| mask = torch.logical_and(frr > cutoff * band[0], frr < cutoff * band[1]) | |
| return torch.sum(xy_abs_fft[mask]) | |
| def prepare_optimizer( | |
| optimizable_params: dict[str, tuple[bool, float, float]], | |
| device: torch.device, | |
| ) -> tuple[dict[str, torch.nn.Parameter], torch.optim.Optimizer]: | |
| """ | |
| Prepare optimization parameters and Adam optimizer. | |
| Parameters | |
| ---------- | |
| optimizable_params : dict | |
| Dict mapping param names to (enabled, initial_value, learning_rate) | |
| device : torch.device | |
| Device to create parameters on (CPU or GPU) | |
| Returns | |
| ------- | |
| tuple[dict, Optimizer] | |
| optimization_params dict and configured optimizer | |
| Notes | |
| ----- | |
| All parameters are created on the specified device for efficient | |
| GPU-accelerated optimization if available. | |
| """ | |
| optimization_params: dict[str, torch.nn.Parameter] = {} | |
| optimizer_config = [] | |
| for name, (enabled, initial, lr) in optimizable_params.items(): | |
| if enabled: | |
| param = torch.nn.Parameter( | |
| torch.tensor([initial], dtype=torch.float32, device=device), | |
| requires_grad=True, | |
| ) | |
| optimization_params[name] = param | |
| optimizer_config.append({"params": [param], "lr": lr}) | |
| optimizer = torch.optim.Adam(optimizer_config) | |
| return optimization_params, optimizer | |
| def run_reconstruction_single( | |
| zyx_stack: np.ndarray, | |
| pixel_scales: tuple[float, float, float], | |
| fixed_params: dict, | |
| param_values: dict, | |
| device: Device = None, | |
| ) -> np.ndarray: | |
| """ | |
| Run a single phase reconstruction with specified parameters (no optimization). | |
| Parameters | |
| ---------- | |
| zyx_stack : np.ndarray | |
| Input Z-stack with shape (Z, Y, X) | |
| pixel_scales : tuple[float, float, float] | |
| (z_scale, y_scale, x_scale) in micrometers | |
| fixed_params : dict | |
| Fixed reconstruction parameters (wavelength, index, etc.) | |
| param_values : dict | |
| Parameter values to use (z_offset, numerical_aperture_detection, etc.) | |
| device : torch.device | str | None, optional | |
| Computing device. If None, auto-selects GPU if available, else CPU. | |
| Returns | |
| ------- | |
| np.ndarray | |
| Normalized uint8 array of reconstructed phase image (for display) | |
| """ | |
| # Resolve device (will print GPU info if available) | |
| device = get_device(device) | |
| # Convert to torch tensor on target device | |
| zyx_tile = torch.tensor(zyx_stack, dtype=torch.float32, device=device) | |
| # Prepare reconstruction arguments | |
| z_scale, y_scale, x_scale = pixel_scales | |
| recon_args = fixed_params.copy() | |
| # Remove non-reconstruction parameters from fixed_params | |
| recon_args.pop("num_iterations", None) | |
| recon_args.pop("use_tiling", None) | |
| recon_args.pop("device", None) | |
| recon_args["yx_shape"] = zyx_tile.shape[1:] | |
| recon_args["yx_pixel_size"] = y_scale | |
| recon_args["z_scale"] = z_scale | |
| # Set parameter values (convert to tensors on device) | |
| for name, value in param_values.items(): | |
| recon_args[name] = torch.tensor([value], dtype=torch.float32, device=device) | |
| # Run reconstruction | |
| yx_recon = run_reconstruction(zyx_tile, recon_args) | |
| # Transfer to CPU and normalize for display | |
| recon_numpy = yx_recon.detach().cpu().numpy() | |
| # Wrap in xarray for normalize_for_display (expects xr.DataArray) | |
| recon_normalized = normalize_for_display(xr.DataArray(recon_numpy)) | |
| return recon_normalized | |
| def run_optimization_streaming( | |
| zyx_stack: np.ndarray, | |
| pixel_scales: tuple[float, float, float], | |
| fixed_params: dict, | |
| optimizable_params: dict, | |
| num_iterations: int = 10, | |
| device: Device = None, | |
| ) -> Generator[dict, None, None]: | |
| """ | |
| Run phase reconstruction optimization with streaming updates. | |
| Generator that yields reconstruction results and loss after each iteration. | |
| Supports GPU acceleration for significant speedup (15-25x on typical hardware). | |
| Parameters | |
| ---------- | |
| zyx_stack : np.ndarray | |
| Input Z-stack with shape (Z, Y, X) | |
| pixel_scales : tuple[float, float, float] | |
| (z_scale, y_scale, x_scale) in micrometers | |
| fixed_params : dict | |
| Fixed reconstruction parameters (wavelength, index, etc.) | |
| optimizable_params : dict | |
| Parameters to optimize with (enabled, initial, lr) tuples | |
| num_iterations : int, optional | |
| Number of optimization iterations, by default 10 | |
| device : torch.device | str | None, optional | |
| Computing device. If None, auto-selects GPU if available, else CPU. | |
| Examples: "cuda", "cpu", "cuda:0", torch.device("cuda") | |
| By default None | |
| Yields | |
| ------ | |
| dict | |
| Dictionary with keys: | |
| - 'reconstructed_image': normalized uint8 array (on CPU for display) | |
| - 'loss': float loss value | |
| - 'iteration': int iteration number (1-indexed) | |
| - 'params': dict of current parameter values | |
| Notes | |
| ----- | |
| All computation is performed on the specified device (GPU if available). | |
| Only final results are transferred to CPU for display, minimizing | |
| transfer overhead. | |
| """ | |
| # Resolve device (will print GPU info if available) | |
| device = get_device(device) | |
| # Convert to torch tensor on target device (single transfer) | |
| zyx_tile = torch.tensor(zyx_stack, dtype=torch.float32, device=device) | |
| # Prepare reconstruction arguments | |
| z_scale, y_scale, x_scale = pixel_scales | |
| recon_args = fixed_params.copy() | |
| # Remove non-reconstruction parameters from fixed_params | |
| recon_args.pop("num_iterations", None) | |
| recon_args.pop("use_tiling", None) | |
| recon_args.pop("device", None) # Remove device if present | |
| recon_args["yx_shape"] = zyx_tile.shape[1:] | |
| recon_args["yx_pixel_size"] = y_scale | |
| recon_args["z_scale"] = z_scale | |
| # Initialize optimizable parameters on device | |
| for name, (enabled, initial, lr) in optimizable_params.items(): | |
| recon_args[name] = torch.tensor([initial], dtype=torch.float32, device=device) | |
| # Prepare optimizer with parameters on device | |
| optimization_params, optimizer = prepare_optimizer(optimizable_params, device) | |
| # Optimization loop (all on device) | |
| for step in range(num_iterations): | |
| # Update parameters | |
| for name, param in optimization_params.items(): | |
| recon_args[name] = param | |
| # Run reconstruction (all on device) | |
| yx_recon = run_reconstruction(zyx_tile, recon_args) | |
| # Compute loss (all on device, negative midband power - we want to maximize) | |
| loss = -compute_midband_power( | |
| yx_recon, | |
| NA_det=0.15, | |
| lambda_ill=recon_args["wavelength_illumination"], | |
| pixel_size=recon_args["yx_pixel_size"], | |
| band=(0.1, 0.2), | |
| ) | |
| # Backward pass and optimizer step (on device) | |
| loss.backward() | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| # Transfer to CPU ONLY for display (single transfer per iteration) | |
| recon_numpy = yx_recon.detach().cpu().numpy() | |
| # Wrap in xarray for normalize_for_display (expects xr.DataArray) | |
| recon_normalized = normalize_for_display(xr.DataArray(recon_numpy)) | |
| # Extract current parameter values (scalars, already on CPU) | |
| param_values = { | |
| name: param.item() for name, param in optimization_params.items() | |
| } | |
| # Yield results | |
| yield { | |
| "reconstructed_image": recon_normalized, | |
| "loss": loss.item(), | |
| "iteration": step + 1, | |
| "params": param_values, | |
| } | |
| def extract_tiles( | |
| zyx_data: np.ndarray, num_tiles: tuple[int, int], overlap_pct: float | |
| ) -> tuple[dict[str, np.ndarray], dict[str, tuple[int, int, int]]]: | |
| """ | |
| Extract overlapping tiles from a Z-stack for processing. | |
| Parameters | |
| ---------- | |
| zyx_data : np.ndarray | |
| Input data with shape (Z, Y, X) | |
| num_tiles : tuple[int, int] | |
| Number of tiles in (Y, X) dimensions | |
| overlap_pct : float | |
| Overlap percentage between tiles (0.0 to 1.0) | |
| Returns | |
| ------- | |
| tuple[dict, dict] | |
| tiles: dict mapping tile names to arrays | |
| translations: dict mapping tile names to (z, y, x) positions | |
| """ | |
| Z, Y, X = zyx_data.shape | |
| tile_height = int(np.ceil(Y / (num_tiles[0] - (num_tiles[0] - 1) * overlap_pct))) | |
| tile_width = int(np.ceil(X / (num_tiles[1] - (num_tiles[1] - 1) * overlap_pct))) | |
| stride_y = int(tile_height * (1 - overlap_pct)) | |
| stride_x = int(tile_width * (1 - overlap_pct)) | |
| tiles = {} | |
| translations = {} | |
| for yi in range(num_tiles[0]): | |
| for xi in range(num_tiles[1]): | |
| y0, x0 = yi * stride_y, xi * stride_x | |
| y1, x1 = min(y0 + tile_height, Y), min(x0 + tile_width, X) | |
| tile_name = f"0/0/{yi:03d}{xi:03d}" | |
| tiles[tile_name] = zyx_data[:, y0:y1, x0:x1] | |
| translations[tile_name] = (0, y0, x0) | |
| return tiles, translations | |