Spaces:
Runtime error
Runtime error
File size: 26,370 Bytes
7f5c4ef ec6b668 7f5c4ef ec6b668 7f5c4ef ec6b668 7f5c4ef ec6b668 7f5c4ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 |
"""
Utility functions for Gradio demos
Provides reusable components for:
- Data loading from OME-Zarr stores
- Image normalization and processing
- Slice extraction from xarray DataArrays
- Phase reconstruction and optimization
Design Notes
------------
All image processing functions work with xarray.DataArray to maintain
labeled dimensions and coordinate information as long as possible.
Only convert to numpy arrays at the final display step.
"""
from pathlib import Path
from typing import Generator
import numpy as np
import torch
import xarray as xr
from numpy.typing import NDArray
from xarray_ome import open_ome_dataset
from waveorder import util
from waveorder.models import isotropic_thin_3d
from waveorder.cli.compute_transfer_function import (
_position_list_from_shape_scale_offset,
)
# Type alias for device specification
Device = torch.device | str | None
def get_device(device: Device = None) -> torch.device:
"""
Get torch device with smart defaults.
Parameters
----------
device : torch.device | str | None
If None, auto-selects cuda if available, else cpu.
If str, converts to torch.device.
If torch.device, returns as-is.
Returns
-------
torch.device
Validated device ready for use
Examples
--------
>>> get_device() # Auto-detect
device(type='cuda', index=0) # if GPU available
>>> get_device("cpu") # Force CPU
device(type='cpu')
>>> get_device(torch.device("cuda:1")) # Specific GPU
device(type='cuda', index=1)
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
print(f"🚀 Using GPU: {torch.cuda.get_device_name(device)}")
gpu_mem_gb = torch.cuda.get_device_properties(device).total_memory / 1e9
print(f" GPU Memory: {gpu_mem_gb:.2f} GB")
else:
print("💻 Using CPU (GPU not available)")
return device
if isinstance(device, str):
return torch.device(device)
return device
# === HCS Plate Loading with iohub ===
def get_plate_metadata(zarr_path: Path | str, allowed_fovs: list[str]) -> dict:
"""
Extract HCS plate metadata for FOV selection using iohub.
Optimized to only load metadata for specified FOVs.
Parameters
----------
zarr_path : Path | str
Path to the HCS plate zarr store
allowed_fovs : list[str]
List of allowed FOV names (e.g., ['002026', '002027', '002028'])
Returns
-------
dict
Metadata with keys:
- 'rows': list of row names (e.g., ['A'])
- 'columns': list of column names (e.g., ['1', '2', '3'])
- 'wells': dict mapping (row, col) to list of field names
- 'plate': iohub Plate object for later access
- 'zarr_path': stored path for data loading
"""
from iohub import open_ome_zarr
# Open HCS plate with iohub (fast - doesn't load data)
plate = open_ome_zarr(str(zarr_path), mode="r", layout="hcs")
# Hardcoded metadata for known structure (avoids iterating 1000s of positions)
rows = ["A"]
columns = ["1", "2", "3"]
# Only return the allowed FOVs for each well
wells = {
("A", "1"): allowed_fovs,
("A", "2"): allowed_fovs,
("A", "3"): allowed_fovs,
}
return {
"rows": rows,
"columns": columns,
"wells": wells,
"plate": plate,
"zarr_path": str(zarr_path),
}
def load_fov_from_plate(
plate, row: str, column: str, field: str, resolution: int = 0
) -> xr.DataArray:
"""
Load a specific FOV from HCS plate using hybrid iohub + xarray-ome approach.
Uses iohub for navigation, then xarray-ome for fast data loading.
Parameters
----------
plate : iohub.Plate
Plate loaded with open_ome_zarr(..., layout="hcs")
row : str
Row name (e.g., 'A')
column : str
Column name (e.g., '1')
field : str
Field/position name (e.g., '002026')
resolution : int, optional
Resolution level to load, by default 0
Returns
-------
xr.DataArray
Image data with labeled dimensions (T, C, Z, Y, X)
"""
# Navigate to position using iohub (fast)
position_key = f"{row}/{column}/{field}"
position = plate[position_key]
# Get full zarr path from position (handle both Zarr V2 and V3)
store = position.zgroup.store
if hasattr(store, 'path'):
base_path = Path(store.path) # Zarr V2 (DirectoryStore)
elif hasattr(store, 'root'):
base_path = Path(store.root) # Zarr V3 (LocalStore)
else:
raise RuntimeError(f"Unknown store type: {type(store)}")
position_path = base_path / position.zgroup.path
# Load with xarray-ome (fast and reliable)
fov_dataset = open_ome_dataset(position_path, resolution=resolution, validate=False)
data_xr = fov_dataset["image"]
return data_xr
# === Data Loading ===
def load_ome_zarr_fov(
zarr_path: Path | str, fov_path: Path | str, resolution: int = 0
) -> xr.DataArray:
"""
Load a field of view from an OME-Zarr store as an xarray DataArray.
Parameters
----------
zarr_path : Path | str
Path to the root OME-Zarr store
fov_path : Path | str
Relative path to the FOV (e.g., "A/1/001007")
resolution : int, optional
Resolution level to load (0 is full resolution), by default 0
Returns
-------
xr.DataArray
Image data with labeled dimensions (T, C, Z, Y, X)
"""
zarr_path = Path(zarr_path)
fov_path = Path(fov_path)
print(f"Loading zarr store from: {zarr_path}")
print(f"Accessing FOV: {fov_path}")
# Load as xarray Dataset
fov_dataset: xr.Dataset = open_ome_dataset(
zarr_path / fov_path, resolution=resolution, validate=False
)
# Extract the image DataArray
data_xr = fov_dataset["image"]
print(f"Loaded data shape: {dict(data_xr.sizes)}")
print(f"Dimensions: {list(data_xr.dims)}")
print(f"Data type: {data_xr.dtype}")
return data_xr
# === Image Processing ===
def normalize_for_display(
img_2d: xr.DataArray,
percentiles: tuple[float, float] = (1, 99),
clip_to_uint8: bool = True,
) -> np.ndarray:
"""
Normalize a 2D microscopy image using percentile clipping.
Uses robust percentile-based normalization to handle outliers
common in microscopy data. Works with xarray DataArrays to maintain
labeled dimensions through the processing pipeline.
Parameters
----------
img_2d : xr.DataArray
2D image DataArray to normalize
percentiles : tuple[float, float], optional
Lower and upper percentiles for clipping, by default (1, 99)
clip_to_uint8 : bool, optional
If True, convert to uint8 (0-255), otherwise keep as float (0-1),
by default True
Returns
-------
np.ndarray
Normalized numpy array (uint8 if clip_to_uint8=True, else float32)
Notes
-----
Expects xarray.DataArray input. For raw numpy arrays,
wrap in xarray first: xr.DataArray(array, dims=["Y", "X"])
"""
# Calculate percentiles using xarray
p_low = float(img_2d.quantile(percentiles[0] / 100.0).values)
p_high = float(img_2d.quantile(percentiles[1] / 100.0).values)
# Handle edge case: no intensity variation
if p_high - p_low < 1e-10:
return np.zeros(img_2d.shape, dtype=np.uint8 if clip_to_uint8 else np.float32)
# Clip and normalize using xarray operations
img_clipped = img_2d.clip(min=p_low, max=p_high)
img_normalized = (img_clipped - p_low) / (p_high - p_low)
# Convert to numpy array
result = img_normalized.values
# Convert to requested output format
if clip_to_uint8:
result = (result * 255).astype(np.uint8)
else:
result = result.astype(np.float32)
return result
# === Slice Extraction ===
def extract_2d_slice(
data_xr: xr.DataArray,
t: int | None = None,
c: int | None = None,
z: int | None = None,
normalize: bool = True,
verbose: bool = True,
) -> np.ndarray:
"""
Extract and optionally normalize a 2D slice from xarray data.
Flexibly handles different dimension specifications. If a dimension
index is None, it will be squeezed out if size=1 or raise an error
if size>1.
Parameters
----------
data_xr : xr.DataArray
Image data with dimensions (T, C, Z, Y, X)
t : int | None, optional
Timepoint index, by default None
c : int | None, optional
Channel index, by default None
z : int | None, optional
Z-slice index, by default None
normalize : bool, optional
Whether to normalize for display, by default True
verbose : bool, optional
Whether to print slice information, by default True
Returns
-------
np.ndarray
2D numpy array (normalized uint8 if normalize=True, else raw values)
Raises
------
ValueError
If result is empty or not 2D after slicing and squeezing
"""
# Build selection dictionary for indexed dimensions
sel_dict = {}
if t is not None:
sel_dict["T"] = int(t)
if c is not None:
sel_dict["C"] = int(c)
if z is not None:
sel_dict["Z"] = int(z)
# Extract slice using xarray's labeled indexing
slice_xr = data_xr.isel(**sel_dict) if sel_dict else data_xr
# Compute if Dask-backed (load from disk)
if hasattr(slice_xr.data, "compute"):
slice_xr = slice_xr.compute()
# Squeeze singleton dimensions (e.g., single channel, single Z)
slice_xr = slice_xr.squeeze()
# Validation: ensure non-empty result
if slice_xr.size == 0:
raise ValueError(
f"Empty array after slicing. Selection: {sel_dict}, "
f"Original shape: {data_xr.shape}"
)
# Validation: ensure 2D result
if slice_xr.ndim != 2:
raise ValueError(
f"Expected 2D array after slicing, got shape {slice_xr.shape}. "
f"Selection: {sel_dict}"
)
# Verbose output: print slice information
if verbose:
sel_str = (
", ".join(f"{k}={v}" for k, v in sel_dict.items())
if sel_dict
else "full array"
)
print(
f"Extracted slice: {sel_str}, Shape={slice_xr.shape}, "
f"Range=[{float(slice_xr.min()):.1f}, {float(slice_xr.max()):.1f}]"
)
# Normalize or convert to numpy
if normalize:
slice_2d = normalize_for_display(slice_xr)
else:
slice_2d = slice_xr.values
return slice_2d
# === Slice Extraction Factory ===
def create_slice_extractor(
data_xr: xr.DataArray,
normalize: bool = True,
channel: int = 0,
):
"""
Create a closure function for extracting slices from a specific dataset.
This factory function is useful for Gradio callbacks where the data
is loaded once and the same extraction function is called multiple times.
Parameters
----------
data_xr : xr.DataArray
Image data to extract slices from
normalize : bool, optional
Whether to normalize for display, by default True
channel : int, optional
Default channel to use, by default 0
Returns
-------
callable
Function with signature (t: int, z: int) -> np.ndarray that extracts
and normalizes 2D slices
"""
def get_slice(t: int, z: int) -> np.ndarray:
"""Extract and normalize a 2D slice at timepoint t and z-slice z."""
return extract_2d_slice(
data_xr,
t=int(t),
c=channel,
z=int(z),
normalize=normalize,
verbose=True,
)
return get_slice
# === Metadata Helpers ===
def get_dimension_info(data_xr: xr.DataArray) -> dict:
"""
Extract dimension information from xarray DataArray.
Parameters
----------
data_xr : xr.DataArray
Image data with dimensions
Returns
-------
dict
Dictionary with keys: 'sizes', 'dims', 'coords', 'dtype'
"""
return {
"sizes": dict(data_xr.sizes),
"dims": list(data_xr.dims),
"coords": {dim: data_xr.coords[dim].values.tolist() for dim in data_xr.dims},
"dtype": str(data_xr.dtype),
}
def print_data_summary(data_xr: xr.DataArray) -> None:
"""
Print a formatted summary of xarray DataArray.
Parameters
----------
data_xr : xr.DataArray
Image data to summarize
"""
info = get_dimension_info(data_xr)
print("\n" + "=" * 60)
print("DATA SUMMARY")
print("=" * 60)
print(f"Shape: {info['sizes']}")
print(f"Dimensions: {info['dims']}")
print(f"Data type: {info['dtype']}")
# Print coordinate ranges
print("\nCoordinate Ranges:")
for dim in info["dims"]:
coords = info["coords"][dim]
if len(coords) > 0:
print(f" {dim}: [{coords[0]:.2f} ... {coords[-1]:.2f}] (n={len(coords)})")
# Print memory size estimate
total_elements = np.prod(list(info["sizes"].values()))
dtype_size = np.dtype(data_xr.dtype).itemsize
size_mb = (total_elements * dtype_size) / (1024**2)
print(f"\nEstimated size: {size_mb:.1f} MB")
print("=" * 60 + "\n")
# === Phase Reconstruction Functions ===
def run_reconstruction(zyx_tile: torch.Tensor, recon_args: dict) -> torch.Tensor:
"""
Run phase reconstruction on a Z-stack.
Uses waveorder's official _position_list_from_shape_scale_offset
to ensure proper z-position calculation and correct phase sign.
Parameters
----------
zyx_tile : torch.Tensor
Input Z-stack data with shape (Z, Y, X). Can be on CPU or GPU.
recon_args : dict
Reconstruction arguments including wavelength, NA, pixel sizes, etc.
All tensor values should be on the same device as zyx_tile.
Returns
-------
torch.Tensor
Reconstructed 2D phase image with shape (Y, X), on same device as input.
Notes
-----
All intermediate tensors are created on the same device as the input
to ensure efficient computation without device transfers.
"""
# Infer device from input tensor
device = zyx_tile.device
# Prepare transfer function arguments
tf_args = recon_args.copy()
Z, _, _ = zyx_tile.shape
# Extract z_offset value (keep as tensor if it is one, for gradient flow)
z_offset_value = recon_args["z_offset"]
if torch.is_tensor(z_offset_value):
# For optimization: extract scalar value for _position_list function
z_offset_scalar = z_offset_value.item()
else:
z_offset_scalar = z_offset_value
# Use waveorder's official function (returns torch.Tensor on CPU)
z_position_list_cpu = _position_list_from_shape_scale_offset(
shape=Z,
scale=recon_args["z_scale"],
offset=z_offset_scalar,
)
# Move to device and ensure gradient connection if z_offset is a parameter
if torch.is_tensor(z_offset_value) and z_offset_value.requires_grad:
# Recompute on device to maintain gradient connection
# Uses same formula as waveorder: -arange(Z) + (Z // 2) + offset
z_position_list = (
-torch.arange(Z, dtype=torch.float32, device=device) + (Z // 2) + z_offset_value
) * recon_args["z_scale"]
else:
# No gradient needed, just move to device
z_position_list = z_position_list_cpu.to(device)
tf_args["z_position_list"] = z_position_list
tf_args.pop("z_offset")
tf_args.pop("z_scale")
# Core reconstruction calls (all on same device)
tf_abs, tf_phase = isotropic_thin_3d.calculate_transfer_function(**tf_args)
system = isotropic_thin_3d.calculate_singular_system(tf_abs, tf_phase)
_, yx_phase_recon = isotropic_thin_3d.apply_inverse_transfer_function(
zyx_tile, system, regularization_strength=1e-2
)
return yx_phase_recon
def compute_midband_power(
yx_array: torch.Tensor,
NA_det: float,
lambda_ill: float,
pixel_size: float,
band: tuple[float, float] = (0.125, 0.25),
) -> torch.Tensor:
"""
Compute midband power metric for optimization loss.
Parameters
----------
yx_array : torch.Tensor
2D reconstructed image (on CPU or GPU)
NA_det : float
Numerical aperture of detection
lambda_ill : float
Illumination wavelength
pixel_size : float
Pixel size in same units as wavelength
band : tuple[float, float], optional
Frequency band as fraction of cutoff, by default (0.125, 0.25)
Returns
-------
torch.Tensor
Scalar power value in the specified frequency band, on same device as input.
Notes
-----
All operations are performed on the same device as the input tensor
for efficient GPU computation.
"""
device = yx_array.device
# Generate frequency coordinates (returns numpy arrays)
_, _, fxx, fyy = util.gen_coordinate(yx_array.shape, pixel_size)
# Convert to torch tensor on same device
frr = torch.tensor(np.sqrt(fxx**2 + fyy**2), dtype=torch.float32, device=device)
# FFT and frequency masking (all on device)
xy_abs_fft = torch.abs(torch.fft.fftn(yx_array))
cutoff = 2 * NA_det / lambda_ill
mask = torch.logical_and(frr > cutoff * band[0], frr < cutoff * band[1])
return torch.sum(xy_abs_fft[mask])
def prepare_optimizer(
optimizable_params: dict[str, tuple[bool, float, float]],
device: torch.device,
) -> tuple[dict[str, torch.nn.Parameter], torch.optim.Optimizer]:
"""
Prepare optimization parameters and Adam optimizer.
Parameters
----------
optimizable_params : dict
Dict mapping param names to (enabled, initial_value, learning_rate)
device : torch.device
Device to create parameters on (CPU or GPU)
Returns
-------
tuple[dict, Optimizer]
optimization_params dict and configured optimizer
Notes
-----
All parameters are created on the specified device for efficient
GPU-accelerated optimization if available.
"""
optimization_params: dict[str, torch.nn.Parameter] = {}
optimizer_config = []
for name, (enabled, initial, lr) in optimizable_params.items():
if enabled:
param = torch.nn.Parameter(
torch.tensor([initial], dtype=torch.float32, device=device),
requires_grad=True,
)
optimization_params[name] = param
optimizer_config.append({"params": [param], "lr": lr})
optimizer = torch.optim.Adam(optimizer_config)
return optimization_params, optimizer
def run_reconstruction_single(
zyx_stack: np.ndarray,
pixel_scales: tuple[float, float, float],
fixed_params: dict,
param_values: dict,
device: Device = None,
) -> np.ndarray:
"""
Run a single phase reconstruction with specified parameters (no optimization).
Parameters
----------
zyx_stack : np.ndarray
Input Z-stack with shape (Z, Y, X)
pixel_scales : tuple[float, float, float]
(z_scale, y_scale, x_scale) in micrometers
fixed_params : dict
Fixed reconstruction parameters (wavelength, index, etc.)
param_values : dict
Parameter values to use (z_offset, numerical_aperture_detection, etc.)
device : torch.device | str | None, optional
Computing device. If None, auto-selects GPU if available, else CPU.
Returns
-------
np.ndarray
Normalized uint8 array of reconstructed phase image (for display)
"""
# Resolve device (will print GPU info if available)
device = get_device(device)
# Convert to torch tensor on target device
zyx_tile = torch.tensor(zyx_stack, dtype=torch.float32, device=device)
# Prepare reconstruction arguments
z_scale, y_scale, x_scale = pixel_scales
recon_args = fixed_params.copy()
# Remove non-reconstruction parameters from fixed_params
recon_args.pop("num_iterations", None)
recon_args.pop("use_tiling", None)
recon_args.pop("device", None)
recon_args["yx_shape"] = zyx_tile.shape[1:]
recon_args["yx_pixel_size"] = y_scale
recon_args["z_scale"] = z_scale
# Set parameter values (convert to tensors on device)
for name, value in param_values.items():
recon_args[name] = torch.tensor([value], dtype=torch.float32, device=device)
# Run reconstruction
yx_recon = run_reconstruction(zyx_tile, recon_args)
# Transfer to CPU and normalize for display
recon_numpy = yx_recon.detach().cpu().numpy()
# Wrap in xarray for normalize_for_display (expects xr.DataArray)
recon_normalized = normalize_for_display(xr.DataArray(recon_numpy))
return recon_normalized
def run_optimization_streaming(
zyx_stack: np.ndarray,
pixel_scales: tuple[float, float, float],
fixed_params: dict,
optimizable_params: dict,
num_iterations: int = 10,
device: Device = None,
) -> Generator[dict, None, None]:
"""
Run phase reconstruction optimization with streaming updates.
Generator that yields reconstruction results and loss after each iteration.
Supports GPU acceleration for significant speedup (15-25x on typical hardware).
Parameters
----------
zyx_stack : np.ndarray
Input Z-stack with shape (Z, Y, X)
pixel_scales : tuple[float, float, float]
(z_scale, y_scale, x_scale) in micrometers
fixed_params : dict
Fixed reconstruction parameters (wavelength, index, etc.)
optimizable_params : dict
Parameters to optimize with (enabled, initial, lr) tuples
num_iterations : int, optional
Number of optimization iterations, by default 10
device : torch.device | str | None, optional
Computing device. If None, auto-selects GPU if available, else CPU.
Examples: "cuda", "cpu", "cuda:0", torch.device("cuda")
By default None
Yields
------
dict
Dictionary with keys:
- 'reconstructed_image': normalized uint8 array (on CPU for display)
- 'loss': float loss value
- 'iteration': int iteration number (1-indexed)
- 'params': dict of current parameter values
Notes
-----
All computation is performed on the specified device (GPU if available).
Only final results are transferred to CPU for display, minimizing
transfer overhead.
"""
# Resolve device (will print GPU info if available)
device = get_device(device)
# Convert to torch tensor on target device (single transfer)
zyx_tile = torch.tensor(zyx_stack, dtype=torch.float32, device=device)
# Prepare reconstruction arguments
z_scale, y_scale, x_scale = pixel_scales
recon_args = fixed_params.copy()
# Remove non-reconstruction parameters from fixed_params
recon_args.pop("num_iterations", None)
recon_args.pop("use_tiling", None)
recon_args.pop("device", None) # Remove device if present
recon_args["yx_shape"] = zyx_tile.shape[1:]
recon_args["yx_pixel_size"] = y_scale
recon_args["z_scale"] = z_scale
# Initialize optimizable parameters on device
for name, (enabled, initial, lr) in optimizable_params.items():
recon_args[name] = torch.tensor([initial], dtype=torch.float32, device=device)
# Prepare optimizer with parameters on device
optimization_params, optimizer = prepare_optimizer(optimizable_params, device)
# Optimization loop (all on device)
for step in range(num_iterations):
# Update parameters
for name, param in optimization_params.items():
recon_args[name] = param
# Run reconstruction (all on device)
yx_recon = run_reconstruction(zyx_tile, recon_args)
# Compute loss (all on device, negative midband power - we want to maximize)
loss = -compute_midband_power(
yx_recon,
NA_det=0.15,
lambda_ill=recon_args["wavelength_illumination"],
pixel_size=recon_args["yx_pixel_size"],
band=(0.1, 0.2),
)
# Backward pass and optimizer step (on device)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Transfer to CPU ONLY for display (single transfer per iteration)
recon_numpy = yx_recon.detach().cpu().numpy()
# Wrap in xarray for normalize_for_display (expects xr.DataArray)
recon_normalized = normalize_for_display(xr.DataArray(recon_numpy))
# Extract current parameter values (scalars, already on CPU)
param_values = {
name: param.item() for name, param in optimization_params.items()
}
# Yield results
yield {
"reconstructed_image": recon_normalized,
"loss": loss.item(),
"iteration": step + 1,
"params": param_values,
}
def extract_tiles(
zyx_data: np.ndarray, num_tiles: tuple[int, int], overlap_pct: float
) -> tuple[dict[str, np.ndarray], dict[str, tuple[int, int, int]]]:
"""
Extract overlapping tiles from a Z-stack for processing.
Parameters
----------
zyx_data : np.ndarray
Input data with shape (Z, Y, X)
num_tiles : tuple[int, int]
Number of tiles in (Y, X) dimensions
overlap_pct : float
Overlap percentage between tiles (0.0 to 1.0)
Returns
-------
tuple[dict, dict]
tiles: dict mapping tile names to arrays
translations: dict mapping tile names to (z, y, x) positions
"""
Z, Y, X = zyx_data.shape
tile_height = int(np.ceil(Y / (num_tiles[0] - (num_tiles[0] - 1) * overlap_pct)))
tile_width = int(np.ceil(X / (num_tiles[1] - (num_tiles[1] - 1) * overlap_pct)))
stride_y = int(tile_height * (1 - overlap_pct))
stride_x = int(tile_width * (1 - overlap_pct))
tiles = {}
translations = {}
for yi in range(num_tiles[0]):
for xi in range(num_tiles[1]):
y0, x0 = yi * stride_y, xi * stride_x
y1, x1 = min(y0 + tile_height, Y), min(x0 + tile_width, X)
tile_name = f"0/0/{yi:03d}{xi:03d}"
tiles[tile_name] = zyx_data[:, y0:y1, x0:x1]
translations[tile_name] = (0, y0, x0)
return tiles, translations
|