"""Diffusion utility functions.""" from functools import reduce from inspect import isfunction from math import ceil, floor, log2 from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union import torch import torch.nn.functional as F from typing_extensions import TypeGuard T = TypeVar("T") def exists(val: Optional[T]) -> TypeGuard[T]: return val is not None def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T: if exists(val): return val return d() if isfunction(d) else d def rand_bool(shape, proba, device=None): if proba == 1: return torch.ones(shape, device=device, dtype=torch.bool) elif proba == 0: return torch.zeros(shape, device=device, dtype=torch.bool) else: return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool) def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]: kwargs_with_prefix = {k: v for k, v in d.items() if k.startswith(prefix)} kwargs = {k: v for k, v in d.items() if not k.startswith(prefix)} if keep_prefix: return kwargs_with_prefix, kwargs kwargs_no_prefix = {k[len(prefix):]: v for k, v in kwargs_with_prefix.items()} return kwargs_no_prefix, kwargs