import os import random import numpy as np from PIL import Image from loguru import logger import sys import inspect from timm.scheduler.cosine_lr import CosineLRScheduler import torch from torch import nn import torch.distributed as dist def init_random_seed(seed=None, device='cuda', rank=0, world_size=1): """Initialize random seed.""" if seed is not None: return seed # Make sure all ranks share the same random seed to prevent # some potential bugs. Please refer to # https://github.com/open-mmlab/mmdetection/issues/6339 seed = np.random.randint(2**31) if world_size == 1: return seed if rank == 0: random_num = torch.tensor(seed, dtype=torch.int32, device=device) else: random_num = torch.tensor(0, dtype=torch.int32, device=device) dist.broadcast(random_num, src=0) return random_num.item() def set_random_seed(seed, deterministic=False): """Set random seed.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) if deterministic: torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False @torch.no_grad() def concat_all_gather(tensor): """ Performs all_gather operation on the provided tensors. *** Warning ***: torch.distributed.all_gather has no gradient. """ tensor = tensor.contiguous() tensors_gather = [ torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) ] torch.distributed.all_gather(tensors_gather, tensor, async_op=False) output = torch.cat(tensors_gather, dim=0) return output @torch.no_grad() def concat_all_gather_varsize(tensor): """ Performs all_gather operation on tensors of varying sizes across distributed processes. Handles cases where tensors have different first-dimension sizes (batch size). """ tensor = tensor.contiguous() world_size = torch.distributed.get_world_size() local_size = torch.tensor([tensor.shape[0]], dtype=torch.int64, device=tensor.device) all_sizes = [torch.zeros_like(local_size) for _ in range(world_size)] torch.distributed.all_gather(all_sizes, local_size) all_sizes = torch.tensor([s.item() for s in all_sizes], device=tensor.device) max_size = all_sizes.max().item() # Pad the tensor to match max_size padded_tensor = torch.zeros((max_size, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) padded_tensor[:tensor.shape[0]] = tensor # Gather all padded tensors gathered_tensors = [torch.zeros_like(padded_tensor) for _ in range(world_size)] torch.distributed.all_gather(gathered_tensors, padded_tensor) gathered_tensors = torch.cat(gathered_tensors, dim=0) valid_tensors = [] start_idx = 0 for size in all_sizes: if size > 0: valid_tensors.append(gathered_tensors[start_idx:start_idx + size]) start_idx += max_size # Move to the next chunk return torch.cat(valid_tensors, dim=0) if valid_tensors else torch.empty(0, dtype=tensor.dtype, device=tensor.device) @torch.no_grad() def concat_all_gather_varsize_optimized(tensor): """ Optimized version of concat_all_gather_varsize. Uses torch.split() for efficiency in extracting valid tensors. """ tensor = tensor.contiguous() world_size = torch.distributed.get_world_size() local_size = torch.tensor([tensor.shape[0]], dtype=torch.int64, device=tensor.device) all_sizes = [torch.zeros_like(local_size) for _ in range(world_size)] torch.distributed.all_gather(all_sizes, local_size) all_sizes = torch.tensor([s.item() for s in all_sizes], device=tensor.device) max_size = all_sizes.max().item() # Pad tensor padded_tensor = torch.zeros((max_size, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) padded_tensor[:tensor.shape[0]] = tensor gathered_tensors = [torch.zeros_like(padded_tensor) for _ in range(world_size)] torch.distributed.all_gather(gathered_tensors, padded_tensor) gathered_tensors = torch.cat(gathered_tensors, dim=0) # Efficient slicing using torch.split() split_tensors = torch.split(gathered_tensors, all_sizes.tolist()) return torch.cat(split_tensors, dim=0) if split_tensors else torch.empty(0, dtype=tensor.dtype, device=tensor.device) def worker_init_fn(worker_id, num_workers, rank, seed): # The seed of each worker equals to # num_worker * rank + worker_id + user_seed worker_seed = num_workers * rank + worker_id + seed np.random.seed(worker_seed) random.seed(worker_seed) class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self, name, fmt=":f"): self.name = name self.fmt = fmt self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def __str__(self): if self.name == "Lr": fmtstr = "{name}={val" + self.fmt + "}" else: fmtstr = "{name}={val" + self.fmt + "} ({avg" + self.fmt + "})" return fmtstr.format(**self.__dict__) class ProgressMeter(object): def __init__(self, num_batches, meters, prefix=""): self.batch_fmtstr = self._get_batch_fmtstr(num_batches) self.meters = meters self.prefix = prefix def display(self, batch): entries = [self.prefix + self.batch_fmtstr.format(batch)] entries += [str(meter) for meter in self.meters] logger.info(" ".join(entries)) def _get_batch_fmtstr(self, num_batches): num_digits = len(str(num_batches // 1)) fmt = "{:" + str(num_digits) + "d}" return "[" + fmt + "/" + fmt.format(num_batches) + "]" def trainMetricGPU(output, target, threshold=0.35, pr_iou=0.5): assert (output.dim() in [2, 3, 4]) assert output.shape == target.shape output = output.flatten(1) target = target.flatten(1) output = torch.sigmoid(output) output[output < threshold] = 0. output[output >= threshold] = 1. # inter & union inter = (output.bool() & target.bool()).sum(dim=1) # b union = (output.bool() | target.bool()).sum(dim=1) # b ious = inter / (union + 1e-6) # 0 ~ 1 # iou & pr@5 iou = ious.mean() prec = (ious > pr_iou).float().mean() return 100. * iou, 100. * prec def ValMetricGPU(output, target, threshold=0.35): assert output.size(0) == 1 output = output.flatten(1) target = target.flatten(1) output = torch.sigmoid(output) output[output < threshold] = 0. output[output >= threshold] = 1. # inter & union inter = (output.bool() & target.bool()).sum(dim=1) # b union = (output.bool() | target.bool()).sum(dim=1) # b ious = inter / (union + 1e-6) # 0 ~ 1 return ious def intersectionAndUnionGPU(output, target, K, threshold=0.5): # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. assert (output.dim() in [1, 2, 3]) assert output.shape == target.shape output = output.view(-1) target = target.view(-1) output = torch.sigmoid(output) output[output < threshold] = 0. output[output >= threshold] = 1. intersection = output[output == target] area_intersection = torch.histc(intersection.float(), bins=K, min=0, max=K - 1) area_output = torch.histc(output.float(), bins=K, min=0, max=K - 1) area_target = torch.histc(target.float(), bins=K, min=0, max=K - 1) area_union = area_output + area_target - area_intersection return area_intersection[1], area_union[1] def group_weight(weight_group, module, lr): group_decay = [] group_no_decay = [] for m in module.modules(): if isinstance(m, nn.Linear): group_decay.append(m.weight) if m.bias is not None: group_no_decay.append(m.bias) elif isinstance(m, nn.modules.conv._ConvNd): group_decay.append(m.weight) if m.bias is not None: group_no_decay.append(m.bias) elif isinstance(m, nn.modules.batchnorm._BatchNorm): if m.weight is not None: group_no_decay.append(m.weight) if m.bias is not None: group_no_decay.append(m.bias) assert len(list( module.parameters())) == len(group_decay) + len(group_no_decay) weight_group.append(dict(params=group_decay, lr=lr)) weight_group.append(dict(params=group_no_decay, weight_decay=.0, lr=lr)) return weight_group def colorize(gray, palette): # gray: numpy array of the label and 1*3N size list palette color = Image.fromarray(gray.astype(np.uint8)).convert('P') color.putpalette(palette) return color def find_free_port(): import socket sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # Binding to port 0 will cause the OS to find an available port for us sock.bind(("", 0)) port = sock.getsockname()[1] sock.close() # NOTE: there is still a chance the port could be taken by other processes. return port def get_caller_name(depth=0): """ Args: depth (int): Depth of caller conext, use 0 for caller depth. Default value: 0. Returns: str: module name of the caller """ # the following logic is a little bit faster than inspect.stack() logic frame = inspect.currentframe().f_back for _ in range(depth): frame = frame.f_back return frame.f_globals["__name__"] class StreamToLoguru: """ stream object that redirects writes to a logger instance. """ def __init__(self, level="INFO", caller_names=("apex", "pycocotools")): """ Args: level(str): log level string of loguru. Default value: "INFO". caller_names(tuple): caller names of redirected module. Default value: (apex, pycocotools). """ self.level = level self.linebuf = "" self.caller_names = caller_names def write(self, buf): full_name = get_caller_name(depth=1) module_name = full_name.rsplit(".", maxsplit=-1)[0] if module_name in self.caller_names: for line in buf.rstrip().splitlines(): # use caller level log logger.opt(depth=2).log(self.level, line.rstrip()) else: sys.__stdout__.write(buf) def flush(self): pass def redirect_sys_output(log_level="INFO"): redirect_logger = StreamToLoguru(log_level) sys.stderr = redirect_logger sys.stdout = redirect_logger def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="a"): """setup logger for training and testing. Args: save_dir(str): location to save log file distributed_rank(int): device rank when multi-gpu environment filename (string): log save name. mode(str): log file write mode, `append` or `override`. default is `a`. Return: logger instance. """ loguru_format = ( "{time:YYYY-MM-DD HH:mm:ss} | " "{level: <8} | " "{name}:{line} - {message}") logger.remove() save_file = os.path.join(save_dir, filename) if mode == "o" and os.path.exists(save_file): os.remove(save_file) # only keep logger in rank0 process if distributed_rank == 0: logger.add( sys.stderr, format=loguru_format, level="INFO", enqueue=True, ) logger.add(save_file) # redirect stdout/stderr to loguru redirect_sys_output("INFO") def build_scheduler(config, optimizer, n_iter_per_epoch): num_steps = int(config.epochs * n_iter_per_epoch) warmup_steps = int(config.warmup_epochs * n_iter_per_epoch) lr_scheduler = CosineLRScheduler( optimizer, t_initial=num_steps, lr_min=config.min_lr, warmup_lr_init=config.warmup_lr, warmup_t=warmup_steps, cycle_limit=1, t_in_epochs=False, ) return lr_scheduler def collate_fn(batch): # img, word_vec, mask, pad_mask, params images, word_vecs, masks, pad_masks, params_list = zip(*batch) images = torch.cat(images) word_vecs = torch.cat(word_vecs) masks = torch.cat(masks) pad_masks = torch.cat(pad_masks) # params batchify batched_params = {} if params_list and isinstance(params_list[0], dict): all_keys = params_list[0].keys() for key in all_keys: if key == 'hardpos_emb': # sbert embddings hardpos_embs = [p[key] for p in params_list] batched_params[key] = torch.stack(hardpos_embs) if all(isinstance(e, torch.Tensor) for e in hardpos_embs) else hardpos_embs else: batched_params[key] = [p[key] for p in params_list] return images, word_vecs, masks, pad_masks, batched_params