File size: 16,483 Bytes
6021dd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
import os
import torch
import numpy as np
import lpips as lp
import pandas as pd
import torchmetrics
import matplotlib.pyplot as plt
from bisect import bisect_right
import torchvision.transforms as T
from torch import nn

from matplotlib.colors import ListedColormap, BoundaryNorm
from matplotlib.lines import Line2D

from data import dutils

# =======================================================================
# Scheduler Helper Function
# =======================================================================

class SequentialLR(torch.optim.lr_scheduler._LRScheduler):
    """Receives the list of schedulers that is expected to be called sequentially during
    optimization process and milestone points that provides exact intervals to reflect
    which scheduler is supposed to be called at a given epoch.

    Args:
        schedulers (list): List of chained schedulers.
        milestones (list): List of integers that reflects milestone points.

    Example:
        >>> # Assuming optimizer uses lr = 1. for all groups
        >>> # lr = 0.1     if epoch == 0
        >>> # lr = 0.1     if epoch == 1
        >>> # lr = 0.9     if epoch == 2
        >>> # lr = 0.81    if epoch == 3
        >>> # lr = 0.729   if epoch == 4
        >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)
        >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9)
        >>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2])
        >>> for epoch in range(100):
        >>>     train(...)
        >>>     validate(...)
        >>>     scheduler.step()
    """

    def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose=False):
        for scheduler_idx in range(1, len(schedulers)):
            if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer):
                raise ValueError(
                    "Sequential Schedulers expects all schedulers to belong to the same optimizer, but "
                    "got schedulers at index {} and {} to be different".format(0, scheduler_idx)
                )
        if (len(milestones) != len(schedulers) - 1):
            raise ValueError(
                "Sequential Schedulers expects number of schedulers provided to be one more "
                "than the number of milestone points, but got number of schedulers {} and the "
                "number of milestones to be equal to {}".format(len(schedulers), len(milestones))
            )
        self.optimizer = optimizer
        self._schedulers = schedulers
        self._milestones = milestones
        self.last_epoch = last_epoch + 1

    def step(self, ref=None):
        self.last_epoch += 1
        idx = bisect_right(self._milestones, self.last_epoch)
        if idx > 0 and self._milestones[idx - 1] == self.last_epoch:
            self._schedulers[idx].step(0)
        else:
            # Check HERE
            if isinstance(self._schedulers[idx], torch.optim.lr_scheduler.ReduceLROnPlateau):
                self._schedulers[idx].step(ref)
            else:
                self._schedulers[idx].step()

    def state_dict(self):
        """Returns the state of the scheduler as a :class:`dict`.

        It contains an entry for every variable in self.__dict__ which
        is not the optimizer.
        The wrapped scheduler states will also be saved.
        """
        state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')}
        state_dict['_schedulers'] = [None] * len(self._schedulers)

        for idx, s in enumerate(self._schedulers):
            state_dict['_schedulers'][idx] = s.state_dict()

        return state_dict

    def load_state_dict(self, state_dict):
        """Loads the schedulers state.

        Args:
            state_dict (dict): scheduler state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        _schedulers = state_dict.pop('_schedulers')
        self.__dict__.update(state_dict)
        # Restore state_dict keys in order to prevent side effects
        # https://github.com/pytorch/pytorch/issues/32756
        state_dict['_schedulers'] = _schedulers

        for idx, s in enumerate(_schedulers):
            self._schedulers[idx].load_state_dict(s)

def warmup_lambda(warmup_steps, min_lr_ratio=0.1):
    def ret_lambda(epoch):
        if epoch <= warmup_steps:
            return min_lr_ratio + (1.0 - min_lr_ratio) * epoch / warmup_steps
        else:
            return 1.0
    return ret_lambda

# =======================================================================
# Utils in utils :)
# =======================================================================
def to_cpu_tensor(*args):
    '''
    Input arbitrary number of array/tensors, each will be converted to CPU torch.Tensor
    '''
    out = []
    for tensor in args:
        if type(tensor) is np.ndarray:
            tensor = torch.Tensor(tensor)    
        if type(tensor) is torch.Tensor:
            tensor = tensor.cpu()
        out.append(tensor)
    # single value input: return single value output
    if len(out) == 1:
        return out[0]
    return out

def merge_leading_dims(tensor, n=2):
    '''
    Merge the first N dimension of a tensor
    '''
    return tensor.reshape((-1, *tensor.shape[n:]))

# =======================================================================
# Model Preparation, saving & loading (copied from utils.py)
# =======================================================================
def build_model_name(model_type, model_config):
    '''
    Build the model name (without extension)
    '''
    model_name = model_type + '_'
    for k, v in model_config.items():
        model_name += k
        if type(v) is list or type(v) is tuple:
            model_name += '-'
            for i, item in enumerate(v):
                model_name += (str(item) if type(item) is not bool else '') + ('-' if i < len(v)-1 else '')                
        else:
            model_name += (('-' + str(v)) if type(v) is not bool else '')
        model_name += '_'
    return model_name[:-1]

def build_model_path(base_dir, dataset_type, model_type, timestamp=None):
    if timestamp is None:
        return os.path.join(base_dir, dataset_type, model_type)
    elif timestamp == True:
        return os.path.join(base_dir, dataset_type, model_type, pd.Timestamp.now().strftime('%Y%m%d%H%M%S'))
    return os.path.join(base_dir, dataset_type, model_type, timestamp)

# =======================================================================
# Preprocess Function for Loading HKO-7 dataset
# =======================================================================

def hko7_preprocess(x_seq, x_mask, dt_clip, args):
    resize = args.resize if 'resize' in args else x_seq.shape[-1]
    seq_len = args.seq_len if 'seq_len' in args else 5

    # post-process on HKO-10
    x_seq = x_seq.transpose((1, 0, 2, 3, 4)) / 255. # => (batch_size, seq_length, 1, 480, 480)
    if 'scale' in args and args.scale == 'non-linear':
        x_seq = dutils.linear_to_nonlinear_batched(x_seq, dt_clip)
    else:
        x_seq = dutils.nonlinear_to_linear_batched(x_seq, dt_clip)

    b, t, c, h, w = x_seq.shape
    assert c == 1, f'# channels ({c}) != 1'

    # resize (downsample) the images if necessary
    x_seq = torch.Tensor(x_seq).float().reshape((b*t, c, h, w))
    if resize != h:
        tform = T.Compose([
            T.ToPILImage(), 
            T.Resize(resize),
            T.ToTensor(),
        ])
    else:
        tform = T.Compose([])

    x_seq = torch.stack([tform(x_frame) for x_frame in x_seq], dim=0)
    x_seq = x_seq.reshape((b, t, c, resize, resize))

    x, y = x_seq[:, :seq_len], x_seq[:, seq_len:]
    return x, y

# =======================================================================
# Evaluation Metrics-Related 
# =======================================================================

mae = lambda *args: torch.nn.functional.l1_loss(*args).cpu().detach().numpy()
mse = lambda *args: torch.nn.functional.mse_loss(*args).cpu().detach().numpy()

def ssim(y_pred, y):
    y, y_pred = to_cpu_tensor(y, y_pred)
    b, t, c, h, w = y.shape
    y = y.reshape((b*t, c, h, w))
    y_pred = y_pred.reshape((b*t, c, h, w))
    # to further ensure any of the input is not negative
    y = torch.clamp(y, 0, 1)
    y_pred = torch.clamp(y_pred, 0, 1)
    return torchmetrics.image.StructuralSimilarityIndexMeasure(data_range=1.0)(y_pred, y)

def psnr(y_pred, y):
    y, y_pred = to_cpu_tensor(y, y_pred)
    b, t, c, h, w = y.shape
    y = y.reshape((b*t, c, h, w))
    y_pred = y_pred.reshape((b*t, c, h, w))
    acc_score = 0
    for i in range(b*t):
        acc_score += torchmetrics.image.PeakSignalNoiseRatio(data_range=1.0)(y_pred[i], y[i]) / (b*t)
    return acc_score

GLOBAL_LPIPS_OBJ = None # a static variable
def lpips64(y_pred, y, net='vgg'):
    # convert the image range into [-1, 1], assuming the input range to be [0, 1]
    y = merge_leading_dims(y)
    y_pred = merge_leading_dims(y_pred)

    y = torch.nn.functional.interpolate(y, (64, 64), mode='bicubic').clamp(0,1)
    y_pred = torch.nn.functional.interpolate(y_pred, (64, 64), mode='bicubic').clamp(0,1)
    
    y = (2 * y - 1)
    y_pred = (2 * y_pred - 1)
    global GLOBAL_LPIPS_OBJ
    if GLOBAL_LPIPS_OBJ is None:
        GLOBAL_LPIPS_OBJ = lp.LPIPS(net=net).to(y.device)
    return GLOBAL_LPIPS_OBJ(y_pred, y).mean()

def tfpn(y_pred, y, threshold, radius=1):
    '''
    convert to cpu, and merge the first two dimensions
    '''
    y = merge_leading_dims(y)
    y_pred = merge_leading_dims(y_pred)
    with torch.no_grad():
        if radius > 1:
            pool = nn.MaxPool2d(radius)
            y = pool(y)
            y_pred = pool(y_pred) 
        y = torch.where(y >= threshold, 1, 0)
        y_pred = torch.where(y_pred >= threshold, 1, 0)
        mat = torchmetrics.functional.confusion_matrix(y_pred, y, task='binary', threshold=threshold)
        (tn, fp), (fn, tp) = to_cpu_tensor(mat)
    return tp, tn, fp, fn

def tfpn_pool(y_pred, y, threshold, radius):
    y_pred = merge_leading_dims(y_pred)
    y = merge_leading_dims(y)
    pool = nn.MaxPool2d(radius, stride=radius//4 if radius//4 > 0 else radius) 
    with torch.no_grad():
        y = torch.where(y>=threshold, 1, 0).float()
        y_pred = torch.where(y_pred>=threshold, 1, 0).float()
        y = pool(y)
        y_pred = pool(y_pred)
        mat = torchmetrics.functional.confusion_matrix(y_pred, y, task='binary', threshold=threshold)
        (tn, fp), (fn, tp) = to_cpu_tensor(mat)
    return tp, tn, fp, fn

def csi(tp, tn, fp, fn):
    '''Critical Success Index. The larger the better.'''
    if (tp + fn + fp) < 1e-7:
        return 0.
    return tp / (tp + fn + fp)

def hss(tp, tn, fp, fn):
    '''Heidke Skill Score. (-inf, 1]. Larger better.'''
    if (tp+fn)*(fn+tn) + (tp+fp)*(fp+tn) == 0:
        return 0.
    return 2 * (tp*tn - fp*fn) / ((tp+fn)*(fn+tn) + (tp+fp)*(fp+tn))

# =======================================================================
# Data Visualization
# =======================================================================

def torch_visualize(sequences, savedir=None, horizontal=10, vmin=0, vmax=1):
    '''
    input: sequences, a list/dict of numpy/torch arrays with shape (B, T, C, H, W) 
    C is assumed to be 1 and squeezed 
    If batch > 1, only the first sequence will be printed 
    '''        
    # First pass: compute the vertical height and convert to proper format
    vertical = 0
    display_texts = []
    if (type(sequences) is dict):
        temp = []
        for k, v in sequences.items():
            vertical += int(np.ceil(v.shape[1] / horizontal)) 
            temp.append(v)
            display_texts.append(k)            
        sequences = temp
    else:
        for i, sequence in enumerate(sequences):
            vertical += int(np.ceil(sequence.shape[1] / horizontal))
            display_texts.append(f'Item {i+1}')
    sequences = to_cpu_tensor(*sequences)
    # Plot the sequences   
    j = 0
    fig, axes = plt.subplots(vertical, horizontal, figsize=(2*horizontal, 2*vertical), tight_layout=True)
    plt.setp(axes, xticks=[], yticks=[])
    for k, sequence in enumerate(sequences):
        # only take the first batch, now seq[0] is the temporal dim
        sequence = sequence[0].squeeze() # (T, H, W)
        axes[j, 0].set_ylabel(display_texts[k])
        for i, frame in enumerate(sequence):
            j_shift = j + i // horizontal 
            i_shift = i % horizontal
            axes[j_shift, i_shift].imshow(frame, vmin=vmin, vmax=vmax, cmap='gray')
        j += int(np.ceil(sequence.shape[0] / horizontal))    
    if savedir:
        plt.savefig(savedir + '' if savedir.endswith('.png') else '.png')
        plt.close()
    else:
        plt.show()

""" Visualize function with colorbar and a line seprate input and output """
def color_visualize(sequences, savedir='', horizontal=5, skip=1, ypos=0):
    '''
    input: sequences, a list/dict of numpy/torch arrays with shape (B, T, C, H, W) 
    C is assumed to be 1 and squeezed 
    If batch > 1, only the first sequence will be printed 
    '''        
    plt.style.use(['science', 'no-latex'])
    VIL_COLORS = [[0, 0, 0],
                [0.30196078431372547, 0.30196078431372547, 0.30196078431372547],
                [0.1568627450980392, 0.7450980392156863, 0.1568627450980392],
                [0.09803921568627451, 0.5882352941176471, 0.09803921568627451],
                [0.0392156862745098, 0.4117647058823529, 0.0392156862745098],
                [0.0392156862745098, 0.29411764705882354, 0.0392156862745098],
                [0.9607843137254902, 0.9607843137254902, 0.0],
                [0.9294117647058824, 0.6745098039215687, 0.0],
                [0.9411764705882353, 0.43137254901960786, 0.0],
                [0.6274509803921569, 0.0, 0.0],
                [0.9058823529411765, 0.0, 1.0]]

    VIL_LEVELS = [0.0, 16.0, 31.0, 59.0, 74.0, 100.0, 133.0, 160.0, 181.0, 219.0, 255.0]

    # First pass: compute the vertical height and convert to proper format
    vertical = 0
    display_texts = []
    if (type(sequences) is dict):
        temp = []
        for k, v in sequences.items():
            vertical += int(np.ceil(v.shape[1] / horizontal)) 
            temp.append(v)
            display_texts.append(k)            
        sequences = temp
    else:
        for i, sequence in enumerate(sequences):
            vertical += int(np.ceil(sequence.shape[1] / horizontal))
            display_texts.append(f'Item {i+1}')
    sequences = to_cpu_tensor(*sequences)
    # Plot the sequences   
    j = 0
    fig, axes = plt.subplots(vertical, horizontal, figsize=(2*horizontal, 2*vertical), tight_layout=True)
    plt.subplots_adjust(hspace=0.0, wspace=0.0) # tight layout
    plt.setp(axes, xticks=[], yticks=[])    
    for k, sequence in enumerate(sequences):
        # only take the first batch, now seq[0] is the temporal dim
        sequence = sequence[0].squeeze() # (T, H, W)
        
        ## =================
        # = labels of time =
        if k == 0:
            for i in range(len(sequence)):
                axes[j, i].set_xlabel(f'$t-{(len(sequence)-i)-1}$', fontsize=16)
                axes[j, i].xaxis.set_label_position('top') 
        elif k == len(sequences)-1:
            for i in range(len(sequence)):
                axes[j, i].set_xlabel(f'$t+{skip*i+1}$', fontsize=16)
                axes[j, i].xaxis.set_label_position('bottom')            
        ## =================        
        axes[j, 0].set_ylabel(display_texts[k], fontsize=16)
        for i, frame in enumerate(sequence):
            j_shift = j + i // horizontal 
            i_shift = i % horizontal
            im = axes[j_shift, i_shift].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \
                                          norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N))
        j += int(np.ceil(sequence.shape[0] / horizontal))    
    
    ## = plot splittin line =
    if ypos == 0:
        ypos = 1 - 1 / len(sequences) - 0.017
    fig.lines.append(Line2D((0, 1), (ypos, ypos), transform=fig.transFigure, ls='--', linewidth=2, color='#444'))
    # color bar
    cax = fig.add_axes([1, 0.05, 0.02, 0.5])
    fig.colorbar(im, cax=cax)
    ## =================
    if savedir:
        plt.savefig(savedir + '' if len(savedir)>0 else 'out.png')
        plt.close()
    else:
        plt.show()