Sm00thix's picture
Initial upload of source and weights
78a947a
"""
Utility functions for image processing in hyperspectral datasets.
Author: Ole-Christian Galbo Engstrøm
E-mail: ocge@foss.dk
"""
from typing import Tuple, Union
import numpy as np
import spectral as spy
import torch
import torch.nn.functional as F
def pad_with_average_spectrum(
img: Union[spy.io.bilfile.BilFile, np.ndarray],
output_height: int,
output_width: int,
is_mask: bool,
):
if isinstance(img, spy.io.bilfile.BilFile):
img = img.load()
h, w = img.shape[:2]
if h >= output_height and w >= output_width:
return img
if is_mask:
img = np.expand_dims(img, axis=-1)
left_column = img[:, 0:1, :]
right_column = img[:, -1:, :]
if is_mask:
avg_spectrum = np.zeros(img.shape[-1], dtype=img.dtype)
# print(f"Padding mask with avg_spectrum")
else:
# print(f"Padding image with avg_spectrum")
avg_spectrum_left_column = np.mean(left_column, axis=(0, 1))
avg_spectrum_right_column = np.mean(right_column, axis=(0, 1))
avg_spectrum = (avg_spectrum_left_column + avg_spectrum_right_column) / 2
if h < output_height:
num_top_pad_rows = (output_height - img.shape[0]) // 2
num_bottom_pad_rows = num_top_pad_rows + (output_height - img.shape[0]) % 2
else:
num_top_pad_rows = 0
num_bottom_pad_rows = 0
output_height = h
top_pad = np.tile(avg_spectrum[None, None, :], (num_top_pad_rows, img.shape[1], 1))
bottom_pad = np.tile(
avg_spectrum[None, None, :], (num_bottom_pad_rows, img.shape[1], 1)
)
if w < output_width:
num_left_pad_columns = (output_width - img.shape[1]) // 2
num_right_pad_columns = num_left_pad_columns + (output_width - img.shape[1]) % 2
left_pad = np.tile(
avg_spectrum[None, None, :], (output_height, num_left_pad_columns, 1)
)
right_pad = np.tile(
avg_spectrum[None, None, :], (output_height, num_right_pad_columns, 1)
)
img = np.concatenate([top_pad, img, bottom_pad], axis=0)
img = np.concatenate([left_pad, img, right_pad], axis=1)
if is_mask:
img = img.squeeze(axis=-1)
return img
def pad_with_random_spectrum(
img: Union[spy.io.bilfile.BilFile, np.ndarray], output_height: int, is_mask: bool
):
if isinstance(img, spy.io.bilfile.BilFile):
img = img.load()
if img.shape[0] >= output_height:
return img
if is_mask:
img = np.expand_dims(img, axis=-1)
left_column = img[:, 0:1, :]
right_column = img[:, -1:, :]
if is_mask:
zero_spectrum = np.zeros(img.shape[-1], dtype=img.dtype)
print(f"Padding mask with zero spectrum")
else:
print(f"Padding image with random spectrum")
# stack the left and right columns
stacked_columns = np.concatenate([left_column, right_column], axis=0).squeeze(
axis=1
)
num_top_pad_rows = (output_height - img.shape[0]) // 2
num_bottom_pad_rows = num_top_pad_rows + (output_height - img.shape[0]) % 2
if is_mask:
top_pad = np.tile(
zero_spectrum[None, None, :], (num_top_pad_rows, img.shape[1], 1)
)
bottom_pad = np.tile(
zero_spectrum[None, None, :], (num_bottom_pad_rows, img.shape[1], 1)
)
else:
random_indices = np.random.randint(
0,
stacked_columns.shape[0],
size=(num_top_pad_rows + num_bottom_pad_rows, img.shape[1]),
)
top_pad = stacked_columns[random_indices[:num_top_pad_rows]]
bottom_pad = stacked_columns[random_indices[num_top_pad_rows:]]
img = np.concatenate([top_pad, img, bottom_pad], axis=0)
if is_mask:
img = img.squeeze(axis=-1)
return img
def bilinear_interpolation(
img: torch.Tensor,
output_height: int,
output_width: int,
) -> torch.Tensor:
return F.interpolate(
img[None, ...],
size=(output_height, output_width),
mode="bilinear",
align_corners=False,
)
def discard_less_than_750nm_and_bin(img):
assert img.shape[0] == 300
img = img[176:, ...]
# bin by taking the average of every two bands
img = img.reshape(img.shape[0] // 2, 2, img.shape[1], img.shape[2]).mean(axis=1)
return img
def discard_less_than_750nm(img):
assert img.shape[0] == 300
return img[177:, ...]
def convert_mask_to_8bit(mask):
mask = mask.astype(np.uint8)
return mask
def quantize_16_bit(arr):
bias = arr.min()
max_uint16 = 65535
scale = max_uint16 / (arr.max() - bias)
uint_16_arr = np.round((arr - bias) * scale)
assert uint_16_arr.min() == 0
try:
assert uint_16_arr.max() == max_uint16
except AssertionError:
print(f"Max value is {uint_16_arr.max()}")
uint_16_arr = uint_16_arr.astype(np.uint16)
return uint_16_arr, bias, scale
def dequantize_16_bit(arr, bias, scale):
assert arr.dtype == np.uint16
assert bias.dtype == np.float32
assert scale.dtype == np.float32
return (arr / scale) + bias
def compute_crop_central_coordinates(
crop_height: int,
crop_width: int,
central_crop_height: int,
central_crop_width: int,
) -> Tuple[int, int, int, int]:
crop_start_height_coordinate = (crop_height - central_crop_height) // 2
crop_end_height_coordinate = crop_start_height_coordinate + central_crop_height
crop_start_width_coordinate = (crop_width - central_crop_width) // 2
crop_end_width_coordinate = crop_start_width_coordinate + central_crop_width
return (
crop_start_height_coordinate,
crop_end_height_coordinate,
crop_start_width_coordinate,
crop_end_width_coordinate,
)
def mask_image(
img: torch.Tensor,
mask: torch.Tensor,
) -> torch.Tensor:
img = img * mask
return img
def load_image_crop_overlap_tile(
img: Union[spy.io.bilfile.BilFile, np.ndarray],
crop_start_height_coordinate: int,
crop_end_height_coordinate: int,
crop_start_width_coordinate: int,
crop_end_width_coordinate: int,
crop_expanded_height: int,
crop_expanded_width: int,
) -> torch.Tensor:
img_height = img.shape[0]
img_width = img.shape[1]
crop_width = crop_end_width_coordinate - crop_start_width_coordinate
crop_height = crop_end_height_coordinate - crop_start_height_coordinate
# How much the crop is expanded on each side (i.e., the extra pixels on each side of the original crop)
expanded_left_size = (crop_expanded_width - crop_width) // 2
expanded_right_size = expanded_left_size + (crop_expanded_width - crop_width) % 2
expanded_top_size = (crop_expanded_height - crop_height) // 2
expanded_bottom_size = expanded_top_size + (crop_expanded_height - crop_height) % 2
# Get as much of the expanded crop as possible without going out of bounds
calculated_crop_start_height_coordinate = (
crop_start_height_coordinate - expanded_top_size
)
calculated_crop_end_height_coordinate = (
crop_end_height_coordinate + expanded_bottom_size
)
calculated_crop_start_width_coordinate = (
crop_start_width_coordinate - expanded_left_size
)
calculated_crop_end_width_coordinate = (
crop_end_width_coordinate + expanded_right_size
)
actual_crop_start_height_coordinate = max(
0, calculated_crop_start_height_coordinate
)
actual_crop_end_height_coordinate = min(
img_height, calculated_crop_end_height_coordinate
)
actual_crop_start_width_coordinate = max(0, calculated_crop_start_width_coordinate)
actual_crop_end_width_coordinate = min(
img_width, calculated_crop_end_width_coordinate
)
# Get the expanded crop
expanded_crop = img[
actual_crop_start_height_coordinate:actual_crop_end_height_coordinate,
actual_crop_start_width_coordinate:actual_crop_end_width_coordinate,
]
if isinstance(expanded_crop, np.ndarray):
expanded_crop = np.asarray(expanded_crop, dtype=np.float32)
if len(expanded_crop.shape) == 2:
expanded_crop = expanded_crop[None, ...]
elif len(expanded_crop.shape) == 3:
expanded_crop = np.moveaxis(expanded_crop, -1, 0)
expanded_crop = torch.tensor(expanded_crop)
# Check if the expanded crop needs mirroring
top_mirror_size = (
actual_crop_start_height_coordinate - calculated_crop_start_height_coordinate
)
bottom_mirror_size = (
calculated_crop_end_height_coordinate - actual_crop_end_height_coordinate
)
left_mirror_size = (
actual_crop_start_width_coordinate - calculated_crop_start_width_coordinate
)
right_mirror_size = (
calculated_crop_end_width_coordinate - actual_crop_end_width_coordinate
)
# Mirror the expanded crop if necessary
if (
top_mirror_size > 0
or bottom_mirror_size > 0
or left_mirror_size > 0
or right_mirror_size > 0
):
expanded_crop = F.pad(
expanded_crop,
(left_mirror_size, right_mirror_size, top_mirror_size, bottom_mirror_size),
mode="reflect",
)
return expanded_crop