File size: 6,082 Bytes
1c8e113 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | from oldm.hack import disable_verbosity
disable_verbosity()
import os
import sys
import torch
from datetime import datetime
file_path = os.path.abspath(__file__)
parent_dir = os.path.abspath(os.path.dirname(file_path) + '/..')
if parent_dir not in sys.path:
sys.path.append(parent_dir)
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
from oldm.logger import ImageLogger
from oldm.model import create_model, load_state_dict
from dataset.utils import return_dataset
from oft import inject_trainable_oft, inject_trainable_oft_conv, inject_trainable_oft_extended, inject_trainable_oft_with_norm
from hra import inject_trainable_hra
from lora import inject_trainable_lora
import argparse
parser = argparse.ArgumentParser()
# HRA
parser.add_argument('--hra_r', type=int, default=8)
parser.add_argument('--hra_apply_GS', action='store_true', default=False)
parser.add_argument('--hra_lamb', type=float, default=0.0)
# OFT
parser.add_argument('--oft_r', type=int, default=4)
parser.add_argument('--oft_eps', type=float, default=7e-6)
parser.add_argument('--oft_coft', action="store_true", default=True)
parser.add_argument('--oft_block_share', action="store_true", default=False)
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--num_samples', type=int, default=1)
parser.add_argument('--plot_frequency', type=int, default=100)
parser.add_argument('--learning_rate', type=float, default=9e-4)
parser.add_argument('--sd_locked', action="store_true", default=True)
parser.add_argument('--only_mid_control', action="store_true", default=False)
parser.add_argument('--num_gpus', type=int, default=torch.cuda.device_count())
parser.add_argument('--resume_path',
type=str,
default='./models/hra_half_init_l_8.ckpt',
)
parser.add_argument('--time_str', type=str, default=datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f"))
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--control',
type=str,
help='control signal. Options are [segm, sketch, densepose, depth, canny, landmark]',
default="segm")
args = parser.parse_args()
if __name__ == "__main__":
# specify the control signal and dataset
control = args.control
# create dataset
train_dataset, val_dataset, data_name, logger_freq, max_epochs = return_dataset(control) # , n_samples=n_samples)
# Configs
resume_path = args.resume_path
batch_size = args.batch_size
num_samples = args.num_samples
plot_frequency = args.plot_frequency
learning_rate = args.learning_rate
sd_locked = args.sd_locked
only_mid_control = args.only_mid_control
num_gpus = args.num_gpus
time_str = args.time_str
num_workers = args.num_workers
for arg in vars(args):
print(f'{arg}: {getattr(args, arg)}')
print(f'data_name: {data_name}\nlogger_freq: {logger_freq}\nmax_epochs: {max_epochs}')
if 'oft' in args.resume_path:
experiment = 'oft_{}_{}_eps_{}_pe_diff_mlp_r_{}_{}gpu_{}'.format(data_name, control, args.oft_eps, args.oft_r, num_gpus, time_str)
elif 'hra' in args.resume_path:
if args.hra_apply_GS:
experiment = 'hra_apply_GS_{}_{}_pe_diff_mlp_r_{}_{}gpu_{}'.format(data_name, control, args.hra_r, num_gpus, time_str)
else:
experiment = 'hra_{}_{}_pe_diff_mlp_r_{}_lambda_{}_lr_{}_{}gpu_{}'.format(data_name, control, args.hra_r, args.hra_lamb, args.learning_rate, num_gpus, time_str)
elif 'lora' in args.resume_path:
experiment = 'lora_{}_{}_pe_diff_mlp_r_{}_{}gpu_{}'.format(data_name, control, args.r, num_gpus, time_str)
# First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
model = create_model('./configs/oft_ldm_v15.yaml').cpu()
model.model.requires_grad_(False)
print(f'Total parameters not requiring grad: {sum([p.numel() for p in model.model.parameters() if p.requires_grad == False])}')
# inject trainable oft parameters
if 'oft' in args.resume_path:
unet_lora_params, train_names = inject_trainable_oft(model.model, r=args.oft_r, eps=args.oft_eps, is_coft=args.oft_coft, block_share=args.oft_block_share)
elif 'hra' in args.resume_path:
unet_lora_params, train_names = inject_trainable_hra(model.model, r=args.hra_r, apply_GS=args.hra_apply_GS)
elif 'lora' in args.resume_path:
unet_lora_params, train_names = inject_trainable_lora(model.model, rank=args.r, network_alpha=None)
print(f'Total parameters requiring grad: {sum([p.numel() for p in model.model.parameters() if p.requires_grad == True])}')
model.load_state_dict(load_state_dict(resume_path, location='cpu'))
model.learning_rate = learning_rate
model.sd_locked = sd_locked
model.only_mid_control = only_mid_control
checkpoint_callback = ModelCheckpoint(
dirpath='log/image_log_' + experiment,
filename='model-{epoch:02d}',
save_top_k=-1,
save_last=True,
every_n_epochs=1,
monitor=None, # No specific metric to monitor for saving
)
# Misc
train_dataloader = DataLoader(train_dataset, num_workers=num_workers, batch_size=batch_size, shuffle=False)
val_dataloader = DataLoader(val_dataset, num_workers=num_workers, batch_size=1, shuffle=False)
logger = ImageLogger(
val_dataloader=val_dataloader,
batch_frequency=logger_freq,
experiment=experiment,
plot_frequency=plot_frequency,
num_samples=num_samples,
)
trainer = pl.Trainer(
max_epochs=max_epochs,
gpus=num_gpus,
precision=32,
callbacks=[logger, checkpoint_callback],
)
# Train!
last_model_path = 'log/image_log_' + experiment + '/last.ckpt'
if os.path.exists(last_model_path):
trainer.fit(model, train_dataloader, ckpt_path=last_model_path)
else:
trainer.fit(model, train_dataloader)
|