| """ |
| RT-DETRv4: Painlessly Furthering Real-Time Object Detection with Vision Foundation Models |
| Copyright (c) 2025 The RT-DETRv4 Authors. All Rights Reserved. |
| --------------------------------------------------------------------------------- |
| Modified from RT-DETR (https://github.com/lyuwenyu/RT-DETR) |
| |
| reference |
| - https://github.com/pytorch/vision/blob/main/references/detection/utils.py |
| - https://github.com/facebookresearch/detr/blob/master/util/misc.py#L406 |
| |
| Copyright(c) 2023 lyuwenyu. All Rights Reserved. |
| """ |
|
|
| import os |
| import time |
| import random |
| from datetime import timedelta |
|
|
| import numpy as np |
| import atexit |
|
|
| import torch |
| import torch.nn as nn |
| import torch.distributed |
| import torch.backends.cudnn |
|
|
| from torch.nn.parallel import DataParallel as DP |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
|
|
| from torch.utils.data import DistributedSampler |
| |
| from ..data import DataLoader |
|
|
|
|
| def setup_distributed(print_rank: int=0, print_method: str='builtin', seed: int=None, ): |
| """ |
| env setup |
| args: |
| print_rank, |
| print_method, (builtin, rich) |
| seed, |
| """ |
| try: |
| |
| RANK = int(os.getenv('RANK', -1)) |
| LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) |
| WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) |
|
|
| |
| torch.distributed.init_process_group(init_method='env://', timeout=timedelta(seconds=72000)) |
| torch.distributed.barrier() |
|
|
| rank = torch.distributed.get_rank() |
| torch.cuda.set_device(rank) |
| torch.cuda.empty_cache() |
| enabled_dist = True |
| if get_rank() == print_rank: |
| print('Initialized distributed mode...') |
|
|
| except Exception: |
| enabled_dist = False |
| print('Not init distributed mode.') |
|
|
| setup_print(get_rank() == print_rank, method=print_method) |
| if seed is not None: |
| setup_seed(seed) |
|
|
| return enabled_dist |
|
|
|
|
| def setup_print(is_main, method='builtin'): |
| """This function disables printing when not in master process |
| """ |
| import builtins as __builtin__ |
|
|
| if method == 'builtin': |
| builtin_print = __builtin__.print |
|
|
| elif method == 'rich': |
| import rich |
| builtin_print = rich.print |
|
|
| else: |
| raise AttributeError('') |
|
|
| def print(*args, **kwargs): |
| force = kwargs.pop('force', False) |
| if is_main or force: |
| builtin_print(*args, **kwargs) |
|
|
| __builtin__.print = print |
|
|
|
|
| def is_dist_available_and_initialized(): |
| if not torch.distributed.is_available(): |
| return False |
| if not torch.distributed.is_initialized(): |
| return False |
| return True |
|
|
|
|
| @atexit.register |
| def cleanup(): |
| """cleanup distributed environment |
| """ |
| if is_dist_available_and_initialized(): |
| torch.distributed.barrier() |
| torch.distributed.destroy_process_group() |
|
|
|
|
| def get_rank(): |
| if not is_dist_available_and_initialized(): |
| return 0 |
| return torch.distributed.get_rank() |
|
|
|
|
| def get_world_size(): |
| if not is_dist_available_and_initialized(): |
| return 1 |
| return torch.distributed.get_world_size() |
|
|
|
|
| def is_main_process(): |
| return get_rank() == 0 |
|
|
|
|
| def save_on_master(*args, **kwargs): |
| if is_main_process(): |
| torch.save(*args, **kwargs) |
|
|
|
|
|
|
| def warp_model( |
| model: torch.nn.Module, |
| sync_bn: bool=False, |
| dist_mode: str='ddp', |
| find_unused_parameters: bool=False, |
| compile: bool=False, |
| compile_mode: str='reduce-overhead', |
| **kwargs |
| ): |
| if is_dist_available_and_initialized(): |
| rank = get_rank() |
| model = nn.SyncBatchNorm.convert_sync_batchnorm(model) if sync_bn else model |
| if dist_mode == 'dp': |
| model = DP(model, device_ids=[rank], output_device=rank) |
| elif dist_mode == 'ddp': |
| model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=find_unused_parameters) |
| else: |
| raise AttributeError('') |
|
|
| if compile: |
| model = torch.compile(model, mode=compile_mode) |
|
|
| return model |
|
|
| def de_model(model): |
| return de_parallel(de_complie(model)) |
|
|
|
|
| def warp_loader(loader, shuffle=False): |
| if is_dist_available_and_initialized(): |
| sampler = DistributedSampler(loader.dataset, shuffle=shuffle) |
| loader = DataLoader(loader.dataset, |
| loader.batch_size, |
| sampler=sampler, |
| drop_last=loader.drop_last, |
| collate_fn=loader.collate_fn, |
| pin_memory=loader.pin_memory, |
| num_workers=loader.num_workers) |
| return loader |
|
|
|
|
|
|
| def is_parallel(model) -> bool: |
| |
| return type(model) in (torch.nn.parallel.DataParallel, torch.nn.parallel.DistributedDataParallel) |
|
|
|
|
| def de_parallel(model) -> nn.Module: |
| |
| return model.module if is_parallel(model) else model |
|
|
|
|
| def reduce_dict(data, avg=True): |
| """ |
| Args |
| data dict: input, {k: v, ...} |
| avg bool: true |
| """ |
| world_size = get_world_size() |
| if world_size < 2: |
| return data |
|
|
| with torch.no_grad(): |
| keys, values = [], [] |
| for k in sorted(data.keys()): |
| keys.append(k) |
| values.append(data[k]) |
|
|
| values = torch.stack(values, dim=0) |
| torch.distributed.all_reduce(values) |
|
|
| if avg is True: |
| values /= world_size |
|
|
| return {k: v for k, v in zip(keys, values)} |
|
|
|
|
| def all_gather(data): |
| """ |
| Run all_gather on arbitrary picklable data (not necessarily tensors) |
| Args: |
| data: any picklable object |
| Returns: |
| list[data]: list of data gathered from each rank |
| """ |
| world_size = get_world_size() |
| if world_size == 1: |
| return [data] |
| data_list = [None] * world_size |
| torch.distributed.all_gather_object(data_list, data) |
| return data_list |
|
|
|
|
| def sync_time(): |
| """sync_time |
| """ |
| if torch.cuda.is_available(): |
| torch.cuda.synchronize() |
|
|
| return time.time() |
|
|
|
|
|
|
| def setup_seed(seed: int, deterministic=False): |
| """setup_seed for reproducibility |
| torch.manual_seed(3407) is all you need. https://arxiv.org/abs/2109.08203 |
| """ |
| seed = seed + get_rank() |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
|
|
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
| |
| if torch.backends.cudnn.is_available() and deterministic: |
| torch.backends.cudnn.deterministic = True |
|
|
|
|
| |
| def check_compile(): |
| import torch |
| import warnings |
| gpu_ok = False |
| if torch.cuda.is_available(): |
| device_cap = torch.cuda.get_device_capability() |
| if device_cap in ((7, 0), (8, 0), (9, 0)): |
| gpu_ok = True |
| if not gpu_ok: |
| warnings.warn( |
| "GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower " |
| "than expected." |
| ) |
| return gpu_ok |
|
|
| def is_compile(model): |
| import torch._dynamo |
| return type(model) in (torch._dynamo.OptimizedModule, ) |
|
|
| def de_complie(model): |
| return model._orig_mod if is_compile(model) else model |
|
|