|
|
import os |
|
|
import torch |
|
|
import numpy as np |
|
|
import lpips as lp |
|
|
import pandas as pd |
|
|
import torchmetrics |
|
|
import matplotlib.pyplot as plt |
|
|
from bisect import bisect_right |
|
|
import torchvision.transforms as T |
|
|
from torch import nn |
|
|
|
|
|
from matplotlib.colors import ListedColormap, BoundaryNorm |
|
|
from matplotlib.lines import Line2D |
|
|
|
|
|
from data import dutils |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SequentialLR(torch.optim.lr_scheduler._LRScheduler): |
|
|
"""Receives the list of schedulers that is expected to be called sequentially during |
|
|
optimization process and milestone points that provides exact intervals to reflect |
|
|
which scheduler is supposed to be called at a given epoch. |
|
|
|
|
|
Args: |
|
|
schedulers (list): List of chained schedulers. |
|
|
milestones (list): List of integers that reflects milestone points. |
|
|
|
|
|
Example: |
|
|
>>> # Assuming optimizer uses lr = 1. for all groups |
|
|
>>> # lr = 0.1 if epoch == 0 |
|
|
>>> # lr = 0.1 if epoch == 1 |
|
|
>>> # lr = 0.9 if epoch == 2 |
|
|
>>> # lr = 0.81 if epoch == 3 |
|
|
>>> # lr = 0.729 if epoch == 4 |
|
|
>>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2) |
|
|
>>> scheduler2 = ExponentialLR(self.opt, gamma=0.9) |
|
|
>>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2]) |
|
|
>>> for epoch in range(100): |
|
|
>>> train(...) |
|
|
>>> validate(...) |
|
|
>>> scheduler.step() |
|
|
""" |
|
|
|
|
|
def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose=False): |
|
|
for scheduler_idx in range(1, len(schedulers)): |
|
|
if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer): |
|
|
raise ValueError( |
|
|
"Sequential Schedulers expects all schedulers to belong to the same optimizer, but " |
|
|
"got schedulers at index {} and {} to be different".format(0, scheduler_idx) |
|
|
) |
|
|
if (len(milestones) != len(schedulers) - 1): |
|
|
raise ValueError( |
|
|
"Sequential Schedulers expects number of schedulers provided to be one more " |
|
|
"than the number of milestone points, but got number of schedulers {} and the " |
|
|
"number of milestones to be equal to {}".format(len(schedulers), len(milestones)) |
|
|
) |
|
|
self.optimizer = optimizer |
|
|
self._schedulers = schedulers |
|
|
self._milestones = milestones |
|
|
self.last_epoch = last_epoch + 1 |
|
|
|
|
|
def step(self, ref=None): |
|
|
self.last_epoch += 1 |
|
|
idx = bisect_right(self._milestones, self.last_epoch) |
|
|
if idx > 0 and self._milestones[idx - 1] == self.last_epoch: |
|
|
self._schedulers[idx].step(0) |
|
|
else: |
|
|
|
|
|
if isinstance(self._schedulers[idx], torch.optim.lr_scheduler.ReduceLROnPlateau): |
|
|
self._schedulers[idx].step(ref) |
|
|
else: |
|
|
self._schedulers[idx].step() |
|
|
|
|
|
def state_dict(self): |
|
|
"""Returns the state of the scheduler as a :class:`dict`. |
|
|
|
|
|
It contains an entry for every variable in self.__dict__ which |
|
|
is not the optimizer. |
|
|
The wrapped scheduler states will also be saved. |
|
|
""" |
|
|
state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')} |
|
|
state_dict['_schedulers'] = [None] * len(self._schedulers) |
|
|
|
|
|
for idx, s in enumerate(self._schedulers): |
|
|
state_dict['_schedulers'][idx] = s.state_dict() |
|
|
|
|
|
return state_dict |
|
|
|
|
|
def load_state_dict(self, state_dict): |
|
|
"""Loads the schedulers state. |
|
|
|
|
|
Args: |
|
|
state_dict (dict): scheduler state. Should be an object returned |
|
|
from a call to :meth:`state_dict`. |
|
|
""" |
|
|
_schedulers = state_dict.pop('_schedulers') |
|
|
self.__dict__.update(state_dict) |
|
|
|
|
|
|
|
|
state_dict['_schedulers'] = _schedulers |
|
|
|
|
|
for idx, s in enumerate(_schedulers): |
|
|
self._schedulers[idx].load_state_dict(s) |
|
|
|
|
|
def warmup_lambda(warmup_steps, min_lr_ratio=0.1): |
|
|
def ret_lambda(epoch): |
|
|
if epoch <= warmup_steps: |
|
|
return min_lr_ratio + (1.0 - min_lr_ratio) * epoch / warmup_steps |
|
|
else: |
|
|
return 1.0 |
|
|
return ret_lambda |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def to_cpu_tensor(*args): |
|
|
''' |
|
|
Input arbitrary number of array/tensors, each will be converted to CPU torch.Tensor |
|
|
''' |
|
|
out = [] |
|
|
for tensor in args: |
|
|
if type(tensor) is np.ndarray: |
|
|
tensor = torch.Tensor(tensor) |
|
|
if type(tensor) is torch.Tensor: |
|
|
tensor = tensor.cpu() |
|
|
out.append(tensor) |
|
|
|
|
|
if len(out) == 1: |
|
|
return out[0] |
|
|
return out |
|
|
|
|
|
def merge_leading_dims(tensor, n=2): |
|
|
''' |
|
|
Merge the first N dimension of a tensor |
|
|
''' |
|
|
return tensor.reshape((-1, *tensor.shape[n:])) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_model_name(model_type, model_config): |
|
|
''' |
|
|
Build the model name (without extension) |
|
|
''' |
|
|
model_name = model_type + '_' |
|
|
for k, v in model_config.items(): |
|
|
model_name += k |
|
|
if type(v) is list or type(v) is tuple: |
|
|
model_name += '-' |
|
|
for i, item in enumerate(v): |
|
|
model_name += (str(item) if type(item) is not bool else '') + ('-' if i < len(v)-1 else '') |
|
|
else: |
|
|
model_name += (('-' + str(v)) if type(v) is not bool else '') |
|
|
model_name += '_' |
|
|
return model_name[:-1] |
|
|
|
|
|
def build_model_path(base_dir, dataset_type, model_type, timestamp=None): |
|
|
if timestamp is None: |
|
|
return os.path.join(base_dir, dataset_type, model_type) |
|
|
elif timestamp == True: |
|
|
return os.path.join(base_dir, dataset_type, model_type, pd.Timestamp.now().strftime('%Y%m%d%H%M%S')) |
|
|
return os.path.join(base_dir, dataset_type, model_type, timestamp) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def hko7_preprocess(x_seq, x_mask, dt_clip, args): |
|
|
resize = args.resize if 'resize' in args else x_seq.shape[-1] |
|
|
seq_len = args.seq_len if 'seq_len' in args else 5 |
|
|
|
|
|
|
|
|
x_seq = x_seq.transpose((1, 0, 2, 3, 4)) / 255. |
|
|
if 'scale' in args and args.scale == 'non-linear': |
|
|
x_seq = dutils.linear_to_nonlinear_batched(x_seq, dt_clip) |
|
|
else: |
|
|
x_seq = dutils.nonlinear_to_linear_batched(x_seq, dt_clip) |
|
|
|
|
|
b, t, c, h, w = x_seq.shape |
|
|
assert c == 1, f'# channels ({c}) != 1' |
|
|
|
|
|
|
|
|
x_seq = torch.Tensor(x_seq).float().reshape((b*t, c, h, w)) |
|
|
if resize != h: |
|
|
tform = T.Compose([ |
|
|
T.ToPILImage(), |
|
|
T.Resize(resize), |
|
|
T.ToTensor(), |
|
|
]) |
|
|
else: |
|
|
tform = T.Compose([]) |
|
|
|
|
|
x_seq = torch.stack([tform(x_frame) for x_frame in x_seq], dim=0) |
|
|
x_seq = x_seq.reshape((b, t, c, resize, resize)) |
|
|
|
|
|
x, y = x_seq[:, :seq_len], x_seq[:, seq_len:] |
|
|
return x, y |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mae = lambda *args: torch.nn.functional.l1_loss(*args).cpu().detach().numpy() |
|
|
mse = lambda *args: torch.nn.functional.mse_loss(*args).cpu().detach().numpy() |
|
|
|
|
|
def ssim(y_pred, y): |
|
|
y, y_pred = to_cpu_tensor(y, y_pred) |
|
|
b, t, c, h, w = y.shape |
|
|
y = y.reshape((b*t, c, h, w)) |
|
|
y_pred = y_pred.reshape((b*t, c, h, w)) |
|
|
|
|
|
y = torch.clamp(y, 0, 1) |
|
|
y_pred = torch.clamp(y_pred, 0, 1) |
|
|
return torchmetrics.image.StructuralSimilarityIndexMeasure(data_range=1.0)(y_pred, y) |
|
|
|
|
|
def psnr(y_pred, y): |
|
|
y, y_pred = to_cpu_tensor(y, y_pred) |
|
|
b, t, c, h, w = y.shape |
|
|
y = y.reshape((b*t, c, h, w)) |
|
|
y_pred = y_pred.reshape((b*t, c, h, w)) |
|
|
acc_score = 0 |
|
|
for i in range(b*t): |
|
|
acc_score += torchmetrics.image.PeakSignalNoiseRatio(data_range=1.0)(y_pred[i], y[i]) / (b*t) |
|
|
return acc_score |
|
|
|
|
|
GLOBAL_LPIPS_OBJ = None |
|
|
def lpips64(y_pred, y, net='vgg'): |
|
|
|
|
|
y = merge_leading_dims(y) |
|
|
y_pred = merge_leading_dims(y_pred) |
|
|
|
|
|
y = torch.nn.functional.interpolate(y, (64, 64), mode='bicubic').clamp(0,1) |
|
|
y_pred = torch.nn.functional.interpolate(y_pred, (64, 64), mode='bicubic').clamp(0,1) |
|
|
|
|
|
y = (2 * y - 1) |
|
|
y_pred = (2 * y_pred - 1) |
|
|
global GLOBAL_LPIPS_OBJ |
|
|
if GLOBAL_LPIPS_OBJ is None: |
|
|
GLOBAL_LPIPS_OBJ = lp.LPIPS(net=net).to(y.device) |
|
|
return GLOBAL_LPIPS_OBJ(y_pred, y).mean() |
|
|
|
|
|
def tfpn(y_pred, y, threshold, radius=1): |
|
|
''' |
|
|
convert to cpu, and merge the first two dimensions |
|
|
''' |
|
|
y = merge_leading_dims(y) |
|
|
y_pred = merge_leading_dims(y_pred) |
|
|
with torch.no_grad(): |
|
|
if radius > 1: |
|
|
pool = nn.MaxPool2d(radius) |
|
|
y = pool(y) |
|
|
y_pred = pool(y_pred) |
|
|
y = torch.where(y >= threshold, 1, 0) |
|
|
y_pred = torch.where(y_pred >= threshold, 1, 0) |
|
|
mat = torchmetrics.functional.confusion_matrix(y_pred, y, task='binary', threshold=threshold) |
|
|
(tn, fp), (fn, tp) = to_cpu_tensor(mat) |
|
|
return tp, tn, fp, fn |
|
|
|
|
|
def tfpn_pool(y_pred, y, threshold, radius): |
|
|
y_pred = merge_leading_dims(y_pred) |
|
|
y = merge_leading_dims(y) |
|
|
pool = nn.MaxPool2d(radius, stride=radius//4 if radius//4 > 0 else radius) |
|
|
with torch.no_grad(): |
|
|
y = torch.where(y>=threshold, 1, 0).float() |
|
|
y_pred = torch.where(y_pred>=threshold, 1, 0).float() |
|
|
y = pool(y) |
|
|
y_pred = pool(y_pred) |
|
|
mat = torchmetrics.functional.confusion_matrix(y_pred, y, task='binary', threshold=threshold) |
|
|
(tn, fp), (fn, tp) = to_cpu_tensor(mat) |
|
|
return tp, tn, fp, fn |
|
|
|
|
|
def csi(tp, tn, fp, fn): |
|
|
'''Critical Success Index. The larger the better.''' |
|
|
if (tp + fn + fp) < 1e-7: |
|
|
return 0. |
|
|
return tp / (tp + fn + fp) |
|
|
|
|
|
def hss(tp, tn, fp, fn): |
|
|
'''Heidke Skill Score. (-inf, 1]. Larger better.''' |
|
|
if (tp+fn)*(fn+tn) + (tp+fp)*(fp+tn) == 0: |
|
|
return 0. |
|
|
return 2 * (tp*tn - fp*fn) / ((tp+fn)*(fn+tn) + (tp+fp)*(fp+tn)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def torch_visualize(sequences, savedir=None, horizontal=10, vmin=0, vmax=1): |
|
|
''' |
|
|
input: sequences, a list/dict of numpy/torch arrays with shape (B, T, C, H, W) |
|
|
C is assumed to be 1 and squeezed |
|
|
If batch > 1, only the first sequence will be printed |
|
|
''' |
|
|
|
|
|
vertical = 0 |
|
|
display_texts = [] |
|
|
if (type(sequences) is dict): |
|
|
temp = [] |
|
|
for k, v in sequences.items(): |
|
|
vertical += int(np.ceil(v.shape[1] / horizontal)) |
|
|
temp.append(v) |
|
|
display_texts.append(k) |
|
|
sequences = temp |
|
|
else: |
|
|
for i, sequence in enumerate(sequences): |
|
|
vertical += int(np.ceil(sequence.shape[1] / horizontal)) |
|
|
display_texts.append(f'Item {i+1}') |
|
|
sequences = to_cpu_tensor(*sequences) |
|
|
|
|
|
j = 0 |
|
|
fig, axes = plt.subplots(vertical, horizontal, figsize=(2*horizontal, 2*vertical), tight_layout=True) |
|
|
plt.setp(axes, xticks=[], yticks=[]) |
|
|
for k, sequence in enumerate(sequences): |
|
|
|
|
|
sequence = sequence[0].squeeze() |
|
|
axes[j, 0].set_ylabel(display_texts[k]) |
|
|
for i, frame in enumerate(sequence): |
|
|
j_shift = j + i // horizontal |
|
|
i_shift = i % horizontal |
|
|
axes[j_shift, i_shift].imshow(frame, vmin=vmin, vmax=vmax, cmap='gray') |
|
|
j += int(np.ceil(sequence.shape[0] / horizontal)) |
|
|
if savedir: |
|
|
plt.savefig(savedir + '' if savedir.endswith('.png') else '.png') |
|
|
plt.close() |
|
|
else: |
|
|
plt.show() |
|
|
|
|
|
""" Visualize function with colorbar and a line seprate input and output """ |
|
|
def color_visualize(sequences, savedir='', horizontal=5, skip=1, ypos=0): |
|
|
''' |
|
|
input: sequences, a list/dict of numpy/torch arrays with shape (B, T, C, H, W) |
|
|
C is assumed to be 1 and squeezed |
|
|
If batch > 1, only the first sequence will be printed |
|
|
''' |
|
|
plt.style.use(['science', 'no-latex']) |
|
|
VIL_COLORS = [[0, 0, 0], |
|
|
[0.30196078431372547, 0.30196078431372547, 0.30196078431372547], |
|
|
[0.1568627450980392, 0.7450980392156863, 0.1568627450980392], |
|
|
[0.09803921568627451, 0.5882352941176471, 0.09803921568627451], |
|
|
[0.0392156862745098, 0.4117647058823529, 0.0392156862745098], |
|
|
[0.0392156862745098, 0.29411764705882354, 0.0392156862745098], |
|
|
[0.9607843137254902, 0.9607843137254902, 0.0], |
|
|
[0.9294117647058824, 0.6745098039215687, 0.0], |
|
|
[0.9411764705882353, 0.43137254901960786, 0.0], |
|
|
[0.6274509803921569, 0.0, 0.0], |
|
|
[0.9058823529411765, 0.0, 1.0]] |
|
|
|
|
|
VIL_LEVELS = [0.0, 16.0, 31.0, 59.0, 74.0, 100.0, 133.0, 160.0, 181.0, 219.0, 255.0] |
|
|
|
|
|
|
|
|
vertical = 0 |
|
|
display_texts = [] |
|
|
if (type(sequences) is dict): |
|
|
temp = [] |
|
|
for k, v in sequences.items(): |
|
|
vertical += int(np.ceil(v.shape[1] / horizontal)) |
|
|
temp.append(v) |
|
|
display_texts.append(k) |
|
|
sequences = temp |
|
|
else: |
|
|
for i, sequence in enumerate(sequences): |
|
|
vertical += int(np.ceil(sequence.shape[1] / horizontal)) |
|
|
display_texts.append(f'Item {i+1}') |
|
|
sequences = to_cpu_tensor(*sequences) |
|
|
|
|
|
j = 0 |
|
|
fig, axes = plt.subplots(vertical, horizontal, figsize=(2*horizontal, 2*vertical), tight_layout=True) |
|
|
plt.subplots_adjust(hspace=0.0, wspace=0.0) |
|
|
plt.setp(axes, xticks=[], yticks=[]) |
|
|
for k, sequence in enumerate(sequences): |
|
|
|
|
|
sequence = sequence[0].squeeze() |
|
|
|
|
|
|
|
|
|
|
|
if k == 0: |
|
|
for i in range(len(sequence)): |
|
|
axes[j, i].set_xlabel(f'$t-{(len(sequence)-i)-1}$', fontsize=16) |
|
|
axes[j, i].xaxis.set_label_position('top') |
|
|
elif k == len(sequences)-1: |
|
|
for i in range(len(sequence)): |
|
|
axes[j, i].set_xlabel(f'$t+{skip*i+1}$', fontsize=16) |
|
|
axes[j, i].xaxis.set_label_position('bottom') |
|
|
|
|
|
axes[j, 0].set_ylabel(display_texts[k], fontsize=16) |
|
|
for i, frame in enumerate(sequence): |
|
|
j_shift = j + i // horizontal |
|
|
i_shift = i % horizontal |
|
|
im = axes[j_shift, i_shift].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \ |
|
|
norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N)) |
|
|
j += int(np.ceil(sequence.shape[0] / horizontal)) |
|
|
|
|
|
|
|
|
if ypos == 0: |
|
|
ypos = 1 - 1 / len(sequences) - 0.017 |
|
|
fig.lines.append(Line2D((0, 1), (ypos, ypos), transform=fig.transFigure, ls='--', linewidth=2, color='#444')) |
|
|
|
|
|
cax = fig.add_axes([1, 0.05, 0.02, 0.5]) |
|
|
fig.colorbar(im, cax=cax) |
|
|
|
|
|
if savedir: |
|
|
plt.savefig(savedir + '' if len(savedir)>0 else 'out.png') |
|
|
plt.close() |
|
|
else: |
|
|
plt.show() |
|
|
|