# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - # # # # This file was created by: Alberto Palomo Alonso # # Universidad de Alcalá - Escuela Politécnica Superior # # # # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - # # Import statements: import torch class MaskedMeanPooling(torch.nn.Module): """ Mean pooling layer with explicit masking support. This layer computes the mean over the sequence dimension while ignoring padded elements according to a boolean mask. It supports both PyTorch-style padding masks and valid-position masks. """ def __init__(self, valid_pad: bool = True, eps: float = 1e-6): """ Initialize the masked mean pooling layer. Args: valid_pad (bool, optional): Mask interpretation mode. If True, `True` values in the mask indicate valid (non-padded) positions. If False, `True` values indicate padded positions, following PyTorch-style padding conventions. Defaults to True. eps (float, optional): Small constant to avoid division by zero when all positions are masked. Defaults to 1e-8. """ super().__init__() self.valid_pad = valid_pad self.eps = eps def forward( self, x: torch.Tensor, mask: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: """ Apply masked mean pooling. Args: x (torch.Tensor): Input tensor of shape (..., S, D), where B is the batch size, S the sequence length, and D the feature dimension. mask (torch.Tensor): Boolean mask tensor of shape (..., S). The interpretation depends on `valid_pad`. Returns: tuple: torch.Tensor: Pooled tensor of shape (..., D). torch.Tensor: Updated valid mask after pooling of shape (..., ). """ # Mask handling: if mask is None: valid_mask = torch.ones(x.shape[:3], dtype=torch.bool, device=x.device) else: valid_mask = mask # Valid: if self.valid_pad: valid_mask = valid_mask else: valid_mask = torch.logical_not(valid_mask) valid_mask = valid_mask.unsqueeze(-1).to(x.dtype) # (..., S, 1) summed = torch.sum(x * valid_mask, dim=-2) # (..., D) denom = valid_mask.sum(dim=-2).clamp(min=self.eps) # (..., 1) # Valid mask pooling (any): valid_mask = valid_mask.squeeze(-1).any(dim=-1) return summed / denom, valid_mask # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - # # END OF FILE # # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #