# TOOD: This is a very inflexible sampling algorithm -- Only works for semiautoregressive with one token addition at one time # TODO: This code is quite bad, we'd like to refactor, can we use ein / einops? import torch from dataclasses import dataclass from typing import Any, Literal, Optional from lightning_modules.mdm import MaskedDiffusionModule @dataclass class SamplingTraceDatapoint: t: float event_type: Literal["insertion", "change"] position: int token: Any @dataclass class SamplingResult: samples: torch.Tensor # Trace is supposed to be processed sequentially as updates are not commutative trace: Optional[list[SamplingTraceDatapoint]] def __iter__(self): yield from [self.samples, self.trace] # Sample from categorical distribution for each position using the transition probabilities def _sample_tokens(probs: torch.Tensor) -> torch.Tensor: """Sample one token per position from probability distribution. Args: probs: [batch_size, seq_len, vocab_size] transition probabilities Returns: [batch_size, seq_len] sampled token indices """ batch_size, seq_len, vocab_size = probs.shape flat_probs = probs.view(-1, vocab_size) samples = torch.multinomial(flat_probs, num_samples=1) return samples.view(batch_size, seq_len) @torch.no_grad() def mdm_euler_sampling( model: MaskedDiffusionModule, steps: int, mask: int, pad: int, batch_size: int, max_length: int, return_trace: bool = False, ): assert not return_trace, "Trace is not yet implemented in MDM Euler sampling" device = model.device xt = torch.full((batch_size, max_length), mask, dtype=torch.int64, device=device) dt = 1.0 / steps t = torch.zeros(batch_size, device=device) for i in range(steps): print("i-th sampling step") # ——— predict and convert rates ——— pred_rate = model(xt, t) pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) unmask_rate = pred_rate.unmask_rate # ——— unmask step (Euler) ——— mask_pos = (xt == mask).nonzero(as_tuple=True) unmask_rate[xt != mask] = 0 unmask_rate[*mask_pos, mask] = 0 unmask_rate[*mask_pos, mask] = -unmask_rate[*mask_pos, :].sum(dim=1) trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) _xt = xt.clone() trans_prob.scatter_add_( 2, _xt.unsqueeze(-1), torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), ) if i == steps - 1: print("Final step, removing mask token from sampling") trans_prob[*mask_pos, mask] = 0.0 print(trans_prob[*mask_pos, mask]) new_xt = _sample_tokens(trans_prob) new_xt = torch.where(xt != mask, xt, new_xt) xt = new_xt t = t + dt return xt, [] @torch.no_grad() def any_order_mask_insertion_euler_sampling( model: torch.nn.Module, steps: int, mask: int, pad: int, batch_size: int, max_length: int, return_trace: bool = False, ) -> SamplingResult: device = model.device # 1) Initialize all‑pad sequence and trace xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device) sampling_trace = [] dt = 1.0 / steps t = torch.zeros(batch_size, device=device) # Precompute row indices for scatter batch_idx_L = ( torch.arange(batch_size, device=device) .view(batch_size, 1) .expand(batch_size, max_length) ) pos_idx_L = ( torch.arange(max_length, device=device) .view(1, max_length) .expand(batch_size, max_length) ) sampling_trace = [[] for _ in range(batch_size)] if return_trace else None for i in range(steps): # ——— predict and convert rates ——— pred_rate = model(xt, t) pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) unmask_rate = pred_rate.unmask_rate # (B, L, V) len_rate = pred_rate.length_rate # (B, L+1) # ——— unmask step (Euler) ——— mask_pos = (xt == mask).nonzero(as_tuple=True) unmask_rate[xt != mask] = 0 unmask_rate[*mask_pos, mask] = 0 unmask_rate[*mask_pos, mask] = -unmask_rate[*mask_pos, :].sum(dim=1) trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) # add “stay” probability _xt = xt.clone() _xt[xt == pad] = mask trans_prob.scatter_add_( 2, _xt.unsqueeze(-1), torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), ) if i == steps - 1: print("Final step, removing mask token from sampling") trans_prob[*mask_pos, mask] = ( 0.0 # remove mask token from sampling at the last step ) print(trans_prob[*mask_pos, mask]) new_xt = _sample_tokens(trans_prob) new_xt[xt == pad] = pad new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) if i != steps - 1: # ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ——— ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1) xt_len = xt.ne(pad).sum(dim=1) # (B,) gaps = torch.arange(max_length + 1, device=device).view(1, -1) ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() total_ext = ext.sum(dim=1) valid = xt_len + total_ext <= max_length ext = ext * valid.view(batch_size, 1).long() ext_ex = ext.int().cumsum(dim=1) # (B, L+1) new_len = xt_len + total_ext # (B,) xt_tmp = torch.full_like(xt, pad) mask_fill = pos_idx_L < new_len.view(batch_size, 1) xt_tmp[mask_fill] = mask new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L) orig_mask = pos_idx_L < xt_len.view(batch_size, 1) flat_b = batch_idx_L[orig_mask] flat_p = new_pos_orig[orig_mask] xt_tmp[flat_b, flat_p] = new_xt[orig_mask] else: xt_tmp = new_xt if return_trace: # Check if the token was changed for i in range(batch_size): for j in range(max_length): if xt[i, j] != pad and xt[i, j] != new_xt[i, j]: sampling_trace[i].append( SamplingTraceDatapoint( t=t[i].item(), event_type="change", position=j, token=new_xt[i, j].item(), ) ) # Check if a new token was inserted for j in range(max_length): id = max_length - j - 1 if ext[i, id]: sampling_trace[i].append( SamplingTraceDatapoint( t=t[i].item(), event_type="insertion", position=id, token=mask, ) ) xt = xt_tmp t = t + dt return xt, sampling_trace @torch.no_grad() def mdm_tau_leaping_sampling( model: MaskedDiffusionModule, steps: int, mask: int, pad: int, batch_size: int, max_length: int, return_trace: bool = False, ): assert not return_trace, "Trace is not yet supported" device = model.device xt = torch.full((batch_size, max_length), mask, dtype=torch.int64, device=device) dt = 1.0 / steps t = torch.zeros(batch_size, device=device) for i in range(steps): # ——— predict and convert rates ——— pred = model(xt, t) pred = model.interpolant.to_actual_rate(xt, pred, t) unmask_rate = pred.unmask_rate # (B, L, V) if i == steps - 1: # last step: deterministic unmask via argmax mask_pos = xt == mask # (B, L) new_token = unmask_rate.argmax(dim=2) # (B, L) new_xt = xt.clone() new_xt[mask_pos] = new_token[mask_pos] new_xt = torch.where(xt != mask, xt, new_xt) xt = new_xt t = t + dt continue # tau-leaping via Poisson counts counts = torch.poisson(unmask_rate * dt).long() mask_pos = xt == mask # (B, L) # zero out non-mask positions and mask→mask counts[~mask_pos.unsqueeze(-1).expand_as(counts)] = 0 counts[..., mask] = 0 # only accept exactly one event sum_c = counts.sum(dim=2) # (B, L) one_event = sum_c == 1 new_token = counts.argmax(dim=2) # (B, L) # build new xt new_xt = xt.clone() new_xt[one_event] = new_token[one_event] # keep pads and already-unmasked tokens new_xt = torch.where(xt != mask, xt, new_xt) xt = new_xt t = t + dt return xt, [] # Not used in production, for debugging purposes lengths = {4: 0.1, 16: 0.4, 32: 0.4, 64: 0.1} def binomial_mass(k, n, p): """ Calculate the probability mass function (PMF) for a binomial distribution. Args: k (int): Number of successes n (int): Number of trials p (float): Probability of success in a single trial Returns: float: Probability mass P(X = k) """ import math # Calculate binomial coefficient (n choose k) try: binom_coef = math.factorial(n) / (math.factorial(k) * math.factorial(n - k)) except ValueError: # Handle cases where k > n or negative values return 0.0 # Calculate probability mass return binom_coef * (p ** k) * ((1 - p) ** (n - k)) def calculate_rate_batch(alpha_t, len_t): """ Calculate rate for a batch of alpha_t and len_t values. Args: alpha_t (torch.Tensor): Tensor of shape (batch_size,) len_t (torch.Tensor): Tensor of shape (batch_size,) Returns: torch.Tensor: Tensor of shape (batch_size,) containing calculated rates """ batch_size = alpha_t.shape[0] device = alpha_t.device # Initialize tensors for numerator and denominator nom = torch.zeros(batch_size, device=device) denom = torch.zeros(batch_size, device=device) for length, probability in lengths.items(): # Create mask for valid entries where len_t <= length valid_mask = (len_t <= length) & (len_t >= 0) if not valid_mask.any(): continue valid_indices = valid_mask.nonzero(as_tuple=True)[0] valid_len_t = len_t[valid_indices] valid_alpha_t = alpha_t[valid_indices] # Calculate binomial probabilities efficiently using torch distribution binom_dist = torch.distributions.Binomial(total_count=length, probs=valid_alpha_t) binom_probs = binom_dist.log_prob(valid_len_t).exp() # Update numerator and denominator for valid indices nom[valid_indices] += (length - valid_len_t) * probability * binom_probs denom[valid_indices] += probability * binom_probs # Handle division by zero in a vectorized way result = torch.zeros_like(nom) div_mask = denom > 0 result[div_mask] = nom[div_mask] / (denom[div_mask]) return result # Keep the original function for backward compatibility def calculate_rate(alpha_t, len_t): """Legacy scalar version of calculate_rate""" if isinstance(alpha_t, torch.Tensor) and alpha_t.ndim > 0: return calculate_rate_batch(alpha_t, len_t) nom, denom = 0, 0 for length, probability in lengths.items(): if length >= len_t: nom += (length - len_t) * probability * binomial_mass(len_t, length, alpha_t) denom += probability * binomial_mass(len_t, length, alpha_t) if denom == 0: return 0.0 return nom /denom @torch.no_grad() @torch.compile(mode="reduce-overhead") def any_order_mask_insertion_tau_leaping_sampling( model: torch.nn.Module, steps: int, mask: int, pad: int, batch_size: int, max_length: int, return_trace: bool = False, confidence_based_sampling: bool = True, # whether to use confidence-based decoding alpha: float = 5.0, # hyperparameter for window size calculation max_window: int = 32, # Maximum window size for sliding window confidence_method: str = "prob_diff", # "position", "top_prob", "prob_diff", "entropy" use_sliding_window: bool = False, # whether to use sliding window for position selection ) -> SamplingResult: device = model.device xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device) sampling_trace = [] dt = 1.0 / steps t = torch.zeros(batch_size, device=device) # Precompute row indices for scatter batch_idx_L = ( torch.arange(batch_size, device=device) .view(batch_size, 1) .expand(batch_size, max_length) ) pos_idx_L = ( torch.arange(max_length, device=device) .view(1, max_length) .expand(batch_size, max_length) ) for i in range(steps): # --- predict rates --- pred = model(xt, t) xt_len = (xt != pad).sum(dim=1) pred = model.interpolant.to_actual_rate(xt, pred, t) unmask_rate = pred.unmask_rate # (B, L, V) len_rate = pred.length_rate # (B, L+1) if i == steps - 1: # last step: deterministic unmask via argmax mask_pos = xt == mask new_token = unmask_rate.argmax(dim=2) new_xt = xt.clone() new_xt[mask_pos] = new_token[mask_pos] new_xt = torch.where(xt == pad, pad, new_xt) new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) xt = new_xt t = t + dt continue # --- confidence-based decoding --- if confidence_based_sampling > 0.0: # Confidence-based unmasking (vectorized) mask_positions = (xt == mask) # (B, L) num_mask_positions = mask_positions.sum(dim=1) # (B,) # 1. Determine number of tokens to unmask using Poisson unmask_counts = torch.poisson(num_mask_positions.float() * dt).long() # (B,) # 2. Calculate confidence based on selected method if confidence_method == "position": # Position-based confidence: position i / len(xt) xt_len = (xt != pad).sum(dim=1) # (B,) - current sequence lengths position_indices = torch.arange(max_length, device=device).unsqueeze(0).expand(batch_size, -1) # (B, L) confidence = 1.0 - (position_indices.float() / xt_len.unsqueeze(1).float().clamp(min=1)) # (B, L) elif confidence_method == "top_prob": # Top probability confidence import torch.nn.functional as F token_logits = unmask_rate # (B, L, V) - use the unmask_rate as logits unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V) confidence = unmask_probs.max(dim=-1)[0] # (B, L) elif confidence_method == "prob_diff": # Probability difference confidence (top - second top) import torch.nn.functional as F token_logits = unmask_rate # (B, L, V) unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V) top2_probs, _ = torch.topk(unmask_probs, k=2, dim=-1) # (B, L, 2) confidence = top2_probs[:, :, 0] - top2_probs[:, :, 1] # (B, L) elif confidence_method == "entropy": # Entropy-based confidence (lower entropy = higher confidence) import torch.nn.functional as F token_logits = unmask_rate # (B, L, V) unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V) entropy = -torch.sum(unmask_probs * torch.log(unmask_probs + 1e-10), dim=-1) # (B, L) confidence = -entropy # (B, L) - negative entropy so lower entropy gives higher confidence else: raise ValueError(f"Unknown confidence_method: {confidence_method}") # 3. Apply window constraint if enabled if use_sliding_window: # Calculate dynamic k for each batch k_values = torch.minimum( torch.minimum( (alpha * unmask_counts).long(), torch.tensor(max_window, device=device) ), num_mask_positions) # (B,) # Get cumulative count of mask positions mask_cumsum = mask_positions.cumsum(dim=1) # (B, L) # Create window mask: position is eligible if it's a mask and within first k masks is_within_window = mask_cumsum <= k_values.unsqueeze(1) # (B, L) window_mask = mask_positions & is_within_window # (B, L) # Set confidence to -inf for positions outside the window or non-mask positions confidence = torch.where(window_mask, confidence, torch.tensor(-float('inf'), device=device)) else: # No window constraint - only mask positions are eligible confidence = torch.where(mask_positions, confidence, torch.tensor(-float('inf'), device=device)) new_xt = xt.clone() # Vectorized unmasking max_unmask = unmask_counts.max().item() if max_unmask > 0: # Get top-k indices for all batches _, all_top_indices = torch.topk(confidence, k=max_unmask, dim=1, largest=True) # (B, max_unmask) # Create mask for valid unmask operations unmask_mask = torch.arange(max_unmask, device=device).unsqueeze(0) < unmask_counts.unsqueeze(1) # (B, max_unmask) # Get most likely tokens most_likely_tokens = unmask_rate.argmax(dim=-1) # (B, L) # Gather the tokens to place at selected positions selected_positions = all_top_indices[unmask_mask] # Flattened valid positions batch_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, max_unmask)[unmask_mask] # Corresponding batch indices # Apply unmasking with sampled tokens new_xt[batch_indices, selected_positions] = most_likely_tokens[batch_indices, selected_positions] else: # --- tau-leaping unmask via Poisson --- counts = torch.poisson(unmask_rate * dt).long() mask_pos = xt == mask counts[~mask_pos.unsqueeze(-1).expand_as(counts)] = 0 counts[..., mask] = 0 sum_c = counts.sum(dim=2) one_event = sum_c == 1 new_token = counts.argmax(dim=2) new_xt = xt.clone() new_xt[one_event] = new_token[one_event] new_xt = torch.where(xt == pad, pad, new_xt) new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) # insertion only on non-last if i != steps - 1: # --- Poisson insertion, compute new lengths and fill masks --- ext = torch.poisson(len_rate * dt).long() # (B, L+1) xt_len = xt.ne(pad).sum(dim=1) # (B,) gaps = torch.arange(max_length + 1, device=device).view(1, -1) ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() total_ext = ext.sum(dim=1) valid = xt_len + total_ext <= max_length ext = ext * valid.view(batch_size, 1).long() # compute prefix sums of insertions ext_ex = ext.int().cumsum(dim=1) # (B, L+1) new_len = xt_len + total_ext # (B,) # initialize with pads, then fill mask up to new_len xt_tmp = torch.full_like(xt, pad) mask_pos = pos_idx_L < new_len.view(batch_size, 1) xt_tmp[mask_pos] = mask # shift and scatter original tokens new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L) orig_mask = pos_idx_L < xt_len.view(batch_size, 1) flat_b = batch_idx_L[orig_mask] flat_p = new_pos_orig[orig_mask] xt_tmp[flat_b, flat_p] = new_xt[orig_mask] else: xt_tmp = new_xt xt = xt_tmp t = t + dt if return_trace: sampling_trace.append(xt) return xt, sampling_trace