Bailan-Alex's picture
Upload folder using huggingface_hub
4f2b2f4 verified
import abc
from typing import Optional
import torch
from torch import Tensor
from dataclasses import dataclass
from schedule import Schedule
import torch.nn.functional as F
@dataclass
class ModelPrediction:
token_logits: Tensor
length_posterior: Optional[Tensor]
expected_gaps: Tensor
def __init__(
self,
token_logits: Tensor,
length_posterior: Optional[Tensor] = None,
expected_gaps: Optional[Tensor] = None,
):
assert length_posterior is not None or expected_gaps is not None
self.token_logits = token_logits
self.length_posterior = length_posterior
self.expected_gaps = expected_gaps
if self.expected_gaps is None:
_, _, L = self.length_posterior.shape
index = torch.arange(0, L, device=token_logits.device).view(1, 1, -1)
self.expected_gaps = (F.softmax(self.length_posterior, dim=-1) * index).sum(dim=-1)
@dataclass
class Rate:
unmask_rate: Tensor # Shape [Batch, Length, Vocab]
length_rate: Tensor # Shape [Batch]
@dataclass
class HittingTime:
insertion_time: Tensor # Shape [Batch, Length]
unmasking_time: Tensor # Shape [Batch, Length]
def __iter__(self):
yield from [self.insertion_time, self.unmasking_time]
@dataclass
class JointInterpolantResult:
# Joint Interpolant
xt: Tensor # Shape [Batch, Length]
st: Tensor # Shape [Batch, Length]
_x1: Tensor
_pad_token: int
_mask_token: int
@property
def mask_indices(self) -> Tensor:
return self.xt == self._mask_token
@property
def unmasked(self) -> Tensor:
return torch.gather(self._x1, 1, self.st)
@property
def xt_length(self) -> Tensor:
# Calculate length of xt
return (self.xt != self._pad_token).sum(dim=1)
@property
def x1_length(self) -> Tensor:
# Calculate length of x1
return (self._x1 != self._pad_token).sum(dim=1)
@property
def gaps_and_mask(self) -> tuple[Tensor, Tensor]:
x1_len = self.x1_length
gaps = self.st.clone()
pad_front = gaps.new_zeros((gaps.shape[0], 1)) - 1 # -1 for the front padding
pad_back = gaps.new_zeros((gaps.shape[0], 1))
gaps = torch.cat([pad_front, gaps, pad_back], dim=1) # Add a leading zero
gaps.scatter_(
1, self.xt_length.unsqueeze(1) + 1, x1_len.unsqueeze(1)
) # Fill the last position with x1_len
gaps = gaps[:, 1:] - gaps[:, :-1] - 1
gaps = torch.clamp(gaps, min=0)
idx = torch.arange(gaps.size(1), device=self.xt.device).unsqueeze(
0
) # shape [1, max_gap]
mask = idx <= self.xt_length.unsqueeze(1)
gaps[~mask] = 0
return gaps, mask
class JointInterpolant(abc.ABC):
def __init__(
self,
vocab_size: int,
mask_token: int,
pad_token: int,
max_length: int,
):
"""
TODO: Add knobs
"""
self.mask_token = mask_token
self.pad_token = pad_token
self.max_length = max_length
self.vocab_size = vocab_size
@abc.abstractmethod
def elbo_weight(self, t: Tensor, x1: Tensor):
"""
Return the ELBO weight for the training, can be changed depends on the empirical results
Shape:
t: [B]
Returns:
weight_unmask: [B, L]
weight_delete: [B, L+1]
"""
raise NotImplementedError
@abc.abstractmethod
def to_actual_rate(self, prediction: ModelPrediction, t: Tensor) -> Rate:
raise NotImplementedError
@abc.abstractmethod
def sample_interpolant(self, t: Tensor, x1: Tensor) -> JointInterpolantResult:
"""
Sample the interpolant xt from x1 at time t
Shapes:
x1: [B, L]
t: [B]
Returns:
xt: [B, L]
st: [B, L] boolean mask of positions that corresponds to xt
xt_mask_indices: [B, L] boolean mask of positions that are masked at xt
x1_remained: [B, L] tokens that are not deleted, used for the training target
gap_counts: [B, L+1] the number of deleted tokens between xt slots
"""
raise NotImplementedError
class AnyOrderMaskInsertionInterpolant(JointInterpolant):
def __init__(
self,
insertion_schedule: Schedule,
unmask_schedule: Schedule,
vocab_size: int,
mask_token: int,
pad_token: int,
max_length: int,
):
super().__init__(vocab_size, mask_token, pad_token, max_length)
self.insertion_schedule = insertion_schedule
self.unmask_schedule = unmask_schedule
def hitting_time(self, t: Tensor, x1: Tensor) -> tuple[Tensor, Tensor]:
"""
t1 is sampled from a uniform distribution over [0, 1]. when t1 < self.mask_schedule.at(t)
t2 is sampled from a uniform distribution over [t1, 1]
"""
B, L = x1.shape
eps = 1e-6
insert_time = self.insertion_schedule.sample((B, L), device=x1.device)
insert_time = eps + (1 - eps) * insert_time # ensure t1 is not 0
unmask_time = self.unmask_schedule.sample_truncated(
insert_time, (B, L), device=x1.device
)
return insert_time, unmask_time
def elbo_weight(self, t: Tensor, x1: Tensor):
"""
Return the ELBO weight for the training, can be changed depends on the empirical results
"""
insert_weight = self.insertion_schedule.rate_scale_factor(t)
insert_weight = insert_weight[:, None].expand(-1, x1.shape[1] + 1)
unmask_weight = self.unmask_schedule.rate_scale_factor(t)
unmask_weight = unmask_weight.unsqueeze(1).expand(-1, x1.shape[1])
return unmask_weight, insert_weight
def to_actual_rate(
self, xt: Tensor, prediction: ModelPrediction, t: Tensor
) -> Rate:
"""
Return the actual rate for the sampling
Args:
xt: [B, L] the sampled tokens
prediction: ModelPrediction object containing token_posterior and expected_gaps
t: [B] the time parameter
"""
token_posterior = F.softmax(prediction.token_logits, dim=-1) # (B, L, V)
unmask_rate = token_posterior * self.unmask_schedule.rate_scale_factor(t).view(
-1, 1, 1
)
length_rate = (
prediction.expected_gaps
* self.insertion_schedule.rate_scale_factor(t).view(-1, 1)
)
return Rate(
unmask_rate=unmask_rate, # (B, L, V)
length_rate=length_rate, # (B, L+1)
)
def sample_interpolant(self, t: Tensor, x1: Tensor) -> JointInterpolantResult:
"""
Shapes:
x1: [B, L]
t: [B]
Returns:
xt: [B, L]
st: [B, L] boolean mask of positions that corresponds to xt
xt_mask_indices: [B, L] boolean mask of positions that are masked at xt
x1_remained: [B, L] tokens that are not deleted, used for the training target
gap_counts: [B, L+1] the number of deleted tokens between xt slots
"""
# sample the stopping time (B, L, 2)
insertion_time, unmasking_time = self.hitting_time(t, x1)
clean_tokens = x1.ne(self.pad_token)
deleted_tokens = clean_tokens & (t[:, None] < insertion_time)
masked_tokens = (
clean_tokens
& (t[:, None] >= insertion_time)
& (t[:, None] < unmasking_time)
)
xt = torch.where(
deleted_tokens,
self.pad_token, # for deletion, change to pad token
torch.where(
masked_tokens,
self.mask_token, # for masking, change to mask token
x1,
),
)
st = xt.ne(self.pad_token).argsort(dim=1, descending=True, stable=True)
xt = torch.gather(xt, 1, st)
st[xt == self.pad_token] = 0
return JointInterpolantResult(
xt=xt, st=st, _x1=x1, _pad_token=self.pad_token, _mask_token=self.mask_token
)
class MDMInterpolant(JointInterpolant):
def __init__(
self,
unmask_schedule: Schedule,
vocab_size: int,
mask_token: int,
pad_token: int,
max_length: int,
):
super().__init__(vocab_size, mask_token, pad_token, max_length)
self.unmask_schedule = unmask_schedule
def elbo_weight(self, t: Tensor, x1: Tensor):
"""
Return the ELBO weight for the training, can be changed depends on the empirical results
there's no weight_delete for the vanilla MDM
"""
weight_unmask = self.unmask_schedule.rate_scale_factor(t)
weight_unmask_expanded = weight_unmask.unsqueeze(1).expand(
-1, x1.shape[1]
) # (B,L)
return weight_unmask_expanded
def to_actual_rate(self, xt: Tensor, prediction: Tensor, t: Tensor) -> Rate:
"""
Return the actual rate for the sampling
"""
token_posterior = F.softmax(prediction, dim=-1) # (B, L, V)
unmask_rate = token_posterior * self.unmask_schedule.rate_scale_factor(t).view(
-1, 1, 1
)
return Rate(
unmask_rate=unmask_rate, # (B, L, V)
length_rate=None, # (B, L+1)
)
def sample_interpolant(self, t: Tensor, x1: Tensor) -> JointInterpolantResult:
# sample the stopping time (B, L, 2)
eps = 1e-6
unmask_time = self.unmask_schedule.sample(
(x1.shape[0], x1.shape[1]), device=x1.device
)
unmask_time = unmask_time * (1 - eps) + eps
xt = torch.where(
t[:, None] < unmask_time,
self.mask_token, # for masking, change to mask token
x1,
)
st = torch.arange(xt.shape[1], device=xt.device, dtype=torch.long).repeat(
xt.shape[0], 1
)
return JointInterpolantResult(
xt=xt, st=st, _x1=x1, _pad_token=self.pad_token, _mask_token=self.mask_token
)