alverciito commited on
Commit ·
0aeae6d
1
Parent(s): 8021b9c
zero shot experiment (fix v3)
Browse files
model.py
CHANGED
|
@@ -169,8 +169,8 @@ class SentenceCoseNet(PreTrainedModel):
|
|
| 169 |
`(batch_size, sequence_length, emb_dim)`.
|
| 170 |
"""
|
| 171 |
# Convert to type:
|
| 172 |
-
x = input_ids.int()
|
| 173 |
-
mask = attention_mask if attention_mask is not None else None
|
| 174 |
|
| 175 |
# Embedding and positional encoding:
|
| 176 |
x = self.model.embedding(x)
|
|
@@ -188,7 +188,7 @@ class SentenceCoseNet(PreTrainedModel):
|
|
| 188 |
|
| 189 |
# Reshape x and mask:
|
| 190 |
x = x.reshape(_b, _s, _t, _d)
|
| 191 |
-
return x.squeeze(
|
| 192 |
|
| 193 |
def get_sentence_embedding(
|
| 194 |
self,
|
|
@@ -213,8 +213,8 @@ class SentenceCoseNet(PreTrainedModel):
|
|
| 213 |
Sentence embeddings of shape (B, D)
|
| 214 |
"""
|
| 215 |
# Convert to type:
|
| 216 |
-
x = input_ids.int()
|
| 217 |
-
mask = attention_mask if attention_mask is not None else None
|
| 218 |
|
| 219 |
# Embedding and positional encoding:
|
| 220 |
x = self.model.embedding(x)
|
|
@@ -242,7 +242,7 @@ class SentenceCoseNet(PreTrainedModel):
|
|
| 242 |
# Apply normalization if required:
|
| 243 |
if normalize:
|
| 244 |
x = torch.nn.functional.normalize(x, p=2, dim=-1)
|
| 245 |
-
return x.squeeze(
|
| 246 |
|
| 247 |
def similarity(self, embeddings_1: torch.Tensor, embeddings_2: torch.Tensor) -> torch.Tensor:
|
| 248 |
"""
|
|
|
|
| 169 |
`(batch_size, sequence_length, emb_dim)`.
|
| 170 |
"""
|
| 171 |
# Convert to type:
|
| 172 |
+
x = input_ids.int().unsqueeze(1)
|
| 173 |
+
mask = attention_mask.unsqueeze(1) if attention_mask is not None else None
|
| 174 |
|
| 175 |
# Embedding and positional encoding:
|
| 176 |
x = self.model.embedding(x)
|
|
|
|
| 188 |
|
| 189 |
# Reshape x and mask:
|
| 190 |
x = x.reshape(_b, _s, _t, _d)
|
| 191 |
+
return x.squeeze(1)
|
| 192 |
|
| 193 |
def get_sentence_embedding(
|
| 194 |
self,
|
|
|
|
| 213 |
Sentence embeddings of shape (B, D)
|
| 214 |
"""
|
| 215 |
# Convert to type:
|
| 216 |
+
x = input_ids.int().unsqueeze(1)
|
| 217 |
+
mask = attention_mask.unsqueeze(1) if attention_mask is not None else None
|
| 218 |
|
| 219 |
# Embedding and positional encoding:
|
| 220 |
x = self.model.embedding(x)
|
|
|
|
| 242 |
# Apply normalization if required:
|
| 243 |
if normalize:
|
| 244 |
x = torch.nn.functional.normalize(x, p=2, dim=-1)
|
| 245 |
+
return x.squeeze(1)
|
| 246 |
|
| 247 |
def similarity(self, embeddings_1: torch.Tensor, embeddings_2: torch.Tensor) -> torch.Tensor:
|
| 248 |
"""
|