alverciito
upload safetensors and refactor research files
dbd79bd
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# #
# This file was created by: Alberto Palomo Alonso #
# Universidad de Alcalá - Escuela Politécnica Superior #
# #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# Import statements:
import torch
class MaskedMeanPooling(torch.nn.Module):
"""
Mean pooling layer with explicit masking support.
This layer computes the mean over the sequence dimension while
ignoring padded elements according to a boolean mask. It supports
both PyTorch-style padding masks and valid-position masks.
"""
def __init__(self, valid_pad: bool = True, eps: float = 1e-6):
"""
Initialize the masked mean pooling layer.
Args:
valid_pad (bool, optional): Mask interpretation mode. If True,
`True` values in the mask indicate valid (non-padded) positions.
If False, `True` values indicate padded positions, following
PyTorch-style padding conventions. Defaults to True.
eps (float, optional): Small constant to avoid division by zero
when all positions are masked. Defaults to 1e-8.
"""
super().__init__()
self.valid_pad = valid_pad
self.eps = eps
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Apply masked mean pooling.
Args:
x (torch.Tensor): Input tensor of shape (..., S, D), where
B is the batch size, S the sequence length, and D the
feature dimension.
mask (torch.Tensor): Boolean mask tensor of shape (..., S).
The interpretation depends on `valid_pad`.
Returns:
tuple:
torch.Tensor: Pooled tensor of shape (..., D).
torch.Tensor: Updated valid mask after pooling of shape (..., ).
"""
# Mask handling:
if mask is None:
valid_mask = torch.ones(x.shape[:3], dtype=torch.bool, device=x.device)
else:
valid_mask = mask
# Valid:
if self.valid_pad:
valid_mask = valid_mask
else:
valid_mask = torch.logical_not(valid_mask)
valid_mask = valid_mask.unsqueeze(-1).to(x.dtype) # (..., S, 1)
summed = torch.sum(x * valid_mask, dim=-2) # (..., D)
denom = valid_mask.sum(dim=-2).clamp(min=self.eps) # (..., 1)
# Valid mask pooling (any):
valid_mask = valid_mask.squeeze(-1).any(dim=-1)
return summed / denom, valid_mask
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# END OF FILE #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #