| | import logging |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None): |
| | r"""Check model gradient against unexpected jumps and failures""" |
| | skip_flag = False |
| | if ignore_stopnet: |
| | if not amp_opt_params: |
| | grad_norm = torch.nn.utils.clip_grad_norm_( |
| | [param for name, param in model.named_parameters() if "stopnet" not in name], grad_clip |
| | ) |
| | else: |
| | grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip) |
| | else: |
| | if not amp_opt_params: |
| | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) |
| | else: |
| | grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip) |
| |
|
| | |
| | if isinstance(grad_norm, float): |
| | if np.isinf(grad_norm): |
| | logger.warning("Gradient is INF !!") |
| | skip_flag = True |
| | else: |
| | if torch.isinf(grad_norm): |
| | logger.warning("Gradient is INF !!") |
| | skip_flag = True |
| | return grad_norm, skip_flag |
| |
|
| |
|
| | def gradual_training_scheduler(global_step, config): |
| | """Setup the gradual training schedule wrt number |
| | of active GPUs""" |
| | num_gpus = torch.cuda.device_count() |
| | if num_gpus == 0: |
| | num_gpus = 1 |
| | new_values = None |
| | |
| | for values in config.gradual_training: |
| | if global_step * num_gpus >= values[0]: |
| | new_values = values |
| | return new_values[1], new_values[2] |
| |
|