| | |
| |
|
| | import torch |
| | from accelerate.logging import get_logger |
| |
|
| | logger = get_logger(__name__) |
| |
|
| |
|
| | def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False): |
| | |
| | supported_optimizers = ["adam", "adamw", "prodigy"] |
| | if args.optimizer not in supported_optimizers: |
| | logger.warning( |
| | f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW" |
| | ) |
| | args.optimizer = "adamw" |
| |
|
| | if args.use_8bit_adam and not (args.optimizer.lower() |
| | not in ["adam", "adamw"]): |
| | logger.warning( |
| | f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was " |
| | f"set to {args.optimizer.lower()}") |
| |
|
| | if args.use_8bit_adam: |
| | try: |
| | import bitsandbytes as bnb |
| | except ImportError: |
| | raise ImportError( |
| | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." |
| | ) |
| |
|
| | if args.optimizer.lower() == "adamw": |
| | optimizer_class = (bnb.optim.AdamW8bit |
| | if args.use_8bit_adam else torch.optim.AdamW) |
| |
|
| | optimizer = optimizer_class( |
| | params_to_optimize, |
| | betas=(args.adam_beta1, args.adam_beta2), |
| | eps=args.adam_epsilon, |
| | weight_decay=args.adam_weight_decay, |
| | ) |
| | elif args.optimizer.lower() == "adam": |
| | optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam |
| |
|
| | optimizer = optimizer_class( |
| | params_to_optimize, |
| | betas=(args.adam_beta1, args.adam_beta2), |
| | eps=args.adam_epsilon, |
| | weight_decay=args.adam_weight_decay, |
| | ) |
| | elif args.optimizer.lower() == "prodigy": |
| | try: |
| | import prodigyopt |
| | except ImportError: |
| | raise ImportError( |
| | "To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`" |
| | ) |
| |
|
| | optimizer_class = prodigyopt.Prodigy |
| |
|
| | if args.learning_rate <= 0.1: |
| | logger.warning( |
| | "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" |
| | ) |
| |
|
| | optimizer = optimizer_class( |
| | params_to_optimize, |
| | lr=args.learning_rate, |
| | betas=(args.adam_beta1, args.adam_beta2), |
| | beta3=args.prodigy_beta3, |
| | weight_decay=args.adam_weight_decay, |
| | eps=args.adam_epsilon, |
| | decouple=args.prodigy_decouple, |
| | use_bias_correction=args.prodigy_use_bias_correction, |
| | safeguard_warmup=args.prodigy_safeguard_warmup, |
| | ) |
| |
|
| | return optimizer |
| |
|