|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from .config import ModelConfig |
|
|
from .cosenet import CosineDistanceLayer, CoSeNet |
|
|
from .transformers import EncoderBlock, PositionalEncoding, MaskedMeanPooling |
|
|
|
|
|
|
|
|
class SegmentationNetwork(torch.nn.Module): |
|
|
""" |
|
|
Segmentation network combining Transformer encoders with CoSeNet. |
|
|
|
|
|
This model integrates token embeddings and positional encodings with |
|
|
a stack of Transformer encoder blocks to produce contextualized |
|
|
representations. These representations are then processed by a |
|
|
CoSeNet module to perform structured segmentation, followed by a |
|
|
cosine-based distance computation. |
|
|
|
|
|
The final output is a pair-wise distance matrix suitable for |
|
|
segmentation or boundary detection tasks. |
|
|
""" |
|
|
def __init__(self, model_config: ModelConfig, task='segmentation', **kwargs): |
|
|
""" |
|
|
Initialize the segmentation network. |
|
|
|
|
|
The network is composed of an embedding layer, positional encoding, |
|
|
multiple Transformer encoder blocks, a CoSeNet segmentation module, |
|
|
and a cosine distance layer. |
|
|
|
|
|
Args: |
|
|
model_config (ModelConfig): Configuration object containing all |
|
|
hyperparameters required to build the model, including |
|
|
vocabulary size, model dimensionality, transformer settings, |
|
|
and CoSeNet parameters. |
|
|
**kwargs: Additional keyword arguments forwarded to |
|
|
`torch.nn.Module`. |
|
|
""" |
|
|
super().__init__(**kwargs) |
|
|
self.valid_padding = model_config.valid_padding |
|
|
|
|
|
|
|
|
self.embedding = torch.nn.Embedding( |
|
|
model_config.vocab_size, |
|
|
model_config.model_dim |
|
|
) |
|
|
self.positional_encoding = PositionalEncoding( |
|
|
emb_dim=model_config.model_dim, |
|
|
max_len=model_config.max_tokens |
|
|
) |
|
|
self.cosenet = CoSeNet( |
|
|
trainable=model_config.cosenet.trainable, |
|
|
init_scale=model_config.cosenet.init_scale |
|
|
) |
|
|
self.distance_layer = CosineDistanceLayer() |
|
|
self.pooling = MaskedMeanPooling(valid_pad=model_config.valid_padding) |
|
|
|
|
|
|
|
|
module_list = list() |
|
|
for transformer_config in model_config.transformers: |
|
|
encoder_block = EncoderBlock( |
|
|
feature_dim=model_config.model_dim, |
|
|
attention_heads=transformer_config.attention_heads, |
|
|
feed_forward_multiplier=transformer_config.feed_forward_multiplier, |
|
|
dropout=transformer_config.dropout, |
|
|
valid_padding=model_config.valid_padding, |
|
|
pre_normalize=transformer_config.pre_normalize |
|
|
) |
|
|
module_list.append(encoder_block) |
|
|
|
|
|
self.encoder_blocks = torch.nn.ModuleList(module_list) |
|
|
self.task = task |
|
|
if self.task not in ['segmentation', 'similarity', 'token_encoding', 'sentence_encoding']: |
|
|
raise ValueError(f"Invalid task '{self.task}'. Supported tasks are 'segmentation', 'similarity', " |
|
|
f"'token_encoding', and 'sentence_encoding'.") |
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, candidate_mask: torch.Tensor = None) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass of the segmentation network. |
|
|
|
|
|
The input token indices are embedded and enriched with positional |
|
|
information, then processed by a stack of Transformer encoder |
|
|
blocks. The resulting representations are segmented using CoSeNet |
|
|
and finally transformed into a pair-wise distance representation. |
|
|
|
|
|
Args: |
|
|
x (torch.Tensor): Input tensor of token indices with shape |
|
|
(batch_size, sequence_length). |
|
|
mask (torch.Tensor, optional): Optional mask tensor indicating |
|
|
valid or padded positions, depending on the configuration |
|
|
of the Transformer blocks. Defaults to None. |
|
|
|
|
|
If `valid_padding` is disabled, the mask is inverted before being |
|
|
passed to CoSeNet to match its masking convention. |
|
|
|
|
|
candidate_mask (torch.Tensor, optional): Optional mask tensor for |
|
|
candidate positions in CoSeNet. Defaults to None. |
|
|
|
|
|
If `valid_padding` is disabled, the mask is inverted before being |
|
|
passed to CoSeNet to match its masking convention. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Output tensor containing pairwise distance values |
|
|
derived from the segmented representations. |
|
|
""" |
|
|
|
|
|
x = x.int() |
|
|
|
|
|
|
|
|
x = self.embedding(x) |
|
|
x = self.positional_encoding(x) |
|
|
|
|
|
|
|
|
_b, _s, _t, _d = x.shape |
|
|
x = x.reshape(_b * _s, _t, _d) |
|
|
if mask is not None: |
|
|
mask = mask.reshape(_b * _s, _t).bool() |
|
|
|
|
|
|
|
|
for encoder in self.encoder_blocks: |
|
|
x = encoder(x, mask=mask) |
|
|
|
|
|
|
|
|
x = x.reshape(_b, _s, _t, _d) |
|
|
if mask is not None: |
|
|
mask = mask.reshape(_b, _s, _t) |
|
|
mask = torch.logical_not(mask) if not self.valid_padding else mask |
|
|
|
|
|
if self.task == 'token_encoding': |
|
|
return x |
|
|
|
|
|
|
|
|
x, mask = self.pooling(x, mask=mask) |
|
|
|
|
|
if self.task == 'sentence_encoding': |
|
|
return x |
|
|
|
|
|
|
|
|
x = self.distance_layer(x) |
|
|
|
|
|
if self.task == 'similarity': |
|
|
return x |
|
|
|
|
|
|
|
|
x = self.cosenet(x, mask=mask) |
|
|
|
|
|
|
|
|
if candidate_mask is not None: |
|
|
candidate_mask = candidate_mask.bool() if not self.valid_padding else torch.logical_not(candidate_mask.bool()) |
|
|
candidate_mask = candidate_mask.to(device=x.device) |
|
|
x = x.masked_fill(candidate_mask, 0) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|