STLDM_official / nowcasting /models /deconvolution.py
sqfoo's picture
Upload 99 files
6021dd1 verified
from nowcasting.config import cfg, cfg_from_file, load_latest_cfg, save_cfg
from nowcasting.utils import * # TODO use explicit import
from nowcasting.utils import load_params
from nowcasting.ops import fc_layer, activation
from nowcasting.my_module import MyModule
from nowcasting.models.deconvolution_symbol import discriminator_symbol, generator_symbol
from nowcasting.hko_factory import HKONowcastingFactory
import os
import sys
import logging
import random
from collections import namedtuple
import mxnet as mx
import numpy as np
### Losses
def construct_l2_loss(gt, pred, normalize_gt=False):
"""Construct symbol of L2 loss.
Used variables:
gt: ground truth
pred: prediction (or real data during training)
Args:
gt: ground truth variable
pred: prediction (or real data during training) variable
normalize_gt: if True divide gt by 255.0
"""
if normalize_gt:
gt = gt / 255.0
if cfg.DATASET == "MOVINGMNIST":
return mx.sym.MakeLoss(
mx.sym.mean(mx.sym.square(gt - pred)),
grad_scale=cfg.MODEL.L2_LAMBDA,
name="mse")
elif cfg.DATASET == "HKO":
factory = HKONowcastingFactory(
batch_size=cfg.MODEL.TRAIN.BATCH_SIZE,
in_seq_len=cfg.HKO.BENCHMARK.IN_LEN,
out_seq_len=cfg.HKO.BENCHMARK.OUT_LEN)
return factory.loss_sym(pred=pred, target=gt)
### Modules
def construct_modules(args):
"""Construct modules for training or testing mode.
If args.testing is False, returns [generator_net, loss_net].
Otherwise only returns [generator_net]
"""
### Symbol construction
context = mx.sym.Variable('context')
gt = mx.sym.Variable('gt')
pred = mx.sym.Variable('pred')
if cfg.MODEL.TESTING:
sym_g = generator_symbol(context, momentum=1)
sym_d = discriminator_symbol(context, pred, momentum=1)
else:
sym_g = generator_symbol(context)
sym_d = discriminator_symbol(context, pred)
sym_l2_loss = construct_l2_loss(gt, pred)
### Module construction
modules = []
module_names = []
generator_net = MyModule(
sym_g, data_names=('context', ), label_names=None, context=args.ctx)
modules.append(generator_net)
module_names.append("generator")
loss_data_names = ['gt', 'pred']
if cfg.DATASET == "HKO":
loss_data_names.append('mask')
loss_net = MyModule(
mx.sym.Group([
sym_l2_loss, mx.sym.BlockGrad(
mx.sym.mean(
mx.sym.square(mx.sym.clip(pred, a_min=0, a_max=1) - gt)),
name="real_mse")
]),
data_names=loss_data_names,
label_names=None,
context=args.ctx)
modules.append(loss_net)
module_names.append("loss")
### Module binding
# Bind generator
if cfg.DATASET == "MOVINGMNIST":
IN_LEN = cfg.MOVINGMNIST.IN_LEN
OUT_LEN = cfg.MOVINGMNIST.OUT_LEN
IMG_SIZE = cfg.MOVINGMNIST.IMG_SIZE
elif cfg.DATASET == "HKO":
IN_LEN = cfg.HKO.BENCHMARK.IN_LEN
OUT_LEN = cfg.HKO.BENCHMARK.OUT_LEN
IMG_SIZE = cfg.HKO.ITERATOR.WIDTH
data_shapes = {
'context':
mx.io.DataDesc(
name='context',
shape=(cfg.MODEL.TRAIN.BATCH_SIZE, 1, IN_LEN, IMG_SIZE, IMG_SIZE),
layout="NCDHW"),
'gt':
mx.io.DataDesc(
name='gt',
shape=(cfg.MODEL.TRAIN.BATCH_SIZE, 1, OUT_LEN, IMG_SIZE, IMG_SIZE),
layout="NCDHW"),
'pred':
mx.io.DataDesc(
name='pred',
shape=(cfg.MODEL.TRAIN.BATCH_SIZE, 1, OUT_LEN, IMG_SIZE, IMG_SIZE),
layout="NCDHW")
}
if cfg.DATASET == "HKO":
data_shapes["mask"] = mx.io.DataDesc(
name='mask',
shape=(cfg.MODEL.TRAIN.BATCH_SIZE, 1, OUT_LEN, IMG_SIZE, IMG_SIZE),
layout="NCDHW")
label_shapes = {
'label':
mx.io.DataDesc(name='label', shape=(cfg.MODEL.TRAIN.BATCH_SIZE, 1))
}
init = mx.init.Xavier(rnd_type="gaussian", magnitude=1)
for m, name in zip(modules, module_names):
ds = [data_shapes[name] for name in m.data_names]
ls = [label_shapes[name] for name in m.label_names]
if len(ls) == 0:
ls = None
m.bind(data_shapes=ds, label_shapes=ls, inputs_need_grad=True)
if not cfg.MODEL.RESUME or name not in ["generator", "gan"]:
# Only "generator" and "gan" support being restored.
# All other modules are freshly initialized, even if RESUME == True.
m.init_params(initializer=init)
else:
logging.info("Loading parameters of {} from {}, Iter = {}".format(
name, os.path.realpath(
cfg.MODEL.LOAD_DIR), cfg.MODEL.LOAD_ITER))
arg_params, aux_params = load_params(
prefix=os.path.join(cfg.MODEL.LOAD_DIR, name),
epoch=cfg.MODEL.LOAD_ITER)
m.init_params(
arg_params=arg_params,
aux_params=aux_params,
allow_missing=False,
force_init=True)
logging.info("Loading complete!")
lr_scheduler = mx.lr_scheduler.FactorScheduler(
step=cfg.MODEL.TRAIN.LR_DECAY_ITER,
factor=cfg.MODEL.TRAIN.LR_DECAY_FACTOR,
stop_factor_lr=cfg.MODEL.TRAIN.MIN_LR)
if cfg.MODEL.TESTING and cfg.MODEL.TEST.FINETUNE:
optimizer_name = cfg.MODEL.TEST.ONLINE.OPTIMIZER
else:
optimizer_name = cfg.MODEL.TRAIN.OPTIMIZER
if optimizer_name == "adam":
m.init_optimizer(
optimizer="adam",
optimizer_params={
'learning_rate':
cfg.MODEL.TEST.ONLINE.LR if cfg.MODEL.TESTING and
cfg.MODEL.TEST.FINETUNE else cfg.MODEL.TRAIN.LR,
'rescale_grad':
1.0,
'epsilon':
cfg.MODEL.TEST.ONLINE.EPS if cfg.MODEL.TESTING and
cfg.MODEL.TEST.FINETUNE else cfg.MODEL.TRAIN.EPS,
'lr_scheduler':
None if cfg.MODEL.TESTING and cfg.MODEL.TEST.FINETUNE else
lr_scheduler,
'wd':
cfg.MODEL.TEST.ONLINE.WD if cfg.MODEL.TESTING and
cfg.MODEL.TEST.FINETUNE else cfg.MODEL.TRAIN.WD,
'beta1':
cfg.MODEL.TEST.ONLINE.BETA1 if cfg.MODEL.TESTING and
cfg.MODEL.TEST.FINETUNE else cfg.MODEL.TRAIN.BETA1
})
elif optimizer_name == "rmsprop":
m.init_optimizer(
optimizer="adagrad",
optimizer_params={
'learning_rate':
cfg.MODEL.TEST.ONLINE.LR if cfg.MODEL.TESTING and
cfg.MODEL.TEST.FINETUNE else cfg.MODEL.TRAIN.LR,
'rescale_grad':
1.0,
'epsilon':
cfg.MODEL.TEST.ONLINE.EPS if cfg.MODEL.TESTING and
cfg.MODEL.TEST.FINETUNE else cfg.MODEL.TRAIN.EPS,
'lr_scheduler':
None if cfg.MODEL.TESTING and cfg.MODEL.TEST.FINETUNE else
lr_scheduler,
'wd':
cfg.MODEL.TEST.ONLINE.WD if cfg.MODEL.TESTING and
cfg.MODEL.TEST.FINETUNE else cfg.MODEL.TRAIN.WD,
'gamma1':
cfg.MODEL.TEST.ONLINE.GAMMA1 if cfg.MODEL.TESTING and
cfg.MODEL.TEST.FINETUNE else cfg.MODEL.TRAIN.GAMMA1
})
elif optimizer_name == "adagrad":
m.init_optimizer(
optimizer="adagrad",
optimizer_params={
'learning_rate':
cfg.MODEL.TEST.ONLINE.LR if cfg.MODEL.TESTING and
cfg.MODEL.TEST.FINETUNE else cfg.MODEL.TRAIN.LR,
'rescale_grad':
1.0,
'lr_scheduler':
None if cfg.MODEL.TESTING and cfg.MODEL.TEST.FINETUNE else
lr_scheduler,
'wd':
cfg.MODEL.TEST.ONLINE.WD if cfg.MODEL.TESTING and
cfg.MODEL.TEST.FINETUNE else cfg.MODEL.TRAIN.WD
})
else:
raise NotImplementedError
m.summary()
return modules
### Arguments
def mode_args(parser):
group = parser.add_argument_group('Mode',
'Run in training or testing mode.')
group.add_argument(
'--test',
help='Run testing code. Implies --resume.',
action='store_true')
group.add_argument(
'--cfg',
dest='cfg_file',
help='Optional configuration file. '
'Given command line options will override defaults set in this configuration file.',
type=str)
group.add_argument('--save_dir', help='The saving directory', type=str)
group.add_argument(
'--resume',
help='Continue to train the previous model. This is implied by --test.',
action='store_true',
default=False)
group.add_argument(
'--load_dir',
help='Load model parameters from load_dir to continue training the previous model. '
'Only honoured if --resume is specified.',
type=str)
group.add_argument(
'--load_iter',
help='Load model parameters from specified iteration.',
type=int)
group.add_argument(
'--saving_postfix',
help='The postfix of the saving directory',
type=str)
group.add_argument(
'--ctx',
dest='ctx',
help='Running Context. E.g `--ctx gpu` or `--ctx gpu0,gpu1` or `--ctx cpu`',
type=str,
default='gpu')
def parse_mode_args(args):
args.ctx = parse_ctx(args.ctx)
if args.cfg_file:
cfg_from_file(args.cfg_file, target=cfg)
# Parameter loading
if args.test or cfg.MODEL.TESTING:
cfg.MODEL.TESTING = True
args.resume = True
if args.resume:
cfg.MODEL.RESUME = True
if args.load_dir:
cfg.MODEL.LOAD_DIR = args.load_dir
if args.load_iter:
cfg.MODEL.LOAD_ITER = args.load_iter
def training_args(parser):
group = parser.add_argument_group('Training',
'Configure training/testing process.')
group.add_argument(
'--seed',
help="Initialize mxnet and numpy random state with this seed.",
type=int)
group.add_argument(
'--batch_size',
dest='batch_size',
help="batchsize of the training process",
type=int)
group.add_argument('--lr', dest='lr', help='learning rate', type=float)
group.add_argument('--wd', dest='wd', help='weight decay', type=float)
group.add_argument(
'--grad_clip',
dest='grad_clip',
help='gradient clipping threshold',
type=float)
group.add_argument(
'--optimizer', dest='optimizer', help='optimizer to use', type=str)
group.add_argument(
'--l2_lambda',
dest='l2_lambda',
help="GAN_loss * λ_gan + L2_loss * λ_l2",
type=float)
group.add_argument(
'--gan_lambda',
dest='gan_lambda',
help="GAN_loss * λ_gan + L2_loss * λ_l2",
type=float)
group.add_argument(
'--original_gan_loss',
dest='use_original_gan_loss',
help="Use 2D convolutions / deconvolutions with same number of parameters as 3D model",
action="store_true")
group.add_argument(
'--label_smoothing_alpha',
dest='label_smoothing_alpha',
help="Change one sided label smoothing α",
type=float)
group.add_argument(
'--label_smoothing_beta',
dest='label_smoothing_beta',
help="Change two sided label smoothing β",
type=float)
def parse_training_args(args):
if args.batch_size:
cfg.MODEL.TRAIN.BATCH_SIZE = args.batch_size
if args.lr:
cfg.MODEL.TRAIN.LR = args.lr
if args.wd:
cfg.MODEL.TRAIN.WD = args.wd
if args.grad_clip:
cfg.MODEL.TRAIN.GRAD_CLIP = args.grad_clip
if args.optimizer:
cfg.MODEL.TRAIN.OPTIMIZER = args.optimizer
if args.l2_lambda:
cfg.MODEL.L2_LAMBDA = args.l2_lambda
if args.seed:
cfg.SEED = args.seed
if cfg.SEED:
logging.info("Fixing random seed to {}".format(cfg.SEED))
random.seed(cfg.SEED)
mx.random.seed(cfg.SEED)
np.random.seed(cfg.SEED)
def model_args(parser):
group = parser.add_argument_group('Model',
'Configure model model architecture.')
group.add_argument(
'--use_2d',
dest='use_2d',
help="Use 2D convolutions / deconvolutions with same number of parameters as 3D model",
action="store_true")
group.add_argument(
'--encoder',
dest='encoder',
help="'share', 'separate' or 'stack'. The way to encode context frames."
)
group.add_argument(
'--no_bn',
dest='bn',
help="Disable batch norm everywhere.",
action="store_false")
group.add_argument(
'--num_filter',
dest='num_filter',
help="Set the base number of filters.",
type=int)
def parse_model_args(args):
if args.use_2d:
cfg.MODEL.DECONVBASELINE.USE_3D = not args.use_2d
if args.encoder:
assert args.encoder in ["concat", "shared", "separate"]
cfg.MODEL.DECONVBASELINE.ENCODER = args.encoder
if args.bn:
cfg.MODEL.DECONVBASELINE.BN = args.bn
if args.num_filter:
cfg.MODEL.DECONVBASELINE.BASE_NUM_FILTER = args.num_filter
def get_base_dir(args):
if args.save_dir:
return args.save_dir
return "conv2d" if not cfg.MODEL.DECONVBASELINE.USE_3D else "conv3d"
### Training
def train_step(generator_net,
loss_net,
context_nd,
gt_nd,
mask_nd=None):
"""Fine-tune the encoder and forecaster for one step
Args:
generator_net
loss_net
context_nd
gt_nd
"""
# Forward generator
generator_net.forward(
is_train=True, data_batch=mx.io.DataBatch(data=[context_nd]))
generator_outputs = dict(
zip(generator_net.output_names, generator_net.get_outputs()))
pred_nd = generator_outputs["pred_output"]
# Calculate the gradient of the normal loss functions
loss_net.forward_backward(data_batch=mx.io.DataBatch(
data=[gt_nd, pred_nd]
if mask_nd is None else [gt_nd, pred_nd, mask_nd]))
loss_input_grads = dict(
zip(loss_net.data_names, loss_net.get_input_grads()))
pred_grad = loss_input_grads["pred"]
loss_out = dict(zip(loss_net.output_names, loss_net.get_outputs()))
avg_l2 = float(loss_out["mse_output"].asnumpy())
avg_real_mse = float(loss_out["real_mse_output"].asnumpy())
# Backward generator
generator_net.backward(out_grads=[pred_grad])
# Update forecaster and encoder
generator_grad_norm = generator_net.clip_by_global_norm(
max_norm=cfg.MODEL.TRAIN.GRAD_CLIP)
generator_net.update()
# encoder_net.update()
return generator_outputs["forecast_target_output"],\
avg_l2, avg_real_mse, generator_grad_norm
### Testing
def test_step(generator_net, context_nd):
"""Returns generated frames.
Returns:
shape=(cfg.MODEL.TRAIN.BATCH_SIZE, cfg.MOVINGMNIST.TESTING_LEN, 1,
cfg.MOVINGMNIST.IMG_SIZE, cfg.MOVINGMNIST.IMG_SIZE))
"""
if cfg.DATASET != "MOVINGMNIST":
raise NotImplementedError
if cfg.MOVINGMNIST.OUT_LEN == 1:
frames = np.empty(
shape=(cfg.MOVINGMNIST.TESTING_LEN, cfg.MODEL.TRAIN.BATCH_SIZE, 1,
cfg.MOVINGMNIST.IMG_SIZE, cfg.MOVINGMNIST.IMG_SIZE))
for frame_num in range(cfg.MOVINGMNIST.TESTING_LEN):
# Generate 1 frame
generator_net.forward(
data_batch=mx.io.DataBatch(data=[context_nd]), is_train=False)
generator_outputs = dict(
zip(generator_net.output_names, generator_net.get_outputs()))
pred_nd = generator_outputs["pred_output"]
pred_np = pred_nd.asnumpy()
# Insert new last frame
context_np = context_nd.asnumpy()
context_np = np.roll(a=context_np, shift=-1, axis=2)
context_np[:, :, -1, ] = pred_np[:, :, -1, ] # Construct context
context_nd = mx.nd.array(context_np)
# Store generated frame
frames[frame_num, ] = pred_np[:, :, -1, ]
return np.moveaxis(frames, 0, 1)
else:
generator_net.forward(
data_batch=mx.io.DataBatch(data=[context_nd]), is_train=False)
generator_outputs = dict(
zip(generator_net.output_names, generator_net.get_outputs()))
pred_nd = generator_outputs["pred_output"]
return pred_nd