alverciito commited on
Commit
0aeae6d
·
1 Parent(s): 8021b9c

zero shot experiment (fix v3)

Browse files
Files changed (1) hide show
  1. model.py +6 -6
model.py CHANGED
@@ -169,8 +169,8 @@ class SentenceCoseNet(PreTrainedModel):
169
  `(batch_size, sequence_length, emb_dim)`.
170
  """
171
  # Convert to type:
172
- x = input_ids.int()
173
- mask = attention_mask if attention_mask is not None else None
174
 
175
  # Embedding and positional encoding:
176
  x = self.model.embedding(x)
@@ -188,7 +188,7 @@ class SentenceCoseNet(PreTrainedModel):
188
 
189
  # Reshape x and mask:
190
  x = x.reshape(_b, _s, _t, _d)
191
- return x.squeeze(0)
192
 
193
  def get_sentence_embedding(
194
  self,
@@ -213,8 +213,8 @@ class SentenceCoseNet(PreTrainedModel):
213
  Sentence embeddings of shape (B, D)
214
  """
215
  # Convert to type:
216
- x = input_ids.int()
217
- mask = attention_mask if attention_mask is not None else None
218
 
219
  # Embedding and positional encoding:
220
  x = self.model.embedding(x)
@@ -242,7 +242,7 @@ class SentenceCoseNet(PreTrainedModel):
242
  # Apply normalization if required:
243
  if normalize:
244
  x = torch.nn.functional.normalize(x, p=2, dim=-1)
245
- return x.squeeze(0)
246
 
247
  def similarity(self, embeddings_1: torch.Tensor, embeddings_2: torch.Tensor) -> torch.Tensor:
248
  """
 
169
  `(batch_size, sequence_length, emb_dim)`.
170
  """
171
  # Convert to type:
172
+ x = input_ids.int().unsqueeze(1)
173
+ mask = attention_mask.unsqueeze(1) if attention_mask is not None else None
174
 
175
  # Embedding and positional encoding:
176
  x = self.model.embedding(x)
 
188
 
189
  # Reshape x and mask:
190
  x = x.reshape(_b, _s, _t, _d)
191
+ return x.squeeze(1)
192
 
193
  def get_sentence_embedding(
194
  self,
 
213
  Sentence embeddings of shape (B, D)
214
  """
215
  # Convert to type:
216
+ x = input_ids.int().unsqueeze(1)
217
+ mask = attention_mask.unsqueeze(1) if attention_mask is not None else None
218
 
219
  # Embedding and positional encoding:
220
  x = self.model.embedding(x)
 
242
  # Apply normalization if required:
243
  if normalize:
244
  x = torch.nn.functional.normalize(x, p=2, dim=-1)
245
+ return x.squeeze(1)
246
 
247
  def similarity(self, embeddings_1: torch.Tensor, embeddings_2: torch.Tensor) -> torch.Tensor:
248
  """