File size: 10,913 Bytes
dbd79bd
 
 
 
 
 
 
da14095
 
 
dbd79bd
 
da14095
dbd79bd
da14095
dbd79bd
da14095
 
 
 
dbd79bd
da14095
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbd79bd
 
da14095
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbd79bd
 
da14095
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbd79bd
 
 
da14095
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbd79bd
da14095
dbd79bd
 
da14095
 
 
dbd79bd
da14095
 
 
8068e2f
869db96
da14095
33c844e
dbd79bd
4aa1cf7
 
 
da14095
 
 
 
 
dbd79bd
da14095
dbd79bd
da14095
 
 
dbd79bd
 
da14095
 
 
 
 
 
 
dbd79bd
da14095
 
 
 
 
34f99b8
 
27cdddc
34f99b8
 
 
 
 
 
 
 
 
 
 
dbd79bd
da14095
 
 
 
27cdddc
da14095
dbd79bd
da14095
 
dbd79bd
da14095
 
 
 
 
9a7365b
 
da14095
 
 
 
 
34f99b8
 
 
27cdddc
 
34f99b8
 
 
27cdddc
 
d0a3e2d
 
 
 
dd4f3a8
 
 
 
 
 
d0a3e2d
 
 
 
00f1b20
d0a3e2d
27cdddc
0d019f1
27cdddc
 
 
00f1b20
27cdddc
da14095
 
 
 
 
 
 
 
 
dbd79bd
da14095
 
dbd79bd
da14095
 
 
 
 
 
 
 
 
 
dbd79bd
da14095
 
 
34f99b8
da14095
 
 
 
 
 
dbd79bd
34f99b8
 
 
 
 
 
 
 
 
 
 
 
6faa82b
 
 
34f99b8
 
 
 
 
 
 
 
8068e2f
 
b8b34d9
 
 
 
 
 
8068e2f
 
 
dbd79bd
b8b34d9
8068e2f
dbd79bd
b8b34d9
 
 
8068e2f
b8b34d9
dbd79bd
b8b34d9
dbd79bd
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
#                                                           #
#   This file was created by: Alberto Palomo Alonso         #
# Universidad de Alcalá - Escuela Politécnica Superior      #
#                                                           #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
import torch
from transformers import PreTrainedModel, PretrainedConfig
from src.model import SegmentationNetwork
from src.model.config import ModelConfig, TransformerConfig, CoSeNetConfig


class SentenceCoseNetConfig(PretrainedConfig):
    """
    Configuration class for SentenceCoseNet.

    This class stores all hyperparameters needed to initialize
    a `SentenceCoseNet` model. It follows Hugging Face's
    `PretrainedConfig` interface so the model can be saved,
    loaded, and shared via the Hub.

    Attributes:
        model_type (str):
            Identifier used by Hugging Face to register the model.
        vocab_size (int):
            Size of the tokenizer vocabulary.
        emb_dim (int):
            Dimensionality of token embeddings.
        seq_len (int):
            Maximum input sequence length supported by the model.
        dropout (float):
            Dropout probability applied in Transformer blocks.
        valid_padding (bool):
            Whether padding tokens are treated as valid positions.
        cosenet (dict):
            Configuration of the cosine-similarity network head.
        transformers (list[dict]):
            List of Transformer encoder block configurations.
    """

    model_type = "sentence_cosenet"

    def __init__(
        self,
        vocab_size: int = 32768,
        emb_dim: int = 256,
        seq_len: int = 382,
        dropout: float = 0.0,
        valid_padding: bool = True,
        cosenet: dict | None = None,
        transformers: list | None = None,
        **kwargs,
    ):
        """
        Initialize SentenceCoseNet configuration.

        Args:
            vocab_size:
                Size of the tokenizer vocabulary.
            emb_dim:
                Dimension of token embeddings.
            seq_len:
                Maximum number of tokens per input sequence.
            dropout:
                Dropout probability used throughout the network.
            valid_padding:
                Whether padded tokens should be considered valid.
            cosenet:
                Optional configuration dictionary for the cosine
                similarity network head.
            transformers:
                Optional list of dictionaries describing each
                Transformer encoder block.
            **kwargs:
                Additional keyword arguments passed to
                `PretrainedConfig`.
        """
        super().__init__(**kwargs)

        self.vocab_size = vocab_size
        self.emb_dim = emb_dim
        self.seq_len = seq_len
        self.dropout = dropout
        self.valid_padding = valid_padding

        self.cosenet = cosenet or {
            "trainable": True,
            "init_scale": 5.0
        }

        self.transformers = transformers or [
            {
                "attention_heads": 16,
                "feed_forward_multiplier": 8,
                "dropout": 0.0,
                "pre_normalize": True
            },
            {
                "attention_heads": 16,
                "feed_forward_multiplier": 8,
                "dropout": 0.0,
                "pre_normalize": True
            }
        ]

        self.hidden_size = emb_dim
        self.max_position_embeddings = seq_len


class SentenceCoseNet(PreTrainedModel):
    """
    Sentence-level encoder model based on CoseNet.

    This class wraps a custom PyTorch segmentation network
    and exposes it as a Hugging Face `PreTrainedModel`,
    enabling interoperability with the Transformers ecosystem.

    The model is intended for:
    - Sentence embeddings
    - Semantic search
    - Information retrieval
    - Similarity learning
    """

    config_class = SentenceCoseNetConfig
    base_model_prefix = "cosenet"

    def __init__(self, config: SentenceCoseNetConfig):
        """
        Initialize the SentenceCoseNet model.

        Args:
            config:
                Instance of `SentenceCoseNetConfig` containing
                model hyperparameters.
        """
        super().__init__(config)
        
        # Core PyTorch model
        self.model = SegmentationNetwork(self.to_model_config(config))

        # Initialize weights following HF conventions
        self.post_init()

        # Set evaluation mode by default
        self.model.eval()

    def encode(
        self,
        input_ids: torch.Tensor,
        attention_mask=None
    ) -> torch.Tensor:
        """
        Encode input token sequences into contextualized embeddings.

        This method performs embedding lookup, positional encoding,
        and Transformer-based contextualization, returning token-level
        representations.

        Args:
            input_ids:
                Tensor of token IDs with shape
                `(batch_size, sequence_length)`.
            attention_mask:
                Optional attention mask indicating valid (1) and
                padded (0) positions. Shape:
                `(batch_size, sequence_length)`.

        Returns:
            torch.Tensor:
                Contextualized token embeddings with shape
                `(batch_size, sequence_length, emb_dim)`.
        """
        # Set the model task:
        self.model.task = 'token_encoding'
        # Convert to type:
        if len(input_ids.shape) == 2:
            x = input_ids.int().unsqueeze(1)
            mask = attention_mask.unsqueeze(1) if attention_mask is not None else None
            output = self.model(x=x, mask=mask).squeeze(1)
        elif len(input_ids.shape) == 3:
            x = input_ids.int()
            mask = attention_mask if attention_mask is not None else None
            output = self.model(x=x, mask=mask)
        else:
            raise ValueError("Input tensor must be of shape (Batch, Tokens) or (Batch, Sentences, Tokens).")
        return output

    def get_sentence_embedding(
            self,
            input_ids: torch.Tensor,
            attention_mask=None,
            normalize: bool = False,
    ) -> torch.Tensor:
        """
        Compute sentence embeddings for zero-shot transfer and
        information retrieval.

        Args:
            input_ids (torch.Tensor):
                Tensor of shape (B, T)
            attention_mask (torch.Tensor, optional):
                Boolean or binary mask of shape (B, T)
            normalize (bool, optional):
                Whether to L2-normalize the output embeddings.

        Returns:
            torch.Tensor:
                Sentence embeddings of shape (B, D)
        """
        # Set the model task:
        self.model.task = 'sentence_encoding'
        output = self.call(input_ids, attention_mask)

        if normalize:
            output = torch.nn.functional.normalize(output, p=2, dim=-1)

        return output

    def similarity(self, embeddings_1: torch.Tensor, embeddings_2: torch.Tensor) -> torch.Tensor:
        """
        Compute cosine similarity scores between two sets of embeddings.

        Args:
            embeddings_1 (torch.Tensor):
                Tensor of shape (B, S, D) containing the first set of
                embeddings concatenated along the first dimension.

            embeddings_2 (torch.Tensor):
                Tensor of shape (B, S, D) containing the second set of
                embeddings concatenated along the first dimension.

        Returns:
            torch.Tensor:
                Similarity scores of shape (B, S)
        """
        # Concatenate embeddings (B, S, 2, D)
        embeddings = torch.stack([embeddings_1, embeddings_2], dim=-2)
        # Compute distances (B, S, 2, 2):
        embeddings = self.model.distance_layer(embeddings)
        # Return cosine similarities (B, S):
        return (embeddings[..., 0, 1] + embeddings[..., 1, 0]) / 2

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask=None,
        candidate_mask=None,
        **kwargs,
    ):
        """
        Forward pass of the SentenceCoseNet model.

        This method delegates execution to the underlying
        `SegmentationNetwork`.

        Args:
            input_ids:
                Tensor of token IDs with shape
                `(batch_size, sequence_length)`.
            attention_mask:
                Optional attention mask tensor.
            candidate_mask:
                Optional mask indicating candidate segments or spans.
            **kwargs:
                Additional arguments forwarded to the core model.

        Returns:
            Model-specific output as produced by `SegmentationNetwork`.
        """
        self.model.task = 'segmentation'
        return self.model(
            x=input_ids,
            mask=attention_mask,
            candidate_mask=candidate_mask,
            **kwargs,
        )

    def call(self, input_ids: torch.Tensor, attention_mask=None) -> torch.Tensor:
        """
        Internal method to handle different input shapes (task already selected).
        Args:
            input_ids:
                Tensor of token IDs with shape
                `(batch_size, sequence_length)`.
            attention_mask:
                Optional attention mask tensor.
        """
        # Convert to type:
        if len(input_ids.shape) == 2:
            x = input_ids.int().unsqueeze(1)
            mask = attention_mask.unsqueeze(1) if attention_mask is not None else None
            output = self.model(x=x, mask=mask).squeeze(1)
        elif len(input_ids.shape) == 3:
            x = input_ids.int()
            mask = attention_mask if attention_mask is not None else None
            output = self.model(x=x, mask=mask)
        else:
            raise ValueError("Input tensor must be of shape (Batch, Tokens) or (Batch, Sentences, Tokens).")
        return output

    @staticmethod
    def to_model_config(config: SentenceCoseNetConfig) -> ModelConfig:
        """
        Convert Hugging Face config to internal ModelConfig.
        """
        mc = ModelConfig()

        # Core dimensions
        mc.vocab_size = config.vocab_size
        mc.model_dim = config.emb_dim
        mc.valid_padding = config.valid_padding

        # CoSeNet config
        mc.cosenet = CoSeNetConfig(**config.cosenet)

        # Transformer stack
        mc.transformers = [
            TransformerConfig(**cfg)
            for cfg in config.transformers
        ]

        return mc
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
#                        END OF FILE                        #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #