#This code file is from [https://github.com/hao-ai-lab/FastVideo], which is licensed under Apache License 2.0. import torch from accelerate.logging import get_logger logger = get_logger(__name__) def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False): # Optimizer creation supported_optimizers = ["adam", "adamw", "prodigy"] if args.optimizer not in supported_optimizers: logger.warning( f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW" ) args.optimizer = "adamw" if args.use_8bit_adam and not (args.optimizer.lower() not in ["adam", "adamw"]): logger.warning( f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was " f"set to {args.optimizer.lower()}") if args.use_8bit_adam: try: import bitsandbytes as bnb except ImportError: raise ImportError( "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." ) if args.optimizer.lower() == "adamw": optimizer_class = (bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW) optimizer = optimizer_class( params_to_optimize, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_epsilon, weight_decay=args.adam_weight_decay, ) elif args.optimizer.lower() == "adam": optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam optimizer = optimizer_class( params_to_optimize, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_epsilon, weight_decay=args.adam_weight_decay, ) elif args.optimizer.lower() == "prodigy": try: import prodigyopt except ImportError: raise ImportError( "To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`" ) optimizer_class = prodigyopt.Prodigy if args.learning_rate <= 0.1: logger.warning( "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" ) optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), beta3=args.prodigy_beta3, weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, decouple=args.prodigy_decouple, use_bias_correction=args.prodigy_use_bias_correction, safeguard_warmup=args.prodigy_safeguard_warmup, ) return optimizer