|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
class MaskedMeanPooling(torch.nn.Module): |
|
|
""" |
|
|
Mean pooling layer with explicit masking support. |
|
|
|
|
|
This layer computes the mean over the sequence dimension while |
|
|
ignoring padded elements according to a boolean mask. It supports |
|
|
both PyTorch-style padding masks and valid-position masks. |
|
|
""" |
|
|
|
|
|
def __init__(self, valid_pad: bool = True, eps: float = 1e-6): |
|
|
""" |
|
|
Initialize the masked mean pooling layer. |
|
|
|
|
|
Args: |
|
|
valid_pad (bool, optional): Mask interpretation mode. If True, |
|
|
`True` values in the mask indicate valid (non-padded) positions. |
|
|
If False, `True` values indicate padded positions, following |
|
|
PyTorch-style padding conventions. Defaults to True. |
|
|
eps (float, optional): Small constant to avoid division by zero |
|
|
when all positions are masked. Defaults to 1e-8. |
|
|
""" |
|
|
super().__init__() |
|
|
self.valid_pad = valid_pad |
|
|
self.eps = eps |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
mask: torch.Tensor |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Apply masked mean pooling. |
|
|
|
|
|
Args: |
|
|
x (torch.Tensor): Input tensor of shape (..., S, D), where |
|
|
B is the batch size, S the sequence length, and D the |
|
|
feature dimension. |
|
|
mask (torch.Tensor): Boolean mask tensor of shape (..., S). |
|
|
The interpretation depends on `valid_pad`. |
|
|
|
|
|
Returns: |
|
|
tuple: |
|
|
torch.Tensor: Pooled tensor of shape (..., D). |
|
|
torch.Tensor: Updated valid mask after pooling of shape (..., ). |
|
|
""" |
|
|
|
|
|
if mask is None: |
|
|
valid_mask = torch.ones(x.shape[:3], dtype=torch.bool, device=x.device) |
|
|
else: |
|
|
valid_mask = mask |
|
|
|
|
|
|
|
|
if self.valid_pad: |
|
|
valid_mask = valid_mask |
|
|
else: |
|
|
valid_mask = torch.logical_not(valid_mask) |
|
|
|
|
|
valid_mask = valid_mask.unsqueeze(-1).to(x.dtype) |
|
|
summed = torch.sum(x * valid_mask, dim=-2) |
|
|
denom = valid_mask.sum(dim=-2).clamp(min=self.eps) |
|
|
|
|
|
|
|
|
valid_mask = valid_mask.squeeze(-1).any(dim=-1) |
|
|
|
|
|
return summed / denom, valid_mask |
|
|
|
|
|
|
|
|
|
|
|
|