|
|
|
|
|
|
|
|
import math |
|
|
import os |
|
|
import signal |
|
|
import sys |
|
|
import time |
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
|
|
|
|
|
|
|
|
import infinity.utils.dist as dist |
|
|
|
|
|
class NullCtx: |
|
|
def __enter__(self): |
|
|
pass |
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
|
pass |
|
|
|
|
|
|
|
|
class AmpOptimizer: |
|
|
def __init__( |
|
|
self, |
|
|
model_name_3letters: str, mixed_precision: int, |
|
|
optimizer: torch.optim.Optimizer, model_maybe_fsdp: Union[torch.nn.Module, FSDP], |
|
|
r_accu: float, grad_clip: float, zero: int, |
|
|
): |
|
|
self.enable_amp = mixed_precision > 0 |
|
|
self.zero = zero |
|
|
if self.enable_amp: |
|
|
self.using_fp16_rather_bf16 = mixed_precision != 2 |
|
|
self.max_sc = float(mixed_precision if mixed_precision > 128 else 32768) |
|
|
|
|
|
self.amp_ctx = torch.autocast('cuda', enabled=True, dtype=torch.float16 if self.using_fp16_rather_bf16 else torch.bfloat16, cache_enabled=self.zero == 0) |
|
|
if self.using_fp16_rather_bf16: |
|
|
self.scaler = torch.cuda.amp.GradScaler(init_scale=2. ** 11, growth_interval=1000) |
|
|
else: |
|
|
self.scaler = None |
|
|
else: |
|
|
self.using_fp16_rather_bf16 = True |
|
|
self.amp_ctx = NullCtx() |
|
|
self.scaler = None |
|
|
|
|
|
t = torch.zeros(dist.get_world_size()) |
|
|
t[dist.get_rank()] = float(self.enable_amp) |
|
|
dist.allreduce(t) |
|
|
assert round(t.sum().item()) in {0, dist.get_world_size()}, f'enable_amp: {t}' |
|
|
|
|
|
t = torch.zeros(dist.get_world_size()) |
|
|
t[dist.get_rank()] = float(self.using_fp16_rather_bf16) |
|
|
dist.allreduce(t) |
|
|
assert round(t.sum().item()) in {0, dist.get_world_size()}, f'using_fp16_rather_bf16: {t}' |
|
|
|
|
|
self.model_name_3letters = model_name_3letters |
|
|
self.optimizer, self.model_maybe_fsdp = optimizer, model_maybe_fsdp |
|
|
self.r_accu = r_accu |
|
|
|
|
|
self.paras = self.names = ... |
|
|
|
|
|
self.grad_clip, self.grad_clip_we = grad_clip, 0 |
|
|
if self.grad_clip > 100: |
|
|
self.grad_clip %= 100 |
|
|
self.per_param = True |
|
|
else: |
|
|
self.per_param = False |
|
|
self.per_param = False |
|
|
|
|
|
self.early_clipping = grad_clip > 0 and not hasattr(optimizer, 'global_grad_norm') |
|
|
self.late_clipping = grad_clip > 0 and hasattr(optimizer, 'global_grad_norm') |
|
|
|
|
|
self.fp = None |
|
|
self.last_orig_norm: torch.Tensor = torch.tensor(0.1) |
|
|
|
|
|
|
|
|
|
|
|
def backward_clip_step( |
|
|
self, ep: int, it: int, g_it: int, stepping: bool, loss: torch.Tensor, clip_decay_ratio=1, stable=False, |
|
|
) -> Tuple[torch.Tensor, Optional[float]]: |
|
|
|
|
|
loss = loss.mul(self.r_accu) |
|
|
orig_norm = scaler_sc = None |
|
|
|
|
|
|
|
|
if self.scaler is not None: |
|
|
self.scaler.scale(loss).backward(retain_graph=False, create_graph=False) |
|
|
else: |
|
|
loss.backward(retain_graph=False, create_graph=False) |
|
|
|
|
|
|
|
|
|
|
|
if stepping: |
|
|
if self.scaler is not None: self.scaler.unscale_(self.optimizer) |
|
|
|
|
|
|
|
|
skipped, orig_norm = 0, self.last_orig_norm |
|
|
|
|
|
if self.fp is not None: |
|
|
if g_it % 10 == 0: self.fp.seek(0); self.fp.truncate(0) |
|
|
self.fp.write(f'<ep{ep} it{it} {g_it}>\n'); self.fp.flush() |
|
|
if self.early_clipping: |
|
|
c = self.grad_clip * clip_decay_ratio |
|
|
if self.zero: |
|
|
orig_norm: Optional[torch.Tensor] = self.model_maybe_fsdp.clip_grad_norm_(c) |
|
|
else: |
|
|
orig_norm: Optional[torch.Tensor] = torch.nn.utils.clip_grad_norm_(self.model_maybe_fsdp.parameters(), c) |
|
|
|
|
|
|
|
|
if self.scaler is not None: |
|
|
self.scaler: torch.cuda.amp.GradScaler |
|
|
if self.zero: |
|
|
|
|
|
|
|
|
for optimizer_state in self.scaler._per_optimizer_states.values(): |
|
|
for t in optimizer_state['found_inf_per_device'].values(): |
|
|
dist.allreduce(t) |
|
|
|
|
|
self.scaler.step(self.optimizer) |
|
|
scaler_sc: Optional[float] = self.scaler.get_scale() |
|
|
if scaler_sc > self.max_sc: |
|
|
|
|
|
self.scaler.update(new_scale=self.max_sc) |
|
|
else: |
|
|
self.scaler.update() |
|
|
try: |
|
|
scaler_sc = float(math.log2(scaler_sc)) |
|
|
except Exception as e: |
|
|
print(f'[scaler_sc = {scaler_sc}]\n' * 15, flush=True) |
|
|
time.sleep(1) |
|
|
print(f'[scaler_sc = {scaler_sc}]\n' * 15, flush=True) |
|
|
raise e |
|
|
else: |
|
|
self.optimizer.step() |
|
|
|
|
|
if self.late_clipping: |
|
|
orig_norm: Optional[torch.Tensor] = self.optimizer.global_grad_norm |
|
|
self.last_orig_norm = orig_norm |
|
|
|
|
|
return orig_norm, scaler_sc |
|
|
|
|
|
def state_dict(self): |
|
|
return { |
|
|
'optimizer': self.optimizer.state_dict() |
|
|
} if self.scaler is None else { |
|
|
'scaler': self.scaler.state_dict(), |
|
|
'optimizer': self.optimizer.state_dict() |
|
|
} |
|
|
|
|
|
def load_state_dict(self, state, strict=True): |
|
|
if self.scaler is not None: |
|
|
try: self.scaler.load_state_dict(state['scaler']) |
|
|
except Exception as e: print(f'[fp16 load_state_dict err] {e}') |
|
|
self.optimizer.load_state_dict(state['optimizer']) |
|
|
|