| """WSI patch extraction utilities. |
| |
| Provides `SlideImage` with: |
| - tissue segmentation (`segment_regions`) |
| - segmentation visualization (`render_segmentation`) |
| - coordinate extraction to HDF5 (`write_patches`) |
| """ |
| from __future__ import annotations |
|
|
| import os |
| import math |
| from typing import List |
|
|
| import numpy as np |
| import cv2 |
| import h5py |
| import openslide |
| from PIL import Image |
|
|
| |
|
|
|
|
| class SlideImage: |
| def __init__(self, path: str): |
| self.path = path |
| self.name = os.path.splitext(os.path.basename(path))[0] |
| self.patient_id = os.path.basename(os.path.dirname(path)) |
| self.wsi = openslide.open_slide(path) |
| try: |
| self.wsi.set_cache(openslide.OpenSlideCache(0)) |
| except Exception: |
| pass |
| |
| self.level_dim = self.wsi.level_dimensions |
| self.level_downsamples = self._calc_level_downsamples() |
| self.contours_tissue: List[np.ndarray] = [] |
| self.holes_tissue: List[List[np.ndarray]] = [] |
|
|
| def _calc_level_downsamples(self): |
| """Compute per-level downsamples in a way consistent with OpenSlide metadata. |
| |
| Some slides report `wsi.level_downsamples` as scalars, while the true |
| X/Y downsample inferred from dimensions can be slightly non-integer. |
| We follow the same rule used in the existing pipeline: |
| - if inferred (sx, sy) matches (ds, ds), use (ds, ds) |
| - otherwise use the inferred pair |
| """ |
| downs = [] |
| dim0 = self.wsi.level_dimensions[0] |
| for ds, dim in zip(self.wsi.level_downsamples, self.wsi.level_dimensions): |
| est = (dim0[0] / float(dim[0]), dim0[1] / float(dim[1])) |
| if est == (ds, ds): |
| downs.append((ds, ds)) |
| else: |
| downs.append(est) |
| return downs |
|
|
| def get_slide(self): |
| return self.wsi |
|
|
| def initSegmentation(self, mask_file): |
| |
| import pickle |
| with open(mask_file, 'rb') as f: |
| asset = pickle.load(f) |
| self.holes_tissue = asset.get('holes', []) |
| self.contours_tissue = asset.get('tissue', []) |
|
|
| def saveSegmentation(self, mask_file): |
| import pickle |
| with open(mask_file, 'wb') as f: |
| pickle.dump({'holes': self.holes_tissue, 'tissue': self.contours_tissue}, f) |
|
|
| @staticmethod |
| def _filter_contours(contours, hierarchy_2col, a_t=100, a_h=16, max_n_holes=8): |
| """Filter contours by area and keep up to N holes per contour. |
| |
| `hierarchy_2col` is expected to be `hierarchy[:, 2:]` from OpenCV RETR_CCOMP, |
| i.e. columns: [child, parent]. |
| """ |
| filtered_idx = [] |
| hole_groups = [] |
| |
| parent_idx = np.flatnonzero(hierarchy_2col[:, 1] == -1) |
| for pi in parent_idx: |
| cont = contours[pi] |
| |
| hole_ids = np.flatnonzero(hierarchy_2col[:, 1] == pi) |
| a = cv2.contourArea(cont) |
| if hole_ids.size: |
| a -= np.sum([cv2.contourArea(contours[c]) for c in hole_ids]) |
| if a <= 0: |
| continue |
| if a >= a_t: |
| filtered_idx.append(pi) |
| |
| unfiltered = [contours[c] for c in hole_ids] |
| unfiltered.sort(key=cv2.contourArea, reverse=True) |
| kept = [h for h in unfiltered[:max_n_holes] if cv2.contourArea(h) > a_h] |
| hole_groups.append(kept) |
| return [contours[i] for i in filtered_idx], hole_groups |
|
|
| @staticmethod |
| def _scale_contours(contours, scale): |
| if contours is None: |
| return [] |
| sx, sy = scale |
| out = [] |
| for c in contours: |
| pts = c.reshape(-1, 2) |
| |
| |
| pts = (pts * np.array([sx, sy])).astype(np.int32) |
| out.append(pts.reshape(-1, 1, 2)) |
| return out |
|
|
| def segment_regions(self, |
| seg_level=0, |
| sthresh=20, |
| sthresh_up=255, |
| mthresh=7, |
| close=0, |
| use_otsu=False, |
| filter_params={"a_t": 100, "a_h": 16, "max_n_holes": 8}, |
| ref_patch_size=512, |
| exclude_ids=None, |
| keep_ids=None, |
| seg_downsample=1.0): |
| """Segment tissue regions to build foreground contours.""" |
| |
| raw = np.array(self.wsi.read_region((0, 0), seg_level, self.level_dim[seg_level])) |
| if seg_downsample is None: |
| seg_downsample = 1.0 |
| scale_ds = float(seg_downsample) if seg_downsample and seg_downsample > 1.0 else 1.0 |
| if scale_ds > 1.0: |
| new_w = max(1, int(raw.shape[1] / scale_ds)) |
| new_h = max(1, int(raw.shape[0] / scale_ds)) |
| img = cv2.resize(raw, (new_w, new_h), interpolation=cv2.INTER_AREA) |
| else: |
| img = raw |
| |
| img_rgb = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB) |
|
|
| |
| def _draw_white_bands(rgb, thickness=20): |
| h, w = rgb.shape[:2] |
| cv2.rectangle(rgb, (0, 0), (w, thickness), (255, 255, 255), -1) |
| cv2.rectangle(rgb, (0, h - thickness), (w, h), (255, 255, 255), -1) |
| cv2.rectangle(rgb, (0, 0), (thickness, h), (255, 255, 255), -1) |
| cv2.rectangle(rgb, (w - thickness, 0), (w, h), (255, 255, 255), -1) |
| return rgb |
|
|
| |
| img_rgb = _draw_white_bands(img_rgb, thickness=20) |
|
|
| |
| img_gray = cv2.cvtColor(img, cv2.COLOR_RGBA2GRAY) |
| lap = cv2.Laplacian(img_gray, cv2.CV_64F) |
| lap_abs = cv2.convertScaleAbs(lap) |
| mono_mask = lap_abs <= 15 |
| img_rgb[mono_mask] = (255, 255, 255) |
|
|
| hsv = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2HSV) |
| s = hsv[:, :, 1] |
| ksize = int(mthresh) if int(mthresh) % 2 == 1 else int(mthresh) + 1 |
| s_med = cv2.medianBlur(s, ksize) |
|
|
| if use_otsu: |
| _, mask = cv2.threshold(s_med, 0, sthresh_up, cv2.THRESH_OTSU + cv2.THRESH_BINARY) |
| else: |
| _, mask = cv2.threshold(s_med, int(sthresh), sthresh_up, cv2.THRESH_BINARY) |
|
|
| def _post_mask(mk): |
| if close and close > 0: |
| kernel = np.ones((int(close), int(close)), np.uint8) |
| mk = cv2.morphologyEx(mk, cv2.MORPH_CLOSE, kernel) |
| return mk |
|
|
| mask = _post_mask(mask) |
| contours, hierarchy = cv2.findContours(mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE) |
| if not contours or hierarchy is None: |
| self.contours_tissue = [] |
| self.holes_tissue = [] |
| return |
|
|
| hierarchy = np.squeeze(hierarchy, axis=0)[:, 2:] |
|
|
| base_scale = self.level_downsamples[seg_level] |
| |
| coord_scale = (base_scale[0] * scale_ds, base_scale[1] * scale_ds) |
| |
| scaled_area = max(1, int(ref_patch_size ** 2 / (coord_scale[0] * coord_scale[1]))) |
| f_a_t = int(filter_params.get('a_t', 100)) * scaled_area |
| f_a_h = int(filter_params.get('a_h', 16)) * scaled_area |
| max_holes = int(filter_params.get('max_n_holes', 8)) |
|
|
| fg_contours, hole_groups = self._filter_contours(contours, hierarchy, a_t=f_a_t, a_h=f_a_h, max_n_holes=max_holes) |
|
|
| |
| self.contours_tissue = self._scale_contours(fg_contours, coord_scale) |
| self.holes_tissue = [self._scale_contours(hs, coord_scale) for hs in hole_groups] |
| |
| n = len(self.contours_tissue) |
| ids = set(range(n)) |
| if keep_ids: |
| ids = set(keep_ids) - set(exclude_ids or []) |
| elif exclude_ids: |
| ids = ids - set(exclude_ids) |
| self.contours_tissue = [self.contours_tissue[i] for i in sorted(ids)] |
| self.holes_tissue = [self.holes_tissue[i] for i in sorted(ids)] |
|
|
| def render_segmentation( |
| self, |
| vis_level=0, |
| color=(0, 255, 0), |
| hole_color=(0, 0, 255), |
| line_thickness=250, |
| top_left=None, |
| bot_right=None, |
| view_slide_only=False, |
| seg_display=True, |
| ): |
| scale = [1 / self.level_downsamples[vis_level][0], 1 / self.level_downsamples[vis_level][1]] |
| if top_left is not None and bot_right is not None: |
| top_left = tuple(top_left) |
| bot_right = tuple(bot_right) |
| w, h = tuple((np.array(bot_right) * scale).astype(int) - (np.array(top_left) * scale).astype(int)) |
| region_size = (w, h) |
| else: |
| top_left = (0, 0) |
| region_size = self.level_dim[vis_level] |
| img = np.array(self.wsi.read_region(top_left, vis_level, region_size).convert("RGB")) |
| if not view_slide_only and self.contours_tissue and seg_display: |
| offset = tuple(-(np.array(top_left) * np.array(scale)).astype(int)) |
| lt = int(line_thickness * math.sqrt(scale[0] * scale[1])) |
| for cont in self.contours_tissue: |
| contour = (cont * np.array(scale)).astype(np.int32) |
| cv2.drawContours(img, [contour], -1, color, lt, lineType=cv2.LINE_8, offset=offset) |
|
|
| |
| if getattr(self, 'holes_tissue', None): |
| for holes in self.holes_tissue: |
| if not holes: |
| continue |
| holes_scaled = [(h * np.array(scale)).astype(np.int32) for h in holes] |
| cv2.drawContours(img, holes_scaled, -1, hole_color, lt, lineType=cv2.LINE_8) |
| return Image.fromarray(img) |
|
|
| def write_patches( |
| self, |
| patch_level: int, |
| patch_size: int, |
| step_size: int, |
| save_path: str, |
| use_padding: bool = True, |
| contour_fn: str = "four_pt", |
| |
| white_thresh: int = 15, |
| black_thresh: int = 50, |
| ) -> str: |
| """Extract patch coordinates inside segmented tissue and save to HDF5. |
| |
| The output HDF5 contains: |
| - coords: int32 [N, 2] (level-0 pixel coordinates) |
| - contour_index: int32 [N] |
| - file attrs: complete=True when finished |
| """ |
|
|
| name = self.name |
| downsample = self.level_downsamples[patch_level] |
| os.makedirs(save_path, exist_ok=True) |
| file_path = os.path.join(save_path, f"{name}.h5") |
| os.makedirs(os.path.dirname(file_path), exist_ok=True) |
|
|
| |
| if os.path.isfile(file_path): |
| try: |
| with h5py.File(file_path, 'r') as f: |
| if bool(f.attrs.get('complete', False)): |
| return file_path |
| except Exception: |
| pass |
|
|
| |
| try: |
| if os.path.isfile(file_path): |
| os.remove(file_path) |
| except Exception: |
| pass |
|
|
| with h5py.File(file_path, 'w') as f: |
| coords = f.create_dataset('coords', shape=(0, 2), maxshape=(None, 2), chunks=(1024, 2), dtype=np.int32) |
| |
| coords.attrs['patch_size'] = int(patch_size) |
| coords.attrs['patch_level'] = int(patch_level) |
| coords.attrs['downsample'] = downsample |
| coords.attrs['downsampled_level_dim'] = tuple(np.array(self.level_dim[patch_level])) |
| coords.attrs['level_dim'] = tuple(np.array(self.level_dim[patch_level])) |
| coords.attrs['name'] = name |
|
|
| cid = f.create_dataset('contour_index', shape=(0,), maxshape=(None,), chunks=(1024,), dtype=np.int32) |
| cid.attrs['name'] = name |
|
|
| |
| if not self.contours_tissue: |
| with h5py.File(file_path, 'a') as f: |
| f.attrs['complete'] = True |
| f.attrs['n_coords'] = 0 |
| return file_path |
|
|
| def _cont_check_fn(cont: np.ndarray, ref_patch: int): |
| if contour_fn == 'four_pt': |
| shift = int(ref_patch // 2 * 0.5) |
|
|
| def _fn(pt): |
| cx, cy = pt[0] + ref_patch // 2, pt[1] + ref_patch // 2 |
| pts = [ |
| (cx - shift, cy - shift), |
| (cx + shift, cy + shift), |
| (cx + shift, cy - shift), |
| (cx - shift, cy + shift), |
| ] |
| for p in pts: |
| if cv2.pointPolygonTest(cont, (float(p[0]), float(p[1])), False) >= 0: |
| return True |
| return False |
|
|
| return _fn |
|
|
| if contour_fn == 'center': |
| def _fn(pt): |
| cx, cy = pt[0] + ref_patch // 2, pt[1] + ref_patch // 2 |
| return cv2.pointPolygonTest(cont, (float(cx), float(cy)), False) >= 0 |
|
|
| return _fn |
|
|
| if contour_fn == 'basic': |
| def _fn(pt): |
| return cv2.pointPolygonTest(cont, (float(pt[0]), float(pt[1])), False) >= 0 |
|
|
| return _fn |
|
|
| raise ValueError(f"Unsupported contour_fn: {contour_fn}") |
|
|
| def _in_holes(holes, pt, ref_patch: int) -> bool: |
| if not holes: |
| return False |
| cx, cy = pt[0] + ref_patch / 2.0, pt[1] + ref_patch / 2.0 |
| for hole in holes: |
| if cv2.pointPolygonTest(hole, (float(cx), float(cy)), False) > 0: |
| return True |
| return False |
|
|
| patch_downsample = ( |
| int(self.level_downsamples[patch_level][0]), |
| int(self.level_downsamples[patch_level][1]), |
| ) |
| ref_patch_size = (patch_size * patch_downsample[0], patch_size * patch_downsample[1]) |
| step_size_x = int(step_size * patch_downsample[0]) |
| step_size_y = int(step_size * patch_downsample[1]) |
|
|
| img_w, img_h = self.level_dim[0] |
|
|
| total_written = 0 |
| with h5py.File(file_path, 'a') as f: |
| coords_d = f['coords'] |
| cid_d = f['contour_index'] |
|
|
| for ci, cont in enumerate(self.contours_tissue): |
| start_x, start_y, w, h = cv2.boundingRect(cont) |
| if use_padding: |
| stop_x = start_x + w |
| stop_y = start_y + h |
| else: |
| stop_x = min(start_x + w, img_w - ref_patch_size[0] + 1) |
| stop_y = min(start_y + h, img_h - ref_patch_size[1] + 1) |
|
|
| checker = _cont_check_fn(cont, ref_patch_size[0]) |
| holes = self.holes_tissue[ci] if getattr(self, 'holes_tissue', None) else None |
|
|
| xs = np.arange(start_x, stop_x, step=step_size_x) |
| ys = np.arange(start_y, stop_y, step=step_size_y) |
|
|
| for y in ys: |
| for x in xs: |
| pt = (int(x), int(y)) |
| if not checker(pt): |
| continue |
| if holes is not None and _in_holes(holes, pt, ref_patch_size[0]): |
| continue |
|
|
| n0 = len(coords_d) |
| coords_d.resize(n0 + 1, axis=0) |
| coords_d[n0] = (pt[0], pt[1]) |
| cid_d.resize(n0 + 1, axis=0) |
| cid_d[n0] = int(ci) |
| total_written += 1 |
|
|
| f.attrs['complete'] = True |
| f.attrs['n_coords'] = int(total_written) |
|
|
| return file_path |
|
|
| __all__ = ["SlideImage"] |
|
|