Upload modeling_e1.py with huggingface_hub
Browse files- modeling_e1.py +2 -2
modeling_e1.py
CHANGED
|
@@ -1581,8 +1581,8 @@ class EmbeddingMixin:
|
|
| 1581 |
with torch.no_grad():
|
| 1582 |
for i, batch in tqdm(enumerate(range(0, len(to_embed), batch_size)), desc='Embedding batches'):
|
| 1583 |
seqs = to_embed[i:i + batch_size]
|
| 1584 |
-
input_ids, attention_mask = self._embed(seqs, return_attention_mask=True)
|
| 1585 |
-
embeddings = get_embeddings(input_ids, attention_mask)
|
| 1586 |
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
| 1587 |
if full_embeddings:
|
| 1588 |
emb = emb[mask.bool()].reshape(-1, hidden_size)
|
|
|
|
| 1581 |
with torch.no_grad():
|
| 1582 |
for i, batch in tqdm(enumerate(range(0, len(to_embed), batch_size)), desc='Embedding batches'):
|
| 1583 |
seqs = to_embed[i:i + batch_size]
|
| 1584 |
+
input_ids, attention_mask = self._embed(seqs, return_attention_mask=True)
|
| 1585 |
+
embeddings = get_embeddings(input_ids, attention_mask).float() # sql requires float32
|
| 1586 |
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
| 1587 |
if full_embeddings:
|
| 1588 |
emb = emb[mask.bool()].reshape(-1, hidden_size)
|