PICS / run_train.py
Hang Zhou
Upload run_train.py with huggingface_hub
b230236 verified
import os
import argparse
import pytorch_lightning as pl
from braceexpand import braceexpand
from torch.utils.data import DataLoader
from datasets.webdataset import MultiWebDataset
from cldm.logger import ImageLogger
from cldm.model import create_model, load_state_dict
from torch.utils.data import ConcatDataset
from cldm.hack import disable_verbosity, enable_sliced_attention
from omegaconf import OmegaConf
import torch
from datasets.base import BaseDataset
class BaseLogic(BaseDataset):
def __init__(self, area_ratio, obj_thr):
self.area_ratio = area_ratio
self.obj_thr = obj_thr
print("Number of GPUs available: ", torch.cuda.device_count())
print("Current device: ", torch.cuda.current_device())
print("Device name: ", torch.cuda.get_device_name(0))
def get_args_parser():
parser = argparse.ArgumentParser('PICS Training Script', add_help=False)
parser.add_argument('--resume_path', required=None, type=str)
parser.add_argument('--root_dir', required=True, type=str)
parser.add_argument('--batch_size', default=1, type=int)
parser.add_argument('--limit_train_batches', default=1, type=float)
parser.add_argument('--logger_freq', default=1000, type=int)
parser.add_argument('--learning_rate', default=1e-5, type=float)
parser.add_argument('--is_joint', action='store_true', help="Joint/Seprate training")
parser.add_argument("--dataset_name", type=str, default='lvis', help="Dataset name")
return parser
def main(args):
save_memory = False
disable_verbosity()
if save_memory:
enable_sliced_attention()
sd_locked = False
only_mid_control = False
accumulate_grad_batches = 1
obj_thr = {'obj_thr': 2}
model = create_model('./configs/pics.yaml').cpu()
if args.resume_path and os.path.exists(args.resume_path):
print(f"Loading checkpoint from: {args.resume_path}")
checkpoint = load_state_dict(args.resume_path, location='cpu')
model.load_state_dict(checkpoint, strict=False)
else:
print("No checkpoint found or provided. Training from scratch...")
model.learning_rate = args.learning_rate
model.sd_locked = sd_locked
model.only_mid_control = only_mid_control
DConf = OmegaConf.load('./configs/datasets.yaml')
if args.is_joint:
# weights = {'LVIS': 30, 'VITONHD': 60, 'Objects365': 1, 'Cityscapes': 180, 'MapillaryVistas': 180,'BDD100K': 180}
weights = {'LVIS': 3, 'VITONHD': 6, 'Objects365': 1, 'Cityscapes': 18, 'MapillaryVistas': 18, 'BDD100K': 18}
else:
if args.dataset_name == 'lvis':
weights = {'LVIS': 1, 'VITONHD': 0, 'Objects365': 0, 'Cityscapes': 0, 'MapillaryVistas': 0, 'BDD100K': 0}
elif args.dataset_name == 'vitonhd':
weights = {'LVIS': 0, 'VITONHD': 1, 'Objects365': 0, 'Cityscapes': 0, 'MapillaryVistas': 0, 'BDD100K': 0}
elif args.dataset_name == 'object365':
weights = {'LVIS': 0, 'VITONHD': 0, 'Objects365': 1, 'Cityscapes': 0, 'MapillaryVistas': 0, 'BDD100K': 0}
elif args.dataset_name == 'cityscapes':
weights = {'LVIS': 0, 'VITONHD': 0, 'Objects365': 0, 'Cityscapes': 1, 'MapillaryVistas': 0, 'BDD100K': 0}
elif args.dataset_name == 'mapillaryvistas':
weights = {'LVIS': 0, 'VITONHD': 0, 'Objects365': 0, 'Cityscapes': 0, 'MapillaryVistas': 1, 'BDD100K': 0}
elif args.dataset_name == 'bdd100k':
weights = {'LVIS': 0, 'VITONHD': 0, 'Objects365': 0, 'Cityscapes': 0, 'MapillaryVistas': 0, 'BDD100K': 1}
else:
raise ValueError(f"Unsupported dataset name: {args.dataset_name}")
all_urls = []
dataset_shards = [
('LVIS', DConf.Train.LVIS.shards),
('VITONHD', DConf.Train.VITONHD.shards),
('Objects365', DConf.Train.Objects365.shards),
('Cityscapes', DConf.Train.Cityscapes.shards),
('MapillaryVistas', DConf.Train.MapillaryVistas.shards),
('BDD100K', DConf.Train.BDD100K.shards)
]
for name, path in dataset_shards:
expanded = list(braceexpand(path))
all_urls.extend(expanded * weights.get(name, 1))
import random
random.shuffle(all_urls)
logic_helper = BaseLogic(
area_ratio=DConf.Defaults.area_ratio,
obj_thr=DConf.Defaults.obj_thr
)
dataset = MultiWebDataset(
urls=all_urls,
construct_collage_fn=logic_helper._construct_collage,
shuffle_size=10000,
seed=42,
decode_mode="pil",
)
dataloader = DataLoader(
dataset,
num_workers=8,
batch_size=args.batch_size,
)
logger = ImageLogger(batch_frequency=args.logger_freq, log_images_kwargs=obj_thr)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
dirpath=os.path.join(args.root_dir, 'checkpoints'),
filename='pics-{step:06d}',
every_n_train_steps=2000,
save_top_k=-1,
)
trainer = pl.Trainer(
default_root_dir=args.root_dir,
limit_train_batches=args.limit_train_batches,
accelerator="gpu",
devices=1,
precision=16,
callbacks=[logger, checkpoint_callback],
accumulate_grad_batches=accumulate_grad_batches,
max_epochs=50,
val_check_interval=2000,
)
trainer.fit(model, dataloader)
if __name__ == '__main__':
parser = argparse.ArgumentParser('PICS Training', parents=[get_args_parser()])
args = parser.parse_args()
main(args)