alverciito commited on
Commit
34f99b8
·
1 Parent(s): 8068e2f

fix huggingface model missmatch

Browse files
Files changed (2) hide show
  1. model.py +43 -50
  2. src/model/segmentation.py +15 -1
model.py CHANGED
@@ -168,27 +168,20 @@ class SentenceCoseNet(PreTrainedModel):
168
  Contextualized token embeddings with shape
169
  `(batch_size, sequence_length, emb_dim)`.
170
  """
 
 
171
  # Convert to type:
172
- x = input_ids.int().unsqueeze(1)
173
- mask = attention_mask.unsqueeze(1) if attention_mask is not None else None
174
-
175
- # Embedding and positional encoding:
176
- x = self.model.embedding(x)
177
- x = self.model.positional_encoding(x)
178
-
179
- # Reshape x and mask:
180
- _b, _s, _t, _d = x.shape
181
- x = x.reshape(_b * _s, _t, _d)
182
- if mask is not None:
183
- mask = mask.reshape(_b * _s, _t).bool()
184
-
185
- # Encode the sequence:
186
- for encoder in self.model.encoder_blocks:
187
- x = encoder(x, mask=mask)
188
-
189
- # Reshape x and mask:
190
- x = x.reshape(_b, _s, _t, _d)
191
- return x.squeeze(1)
192
 
193
  def get_sentence_embedding(
194
  self,
@@ -212,37 +205,14 @@ class SentenceCoseNet(PreTrainedModel):
212
  torch.Tensor:
213
  Sentence embeddings of shape (B, D)
214
  """
215
- # Convert to type:
216
- x = input_ids.int().unsqueeze(1)
217
- mask = attention_mask.unsqueeze(1) if attention_mask is not None else None
218
-
219
- # Embedding and positional encoding:
220
- x = self.model.embedding(x)
221
- x = self.model.positional_encoding(x)
222
 
223
- # Reshape x and mask:
224
- _b, _s, _t, _d = x.shape
225
- x = x.reshape(_b * _s, _t, _d)
226
- if mask is not None:
227
- mask = mask.reshape(_b * _s, _t).bool()
228
-
229
- # Encode the sequence:
230
- for encoder in self.model.encoder_blocks:
231
- x = encoder(x, mask=mask)
232
-
233
- # Reshape x and mask:
234
- x = x.reshape(_b, _s, _t, _d)
235
- if mask is not None:
236
- mask = mask.reshape(_b, _s, _t)
237
- mask = torch.logical_not(mask) if not self.model.valid_padding else mask
238
-
239
- # Apply pooling:
240
- x, mask = self.model.pooling(x, mask=mask)
241
-
242
- # Apply normalization if required:
243
  if normalize:
244
- x = torch.nn.functional.normalize(x, p=2, dim=-1)
245
- return x.squeeze(1)
 
246
 
247
  def similarity(self, embeddings_1: torch.Tensor, embeddings_2: torch.Tensor) -> torch.Tensor:
248
  """
@@ -268,7 +238,6 @@ class SentenceCoseNet(PreTrainedModel):
268
  # Return cosine similarities (B, S):
269
  return embeddings[..., 0, 1]
270
 
271
-
272
  def forward(
273
  self,
274
  input_ids: torch.Tensor,
@@ -296,6 +265,7 @@ class SentenceCoseNet(PreTrainedModel):
296
  Returns:
297
  Model-specific output as produced by `SegmentationNetwork`.
298
  """
 
299
  return self.model(
300
  x=input_ids,
301
  mask=attention_mask,
@@ -303,6 +273,29 @@ class SentenceCoseNet(PreTrainedModel):
303
  **kwargs,
304
  )
305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  @staticmethod
307
  def to_model_config(config: SentenceCoseNetConfig) -> ModelConfig:
308
  """
 
168
  Contextualized token embeddings with shape
169
  `(batch_size, sequence_length, emb_dim)`.
170
  """
171
+ # Set the model task:
172
+ self.model.task = 'token_encoding'
173
  # Convert to type:
174
+ if len(input_ids.shape) == 2:
175
+ x = input_ids.int().unsqueeze(1)
176
+ mask = attention_mask.unsqueeze(1) if attention_mask is not None else None
177
+ output = self.model(x=x, mask=mask).squeeze(1)
178
+ elif len(input_ids.shape) == 3:
179
+ x = input_ids.int()
180
+ mask = attention_mask if attention_mask is not None else None
181
+ output = self.model(x=x, mask=mask)
182
+ else:
183
+ raise ValueError("Input tensor must be of shape (Batch, Tokens) or (Batch, Sentences, Tokens).")
184
+ return output
 
 
 
 
 
 
 
 
 
185
 
186
  def get_sentence_embedding(
187
  self,
 
205
  torch.Tensor:
206
  Sentence embeddings of shape (B, D)
207
  """
208
+ # Set the model task:
209
+ self.model.task = 'sentence_encoding'
210
+ output = self.call(input_ids, attention_mask)
 
 
 
 
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  if normalize:
213
+ output = torch.nn.functional.normalize(output, p=2, dim=-1)
214
+
215
+ return output
216
 
217
  def similarity(self, embeddings_1: torch.Tensor, embeddings_2: torch.Tensor) -> torch.Tensor:
218
  """
 
238
  # Return cosine similarities (B, S):
239
  return embeddings[..., 0, 1]
240
 
 
241
  def forward(
242
  self,
243
  input_ids: torch.Tensor,
 
265
  Returns:
266
  Model-specific output as produced by `SegmentationNetwork`.
267
  """
268
+ self.model.task = 'segmentation'
269
  return self.model(
270
  x=input_ids,
271
  mask=attention_mask,
 
273
  **kwargs,
274
  )
275
 
276
+ def call(self, input_ids: torch.Tensor, attention_mask=None) -> torch.Tensor:
277
+ """
278
+ Internal method to handle different input shapes (task already selected).
279
+ Args:
280
+ input_ids:
281
+ Tensor of token IDs with shape
282
+ `(batch_size, sequence_length)`.
283
+ attention_mask:
284
+ Optional attention mask tensor.
285
+ """
286
+ # Convert to type:
287
+ if len(input_ids.shape) == 2:
288
+ x = input_ids.int().unsqueeze(1)
289
+ mask = attention_mask.unsqueeze(1) if attention_mask is not None else None
290
+ output = self.model(x=x, mask=mask).squeeze(1)
291
+ elif len(input_ids.shape) == 3:
292
+ x = input_ids.int()
293
+ mask = attention_mask if attention_mask is not None else None
294
+ output = self.model(x=x, mask=mask)
295
+ else:
296
+ raise ValueError("Input tensor must be of shape (Batch, Tokens) or (Batch, Sentences, Tokens).")
297
+ return output
298
+
299
  @staticmethod
300
  def to_model_config(config: SentenceCoseNetConfig) -> ModelConfig:
301
  """
src/model/segmentation.py CHANGED
@@ -24,7 +24,7 @@ class SegmentationNetwork(torch.nn.Module):
24
  The final output is a pair-wise distance matrix suitable for
25
  segmentation or boundary detection tasks.
26
  """
27
- def __init__(self, model_config: ModelConfig, **kwargs):
28
  """
29
  Initialize the segmentation network.
30
 
@@ -73,6 +73,11 @@ class SegmentationNetwork(torch.nn.Module):
73
  module_list.append(encoder_block)
74
 
75
  self.encoder_blocks = torch.nn.ModuleList(module_list)
 
 
 
 
 
76
 
77
  def forward(self, x: torch.Tensor, mask: torch.Tensor = None, candidate_mask: torch.Tensor = None) -> torch.Tensor:
78
  """
@@ -126,12 +131,21 @@ class SegmentationNetwork(torch.nn.Module):
126
  mask = mask.reshape(_b, _s, _t)
127
  mask = torch.logical_not(mask) if not self.valid_padding else mask
128
 
 
 
 
129
  # Apply pooling:
130
  x, mask = self.pooling(x, mask=mask)
131
 
 
 
 
132
  # Compute distances:
133
  x = self.distance_layer(x)
134
 
 
 
 
135
  # Pass through CoSeNet:
136
  x = self.cosenet(x, mask=mask)
137
 
 
24
  The final output is a pair-wise distance matrix suitable for
25
  segmentation or boundary detection tasks.
26
  """
27
+ def __init__(self, model_config: ModelConfig, task='segmentation', **kwargs):
28
  """
29
  Initialize the segmentation network.
30
 
 
73
  module_list.append(encoder_block)
74
 
75
  self.encoder_blocks = torch.nn.ModuleList(module_list)
76
+ self.task = task
77
+ if self.task not in ['segmentation', 'similarity', 'token_encoding', 'sentence_encoding']:
78
+ raise ValueError(f"Invalid task '{self.task}'. Supported tasks are 'segmentation', 'similarity', "
79
+ f"'token_encoding', and 'sentence_encoding'.")
80
+
81
 
82
  def forward(self, x: torch.Tensor, mask: torch.Tensor = None, candidate_mask: torch.Tensor = None) -> torch.Tensor:
83
  """
 
131
  mask = mask.reshape(_b, _s, _t)
132
  mask = torch.logical_not(mask) if not self.valid_padding else mask
133
 
134
+ if self.task == 'token_encoding':
135
+ return x
136
+
137
  # Apply pooling:
138
  x, mask = self.pooling(x, mask=mask)
139
 
140
+ if self.task == 'sentence_encoding':
141
+ return x
142
+
143
  # Compute distances:
144
  x = self.distance_layer(x)
145
 
146
+ if self.task == 'similarity':
147
+ return x
148
+
149
  # Pass through CoSeNet:
150
  x = self.cosenet(x, mask=mask)
151