alverciito commited on
Commit
00f1b20
·
1 Parent(s): 6faa82b

fix docstrings

Browse files
Files changed (1) hide show
  1. model.py +2 -2
model.py CHANGED
@@ -232,14 +232,14 @@ class SentenceCoseNet(PreTrainedModel):
232
 
233
  Returns:
234
  torch.Tensor:
235
- Similarity scores of shape (B, S, S)
236
  """
237
  # Concatenate embeddings (B, S, 2, D)
238
  embeddings = torch.stack([embeddings_1, embeddings_2], dim=-2)
239
  # Compute distances (B, S, 2, 2):
240
  embeddings = self.model.distance_layer(embeddings)
241
  # Return cosine similarities (B, S):
242
- return embeddings[..., 0, 1]
243
 
244
  def forward(
245
  self,
 
232
 
233
  Returns:
234
  torch.Tensor:
235
+ Similarity scores of shape (B, S)
236
  """
237
  # Concatenate embeddings (B, S, 2, D)
238
  embeddings = torch.stack([embeddings_1, embeddings_2], dim=-2)
239
  # Compute distances (B, S, 2, 2):
240
  embeddings = self.model.distance_layer(embeddings)
241
  # Return cosine similarities (B, S):
242
+ return (embeddings[..., 0, 1] + embeddings[..., 1, 0]) / 2
243
 
244
  def forward(
245
  self,