Upload modeling_e1.py with huggingface_hub
Browse files- modeling_e1.py +4 -4
modeling_e1.py
CHANGED
|
@@ -1579,8 +1579,8 @@ class EmbeddingMixin:
|
|
| 1579 |
print(f"Embedding {len(to_embed)} new sequences")
|
| 1580 |
if len(to_embed) > 0:
|
| 1581 |
with torch.no_grad():
|
| 1582 |
-
for
|
| 1583 |
-
seqs = to_embed[
|
| 1584 |
input_ids, attention_mask = self._embed(seqs, return_attention_mask=True)
|
| 1585 |
embeddings = get_embeddings(input_ids, attention_mask).float() # sql requires float32
|
| 1586 |
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
|
@@ -1604,8 +1604,8 @@ class EmbeddingMixin:
|
|
| 1604 |
|
| 1605 |
if len(to_embed) > 0:
|
| 1606 |
with torch.no_grad():
|
| 1607 |
-
for
|
| 1608 |
-
seqs = to_embed[
|
| 1609 |
last_hidden_state, attention_mask = self._embed(seqs, return_attention_mask=True)
|
| 1610 |
embeddings = get_embeddings(last_hidden_state, attention_mask).to(embed_dtype)
|
| 1611 |
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
|
|
|
| 1579 |
print(f"Embedding {len(to_embed)} new sequences")
|
| 1580 |
if len(to_embed) > 0:
|
| 1581 |
with torch.no_grad():
|
| 1582 |
+
for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'):
|
| 1583 |
+
seqs = to_embed[batch_start:batch_start + batch_size]
|
| 1584 |
input_ids, attention_mask = self._embed(seqs, return_attention_mask=True)
|
| 1585 |
embeddings = get_embeddings(input_ids, attention_mask).float() # sql requires float32
|
| 1586 |
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
|
|
|
| 1604 |
|
| 1605 |
if len(to_embed) > 0:
|
| 1606 |
with torch.no_grad():
|
| 1607 |
+
for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'):
|
| 1608 |
+
seqs = to_embed[batch_start:batch_start + batch_size]
|
| 1609 |
last_hidden_state, attention_mask = self._embed(seqs, return_attention_mask=True)
|
| 1610 |
embeddings = get_embeddings(last_hidden_state, attention_mask).to(embed_dtype)
|
| 1611 |
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|