File size: 5,129 Bytes
6021dd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
Sample Command
"""
import os, sys, logging, argparse
import torch
from torch import nn
import numpy as np
import torch.nn.functional as F

from stldm import *
import utilspp as utpp
from data.config import SEVIR_13_12, HKO7_5_20, METEONET_5_20
from data.loader import GET_TestLoader
from data.dutils import resize

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # Dataset related
    parser.add_argument('-d', '--dataset', type=str, default='', help='the dataset definition to be set')
    # model related
    parser.add_argument('-f', dest='checkpt', type=str, default='', help='model checkpoint to be loaded from (Empty = not loading)')
    parser.add_argument('-m', '--model', type=str, default='', help='the model definition to be created')
    parser.add_argument('--type', type=str, default='3D', help='Determine which kind of model to use, 2D or 3D')
    parser.add_argument('--c_str', type=float, default=0.0, help='CFG strength')
    parser.add_argument('--e_id', type=int, default=0, help='Ensemble ID')
    # hyperparams
    parser.add_argument('-s', '--step', type=int, default=-1, help='The number of steps to run. -1: the entire dataloader')
    parser.add_argument('-b', '--batch_size', type=int, default=16, help='The batch size')
    # logging related
    parser.add_argument('--print_every', type=int, default=100, help='The number of steps to log the training loss')
    parser.add_argument('-o', '--output', default=None, help='The path to save the log files')
    args = parser.parse_args()

    # prepare logger
    if args.output is None:
        path_list = args.checkpt.split("/")
        logfile_name = os.path.join(*path_list[:-1], 'logs', f'{path_list[-1][:-3]}.log')
    else:
        logfile_name = f'{args.output}.log'
    logging.basicConfig(level=logging.NOTSET, handlers=[logging.FileHandler(logfile_name), logging.StreamHandler()], format='%(message)s')
    logging.info(f'Model checkpoint: {args.checkpt}')
    logging.info(f'Steps: {args.step}')

    sampler_dir = os.path.join(*logfile_name.split("/")[:-2], f'CFG={args.c_str}_samples')
    os.makedirs(sampler_dir, exist_ok=True)

    # Prepare Dataloader
    dataset_config = globals()[args.dataset]
    dataset_param, dataset_meta = dataset_config['param'], dataset_config['meta']
    loader = GET_TestLoader(dataset_meta, dataset_param, args.batch_size)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Prepare Model
    assert args.type in ['2D', '3D'], 'Please specify either 2D or 3D'
    model_config = globals()[args.model]
    model = n2n_setup[args.type](model_config, print_info=True, cfg_str=args.c_str if args.c_str != 0.0 else None).to(device)
    logging.info(f'CFG Scheduler: Const-{args.c_str}')
    
    data = torch.load(args.checkpt, map_location=device)
    if 'model' in data.keys():
        model.load_state_dict(data['model'])
    else:
        model.load_state_dict(data)
    
    
    in_len, out_len = model_config['vp_param']['shape_in'][0], model_config['vp_param']['shape_out'][0]
    img_size = model_config['vp_param']['shape_in'][-1]
    
    step = 0
    out = []
    while args.step < 0 or step <=args.step:
        model.eval()

        if dataset_meta['dataset'] == 'HKO-7':  
            setattr(args, 'seq_len', in_len)
            try:
                data = loader.sample(batch_size=args.batch_size)
            except Exception as e:
                logging.error(e)
                break
            x_seq, x_mask, dt_clip, _ = data                  
            x, y = utpp.hko7_preprocess(x_seq, x_mask, dt_clip, args)
        elif dataset_meta['dataset'] == 'SEVIR':
            data = loader.sample(batch_size=args.batch_size)
            if data is None:
                break
            x, y = data['vil'][:, :in_len], data['vil'][:, in_len:]
        elif dataset_meta['dataset'].startswith('meteo'):
            try:
                x, y = next(loader)
            except Exception as e:
                logging.error(e)
                break
                
        x, y = x.to(device), y.to(device)

        with torch.no_grad():
            if x.shape[-1] != img_size:
                x = resize(x, img_size)
                y = resize(y, img_size) # TO compare with DiffCast paper
            if model_config['pre'] is not None:
                x = model_config['pre'](x)

            y_pred = model(x)
                
            if model_config['post'] is not None:
                x = model_config['post'](x)
                y_pred = model_config['post'](y_pred)
            y_pred = y_pred.clamp(0,1)
            
            out.append(y_pred.detach().cpu())

        step += 1
        # log/print every
        if step == 1 or step % args.print_every == 0:
            logging.info(f'{step} Steps Generated, {len(out)} in out array')  
        
    logging.info(f'{step} Steps Generated, {len(out)} in out array')  
    out = torch.cat(out, dim=0)
    out = out.numpy()
    save_path = os.path.join(sampler_dir, f'BTCHW_total-no:{len(out)}_e={args.e_id}.npy')
    np.save(save_path, out)
    print('Output saved in', save_path)