Spaces:
Sleeping
Sleeping
File size: 7,309 Bytes
fcb1768 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 | import importlib
import os
import random
import torch
import numpy as np
from collections import abc
from einops import rearrange
from functools import partial
import multiprocessing as mp
from threading import Thread
from queue import Queue
from inspect import isfunction
from PIL import Image, ImageDraw, ImageFont
import shutil
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
return total_params
def instantiate_from_cfg(config):
if not "type" in config:
raise KeyError("Expected key `type` to instantiate.")
return get_obj_from_str(config["type"])(**config.get("params", dict()))
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
# create dummy dataset instance
# run prefetching
if idx_to_fn:
res = func(data, worker_id=idx)
else:
res = func(data)
Q.put([idx, res])
Q.put("Done")
def parallel_data_prefetch(
func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
):
# if target_data_type not in ["ndarray", "list"]:
# raise ValueError(
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
# )
if isinstance(data, np.ndarray) and target_data_type == "list":
raise ValueError("list expected but function got ndarray.")
elif isinstance(data, abc.Iterable):
if isinstance(data, dict):
print(
f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
)
data = list(data.values())
if target_data_type == "ndarray":
data = np.asarray(data)
else:
data = list(data)
else:
raise TypeError(
f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
)
if cpu_intensive:
Q = mp.Queue(1000)
proc = mp.Process
else:
Q = Queue(1000)
proc = Thread
# spawn processes
if target_data_type == "ndarray":
arguments = [
[func, Q, part, i, use_worker_id]
for i, part in enumerate(np.array_split(data, n_proc))
]
else:
step = (
int(len(data) / n_proc + 1)
if len(data) % n_proc != 0
else int(len(data) / n_proc)
)
arguments = [
[func, Q, part, i, use_worker_id]
for i, part in enumerate(
[data[i: i + step] for i in range(0, len(data), step)]
)
]
processes = []
for i in range(n_proc):
p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
processes += [p]
# start processes
print(f"Start prefetching...")
import time
start = time.time()
gather_res = [[] for _ in range(n_proc)]
try:
for p in processes:
p.start()
k = 0
while k < n_proc:
# get result
res = Q.get()
if res == "Done":
k += 1
else:
gather_res[res[0]] = res[1]
except Exception as e:
print("Exception: ", e)
for p in processes:
p.terminate()
raise e
finally:
for p in processes:
p.join()
print(f"Prefetching complete. [{time.time() - start} sec.]")
if target_data_type == 'ndarray':
if not isinstance(gather_res[0], np.ndarray):
return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
# order outputs
return np.concatenate(gather_res, axis=0)
elif target_data_type == 'list':
out = []
for r in gather_res:
out.extend(r)
return out
else:
return gather_res
def set_seed(seed):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
# ensure reproducibility
os.environ["PYTHONHASHSEED"] = str(seed)
def transform_date_str(date_str):
from datetime import datetime
# Convert the date string to a datetime object
date_obj = datetime.strptime(date_str, '%Y-%m-%d')
# Calculate the week of the month
start_of_month = datetime(date_obj.year, date_obj.month, 1)
week_of_month = (date_obj - start_of_month).days // 7 + 1
return f"{date_obj.year}{date_obj.month:02}week{week_of_month}"
def save_files(base_dir, run_directory, extensions=('.py', '.yaml')):
run_directory = os.path.join(run_directory, 'run')
os.makedirs(run_directory, exist_ok=True)
src_dirs = [
"configs",
"criteria",
"datasets",
"models",
"trainers",
"utils",
]
src_dirs = [os.path.join(base_dir, src_dir) for src_dir in src_dirs]
for src_dir in src_dirs:
# Traverse the directory tree
for root, dirs, files in os.walk(src_dir):
# Calculate the relative path from the base directory
relative_path = os.path.relpath(root, base_dir)
dest_dir = os.path.join(run_directory, relative_path)
os.makedirs(dest_dir, exist_ok=True)
# Copy files with the specified extensions
for file in files:
if file.endswith(extensions):
src_file_path = os.path.join(root, file)
dest_file_path = os.path.join(dest_dir, file)
shutil.copy(src_file_path, dest_file_path)
# print(f"Saved {src_file_path} to {dest_file_path}")
def call_model_method(model, method_name, *args, **kwargs):
"""
Calls a method on the model, regardless of whether it is wrapped in DataParallel or not.
:param model: The model or DataParallel wrapped model.
:param method_name: The name of the method to call.
:param args: Positional arguments to pass to the method.
:param kwargs: Keyword arguments to pass to the method.
"""
if isinstance(model, torch.nn.DataParallel):
target_model = model.module
else:
target_model = model
# Get the method and call it
method = getattr(target_model, method_name)
return method(*args, **kwargs)
def get_attributes_with_prefix(instance, prefix):
return {attr_name: getattr(instance, attr_name) for attr_name in vars(instance) if attr_name.startswith(prefix)}
def update_ema_params(model, ema_model, alpha, global_step):
alpha = min(1 - 1 / (global_step + 1), alpha)
# print('ema_model = ema_model * {} + (1 - {}) * model'.format(alpha, alpha))
for ema_param, param in zip(ema_model.parameters(), model.parameters()):
# ema_param.data.mul_(alpha).add_(1 - alpha, param.data)
ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)
|