| import torch |
| from einops import reduce |
| from jaxtyping import Float, Int64 |
| from torch import Tensor |
|
|
|
|
| def sample_discrete_distribution( |
| pdf: Float[Tensor, "*batch bucket"], |
| num_samples: int, |
| eps: float = torch.finfo(torch.float32).eps, |
| ) -> tuple[ |
| Int64[Tensor, "*batch sample"], |
| Float[Tensor, "*batch sample"], |
| ]: |
| *batch, bucket = pdf.shape |
| normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum")) |
| cdf = normalized_pdf.cumsum(dim=-1) |
| samples = torch.rand((*batch, num_samples), device=pdf.device) |
| index = torch.searchsorted(cdf, samples, right=True).clip(max=bucket - 1) |
| return index, normalized_pdf.gather(dim=-1, index=index) |
|
|
|
|
| def gather_discrete_topk( |
| pdf: Float[Tensor, "*batch bucket"], |
| num_samples: int, |
| eps: float = torch.finfo(torch.float32).eps, |
| ) -> tuple[ |
| Int64[Tensor, "*batch sample"], |
| Float[Tensor, "*batch sample"], |
| ]: |
| normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum")) |
| index = pdf.topk(k=num_samples, dim=-1).indices |
| return index, normalized_pdf.gather(dim=-1, index=index) |
|
|