from typing import Callable, Tuple import torch def compute_ess(w, dim=-1): ess = (w.sum(dim=dim))**2 / torch.sum(w**2, dim=dim) return ess def compute_ess_from_log_w(log_w, dim=-1): return compute_ess(normalize_weights(log_w, dim=dim), dim=dim) def normalize_weights(log_weights, dim=-1): return torch.exp(normalize_log_weights(log_weights, dim=dim)) def normalize_log_weights(log_weights, dim=-1): log_weights = log_weights - log_weights.max(dim=dim, keepdims=True)[0] log_weights = log_weights - torch.logsumexp(log_weights, dim=dim, keepdims=True) # type: ignore return log_weights def stratified_resample(log_weights: torch.Tensor): N = log_weights.shape[0] weights = normalize_weights(log_weights) cdf = torch.cumsum(weights, dim=0) # Stratified uniform samples u = (torch.arange(N, dtype=torch.float32, device=log_weights.device) + torch.rand(N, device=log_weights.device)) / N indices = torch.searchsorted(cdf, u, right=True) return indices def systematic_resample(log_weights: torch.Tensor, normalized=True): N = log_weights.shape[0] weights = normalize_weights(log_weights) cdf = torch.cumsum(weights, dim=0) # Systematic uniform samples u0 = torch.rand(1, device=log_weights.device) / N u = u0 + torch.arange(N, dtype=torch.float32, device=log_weights.device) / N indices = torch.searchsorted(cdf, u, right=True) return indices def multinomial_resample(log_weights: torch.Tensor, normalized=True): N = log_weights.shape[0] weights = normalize_weights(log_weights) resampled_indices = torch.multinomial(weights, N, replacement=True) return resampled_indices def partial_resample(log_weights: torch.Tensor, resample_fn: Callable[[torch.Tensor], torch.Tensor], M: int) -> Tuple[torch.Tensor, torch.Tensor]: """ Perform partial resampling on a set of particles using PyTorch. Args: log_weights (torch.Tensor): 1D tensor of shape (K,) containing log-weights. resample_fn (callable): function that takes log_weights and n_samples, returning a tensor of shape (n_samples,) of sampled indices. M (int): total number of particles to resample. Returns: new_indices (torch.Tensor): 1D tensor of shape (K,) mapping each output slot to an original particle index. new_log_weights (torch.Tensor): 1D tensor of shape (K,) of updated log-weights. """ K = log_weights.numel() # Convert log-weights to normalized weights log_weights = normalize_log_weights(log_weights) weights = torch.exp(log_weights) # Determine how many high and low weights to resample M_hi = 1 # M // 2 M_lo = M - M_hi # Get indices of highest and lowest weights _, hi_idx = torch.topk(weights, M_hi, largest=True) _, lo_idx = torch.topk(weights, M_lo, largest=False) I = torch.cat([hi_idx, lo_idx]) # indices selected for resampling # Perform multinomial resampling only on selected subset # resample_fn expects log-weights of the subset subset_logw = log_weights[I] local_sampled = resample_fn(subset_logw) # indices in [0, len(I)) # Map back to original indices sampled = I[local_sampled] # Build new index mapping: default to identity (retain original) new_indices = torch.arange(K, device=log_weights.device) new_indices[I] = sampled # Compute new uniform weight for resampled particles total_I_weight = weights[I].sum() uniform_weight = total_I_weight / M # Prepare new log-weights new_log_weight = torch.empty_like(log_weights) # For non-resampled, keep original log-weights mask = torch.ones(K, dtype=torch.bool, device=log_weights.device) mask[I] = False new_log_weight[mask] = log_weights[mask] # For resampled, assign uniform log-weight new_log_weight[I] = torch.log(uniform_weight) return new_indices, new_log_weight def resample(log_w, ess_threshold=None, partial=False): """ Resample the log weights and return the indices of the resampled particles. Parameters ---------- log_w : array_like The log weights of the particles. ess_threshold : float, optional The effective sample size (ESS) threshold. If the ESS is below this threshold, resampling is performed. If None, no resampling is performed. partial : bool, optional If True, the resampling is performed on the partial weights. If False, the resampling is performed on the full weights. Returns ------- array_like The indices of the resampled particles. """ base_sampling_fn = systematic_resample N = log_w.size(0) ess = compute_ess_from_log_w(log_w) if ess_threshold is not None and ess >= ess_threshold * N: # Skip resampling as ess is not below the threshold return ( torch.arange(N, device=log_w.device), False, log_w ) if partial: resample_indices, log_w = partial_resample(log_w, base_sampling_fn, N // 2) else: resample_indices = base_sampling_fn(log_w) log_w = torch.zeros_like(log_w) return ( resample_indices, True, log_w )