YuPeng0214 commited on
Commit
f1bf335
·
verified ·
1 Parent(s): 004af97

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +11 -9
README.md CHANGED
@@ -155,15 +155,17 @@ from torch import Tensor
155
  from transformers import AutoTokenizer, AutoModel
156
 
157
 
158
- def last_token_pool(last_hidden_states: Tensor,
159
  attention_mask: Tensor) -> Tensor:
160
- left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
161
- if left_padding:
162
- return last_hidden_states[:, -1]
163
- else:
164
- sequence_lengths = attention_mask.sum(dim=1) - 1
165
- batch_size = last_hidden_states.shape[0]
166
- return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
 
 
167
 
168
 
169
  def get_detailed_instruct(task_description: str, query: str) -> str:
@@ -195,7 +197,7 @@ batch_dict = tokenizer(
195
  )
196
  batch_dict.to(model.device)
197
  outputs = model(**batch_dict)
198
- embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
199
 
200
  embeddings = F.normalize(embeddings, p=2, dim=1)
201
  scores = (embeddings[:2] @ embeddings[2:].T)
 
155
  from transformers import AutoTokenizer, AutoModel
156
 
157
 
158
+ def mean_pool(last_hidden_states: Tensor,
159
  attention_mask: Tensor) -> Tensor:
160
+
161
+ seq_lengths = attention_mask.sum(dim=-1)
162
+ return torch.stack(
163
+ [
164
+ last_hidden_states[i, -length:, :].sum(dim=0) / length
165
+ for i, length in enumerate(seq_lengths)
166
+ ],
167
+ dim=0,
168
+ )
169
 
170
 
171
  def get_detailed_instruct(task_description: str, query: str) -> str:
 
197
  )
198
  batch_dict.to(model.device)
199
  outputs = model(**batch_dict)
200
+ embeddings = mean_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
201
 
202
  embeddings = F.normalize(embeddings, p=2, dim=1)
203
  scores = (embeddings[:2] @ embeddings[2:].T)