Hang Zhou commited on
Commit
b230236
·
verified ·
1 Parent(s): bd557e8

Upload run_train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. run_train.py +147 -0
run_train.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import pytorch_lightning as pl
4
+ from braceexpand import braceexpand
5
+ from torch.utils.data import DataLoader
6
+ from datasets.webdataset import MultiWebDataset
7
+
8
+ from cldm.logger import ImageLogger
9
+ from cldm.model import create_model, load_state_dict
10
+ from torch.utils.data import ConcatDataset
11
+ from cldm.hack import disable_verbosity, enable_sliced_attention
12
+ from omegaconf import OmegaConf
13
+ import torch
14
+
15
+ from datasets.base import BaseDataset
16
+
17
+ class BaseLogic(BaseDataset):
18
+ def __init__(self, area_ratio, obj_thr):
19
+ self.area_ratio = area_ratio
20
+ self.obj_thr = obj_thr
21
+
22
+ print("Number of GPUs available: ", torch.cuda.device_count())
23
+ print("Current device: ", torch.cuda.current_device())
24
+ print("Device name: ", torch.cuda.get_device_name(0))
25
+
26
+ def get_args_parser():
27
+ parser = argparse.ArgumentParser('PICS Training Script', add_help=False)
28
+
29
+ parser.add_argument('--resume_path', required=None, type=str)
30
+ parser.add_argument('--root_dir', required=True, type=str)
31
+ parser.add_argument('--batch_size', default=1, type=int)
32
+ parser.add_argument('--limit_train_batches', default=1, type=float)
33
+ parser.add_argument('--logger_freq', default=1000, type=int)
34
+ parser.add_argument('--learning_rate', default=1e-5, type=float)
35
+ parser.add_argument('--is_joint', action='store_true', help="Joint/Seprate training")
36
+ parser.add_argument("--dataset_name", type=str, default='lvis', help="Dataset name")
37
+
38
+ return parser
39
+
40
+ def main(args):
41
+ save_memory = False
42
+ disable_verbosity()
43
+ if save_memory:
44
+ enable_sliced_attention()
45
+
46
+ sd_locked = False
47
+ only_mid_control = False
48
+ accumulate_grad_batches = 1
49
+ obj_thr = {'obj_thr': 2}
50
+
51
+ model = create_model('./configs/pics.yaml').cpu()
52
+ if args.resume_path and os.path.exists(args.resume_path):
53
+ print(f"Loading checkpoint from: {args.resume_path}")
54
+ checkpoint = load_state_dict(args.resume_path, location='cpu')
55
+ model.load_state_dict(checkpoint, strict=False)
56
+ else:
57
+ print("No checkpoint found or provided. Training from scratch...")
58
+
59
+ model.learning_rate = args.learning_rate
60
+ model.sd_locked = sd_locked
61
+ model.only_mid_control = only_mid_control
62
+
63
+ DConf = OmegaConf.load('./configs/datasets.yaml')
64
+
65
+ if args.is_joint:
66
+ # weights = {'LVIS': 30, 'VITONHD': 60, 'Objects365': 1, 'Cityscapes': 180, 'MapillaryVistas': 180,'BDD100K': 180}
67
+ weights = {'LVIS': 3, 'VITONHD': 6, 'Objects365': 1, 'Cityscapes': 18, 'MapillaryVistas': 18, 'BDD100K': 18}
68
+ else:
69
+ if args.dataset_name == 'lvis':
70
+ weights = {'LVIS': 1, 'VITONHD': 0, 'Objects365': 0, 'Cityscapes': 0, 'MapillaryVistas': 0, 'BDD100K': 0}
71
+ elif args.dataset_name == 'vitonhd':
72
+ weights = {'LVIS': 0, 'VITONHD': 1, 'Objects365': 0, 'Cityscapes': 0, 'MapillaryVistas': 0, 'BDD100K': 0}
73
+ elif args.dataset_name == 'object365':
74
+ weights = {'LVIS': 0, 'VITONHD': 0, 'Objects365': 1, 'Cityscapes': 0, 'MapillaryVistas': 0, 'BDD100K': 0}
75
+ elif args.dataset_name == 'cityscapes':
76
+ weights = {'LVIS': 0, 'VITONHD': 0, 'Objects365': 0, 'Cityscapes': 1, 'MapillaryVistas': 0, 'BDD100K': 0}
77
+ elif args.dataset_name == 'mapillaryvistas':
78
+ weights = {'LVIS': 0, 'VITONHD': 0, 'Objects365': 0, 'Cityscapes': 0, 'MapillaryVistas': 1, 'BDD100K': 0}
79
+ elif args.dataset_name == 'bdd100k':
80
+ weights = {'LVIS': 0, 'VITONHD': 0, 'Objects365': 0, 'Cityscapes': 0, 'MapillaryVistas': 0, 'BDD100K': 1}
81
+ else:
82
+ raise ValueError(f"Unsupported dataset name: {args.dataset_name}")
83
+
84
+ all_urls = []
85
+ dataset_shards = [
86
+ ('LVIS', DConf.Train.LVIS.shards),
87
+ ('VITONHD', DConf.Train.VITONHD.shards),
88
+ ('Objects365', DConf.Train.Objects365.shards),
89
+ ('Cityscapes', DConf.Train.Cityscapes.shards),
90
+ ('MapillaryVistas', DConf.Train.MapillaryVistas.shards),
91
+ ('BDD100K', DConf.Train.BDD100K.shards)
92
+ ]
93
+
94
+ for name, path in dataset_shards:
95
+ expanded = list(braceexpand(path))
96
+ all_urls.extend(expanded * weights.get(name, 1))
97
+
98
+ import random
99
+ random.shuffle(all_urls)
100
+
101
+ logic_helper = BaseLogic(
102
+ area_ratio=DConf.Defaults.area_ratio,
103
+ obj_thr=DConf.Defaults.obj_thr
104
+ )
105
+
106
+ dataset = MultiWebDataset(
107
+ urls=all_urls,
108
+ construct_collage_fn=logic_helper._construct_collage,
109
+ shuffle_size=10000,
110
+ seed=42,
111
+ decode_mode="pil",
112
+ )
113
+
114
+ dataloader = DataLoader(
115
+ dataset,
116
+ num_workers=8,
117
+ batch_size=args.batch_size,
118
+ )
119
+
120
+ logger = ImageLogger(batch_frequency=args.logger_freq, log_images_kwargs=obj_thr)
121
+
122
+ checkpoint_callback = pl.callbacks.ModelCheckpoint(
123
+ dirpath=os.path.join(args.root_dir, 'checkpoints'),
124
+ filename='pics-{step:06d}',
125
+ every_n_train_steps=2000,
126
+ save_top_k=-1,
127
+ )
128
+
129
+ trainer = pl.Trainer(
130
+ default_root_dir=args.root_dir,
131
+ limit_train_batches=args.limit_train_batches,
132
+ accelerator="gpu",
133
+ devices=1,
134
+ precision=16,
135
+ callbacks=[logger, checkpoint_callback],
136
+ accumulate_grad_batches=accumulate_grad_batches,
137
+ max_epochs=50,
138
+ val_check_interval=2000,
139
+ )
140
+ trainer.fit(model, dataloader)
141
+
142
+
143
+ if __name__ == '__main__':
144
+ parser = argparse.ArgumentParser('PICS Training', parents=[get_args_parser()])
145
+ args = parser.parse_args()
146
+ main(args)
147
+