#!/usr/bin/env python3 """ Developed by Nikhil Nageshwar Inturi Stitch tiled mask .npy files and mask PNGs into mosaics, based on a Cellpose output root (4_cellpose_masks) containing 'segmentation' and 'masks' subfolders, ignoring 'preview'. """ # imports from PIL import Image from pathlib import Path import logging, numpy as np, tifffile # local imports from utils.constants import * class MaskStitcher: """ Stitch both .npy masks and mask PNGs from a Cellpose output root: - Expects root with subfolders SEGMENTATION_DIR (.npy) and MASKS_DIR (.png) - Outputs mosaics in STITCHED_MASKS_DIR """ def __init__(self, input_dir: Path, output_dir: Path = None) -> None: self.input_dir = Path(input_dir) self.seg_dir = self.input_dir / SEGMENTATION_DIR self.png_dir = self.input_dir / MASKS_DIR self.output_dir = Path(output_dir) if output_dir is not None else Path(STITCHED_MASKS_DIR) self.output_dir.mkdir(parents=True, exist_ok=True) self.logger = logging.getLogger(self.__class__.__name__) @staticmethod def _parse(fname: str): stem = Path(fname).stem base, row, col = stem.rsplit("_", 2) return base, int(row), int(col) def _read_npy(self, fp: Path) -> np.ndarray: data = np.load(fp, allow_pickle=True) if data.dtype == object: data = data.item().get('masks') return data.astype(np.int32, copy=False) def _read_png(self, fp: Path) -> np.ndarray: arr = np.array(Image.open(fp)) return arr.astype(np.int32, copy=False) def _groups(self, directory: Path, pattern: str): groups = {} for fp in directory.glob(pattern): base, _, _ = self._parse(fp.name) groups.setdefault(base, []).append(fp) return groups def _layout(self, files, read_func): row_h = {} col_w = {} for fp in files: _, r, c = self._parse(fp.name) h, w = read_func(fp).shape row_h[r] = max(row_h.get(r, 0), h) col_w[c] = max(col_w.get(c, 0), w) y_off = {} x_off = {} y = x = 0 for r in sorted(row_h): y_off[r] = y y += row_h[r] for c in sorted(col_w): x_off[c] = x x += col_w[c] return y_off, x_off, y, x def _stitch(self, files, read_func): y_off, x_off, H, W = self._layout(files, read_func) mosaic = np.zeros((H, W), dtype=np.int32) next_lbl = 1 for fp in files: _, r, c = self._parse(fp.name) tile = read_func(fp) for lbl in np.unique(tile)[1:]: mask_region = (tile == lbl) yy = y_off[r] xx = x_off[c] region = mosaic[yy:yy+tile.shape[0], xx:xx+tile.shape[1]] region[mask_region] = next_lbl next_lbl += 1 return mosaic def stitch_all(self) -> None: seg_groups = self._groups(self.seg_dir, "*.npy") for base, files in seg_groups.items(): self.logger.info(f"Stitching segmentation for '{base}' ") mosaic = self._stitch(files, self._read_npy) out_npy = self.output_dir / f"{base}_stitched.npy" np.save(out_npy, mosaic) self.logger.info(f"Saved stitched .npy: {out_npy}") out_tif = self.output_dir / f"{base}_stitched.tif" tifffile.imwrite(out_tif, (mosaic>0).astype(np.uint8)*255, photometric="minisblack") self.logger.info(f"Saved stitched TIFF: {out_tif}") png_groups = self._groups(self.png_dir, "*.png") for base, files in png_groups.items(): self.logger.info(f"Stitching mask PNGs for '{base}' ") mosaic = self._stitch(files, self._read_png) out_png = self.output_dir / f"{base}_mask_stitched.png" Image.fromarray(mosaic.astype(np.uint16)).save(out_png) self.logger.info(f"Saved stitched mask PNG: {out_png}")