| import math |
| from typing import List, Optional, Tuple, Union |
|
|
| import torch |
|
|
|
|
| class NullCtx: |
| def __enter__(self): |
| pass |
| |
| def __exit__(self, exc_type, exc_val, exc_tb): |
| pass |
|
|
|
|
| class AmpOptimizer: |
| def __init__( |
| self, |
| mixed_precision: int, |
| optimizer: torch.optim.Optimizer, names: List[str], paras: List[torch.nn.Parameter], |
| grad_clip: float, n_gradient_accumulation: int = 1, |
| ): |
| self.enable_amp = mixed_precision > 0 |
| self.using_fp16_rather_bf16 = mixed_precision == 1 |
| |
| if self.enable_amp: |
| self.amp_ctx = torch.autocast('cuda', enabled=True, dtype=torch.float16 if self.using_fp16_rather_bf16 else torch.bfloat16, cache_enabled=True) |
| self.scaler = torch.cuda.amp.GradScaler(init_scale=2. ** 11, growth_interval=1000) if self.using_fp16_rather_bf16 else None |
| else: |
| self.amp_ctx = NullCtx() |
| self.scaler = None |
| |
| self.optimizer, self.names, self.paras = optimizer, names, paras |
| self.grad_clip = grad_clip |
| self.early_clipping = self.grad_clip > 0 and not hasattr(optimizer, 'global_grad_norm') |
| self.late_clipping = self.grad_clip > 0 and hasattr(optimizer, 'global_grad_norm') |
| |
| self.r_accu = 1 / n_gradient_accumulation |
| |
| def backward_clip_step( |
| self, stepping: bool, loss: torch.Tensor, |
| ) -> Tuple[Optional[Union[torch.Tensor, float]], Optional[float]]: |
| |
| loss = loss.mul(self.r_accu) |
| orig_norm = scaler_sc = None |
| if self.scaler is not None: |
| self.scaler.scale(loss).backward(retain_graph=False, create_graph=False) |
| else: |
| loss.backward(retain_graph=False, create_graph=False) |
| |
| if stepping: |
| if self.scaler is not None: self.scaler.unscale_(self.optimizer) |
| if self.early_clipping: |
| orig_norm = torch.nn.utils.clip_grad_norm_(self.paras, self.grad_clip) |
| |
| if self.scaler is not None: |
| self.scaler.step(self.optimizer) |
| scaler_sc: float = self.scaler.get_scale() |
| if scaler_sc > 32768.: |
| self.scaler.update(new_scale=32768.) |
| else: |
| self.scaler.update() |
| try: |
| scaler_sc = float(math.log2(scaler_sc)) |
| except Exception as e: |
| print(f'[scaler_sc = {scaler_sc}]\n' * 15, flush=True) |
| raise e |
| else: |
| self.optimizer.step() |
| |
| if self.late_clipping: |
| orig_norm = self.optimizer.global_grad_norm |
| |
| self.optimizer.zero_grad(set_to_none=True) |
| |
| return orig_norm, scaler_sc |
| |
| def state_dict(self): |
| return { |
| 'optimizer': self.optimizer.state_dict() |
| } if self.scaler is None else { |
| 'scaler': self.scaler.state_dict(), |
| 'optimizer': self.optimizer.state_dict() |
| } |
| |
| def load_state_dict(self, state, strict=True): |
| if self.scaler is not None: |
| try: self.scaler.load_state_dict(state['scaler']) |
| except Exception as e: print(f'[fp16 load_state_dict err] {e}') |
| self.optimizer.load_state_dict(state['optimizer']) |
|
|