File size: 13,878 Bytes
e101805
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
"""Functional patch extraction API.

Primary entrypoint: `patchfy()`.
"""

from __future__ import annotations

import os
from typing import Optional, Tuple, Union

import numpy as np
import openslide
from PIL import Image

from .wsi_core import SlideImage

Image.MAX_IMAGE_PIXELS = None


def _slide_id_from_path(wsi_path: str) -> str:
    """Return slide_id (filename without extension) from a WSI path."""
    fname = os.path.basename(wsi_path)
    slide_id, _ = os.path.splitext(fname)
    return slide_id


def _h5_is_complete(h5_path: str) -> bool:
    if not os.path.isfile(h5_path):
        return False
    import h5py

    try:
        with h5py.File(h5_path, 'r') as f:
            complete = bool(f.attrs.get('complete', False))
            has_coords = 'coords' in f and len(f['coords']) > 0
            return complete and has_coords
    except Exception:
        return False


def _ensure_slide_image(wsi: Union[str, os.PathLike, SlideImage]) -> Tuple[SlideImage, str]:
    """Normalize WSI input into a SlideImage + path string."""
    if isinstance(wsi, SlideImage):
        return wsi, wsi.path
    wsi_path = os.fspath(wsi)
    return SlideImage(wsi_path), wsi_path


def run_pipeline(
    wsi_path: str,
    out_dir: str,
    patch_size=256,
    step_size=256,
    patch_level=0,
    save_mask=False,
    save_h5=True,
    auto_skip=True,
    seg_downsample: float = 1.0,
    max_seg_pixels: float | None = 4e10,
    # MPP normalization (always enabled)
    standard_mpp: float = 0.5,
    assumed_mpp: float = 0.5,
    # Segmentation/Filtering knobs
    sthresh: int = 8,
    mthresh: int = 7,
    close: int = 4,
    use_otsu: bool = False,
    a_t: int = 1,
    a_h: int = 1,
    max_n_holes: int = 100,
    # Visualization knob (only used when save_mask=True)
    line_thickness: int = 100,
):
    """Internal pipeline for single-WSI patch extraction."""
    # Prepare output dirs
    patch_save_dir = os.path.join(out_dir, 'patches')
    mask_save_dir = os.path.join(out_dir, 'masks')
    os.makedirs(patch_save_dir, exist_ok=True)
    os.makedirs(mask_save_dir, exist_ok=True)

    # Auto-skip existing output
    slide_id = _slide_id_from_path(wsi_path)
    canonical_path = os.path.join(patch_save_dir, slide_id + '.h5')

    if auto_skip:
        if _h5_is_complete(canonical_path):
            return None, None
        else:
            if os.path.isfile(canonical_path):
                try:
                    os.remove(canonical_path)
                except Exception:
                    pass
    else:
        # If the caller explicitly disables auto_skip, regenerate outputs.
        # This is important when parameters change (e.g., MPP-normalized patch_size/step_size).
        if os.path.isfile(canonical_path):
            try:
                os.remove(canonical_path)
            except Exception:
                pass

    WSI_object = SlideImage(wsi_path)

    # --- Always normalize patch_size/step_size by MPP (match float crop mode) ---
    wsi = WSI_object.get_slide()
    props = getattr(wsi, 'properties', {}) or {}
    this_mpp = None
    for key in (getattr(openslide, 'PROPERTY_NAME_MPP_X', 'openslide.mpp-x'), 'openslide.mpp-x', 'aperio.MPP'):
        try:
            if key in props:
                this_mpp = float(props[key])
                break
        except Exception:
            this_mpp = None
    if this_mpp is None or not np.isfinite(this_mpp) or this_mpp <= 0:
        # objective-power heuristic: 10x->1.0, 20x->0.5, 40x->0.25
        try:
            obj = props.get('openslide.objective-power')
            if obj is not None:
                obj = float(obj)
                if obj > 0:
                    this_mpp = 10.0 / obj
        except Exception:
            this_mpp = None
    if this_mpp is None or not np.isfinite(this_mpp) or this_mpp <= 0:
        this_mpp = float(assumed_mpp)

    crop_factor = float(standard_mpp) / float(this_mpp)
    # Keep behavior identical to the original float-crop branch: int() truncation
    eff_patch_size = int(int(patch_size) * crop_factor)
    eff_step_size = int(int(step_size) * crop_factor)
    if eff_patch_size <= 0:
        eff_patch_size = int(patch_size)
    if eff_step_size <= 0:
        eff_step_size = int(step_size)

    print(f"mpp: {this_mpp} (assumed_mpp={assumed_mpp}, standard_mpp={standard_mpp})")
    print(f"patch_size: {eff_patch_size}  step_size: {eff_step_size}")

    # Determine visualization/segmentation levels
    vis_level = wsi.get_best_level_for_downsample(64)
    seg_level = wsi.get_best_level_for_downsample(64)

    w, h = WSI_object.level_dim[seg_level]
    if max_seg_pixels is not None and (w * h) > float(max_seg_pixels):
        # Make the skip explicit so downstream knows this slide was not processed.
        slide_id = _slide_id_from_path(wsi_path)
        reason = (
            f"SKIP: seg_level image too large (w*h={w*h:.3g}, w={w}, h={h}, "
            f"seg_level={seg_level}, max_seg_pixels={float(max_seg_pixels):.3g})"
        )
        print(reason)
        try:
            skipped_dir = os.path.join(out_dir, 'skipped')
            os.makedirs(skipped_dir, exist_ok=True)
            with open(os.path.join(skipped_dir, slide_id + '.txt'), 'w', encoding='utf-8') as f:
                f.write(reason + "\n")
                f.write(wsi_path + "\n")
        except Exception:
            pass
        return None, None

    # Segmentation
    import time as _t

    t0 = _t.time()
    try:
        WSI_object.segment_regions(
            seg_level=seg_level,
            sthresh=sthresh,
            mthresh=mthresh,
            close=close,
            use_otsu=use_otsu,
            filter_params={'a_t': a_t, 'a_h': a_h, 'max_n_holes': max_n_holes},
            seg_downsample=seg_downsample,
        )
    except Exception:
        # fallback to a coarser level if needed
        fallback_level = min(2, len(WSI_object.level_dim) - 1)
        WSI_object.segment_regions(
            seg_level=fallback_level,
            sthresh=sthresh,
            mthresh=mthresh,
            close=close,
            use_otsu=use_otsu,
            filter_params={'a_t': a_t, 'a_h': a_h, 'max_n_holes': max_n_holes},
            seg_downsample=seg_downsample,
        )
    seg_time_elapsed = _t.time() - t0
    if len(WSI_object.contours_tissue) == 0:
        # Still allow saving a visualization (will be slide-only if no contours).
        if save_mask:
            try:
                mask = WSI_object.render_segmentation(vis_level=vis_level, line_thickness=line_thickness)
                seg_ds = float(seg_downsample or 1.0)
                if seg_ds > 1.0:
                    new_w = max(1, int(mask.width / seg_ds))
                    new_h = max(1, int(mask.height / seg_ds))
                    mask = mask.resize((new_w, new_h), resample=Image.Resampling.BILINEAR)
                mask.save(os.path.join(mask_save_dir, slide_id + '.jpg'))
            except Exception:
                pass
        return None, None

    if save_mask:
        mask = WSI_object.render_segmentation(vis_level=vis_level, line_thickness=line_thickness)
        seg_ds = float(seg_downsample or 1.0)
        if seg_ds > 1.0:
            new_w = max(1, int(mask.width / seg_ds))
            new_h = max(1, int(mask.height / seg_ds))
            mask = mask.resize((new_w, new_h), resample=Image.Resampling.BILINEAR)
        mask.save(os.path.join(mask_save_dir, slide_id + '.jpg'))

    # Patch extraction
    t0 = _t.time()
    file_path = None
    if save_h5:
        file_path = WSI_object.write_patches(
            patch_level=patch_level,
            patch_size=int(eff_patch_size),
            step_size=int(eff_step_size),
            save_path=patch_save_dir,
            use_padding=True,
            contour_fn='four_pt',
        )
        # The writer returns the concrete H5 path. Ensure it's the canonical location.
        if file_path != canonical_path:
            # Don't fail hard, but keep behavior deterministic.
            file_path = canonical_path
        if not _h5_is_complete(file_path):
            print(f"warning: generated file missing completion marker: {file_path}")
    patch_time_elapsed = _t.time() - t0

    # Print elapsed times
    print(f"segmentation time: {seg_time_elapsed:.2f}s")
    print(f"patch extraction time: {patch_time_elapsed:.2f}s")

    # Return coords/contour_index from the written HDF5.
    if file_path is None:
        return None, None

    try:
        import h5py

        with h5py.File(file_path, 'r') as f:
            coords = f['coords'][...]
            contour_indices = f['contour_index'][...] if 'contour_index' in f else None
        return coords, contour_indices
    except Exception:
        return None, None


def patchfy(
    wsi: Union[str, os.PathLike, SlideImage],
    out: str,
    step_size: int = 256,
    patch_size: int = 256,
    patch_level: int = 0,
    save_mask: bool = False,
    save_h5: bool = True,
    auto_skip: bool = True,
    seg_downsample: float = 1.0,
    max_seg_pixels: Optional[float] = 4e10,
    # Segmentation/Filtering knobs
    sthresh: int = 8,
    mthresh: int = 7,
    close: int = 4,
    use_otsu: bool = False,
    a_t: int = 1,
    a_h: int = 1,
    max_n_holes: int = 100,
    # Visualization knob
    line_thickness: int = 100,
) -> Tuple[Optional[str], Optional[np.ndarray], Optional[np.ndarray]]:
    """High-level API for single-WSI patch extraction.

    This function is designed to be imported and called from other Python code.

    Parameters
    ----------
    wsi:
        Either a WSI path (str / PathLike) or an already-opened `SlideImage`.

    out:
        Output directory. Two subfolders may be created:
        - `<out>/patches` for HDF5
        - `<out>/masks` for segmentation visualization (when `save_mask=True`)

    step_size:
        Stride between patch coordinates in level-0 pixel space.

    patch_size:
        Patch edge length in pixels (requested). The effective values written to H5
        are MPP-normalized inside `run_pipeline`.

    patch_level:
        OpenSlide pyramid level to extract patches from. 0 means full resolution.

    save_mask:
        If True, saves a segmentation visualization image to `<out>/masks/<slide_id>.jpg`.

    save_h5:
        If True, writes an HDF5 file to `<out>/patches/<slide_id>.h5`.

    auto_skip:
        If True, and a previous output HDF5 exists with `attrs['complete']=True` and non-empty coords,
        the slide is skipped.

    seg_downsample:
        Additional downsampling factor used *during segmentation only* to reduce computation.
        - 1.0 means no extra downsample.
        - >1.0 means resize segmentation image by this factor (faster, potentially less accurate).

    max_seg_pixels:
        Safety cap for segmentation image size at the chosen `seg_level`.
        If `w*h` at `seg_level` exceeds this, the slide is skipped and a reason is written to
        `<out>/skipped/<slide_id>.txt`. Set `<=0` to disable this check.

    sthresh:
        Tissue segmentation threshold on the *S channel* (saturation in HSV) used by
        `SlideImage.segment_regions`. Lower values generally mark more pixels as tissue.

    mthresh:
        Median blur kernel size applied to the saturation mask before thresholding.
        Must be an integer; even values will be made odd inside `segment_regions`.

    close:
        Morphological closing kernel size (in pixels) applied to the binary mask.
        Use 0 to disable closing.

    use_otsu:
        If True, use Otsu thresholding instead of fixed `sthresh` during saturation thresholding.

    a_t:
        Minimum tissue contour area threshold (baseline) passed via `filter_params['a_t']`.
        Note: `SlideImage.segment_regions` internally rescales this threshold based on the pyramid
        level and `seg_downsample` so this is a *baseline* knob rather than a strict level-0 area.

    a_h:
        Minimum hole area threshold (baseline) passed via `filter_params['a_h']`.
        Same rescaling behavior as `a_t`.

    max_n_holes:
        Maximum number of holes to keep per tissue contour (largest holes by area).

    line_thickness:
        Line thickness used when rendering the segmentation visualization mask.
        Only used if `save_mask=True`.

    Returns
    -------
    (h5_path, coords, contour_indices)
        If skipped/failed to segment, returns (None, None, None).

        - `h5_path`: canonical output path under `<out>/patches/<slide_id>.h5` when `save_h5=True`.
        - `coords`: Nx2 array of patch (x, y) coordinates in level-0 pixels.
        - `contour_indices`: N array with the tissue contour index for each coordinate.
    """
    os.makedirs(out, exist_ok=True)

    # Normalize disabling behavior
    if max_seg_pixels is not None and max_seg_pixels <= 0:
        max_seg_pixels = None

    WSI_object, wsi_path = _ensure_slide_image(wsi)
    slide_id = getattr(WSI_object, 'name', None) or _slide_id_from_path(wsi_path)
    patch_save_dir = os.path.join(out, 'patches')
    h5_path = os.path.join(patch_save_dir, slide_id + '.h5') if save_h5 else None

    coords, contour_indices = run_pipeline(
        wsi_path=wsi_path,
        out_dir=out,
        patch_size=patch_size,
        step_size=step_size,
        patch_level=patch_level,
        save_mask=save_mask,
        save_h5=save_h5,
        auto_skip=auto_skip,
        seg_downsample=seg_downsample,
        max_seg_pixels=max_seg_pixels,
        sthresh=sthresh,
        mthresh=mthresh,
        close=close,
        use_otsu=use_otsu,
        a_t=a_t,
        a_h=a_h,
        max_n_holes=max_n_holes,
        line_thickness=line_thickness,
    )

    if coords is None or contour_indices is None:
        return None, None, None
    return h5_path, coords, contour_indices


__all__ = [
    'patchfy',
    'run_pipeline',
    'SlideImage',
]