File size: 2,964 Bytes
dbd79bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
#                                                           #
#   This file was created by: Alberto Palomo Alonso         #
# Universidad de Alcalá - Escuela Politécnica Superior      #
#                                                           #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# Import statements:
import torch


class MaskedMeanPooling(torch.nn.Module):
    """
    Mean pooling layer with explicit masking support.

    This layer computes the mean over the sequence dimension while
    ignoring padded elements according to a boolean mask. It supports
    both PyTorch-style padding masks and valid-position masks.
    """

    def __init__(self, valid_pad: bool = True, eps: float = 1e-6):
        """
        Initialize the masked mean pooling layer.

        Args:
            valid_pad (bool, optional): Mask interpretation mode. If True,
                `True` values in the mask indicate valid (non-padded) positions.
                If False, `True` values indicate padded positions, following
                PyTorch-style padding conventions. Defaults to True.
            eps (float, optional): Small constant to avoid division by zero
                when all positions are masked. Defaults to 1e-8.
        """
        super().__init__()
        self.valid_pad = valid_pad
        self.eps = eps

    def forward(
        self,
        x: torch.Tensor,
        mask: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Apply masked mean pooling.

        Args:
            x (torch.Tensor): Input tensor of shape (..., S, D), where
                B is the batch size, S the sequence length, and D the
                feature dimension.
            mask (torch.Tensor): Boolean mask tensor of shape (..., S).
                The interpretation depends on `valid_pad`.

        Returns:
            tuple:
                torch.Tensor: Pooled tensor of shape (..., D).
                torch.Tensor: Updated valid mask after pooling of shape (..., ).
        """
        # Mask handling:
        if mask is None:
            valid_mask = torch.ones(x.shape[:3], dtype=torch.bool, device=x.device)
        else:
            valid_mask = mask

        # Valid:
        if self.valid_pad:
            valid_mask = valid_mask
        else:
            valid_mask = torch.logical_not(valid_mask)

        valid_mask = valid_mask.unsqueeze(-1).to(x.dtype)  # (..., S, 1)
        summed = torch.sum(x * valid_mask, dim=-2)          # (..., D)
        denom = valid_mask.sum(dim=-2).clamp(min=self.eps)  # (..., 1)

        # Valid mask pooling (any):
        valid_mask = valid_mask.squeeze(-1).any(dim=-1)

        return summed / denom, valid_mask
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
#                        END OF FILE                        #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #