zhouyik's picture
Upload folder using huggingface_hub
032e687 verified
# ------------------------------------------------------------------------
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
"""Engine utilities."""
import collections
import functools
import pickle
import torch
from torch import nn
import numpy as np
from third_parts.tokenize_anything.utils import logging
GLOBAL_DDP_GROUP = None
def count_params(module, trainable=True, unit="M"):
"""Return the number of parameters."""
counts = [v.size().numel() for v in module.parameters() if v.requires_grad or (not trainable)]
return sum(counts) / {"M": 1e6, "B": 1e9}[unit]
def freeze_module(module, trainable=False):
"""Freeze parameters of given module."""
module.eval() if not trainable else module.train()
for param in module.parameters():
param.requires_grad = trainable
def get_device(index):
"""Create the available device object."""
if torch.cuda.is_available():
return torch.device("cuda", index)
for device_type in ("mps",):
try:
if getattr(torch.backends, device_type).is_available():
return torch.device(device_type, index)
except AttributeError:
pass
return torch.device("cpu")
def get_param_groups(model, layer_lr_decay=1.0):
"""Separate parameters into groups."""
memo, groups, lr_scale_getter = set(), collections.OrderedDict(), None
if layer_lr_decay < 1.0 and hasattr(model.image_encoder, "get_lr_scale"):
lr_scale_getter = functools.partial(model.image_encoder.get_lr_scale, decay=layer_lr_decay)
norm_types = (nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm, nn.LayerNorm)
for module_name, module in model.named_modules():
for param_name, param in module.named_parameters(recurse=False):
if not param.requires_grad or param in memo:
continue
memo.add(param)
attrs = collections.OrderedDict()
if lr_scale_getter:
attrs["lr_scale"] = lr_scale_getter(f"{module_name}.{param_name}")
if hasattr(param, "lr_scale"):
attrs["lr_scale"] = param.lr_scale
if getattr(param, "no_weight_decay", False) or isinstance(module, norm_types):
attrs["weight_decay"] = 0
group_name = "/".join(["%s:%s" % (v[0], v[1]) for v in list(attrs.items())])
groups[group_name] = groups.get(group_name, {**attrs, **{"params": []}})
groups[group_name]["params"].append(param)
return list(groups.values())
def load_weights(module, weights_file, prefix_removed="", strict=True):
"""Load a weights file."""
if not weights_file:
return
if weights_file.endswith(".pkl"):
with open(weights_file, "rb") as f:
state_dict = pickle.load(f)
for k, v in state_dict.items():
state_dict[k] = torch.as_tensor(v)
else:
state_dict = torch.load(weights_file, map_location="cpu")
if prefix_removed:
new_state_dict = type(state_dict)()
for k in list(state_dict.keys()):
new_state_dict[k.replace(prefix_removed, "")] = state_dict.pop(k)
state_dict = new_state_dict
module.load_state_dict(state_dict, strict=strict)
def manual_seed(seed, device_and_seed=None):
"""Set the cpu and device random seed."""
torch.manual_seed(seed)
if device_and_seed is not None:
device_index, device_seed = device_and_seed
device_type = get_device(device_index).type
np.random.seed(device_seed)
if device_type in ("cuda", "mps"):
getattr(torch, device_type).manual_seed(device_seed)
def synchronize_device(device):
"""Synchronize the computation of device."""
if device.type in ("cuda", "mps"):
getattr(torch, device.type).synchronize(device)
def create_ddp_group(cfg, ranks=None, devices=None):
"""Create group for data parallelism."""
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
world_rank = torch.distributed.get_rank()
ranks = ranks if ranks else [i for i in range(cfg.NUM_GPUS)]
logging.set_root(world_rank == ranks[0])
devices = devices if devices else [i % 8 for i in range(len(ranks))]
cfg.GPU_ID = devices[world_rank]
torch.cuda.set_device(cfg.GPU_ID)
global GLOBAL_DDP_GROUP
GLOBAL_DDP_GROUP = torch.distributed.new_group(ranks)
return GLOBAL_DDP_GROUP
def get_ddp_group():
"""Return the process group for data parallelism."""
return GLOBAL_DDP_GROUP
def get_ddp_rank():
"""Return the rank in the data parallelism group."""
ddp_group = get_ddp_group()
if ddp_group is None:
return 0
return torch.distributed.get_rank(ddp_group)
def apply_ddp(module):
"""Apply distributed data parallelism for given module."""
ddp_group = get_ddp_group()
if ddp_group is None:
return module
return torch.nn.parallel.DistributedDataParallel(module, process_group=ddp_group)
def apply_deepspeed(model, optimizer, ds_config=None, log_lvl="WARNING"):
"""Apply deepspeed parallelism for given module."""
if not ds_config:
return None
import deepspeed
# Store the float32 batchnorm stats to avoid post-upcasting.
bn_types, bn_buffers = (nn.BatchNorm1d, nn.BatchNorm2d, nn.SyncBatchNorm), {}
for m in filter(lambda m: isinstance(m, bn_types), model.modules()):
for key, buf in m._buffers.items():
if buf is not None and "running" in key:
bn_buffers["{}_{}".format(id(m), key)] = m._buffers[key]
deepspeed.logger.setLevel(log_lvl)
ds_model = deepspeed.initialize(None, model, optimizer, config=ds_config)[0]
# Revert the downcasting batchnorm stats in deepspeed initialization.
# torch>=2.1 is required, which fixes PR#98332 for 16bit BN weight/bias.
for m in filter(lambda m: isinstance(m, bn_types), model.modules()):
for key, buf in m._buffers.items():
if buf is not None and "running" in key:
m._buffers[key] = bn_buffers["{}_{}".format(id(m), key)]
assert m._buffers[key].dtype == torch.float32
return ds_model