File size: 6,535 Bytes
dbd79bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34f99b8
dbd79bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34f99b8
 
 
 
 
dbd79bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34f99b8
 
 
dbd79bd
 
 
34f99b8
 
 
dbd79bd
 
 
34f99b8
 
 
dbd79bd
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
#                                                           #
#   This file was created by: Alberto Palomo Alonso         #
# Universidad de Alcalá - Escuela Politécnica Superior      #
#                                                           #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# Import statements:
import torch
from .config import ModelConfig
from .cosenet import CosineDistanceLayer, CoSeNet
from .transformers import EncoderBlock, PositionalEncoding, MaskedMeanPooling


class SegmentationNetwork(torch.nn.Module):
    """
    Segmentation network combining Transformer encoders with CoSeNet.

    This model integrates token embeddings and positional encodings with
    a stack of Transformer encoder blocks to produce contextualized
    representations. These representations are then processed by a
    CoSeNet module to perform structured segmentation, followed by a
    cosine-based distance computation.

    The final output is a pair-wise distance matrix suitable for
    segmentation or boundary detection tasks.
    """
    def __init__(self, model_config: ModelConfig, task='segmentation', **kwargs):
        """
        Initialize the segmentation network.

        The network is composed of an embedding layer, positional encoding,
        multiple Transformer encoder blocks, a CoSeNet segmentation module,
        and a cosine distance layer.

        Args:
            model_config (ModelConfig): Configuration object containing all
                hyperparameters required to build the model, including
                vocabulary size, model dimensionality, transformer settings,
                and CoSeNet parameters.
            **kwargs: Additional keyword arguments forwarded to
                `torch.nn.Module`.
        """
        super().__init__(**kwargs)
        self.valid_padding = model_config.valid_padding

        # Build layers:
        self.embedding = torch.nn.Embedding(
            model_config.vocab_size,
            model_config.model_dim
        )
        self.positional_encoding = PositionalEncoding(
            emb_dim=model_config.model_dim,
            max_len=model_config.max_tokens
        )
        self.cosenet = CoSeNet(
            trainable=model_config.cosenet.trainable,
            init_scale=model_config.cosenet.init_scale
        )
        self.distance_layer = CosineDistanceLayer()
        self.pooling = MaskedMeanPooling(valid_pad=model_config.valid_padding)

        # Build encoder blocks:
        module_list = list()
        for transformer_config in model_config.transformers:
            encoder_block = EncoderBlock(
                feature_dim=model_config.model_dim,
                attention_heads=transformer_config.attention_heads,
                feed_forward_multiplier=transformer_config.feed_forward_multiplier,
                dropout=transformer_config.dropout,
                valid_padding=model_config.valid_padding,
                pre_normalize=transformer_config.pre_normalize
            )
            module_list.append(encoder_block)

        self.encoder_blocks = torch.nn.ModuleList(module_list)
        self.task = task
        if self.task not in ['segmentation', 'similarity', 'token_encoding', 'sentence_encoding']:
            raise ValueError(f"Invalid task '{self.task}'. Supported tasks are 'segmentation', 'similarity', "
                             f"'token_encoding', and 'sentence_encoding'.")


    def forward(self, x: torch.Tensor, mask: torch.Tensor = None, candidate_mask: torch.Tensor = None) -> torch.Tensor:
        """
        Forward pass of the segmentation network.

        The input token indices are embedded and enriched with positional
        information, then processed by a stack of Transformer encoder
        blocks. The resulting representations are segmented using CoSeNet
        and finally transformed into a pair-wise distance representation.

        Args:
            x (torch.Tensor): Input tensor of token indices with shape
                (batch_size, sequence_length).
            mask (torch.Tensor, optional): Optional mask tensor indicating
                valid or padded positions, depending on the configuration
                of the Transformer blocks. Defaults to None.

                If `valid_padding` is disabled, the mask is inverted before being
                passed to CoSeNet to match its masking convention.

            candidate_mask (torch.Tensor, optional): Optional mask tensor for
                candidate positions in CoSeNet. Defaults to None.

                If `valid_padding` is disabled, the mask is inverted before being
                passed to CoSeNet to match its masking convention.

        Returns:
            torch.Tensor: Output tensor containing pairwise distance values
            derived from the segmented representations.
        """
        # Convert to type:
        x = x.int()

        # Embedding and positional encoding:
        x = self.embedding(x)
        x = self.positional_encoding(x)

        # Reshape x and mask:
        _b, _s, _t, _d = x.shape
        x = x.reshape(_b * _s, _t, _d)
        if mask is not None:
            mask = mask.reshape(_b * _s, _t).bool()

        # Encode the sequence:
        for encoder in self.encoder_blocks:
            x = encoder(x, mask=mask)

        # Reshape x and mask:
        x = x.reshape(_b, _s, _t, _d)
        if mask is not None:
            mask = mask.reshape(_b, _s, _t)
            mask = torch.logical_not(mask) if not self.valid_padding else mask

        if self.task == 'token_encoding':
            return x

        # Apply pooling:
        x, mask = self.pooling(x, mask=mask)

        if self.task == 'sentence_encoding':
            return x

        # Compute distances:
        x = self.distance_layer(x)

        if self.task == 'similarity':
            return x

        # Pass through CoSeNet:
        x = self.cosenet(x, mask=mask)

        # Apply candidate mask if provided:
        if candidate_mask is not None:
            candidate_mask = candidate_mask.bool() if not self.valid_padding else torch.logical_not(candidate_mask.bool())
            candidate_mask = candidate_mask.to(device=x.device)
            x = x.masked_fill(candidate_mask, 0)

        return x
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
#                        END OF FILE                        #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #