| import builtins |
|
|
| from omegaconf import OmegaConf |
| from loguru import logger |
| import sys |
| import os |
| import torch |
| import numpy as np |
| import wandb |
| import random |
| from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
|
|
| class UnNormalize(object): |
| def __init__(self, mean, std, rescale_factor=None): |
| self.mean = mean |
| self.std = std |
| self.rescale_factor = rescale_factor |
|
|
| def __call__(self, image): |
| image2 = torch.clone(image) |
| dims = len(image2.shape) |
| if dims == 3: |
| image2 = image2.unsqueeze(0) |
| image2 = image2.permute(1, 0, 2, 3) |
| for t, m, s in zip(image2, self.mean, self.std): |
| t.mul_(s).add_(m) |
| image2 = image2.permute(1, 0, 2, 3) |
| if dims == 3: |
| image2 = image2.squeeze(0) |
|
|
| if self.rescale_factor is not None: |
| standard_rescale = 1.0 / 255.0 |
| if abs(self.rescale_factor - standard_rescale) > 1e-6: |
| |
| image2 = image2 / 2.0 + 0.5 |
| |
| return torch.clamp(image2, 0, 1) |
|
|
|
|
| class AverageScalarMeter(object): |
| def __init__(self, window_size): |
| self.window_size = window_size |
| self.current_size = 0 |
| self.mean = 0 |
|
|
| def update(self, values): |
| size = values.size()[0] |
| if size == 0: |
| return |
| new_mean = torch.mean(values.float(), dim=0).cpu().numpy().item() |
| size = np.clip(size, 0, self.window_size) |
| old_size = min(self.window_size - size, self.current_size) |
| size_sum = old_size + size |
| self.current_size = size_sum |
| self.mean = (self.mean * old_size + new_mean * size) / size_sum |
|
|
| def clear(self): |
| self.current_size = 0 |
| self.mean = 0 |
|
|
| def __len__(self): |
| return self.current_size |
|
|
| def get_mean(self): |
| return self.mean |
|
|
|
|
| def plot_grad_norms(named_parameters, name_prefix=''): |
| for name, param in named_parameters: |
| if param.grad is not None: |
| norm = torch.linalg.vector_norm(param.grad, 2.0).item() |
| wandb.log({f'{name_prefix}{name}': norm}) |
|
|
|
|
| def suppress_print(): |
| """Suppresses printing from the current process.""" |
| def ignore(*_objects, _sep=" ", _end="\n", _file=sys.stdout, _flush=False): |
| pass |
| builtins.print = ignore |
|
|
|
|
| def suppress_wandb(): |
| """Suppresses wandb logging from the current_process.""" |
| |
| original_functions = {} |
| for attr_name in dir(wandb): |
| attr = getattr(wandb, attr_name) |
| if callable(attr) and not attr_name.startswith('__'): |
| original_functions[attr_name] = attr |
|
|
| |
| def make_noop(name): |
| def noop(*args, **kwargs): |
| pass |
| return noop |
|
|
| setattr(wandb, attr_name, make_noop(attr_name)) |
|
|
|
|
| def suppress_logging(): |
| """Suppresses loguru logging from the current process.""" |
| logger.remove() |
| logger.add(lambda _: None) |
|
|
|
|
| def dump_cfg(cfg, logdir): |
| out_f = os.path.join(logdir, "config.yaml") |
| with open(out_f, "w") as f: |
| f.write(OmegaConf.to_yaml(cfg)) |
| print("Wrote config to: {}".format(out_f)) |
|
|
|
|
| def get_scheduled_temperature(step, total_steps, temp_schedule_args): |
| if temp_schedule_args['mode'] == 'exp': |
| t_start = temp_schedule_args['exp']['temp_start'] |
| t_end = temp_schedule_args['exp']['temp_end'] |
| return t_start * (t_end / t_start) ** (step / total_steps) |
| else: |
| raise ValueError(f"Unknown temp_schedule_args: {temp_schedule_args}") |
|
|
|
|
| def seed_everything(seed: int): |
| random.seed(seed) |
| os.environ['PYTHONHASHSEED'] = str(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed(seed) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
| if hasattr(torch, 'mps') and torch.backends.mps.is_available(): |
| torch.mps.manual_seed(seed) |
|
|
|
|
| def seed_worker(worker_id): |
| worker_seed = torch.initial_seed() % 2**32 |
| np.random.seed(worker_seed) |
| random.seed(worker_seed) |
| torch.manual_seed(worker_seed + worker_id) |
|
|
|
|
| def format_kwargs(cfg, optional_args): |
| return { |
| arg_name: getattr(getattr(cfg, section), attr) |
| for arg_name, section, attr in optional_args |
| if hasattr(cfg, section) and hasattr(getattr(cfg, section), attr) |
| } |
|
|
|
|
| def move_inputs_to_cuda(inputs): |
| for k, v in inputs.items(): |
| if isinstance(v, torch.Tensor): |
| inputs[k] = v.cuda() |
| elif isinstance(v, dict): |
| inputs[k] = move_inputs_to_cuda(v) |
| return inputs |
|
|
|
|
| def unwrap_model(model): |
| """Unwrap DDP model if needed.""" |
| if isinstance(model, DDP): |
| return model.module |
| return model |
|
|
|
|
| def get_gazing_pos_from_gazing_mask(gazing_mask: torch.Tensor) -> torch.Tensor: |
| """ |
| Get the gazing positions from the gazing mask. |
| inputs: |
| gazing_mask: (B, N). 1 means gazed, 0 means not gazed. |
| outputs: |
| gazing_pos: (B, K). K is the maximum number of gazed tokens per instance. If the instance has less than K gazed tokens, the remaining positions are padded with -1. |
| if_padded_gazing: (B, K). 1 means padded, 0 means not padded. |
| """ |
| |
| gazing_mask = gazing_mask.to(torch.long) |
| B, N = gazing_mask.shape |
|
|
| |
| idx = torch.arange(N, device=gazing_mask.device).expand(B, N) |
|
|
| |
| |
| |
| key = (1 - gazing_mask) * N + idx |
| order = key.argsort(dim=1, stable=True) |
| sorted_idx = idx.gather(1, order) |
|
|
| |
| counts = gazing_mask.sum(dim=1) |
| K = int(counts.max().item()) |
|
|
| if K == 0: |
| return sorted_idx[:, :0] |
|
|
| topk = sorted_idx[:, :K] |
| pos = torch.arange(K, device=gazing_mask.device).expand(B, K) |
| mask = pos < counts.unsqueeze(1) |
|
|
| |
| gazing_pos = topk.masked_fill(~mask, -1) |
| if_padded_gazing = (gazing_pos == -1) |
|
|
| return gazing_pos, if_padded_gazing |