dianecy's picture
Upload folder using huggingface_hub
ea1014e verified
import os
import random
import numpy as np
from PIL import Image
from loguru import logger
import sys
import inspect
from timm.scheduler.cosine_lr import CosineLRScheduler
import torch
from torch import nn
import torch.distributed as dist
def init_random_seed(seed=None, device='cuda', rank=0, world_size=1):
"""Initialize random seed."""
if seed is not None:
return seed
# Make sure all ranks share the same random seed to prevent
# some potential bugs. Please refer to
# https://github.com/open-mmlab/mmdetection/issues/6339
seed = np.random.randint(2**31)
if world_size == 1:
return seed
if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()
def set_random_seed(seed, deterministic=False):
"""Set random seed."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
@torch.no_grad()
def concat_all_gather(tensor):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
tensor = tensor.contiguous()
tensors_gather = [
torch.ones_like(tensor)
for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output
@torch.no_grad()
def concat_all_gather_varsize(tensor):
"""
Performs all_gather operation on tensors of varying sizes across distributed processes.
Handles cases where tensors have different first-dimension sizes (batch size).
"""
tensor = tensor.contiguous()
world_size = torch.distributed.get_world_size()
local_size = torch.tensor([tensor.shape[0]], dtype=torch.int64, device=tensor.device)
all_sizes = [torch.zeros_like(local_size) for _ in range(world_size)]
torch.distributed.all_gather(all_sizes, local_size)
all_sizes = torch.tensor([s.item() for s in all_sizes], device=tensor.device)
max_size = all_sizes.max().item()
# Pad the tensor to match max_size
padded_tensor = torch.zeros((max_size, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
padded_tensor[:tensor.shape[0]] = tensor
# Gather all padded tensors
gathered_tensors = [torch.zeros_like(padded_tensor) for _ in range(world_size)]
torch.distributed.all_gather(gathered_tensors, padded_tensor)
gathered_tensors = torch.cat(gathered_tensors, dim=0)
valid_tensors = []
start_idx = 0
for size in all_sizes:
if size > 0:
valid_tensors.append(gathered_tensors[start_idx:start_idx + size])
start_idx += max_size # Move to the next chunk
return torch.cat(valid_tensors, dim=0) if valid_tensors else torch.empty(0, dtype=tensor.dtype, device=tensor.device)
@torch.no_grad()
def concat_all_gather_varsize_optimized(tensor):
"""
Optimized version of concat_all_gather_varsize.
Uses torch.split() for efficiency in extracting valid tensors.
"""
tensor = tensor.contiguous()
world_size = torch.distributed.get_world_size()
local_size = torch.tensor([tensor.shape[0]], dtype=torch.int64, device=tensor.device)
all_sizes = [torch.zeros_like(local_size) for _ in range(world_size)]
torch.distributed.all_gather(all_sizes, local_size)
all_sizes = torch.tensor([s.item() for s in all_sizes], device=tensor.device)
max_size = all_sizes.max().item()
# Pad tensor
padded_tensor = torch.zeros((max_size, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
padded_tensor[:tensor.shape[0]] = tensor
gathered_tensors = [torch.zeros_like(padded_tensor) for _ in range(world_size)]
torch.distributed.all_gather(gathered_tensors, padded_tensor)
gathered_tensors = torch.cat(gathered_tensors, dim=0)
# Efficient slicing using torch.split()
split_tensors = torch.split(gathered_tensors, all_sizes.tolist())
return torch.cat(split_tensors, dim=0) if split_tensors else torch.empty(0, dtype=tensor.dtype, device=tensor.device)
def worker_init_fn(worker_id, num_workers, rank, seed):
# The seed of each worker equals to
# num_worker * rank + worker_id + user_seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=":f"):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
if self.name == "Lr":
fmtstr = "{name}={val" + self.fmt + "}"
else:
fmtstr = "{name}={val" + self.fmt + "} ({avg" + self.fmt + "})"
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
logger.info(" ".join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = "{:" + str(num_digits) + "d}"
return "[" + fmt + "/" + fmt.format(num_batches) + "]"
def trainMetricGPU(output, target, threshold=0.35, pr_iou=0.5):
assert (output.dim() in [2, 3, 4])
assert output.shape == target.shape
output = output.flatten(1)
target = target.flatten(1)
output = torch.sigmoid(output)
output[output < threshold] = 0.
output[output >= threshold] = 1.
# inter & union
inter = (output.bool() & target.bool()).sum(dim=1) # b
union = (output.bool() | target.bool()).sum(dim=1) # b
ious = inter / (union + 1e-6) # 0 ~ 1
# iou & pr@5
iou = ious.mean()
prec = (ious > pr_iou).float().mean()
return 100. * iou, 100. * prec
def ValMetricGPU(output, target, threshold=0.35):
assert output.size(0) == 1
output = output.flatten(1)
target = target.flatten(1)
output = torch.sigmoid(output)
output[output < threshold] = 0.
output[output >= threshold] = 1.
# inter & union
inter = (output.bool() & target.bool()).sum(dim=1) # b
union = (output.bool() | target.bool()).sum(dim=1) # b
ious = inter / (union + 1e-6) # 0 ~ 1
return ious
def intersectionAndUnionGPU(output, target, K, threshold=0.5):
# 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
assert (output.dim() in [1, 2, 3])
assert output.shape == target.shape
output = output.view(-1)
target = target.view(-1)
output = torch.sigmoid(output)
output[output < threshold] = 0.
output[output >= threshold] = 1.
intersection = output[output == target]
area_intersection = torch.histc(intersection.float(),
bins=K,
min=0,
max=K - 1)
area_output = torch.histc(output.float(), bins=K, min=0, max=K - 1)
area_target = torch.histc(target.float(), bins=K, min=0, max=K - 1)
area_union = area_output + area_target - area_intersection
return area_intersection[1], area_union[1]
def group_weight(weight_group, module, lr):
group_decay = []
group_no_decay = []
for m in module.modules():
if isinstance(m, nn.Linear):
group_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
elif isinstance(m, nn.modules.conv._ConvNd):
group_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
elif isinstance(m, nn.modules.batchnorm._BatchNorm):
if m.weight is not None:
group_no_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
assert len(list(
module.parameters())) == len(group_decay) + len(group_no_decay)
weight_group.append(dict(params=group_decay, lr=lr))
weight_group.append(dict(params=group_no_decay, weight_decay=.0, lr=lr))
return weight_group
def colorize(gray, palette):
# gray: numpy array of the label and 1*3N size list palette
color = Image.fromarray(gray.astype(np.uint8)).convert('P')
color.putpalette(palette)
return color
def find_free_port():
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Binding to port 0 will cause the OS to find an available port for us
sock.bind(("", 0))
port = sock.getsockname()[1]
sock.close()
# NOTE: there is still a chance the port could be taken by other processes.
return port
def get_caller_name(depth=0):
"""
Args:
depth (int): Depth of caller conext, use 0 for caller depth.
Default value: 0.
Returns:
str: module name of the caller
"""
# the following logic is a little bit faster than inspect.stack() logic
frame = inspect.currentframe().f_back
for _ in range(depth):
frame = frame.f_back
return frame.f_globals["__name__"]
class StreamToLoguru:
"""
stream object that redirects writes to a logger instance.
"""
def __init__(self, level="INFO", caller_names=("apex", "pycocotools")):
"""
Args:
level(str): log level string of loguru. Default value: "INFO".
caller_names(tuple): caller names of redirected module.
Default value: (apex, pycocotools).
"""
self.level = level
self.linebuf = ""
self.caller_names = caller_names
def write(self, buf):
full_name = get_caller_name(depth=1)
module_name = full_name.rsplit(".", maxsplit=-1)[0]
if module_name in self.caller_names:
for line in buf.rstrip().splitlines():
# use caller level log
logger.opt(depth=2).log(self.level, line.rstrip())
else:
sys.__stdout__.write(buf)
def flush(self):
pass
def redirect_sys_output(log_level="INFO"):
redirect_logger = StreamToLoguru(log_level)
sys.stderr = redirect_logger
sys.stdout = redirect_logger
def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="a"):
"""setup logger for training and testing.
Args:
save_dir(str): location to save log file
distributed_rank(int): device rank when multi-gpu environment
filename (string): log save name.
mode(str): log file write mode, `append` or `override`. default is `a`.
Return:
logger instance.
"""
loguru_format = (
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> | "
"<level>{level: <8}</level> | "
"<cyan>{name}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>")
logger.remove()
save_file = os.path.join(save_dir, filename)
if mode == "o" and os.path.exists(save_file):
os.remove(save_file)
# only keep logger in rank0 process
if distributed_rank == 0:
logger.add(
sys.stderr,
format=loguru_format,
level="INFO",
enqueue=True,
)
logger.add(save_file)
# redirect stdout/stderr to loguru
redirect_sys_output("INFO")
def build_scheduler(config, optimizer, n_iter_per_epoch):
num_steps = int(config.epochs * n_iter_per_epoch)
warmup_steps = int(config.warmup_epochs * n_iter_per_epoch)
lr_scheduler = CosineLRScheduler(
optimizer,
t_initial=num_steps,
lr_min=config.min_lr,
warmup_lr_init=config.warmup_lr,
warmup_t=warmup_steps,
cycle_limit=1,
t_in_epochs=False,
)
return lr_scheduler
def collate_fn(batch):
# img, word_vec, mask, pad_mask, params
images, word_vecs, masks, pad_masks, params_list = zip(*batch)
images = torch.cat(images)
word_vecs = torch.cat(word_vecs)
masks = torch.cat(masks)
pad_masks = torch.cat(pad_masks)
# params batchify
batched_params = {}
if params_list and isinstance(params_list[0], dict):
all_keys = params_list[0].keys()
for key in all_keys:
if key == 'hardpos_emb': # sbert embddings
hardpos_embs = [p[key] for p in params_list]
batched_params[key] = torch.stack(hardpos_embs) if all(isinstance(e, torch.Tensor) for e in hardpos_embs) else hardpos_embs
else:
batched_params[key] = [p[key] for p in params_list]
return images, word_vecs, masks, pad_masks, batched_params