alverciito commited on
Commit
da14095
·
1 Parent(s): dbd79bd

add model.py for deployment test

Browse files
Files changed (5) hide show
  1. README.md +8 -0
  2. __init__.py +10 -0
  3. config.json +18 -8
  4. configurations.py +0 -0
  5. model.py +226 -130
README.md CHANGED
@@ -1,7 +1,15 @@
1
  ---
 
 
 
 
 
 
2
  license: apache-2.0
3
  ---
4
 
 
 
5
  ## Baseline Comparison
6
  | Category | Model / Method | Spanish Support | Training |
7
  |---|---|---|----------|
 
1
  ---
2
+ library_name: transformers
3
+ pipeline_tag: sentence-similarity
4
+ tags:
5
+ - sentence-embeddings
6
+ - information-retrieval
7
+ - semantic-search
8
  license: apache-2.0
9
  ---
10
 
11
+ # SentenceCoseNet
12
+
13
  ## Baseline Comparison
14
  | Category | Model / Method | Spanish Support | Training |
15
  |---|---|---|----------|
__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
2
+ # #
3
+ # This file was created by: Alberto Palomo Alonso #
4
+ # Universidad de Alcalá - Escuela Politécnica Superior #
5
+ # #
6
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
+ from .model import SentenceCoseNet, SentenceCoseNetConfig
8
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
9
+ # END OF FILE #
10
+ # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
config.json CHANGED
@@ -1,16 +1,26 @@
1
  {
2
- "architectures": [
3
- "CoseNetTransformer"
4
- ],
5
- "dropout": 0.0,
 
 
6
  "emb_dim": 256,
7
- "model_type": "sentence_transformer",
8
- "seq_len": ...,
 
 
 
 
 
 
 
 
9
  "torch_dtype": "float32",
10
  "transformers_version": "4.57.3",
11
- "vocab_size": 32768,
12
  "auto_map": {
13
- "AutoConfig": "configurations.SentenceCoseNetConfig",
14
  "AutoModel": "model.SentenceCoseNet"
15
  }
16
  }
 
1
  {
2
+ "architectures": ["SentenceCoseNet"],
3
+
4
+ "model_type": "sentence_cosenet",
5
+
6
+ "vocab_size": 32768,
7
+ "hidden_size": 256,
8
  "emb_dim": 256,
9
+
10
+ "max_position_embeddings": 382,
11
+ "seq_len": 382,
12
+
13
+ "dropout": 0.0,
14
+
15
+ "pad_token_id": 0,
16
+ "bos_token_id": 1,
17
+ "eos_token_id": 2,
18
+
19
  "torch_dtype": "float32",
20
  "transformers_version": "4.57.3",
21
+
22
  "auto_map": {
23
+ "AutoConfig": "model.SentenceCoseNetConfig",
24
  "AutoModel": "model.SentenceCoseNet"
25
  }
26
  }
configurations.py DELETED
File without changes
model.py CHANGED
@@ -4,173 +4,269 @@
4
  # Universidad de Alcalá - Escuela Politécnica Superior #
5
  # #
6
  # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
7
- # Import statements:
8
  import torch
9
- from src.model.config import ModelConfig
10
- from src.model.cosenet import CosineDistanceLayer, CoSeNet
11
- from src.model.transformers import EncoderBlock, PositionalEncoding, MaskedMeanPooling
12
 
13
 
14
- class CoseNetTransformer(torch.nn.Module):
15
  """
16
- Segmentation network combining Transformer encoders with CoSeNet.
17
 
18
- This model integrates token embeddings and positional encodings with
19
- a stack of Transformer encoder blocks to produce contextualized
20
- representations. These representations are then processed by a
21
- CoSeNet module to perform structured segmentation, followed by a
22
- cosine-based distance computation.
23
 
24
- The final output is a pair-wise distance matrix suitable for
25
- segmentation or boundary detection tasks.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  """
27
- def __init__(self, model_config: ModelConfig, **kwargs):
28
- """
29
- Initialize the segmentation network.
30
 
31
- The network is composed of an embedding layer, positional encoding,
32
- multiple Transformer encoder blocks, a CoSeNet segmentation module,
33
- and a cosine distance layer.
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  Args:
36
- model_config (ModelConfig): Configuration object containing all
37
- hyperparameters required to build the model, including
38
- vocabulary size, model dimensionality, transformer settings,
39
- and CoSeNet parameters.
40
- **kwargs: Additional keyword arguments forwarded to
41
- `torch.nn.Module`.
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  """
43
  super().__init__(**kwargs)
44
- self.valid_padding = model_config.valid_padding
45
 
46
- # Build layers:
47
- self.embedding = torch.nn.Embedding(
48
- model_config.vocab_size,
49
- model_config.model_dim
50
- )
51
- self.positional_encoding = PositionalEncoding(
52
- emb_dim=model_config.model_dim,
53
- max_len=model_config.max_tokens
54
- )
55
- self.cosenet = CoSeNet(
56
- trainable=model_config.cosenet.trainable,
57
- init_scale=model_config.cosenet.init_scale
58
- )
59
- self.distance_layer = CosineDistanceLayer()
60
- self.pooling = MaskedMeanPooling(valid_pad=model_config.valid_padding)
61
-
62
- # Build encoder blocks:
63
- module_list = list()
64
- for transformer_config in model_config.transformers:
65
- encoder_block = EncoderBlock(
66
- feature_dim=model_config.model_dim,
67
- attention_heads=transformer_config.attention_heads,
68
- feed_forward_multiplier=transformer_config.feed_forward_multiplier,
69
- dropout=transformer_config.dropout,
70
- valid_padding=model_config.valid_padding,
71
- pre_normalize=transformer_config.pre_normalize
72
- )
73
- module_list.append(encoder_block)
74
-
75
- self.encoder_blocks = torch.nn.ModuleList(module_list)
76
-
77
- def encode(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  """
79
- Encode input sequences into contextualized representations.
80
- The input token indices are embedded and enriched with positional
81
- information, then processed by a stack of Transformer encoder
82
- blocks.
83
 
84
  Args:
85
- x (torch.Tensor): Input tensor of token indices with shape
86
- (batch_size, max_tokens).
87
- mask (torch.Tensor, optional): Optional mask tensor indicating
88
- valid or padded positions, depending on the configuration
89
- of the Transformer blocks. Defaults to None. Dimensions should be
90
- (batch_size, max_tokens).
91
  """
92
- # Convert to type:
93
- x = x.int()
94
- # Embedding and positional encoding:
95
- x = self.embedding(x)
96
- x = self.positional_encoding(x)
97
- # Check mask inversion:
98
- if mask[0, 0] == 0:
99
- mask = torch.logical_not(mask)
100
- # Encode:
101
- for encoder in self.encoder_blocks:
102
- x = encoder(x, mask=mask)
103
- return x
104
 
 
 
105
 
106
- def forward(self, x: torch.Tensor, mask: torch.Tensor = None, candidate_mask: torch.Tensor = None) -> torch.Tensor:
 
 
 
 
107
  """
108
- Forward pass of the segmentation network.
109
 
110
- The input token indices are embedded and enriched with positional
111
- information, then processed by a stack of Transformer encoder
112
- blocks. The resulting representations are segmented using CoSeNet
113
- and finally transformed into a pair-wise distance representation.
114
 
115
  Args:
116
- x (torch.Tensor): Input tensor of token indices with shape
117
- (batch_size, sequence_length).
118
- mask (torch.Tensor, optional): Optional mask tensor indicating
119
- valid or padded positions, depending on the configuration
120
- of the Transformer blocks. Defaults to None.
 
 
121
 
122
- If `valid_padding` is disabled, the mask is inverted before being
123
- passed to CoSeNet to match its masking convention.
 
 
 
 
 
124
 
125
- candidate_mask (torch.Tensor, optional): Optional mask tensor for
126
- candidate positions in CoSeNet. Defaults to None.
 
127
 
128
- If `valid_padding` is disabled, the mask is inverted before being
129
- passed to CoSeNet to match its masking convention.
 
 
130
 
131
- Returns:
132
- torch.Tensor: Output tensor containing pairwise distance values
133
- derived from the segmented representations.
 
 
134
  """
135
- # Convert to type:
136
- x = x.int()
137
 
138
- # Embedding and positional encoding:
139
- x = self.embedding(x)
140
- x = self.positional_encoding(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
- # Reshape x and mask:
143
- _b, _s, _t, _d = x.shape
144
- x = x.reshape(_b * _s, _t, _d)
145
- if mask is not None:
146
- mask = mask.reshape(_b * _s, _t).bool()
 
 
 
 
147
 
148
- # Encode the sequence:
149
- for encoder in self.encoder_blocks:
150
- x = encoder(x, mask=mask)
151
 
152
- # Reshape x and mask:
153
- x = x.reshape(_b, _s, _t, _d)
154
- if mask is not None:
155
- mask = mask.reshape(_b, _s, _t)
156
- mask = torch.logical_not(mask) if not self.valid_padding else mask
 
 
 
 
 
157
 
158
- # Apply pooling:
159
- x, mask = self.pooling(x, mask=mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
- # Compute distances:
162
- x = self.distance_layer(x)
 
 
163
 
164
- # Pass through CoSeNet:
165
- x = self.cosenet(x, mask=mask)
166
 
167
- # Apply candidate mask if provided:
168
- if candidate_mask is not None:
169
- candidate_mask = candidate_mask.bool() if not self.valid_padding else torch.logical_not(candidate_mask.bool())
170
- candidate_mask = candidate_mask.to(device=x.device)
171
- x = x.masked_fill(candidate_mask, 0)
172
 
173
- return x
174
  # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
175
  # END OF FILE #
176
  # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
 
4
  # Universidad de Alcalá - Escuela Politécnica Superior #
5
  # #
6
  # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
 
7
  import torch
8
+ from transformers import PreTrainedModel, PretrainedConfig
9
+ from src.model import SegmentationNetwork
10
+ from src.model.config import ModelConfig, TransformerConfig, CoSeNetConfig
11
 
12
 
13
+ class SentenceCoseNetConfig(PretrainedConfig):
14
  """
15
+ Configuration class for SentenceCoseNet.
16
 
17
+ This class stores all hyperparameters needed to initialize
18
+ a `SentenceCoseNet` model. It follows Hugging Face's
19
+ `PretrainedConfig` interface so the model can be saved,
20
+ loaded, and shared via the Hub.
 
21
 
22
+ Attributes:
23
+ model_type (str):
24
+ Identifier used by Hugging Face to register the model.
25
+ vocab_size (int):
26
+ Size of the tokenizer vocabulary.
27
+ emb_dim (int):
28
+ Dimensionality of token embeddings.
29
+ seq_len (int):
30
+ Maximum input sequence length supported by the model.
31
+ dropout (float):
32
+ Dropout probability applied in Transformer blocks.
33
+ valid_padding (bool):
34
+ Whether padding tokens are treated as valid positions.
35
+ cosenet (dict):
36
+ Configuration of the cosine-similarity network head.
37
+ transformers (list[dict]):
38
+ List of Transformer encoder block configurations.
39
  """
 
 
 
40
 
41
+ model_type = "sentence_cosenet"
42
+
43
+ def __init__(
44
+ self,
45
+ vocab_size: int = 32768,
46
+ emb_dim: int = 256,
47
+ seq_len: int = 382,
48
+ dropout: float = 0.0,
49
+ valid_padding: bool = True,
50
+ cosenet: dict | None = None,
51
+ transformers: list | None = None,
52
+ **kwargs,
53
+ ):
54
+ """
55
+ Initialize SentenceCoseNet configuration.
56
 
57
  Args:
58
+ vocab_size:
59
+ Size of the tokenizer vocabulary.
60
+ emb_dim:
61
+ Dimension of token embeddings.
62
+ seq_len:
63
+ Maximum number of tokens per input sequence.
64
+ dropout:
65
+ Dropout probability used throughout the network.
66
+ valid_padding:
67
+ Whether padded tokens should be considered valid.
68
+ cosenet:
69
+ Optional configuration dictionary for the cosine
70
+ similarity network head.
71
+ transformers:
72
+ Optional list of dictionaries describing each
73
+ Transformer encoder block.
74
+ **kwargs:
75
+ Additional keyword arguments passed to
76
+ `PretrainedConfig`.
77
  """
78
  super().__init__(**kwargs)
 
79
 
80
+ self.vocab_size = vocab_size
81
+ self.emb_dim = emb_dim
82
+ self.seq_len = seq_len
83
+ self.dropout = dropout
84
+ self.valid_padding = valid_padding
85
+
86
+ self.cosenet = cosenet or {
87
+ "trainable": True,
88
+ "init_scale": 5.0
89
+ }
90
+
91
+ self.transformers = transformers or [
92
+ {
93
+ "attention_heads": 16,
94
+ "feed_forward_multiplier": 8,
95
+ "dropout": 0.0,
96
+ "pre_normalize": True
97
+ },
98
+ {
99
+ "attention_heads": 16,
100
+ "feed_forward_multiplier": 8,
101
+ "dropout": 0.0,
102
+ "pre_normalize": True
103
+ }
104
+ ]
105
+
106
+ self.hidden_size = emb_dim
107
+ self.max_position_embeddings = seq_len
108
+
109
+
110
+ class SentenceCoseNet(PreTrainedModel):
111
+ """
112
+ Sentence-level encoder model based on CoseNet.
113
+
114
+ This class wraps a custom PyTorch segmentation network
115
+ and exposes it as a Hugging Face `PreTrainedModel`,
116
+ enabling interoperability with the Transformers ecosystem.
117
+
118
+ The model is intended for:
119
+ - Sentence embeddings
120
+ - Semantic search
121
+ - Information retrieval
122
+ - Similarity learning
123
+ """
124
+
125
+ config_class = SentenceCoseNetConfig
126
+ base_model_prefix = "cosenet"
127
+
128
+ def __init__(self, config: SentenceCoseNetConfig):
129
  """
130
+ Initialize the SentenceCoseNet model.
 
 
 
131
 
132
  Args:
133
+ config:
134
+ Instance of `SentenceCoseNetConfig` containing
135
+ model hyperparameters.
 
 
 
136
  """
137
+ super().__init__(config)
138
+
139
+ # Core PyTorch model
140
+ self.model = SegmentationNetwork(to_model_config(config))
 
 
 
 
 
 
 
 
141
 
142
+ # Initialize weights following HF conventions
143
+ self.post_init()
144
 
145
+ def encode(
146
+ self,
147
+ input_ids: torch.Tensor,
148
+ attention_mask=None
149
+ ) -> torch.Tensor:
150
  """
151
+ Encode input token sequences into contextualized embeddings.
152
 
153
+ This method performs embedding lookup, positional encoding,
154
+ and Transformer-based contextualization, returning token-level
155
+ representations.
 
156
 
157
  Args:
158
+ input_ids:
159
+ Tensor of token IDs with shape
160
+ `(batch_size, sequence_length)`.
161
+ attention_mask:
162
+ Optional attention mask indicating valid (1) and
163
+ padded (0) positions. Shape:
164
+ `(batch_size, sequence_length)`.
165
 
166
+ Returns:
167
+ torch.Tensor:
168
+ Contextualized token embeddings with shape
169
+ `(batch_size, sequence_length, emb_dim)`.
170
+ """
171
+ # Ensure integer type
172
+ x = input_ids.int()
173
 
174
+ # Embedding + positional encoding
175
+ x = self.model.embedding(x)
176
+ x = self.model.positional_encoding(x)
177
 
178
+ # Transformer encoder stack
179
+ for encoder in self.model.encoder_blocks:
180
+ x = encoder(x, mask=attention_mask)
181
+ return x
182
 
183
+ def get_sentence_embedding(
184
+ self,
185
+ input_ids: torch.Tensor,
186
+ attention_mask=None,
187
+ ) -> torch.Tensor:
188
  """
189
+ Compute sentence embeddings for zero-shot transfer and
190
+ information retrieval.
191
 
192
+ Args:
193
+ input_ids (torch.Tensor):
194
+ Tensor of shape (B, T)
195
+ attention_mask (torch.Tensor, optional):
196
+ Boolean or binary mask of shape (B, T)
197
+
198
+ Returns:
199
+ torch.Tensor:
200
+ Sentence embeddings of shape (B, D)
201
+ """
202
+ # 1) Token-level encoding: (B, T, D)
203
+ token_embeddings = self.encode(
204
+ input_ids=input_ids,
205
+ attention_mask=attention_mask
206
+ )
207
+ # 2) Pooling using the already-configured model pooling
208
+ pooled, _ = self.model.pooling(
209
+ token_embeddings,
210
+ attention_mask
211
+ )
212
+ return pooled
213
 
214
+ def forward(
215
+ self,
216
+ input_ids: torch.Tensor,
217
+ attention_mask=None,
218
+ candidate_mask=None,
219
+ **kwargs,
220
+ ):
221
+ """
222
+ Forward pass of the SentenceCoseNet model.
223
 
224
+ This method delegates execution to the underlying
225
+ `SegmentationNetwork`.
 
226
 
227
+ Args:
228
+ input_ids:
229
+ Tensor of token IDs with shape
230
+ `(batch_size, sequence_length)`.
231
+ attention_mask:
232
+ Optional attention mask tensor.
233
+ candidate_mask:
234
+ Optional mask indicating candidate segments or spans.
235
+ **kwargs:
236
+ Additional arguments forwarded to the core model.
237
 
238
+ Returns:
239
+ Model-specific output as produced by `SegmentationNetwork`.
240
+ """
241
+ return self.model(
242
+ x=input_ids,
243
+ mask=attention_mask,
244
+ candidate_mask=candidate_mask,
245
+ **kwargs,
246
+ )
247
+
248
+
249
+ def to_model_config(self) -> ModelConfig:
250
+ """
251
+ Convert Hugging Face config to internal ModelConfig.
252
+ """
253
+ mc = ModelConfig()
254
 
255
+ # Core dimensions
256
+ mc.vocab_size = self.vocab_size
257
+ mc.model_dim = self.emb_dim
258
+ mc.valid_padding = self.valid_padding
259
 
260
+ # CoSeNet config
261
+ mc.cosenet = CoSeNetConfig(**self.cosenet)
262
 
263
+ # Transformer stack
264
+ mc.transformers = [
265
+ TransformerConfig(**cfg)
266
+ for cfg in self.transformers
267
+ ]
268
 
269
+ return mc
270
  # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
271
  # END OF FILE #
272
  # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #