miqa / utils /misc.py
xiaoqi-wang's picture
Upload utils/misc.py with huggingface_hub
71206d8 verified
import os
import torch
import torch.distributed as dist
from torch import inf
import shutil
import random
import numpy as np
def setup_seed(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def save_checkpoint(args, state, is_best):
os.makedirs(os.path.join(args.output_dir, 'checkpoints'), exist_ok=True)
filename = os.path.join(args.output_dir, 'checkpoints', 'checkpoint.pth.tar')
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, os.path.join(args.output_dir, 'checkpoints', '{}_best.pth.tar'.format(args.run_name)))
def get_grad_norm(parameters, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
total_norm = total_norm ** (1.0 / norm_type)
return total_norm
def reduce_tensor(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= dist.get_world_size()
return rt
def ampscaler_get_grad_norm(parameters, norm_type: float = 2.0) -> torch.Tensor:
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]
norm_type = float(norm_type)
if len(parameters) == 0:
return torch.tensor(0.0)
device = parameters[0].grad.device
if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
else:
total_norm = torch.norm(
torch.stack(
[torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]
),
norm_type,
)
return total_norm
class NativeScalerWithGradNormCount:
state_dict_key = "amp_scaler"
def __init__(self):
self._scaler = torch.cuda.amp.GradScaler()
def __call__(
self,
loss,
optimizer,
clip_grad=None,
parameters=None,
create_graph=False,
update_grad=True,
):
self._scaler.scale(loss).backward(create_graph=create_graph)
if update_grad:
if clip_grad is not None:
assert parameters is not None
self._scaler.unscale_(
optimizer
) # unscale the gradients of optimizer's assigned params in-place
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
else:
self._scaler.unscale_(optimizer)
norm = ampscaler_get_grad_norm(parameters)
self._scaler.step(optimizer)
self._scaler.update()
else:
norm = None
return norm
def state_dict(self):
return self._scaler.state_dict()
def load_state_dict(self, state_dict):
self._scaler.load_state_dict(state_dict)
def auto_resume_helper(output_dir):
checkpoints = os.listdir(output_dir)
checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith("pth")]
print(f"All checkpoints founded in {output_dir}: {checkpoints}")
if len(checkpoints) > 0:
latest_checkpoint = max(
[os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime
)
print(f"The latest checkpoint founded: {latest_checkpoint}")
resume_file = latest_checkpoint
else:
resume_file = None
return resume_file
if __name__ == '__main__':
pass