File size: 5,832 Bytes
94dc344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe


import torch
from pytorch3d import _C


def sample_pdf(
    bins: torch.Tensor,
    weights: torch.Tensor,
    n_samples: int,
    det: bool = False,
    eps: float = 1e-5,
) -> torch.Tensor:
    """
    Samples probability density functions defined by bin edges `bins` and
    the non-negative per-bin probabilities `weights`.

    Args:
        bins: Tensor of shape `(..., n_bins+1)` denoting the edges of the sampling bins.
        weights: Tensor of shape `(..., n_bins)` containing non-negative numbers
            representing the probability of sampling the corresponding bin.
        n_samples: The number of samples to draw from each set of bins.
        det: If `False`, the sampling is random. `True` yields deterministic
            uniformly-spaced sampling from the inverse cumulative density function.
        eps: A constant preventing division by zero in case empty bins are present.

    Returns:
        samples: Tensor of shape `(..., n_samples)` containing `n_samples` samples
            drawn from each probability distribution.

    Refs:
        [1] https://github.com/bmild/nerf/blob/55d8b00244d7b5178f4d003526ab6667683c9da9/run_nerf_helpers.py#L183  # noqa E501
    """
    if torch.is_grad_enabled() and (bins.requires_grad or weights.requires_grad):
        raise NotImplementedError("sample_pdf differentiability.")
    if weights.min() <= -eps:
        raise ValueError("Negative weights provided.")
    batch_shape = bins.shape[:-1]
    n_bins = weights.shape[-1]
    if n_bins + 1 != bins.shape[-1] or weights.shape[:-1] != batch_shape:
        shapes = f"{bins.shape}{weights.shape}"
        raise ValueError("Inconsistent shapes of bins and weights: " + shapes)
    output_shape = batch_shape + (n_samples,)

    if det:
        u = torch.linspace(0.0, 1.0, n_samples, device=bins.device, dtype=torch.float32)
        output = u.expand(output_shape).contiguous()
    else:
        output = torch.rand(output_shape, dtype=torch.float32, device=bins.device)

    # pyre-fixme[16]: Module `pytorch3d` has no attribute `_C`.
    _C.sample_pdf(
        bins.reshape(-1, n_bins + 1),
        weights.reshape(-1, n_bins),
        output.reshape(-1, n_samples),
        eps,
    )

    return output


def sample_pdf_python(
    bins: torch.Tensor,
    weights: torch.Tensor,
    N_samples: int,
    det: bool = False,
    eps: float = 1e-5,
) -> torch.Tensor:
    """
    This is a pure python implementation of the `sample_pdf` function.
    It may be faster than sample_pdf when the number of bins is very large,
    because it behaves as O(batchsize * [n_bins + log(n_bins) * n_samples] )
    whereas sample_pdf behaves as O(batchsize * n_bins * n_samples).
    For 64 bins sample_pdf is much faster.

    Samples probability density functions defined by bin edges `bins` and
    the non-negative per-bin probabilities `weights`.

    Note: This is a direct conversion of the TensorFlow function from the original
    release [1] to PyTorch. It requires PyTorch 1.6 or greater due to the use of
    torch.searchsorted.

    Args:
        bins: Tensor of shape `(..., n_bins+1)` denoting the edges of the sampling bins.
        weights: Tensor of shape `(..., n_bins)` containing non-negative numbers
            representing the probability of sampling the corresponding bin.
        N_samples: The number of samples to draw from each set of bins.
        det: If `False`, the sampling is random. `True` yields deterministic
            uniformly-spaced sampling from the inverse cumulative density function.
        eps: A constant preventing division by zero in case empty bins are present.

    Returns:
        samples: Tensor of shape `(..., N_samples)` containing `N_samples` samples
            drawn from each probability distribution.

    Refs:
        [1] https://github.com/bmild/nerf/blob/55d8b00244d7b5178f4d003526ab6667683c9da9/run_nerf_helpers.py#L183  # noqa E501
    """

    # Get pdf
    weights = weights + eps  # prevent nans
    if weights.min() <= 0:
        raise ValueError("Negative weights provided.")
    pdf = weights / weights.sum(dim=-1, keepdim=True)
    cdf = torch.cumsum(pdf, -1)
    cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)

    # Take uniform samples u of shape (..., N_samples)
    if det:
        u = torch.linspace(0.0, 1.0, N_samples, device=cdf.device, dtype=cdf.dtype)
        u = u.expand(list(cdf.shape[:-1]) + [N_samples]).contiguous()
    else:
        u = torch.rand(
            list(cdf.shape[:-1]) + [N_samples], device=cdf.device, dtype=cdf.dtype
        )

    # Invert CDF
    inds = torch.searchsorted(cdf, u, right=True)
    # inds has shape (..., N_samples) identifying the bin of each sample.
    below = (inds - 1).clamp(0)
    above = inds.clamp(max=cdf.shape[-1] - 1)
    # Below and above are of shape (..., N_samples), identifying the bin
    # edges surrounding each sample.

    inds_g = torch.stack([below, above], -1).view(
        *below.shape[:-1], below.shape[-1] * 2
    )
    cdf_g = torch.gather(cdf, -1, inds_g).view(*below.shape, 2)
    bins_g = torch.gather(bins, -1, inds_g).view(*below.shape, 2)
    # cdf_g and bins_g are of shape (..., N_samples, 2) and identify
    # the cdf and the index of the two bin edges surrounding each sample.

    denom = cdf_g[..., 1] - cdf_g[..., 0]
    denom = torch.where(denom < eps, torch.ones_like(denom), denom)
    t = (u - cdf_g[..., 0]) / denom
    # t is of shape  (..., N_samples) and identifies how far through
    # each sample is in its bin.

    samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])

    return samples