File size: 5,828 Bytes
7435e6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from PIL import Image
from typing import Tuple
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage
import torch
from torchvision.transforms import functional as tvf

from pathlib import Path

def sliced_mean(x, slice_size):
    cs_y = np.cumsum(x, axis=0)
    cs_y = np.concatenate((np.zeros((1, cs_y.shape[1]), dtype=cs_y.dtype), cs_y), axis=0)
    slices_y = (cs_y[slice_size:] - cs_y[:-slice_size])/slice_size
    cs_xy = np.cumsum(slices_y, axis=1)
    cs_xy = np.concatenate((np.zeros((cs_xy.shape[0], 1), dtype=cs_xy.dtype), cs_xy), axis=1)
    slices_xy = (cs_xy[:,slice_size:] - cs_xy[:,:-slice_size])/slice_size
    return slices_xy

def sliced_var(x, slice_size):
    x = x.astype('float64')
    return sliced_mean(x**2, slice_size) - sliced_mean(x, slice_size)**2

def calculate_local_variance(img, var_window):
    """return local variance map with the same size as input image"""
    var = sliced_var(img, var_window)

    left_pad = var_window // 2 -1
    right_pad = var_window -1 - left_pad
    var_padded = np.pad(
        var,
        pad_width=(
            (left_pad,right_pad),
            (left_pad,right_pad)
        ))
    return var_padded

def get_crop_batch(img: np.ndarray, mask: np.ndarray, crop_size=96, crop_scales=np.geomspace(0.5, 2, 7), samples_per_scale=32, use_variance_threshold=False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Generate a batch of cropped images from an input image and corresponding mask, at various scales and rotations.

    Parameters
    ----------
    img : np.ndarray
        The input image from which crops are generated.
    mask : np.ndarray
        The binary mask indicating the region of interest in the image.
    crop_size : int, optional
        The size of the square crop.
    crop_scales : np.ndarray, optional
        An array of scale factors to apply to the crop size.
    samples_per_scale : int, optional
        Number of samples to generate per scale factor.
    use_variance_threshold : bool, optional
        Flag to use variance thresholding for selecting crop locations.

    Returns
    -------
    Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
        A tuple containing the tensor of crops, their rotation angles, and scale factors.
    """
    
    # pad
    pad_size = int(np.ceil(0.5*crop_size*max(crop_scales)*(np.sqrt(2)-1)))
    img_padded = np.pad(img, pad_size)
    mask_padded = np.pad(mask, pad_size)

    # distance map
    distance_map_padded = ndimage.distance_transform_edt(mask_padded)
    # TODO: adjust scales and samples_per_scale
    
    if use_variance_threshold:
        variance_window = min(crop_size//2, min(img.shape))
        variance_map_padded = np.pad(calculate_local_variance(img, variance_window), pad_size)
        variance_median = np.ma.median(np.ma.masked_where(distance_map_padded<0.5*variance_window, variance_map_padded))
        variance_mask = variance_map_padded >= variance_median
    else:
        variance_mask = np.ones_like(mask_padded)
    
    # initilize output
    crops_granum = []
    angles_granum = []
    scales_granum = []
    # loop over scales
    for scale in crop_scales: 
        half_crop_size_scaled = int(np.floor(scale*0.5*crop_size)) # half of crop size after scaling
        crop_pad = int(np.ceil((np.sqrt(2) - 1)*half_crop_size_scaled)) # pad added in order to allow rotation
        half_crop_size_external = half_crop_size_scaled + crop_pad # size of "external crop" which will be rotated

        possible_indices = np.stack(np.where(variance_mask & (distance_map_padded >= 2*half_crop_size_scaled)), axis=1)
        if len(possible_indices) == 0:
            continue
        chosen_indices = np.random.choice(np.arange(len(possible_indices)), min(len(possible_indices), samples_per_scale), replace=False)

        crops = [
            img_padded[y-half_crop_size_external:y+half_crop_size_external, x-half_crop_size_external:x+half_crop_size_external] for y, x in possible_indices[chosen_indices]
        ]

        # rotate
        rotation_angles = np.random.rand(len(crops))*180 - 90
        crops = [
            ndimage.rotate(crop, angle, reshape=False)[crop_pad:-crop_pad,crop_pad:-crop_pad] for crop, angle in zip(crops, rotation_angles)
        ]
        # add to output
        crops_granum.append(tvf.resize(torch.tensor(np.array(crops)), (crop_size,crop_size),antialias=True)) # resize crops to crop_size
        angles_granum.extend(rotation_angles.tolist())
        scales_granum.extend([scale]*len(crops))

    if len(angles_granum) == 0:
        return [], [], []
    
    crops_granum = torch.concat(crops_granum)
    angles_granum = torch.tensor(angles_granum, dtype=torch.float)
    scales_granum = torch.tensor(scales_granum, dtype=torch.float)
    
    return crops_granum, angles_granum, scales_granum

def get_crop_batch_from_path(img_path, mask_path=None, use_variance_threshold=False):
    """
    Load an image and its mask from file paths and generate a batch of cropped images.

    Parameters
    ----------
    img_path : str
        Path to the input image.
    mask_path : str, optional
        Path to the binary mask image. If None, assumes mask path by replacing image extension with '.npy'.
    use_variance_threshold : bool, optional
        Flag to use variance thresholding for selecting crop locations.

    Returns
    -------
    Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
        A tuple containing the tensor of crops, their rotation angles, and scale factors, obtained from the specified image path.
    """
    if mask_path is None:
        mask_path = str(Path(img_path).with_suffix('.npy'))
    mask = np.load(mask_path)
    img = np.array(Image.open(img_path))[:,:,0]
    
    return get_crop_batch(img, mask, use_variance_threshold=use_variance_threshold)