| """Diffusion utility functions.""" | |
| from functools import reduce | |
| from inspect import isfunction | |
| from math import ceil, floor, log2 | |
| from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union | |
| import torch | |
| import torch.nn.functional as F | |
| from typing_extensions import TypeGuard | |
| T = TypeVar("T") | |
| def exists(val: Optional[T]) -> TypeGuard[T]: | |
| return val is not None | |
| def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T: | |
| if exists(val): | |
| return val | |
| return d() if isfunction(d) else d | |
| def rand_bool(shape, proba, device=None): | |
| if proba == 1: | |
| return torch.ones(shape, device=device, dtype=torch.bool) | |
| elif proba == 0: | |
| return torch.zeros(shape, device=device, dtype=torch.bool) | |
| else: | |
| return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool) | |
| def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]: | |
| kwargs_with_prefix = {k: v for k, v in d.items() if k.startswith(prefix)} | |
| kwargs = {k: v for k, v in d.items() if not k.startswith(prefix)} | |
| if keep_prefix: | |
| return kwargs_with_prefix, kwargs | |
| kwargs_no_prefix = {k[len(prefix):]: v for k, v in kwargs_with_prefix.items()} | |
| return kwargs_no_prefix, kwargs | |