|
|
import torch |
|
|
|
|
|
import logging |
|
|
import os |
|
|
import sys |
|
|
import time |
|
|
|
|
|
from dataset import load_data_from_dir, LD3Dataset |
|
|
from trainer import LD3Trainer, ModelConfig, TrainingConfig |
|
|
from utils import ( |
|
|
create_desc, |
|
|
is_trained, |
|
|
get_solvers, |
|
|
parse_arguments, |
|
|
adjust_hyper, |
|
|
save_arguments_to_yaml, |
|
|
) |
|
|
from models import prepare_stuff |
|
|
|
|
|
|
|
|
def setup_logging(log_dir): |
|
|
""" |
|
|
checked! |
|
|
""" |
|
|
|
|
|
logging.shutdown() |
|
|
import importlib |
|
|
importlib.reload(logging) |
|
|
log_format = "%(asctime)s %(message)s" |
|
|
logging.basicConfig( |
|
|
stream=sys.stdout, |
|
|
level=logging.INFO, |
|
|
format=log_format, |
|
|
datefmt="%m/%d %I:%M:%S %p", |
|
|
) |
|
|
fh = logging.FileHandler(os.path.join(log_dir, "log.txt")) |
|
|
fh.setFormatter(logging.Formatter(log_format)) |
|
|
logging.getLogger().addHandler(fh) |
|
|
return logging |
|
|
|
|
|
|
|
|
def main(args): |
|
|
|
|
|
if args.use_ema: |
|
|
print("Auto update use_ema to False for training") |
|
|
args.use_ema = False |
|
|
|
|
|
wrapped_model, _, decoding_fn, noise_schedule, latent_resolution, latent_channel, _, _ = prepare_stuff(args) |
|
|
|
|
|
adjust_hyper(args, latent_resolution, latent_channel) |
|
|
desc = create_desc(args) |
|
|
|
|
|
log_dir = os.path.join(args.log_path, desc) |
|
|
if is_trained(log_dir): |
|
|
print("Skip training") |
|
|
return |
|
|
else: |
|
|
print("The model hasn't been trained yet. Perform training") |
|
|
os.makedirs(log_dir, exist_ok=True) |
|
|
save_arguments_to_yaml(args, os.path.join(log_dir, "config.yml")) |
|
|
setup_logging(log_dir) |
|
|
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
solver, steps, solver_extra_params = get_solvers( |
|
|
args.solver_name, |
|
|
NFEs=args.steps, |
|
|
order=args.order, |
|
|
noise_schedule=noise_schedule, |
|
|
unipc_variant=args.unipc_variant, |
|
|
) |
|
|
|
|
|
latents, targets, conditions, unconditions = load_data_from_dir( |
|
|
data_folder=args.data_dir, limit=args.num_train + args.num_valid |
|
|
) |
|
|
ori_latents = [latent.clone() for latent in latents] |
|
|
|
|
|
train_dataset = LD3Dataset( |
|
|
ori_latents[: args.num_train], |
|
|
latents[: args.num_train], |
|
|
targets[: args.num_train], |
|
|
conditions[: args.num_train], |
|
|
unconditions[: args.num_train], |
|
|
) |
|
|
if args.num_valid > 0 : |
|
|
valid_dataset = LD3Dataset( |
|
|
ori_latents[args.num_train :], |
|
|
latents[args.num_train :], |
|
|
targets[args.num_train :], |
|
|
conditions[args.num_train :], |
|
|
unconditions[args.num_train :], |
|
|
) |
|
|
else: |
|
|
valid_dataset = train_dataset |
|
|
|
|
|
training_config = TrainingConfig( |
|
|
train_data=train_dataset, |
|
|
valid_data=valid_dataset, |
|
|
train_batch_size=args.main_train_batch_size, |
|
|
valid_batch_size=args.main_valid_batch_size, |
|
|
lr_time_1=args.lr_time_1, |
|
|
lr_time_2=args.lr_time_2, |
|
|
shift_lr=args.shift_lr, |
|
|
shift_lr_decay=args.shift_lr_decay, |
|
|
min_lr_time_1=args.min_lr_time_1, |
|
|
min_lr_time_2=args.min_lr_time_2, |
|
|
win_rate=args.win_rate, |
|
|
patient=args.patient, |
|
|
lr_time_decay=args.lr_time_decay, |
|
|
momentum_time_1=args.momentum_time_1, |
|
|
weight_decay_time_1=args.weight_decay_time_1, |
|
|
loss_type=args.loss_type, |
|
|
visualize=args.visualize, |
|
|
no_v1=args.no_v1, |
|
|
prior_timesteps=args.gits_ts, |
|
|
match_prior=args.match_prior, |
|
|
) |
|
|
model_config = ModelConfig( |
|
|
net=wrapped_model, |
|
|
decoding_fn=decoding_fn, |
|
|
noise_schedule=noise_schedule, |
|
|
solver=solver, |
|
|
solver_name=args.solver_name, |
|
|
order=args.order, |
|
|
steps=steps, |
|
|
prior_bound=args.prior_bound, |
|
|
resolution=latent_resolution, |
|
|
channels=latent_channel, |
|
|
time_mode=args.time_mode, |
|
|
solver_extra_params=solver_extra_params, |
|
|
snapshot_path=log_dir, |
|
|
device=device, |
|
|
) |
|
|
trainer = LD3Trainer(model_config, training_config) |
|
|
|
|
|
start = time.time() |
|
|
trainer.train(args.training_rounds_v1, args.training_rounds_v2) |
|
|
end = time.time() |
|
|
logging.info(f"Training time: {end - start}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = parse_arguments() |
|
|
main(args) |
|
|
|