LD3 / main.py
vinhtong97's picture
Upload folder using huggingface_hub
d382778 verified
import torch
import logging
import os
import sys
import time
from dataset import load_data_from_dir, LD3Dataset
from trainer import LD3Trainer, ModelConfig, TrainingConfig
from utils import (
create_desc,
is_trained,
get_solvers,
parse_arguments,
adjust_hyper,
save_arguments_to_yaml,
)
from models import prepare_stuff
def setup_logging(log_dir):
"""
checked!
"""
# Reset logging configuration
logging.shutdown()
import importlib
importlib.reload(logging)
log_format = "%(asctime)s %(message)s"
logging.basicConfig(
stream=sys.stdout,
level=logging.INFO,
format=log_format,
datefmt="%m/%d %I:%M:%S %p",
)
fh = logging.FileHandler(os.path.join(log_dir, "log.txt"))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)
return logging
def main(args):
if args.use_ema:
print("Auto update use_ema to False for training")
args.use_ema = False
wrapped_model, _, decoding_fn, noise_schedule, latent_resolution, latent_channel, _, _ = prepare_stuff(args)
adjust_hyper(args, latent_resolution, latent_channel)
desc = create_desc(args)
log_dir = os.path.join(args.log_path, desc)
if is_trained(log_dir):
print("Skip training")
return
else:
print("The model hasn't been trained yet. Perform training")
os.makedirs(log_dir, exist_ok=True)
save_arguments_to_yaml(args, os.path.join(log_dir, "config.yml"))
setup_logging(log_dir)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
solver, steps, solver_extra_params = get_solvers(
args.solver_name,
NFEs=args.steps,
order=args.order,
noise_schedule=noise_schedule,
unipc_variant=args.unipc_variant,
)
latents, targets, conditions, unconditions = load_data_from_dir(
data_folder=args.data_dir, limit=args.num_train + args.num_valid
)
ori_latents = [latent.clone() for latent in latents]
train_dataset = LD3Dataset(
ori_latents[: args.num_train],
latents[: args.num_train],
targets[: args.num_train],
conditions[: args.num_train],
unconditions[: args.num_train],
)
if args.num_valid > 0 :
valid_dataset = LD3Dataset(
ori_latents[args.num_train :],
latents[args.num_train :],
targets[args.num_train :],
conditions[args.num_train :],
unconditions[args.num_train :],
)
else:
valid_dataset = train_dataset
training_config = TrainingConfig(
train_data=train_dataset,
valid_data=valid_dataset,
train_batch_size=args.main_train_batch_size,
valid_batch_size=args.main_valid_batch_size,
lr_time_1=args.lr_time_1,
lr_time_2=args.lr_time_2,
shift_lr=args.shift_lr,
shift_lr_decay=args.shift_lr_decay,
min_lr_time_1=args.min_lr_time_1,
min_lr_time_2=args.min_lr_time_2,
win_rate=args.win_rate,
patient=args.patient,
lr_time_decay=args.lr_time_decay,
momentum_time_1=args.momentum_time_1,
weight_decay_time_1=args.weight_decay_time_1,
loss_type=args.loss_type,
visualize=args.visualize,
no_v1=args.no_v1,
prior_timesteps=args.gits_ts,
match_prior=args.match_prior,
)
model_config = ModelConfig(
net=wrapped_model,
decoding_fn=decoding_fn,
noise_schedule=noise_schedule,
solver=solver,
solver_name=args.solver_name,
order=args.order,
steps=steps,
prior_bound=args.prior_bound,
resolution=latent_resolution,
channels=latent_channel,
time_mode=args.time_mode,
solver_extra_params=solver_extra_params,
snapshot_path=log_dir,
device=device,
)
trainer = LD3Trainer(model_config, training_config)
start = time.time()
trainer.train(args.training_rounds_v1, args.training_rounds_v2)
end = time.time()
logging.info(f"Training time: {end - start}")
if __name__ == "__main__":
args = parse_arguments()
main(args)