File size: 4,046 Bytes
c843d82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#!/usr/bin/env python3
"""
Developed by Nikhil Nageshwar Inturi

Stitch tiled mask .npy files and mask PNGs into mosaics,
based on a Cellpose output root (4_cellpose_masks) containing
 'segmentation' and 'masks' subfolders, ignoring 'preview'.
"""

# imports
from PIL import Image
from pathlib import Path
import logging, numpy as np, tifffile
# local imports
from utils.constants import *


class MaskStitcher:
    """
    Stitch both .npy masks and mask PNGs from a Cellpose output root:
    - Expects root with subfolders SEGMENTATION_DIR (.npy) and MASKS_DIR (.png)
    - Outputs mosaics in STITCHED_MASKS_DIR
    """

    def __init__(self, input_dir: Path, output_dir: Path = None) -> None:
        self.input_dir = Path(input_dir)
        self.seg_dir = self.input_dir / SEGMENTATION_DIR
        self.png_dir = self.input_dir / MASKS_DIR
        self.output_dir = Path(output_dir) if output_dir is not None else Path(STITCHED_MASKS_DIR)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.logger = logging.getLogger(self.__class__.__name__)

    @staticmethod
    def _parse(fname: str):
        stem = Path(fname).stem
        base, row, col = stem.rsplit("_", 2)
        return base, int(row), int(col)

    def _read_npy(self, fp: Path) -> np.ndarray:
        data = np.load(fp, allow_pickle=True)
        if data.dtype == object:
            data = data.item().get('masks')
        return data.astype(np.int32, copy=False)

    def _read_png(self, fp: Path) -> np.ndarray:
        arr = np.array(Image.open(fp))
        return arr.astype(np.int32, copy=False)

    def _groups(self, directory: Path, pattern: str):
        groups = {}
        for fp in directory.glob(pattern):
            base, _, _ = self._parse(fp.name)
            groups.setdefault(base, []).append(fp)
        return groups

    def _layout(self, files, read_func):
        row_h = {}
        col_w = {}
        for fp in files:
            _, r, c = self._parse(fp.name)
            h, w = read_func(fp).shape
            row_h[r] = max(row_h.get(r, 0), h)
            col_w[c] = max(col_w.get(c, 0), w)
        y_off = {}
        x_off = {}
        y = x = 0
        for r in sorted(row_h):
            y_off[r] = y
            y += row_h[r]
        for c in sorted(col_w):
            x_off[c] = x
            x += col_w[c]
        return y_off, x_off, y, x

    def _stitch(self, files, read_func):
        y_off, x_off, H, W = self._layout(files, read_func)
        mosaic = np.zeros((H, W), dtype=np.int32)
        next_lbl = 1
        for fp in files:
            _, r, c = self._parse(fp.name)
            tile = read_func(fp)
            for lbl in np.unique(tile)[1:]:
                mask_region = (tile == lbl)
                yy = y_off[r]
                xx = x_off[c]
                region = mosaic[yy:yy+tile.shape[0], xx:xx+tile.shape[1]]
                region[mask_region] = next_lbl
                next_lbl += 1
        return mosaic

    def stitch_all(self) -> None:
        seg_groups = self._groups(self.seg_dir, "*.npy")
        for base, files in seg_groups.items():
            self.logger.info(f"Stitching segmentation for '{base}' ")
            mosaic = self._stitch(files, self._read_npy)
            out_npy = self.output_dir / f"{base}_stitched.npy"
            np.save(out_npy, mosaic)
            self.logger.info(f"Saved stitched .npy: {out_npy}")
            out_tif = self.output_dir / f"{base}_stitched.tif"
            tifffile.imwrite(out_tif, (mosaic>0).astype(np.uint8)*255, photometric="minisblack")
            self.logger.info(f"Saved stitched TIFF: {out_tif}")

        png_groups = self._groups(self.png_dir, "*.png")
        for base, files in png_groups.items():
            self.logger.info(f"Stitching mask PNGs for '{base}' ")
            mosaic = self._stitch(files, self._read_png)
            out_png = self.output_dir / f"{base}_mask_stitched.png"
            Image.fromarray(mosaic.astype(np.uint16)).save(out_png)
            self.logger.info(f"Saved stitched mask PNG: {out_png}")