Spaces:
Sleeping
Sleeping
File size: 5,828 Bytes
7435e6e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from PIL import Image
from typing import Tuple
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage
import torch
from torchvision.transforms import functional as tvf
from pathlib import Path
def sliced_mean(x, slice_size):
cs_y = np.cumsum(x, axis=0)
cs_y = np.concatenate((np.zeros((1, cs_y.shape[1]), dtype=cs_y.dtype), cs_y), axis=0)
slices_y = (cs_y[slice_size:] - cs_y[:-slice_size])/slice_size
cs_xy = np.cumsum(slices_y, axis=1)
cs_xy = np.concatenate((np.zeros((cs_xy.shape[0], 1), dtype=cs_xy.dtype), cs_xy), axis=1)
slices_xy = (cs_xy[:,slice_size:] - cs_xy[:,:-slice_size])/slice_size
return slices_xy
def sliced_var(x, slice_size):
x = x.astype('float64')
return sliced_mean(x**2, slice_size) - sliced_mean(x, slice_size)**2
def calculate_local_variance(img, var_window):
"""return local variance map with the same size as input image"""
var = sliced_var(img, var_window)
left_pad = var_window // 2 -1
right_pad = var_window -1 - left_pad
var_padded = np.pad(
var,
pad_width=(
(left_pad,right_pad),
(left_pad,right_pad)
))
return var_padded
def get_crop_batch(img: np.ndarray, mask: np.ndarray, crop_size=96, crop_scales=np.geomspace(0.5, 2, 7), samples_per_scale=32, use_variance_threshold=False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Generate a batch of cropped images from an input image and corresponding mask, at various scales and rotations.
Parameters
----------
img : np.ndarray
The input image from which crops are generated.
mask : np.ndarray
The binary mask indicating the region of interest in the image.
crop_size : int, optional
The size of the square crop.
crop_scales : np.ndarray, optional
An array of scale factors to apply to the crop size.
samples_per_scale : int, optional
Number of samples to generate per scale factor.
use_variance_threshold : bool, optional
Flag to use variance thresholding for selecting crop locations.
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
A tuple containing the tensor of crops, their rotation angles, and scale factors.
"""
# pad
pad_size = int(np.ceil(0.5*crop_size*max(crop_scales)*(np.sqrt(2)-1)))
img_padded = np.pad(img, pad_size)
mask_padded = np.pad(mask, pad_size)
# distance map
distance_map_padded = ndimage.distance_transform_edt(mask_padded)
# TODO: adjust scales and samples_per_scale
if use_variance_threshold:
variance_window = min(crop_size//2, min(img.shape))
variance_map_padded = np.pad(calculate_local_variance(img, variance_window), pad_size)
variance_median = np.ma.median(np.ma.masked_where(distance_map_padded<0.5*variance_window, variance_map_padded))
variance_mask = variance_map_padded >= variance_median
else:
variance_mask = np.ones_like(mask_padded)
# initilize output
crops_granum = []
angles_granum = []
scales_granum = []
# loop over scales
for scale in crop_scales:
half_crop_size_scaled = int(np.floor(scale*0.5*crop_size)) # half of crop size after scaling
crop_pad = int(np.ceil((np.sqrt(2) - 1)*half_crop_size_scaled)) # pad added in order to allow rotation
half_crop_size_external = half_crop_size_scaled + crop_pad # size of "external crop" which will be rotated
possible_indices = np.stack(np.where(variance_mask & (distance_map_padded >= 2*half_crop_size_scaled)), axis=1)
if len(possible_indices) == 0:
continue
chosen_indices = np.random.choice(np.arange(len(possible_indices)), min(len(possible_indices), samples_per_scale), replace=False)
crops = [
img_padded[y-half_crop_size_external:y+half_crop_size_external, x-half_crop_size_external:x+half_crop_size_external] for y, x in possible_indices[chosen_indices]
]
# rotate
rotation_angles = np.random.rand(len(crops))*180 - 90
crops = [
ndimage.rotate(crop, angle, reshape=False)[crop_pad:-crop_pad,crop_pad:-crop_pad] for crop, angle in zip(crops, rotation_angles)
]
# add to output
crops_granum.append(tvf.resize(torch.tensor(np.array(crops)), (crop_size,crop_size),antialias=True)) # resize crops to crop_size
angles_granum.extend(rotation_angles.tolist())
scales_granum.extend([scale]*len(crops))
if len(angles_granum) == 0:
return [], [], []
crops_granum = torch.concat(crops_granum)
angles_granum = torch.tensor(angles_granum, dtype=torch.float)
scales_granum = torch.tensor(scales_granum, dtype=torch.float)
return crops_granum, angles_granum, scales_granum
def get_crop_batch_from_path(img_path, mask_path=None, use_variance_threshold=False):
"""
Load an image and its mask from file paths and generate a batch of cropped images.
Parameters
----------
img_path : str
Path to the input image.
mask_path : str, optional
Path to the binary mask image. If None, assumes mask path by replacing image extension with '.npy'.
use_variance_threshold : bool, optional
Flag to use variance thresholding for selecting crop locations.
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
A tuple containing the tensor of crops, their rotation angles, and scale factors, obtained from the specified image path.
"""
if mask_path is None:
mask_path = str(Path(img_path).with_suffix('.npy'))
mask = np.load(mask_path)
img = np.array(Image.open(img_path))[:,:,0]
return get_crop_batch(img, mask, use_variance_threshold=use_variance_threshold)
|