| | |
| | |
| | |
| | |
| |
|
| | import logging |
| | import os |
| | import random |
| | import subprocess |
| | from urllib.parse import urlparse |
| |
|
| | import numpy as np |
| | import torch |
| | from torch import nn |
| |
|
| |
|
| | logger = logging.getLogger("dinov2") |
| |
|
| |
|
| | def load_pretrained_weights(model, pretrained_weights, checkpoint_key): |
| | if urlparse(pretrained_weights).scheme: |
| | state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu") |
| | else: |
| | state_dict = torch.load(pretrained_weights, map_location="cpu") |
| | if checkpoint_key is not None and checkpoint_key in state_dict: |
| | logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") |
| | state_dict = state_dict[checkpoint_key] |
| | |
| | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} |
| | |
| | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} |
| | msg = model.load_state_dict(state_dict, strict=False) |
| | logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) |
| |
|
| |
|
| | def fix_random_seeds(seed=31): |
| | """ |
| | Fix random seeds. |
| | """ |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| | np.random.seed(seed) |
| | random.seed(seed) |
| |
|
| |
|
| | def get_sha(): |
| | cwd = os.path.dirname(os.path.abspath(__file__)) |
| |
|
| | def _run(command): |
| | return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() |
| |
|
| | sha = "N/A" |
| | diff = "clean" |
| | branch = "N/A" |
| | try: |
| | sha = _run(["git", "rev-parse", "HEAD"]) |
| | subprocess.check_output(["git", "diff"], cwd=cwd) |
| | diff = _run(["git", "diff-index", "HEAD"]) |
| | diff = "has uncommitted changes" if diff else "clean" |
| | branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) |
| | except Exception: |
| | pass |
| | message = f"sha: {sha}, status: {diff}, branch: {branch}" |
| | return message |
| |
|
| |
|
| | class CosineScheduler(object): |
| | def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0): |
| | super().__init__() |
| | self.final_value = final_value |
| | self.total_iters = total_iters |
| |
|
| | freeze_schedule = np.zeros((freeze_iters)) |
| |
|
| | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) |
| |
|
| | iters = np.arange(total_iters - warmup_iters - freeze_iters) |
| | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) |
| | self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule)) |
| |
|
| | assert len(self.schedule) == self.total_iters |
| |
|
| | def __getitem__(self, it): |
| | if it >= self.total_iters: |
| | return self.final_value |
| | else: |
| | return self.schedule[it] |
| |
|
| |
|
| | def has_batchnorms(model): |
| | bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) |
| | for name, module in model.named_modules(): |
| | if isinstance(module, bn_types): |
| | return True |
| | return False |
| |
|