Update README.md
Browse files
README.md
CHANGED
|
@@ -155,15 +155,17 @@ from torch import Tensor
|
|
| 155 |
from transformers import AutoTokenizer, AutoModel
|
| 156 |
|
| 157 |
|
| 158 |
-
def
|
| 159 |
attention_mask: Tensor) -> Tensor:
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
|
|
|
|
|
|
| 167 |
|
| 168 |
|
| 169 |
def get_detailed_instruct(task_description: str, query: str) -> str:
|
|
@@ -195,7 +197,7 @@ batch_dict = tokenizer(
|
|
| 195 |
)
|
| 196 |
batch_dict.to(model.device)
|
| 197 |
outputs = model(**batch_dict)
|
| 198 |
-
embeddings =
|
| 199 |
|
| 200 |
embeddings = F.normalize(embeddings, p=2, dim=1)
|
| 201 |
scores = (embeddings[:2] @ embeddings[2:].T)
|
|
|
|
| 155 |
from transformers import AutoTokenizer, AutoModel
|
| 156 |
|
| 157 |
|
| 158 |
+
def mean_pool(last_hidden_states: Tensor,
|
| 159 |
attention_mask: Tensor) -> Tensor:
|
| 160 |
+
|
| 161 |
+
seq_lengths = attention_mask.sum(dim=-1)
|
| 162 |
+
return torch.stack(
|
| 163 |
+
[
|
| 164 |
+
last_hidden_states[i, -length:, :].sum(dim=0) / length
|
| 165 |
+
for i, length in enumerate(seq_lengths)
|
| 166 |
+
],
|
| 167 |
+
dim=0,
|
| 168 |
+
)
|
| 169 |
|
| 170 |
|
| 171 |
def get_detailed_instruct(task_description: str, query: str) -> str:
|
|
|
|
| 197 |
)
|
| 198 |
batch_dict.to(model.device)
|
| 199 |
outputs = model(**batch_dict)
|
| 200 |
+
embeddings = mean_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
|
| 201 |
|
| 202 |
embeddings = F.normalize(embeddings, p=2, dim=1)
|
| 203 |
scores = (embeddings[:2] @ embeddings[2:].T)
|