File size: 3,875 Bytes
f484856
 
 
 
 
 
 
 
 
 
 
54519d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f484856
 
54519d5
 
 
 
 
f484856
 
54519d5
 
 
f484856
54519d5
 
 
f484856
54519d5
 
 
 
 
 
 
 
 
 
 
f484856
 
 
 
 
 
54519d5
 
 
f484856
 
 
 
 
 
 
 
 
54519d5
f484856
 
 
 
 
 
 
 
 
 
 
54519d5
f484856
 
 
 
 
 
 
 
54519d5
 
 
 
 
 
 
 
f484856
54519d5
f484856
 
 
 
54519d5
f484856
54519d5
 
 
f484856
 
54519d5
f484856
 
54519d5
 
 
 
 
 
 
f484856
 
 
54519d5
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import re
import random
import time
import torch
import numpy as np
from os import path as osp

from .dist_util import master_only
from .logger import get_root_logger


# ---------------------------
#   GPU / MPS Compatibility
# ---------------------------

# Check if PyTorch ≥ 1.12 for MPS (Apple Silicon)
try:
    version_match = re.findall(
        r"^([0-9]+)\.([0-9]+)\.([0-9]+)",
        torch.__version__
    )[0]
    IS_HIGH_VERSION = [int(x) for x in version_match] >= [1, 12, 0]
except:
    IS_HIGH_VERSION = False


def gpu_is_available():
    """Return True if CUDA or MPS is available."""
    if IS_HIGH_VERSION and torch.backends.mps.is_available():
        return True
    return torch.cuda.is_available() and torch.backends.cudnn.is_available()


def get_device(gpu_id=None):
    """Return the best available device (MPS → CUDA → CPU)."""

    gpu_str = f":{gpu_id}" if isinstance(gpu_id, int) else ""

    # Apple MPS
    if IS_HIGH_VERSION and torch.backends.mps.is_available():
        return torch.device("mps")

    # NVIDIA CUDA
    if torch.cuda.is_available() and torch.backends.cudnn.is_available():
        return torch.device("cuda" + gpu_str)

    # CPU fallback
    return torch.device("cpu")


# ---------------------------
#   Utilities
# ---------------------------

def set_random_seed(seed):
    """Set random seeds."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)


def get_time_str():
    return time.strftime('%Y%m%d_%H%M%S', time.localtime())


def mkdir_and_rename(path):
    if osp.exists(path):
        new_name = path + '_archived_' + get_time_str()
        print(f'Path already exists. Renamed to {new_name}', flush=True)
        os.rename(path, new_name)
    os.makedirs(path, exist_ok=True)


@master_only
def make_exp_dirs(opt):
    path_opt = opt['path'].copy()
    if opt['is_train']:
        mkdir_and_rename(path_opt.pop('experiments_root'))
    else:
        mkdir_and_rename(path_opt.pop('results_root'))

    for key, path in path_opt.items():
        if ('strict_load' not in key) and ('pretrain_network' not in key) and ('resume' not in key):
            os.makedirs(path, exist_ok=True)


def scandir(dir_path, suffix=None, recursive=False, full_path=False):
    root = dir_path

    def _scan(path):
        for entry in os.scandir(path):
            if entry.is_file() and not entry.name.startswith('.'):
                file_path = entry.path if full_path else osp.relpath(entry.path, root)
                if suffix is None or file_path.endswith(suffix):
                    yield file_path
            elif entry.is_dir() and recursive:
                yield from _scan(entry.path)

    return _scan(dir_path)


def check_resume(opt, resume_iter):
    logger = get_root_logger()

    if opt['path']['resume_state']:
        networks = [k for k in opt.keys() if k.startswith('network_')]
        flag_pretrain = any(opt['path'].get(f'pretrain_{n}') for n in networks)

        if flag_pretrain:
            logger.warning('pretrain_network path will be ignored during resuming.')

        for network in networks:
            basename = network.replace('network_', '')
            if opt['path'].get('ignore_resume_networks') is None or (
                basename not in opt['path']['ignore_resume_networks']
            ):
                opt['path'][f'pretrain_{network}'] = osp.join(
                    opt['path']['models'], f'net_{basename}_{resume_iter}.pth'
                )
                logger.info(f"Set pretrain for {network}")


def sizeof_fmt(size, suffix='B'):
    for unit in ['', 'K', 'M', 'G', 'T', 'P']:
        if size < 1024:
            return f"{size:3.1f} {unit}{suffix}"
        size /= 1024
    return f"{size:3.1f} Y{suffix}"