WaveOrder / demo_utils.py
srivarra's picture
updated datasets
ec6b668
"""
Utility functions for Gradio demos
Provides reusable components for:
- Data loading from OME-Zarr stores
- Image normalization and processing
- Slice extraction from xarray DataArrays
- Phase reconstruction and optimization
Design Notes
------------
All image processing functions work with xarray.DataArray to maintain
labeled dimensions and coordinate information as long as possible.
Only convert to numpy arrays at the final display step.
"""
from pathlib import Path
from typing import Generator
import numpy as np
import torch
import xarray as xr
from numpy.typing import NDArray
from xarray_ome import open_ome_dataset
from waveorder import util
from waveorder.models import isotropic_thin_3d
from waveorder.cli.compute_transfer_function import (
_position_list_from_shape_scale_offset,
)
# Type alias for device specification
Device = torch.device | str | None
def get_device(device: Device = None) -> torch.device:
"""
Get torch device with smart defaults.
Parameters
----------
device : torch.device | str | None
If None, auto-selects cuda if available, else cpu.
If str, converts to torch.device.
If torch.device, returns as-is.
Returns
-------
torch.device
Validated device ready for use
Examples
--------
>>> get_device() # Auto-detect
device(type='cuda', index=0) # if GPU available
>>> get_device("cpu") # Force CPU
device(type='cpu')
>>> get_device(torch.device("cuda:1")) # Specific GPU
device(type='cuda', index=1)
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
print(f"🚀 Using GPU: {torch.cuda.get_device_name(device)}")
gpu_mem_gb = torch.cuda.get_device_properties(device).total_memory / 1e9
print(f" GPU Memory: {gpu_mem_gb:.2f} GB")
else:
print("💻 Using CPU (GPU not available)")
return device
if isinstance(device, str):
return torch.device(device)
return device
# === HCS Plate Loading with iohub ===
def get_plate_metadata(zarr_path: Path | str, allowed_fovs: list[str]) -> dict:
"""
Extract HCS plate metadata for FOV selection using iohub.
Optimized to only load metadata for specified FOVs.
Parameters
----------
zarr_path : Path | str
Path to the HCS plate zarr store
allowed_fovs : list[str]
List of allowed FOV names (e.g., ['002026', '002027', '002028'])
Returns
-------
dict
Metadata with keys:
- 'rows': list of row names (e.g., ['A'])
- 'columns': list of column names (e.g., ['1', '2', '3'])
- 'wells': dict mapping (row, col) to list of field names
- 'plate': iohub Plate object for later access
- 'zarr_path': stored path for data loading
"""
from iohub import open_ome_zarr
# Open HCS plate with iohub (fast - doesn't load data)
plate = open_ome_zarr(str(zarr_path), mode="r", layout="hcs")
# Hardcoded metadata for known structure (avoids iterating 1000s of positions)
rows = ["A"]
columns = ["1", "2", "3"]
# Only return the allowed FOVs for each well
wells = {
("A", "1"): allowed_fovs,
("A", "2"): allowed_fovs,
("A", "3"): allowed_fovs,
}
return {
"rows": rows,
"columns": columns,
"wells": wells,
"plate": plate,
"zarr_path": str(zarr_path),
}
def load_fov_from_plate(
plate, row: str, column: str, field: str, resolution: int = 0
) -> xr.DataArray:
"""
Load a specific FOV from HCS plate using hybrid iohub + xarray-ome approach.
Uses iohub for navigation, then xarray-ome for fast data loading.
Parameters
----------
plate : iohub.Plate
Plate loaded with open_ome_zarr(..., layout="hcs")
row : str
Row name (e.g., 'A')
column : str
Column name (e.g., '1')
field : str
Field/position name (e.g., '002026')
resolution : int, optional
Resolution level to load, by default 0
Returns
-------
xr.DataArray
Image data with labeled dimensions (T, C, Z, Y, X)
"""
# Navigate to position using iohub (fast)
position_key = f"{row}/{column}/{field}"
position = plate[position_key]
# Get full zarr path from position (handle both Zarr V2 and V3)
store = position.zgroup.store
if hasattr(store, 'path'):
base_path = Path(store.path) # Zarr V2 (DirectoryStore)
elif hasattr(store, 'root'):
base_path = Path(store.root) # Zarr V3 (LocalStore)
else:
raise RuntimeError(f"Unknown store type: {type(store)}")
position_path = base_path / position.zgroup.path
# Load with xarray-ome (fast and reliable)
fov_dataset = open_ome_dataset(position_path, resolution=resolution, validate=False)
data_xr = fov_dataset["image"]
return data_xr
# === Data Loading ===
def load_ome_zarr_fov(
zarr_path: Path | str, fov_path: Path | str, resolution: int = 0
) -> xr.DataArray:
"""
Load a field of view from an OME-Zarr store as an xarray DataArray.
Parameters
----------
zarr_path : Path | str
Path to the root OME-Zarr store
fov_path : Path | str
Relative path to the FOV (e.g., "A/1/001007")
resolution : int, optional
Resolution level to load (0 is full resolution), by default 0
Returns
-------
xr.DataArray
Image data with labeled dimensions (T, C, Z, Y, X)
"""
zarr_path = Path(zarr_path)
fov_path = Path(fov_path)
print(f"Loading zarr store from: {zarr_path}")
print(f"Accessing FOV: {fov_path}")
# Load as xarray Dataset
fov_dataset: xr.Dataset = open_ome_dataset(
zarr_path / fov_path, resolution=resolution, validate=False
)
# Extract the image DataArray
data_xr = fov_dataset["image"]
print(f"Loaded data shape: {dict(data_xr.sizes)}")
print(f"Dimensions: {list(data_xr.dims)}")
print(f"Data type: {data_xr.dtype}")
return data_xr
# === Image Processing ===
def normalize_for_display(
img_2d: xr.DataArray,
percentiles: tuple[float, float] = (1, 99),
clip_to_uint8: bool = True,
) -> np.ndarray:
"""
Normalize a 2D microscopy image using percentile clipping.
Uses robust percentile-based normalization to handle outliers
common in microscopy data. Works with xarray DataArrays to maintain
labeled dimensions through the processing pipeline.
Parameters
----------
img_2d : xr.DataArray
2D image DataArray to normalize
percentiles : tuple[float, float], optional
Lower and upper percentiles for clipping, by default (1, 99)
clip_to_uint8 : bool, optional
If True, convert to uint8 (0-255), otherwise keep as float (0-1),
by default True
Returns
-------
np.ndarray
Normalized numpy array (uint8 if clip_to_uint8=True, else float32)
Notes
-----
Expects xarray.DataArray input. For raw numpy arrays,
wrap in xarray first: xr.DataArray(array, dims=["Y", "X"])
"""
# Calculate percentiles using xarray
p_low = float(img_2d.quantile(percentiles[0] / 100.0).values)
p_high = float(img_2d.quantile(percentiles[1] / 100.0).values)
# Handle edge case: no intensity variation
if p_high - p_low < 1e-10:
return np.zeros(img_2d.shape, dtype=np.uint8 if clip_to_uint8 else np.float32)
# Clip and normalize using xarray operations
img_clipped = img_2d.clip(min=p_low, max=p_high)
img_normalized = (img_clipped - p_low) / (p_high - p_low)
# Convert to numpy array
result = img_normalized.values
# Convert to requested output format
if clip_to_uint8:
result = (result * 255).astype(np.uint8)
else:
result = result.astype(np.float32)
return result
# === Slice Extraction ===
def extract_2d_slice(
data_xr: xr.DataArray,
t: int | None = None,
c: int | None = None,
z: int | None = None,
normalize: bool = True,
verbose: bool = True,
) -> np.ndarray:
"""
Extract and optionally normalize a 2D slice from xarray data.
Flexibly handles different dimension specifications. If a dimension
index is None, it will be squeezed out if size=1 or raise an error
if size>1.
Parameters
----------
data_xr : xr.DataArray
Image data with dimensions (T, C, Z, Y, X)
t : int | None, optional
Timepoint index, by default None
c : int | None, optional
Channel index, by default None
z : int | None, optional
Z-slice index, by default None
normalize : bool, optional
Whether to normalize for display, by default True
verbose : bool, optional
Whether to print slice information, by default True
Returns
-------
np.ndarray
2D numpy array (normalized uint8 if normalize=True, else raw values)
Raises
------
ValueError
If result is empty or not 2D after slicing and squeezing
"""
# Build selection dictionary for indexed dimensions
sel_dict = {}
if t is not None:
sel_dict["T"] = int(t)
if c is not None:
sel_dict["C"] = int(c)
if z is not None:
sel_dict["Z"] = int(z)
# Extract slice using xarray's labeled indexing
slice_xr = data_xr.isel(**sel_dict) if sel_dict else data_xr
# Compute if Dask-backed (load from disk)
if hasattr(slice_xr.data, "compute"):
slice_xr = slice_xr.compute()
# Squeeze singleton dimensions (e.g., single channel, single Z)
slice_xr = slice_xr.squeeze()
# Validation: ensure non-empty result
if slice_xr.size == 0:
raise ValueError(
f"Empty array after slicing. Selection: {sel_dict}, "
f"Original shape: {data_xr.shape}"
)
# Validation: ensure 2D result
if slice_xr.ndim != 2:
raise ValueError(
f"Expected 2D array after slicing, got shape {slice_xr.shape}. "
f"Selection: {sel_dict}"
)
# Verbose output: print slice information
if verbose:
sel_str = (
", ".join(f"{k}={v}" for k, v in sel_dict.items())
if sel_dict
else "full array"
)
print(
f"Extracted slice: {sel_str}, Shape={slice_xr.shape}, "
f"Range=[{float(slice_xr.min()):.1f}, {float(slice_xr.max()):.1f}]"
)
# Normalize or convert to numpy
if normalize:
slice_2d = normalize_for_display(slice_xr)
else:
slice_2d = slice_xr.values
return slice_2d
# === Slice Extraction Factory ===
def create_slice_extractor(
data_xr: xr.DataArray,
normalize: bool = True,
channel: int = 0,
):
"""
Create a closure function for extracting slices from a specific dataset.
This factory function is useful for Gradio callbacks where the data
is loaded once and the same extraction function is called multiple times.
Parameters
----------
data_xr : xr.DataArray
Image data to extract slices from
normalize : bool, optional
Whether to normalize for display, by default True
channel : int, optional
Default channel to use, by default 0
Returns
-------
callable
Function with signature (t: int, z: int) -> np.ndarray that extracts
and normalizes 2D slices
"""
def get_slice(t: int, z: int) -> np.ndarray:
"""Extract and normalize a 2D slice at timepoint t and z-slice z."""
return extract_2d_slice(
data_xr,
t=int(t),
c=channel,
z=int(z),
normalize=normalize,
verbose=True,
)
return get_slice
# === Metadata Helpers ===
def get_dimension_info(data_xr: xr.DataArray) -> dict:
"""
Extract dimension information from xarray DataArray.
Parameters
----------
data_xr : xr.DataArray
Image data with dimensions
Returns
-------
dict
Dictionary with keys: 'sizes', 'dims', 'coords', 'dtype'
"""
return {
"sizes": dict(data_xr.sizes),
"dims": list(data_xr.dims),
"coords": {dim: data_xr.coords[dim].values.tolist() for dim in data_xr.dims},
"dtype": str(data_xr.dtype),
}
def print_data_summary(data_xr: xr.DataArray) -> None:
"""
Print a formatted summary of xarray DataArray.
Parameters
----------
data_xr : xr.DataArray
Image data to summarize
"""
info = get_dimension_info(data_xr)
print("\n" + "=" * 60)
print("DATA SUMMARY")
print("=" * 60)
print(f"Shape: {info['sizes']}")
print(f"Dimensions: {info['dims']}")
print(f"Data type: {info['dtype']}")
# Print coordinate ranges
print("\nCoordinate Ranges:")
for dim in info["dims"]:
coords = info["coords"][dim]
if len(coords) > 0:
print(f" {dim}: [{coords[0]:.2f} ... {coords[-1]:.2f}] (n={len(coords)})")
# Print memory size estimate
total_elements = np.prod(list(info["sizes"].values()))
dtype_size = np.dtype(data_xr.dtype).itemsize
size_mb = (total_elements * dtype_size) / (1024**2)
print(f"\nEstimated size: {size_mb:.1f} MB")
print("=" * 60 + "\n")
# === Phase Reconstruction Functions ===
def run_reconstruction(zyx_tile: torch.Tensor, recon_args: dict) -> torch.Tensor:
"""
Run phase reconstruction on a Z-stack.
Uses waveorder's official _position_list_from_shape_scale_offset
to ensure proper z-position calculation and correct phase sign.
Parameters
----------
zyx_tile : torch.Tensor
Input Z-stack data with shape (Z, Y, X). Can be on CPU or GPU.
recon_args : dict
Reconstruction arguments including wavelength, NA, pixel sizes, etc.
All tensor values should be on the same device as zyx_tile.
Returns
-------
torch.Tensor
Reconstructed 2D phase image with shape (Y, X), on same device as input.
Notes
-----
All intermediate tensors are created on the same device as the input
to ensure efficient computation without device transfers.
"""
# Infer device from input tensor
device = zyx_tile.device
# Prepare transfer function arguments
tf_args = recon_args.copy()
Z, _, _ = zyx_tile.shape
# Extract z_offset value (keep as tensor if it is one, for gradient flow)
z_offset_value = recon_args["z_offset"]
if torch.is_tensor(z_offset_value):
# For optimization: extract scalar value for _position_list function
z_offset_scalar = z_offset_value.item()
else:
z_offset_scalar = z_offset_value
# Use waveorder's official function (returns torch.Tensor on CPU)
z_position_list_cpu = _position_list_from_shape_scale_offset(
shape=Z,
scale=recon_args["z_scale"],
offset=z_offset_scalar,
)
# Move to device and ensure gradient connection if z_offset is a parameter
if torch.is_tensor(z_offset_value) and z_offset_value.requires_grad:
# Recompute on device to maintain gradient connection
# Uses same formula as waveorder: -arange(Z) + (Z // 2) + offset
z_position_list = (
-torch.arange(Z, dtype=torch.float32, device=device) + (Z // 2) + z_offset_value
) * recon_args["z_scale"]
else:
# No gradient needed, just move to device
z_position_list = z_position_list_cpu.to(device)
tf_args["z_position_list"] = z_position_list
tf_args.pop("z_offset")
tf_args.pop("z_scale")
# Core reconstruction calls (all on same device)
tf_abs, tf_phase = isotropic_thin_3d.calculate_transfer_function(**tf_args)
system = isotropic_thin_3d.calculate_singular_system(tf_abs, tf_phase)
_, yx_phase_recon = isotropic_thin_3d.apply_inverse_transfer_function(
zyx_tile, system, regularization_strength=1e-2
)
return yx_phase_recon
def compute_midband_power(
yx_array: torch.Tensor,
NA_det: float,
lambda_ill: float,
pixel_size: float,
band: tuple[float, float] = (0.125, 0.25),
) -> torch.Tensor:
"""
Compute midband power metric for optimization loss.
Parameters
----------
yx_array : torch.Tensor
2D reconstructed image (on CPU or GPU)
NA_det : float
Numerical aperture of detection
lambda_ill : float
Illumination wavelength
pixel_size : float
Pixel size in same units as wavelength
band : tuple[float, float], optional
Frequency band as fraction of cutoff, by default (0.125, 0.25)
Returns
-------
torch.Tensor
Scalar power value in the specified frequency band, on same device as input.
Notes
-----
All operations are performed on the same device as the input tensor
for efficient GPU computation.
"""
device = yx_array.device
# Generate frequency coordinates (returns numpy arrays)
_, _, fxx, fyy = util.gen_coordinate(yx_array.shape, pixel_size)
# Convert to torch tensor on same device
frr = torch.tensor(np.sqrt(fxx**2 + fyy**2), dtype=torch.float32, device=device)
# FFT and frequency masking (all on device)
xy_abs_fft = torch.abs(torch.fft.fftn(yx_array))
cutoff = 2 * NA_det / lambda_ill
mask = torch.logical_and(frr > cutoff * band[0], frr < cutoff * band[1])
return torch.sum(xy_abs_fft[mask])
def prepare_optimizer(
optimizable_params: dict[str, tuple[bool, float, float]],
device: torch.device,
) -> tuple[dict[str, torch.nn.Parameter], torch.optim.Optimizer]:
"""
Prepare optimization parameters and Adam optimizer.
Parameters
----------
optimizable_params : dict
Dict mapping param names to (enabled, initial_value, learning_rate)
device : torch.device
Device to create parameters on (CPU or GPU)
Returns
-------
tuple[dict, Optimizer]
optimization_params dict and configured optimizer
Notes
-----
All parameters are created on the specified device for efficient
GPU-accelerated optimization if available.
"""
optimization_params: dict[str, torch.nn.Parameter] = {}
optimizer_config = []
for name, (enabled, initial, lr) in optimizable_params.items():
if enabled:
param = torch.nn.Parameter(
torch.tensor([initial], dtype=torch.float32, device=device),
requires_grad=True,
)
optimization_params[name] = param
optimizer_config.append({"params": [param], "lr": lr})
optimizer = torch.optim.Adam(optimizer_config)
return optimization_params, optimizer
def run_reconstruction_single(
zyx_stack: np.ndarray,
pixel_scales: tuple[float, float, float],
fixed_params: dict,
param_values: dict,
device: Device = None,
) -> np.ndarray:
"""
Run a single phase reconstruction with specified parameters (no optimization).
Parameters
----------
zyx_stack : np.ndarray
Input Z-stack with shape (Z, Y, X)
pixel_scales : tuple[float, float, float]
(z_scale, y_scale, x_scale) in micrometers
fixed_params : dict
Fixed reconstruction parameters (wavelength, index, etc.)
param_values : dict
Parameter values to use (z_offset, numerical_aperture_detection, etc.)
device : torch.device | str | None, optional
Computing device. If None, auto-selects GPU if available, else CPU.
Returns
-------
np.ndarray
Normalized uint8 array of reconstructed phase image (for display)
"""
# Resolve device (will print GPU info if available)
device = get_device(device)
# Convert to torch tensor on target device
zyx_tile = torch.tensor(zyx_stack, dtype=torch.float32, device=device)
# Prepare reconstruction arguments
z_scale, y_scale, x_scale = pixel_scales
recon_args = fixed_params.copy()
# Remove non-reconstruction parameters from fixed_params
recon_args.pop("num_iterations", None)
recon_args.pop("use_tiling", None)
recon_args.pop("device", None)
recon_args["yx_shape"] = zyx_tile.shape[1:]
recon_args["yx_pixel_size"] = y_scale
recon_args["z_scale"] = z_scale
# Set parameter values (convert to tensors on device)
for name, value in param_values.items():
recon_args[name] = torch.tensor([value], dtype=torch.float32, device=device)
# Run reconstruction
yx_recon = run_reconstruction(zyx_tile, recon_args)
# Transfer to CPU and normalize for display
recon_numpy = yx_recon.detach().cpu().numpy()
# Wrap in xarray for normalize_for_display (expects xr.DataArray)
recon_normalized = normalize_for_display(xr.DataArray(recon_numpy))
return recon_normalized
def run_optimization_streaming(
zyx_stack: np.ndarray,
pixel_scales: tuple[float, float, float],
fixed_params: dict,
optimizable_params: dict,
num_iterations: int = 10,
device: Device = None,
) -> Generator[dict, None, None]:
"""
Run phase reconstruction optimization with streaming updates.
Generator that yields reconstruction results and loss after each iteration.
Supports GPU acceleration for significant speedup (15-25x on typical hardware).
Parameters
----------
zyx_stack : np.ndarray
Input Z-stack with shape (Z, Y, X)
pixel_scales : tuple[float, float, float]
(z_scale, y_scale, x_scale) in micrometers
fixed_params : dict
Fixed reconstruction parameters (wavelength, index, etc.)
optimizable_params : dict
Parameters to optimize with (enabled, initial, lr) tuples
num_iterations : int, optional
Number of optimization iterations, by default 10
device : torch.device | str | None, optional
Computing device. If None, auto-selects GPU if available, else CPU.
Examples: "cuda", "cpu", "cuda:0", torch.device("cuda")
By default None
Yields
------
dict
Dictionary with keys:
- 'reconstructed_image': normalized uint8 array (on CPU for display)
- 'loss': float loss value
- 'iteration': int iteration number (1-indexed)
- 'params': dict of current parameter values
Notes
-----
All computation is performed on the specified device (GPU if available).
Only final results are transferred to CPU for display, minimizing
transfer overhead.
"""
# Resolve device (will print GPU info if available)
device = get_device(device)
# Convert to torch tensor on target device (single transfer)
zyx_tile = torch.tensor(zyx_stack, dtype=torch.float32, device=device)
# Prepare reconstruction arguments
z_scale, y_scale, x_scale = pixel_scales
recon_args = fixed_params.copy()
# Remove non-reconstruction parameters from fixed_params
recon_args.pop("num_iterations", None)
recon_args.pop("use_tiling", None)
recon_args.pop("device", None) # Remove device if present
recon_args["yx_shape"] = zyx_tile.shape[1:]
recon_args["yx_pixel_size"] = y_scale
recon_args["z_scale"] = z_scale
# Initialize optimizable parameters on device
for name, (enabled, initial, lr) in optimizable_params.items():
recon_args[name] = torch.tensor([initial], dtype=torch.float32, device=device)
# Prepare optimizer with parameters on device
optimization_params, optimizer = prepare_optimizer(optimizable_params, device)
# Optimization loop (all on device)
for step in range(num_iterations):
# Update parameters
for name, param in optimization_params.items():
recon_args[name] = param
# Run reconstruction (all on device)
yx_recon = run_reconstruction(zyx_tile, recon_args)
# Compute loss (all on device, negative midband power - we want to maximize)
loss = -compute_midband_power(
yx_recon,
NA_det=0.15,
lambda_ill=recon_args["wavelength_illumination"],
pixel_size=recon_args["yx_pixel_size"],
band=(0.1, 0.2),
)
# Backward pass and optimizer step (on device)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Transfer to CPU ONLY for display (single transfer per iteration)
recon_numpy = yx_recon.detach().cpu().numpy()
# Wrap in xarray for normalize_for_display (expects xr.DataArray)
recon_normalized = normalize_for_display(xr.DataArray(recon_numpy))
# Extract current parameter values (scalars, already on CPU)
param_values = {
name: param.item() for name, param in optimization_params.items()
}
# Yield results
yield {
"reconstructed_image": recon_normalized,
"loss": loss.item(),
"iteration": step + 1,
"params": param_values,
}
def extract_tiles(
zyx_data: np.ndarray, num_tiles: tuple[int, int], overlap_pct: float
) -> tuple[dict[str, np.ndarray], dict[str, tuple[int, int, int]]]:
"""
Extract overlapping tiles from a Z-stack for processing.
Parameters
----------
zyx_data : np.ndarray
Input data with shape (Z, Y, X)
num_tiles : tuple[int, int]
Number of tiles in (Y, X) dimensions
overlap_pct : float
Overlap percentage between tiles (0.0 to 1.0)
Returns
-------
tuple[dict, dict]
tiles: dict mapping tile names to arrays
translations: dict mapping tile names to (z, y, x) positions
"""
Z, Y, X = zyx_data.shape
tile_height = int(np.ceil(Y / (num_tiles[0] - (num_tiles[0] - 1) * overlap_pct)))
tile_width = int(np.ceil(X / (num_tiles[1] - (num_tiles[1] - 1) * overlap_pct)))
stride_y = int(tile_height * (1 - overlap_pct))
stride_x = int(tile_width * (1 - overlap_pct))
tiles = {}
translations = {}
for yi in range(num_tiles[0]):
for xi in range(num_tiles[1]):
y0, x0 = yi * stride_y, xi * stride_x
y1, x1 = min(y0 + tile_height, Y), min(x0 + tile_width, X)
tile_name = f"0/0/{yi:03d}{xi:03d}"
tiles[tile_name] = zyx_data[:, y0:y1, x0:x1]
translations[tile_name] = (0, y0, x0)
return tiles, translations