# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - # # # # This file was created by: Alberto Palomo Alonso # # Universidad de Alcalá - Escuela Politécnica Superior # # # # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - # # Import statements: import torch from .config import ModelConfig from .cosenet import CosineDistanceLayer, CoSeNet from .transformers import EncoderBlock, PositionalEncoding, MaskedMeanPooling class SegmentationNetwork(torch.nn.Module): """ Segmentation network combining Transformer encoders with CoSeNet. This model integrates token embeddings and positional encodings with a stack of Transformer encoder blocks to produce contextualized representations. These representations are then processed by a CoSeNet module to perform structured segmentation, followed by a cosine-based distance computation. The final output is a pair-wise distance matrix suitable for segmentation or boundary detection tasks. """ def __init__(self, model_config: ModelConfig, task='segmentation', **kwargs): """ Initialize the segmentation network. The network is composed of an embedding layer, positional encoding, multiple Transformer encoder blocks, a CoSeNet segmentation module, and a cosine distance layer. Args: model_config (ModelConfig): Configuration object containing all hyperparameters required to build the model, including vocabulary size, model dimensionality, transformer settings, and CoSeNet parameters. **kwargs: Additional keyword arguments forwarded to `torch.nn.Module`. """ super().__init__(**kwargs) self.valid_padding = model_config.valid_padding # Build layers: self.embedding = torch.nn.Embedding( model_config.vocab_size, model_config.model_dim ) self.positional_encoding = PositionalEncoding( emb_dim=model_config.model_dim, max_len=model_config.max_tokens ) self.cosenet = CoSeNet( trainable=model_config.cosenet.trainable, init_scale=model_config.cosenet.init_scale ) self.distance_layer = CosineDistanceLayer() self.pooling = MaskedMeanPooling(valid_pad=model_config.valid_padding) # Build encoder blocks: module_list = list() for transformer_config in model_config.transformers: encoder_block = EncoderBlock( feature_dim=model_config.model_dim, attention_heads=transformer_config.attention_heads, feed_forward_multiplier=transformer_config.feed_forward_multiplier, dropout=transformer_config.dropout, valid_padding=model_config.valid_padding, pre_normalize=transformer_config.pre_normalize ) module_list.append(encoder_block) self.encoder_blocks = torch.nn.ModuleList(module_list) self.task = task if self.task not in ['segmentation', 'similarity', 'token_encoding', 'sentence_encoding']: raise ValueError(f"Invalid task '{self.task}'. Supported tasks are 'segmentation', 'similarity', " f"'token_encoding', and 'sentence_encoding'.") def forward(self, x: torch.Tensor, mask: torch.Tensor = None, candidate_mask: torch.Tensor = None) -> torch.Tensor: """ Forward pass of the segmentation network. The input token indices are embedded and enriched with positional information, then processed by a stack of Transformer encoder blocks. The resulting representations are segmented using CoSeNet and finally transformed into a pair-wise distance representation. Args: x (torch.Tensor): Input tensor of token indices with shape (batch_size, sequence_length). mask (torch.Tensor, optional): Optional mask tensor indicating valid or padded positions, depending on the configuration of the Transformer blocks. Defaults to None. If `valid_padding` is disabled, the mask is inverted before being passed to CoSeNet to match its masking convention. candidate_mask (torch.Tensor, optional): Optional mask tensor for candidate positions in CoSeNet. Defaults to None. If `valid_padding` is disabled, the mask is inverted before being passed to CoSeNet to match its masking convention. Returns: torch.Tensor: Output tensor containing pairwise distance values derived from the segmented representations. """ # Convert to type: x = x.int() # Embedding and positional encoding: x = self.embedding(x) x = self.positional_encoding(x) # Reshape x and mask: _b, _s, _t, _d = x.shape x = x.reshape(_b * _s, _t, _d) if mask is not None: mask = mask.reshape(_b * _s, _t).bool() # Encode the sequence: for encoder in self.encoder_blocks: x = encoder(x, mask=mask) # Reshape x and mask: x = x.reshape(_b, _s, _t, _d) if mask is not None: mask = mask.reshape(_b, _s, _t) mask = torch.logical_not(mask) if not self.valid_padding else mask if self.task == 'token_encoding': return x # Apply pooling: x, mask = self.pooling(x, mask=mask) if self.task == 'sentence_encoding': return x # Compute distances: x = self.distance_layer(x) if self.task == 'similarity': return x # Pass through CoSeNet: x = self.cosenet(x, mask=mask) # Apply candidate mask if provided: if candidate_mask is not None: candidate_mask = candidate_mask.bool() if not self.valid_padding else torch.logical_not(candidate_mask.bool()) candidate_mask = candidate_mask.to(device=x.device) x = x.masked_fill(candidate_mask, 0) return x # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - # # END OF FILE # # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #