File size: 6,082 Bytes
1c8e113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from oldm.hack import disable_verbosity
disable_verbosity()

import os
import sys
import torch
from datetime import datetime

file_path = os.path.abspath(__file__)
parent_dir = os.path.abspath(os.path.dirname(file_path) + '/..')
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
from oldm.logger import ImageLogger
from oldm.model import create_model, load_state_dict
from dataset.utils import return_dataset

from oft import inject_trainable_oft, inject_trainable_oft_conv, inject_trainable_oft_extended, inject_trainable_oft_with_norm
from hra import inject_trainable_hra
from lora import inject_trainable_lora

import argparse


parser = argparse.ArgumentParser()

# HRA
parser.add_argument('--hra_r', type=int, default=8)
parser.add_argument('--hra_apply_GS', action='store_true', default=False)
parser.add_argument('--hra_lamb', type=float, default=0.0)

# OFT
parser.add_argument('--oft_r', type=int, default=4)
parser.add_argument('--oft_eps', type=float, default=7e-6)
parser.add_argument('--oft_coft', action="store_true", default=True)
parser.add_argument('--oft_block_share', action="store_true", default=False)

parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--num_samples', type=int, default=1)
parser.add_argument('--plot_frequency', type=int, default=100)
parser.add_argument('--learning_rate', type=float, default=9e-4)
parser.add_argument('--sd_locked', action="store_true", default=True)
parser.add_argument('--only_mid_control', action="store_true", default=False)
parser.add_argument('--num_gpus', type=int, default=torch.cuda.device_count())
parser.add_argument('--resume_path', 
                    type=str, 
                    default='./models/hra_half_init_l_8.ckpt',
                    )
parser.add_argument('--time_str', type=str, default=datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f"))
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--control', 
                    type=str, 
                    help='control signal. Options are [segm, sketch, densepose, depth, canny, landmark]', 
                    default="segm")

args = parser.parse_args()


if __name__ == "__main__":
    # specify the control signal and dataset
    control = args.control
    
    # create dataset
    train_dataset, val_dataset, data_name, logger_freq, max_epochs = return_dataset(control) # , n_samples=n_samples)

    # Configs
    resume_path = args.resume_path
    
    batch_size = args.batch_size
    num_samples = args.num_samples
    plot_frequency = args.plot_frequency
    learning_rate = args.learning_rate
    sd_locked = args.sd_locked
    only_mid_control = args.only_mid_control
    num_gpus = args.num_gpus
    time_str = args.time_str
    num_workers = args.num_workers
    
    for arg in vars(args):
        print(f'{arg}: {getattr(args, arg)}')
    print(f'data_name: {data_name}\nlogger_freq: {logger_freq}\nmax_epochs: {max_epochs}')
    
    if 'oft' in args.resume_path:
        experiment = 'oft_{}_{}_eps_{}_pe_diff_mlp_r_{}_{}gpu_{}'.format(data_name, control, args.oft_eps, args.oft_r, num_gpus, time_str)
    elif 'hra' in args.resume_path:
        if args.hra_apply_GS:
            experiment = 'hra_apply_GS_{}_{}_pe_diff_mlp_r_{}_{}gpu_{}'.format(data_name, control, args.hra_r, num_gpus, time_str)
        else:
            experiment = 'hra_{}_{}_pe_diff_mlp_r_{}_lambda_{}_lr_{}_{}gpu_{}'.format(data_name, control, args.hra_r,  args.hra_lamb, args.learning_rate, num_gpus, time_str)
    elif 'lora' in args.resume_path:
        experiment = 'lora_{}_{}_pe_diff_mlp_r_{}_{}gpu_{}'.format(data_name, control, args.r, num_gpus, time_str)

    # First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
    model = create_model('./configs/oft_ldm_v15.yaml').cpu()
    model.model.requires_grad_(False)
    print(f'Total parameters not requiring grad: {sum([p.numel() for p in model.model.parameters() if p.requires_grad == False])}')

    # inject trainable oft parameters
    if 'oft' in args.resume_path:
        unet_lora_params, train_names = inject_trainable_oft(model.model, r=args.oft_r, eps=args.oft_eps, is_coft=args.oft_coft, block_share=args.oft_block_share)
    elif 'hra' in args.resume_path:
        unet_lora_params, train_names = inject_trainable_hra(model.model, r=args.hra_r, apply_GS=args.hra_apply_GS)
    elif 'lora' in args.resume_path:
        unet_lora_params, train_names = inject_trainable_lora(model.model, rank=args.r, network_alpha=None)
    
    print(f'Total parameters requiring grad: {sum([p.numel() for p in model.model.parameters() if p.requires_grad == True])}')

    model.load_state_dict(load_state_dict(resume_path, location='cpu'))
    model.learning_rate = learning_rate
    model.sd_locked = sd_locked
    model.only_mid_control = only_mid_control

    checkpoint_callback = ModelCheckpoint(
        dirpath='log/image_log_' + experiment,
        filename='model-{epoch:02d}',
        save_top_k=-1,
        save_last=True,
        every_n_epochs=1,
        monitor=None,  # No specific metric to monitor for saving
    )

    # Misc
    train_dataloader = DataLoader(train_dataset, num_workers=num_workers, batch_size=batch_size, shuffle=False)
    val_dataloader = DataLoader(val_dataset, num_workers=num_workers, batch_size=1, shuffle=False)

    logger = ImageLogger(
        val_dataloader=val_dataloader,
        batch_frequency=logger_freq, 
        experiment=experiment, 
        plot_frequency=plot_frequency,
        num_samples=num_samples,
    )

    trainer = pl.Trainer(
        max_epochs=max_epochs,
        gpus=num_gpus, 
        precision=32, 
        callbacks=[logger, checkpoint_callback],
    )

    # Train!
    last_model_path = 'log/image_log_' + experiment + '/last.ckpt'
    if os.path.exists(last_model_path):
        trainer.fit(model, train_dataloader, ckpt_path=last_model_path)
    else:
        trainer.fit(model, train_dataloader)