|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from dataclasses import dataclass |
|
|
|
|
|
@dataclass |
|
|
class SMArgs: |
|
|
"""Arguments for Softmasking""" |
|
|
|
|
|
|
|
|
sm_alg: str = "none" |
|
|
sm_schedule: str = "none" |
|
|
|
|
|
|
|
|
scale: float = 0.0 |
|
|
steepness: float = 0.0 |
|
|
offset: float = 0.0 |
|
|
|
|
|
|
|
|
mixinputs_k: int = 3 |
|
|
|
|
|
mixinputs_temp: float = 1.0 |
|
|
|
|
|
def get_mixing_factors_for_softmasking(input_ids, logits_prelim, mask_token_id, max_gen_length, sm_args): |
|
|
"""Compute mixing factors and output probabilities for Softmasking.""" |
|
|
|
|
|
|
|
|
xt_one_hot = F.one_hot(input_ids, num_classes=logits_prelim.shape[-1]).to(logits_prelim.dtype) |
|
|
|
|
|
|
|
|
temperature = sm_args.mixinputs_temp if sm_args.sm_alg == "mixinputs_with_temp" else 1.0 |
|
|
neg_entropy, p = get_neg_entropy_and_probabilities(logits_prelim, temperature=temperature) |
|
|
|
|
|
|
|
|
if sm_args.sm_schedule != "none": |
|
|
num_mask_token = (input_ids == mask_token_id).sum().item() |
|
|
scale = get_time_dependence( |
|
|
max_gen_length=max_gen_length, |
|
|
num_mask_token=num_mask_token, |
|
|
scale=sm_args.scale, |
|
|
schedule=sm_args.sm_schedule |
|
|
) |
|
|
else: |
|
|
scale = sm_args.scale |
|
|
|
|
|
|
|
|
mask_positions = (input_ids == mask_token_id) |
|
|
lambda_tensor = calculate_lambda_tensor(neg_entropy, mask_positions, |
|
|
scale, sm_args.steepness, sm_args.offset) |
|
|
|
|
|
if sm_args.sm_alg == "mixinputs_with_topk": |
|
|
|
|
|
p = get_only_topk_probs(logits_prelim, sm_args.mixinputs_k) |
|
|
|
|
|
|
|
|
p_out = (1 - lambda_tensor) * xt_one_hot \ |
|
|
+ lambda_tensor * p |
|
|
|
|
|
return p_out |
|
|
|
|
|
def get_neg_entropy_and_probabilities(logits, temperature=1.0): |
|
|
"""Get negative entropy and probabilities from logits""" |
|
|
|
|
|
epsilon = 1e-10 |
|
|
p = torch.softmax(logits / temperature, dim=-1) |
|
|
logp = torch.log(p + epsilon) |
|
|
neg_entropy = torch.sum(p * logp, dim=-1) |
|
|
return neg_entropy, p |
|
|
|
|
|
def calculate_lambda_tensor(neg_entropy, mask_positions, scale, steepness, offset): |
|
|
"""Calculate lambda tensor from negative entropy""" |
|
|
|
|
|
if neg_entropy is None or scale == 0.0: |
|
|
return torch.zeros_like(neg_entropy) |
|
|
|
|
|
|
|
|
lambda_tensor = neg_entropy |
|
|
lambda_tensor = scale * torch.sigmoid(steepness * (lambda_tensor - offset)) |
|
|
|
|
|
|
|
|
lambda_tensor = torch.where(mask_positions, lambda_tensor, torch.zeros_like(lambda_tensor)) |
|
|
return lambda_tensor.unsqueeze(-1) |
|
|
|
|
|
def get_only_topk_probs(logits, mixinputs_k=3): |
|
|
"""Compute a full-vocabulary probability tensor where only the top-k tokens per position |
|
|
receive softmax probabilities and all other entries are zero.""" |
|
|
|
|
|
topk_logits, topk_indices = torch.topk(logits, k=mixinputs_k, dim=-1) |
|
|
|
|
|
topk_probs = torch.softmax(topk_logits, dim=-1) |
|
|
topk_sum = topk_probs.sum(dim=-1) |
|
|
assert torch.allclose(topk_sum, torch.ones_like(topk_sum), atol=1e-1), \ |
|
|
f"Top-k softmax probabilities do not sum to 1: max deviation = {(topk_sum - 1).abs().max().item()}" |
|
|
|
|
|
probs_full = torch.zeros_like(logits) |
|
|
probs_full.scatter_(-1, topk_indices, topk_probs) |
|
|
assert torch.sum(probs_full > 0).item() == mixinputs_k * logits.shape[0] * logits.shape[1], \ |
|
|
f"Number of non-zero entries in probs_full is incorrect: got {torch.sum(probs_full > 0).item()}, expected {mixinputs_k * logits.shape[0] * logits.shape[1]}" |
|
|
|
|
|
return probs_full |
|
|
|
|
|
def get_time_dependence( |
|
|
max_gen_length: int, |
|
|
num_mask_token: int, |
|
|
scale: float, |
|
|
schedule: str, |
|
|
sm_to_hm: bool = True, |
|
|
threshold: float = 0.5, |
|
|
) -> float: |
|
|
"""Return scale factor depending on decoding progress.""" |
|
|
t = num_mask_token / max_gen_length if max_gen_length else 1.0 |
|
|
|
|
|
if schedule == "none": |
|
|
return scale |
|
|
|
|
|
if schedule == "linear": |
|
|
return scale * (t if sm_to_hm else 1 - t) |
|
|
|
|
|
if schedule == "stepwise": |
|
|
cond = t > threshold if sm_to_hm else t < threshold |
|
|
return scale if cond else 0 |
|
|
|
|
|
raise ValueError(f"Unknown schedule: {schedule}") |
|
|
|