File size: 6,535 Bytes
dbd79bd 34f99b8 dbd79bd 34f99b8 dbd79bd 34f99b8 dbd79bd 34f99b8 dbd79bd 34f99b8 dbd79bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# #
# This file was created by: Alberto Palomo Alonso #
# Universidad de Alcalá - Escuela Politécnica Superior #
# #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# Import statements:
import torch
from .config import ModelConfig
from .cosenet import CosineDistanceLayer, CoSeNet
from .transformers import EncoderBlock, PositionalEncoding, MaskedMeanPooling
class SegmentationNetwork(torch.nn.Module):
"""
Segmentation network combining Transformer encoders with CoSeNet.
This model integrates token embeddings and positional encodings with
a stack of Transformer encoder blocks to produce contextualized
representations. These representations are then processed by a
CoSeNet module to perform structured segmentation, followed by a
cosine-based distance computation.
The final output is a pair-wise distance matrix suitable for
segmentation or boundary detection tasks.
"""
def __init__(self, model_config: ModelConfig, task='segmentation', **kwargs):
"""
Initialize the segmentation network.
The network is composed of an embedding layer, positional encoding,
multiple Transformer encoder blocks, a CoSeNet segmentation module,
and a cosine distance layer.
Args:
model_config (ModelConfig): Configuration object containing all
hyperparameters required to build the model, including
vocabulary size, model dimensionality, transformer settings,
and CoSeNet parameters.
**kwargs: Additional keyword arguments forwarded to
`torch.nn.Module`.
"""
super().__init__(**kwargs)
self.valid_padding = model_config.valid_padding
# Build layers:
self.embedding = torch.nn.Embedding(
model_config.vocab_size,
model_config.model_dim
)
self.positional_encoding = PositionalEncoding(
emb_dim=model_config.model_dim,
max_len=model_config.max_tokens
)
self.cosenet = CoSeNet(
trainable=model_config.cosenet.trainable,
init_scale=model_config.cosenet.init_scale
)
self.distance_layer = CosineDistanceLayer()
self.pooling = MaskedMeanPooling(valid_pad=model_config.valid_padding)
# Build encoder blocks:
module_list = list()
for transformer_config in model_config.transformers:
encoder_block = EncoderBlock(
feature_dim=model_config.model_dim,
attention_heads=transformer_config.attention_heads,
feed_forward_multiplier=transformer_config.feed_forward_multiplier,
dropout=transformer_config.dropout,
valid_padding=model_config.valid_padding,
pre_normalize=transformer_config.pre_normalize
)
module_list.append(encoder_block)
self.encoder_blocks = torch.nn.ModuleList(module_list)
self.task = task
if self.task not in ['segmentation', 'similarity', 'token_encoding', 'sentence_encoding']:
raise ValueError(f"Invalid task '{self.task}'. Supported tasks are 'segmentation', 'similarity', "
f"'token_encoding', and 'sentence_encoding'.")
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, candidate_mask: torch.Tensor = None) -> torch.Tensor:
"""
Forward pass of the segmentation network.
The input token indices are embedded and enriched with positional
information, then processed by a stack of Transformer encoder
blocks. The resulting representations are segmented using CoSeNet
and finally transformed into a pair-wise distance representation.
Args:
x (torch.Tensor): Input tensor of token indices with shape
(batch_size, sequence_length).
mask (torch.Tensor, optional): Optional mask tensor indicating
valid or padded positions, depending on the configuration
of the Transformer blocks. Defaults to None.
If `valid_padding` is disabled, the mask is inverted before being
passed to CoSeNet to match its masking convention.
candidate_mask (torch.Tensor, optional): Optional mask tensor for
candidate positions in CoSeNet. Defaults to None.
If `valid_padding` is disabled, the mask is inverted before being
passed to CoSeNet to match its masking convention.
Returns:
torch.Tensor: Output tensor containing pairwise distance values
derived from the segmented representations.
"""
# Convert to type:
x = x.int()
# Embedding and positional encoding:
x = self.embedding(x)
x = self.positional_encoding(x)
# Reshape x and mask:
_b, _s, _t, _d = x.shape
x = x.reshape(_b * _s, _t, _d)
if mask is not None:
mask = mask.reshape(_b * _s, _t).bool()
# Encode the sequence:
for encoder in self.encoder_blocks:
x = encoder(x, mask=mask)
# Reshape x and mask:
x = x.reshape(_b, _s, _t, _d)
if mask is not None:
mask = mask.reshape(_b, _s, _t)
mask = torch.logical_not(mask) if not self.valid_padding else mask
if self.task == 'token_encoding':
return x
# Apply pooling:
x, mask = self.pooling(x, mask=mask)
if self.task == 'sentence_encoding':
return x
# Compute distances:
x = self.distance_layer(x)
if self.task == 'similarity':
return x
# Pass through CoSeNet:
x = self.cosenet(x, mask=mask)
# Apply candidate mask if provided:
if candidate_mask is not None:
candidate_mask = candidate_mask.bool() if not self.valid_padding else torch.logical_not(candidate_mask.bool())
candidate_mask = candidate_mask.to(device=x.device)
x = x.masked_fill(candidate_mask, 0)
return x
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# END OF FILE #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|