| import math |
| import os |
| import pdb |
|
|
| import cv2 |
| import h5py |
| import numpy as np |
| from PIL import Image |
|
|
| from .util_classes import Mosaic_Canvas |
|
|
|
|
| def isWhitePatch(patch, satThresh=5): |
| patch_hsv = cv2.cvtColor(patch, cv2.COLOR_RGB2HSV) |
| return True if np.mean(patch_hsv[:, :, 1]) < satThresh else False |
|
|
|
|
| def isBlackPatch(patch, rgbThresh=40): |
| return True if np.all(np.mean(patch, axis=(0, 1)) < rgbThresh) else False |
|
|
|
|
| def isBlackPatch_S(patch, rgbThresh=20, percentage=0.05): |
| num_pixels = patch.size[0] * patch.size[1] |
| return ( |
| True |
| if np.all(np.array(patch) < rgbThresh, axis=(2)).sum() > num_pixels * percentage |
| else False |
| ) |
|
|
|
|
| def isWhitePatch_S(patch, rgbThresh=220, percentage=0.2): |
| num_pixels = patch.size[0] * patch.size[1] |
| return ( |
| True |
| if np.all(np.array(patch) > rgbThresh, axis=(2)).sum() > num_pixels * percentage |
| else False |
| ) |
|
|
|
|
| def coord_generator(x_start, x_end, x_step, y_start, y_end, y_step, args_dict=None): |
| for x in range(x_start, x_end, x_step): |
| for y in range(y_start, y_end, y_step): |
| if args_dict is not None: |
| process_dict = args_dict.copy() |
| process_dict.update({"pt": (x, y)}) |
| yield process_dict |
| else: |
| yield (x, y) |
|
|
|
|
| def savePatchIter_bag_hdf5(patch, patient_id): |
| ( |
| x, |
| y, |
| cont_idx, |
| patch_level, |
| downsample, |
| downsampled_level_dim, |
| level_dim, |
| img_patch, |
| name, |
| save_path, |
| ) = tuple(patch.values()) |
| img_patch = np.array(img_patch)[np.newaxis, ...] |
|
|
| """ |
| pil_patch = Image.fromarray(img_patch[0]) |
| patch_save_path = os.path.join(save_path+'_256png', patient_id, name, f'{x}_{y}.png') |
| if not os.path.isdir(os.path.join(save_path+'_256png', patient_id, name)): |
| os.makedirs(os.path.join(save_path+'_256png', patient_id, name), exist_ok=True) |
| pil_patch.save(patch_save_path) |
| """ |
|
|
| img_shape = img_patch.shape |
|
|
| file_path = os.path.join(save_path, name) + ".h5" |
| file = h5py.File(file_path, "a") |
|
|
| dset = file["imgs"] |
| dset.resize(len(dset) + img_shape[0], axis=0) |
| dset[-img_shape[0] :] = img_patch |
|
|
| if "coords" in file: |
| coord_dset = file["coords"] |
| coord_dset.resize(len(coord_dset) + img_shape[0], axis=0) |
| coord_dset[-img_shape[0] :] = (x, y) |
|
|
| |
| if "contour_index" in file: |
| cid_dset = file["contour_index"] |
| cid_dset.resize(len(cid_dset) + img_shape[0], axis=0) |
| cid_dset[-img_shape[0] :] = cont_idx |
|
|
| file.close() |
|
|
|
|
| def save_hdf5(output_path, asset_dict, attr_dict=None, mode="a"): |
| file = h5py.File(output_path, mode) |
| for key, val in asset_dict.items(): |
| data_shape = val.shape |
| if key not in file: |
| data_type = val.dtype |
| chunk_shape = (1,) + data_shape[1:] |
| maxshape = (None,) + data_shape[1:] |
| dset = file.create_dataset( |
| key, |
| shape=data_shape, |
| maxshape=maxshape, |
| chunks=chunk_shape, |
| dtype=data_type, |
| ) |
| dset[:] = val |
| if attr_dict is not None: |
| if key in attr_dict.keys(): |
| for attr_key, attr_val in attr_dict[key].items(): |
| dset.attrs[attr_key] = attr_val |
| else: |
| dset = file[key] |
| dset.resize(len(dset) + data_shape[0], axis=0) |
| dset[-data_shape[0] :] = val |
| file.close() |
| return output_path |
|
|
|
|
| def initialize_hdf5_bag(first_patch, save_coord=False): |
| ( |
| x, |
| y, |
| cont_idx, |
| patch_level, |
| downsample, |
| downsampled_level_dim, |
| level_dim, |
| img_patch, |
| name, |
| save_path, |
| ) = tuple(first_patch.values()) |
| os.makedirs(save_path, exist_ok=True) |
| file_path = os.path.join(save_path, name) + ".h5" |
| file = h5py.File(file_path, "a") |
| img_patch = np.array(img_patch)[np.newaxis, ...] |
| dtype = img_patch.dtype |
|
|
| |
| img_shape = img_patch.shape |
| maxshape = (None,) + img_shape[ |
| 1: |
| ] |
| dset = file.create_dataset( |
| "imgs", shape=img_shape, maxshape=maxshape, chunks=img_shape, dtype=dtype |
| ) |
|
|
| dset[:] = img_patch |
| dset.attrs["patch_level"] = patch_level |
| dset.attrs["wsi_name"] = name |
| dset.attrs["downsample"] = downsample |
| dset.attrs["level_dim"] = level_dim |
| dset.attrs["downsampled_level_dim"] = downsampled_level_dim |
|
|
| if save_coord: |
| coord_dset = file.create_dataset( |
| "coords", shape=(1, 2), maxshape=(None, 2), chunks=(1, 2), dtype=np.int32 |
| ) |
| coord_dset[:] = (x, y) |
|
|
| |
| cid_dset = file.create_dataset( |
| "contour_index", shape=(1,), maxshape=(None,), chunks=(1,), dtype=np.int32 |
| ) |
| cid_dset[:] = cont_idx |
|
|
| file.close() |
| return file_path |
|
|
|
|
| def sample_indices( |
| scores, k, start=0.48, end=0.52, convert_to_percentile=False, seed=1 |
| ): |
| np.random.seed(seed) |
| if convert_to_percentile: |
| end_value = np.quantile(scores, end) |
| start_value = np.quantile(scores, start) |
| else: |
| end_value = end |
| start_value = start |
| score_window = np.logical_and(scores >= start_value, scores <= end_value) |
| indices = np.where(score_window)[0] |
| if len(indices) < 1: |
| return -1 |
| else: |
| return np.random.choice(indices, min(k, len(indices)), replace=False) |
|
|
|
|
| def top_k(scores, k, invert=False): |
| if invert: |
| top_k_ids = scores.argsort()[:k] |
| else: |
| top_k_ids = scores.argsort()[::-1][:k] |
| return top_k_ids |
|
|
|
|
| def to_percentiles(scores): |
| from scipy.stats import rankdata |
|
|
| scores = rankdata(scores, "average") / len(scores) * 100 |
| return scores |
|
|
|
|
| def screen_coords(scores, coords, top_left, bot_right): |
| bot_right = np.array(bot_right) |
| top_left = np.array(top_left) |
| mask = np.logical_and( |
| np.all(coords >= top_left, axis=1), np.all(coords <= bot_right, axis=1) |
| ) |
| scores = scores[mask] |
| coords = coords[mask] |
| return scores, coords |
|
|
|
|
| def sample_rois( |
| scores, |
| coords, |
| k=5, |
| mode="range_sample", |
| seed=1, |
| score_start=0.45, |
| score_end=0.55, |
| top_left=None, |
| bot_right=None, |
| ): |
|
|
| if len(scores.shape) == 2: |
| scores = scores.flatten() |
|
|
| scores = to_percentiles(scores) |
| if top_left is not None and bot_right is not None: |
| scores, coords = screen_coords(scores, coords, top_left, bot_right) |
|
|
| if mode == "range_sample": |
| sampled_ids = sample_indices( |
| scores, |
| start=score_start, |
| end=score_end, |
| k=k, |
| convert_to_percentile=False, |
| seed=seed, |
| ) |
| elif mode == "topk": |
| sampled_ids = top_k(scores, k, invert=False) |
| elif mode == "reverse_topk": |
| sampled_ids = top_k(scores, k, invert=True) |
| else: |
| raise NotImplementedError |
| coords = coords[sampled_ids] |
| scores = scores[sampled_ids] |
|
|
| asset = {"sampled_coords": coords, "sampled_scores": scores} |
| return asset |
|
|
|
|
| def DrawGrid(img, coord, shape, thickness=2, color=(0, 0, 0, 255)): |
| cv2.rectangle( |
| img, |
| tuple(np.maximum([0, 0], coord - thickness // 2)), |
| tuple(coord - thickness // 2 + np.array(shape)), |
| (0, 0, 0, 255), |
| thickness=thickness, |
| ) |
| return img |
|
|
|
|
| def DrawMap( |
| canvas, patch_dset, coords, patch_size, indices=None, verbose=1, draw_grid=True |
| ): |
| if indices is None: |
| indices = np.arange(len(coords)) |
| total = len(indices) |
| if verbose > 0: |
| ten_percent_chunk = math.ceil(total * 0.1) |
| print("start stitching {}".format(patch_dset.attrs["wsi_name"])) |
|
|
| for idx in range(total): |
| if verbose > 0: |
| if idx % ten_percent_chunk == 0: |
| print("progress: {}/{} stitched".format(idx, total)) |
|
|
| patch_id = indices[idx] |
| patch = patch_dset[patch_id] |
| patch = cv2.resize(patch, patch_size) |
| coord = coords[patch_id] |
| canvas_crop_shape = canvas[ |
| coord[1] : coord[1] + patch_size[1], coord[0] : coord[0] + patch_size[0], :3 |
| ].shape[:2] |
| canvas[ |
| coord[1] : coord[1] + patch_size[1], coord[0] : coord[0] + patch_size[0], :3 |
| ] = patch[: canvas_crop_shape[0], : canvas_crop_shape[1], :] |
| if draw_grid: |
| DrawGrid(canvas, coord, patch_size) |
|
|
| return Image.fromarray(canvas) |
|
|
|
|
| def DrawMapFromCoords( |
| canvas, |
| wsi_object, |
| coords, |
| patch_size, |
| vis_level, |
| indices=None, |
| verbose=1, |
| draw_grid=True, |
| ): |
| downsamples = wsi_object.wsi.level_downsamples[vis_level] |
| if indices is None: |
| indices = np.arange(len(coords)) |
| total = len(indices) |
| if verbose > 0: |
| ten_percent_chunk = math.ceil(total * 0.1) |
|
|
| patch_size = tuple( |
| np.ceil((np.array(patch_size) / np.array(downsamples))).astype(np.int32) |
| ) |
| print("downscaled patch size: {}x{}".format(patch_size[0], patch_size[1])) |
|
|
| for idx in range(total): |
| if verbose > 0: |
| if idx % ten_percent_chunk == 0: |
| print("progress: {}/{} stitched".format(idx, total)) |
|
|
| patch_id = indices[idx] |
| coord = coords[patch_id] |
| patch = np.array( |
| wsi_object.wsi.read_region(tuple(coord), vis_level, patch_size).convert( |
| "RGB" |
| ) |
| ) |
| coord = np.ceil(coord / downsamples).astype(np.int32) |
| canvas_crop_shape = canvas[ |
| coord[1] : coord[1] + patch_size[1], coord[0] : coord[0] + patch_size[0], :3 |
| ].shape[:2] |
| canvas[ |
| coord[1] : coord[1] + patch_size[1], coord[0] : coord[0] + patch_size[0], :3 |
| ] = patch[: canvas_crop_shape[0], : canvas_crop_shape[1], :] |
| if draw_grid: |
| DrawGrid(canvas, coord, patch_size) |
|
|
| return Image.fromarray(canvas) |
|
|
|
|
| def StitchPatches( |
| hdf5_file_path, downscale=16, draw_grid=False, bg_color=(0, 0, 0), alpha=-1 |
| ): |
| file = h5py.File(hdf5_file_path, "r") |
| dset = file["imgs"] |
| coords = file["coords"][:] |
| if "downsampled_level_dim" in dset.attrs.keys(): |
| w, h = dset.attrs["downsampled_level_dim"] |
| else: |
| w, h = dset.attrs["level_dim"] |
| print("original size: {} x {}".format(w, h)) |
| w = w // downscale |
| h = h // downscale |
| coords = (coords / downscale).astype(np.int32) |
| print("downscaled size for stiching: {} x {}".format(w, h)) |
| print("number of patches: {}".format(len(dset))) |
| img_shape = dset[0].shape |
| print("patch shape: {}".format(img_shape)) |
| downscaled_shape = (img_shape[1] // downscale, img_shape[0] // downscale) |
|
|
| if w * h > Image.MAX_IMAGE_PIXELS: |
| raise Image.DecompressionBombError( |
| "Visualization Downscale %d is too large" % downscale |
| ) |
|
|
| if alpha < 0 or alpha == -1: |
| heatmap = Image.new(size=(w, h), mode="RGB", color=bg_color) |
| else: |
| heatmap = Image.new( |
| size=(w, h), mode="RGBA", color=bg_color + (int(255 * alpha),) |
| ) |
|
|
| heatmap = np.array(heatmap) |
| heatmap = DrawMap( |
| heatmap, dset, coords, downscaled_shape, indices=None, draw_grid=draw_grid |
| ) |
|
|
| file.close() |
| return heatmap |
|
|
|
|
| def StitchCoords( |
| hdf5_file_path, |
| wsi_object, |
| downscale=16, |
| draw_grid=False, |
| bg_color=(0, 0, 0), |
| alpha=-1, |
| ): |
| wsi = wsi_object.getOpenSlide() |
| vis_level = wsi.get_best_level_for_downsample(downscale) |
| file = h5py.File(hdf5_file_path, "r") |
| dset = file["coords"] |
| coords = dset[:] |
| w, h = wsi.level_dimensions[0] |
|
|
| print("start stitching {}".format(dset.attrs["name"])) |
| print("original size: {} x {}".format(w, h)) |
|
|
| w, h = wsi.level_dimensions[vis_level] |
|
|
| print("downscaled size for stiching: {} x {}".format(w, h)) |
| print("number of patches: {}".format(len(coords))) |
|
|
| patch_size = dset.attrs["patch_size"] |
| patch_level = dset.attrs["patch_level"] |
| print( |
| "patch size: {}x{} patch level: {}".format(patch_size, patch_size, patch_level) |
| ) |
| patch_size = tuple( |
| ( |
| np.array((patch_size, patch_size)) * wsi.level_downsamples[patch_level] |
| ).astype(np.int32) |
| ) |
| print("ref patch size: {}x{}".format(patch_size, patch_size)) |
|
|
| if Image.MAX_IMAGE_PIXELS is not None and w * h > Image.MAX_IMAGE_PIXELS: |
| raise Image.DecompressionBombError( |
| "Visualization Downscale %d is too large" % downscale |
| ) |
|
|
| if alpha < 0 or alpha == -1: |
| heatmap = Image.new(size=(w, h), mode="RGB", color=bg_color) |
| else: |
| heatmap = Image.new( |
| size=(w, h), mode="RGBA", color=bg_color + (int(255 * alpha),) |
| ) |
|
|
| heatmap = np.array(heatmap) |
| heatmap = DrawMapFromCoords( |
| heatmap, |
| wsi_object, |
| coords, |
| patch_size, |
| vis_level, |
| indices=None, |
| draw_grid=draw_grid, |
| ) |
|
|
| file.close() |
| return heatmap |
|
|
|
|
| def StitchCoords2( |
| hdf5_file_path, |
| wsi_object, |
| patch_size, |
| downscale=16, |
| draw_grid=False, |
| bg_color=(0, 0, 0), |
| alpha=-1, |
| ): |
| wsi = wsi_object.getOpenSlide() |
| vis_level = wsi.get_best_level_for_downsample(downscale) |
| file = h5py.File(hdf5_file_path, "r") |
| dset = file["imgs"] |
| coords = file["coords"][:] |
| w, h = wsi.level_dimensions[0] |
|
|
| print("start stitching {}".format(dset.attrs["wsi_name"])) |
| print("original size: {} x {}".format(w, h)) |
|
|
| w, h = wsi.level_dimensions[vis_level] |
|
|
| print("downscaled size for stiching: {} x {}".format(w, h)) |
| print("number of patches: {}".format(len(coords))) |
|
|
| patch_level = dset.attrs["patch_level"] |
| print( |
| "patch size: {}x{} patch level: {}".format(patch_size, patch_size, patch_level) |
| ) |
| patch_size = tuple( |
| ( |
| np.array((patch_size, patch_size)) * wsi.level_downsamples[patch_level] |
| ).astype(np.int32) |
| ) |
| print("ref patch size: {}x{}".format(patch_size, patch_size)) |
|
|
| if w * h > Image.MAX_IMAGE_PIXELS: |
| raise Image.DecompressionBombError( |
| "Visualization Downscale %d is too large" % downscale |
| ) |
|
|
| if alpha < 0 or alpha == -1: |
| heatmap = Image.new(size=(w, h), mode="RGB", color=bg_color) |
| else: |
| heatmap = Image.new( |
| size=(w, h), mode="RGBA", color=bg_color + (int(255 * alpha),) |
| ) |
|
|
| heatmap = np.array(heatmap) |
| heatmap = DrawMapFromCoords( |
| heatmap, |
| wsi_object, |
| coords, |
| patch_size, |
| vis_level, |
| indices=None, |
| draw_grid=draw_grid, |
| ) |
|
|
| file.close() |
| return heatmap |
|
|
|
|
| def SamplePatches( |
| coords_file_path, |
| save_file_path, |
| wsi_object, |
| patch_level=0, |
| custom_downsample=1, |
| patch_size=256, |
| sample_num=100, |
| seed=1, |
| stitch=True, |
| verbose=1, |
| mode="w", |
| ): |
| file = h5py.File(coords_file_path, "r") |
| dset = file["coords"] |
| coords = dset[:] |
|
|
| h5_patch_size = dset.attrs["patch_size"] |
| h5_patch_level = dset.attrs["patch_level"] |
|
|
| if verbose > 0: |
| print("in .h5 file: total number of patches: {}".format(len(coords))) |
| print( |
| "in .h5 file: patch size: {}x{} patch level: {}".format( |
| h5_patch_size, h5_patch_size, h5_patch_level |
| ) |
| ) |
|
|
| if patch_level < 0: |
| patch_level = h5_patch_level |
|
|
| if patch_size < 0: |
| patch_size = h5_patch_size |
|
|
| np.random.seed(seed) |
| indices = np.random.choice( |
| np.arange(len(coords)), min(len(coords), sample_num), replace=False |
| ) |
|
|
| target_patch_size = np.array([patch_size, patch_size]) |
|
|
| if custom_downsample > 1: |
| target_patch_size = ( |
| np.array([patch_size, patch_size]) / custom_downsample |
| ).astype(np.int32) |
|
|
| if stitch: |
| canvas = Mosaic_Canvas( |
| patch_size=target_patch_size[0], |
| n=sample_num, |
| downscale=4, |
| n_per_row=10, |
| bg_color=(0, 0, 0), |
| alpha=-1, |
| ) |
| else: |
| canvas = None |
|
|
| for idx in indices: |
| coord = coords[idx] |
| patch = wsi_object.wsi.read_region( |
| coord, patch_level, tuple([patch_size, patch_size]) |
| ).convert("RGB") |
| if custom_downsample > 1: |
| patch = patch.resize(tuple(target_patch_size)) |
|
|
| |
| |
|
|
| if stitch: |
| canvas.paste_patch(patch) |
|
|
| asset_dict = {"imgs": np.array(patch)[np.newaxis, ...], "coords": coord} |
| save_hdf5(save_file_path, asset_dict, mode=mode) |
| mode = "a" |
|
|
| return canvas, len(coords), len(indices) |
|
|