| | from oldm.hack import disable_verbosity |
| | disable_verbosity() |
| |
|
| | import os |
| | import sys |
| | import torch |
| | from datetime import datetime |
| |
|
| | file_path = os.path.abspath(__file__) |
| | parent_dir = os.path.abspath(os.path.dirname(file_path) + '/..') |
| | if parent_dir not in sys.path: |
| | sys.path.append(parent_dir) |
| |
|
| | import pytorch_lightning as pl |
| | from pytorch_lightning.callbacks import ModelCheckpoint |
| | from torch.utils.data import DataLoader |
| | from oldm.logger import ImageLogger |
| | from oldm.model import create_model, load_state_dict |
| | from dataset.utils import return_dataset |
| |
|
| | from oft import inject_trainable_oft, inject_trainable_oft_conv, inject_trainable_oft_extended, inject_trainable_oft_with_norm |
| | from hra import inject_trainable_hra |
| | from lora import inject_trainable_lora |
| |
|
| | import argparse |
| |
|
| |
|
| | parser = argparse.ArgumentParser() |
| |
|
| | |
| | parser.add_argument('--hra_r', type=int, default=8) |
| | parser.add_argument('--hra_apply_GS', action='store_true', default=False) |
| | parser.add_argument('--hra_lamb', type=float, default=0.0) |
| |
|
| | |
| | parser.add_argument('--oft_r', type=int, default=4) |
| | parser.add_argument('--oft_eps', type=float, default=7e-6) |
| | parser.add_argument('--oft_coft', action="store_true", default=True) |
| | parser.add_argument('--oft_block_share', action="store_true", default=False) |
| |
|
| | parser.add_argument('--batch_size', type=int, default=8) |
| | parser.add_argument('--num_samples', type=int, default=1) |
| | parser.add_argument('--plot_frequency', type=int, default=100) |
| | parser.add_argument('--learning_rate', type=float, default=9e-4) |
| | parser.add_argument('--sd_locked', action="store_true", default=True) |
| | parser.add_argument('--only_mid_control', action="store_true", default=False) |
| | parser.add_argument('--num_gpus', type=int, default=torch.cuda.device_count()) |
| | parser.add_argument('--resume_path', |
| | type=str, |
| | default='./models/hra_half_init_l_8.ckpt', |
| | ) |
| | parser.add_argument('--time_str', type=str, default=datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f")) |
| | parser.add_argument('--num_workers', type=int, default=8) |
| | parser.add_argument('--control', |
| | type=str, |
| | help='control signal. Options are [segm, sketch, densepose, depth, canny, landmark]', |
| | default="segm") |
| |
|
| | args = parser.parse_args() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | control = args.control |
| | |
| | |
| | train_dataset, val_dataset, data_name, logger_freq, max_epochs = return_dataset(control) |
| |
|
| | |
| | resume_path = args.resume_path |
| | |
| | batch_size = args.batch_size |
| | num_samples = args.num_samples |
| | plot_frequency = args.plot_frequency |
| | learning_rate = args.learning_rate |
| | sd_locked = args.sd_locked |
| | only_mid_control = args.only_mid_control |
| | num_gpus = args.num_gpus |
| | time_str = args.time_str |
| | num_workers = args.num_workers |
| | |
| | for arg in vars(args): |
| | print(f'{arg}: {getattr(args, arg)}') |
| | print(f'data_name: {data_name}\nlogger_freq: {logger_freq}\nmax_epochs: {max_epochs}') |
| | |
| | if 'oft' in args.resume_path: |
| | experiment = 'oft_{}_{}_eps_{}_pe_diff_mlp_r_{}_{}gpu_{}'.format(data_name, control, args.oft_eps, args.oft_r, num_gpus, time_str) |
| | elif 'hra' in args.resume_path: |
| | if args.hra_apply_GS: |
| | experiment = 'hra_apply_GS_{}_{}_pe_diff_mlp_r_{}_{}gpu_{}'.format(data_name, control, args.hra_r, num_gpus, time_str) |
| | else: |
| | experiment = 'hra_{}_{}_pe_diff_mlp_r_{}_lambda_{}_lr_{}_{}gpu_{}'.format(data_name, control, args.hra_r, args.hra_lamb, args.learning_rate, num_gpus, time_str) |
| | elif 'lora' in args.resume_path: |
| | experiment = 'lora_{}_{}_pe_diff_mlp_r_{}_{}gpu_{}'.format(data_name, control, args.r, num_gpus, time_str) |
| |
|
| | |
| | model = create_model('./configs/oft_ldm_v15.yaml').cpu() |
| | model.model.requires_grad_(False) |
| | print(f'Total parameters not requiring grad: {sum([p.numel() for p in model.model.parameters() if p.requires_grad == False])}') |
| |
|
| | |
| | if 'oft' in args.resume_path: |
| | unet_lora_params, train_names = inject_trainable_oft(model.model, r=args.oft_r, eps=args.oft_eps, is_coft=args.oft_coft, block_share=args.oft_block_share) |
| | elif 'hra' in args.resume_path: |
| | unet_lora_params, train_names = inject_trainable_hra(model.model, r=args.hra_r, apply_GS=args.hra_apply_GS) |
| | elif 'lora' in args.resume_path: |
| | unet_lora_params, train_names = inject_trainable_lora(model.model, rank=args.r, network_alpha=None) |
| | |
| | print(f'Total parameters requiring grad: {sum([p.numel() for p in model.model.parameters() if p.requires_grad == True])}') |
| |
|
| | model.load_state_dict(load_state_dict(resume_path, location='cpu')) |
| | model.learning_rate = learning_rate |
| | model.sd_locked = sd_locked |
| | model.only_mid_control = only_mid_control |
| |
|
| | checkpoint_callback = ModelCheckpoint( |
| | dirpath='log/image_log_' + experiment, |
| | filename='model-{epoch:02d}', |
| | save_top_k=-1, |
| | save_last=True, |
| | every_n_epochs=1, |
| | monitor=None, |
| | ) |
| |
|
| | |
| | train_dataloader = DataLoader(train_dataset, num_workers=num_workers, batch_size=batch_size, shuffle=False) |
| | val_dataloader = DataLoader(val_dataset, num_workers=num_workers, batch_size=1, shuffle=False) |
| |
|
| | logger = ImageLogger( |
| | val_dataloader=val_dataloader, |
| | batch_frequency=logger_freq, |
| | experiment=experiment, |
| | plot_frequency=plot_frequency, |
| | num_samples=num_samples, |
| | ) |
| |
|
| | trainer = pl.Trainer( |
| | max_epochs=max_epochs, |
| | gpus=num_gpus, |
| | precision=32, |
| | callbacks=[logger, checkpoint_callback], |
| | ) |
| |
|
| | |
| | last_model_path = 'log/image_log_' + experiment + '/last.ckpt' |
| | if os.path.exists(last_model_path): |
| | trainer.fit(model, train_dataloader, ckpt_path=last_model_path) |
| | else: |
| | trainer.fit(model, train_dataloader) |
| |
|