|
|
""" |
|
|
Utility functions for image processing in hyperspectral datasets. |
|
|
|
|
|
Author: Ole-Christian Galbo Engstrøm |
|
|
E-mail: ocge@foss.dk |
|
|
""" |
|
|
|
|
|
from typing import Tuple, Union |
|
|
|
|
|
import numpy as np |
|
|
import spectral as spy |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
def pad_with_average_spectrum( |
|
|
img: Union[spy.io.bilfile.BilFile, np.ndarray], |
|
|
output_height: int, |
|
|
output_width: int, |
|
|
is_mask: bool, |
|
|
): |
|
|
if isinstance(img, spy.io.bilfile.BilFile): |
|
|
img = img.load() |
|
|
h, w = img.shape[:2] |
|
|
if h >= output_height and w >= output_width: |
|
|
return img |
|
|
if is_mask: |
|
|
img = np.expand_dims(img, axis=-1) |
|
|
left_column = img[:, 0:1, :] |
|
|
right_column = img[:, -1:, :] |
|
|
if is_mask: |
|
|
avg_spectrum = np.zeros(img.shape[-1], dtype=img.dtype) |
|
|
|
|
|
else: |
|
|
|
|
|
avg_spectrum_left_column = np.mean(left_column, axis=(0, 1)) |
|
|
avg_spectrum_right_column = np.mean(right_column, axis=(0, 1)) |
|
|
avg_spectrum = (avg_spectrum_left_column + avg_spectrum_right_column) / 2 |
|
|
|
|
|
if h < output_height: |
|
|
num_top_pad_rows = (output_height - img.shape[0]) // 2 |
|
|
num_bottom_pad_rows = num_top_pad_rows + (output_height - img.shape[0]) % 2 |
|
|
else: |
|
|
num_top_pad_rows = 0 |
|
|
num_bottom_pad_rows = 0 |
|
|
output_height = h |
|
|
|
|
|
top_pad = np.tile(avg_spectrum[None, None, :], (num_top_pad_rows, img.shape[1], 1)) |
|
|
bottom_pad = np.tile( |
|
|
avg_spectrum[None, None, :], (num_bottom_pad_rows, img.shape[1], 1) |
|
|
) |
|
|
|
|
|
if w < output_width: |
|
|
num_left_pad_columns = (output_width - img.shape[1]) // 2 |
|
|
num_right_pad_columns = num_left_pad_columns + (output_width - img.shape[1]) % 2 |
|
|
|
|
|
left_pad = np.tile( |
|
|
avg_spectrum[None, None, :], (output_height, num_left_pad_columns, 1) |
|
|
) |
|
|
right_pad = np.tile( |
|
|
avg_spectrum[None, None, :], (output_height, num_right_pad_columns, 1) |
|
|
) |
|
|
|
|
|
img = np.concatenate([top_pad, img, bottom_pad], axis=0) |
|
|
img = np.concatenate([left_pad, img, right_pad], axis=1) |
|
|
if is_mask: |
|
|
img = img.squeeze(axis=-1) |
|
|
return img |
|
|
|
|
|
|
|
|
def pad_with_random_spectrum( |
|
|
img: Union[spy.io.bilfile.BilFile, np.ndarray], output_height: int, is_mask: bool |
|
|
): |
|
|
if isinstance(img, spy.io.bilfile.BilFile): |
|
|
img = img.load() |
|
|
if img.shape[0] >= output_height: |
|
|
return img |
|
|
if is_mask: |
|
|
img = np.expand_dims(img, axis=-1) |
|
|
left_column = img[:, 0:1, :] |
|
|
right_column = img[:, -1:, :] |
|
|
if is_mask: |
|
|
zero_spectrum = np.zeros(img.shape[-1], dtype=img.dtype) |
|
|
print(f"Padding mask with zero spectrum") |
|
|
else: |
|
|
print(f"Padding image with random spectrum") |
|
|
|
|
|
stacked_columns = np.concatenate([left_column, right_column], axis=0).squeeze( |
|
|
axis=1 |
|
|
) |
|
|
|
|
|
num_top_pad_rows = (output_height - img.shape[0]) // 2 |
|
|
num_bottom_pad_rows = num_top_pad_rows + (output_height - img.shape[0]) % 2 |
|
|
|
|
|
if is_mask: |
|
|
top_pad = np.tile( |
|
|
zero_spectrum[None, None, :], (num_top_pad_rows, img.shape[1], 1) |
|
|
) |
|
|
bottom_pad = np.tile( |
|
|
zero_spectrum[None, None, :], (num_bottom_pad_rows, img.shape[1], 1) |
|
|
) |
|
|
|
|
|
else: |
|
|
random_indices = np.random.randint( |
|
|
0, |
|
|
stacked_columns.shape[0], |
|
|
size=(num_top_pad_rows + num_bottom_pad_rows, img.shape[1]), |
|
|
) |
|
|
top_pad = stacked_columns[random_indices[:num_top_pad_rows]] |
|
|
bottom_pad = stacked_columns[random_indices[num_top_pad_rows:]] |
|
|
|
|
|
img = np.concatenate([top_pad, img, bottom_pad], axis=0) |
|
|
if is_mask: |
|
|
img = img.squeeze(axis=-1) |
|
|
return img |
|
|
|
|
|
|
|
|
def bilinear_interpolation( |
|
|
img: torch.Tensor, |
|
|
output_height: int, |
|
|
output_width: int, |
|
|
) -> torch.Tensor: |
|
|
return F.interpolate( |
|
|
img[None, ...], |
|
|
size=(output_height, output_width), |
|
|
mode="bilinear", |
|
|
align_corners=False, |
|
|
) |
|
|
|
|
|
|
|
|
def discard_less_than_750nm_and_bin(img): |
|
|
assert img.shape[0] == 300 |
|
|
img = img[176:, ...] |
|
|
|
|
|
img = img.reshape(img.shape[0] // 2, 2, img.shape[1], img.shape[2]).mean(axis=1) |
|
|
return img |
|
|
|
|
|
|
|
|
def discard_less_than_750nm(img): |
|
|
assert img.shape[0] == 300 |
|
|
return img[177:, ...] |
|
|
|
|
|
|
|
|
def convert_mask_to_8bit(mask): |
|
|
mask = mask.astype(np.uint8) |
|
|
return mask |
|
|
|
|
|
|
|
|
def quantize_16_bit(arr): |
|
|
bias = arr.min() |
|
|
max_uint16 = 65535 |
|
|
scale = max_uint16 / (arr.max() - bias) |
|
|
uint_16_arr = np.round((arr - bias) * scale) |
|
|
assert uint_16_arr.min() == 0 |
|
|
try: |
|
|
assert uint_16_arr.max() == max_uint16 |
|
|
except AssertionError: |
|
|
print(f"Max value is {uint_16_arr.max()}") |
|
|
uint_16_arr = uint_16_arr.astype(np.uint16) |
|
|
return uint_16_arr, bias, scale |
|
|
|
|
|
|
|
|
def dequantize_16_bit(arr, bias, scale): |
|
|
assert arr.dtype == np.uint16 |
|
|
assert bias.dtype == np.float32 |
|
|
assert scale.dtype == np.float32 |
|
|
return (arr / scale) + bias |
|
|
|
|
|
|
|
|
def compute_crop_central_coordinates( |
|
|
crop_height: int, |
|
|
crop_width: int, |
|
|
central_crop_height: int, |
|
|
central_crop_width: int, |
|
|
) -> Tuple[int, int, int, int]: |
|
|
crop_start_height_coordinate = (crop_height - central_crop_height) // 2 |
|
|
crop_end_height_coordinate = crop_start_height_coordinate + central_crop_height |
|
|
crop_start_width_coordinate = (crop_width - central_crop_width) // 2 |
|
|
crop_end_width_coordinate = crop_start_width_coordinate + central_crop_width |
|
|
return ( |
|
|
crop_start_height_coordinate, |
|
|
crop_end_height_coordinate, |
|
|
crop_start_width_coordinate, |
|
|
crop_end_width_coordinate, |
|
|
) |
|
|
|
|
|
|
|
|
def mask_image( |
|
|
img: torch.Tensor, |
|
|
mask: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
img = img * mask |
|
|
return img |
|
|
|
|
|
|
|
|
def load_image_crop_overlap_tile( |
|
|
img: Union[spy.io.bilfile.BilFile, np.ndarray], |
|
|
crop_start_height_coordinate: int, |
|
|
crop_end_height_coordinate: int, |
|
|
crop_start_width_coordinate: int, |
|
|
crop_end_width_coordinate: int, |
|
|
crop_expanded_height: int, |
|
|
crop_expanded_width: int, |
|
|
) -> torch.Tensor: |
|
|
img_height = img.shape[0] |
|
|
img_width = img.shape[1] |
|
|
|
|
|
crop_width = crop_end_width_coordinate - crop_start_width_coordinate |
|
|
crop_height = crop_end_height_coordinate - crop_start_height_coordinate |
|
|
|
|
|
|
|
|
expanded_left_size = (crop_expanded_width - crop_width) // 2 |
|
|
expanded_right_size = expanded_left_size + (crop_expanded_width - crop_width) % 2 |
|
|
expanded_top_size = (crop_expanded_height - crop_height) // 2 |
|
|
expanded_bottom_size = expanded_top_size + (crop_expanded_height - crop_height) % 2 |
|
|
|
|
|
|
|
|
calculated_crop_start_height_coordinate = ( |
|
|
crop_start_height_coordinate - expanded_top_size |
|
|
) |
|
|
calculated_crop_end_height_coordinate = ( |
|
|
crop_end_height_coordinate + expanded_bottom_size |
|
|
) |
|
|
calculated_crop_start_width_coordinate = ( |
|
|
crop_start_width_coordinate - expanded_left_size |
|
|
) |
|
|
calculated_crop_end_width_coordinate = ( |
|
|
crop_end_width_coordinate + expanded_right_size |
|
|
) |
|
|
actual_crop_start_height_coordinate = max( |
|
|
0, calculated_crop_start_height_coordinate |
|
|
) |
|
|
actual_crop_end_height_coordinate = min( |
|
|
img_height, calculated_crop_end_height_coordinate |
|
|
) |
|
|
actual_crop_start_width_coordinate = max(0, calculated_crop_start_width_coordinate) |
|
|
actual_crop_end_width_coordinate = min( |
|
|
img_width, calculated_crop_end_width_coordinate |
|
|
) |
|
|
|
|
|
|
|
|
expanded_crop = img[ |
|
|
actual_crop_start_height_coordinate:actual_crop_end_height_coordinate, |
|
|
actual_crop_start_width_coordinate:actual_crop_end_width_coordinate, |
|
|
] |
|
|
|
|
|
if isinstance(expanded_crop, np.ndarray): |
|
|
expanded_crop = np.asarray(expanded_crop, dtype=np.float32) |
|
|
if len(expanded_crop.shape) == 2: |
|
|
expanded_crop = expanded_crop[None, ...] |
|
|
elif len(expanded_crop.shape) == 3: |
|
|
expanded_crop = np.moveaxis(expanded_crop, -1, 0) |
|
|
expanded_crop = torch.tensor(expanded_crop) |
|
|
|
|
|
|
|
|
top_mirror_size = ( |
|
|
actual_crop_start_height_coordinate - calculated_crop_start_height_coordinate |
|
|
) |
|
|
bottom_mirror_size = ( |
|
|
calculated_crop_end_height_coordinate - actual_crop_end_height_coordinate |
|
|
) |
|
|
left_mirror_size = ( |
|
|
actual_crop_start_width_coordinate - calculated_crop_start_width_coordinate |
|
|
) |
|
|
right_mirror_size = ( |
|
|
calculated_crop_end_width_coordinate - actual_crop_end_width_coordinate |
|
|
) |
|
|
|
|
|
|
|
|
if ( |
|
|
top_mirror_size > 0 |
|
|
or bottom_mirror_size > 0 |
|
|
or left_mirror_size > 0 |
|
|
or right_mirror_size > 0 |
|
|
): |
|
|
expanded_crop = F.pad( |
|
|
expanded_crop, |
|
|
(left_mirror_size, right_mirror_size, top_mirror_size, bottom_mirror_size), |
|
|
mode="reflect", |
|
|
) |
|
|
|
|
|
return expanded_crop |
|
|
|