# Copyright (c) Meta Platforms, Inc. and affiliates. # # This software may be used and distributed in accordance with # the terms of the DINOv3 License Agreement. import logging import os import random import subprocess from typing import Callable, List, Optional, Tuple import numpy as np import torch from torch import Tensor, nn logger = logging.getLogger("dinov3") def cat_keep_shapes(x_list: List[Tensor]) -> Tuple[Tensor, List[Tuple[int]], List[int]]: shapes = [x.shape for x in x_list] num_tokens = [x.select(dim=-1, index=0).numel() for x in x_list] flattened = torch.cat([x.flatten(0, -2) for x in x_list]) return flattened, shapes, num_tokens def uncat_with_shapes(flattened: Tensor, shapes: List[Tuple[int]], num_tokens: List[int]) -> List[Tensor]: outputs_splitted = torch.split_with_sizes(flattened, num_tokens, dim=0) shapes_adjusted = [shape[:-1] + torch.Size([flattened.shape[-1]]) for shape in shapes] outputs_reshaped = [o.reshape(shape) for o, shape in zip(outputs_splitted, shapes_adjusted)] return outputs_reshaped def named_replace( fn: Callable, module: nn.Module, name: str = "", depth_first: bool = True, include_root: bool = False, ) -> nn.Module: if not depth_first and include_root: module = fn(module=module, name=name) for child_name_o, child_module in list(module.named_children()): child_name = ".".join((name, child_name_o)) if name else child_name_o new_child = named_replace( fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True, ) setattr(module, child_name_o, new_child) if depth_first and include_root: module = fn(module=module, name=name) return module def named_apply( fn: Callable, module: nn.Module, name: str = "", depth_first: bool = True, include_root: bool = False, ) -> nn.Module: if not depth_first and include_root: fn(module=module, name=name) for child_name, child_module in module.named_children(): child_name = ".".join((name, child_name)) if name else child_name named_apply( fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True, ) if depth_first and include_root: fn(module=module, name=name) return module def fix_random_seeds(seed: int = 31): """ Fix random seeds. """ torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) def get_sha() -> str: cwd = os.path.dirname(os.path.abspath(__file__)) def _run(command): return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() sha = "N/A" diff = "clean" branch = "N/A" try: sha = _run(["git", "rev-parse", "HEAD"]) subprocess.check_output(["git", "diff"], cwd=cwd) diff = _run(["git", "diff-index", "HEAD"]) diff = "has uncommited changes" if diff else "clean" branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) except Exception: pass message = f"sha: {sha}, status: {diff}, branch: {branch}" return message def get_conda_env() -> Tuple[Optional[str], Optional[str]]: conda_env_name = os.environ.get("CONDA_DEFAULT_ENV") conda_env_path = os.environ.get("CONDA_PREFIX") return conda_env_name, conda_env_path def count_parameters(module: nn.Module) -> int: c = 0 for m in module.parameters(): c += m.nelement() return c def has_batchnorms(model: nn.Module) -> bool: bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) for _, module in model.named_modules(): if isinstance(module, bn_types): return True return False