Spaces:
Running on Zero
Running on Zero
| import builtins | |
| from omegaconf import OmegaConf | |
| from loguru import logger | |
| import sys | |
| import os | |
| import torch | |
| import numpy as np | |
| import wandb | |
| import random | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| class UnNormalize(object): | |
| def __init__(self, mean, std, rescale_factor=None): | |
| self.mean = mean | |
| self.std = std | |
| self.rescale_factor = rescale_factor | |
| def __call__(self, image): | |
| image2 = torch.clone(image) | |
| dims = len(image2.shape) | |
| if dims == 3: | |
| image2 = image2.unsqueeze(0) | |
| image2 = image2.permute(1, 0, 2, 3) | |
| for t, m, s in zip(image2, self.mean, self.std): | |
| t.mul_(s).add_(m) | |
| image2 = image2.permute(1, 0, 2, 3) | |
| if dims == 3: | |
| image2 = image2.squeeze(0) | |
| if self.rescale_factor is not None: | |
| standard_rescale = 1.0 / 255.0 | |
| if abs(self.rescale_factor - standard_rescale) > 1e-6: | |
| # if the processor uses 1/127.5, needs /2.0 + 0.5 correction | |
| image2 = image2 / 2.0 + 0.5 | |
| return torch.clamp(image2, 0, 1) | |
| class AverageScalarMeter(object): | |
| def __init__(self, window_size): | |
| self.window_size = window_size | |
| self.current_size = 0 | |
| self.mean = 0 | |
| def update(self, values): | |
| size = values.size()[0] | |
| if size == 0: | |
| return | |
| new_mean = torch.mean(values.float(), dim=0).cpu().numpy().item() | |
| size = np.clip(size, 0, self.window_size) | |
| old_size = min(self.window_size - size, self.current_size) | |
| size_sum = old_size + size | |
| self.current_size = size_sum | |
| self.mean = (self.mean * old_size + new_mean * size) / size_sum | |
| def clear(self): | |
| self.current_size = 0 | |
| self.mean = 0 | |
| def __len__(self): | |
| return self.current_size | |
| def get_mean(self): | |
| return self.mean | |
| def plot_grad_norms(named_parameters, name_prefix=''): | |
| for name, param in named_parameters: | |
| if param.grad is not None: | |
| norm = torch.linalg.vector_norm(param.grad, 2.0).item() | |
| wandb.log({f'{name_prefix}{name}': norm}) | |
| def suppress_print(): | |
| """Suppresses printing from the current process.""" | |
| def ignore(*_objects, _sep=" ", _end="\n", _file=sys.stdout, _flush=False): | |
| pass | |
| builtins.print = ignore | |
| def suppress_wandb(): | |
| """Suppresses wandb logging from the current_process.""" | |
| # Store original functions | |
| original_functions = {} | |
| for attr_name in dir(wandb): | |
| attr = getattr(wandb, attr_name) | |
| if callable(attr) and not attr_name.startswith('__'): | |
| original_functions[attr_name] = attr | |
| # Replace with no-op function | |
| def make_noop(name): | |
| def noop(*args, **kwargs): | |
| pass | |
| return noop | |
| setattr(wandb, attr_name, make_noop(attr_name)) | |
| def suppress_logging(): | |
| """Suppresses loguru logging from the current process.""" | |
| logger.remove() # Remove all handlers | |
| logger.add(lambda _: None) # Add a no-op handler | |
| def dump_cfg(cfg, logdir): | |
| out_f = os.path.join(logdir, "config.yaml") | |
| with open(out_f, "w") as f: | |
| f.write(OmegaConf.to_yaml(cfg)) | |
| print("Wrote config to: {}".format(out_f)) | |
| def get_scheduled_temperature(step, total_steps, temp_schedule_args): | |
| if temp_schedule_args['mode'] == 'exp': | |
| t_start = temp_schedule_args['exp']['temp_start'] | |
| t_end = temp_schedule_args['exp']['temp_end'] | |
| return t_start * (t_end / t_start) ** (step / total_steps) | |
| else: | |
| raise ValueError(f"Unknown temp_schedule_args: {temp_schedule_args}") | |
| def seed_everything(seed: int): | |
| random.seed(seed) | |
| os.environ['PYTHONHASHSEED'] = str(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| if hasattr(torch, 'mps') and torch.backends.mps.is_available(): | |
| torch.mps.manual_seed(seed) | |
| def seed_worker(worker_id): | |
| worker_seed = torch.initial_seed() % 2**32 | |
| np.random.seed(worker_seed) | |
| random.seed(worker_seed) | |
| torch.manual_seed(worker_seed + worker_id) # Add worker_id to make it different | |
| def format_kwargs(cfg, optional_args): | |
| return { | |
| arg_name: getattr(getattr(cfg, section), attr) | |
| for arg_name, section, attr in optional_args | |
| if hasattr(cfg, section) and hasattr(getattr(cfg, section), attr) | |
| } | |
| def move_inputs_to_cuda(inputs): | |
| for k, v in inputs.items(): | |
| if isinstance(v, torch.Tensor): | |
| inputs[k] = v.cuda() | |
| elif isinstance(v, dict): | |
| inputs[k] = move_inputs_to_cuda(v) | |
| return inputs | |
| def unwrap_model(model): | |
| """Unwrap DDP model if needed.""" | |
| if isinstance(model, DDP): | |
| return model.module | |
| return model | |
| def get_gazing_pos_from_gazing_mask(gazing_mask: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Get the gazing positions from the gazing mask. | |
| inputs: | |
| gazing_mask: (B, N). 1 means gazed, 0 means not gazed. | |
| outputs: | |
| gazing_pos: (B, K). K is the maximum number of gazed tokens per instance. If the instance has less than K gazed tokens, the remaining positions are padded with -1. | |
| if_padded_gazing: (B, K). 1 means padded, 0 means not padded. | |
| """ | |
| # x: (B, N) with 0/1 values (float/bool/int all fine) | |
| gazing_mask = gazing_mask.to(torch.long) | |
| B, N = gazing_mask.shape | |
| # Indices per row | |
| idx = torch.arange(N, device=gazing_mask.device).expand(B, N) | |
| # Sort key: put ones first, keep original order among ones/zeros | |
| # - ones get key = idx (0..N-1) | |
| # - zeros get key = N + idx (pushed after all ones) | |
| key = (1 - gazing_mask) * N + idx | |
| order = key.argsort(dim=1, stable=True) # (B, N) | |
| sorted_idx = idx.gather(1, order) # ones first, then zeros | |
| # Max number of ones (K) and per-row counts | |
| counts = gazing_mask.sum(dim=1) # (B,) | |
| K = int(counts.max().item()) | |
| if K == 0: | |
| return sorted_idx[:, :0] # (B, 0) empty result | |
| topk = sorted_idx[:, :K] # (B, K) | |
| pos = torch.arange(K, device=gazing_mask.device).expand(B, K) | |
| mask = pos < counts.unsqueeze(1) # True where a real "1" exists | |
| # Pad with -1 where the row has fewer than K ones | |
| gazing_pos = topk.masked_fill(~mask, -1) | |
| if_padded_gazing = (gazing_pos == -1) | |
| return gazing_pos, if_padded_gazing |