import torch import torch.nn.functional as F from dataclasses import dataclass @dataclass class SMArgs: """Arguments for Softmasking""" # sm algorithm sm_alg: str = "none" # "mixinputs_with_topk" or "mixinputs_with_temp" sm_schedule: str = "none" # "none", "linear", or "stepwise" # lambda(ยท) parameters scale: float = 0.0 # overall strength of mixing (0 disables mixing) steepness: float = 0.0 # sigmoid steepness for entropy->lambda map offset: float = 0.0 # sigmoid offset entropy->lambda map # used only when sm_alg == "mixinputs_with_topk" mixinputs_k: int = 3 # used only when sm_alg == "mixinputs_with_temp" mixinputs_temp: float = 1.0 def get_mixing_factors_for_softmasking(input_ids, logits_prelim, mask_token_id, max_gen_length, sm_args): """Compute mixing factors and output probabilities for Softmasking.""" # Create a one-hot distribution for the original input `xt`. xt_one_hot = F.one_hot(input_ids, num_classes=logits_prelim.shape[-1]).to(logits_prelim.dtype) # First get the negative entropy to calculate lambda temperature = sm_args.mixinputs_temp if sm_args.sm_alg == "mixinputs_with_temp" else 1.0 neg_entropy, p = get_neg_entropy_and_probabilities(logits_prelim, temperature=temperature) # Update scale with schedule if needed if sm_args.sm_schedule != "none": num_mask_token = (input_ids == mask_token_id).sum().item() scale = get_time_dependence( max_gen_length=max_gen_length, num_mask_token=num_mask_token, scale=sm_args.scale, schedule=sm_args.sm_schedule ) else: scale = sm_args.scale # Calculate lambda tensor mask_positions = (input_ids == mask_token_id) lambda_tensor = calculate_lambda_tensor(neg_entropy, mask_positions, scale, sm_args.steepness, sm_args.offset) if sm_args.sm_alg == "mixinputs_with_topk": # Only fill probabilities for top-k tokens p = get_only_topk_probs(logits_prelim, sm_args.mixinputs_k) # Create convex combination for output probabilities p_out = (1 - lambda_tensor) * xt_one_hot \ + lambda_tensor * p return p_out def get_neg_entropy_and_probabilities(logits, temperature=1.0): """Get negative entropy and probabilities from logits""" epsilon = 1e-10 p = torch.softmax(logits / temperature, dim=-1) # (B,T,V) logp = torch.log(p + epsilon) neg_entropy = torch.sum(p * logp, dim=-1) return neg_entropy, p def calculate_lambda_tensor(neg_entropy, mask_positions, scale, steepness, offset): """Calculate lambda tensor from negative entropy""" if neg_entropy is None or scale == 0.0: return torch.zeros_like(neg_entropy) # scale negative entropy to [0,1] using sigmoid lambda_tensor = neg_entropy lambda_tensor = scale * torch.sigmoid(steepness * (lambda_tensor - offset)) # apply only on mask positions lambda_tensor = torch.where(mask_positions, lambda_tensor, torch.zeros_like(lambda_tensor)) return lambda_tensor.unsqueeze(-1) # (B,T,1) def get_only_topk_probs(logits, mixinputs_k=3): """Compute a full-vocabulary probability tensor where only the top-k tokens per position receive softmax probabilities and all other entries are zero.""" topk_logits, topk_indices = torch.topk(logits, k=mixinputs_k, dim=-1) # (batch_size, seq_len, k) topk_probs = torch.softmax(topk_logits, dim=-1) # (batch_size, seq_len, k) topk_sum = topk_probs.sum(dim=-1) # (batch_size, seq_len) assert torch.allclose(topk_sum, torch.ones_like(topk_sum), atol=1e-1), \ f"Top-k softmax probabilities do not sum to 1: max deviation = {(topk_sum - 1).abs().max().item()}" probs_full = torch.zeros_like(logits) # (B, L, V) probs_full.scatter_(-1, topk_indices, topk_probs) # fill top-k assert torch.sum(probs_full > 0).item() == mixinputs_k * logits.shape[0] * logits.shape[1], \ f"Number of non-zero entries in probs_full is incorrect: got {torch.sum(probs_full > 0).item()}, expected {mixinputs_k * logits.shape[0] * logits.shape[1]}" return probs_full def get_time_dependence( max_gen_length: int, num_mask_token: int, scale: float, schedule: str, sm_to_hm: bool = True, threshold: float = 0.5, ) -> float: """Return scale factor depending on decoding progress.""" t = num_mask_token / max_gen_length if max_gen_length else 1.0 if schedule == "none": return scale if schedule == "linear": return scale * (t if sm_to_hm else 1 - t) if schedule == "stepwise": cond = t > threshold if sm_to_hm else t < threshold return scale if cond else 0 raise ValueError(f"Unknown schedule: {schedule}")