|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
from src.model import SegmentationNetwork |
|
|
from src.model.config import ModelConfig, TransformerConfig, CoSeNetConfig |
|
|
|
|
|
|
|
|
class SentenceCoseNetConfig(PretrainedConfig): |
|
|
""" |
|
|
Configuration class for SentenceCoseNet. |
|
|
|
|
|
This class stores all hyperparameters needed to initialize |
|
|
a `SentenceCoseNet` model. It follows Hugging Face's |
|
|
`PretrainedConfig` interface so the model can be saved, |
|
|
loaded, and shared via the Hub. |
|
|
|
|
|
Attributes: |
|
|
model_type (str): |
|
|
Identifier used by Hugging Face to register the model. |
|
|
vocab_size (int): |
|
|
Size of the tokenizer vocabulary. |
|
|
emb_dim (int): |
|
|
Dimensionality of token embeddings. |
|
|
seq_len (int): |
|
|
Maximum input sequence length supported by the model. |
|
|
dropout (float): |
|
|
Dropout probability applied in Transformer blocks. |
|
|
valid_padding (bool): |
|
|
Whether padding tokens are treated as valid positions. |
|
|
cosenet (dict): |
|
|
Configuration of the cosine-similarity network head. |
|
|
transformers (list[dict]): |
|
|
List of Transformer encoder block configurations. |
|
|
""" |
|
|
|
|
|
model_type = "sentence_cosenet" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size: int = 32768, |
|
|
emb_dim: int = 256, |
|
|
seq_len: int = 382, |
|
|
dropout: float = 0.0, |
|
|
valid_padding: bool = True, |
|
|
cosenet: dict | None = None, |
|
|
transformers: list | None = None, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
Initialize SentenceCoseNet configuration. |
|
|
|
|
|
Args: |
|
|
vocab_size: |
|
|
Size of the tokenizer vocabulary. |
|
|
emb_dim: |
|
|
Dimension of token embeddings. |
|
|
seq_len: |
|
|
Maximum number of tokens per input sequence. |
|
|
dropout: |
|
|
Dropout probability used throughout the network. |
|
|
valid_padding: |
|
|
Whether padded tokens should be considered valid. |
|
|
cosenet: |
|
|
Optional configuration dictionary for the cosine |
|
|
similarity network head. |
|
|
transformers: |
|
|
Optional list of dictionaries describing each |
|
|
Transformer encoder block. |
|
|
**kwargs: |
|
|
Additional keyword arguments passed to |
|
|
`PretrainedConfig`. |
|
|
""" |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
self.vocab_size = vocab_size |
|
|
self.emb_dim = emb_dim |
|
|
self.seq_len = seq_len |
|
|
self.dropout = dropout |
|
|
self.valid_padding = valid_padding |
|
|
|
|
|
self.cosenet = cosenet or { |
|
|
"trainable": True, |
|
|
"init_scale": 5.0 |
|
|
} |
|
|
|
|
|
self.transformers = transformers or [ |
|
|
{ |
|
|
"attention_heads": 16, |
|
|
"feed_forward_multiplier": 8, |
|
|
"dropout": 0.0, |
|
|
"pre_normalize": True |
|
|
}, |
|
|
{ |
|
|
"attention_heads": 16, |
|
|
"feed_forward_multiplier": 8, |
|
|
"dropout": 0.0, |
|
|
"pre_normalize": True |
|
|
} |
|
|
] |
|
|
|
|
|
self.hidden_size = emb_dim |
|
|
self.max_position_embeddings = seq_len |
|
|
|
|
|
|
|
|
class SentenceCoseNet(PreTrainedModel): |
|
|
""" |
|
|
Sentence-level encoder model based on CoseNet. |
|
|
|
|
|
This class wraps a custom PyTorch segmentation network |
|
|
and exposes it as a Hugging Face `PreTrainedModel`, |
|
|
enabling interoperability with the Transformers ecosystem. |
|
|
|
|
|
The model is intended for: |
|
|
- Sentence embeddings |
|
|
- Semantic search |
|
|
- Information retrieval |
|
|
- Similarity learning |
|
|
""" |
|
|
|
|
|
config_class = SentenceCoseNetConfig |
|
|
base_model_prefix = "cosenet" |
|
|
|
|
|
def __init__(self, config: SentenceCoseNetConfig): |
|
|
""" |
|
|
Initialize the SentenceCoseNet model. |
|
|
|
|
|
Args: |
|
|
config: |
|
|
Instance of `SentenceCoseNetConfig` containing |
|
|
model hyperparameters. |
|
|
""" |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
self.model = SegmentationNetwork(self.to_model_config(config)) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
def encode( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask=None |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Encode input token sequences into contextualized embeddings. |
|
|
|
|
|
This method performs embedding lookup, positional encoding, |
|
|
and Transformer-based contextualization, returning token-level |
|
|
representations. |
|
|
|
|
|
Args: |
|
|
input_ids: |
|
|
Tensor of token IDs with shape |
|
|
`(batch_size, sequence_length)`. |
|
|
attention_mask: |
|
|
Optional attention mask indicating valid (1) and |
|
|
padded (0) positions. Shape: |
|
|
`(batch_size, sequence_length)`. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: |
|
|
Contextualized token embeddings with shape |
|
|
`(batch_size, sequence_length, emb_dim)`. |
|
|
""" |
|
|
|
|
|
self.model.task = 'token_encoding' |
|
|
|
|
|
if len(input_ids.shape) == 2: |
|
|
x = input_ids.int().unsqueeze(1) |
|
|
mask = attention_mask.unsqueeze(1) if attention_mask is not None else None |
|
|
output = self.model(x=x, mask=mask).squeeze(1) |
|
|
elif len(input_ids.shape) == 3: |
|
|
x = input_ids.int() |
|
|
mask = attention_mask if attention_mask is not None else None |
|
|
output = self.model(x=x, mask=mask) |
|
|
else: |
|
|
raise ValueError("Input tensor must be of shape (Batch, Tokens) or (Batch, Sentences, Tokens).") |
|
|
return output |
|
|
|
|
|
def get_sentence_embedding( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask=None, |
|
|
normalize: bool = False, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Compute sentence embeddings for zero-shot transfer and |
|
|
information retrieval. |
|
|
|
|
|
Args: |
|
|
input_ids (torch.Tensor): |
|
|
Tensor of shape (B, T) |
|
|
attention_mask (torch.Tensor, optional): |
|
|
Boolean or binary mask of shape (B, T) |
|
|
normalize (bool, optional): |
|
|
Whether to L2-normalize the output embeddings. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: |
|
|
Sentence embeddings of shape (B, D) |
|
|
""" |
|
|
|
|
|
self.model.task = 'sentence_encoding' |
|
|
output = self.call(input_ids, attention_mask) |
|
|
|
|
|
if normalize: |
|
|
output = torch.nn.functional.normalize(output, p=2, dim=-1) |
|
|
|
|
|
return output |
|
|
|
|
|
def similarity(self, embeddings_1: torch.Tensor, embeddings_2: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Compute cosine similarity scores between two sets of embeddings. |
|
|
|
|
|
Args: |
|
|
embeddings_1 (torch.Tensor): |
|
|
Tensor of shape (B, S, D) containing the first set of |
|
|
embeddings concatenated along the first dimension. |
|
|
|
|
|
embeddings_2 (torch.Tensor): |
|
|
Tensor of shape (B, S, D) containing the second set of |
|
|
embeddings concatenated along the first dimension. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: |
|
|
Similarity scores of shape (B, S) |
|
|
""" |
|
|
|
|
|
embeddings = torch.stack([embeddings_1, embeddings_2], dim=-2) |
|
|
|
|
|
embeddings = self.model.distance_layer(embeddings) |
|
|
|
|
|
return (embeddings[..., 0, 1] + embeddings[..., 1, 0]) / 2 |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask=None, |
|
|
candidate_mask=None, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
Forward pass of the SentenceCoseNet model. |
|
|
|
|
|
This method delegates execution to the underlying |
|
|
`SegmentationNetwork`. |
|
|
|
|
|
Args: |
|
|
input_ids: |
|
|
Tensor of token IDs with shape |
|
|
`(batch_size, sequence_length)`. |
|
|
attention_mask: |
|
|
Optional attention mask tensor. |
|
|
candidate_mask: |
|
|
Optional mask indicating candidate segments or spans. |
|
|
**kwargs: |
|
|
Additional arguments forwarded to the core model. |
|
|
|
|
|
Returns: |
|
|
Model-specific output as produced by `SegmentationNetwork`. |
|
|
""" |
|
|
self.model.task = 'segmentation' |
|
|
return self.model( |
|
|
x=input_ids, |
|
|
mask=attention_mask, |
|
|
candidate_mask=candidate_mask, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
def call(self, input_ids: torch.Tensor, attention_mask=None) -> torch.Tensor: |
|
|
""" |
|
|
Internal method to handle different input shapes (task already selected). |
|
|
Args: |
|
|
input_ids: |
|
|
Tensor of token IDs with shape |
|
|
`(batch_size, sequence_length)`. |
|
|
attention_mask: |
|
|
Optional attention mask tensor. |
|
|
""" |
|
|
|
|
|
if len(input_ids.shape) == 2: |
|
|
x = input_ids.int().unsqueeze(1) |
|
|
mask = attention_mask.unsqueeze(1) if attention_mask is not None else None |
|
|
output = self.model(x=x, mask=mask).squeeze(1) |
|
|
elif len(input_ids.shape) == 3: |
|
|
x = input_ids.int() |
|
|
mask = attention_mask if attention_mask is not None else None |
|
|
output = self.model(x=x, mask=mask) |
|
|
else: |
|
|
raise ValueError("Input tensor must be of shape (Batch, Tokens) or (Batch, Sentences, Tokens).") |
|
|
return output |
|
|
|
|
|
@staticmethod |
|
|
def to_model_config(config: SentenceCoseNetConfig) -> ModelConfig: |
|
|
""" |
|
|
Convert Hugging Face config to internal ModelConfig. |
|
|
""" |
|
|
mc = ModelConfig() |
|
|
|
|
|
|
|
|
mc.vocab_size = config.vocab_size |
|
|
mc.model_dim = config.emb_dim |
|
|
mc.valid_padding = config.valid_padding |
|
|
|
|
|
|
|
|
mc.cosenet = CoSeNetConfig(**config.cosenet) |
|
|
|
|
|
|
|
|
mc.transformers = [ |
|
|
TransformerConfig(**cfg) |
|
|
for cfg in config.transformers |
|
|
] |
|
|
|
|
|
return mc |
|
|
|
|
|
|
|
|
|
|
|
|