Spaces:
Sleeping
Sleeping
File size: 5,342 Bytes
9a96e6d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
from typing import Callable, Tuple
import torch
def compute_ess(w, dim=-1):
ess = (w.sum(dim=dim))**2 / torch.sum(w**2, dim=dim)
return ess
def compute_ess_from_log_w(log_w, dim=-1):
return compute_ess(normalize_weights(log_w, dim=dim), dim=dim)
def normalize_weights(log_weights, dim=-1):
return torch.exp(normalize_log_weights(log_weights, dim=dim))
def normalize_log_weights(log_weights, dim=-1):
log_weights = log_weights - log_weights.max(dim=dim, keepdims=True)[0]
log_weights = log_weights - torch.logsumexp(log_weights, dim=dim, keepdims=True) # type: ignore
return log_weights
def stratified_resample(log_weights: torch.Tensor):
N = log_weights.shape[0]
weights = normalize_weights(log_weights)
cdf = torch.cumsum(weights, dim=0)
# Stratified uniform samples
u = (torch.arange(N, dtype=torch.float32, device=log_weights.device) + torch.rand(N, device=log_weights.device)) / N
indices = torch.searchsorted(cdf, u, right=True)
return indices
def systematic_resample(log_weights: torch.Tensor, normalized=True):
N = log_weights.shape[0]
weights = normalize_weights(log_weights)
cdf = torch.cumsum(weights, dim=0)
# Systematic uniform samples
u0 = torch.rand(1, device=log_weights.device) / N
u = u0 + torch.arange(N, dtype=torch.float32, device=log_weights.device) / N
indices = torch.searchsorted(cdf, u, right=True)
return indices
def multinomial_resample(log_weights: torch.Tensor, normalized=True):
N = log_weights.shape[0]
weights = normalize_weights(log_weights)
resampled_indices = torch.multinomial(weights, N, replacement=True)
return resampled_indices
def partial_resample(log_weights: torch.Tensor,
resample_fn: Callable[[torch.Tensor], torch.Tensor],
M: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Perform partial resampling on a set of particles using PyTorch.
Args:
log_weights (torch.Tensor): 1D tensor of shape (K,) containing log-weights.
resample_fn (callable): function that takes log_weights and n_samples,
returning a tensor of shape (n_samples,) of sampled indices.
M (int): total number of particles to resample.
Returns:
new_indices (torch.Tensor): 1D tensor of shape (K,) mapping each output slot to
an original particle index.
new_log_weights (torch.Tensor): 1D tensor of shape (K,) of updated log-weights.
"""
K = log_weights.numel()
# Convert log-weights to normalized weights
log_weights = normalize_log_weights(log_weights)
weights = torch.exp(log_weights)
# Determine how many high and low weights to resample
M_hi = 1 # M // 2
M_lo = M - M_hi
# Get indices of highest and lowest weights
_, hi_idx = torch.topk(weights, M_hi, largest=True)
_, lo_idx = torch.topk(weights, M_lo, largest=False)
I = torch.cat([hi_idx, lo_idx]) # indices selected for resampling
# Perform multinomial resampling only on selected subset
# resample_fn expects log-weights of the subset
subset_logw = log_weights[I]
local_sampled = resample_fn(subset_logw) # indices in [0, len(I))
# Map back to original indices
sampled = I[local_sampled]
# Build new index mapping: default to identity (retain original)
new_indices = torch.arange(K, device=log_weights.device)
new_indices[I] = sampled
# Compute new uniform weight for resampled particles
total_I_weight = weights[I].sum()
uniform_weight = total_I_weight / M
# Prepare new log-weights
new_log_weight = torch.empty_like(log_weights)
# For non-resampled, keep original log-weights
mask = torch.ones(K, dtype=torch.bool, device=log_weights.device)
mask[I] = False
new_log_weight[mask] = log_weights[mask]
# For resampled, assign uniform log-weight
new_log_weight[I] = torch.log(uniform_weight)
return new_indices, new_log_weight
def resample(log_w, ess_threshold=None, partial=False):
"""
Resample the log weights and return the indices of the resampled particles.
Parameters
----------
log_w : array_like
The log weights of the particles.
ess_threshold : float, optional
The effective sample size (ESS) threshold. If the ESS is below this
threshold, resampling is performed. If None, no resampling is
performed.
partial : bool, optional
If True, the resampling is performed on the partial weights. If False,
the resampling is performed on the full weights.
Returns
-------
array_like
The indices of the resampled particles.
"""
base_sampling_fn = systematic_resample
N = log_w.size(0)
ess = compute_ess_from_log_w(log_w)
if ess_threshold is not None and ess >= ess_threshold * N:
# Skip resampling as ess is not below the threshold
return (
torch.arange(N, device=log_w.device),
False,
log_w
)
if partial:
resample_indices, log_w = partial_resample(log_w, base_sampling_fn, N // 2)
else:
resample_indices = base_sampling_fn(log_w)
log_w = torch.zeros_like(log_w)
return (
resample_indices,
True,
log_w
)
|