alverciito
commited on
Commit
·
00f1b20
1
Parent(s):
6faa82b
fix docstrings
Browse files
model.py
CHANGED
|
@@ -232,14 +232,14 @@ class SentenceCoseNet(PreTrainedModel):
|
|
| 232 |
|
| 233 |
Returns:
|
| 234 |
torch.Tensor:
|
| 235 |
-
Similarity scores of shape (B, S
|
| 236 |
"""
|
| 237 |
# Concatenate embeddings (B, S, 2, D)
|
| 238 |
embeddings = torch.stack([embeddings_1, embeddings_2], dim=-2)
|
| 239 |
# Compute distances (B, S, 2, 2):
|
| 240 |
embeddings = self.model.distance_layer(embeddings)
|
| 241 |
# Return cosine similarities (B, S):
|
| 242 |
-
return embeddings[..., 0, 1]
|
| 243 |
|
| 244 |
def forward(
|
| 245 |
self,
|
|
|
|
| 232 |
|
| 233 |
Returns:
|
| 234 |
torch.Tensor:
|
| 235 |
+
Similarity scores of shape (B, S)
|
| 236 |
"""
|
| 237 |
# Concatenate embeddings (B, S, 2, D)
|
| 238 |
embeddings = torch.stack([embeddings_1, embeddings_2], dim=-2)
|
| 239 |
# Compute distances (B, S, 2, 2):
|
| 240 |
embeddings = self.model.distance_layer(embeddings)
|
| 241 |
# Return cosine similarities (B, S):
|
| 242 |
+
return (embeddings[..., 0, 1] + embeddings[..., 1, 0]) / 2
|
| 243 |
|
| 244 |
def forward(
|
| 245 |
self,
|