Seemanth's picture
Upload Chiluka TTS model
f28049f verified
"""Diffusion utility functions."""
from functools import reduce
from inspect import isfunction
from math import ceil, floor, log2
from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
import torch
import torch.nn.functional as F
from typing_extensions import TypeGuard
T = TypeVar("T")
def exists(val: Optional[T]) -> TypeGuard[T]:
return val is not None
def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
if exists(val):
return val
return d() if isfunction(d) else d
def rand_bool(shape, proba, device=None):
if proba == 1:
return torch.ones(shape, device=device, dtype=torch.bool)
elif proba == 0:
return torch.zeros(shape, device=device, dtype=torch.bool)
else:
return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
kwargs_with_prefix = {k: v for k, v in d.items() if k.startswith(prefix)}
kwargs = {k: v for k, v in d.items() if not k.startswith(prefix)}
if keep_prefix:
return kwargs_with_prefix, kwargs
kwargs_no_prefix = {k[len(prefix):]: v for k, v in kwargs_with_prefix.items()}
return kwargs_no_prefix, kwargs