Spaces:
Sleeping
Sleeping
model setup optimizations
Browse files
main.py
CHANGED
|
@@ -15,12 +15,11 @@ from rewards import get_reward_losses
|
|
| 15 |
from training import LatentNoiseTrainer, get_optimizer
|
| 16 |
|
| 17 |
|
| 18 |
-
def setup(args):
|
| 19 |
-
|
| 20 |
seed_everything(args.seed)
|
| 21 |
bf.makedirs(f"{args.save_dir}/logs/{args.task}")
|
|
|
|
| 22 |
# Set up logging and name settings
|
| 23 |
-
# Get the root logger and clear existing handlers
|
| 24 |
logger = logging.getLogger()
|
| 25 |
logger.handlers.clear() # Clear existing handlers
|
| 26 |
settings = (
|
|
@@ -34,6 +33,7 @@ def setup(args):
|
|
| 34 |
f"{'_imagereward' + str(args.imagereward_weighting) if args.enable_imagereward else ''}"
|
| 35 |
f"{'_aesthetic' + str(args.aesthetic_weighting) if args.enable_aesthetic else ''}"
|
| 36 |
)
|
|
|
|
| 37 |
file_stream = open(f"{args.save_dir}/logs/{args.task}/{settings}.txt", "w")
|
| 38 |
handler = logging.StreamHandler(file_stream)
|
| 39 |
formatter = logging.Formatter("%(asctime)s - %(message)s")
|
|
@@ -43,16 +43,68 @@ def setup(args):
|
|
| 43 |
consoleHandler = logging.StreamHandler()
|
| 44 |
consoleHandler.setFormatter(formatter)
|
| 45 |
logger.addHandler(consoleHandler)
|
|
|
|
| 46 |
logging.info(args)
|
|
|
|
| 47 |
if args.device_id is not None:
|
| 48 |
logging.info(f"Using CUDA device {args.device_id}")
|
| 49 |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
| 50 |
os.environ["CUDA_VISIBLE_DEVICES"] = args.device_id
|
|
|
|
| 51 |
device = torch.device("cuda")
|
| 52 |
if args.dtype == "float32":
|
| 53 |
dtype = torch.float32
|
| 54 |
elif args.dtype == "float16":
|
| 55 |
dtype = torch.float16
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
# Get reward losses
|
| 57 |
reward_losses = get_reward_losses(args, dtype, device, args.cache_dir)
|
| 58 |
|
|
@@ -63,7 +115,7 @@ def setup(args):
|
|
| 63 |
|
| 64 |
torch.cuda.empty_cache() # Free up cached memory
|
| 65 |
gc.collect()
|
| 66 |
-
|
| 67 |
trainer = LatentNoiseTrainer(
|
| 68 |
reward_losses=reward_losses,
|
| 69 |
model=pipe,
|
|
@@ -85,7 +137,6 @@ def setup(args):
|
|
| 85 |
|
| 86 |
# Create latents
|
| 87 |
if args.model == "flux":
|
| 88 |
-
# currently only support 512x512 generation
|
| 89 |
shape = (1, 16 * 64, 64)
|
| 90 |
elif args.model != "pixart":
|
| 91 |
height = pipe.unet.config.sample_size * pipe.vae_scale_factor
|
|
@@ -107,6 +158,9 @@ def setup(args):
|
|
| 107 |
)
|
| 108 |
|
| 109 |
enable_grad = not args.no_optim
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
if args.enable_multi_apply:
|
| 112 |
multi_apply_fn = get_multi_apply_fn(
|
|
@@ -121,6 +175,7 @@ def setup(args):
|
|
| 121 |
multi_apply_fn = None
|
| 122 |
|
| 123 |
torch.cuda.empty_cache() # Free up cached memory
|
|
|
|
| 124 |
|
| 125 |
return args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings
|
| 126 |
|
|
@@ -308,7 +363,7 @@ def execute_task(args, trainer, device, dtype, shape, enable_grad, multi_apply_f
|
|
| 308 |
|
| 309 |
def main():
|
| 310 |
args = parse_args()
|
| 311 |
-
args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings = setup(args)
|
| 312 |
execute_task(args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings)
|
| 313 |
|
| 314 |
if __name__ == "__main__":
|
|
|
|
| 15 |
from training import LatentNoiseTrainer, get_optimizer
|
| 16 |
|
| 17 |
|
| 18 |
+
def setup(args, loaded_model_setup=None):
|
|
|
|
| 19 |
seed_everything(args.seed)
|
| 20 |
bf.makedirs(f"{args.save_dir}/logs/{args.task}")
|
| 21 |
+
|
| 22 |
# Set up logging and name settings
|
|
|
|
| 23 |
logger = logging.getLogger()
|
| 24 |
logger.handlers.clear() # Clear existing handlers
|
| 25 |
settings = (
|
|
|
|
| 33 |
f"{'_imagereward' + str(args.imagereward_weighting) if args.enable_imagereward else ''}"
|
| 34 |
f"{'_aesthetic' + str(args.aesthetic_weighting) if args.enable_aesthetic else ''}"
|
| 35 |
)
|
| 36 |
+
|
| 37 |
file_stream = open(f"{args.save_dir}/logs/{args.task}/{settings}.txt", "w")
|
| 38 |
handler = logging.StreamHandler(file_stream)
|
| 39 |
formatter = logging.Formatter("%(asctime)s - %(message)s")
|
|
|
|
| 43 |
consoleHandler = logging.StreamHandler()
|
| 44 |
consoleHandler.setFormatter(formatter)
|
| 45 |
logger.addHandler(consoleHandler)
|
| 46 |
+
|
| 47 |
logging.info(args)
|
| 48 |
+
|
| 49 |
if args.device_id is not None:
|
| 50 |
logging.info(f"Using CUDA device {args.device_id}")
|
| 51 |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
| 52 |
os.environ["CUDA_VISIBLE_DEVICES"] = args.device_id
|
| 53 |
+
|
| 54 |
device = torch.device("cuda")
|
| 55 |
if args.dtype == "float32":
|
| 56 |
dtype = torch.float32
|
| 57 |
elif args.dtype == "float16":
|
| 58 |
dtype = torch.float16
|
| 59 |
+
|
| 60 |
+
# If args.model is the same as the one in loaded_model_setup, reuse the trainer and pipe
|
| 61 |
+
if loaded_model_setup and args.model == loaded_model_setup[0].model:
|
| 62 |
+
# Reuse the trainer and pipe from the loaded model setup
|
| 63 |
+
print(f"Reusing model {args.model} from loaded setup.")
|
| 64 |
+
trainer = loaded_model_setup[1] # Trainer is at position 1 in loaded_model_setup
|
| 65 |
+
|
| 66 |
+
# Update trainer with the new arguments
|
| 67 |
+
trainer.n_iters = args.n_iters
|
| 68 |
+
trainer.n_inference_steps = args.n_inference_steps
|
| 69 |
+
trainer.seed = args.seed
|
| 70 |
+
trainer.save_all_images = args.save_all_images
|
| 71 |
+
trainer.no_optim = args.no_optim
|
| 72 |
+
trainer.regularize = args.enable_reg
|
| 73 |
+
trainer.regularization_weight = args.reg_weight
|
| 74 |
+
trainer.grad_clip = args.grad_clip
|
| 75 |
+
trainer.log_metrics = args.task == "single" or not args.no_optim
|
| 76 |
+
trainer.imageselect = args.imageselect
|
| 77 |
+
|
| 78 |
+
# Get latents (this step is still required)
|
| 79 |
+
if args.model == "flux":
|
| 80 |
+
shape = (1, 16 * 64, 64)
|
| 81 |
+
elif args.model != "pixart":
|
| 82 |
+
height = trainer.model.unet.config.sample_size * trainer.model.vae_scale_factor
|
| 83 |
+
width = trainer.model.unet.config.sample_size * trainer.model.vae_scale_factor
|
| 84 |
+
shape = (
|
| 85 |
+
1,
|
| 86 |
+
trainer.model.unet.in_channels,
|
| 87 |
+
height // trainer.model.vae_scale_factor,
|
| 88 |
+
width // trainer.model.vae_scale_factor,
|
| 89 |
+
)
|
| 90 |
+
else:
|
| 91 |
+
height = trainer.model.transformer.config.sample_size * trainer.model.vae_scale_factor
|
| 92 |
+
width = trainer.model.transformer.config.sample_size * trainer.model.vae_scale_factor
|
| 93 |
+
shape = (
|
| 94 |
+
1,
|
| 95 |
+
trainer.model.transformer.config.in_channels,
|
| 96 |
+
height // trainer.model.vae_scale_factor,
|
| 97 |
+
width // trainer.model.vae_scale_factor,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
multi_apply_fn = loaded_model_setup[6]
|
| 101 |
+
enable_grad = not args.no_optim
|
| 102 |
+
|
| 103 |
+
return args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings
|
| 104 |
+
|
| 105 |
+
# Proceed with full model loading if args.model is different
|
| 106 |
+
print(f"Loading new model: {args.model}")
|
| 107 |
+
|
| 108 |
# Get reward losses
|
| 109 |
reward_losses = get_reward_losses(args, dtype, device, args.cache_dir)
|
| 110 |
|
|
|
|
| 115 |
|
| 116 |
torch.cuda.empty_cache() # Free up cached memory
|
| 117 |
gc.collect()
|
| 118 |
+
|
| 119 |
trainer = LatentNoiseTrainer(
|
| 120 |
reward_losses=reward_losses,
|
| 121 |
model=pipe,
|
|
|
|
| 137 |
|
| 138 |
# Create latents
|
| 139 |
if args.model == "flux":
|
|
|
|
| 140 |
shape = (1, 16 * 64, 64)
|
| 141 |
elif args.model != "pixart":
|
| 142 |
height = pipe.unet.config.sample_size * pipe.vae_scale_factor
|
|
|
|
| 158 |
)
|
| 159 |
|
| 160 |
enable_grad = not args.no_optim
|
| 161 |
+
|
| 162 |
+
torch.cuda.empty_cache() # Free up cached memory
|
| 163 |
+
gc.collect()
|
| 164 |
|
| 165 |
if args.enable_multi_apply:
|
| 166 |
multi_apply_fn = get_multi_apply_fn(
|
|
|
|
| 175 |
multi_apply_fn = None
|
| 176 |
|
| 177 |
torch.cuda.empty_cache() # Free up cached memory
|
| 178 |
+
gc.collect()
|
| 179 |
|
| 180 |
return args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings
|
| 181 |
|
|
|
|
| 363 |
|
| 364 |
def main():
|
| 365 |
args = parse_args()
|
| 366 |
+
args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings = setup(args, loaded_model_setup=None)
|
| 367 |
execute_task(args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings)
|
| 368 |
|
| 369 |
if __name__ == "__main__":
|