File size: 5,491 Bytes
b230236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
import argparse
import pytorch_lightning as pl
from braceexpand import braceexpand
from torch.utils.data import DataLoader
from datasets.webdataset import MultiWebDataset

from cldm.logger import ImageLogger
from cldm.model import create_model, load_state_dict
from torch.utils.data import ConcatDataset
from cldm.hack import disable_verbosity, enable_sliced_attention
from omegaconf import OmegaConf
import torch

from datasets.base import BaseDataset

class BaseLogic(BaseDataset):
    def __init__(self, area_ratio, obj_thr):
        self.area_ratio = area_ratio
        self.obj_thr = obj_thr

print("Number of GPUs available: ", torch.cuda.device_count())
print("Current device: ", torch.cuda.current_device())
print("Device name: ", torch.cuda.get_device_name(0))

def get_args_parser():
    parser = argparse.ArgumentParser('PICS Training Script', add_help=False)

    parser.add_argument('--resume_path', required=None, type=str)
    parser.add_argument('--root_dir', required=True, type=str)
    parser.add_argument('--batch_size', default=1, type=int)
    parser.add_argument('--limit_train_batches', default=1, type=float)
    parser.add_argument('--logger_freq', default=1000, type=int)
    parser.add_argument('--learning_rate', default=1e-5, type=float)
    parser.add_argument('--is_joint', action='store_true', help="Joint/Seprate training")
    parser.add_argument("--dataset_name", type=str, default='lvis', help="Dataset name")

    return parser

def main(args):
    save_memory = False
    disable_verbosity()
    if save_memory:
        enable_sliced_attention()
    
    sd_locked = False
    only_mid_control = False
    accumulate_grad_batches = 1
    obj_thr = {'obj_thr': 2}

    model = create_model('./configs/pics.yaml').cpu()
    if args.resume_path and os.path.exists(args.resume_path):
        print(f"Loading checkpoint from: {args.resume_path}")
        checkpoint = load_state_dict(args.resume_path, location='cpu')
        model.load_state_dict(checkpoint, strict=False)
    else:
        print("No checkpoint found or provided. Training from scratch...")

    model.learning_rate = args.learning_rate
    model.sd_locked = sd_locked
    model.only_mid_control = only_mid_control

    DConf = OmegaConf.load('./configs/datasets.yaml')

    if args.is_joint:
        # weights = {'LVIS': 30, 'VITONHD': 60, 'Objects365': 1, 'Cityscapes': 180, 'MapillaryVistas': 180,'BDD100K': 180}
        weights = {'LVIS': 3, 'VITONHD': 6, 'Objects365': 1, 'Cityscapes': 18, 'MapillaryVistas': 18, 'BDD100K': 18}
    else:
        if args.dataset_name == 'lvis':
            weights = {'LVIS': 1, 'VITONHD': 0, 'Objects365': 0, 'Cityscapes': 0, 'MapillaryVistas': 0, 'BDD100K': 0}
        elif args.dataset_name == 'vitonhd':
            weights = {'LVIS': 0, 'VITONHD': 1, 'Objects365': 0, 'Cityscapes': 0, 'MapillaryVistas': 0, 'BDD100K': 0}
        elif args.dataset_name == 'object365':
            weights = {'LVIS': 0, 'VITONHD': 0, 'Objects365': 1, 'Cityscapes': 0, 'MapillaryVistas': 0, 'BDD100K': 0}
        elif args.dataset_name == 'cityscapes':
            weights = {'LVIS': 0, 'VITONHD': 0, 'Objects365': 0, 'Cityscapes': 1, 'MapillaryVistas': 0, 'BDD100K': 0}
        elif args.dataset_name == 'mapillaryvistas':
            weights = {'LVIS': 0, 'VITONHD': 0, 'Objects365': 0, 'Cityscapes': 0, 'MapillaryVistas': 1, 'BDD100K': 0}
        elif args.dataset_name == 'bdd100k':
            weights = {'LVIS': 0, 'VITONHD': 0, 'Objects365': 0, 'Cityscapes': 0, 'MapillaryVistas': 0, 'BDD100K': 1}
        else:
            raise ValueError(f"Unsupported dataset name: {args.dataset_name}")
        
    all_urls = []
    dataset_shards = [
        ('LVIS', DConf.Train.LVIS.shards),
        ('VITONHD', DConf.Train.VITONHD.shards),
        ('Objects365', DConf.Train.Objects365.shards),
        ('Cityscapes', DConf.Train.Cityscapes.shards),
        ('MapillaryVistas', DConf.Train.MapillaryVistas.shards),
        ('BDD100K', DConf.Train.BDD100K.shards)
    ]

    for name, path in dataset_shards:
        expanded = list(braceexpand(path))
        all_urls.extend(expanded * weights.get(name, 1))
    
    import random
    random.shuffle(all_urls)

    logic_helper = BaseLogic(
        area_ratio=DConf.Defaults.area_ratio, 
        obj_thr=DConf.Defaults.obj_thr
    )

    dataset = MultiWebDataset(
        urls=all_urls,
        construct_collage_fn=logic_helper._construct_collage, 
        shuffle_size=10000,
        seed=42,
        decode_mode="pil",
    )

    dataloader = DataLoader(
        dataset, 
        num_workers=8, 
        batch_size=args.batch_size, 
    )
    
    logger = ImageLogger(batch_frequency=args.logger_freq, log_images_kwargs=obj_thr)

    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath=os.path.join(args.root_dir, 'checkpoints'),
        filename='pics-{step:06d}', 
        every_n_train_steps=2000, 
        save_top_k=-1, 
    )

    trainer = pl.Trainer(
        default_root_dir=args.root_dir,
        limit_train_batches=args.limit_train_batches,
        accelerator="gpu",
        devices=1,
        precision=16,
        callbacks=[logger, checkpoint_callback],
        accumulate_grad_batches=accumulate_grad_batches,
        max_epochs=50,
        val_check_interval=2000,
    )
    trainer.fit(model, dataloader)


if __name__ == '__main__':
    parser = argparse.ArgumentParser('PICS Training', parents=[get_args_parser()])
    args = parser.parse_args()
    main(args)