lhallee commited on
Commit
d434f74
·
verified ·
1 Parent(s): 11c0794

Upload modeling_e1.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_e1.py +4 -4
modeling_e1.py CHANGED
@@ -1579,8 +1579,8 @@ class EmbeddingMixin:
1579
  print(f"Embedding {len(to_embed)} new sequences")
1580
  if len(to_embed) > 0:
1581
  with torch.no_grad():
1582
- for i, batch in tqdm(enumerate(range(0, len(to_embed), batch_size)), desc='Embedding batches'):
1583
- seqs = to_embed[i:i + batch_size]
1584
  input_ids, attention_mask = self._embed(seqs, return_attention_mask=True)
1585
  embeddings = get_embeddings(input_ids, attention_mask).float() # sql requires float32
1586
  for seq, emb, mask in zip(seqs, embeddings, attention_mask):
@@ -1604,8 +1604,8 @@ class EmbeddingMixin:
1604
 
1605
  if len(to_embed) > 0:
1606
  with torch.no_grad():
1607
- for i, batch in tqdm(enumerate(range(0, len(to_embed), batch_size)), desc='Embedding batches'):
1608
- seqs = to_embed[i:i + batch_size]
1609
  last_hidden_state, attention_mask = self._embed(seqs, return_attention_mask=True)
1610
  embeddings = get_embeddings(last_hidden_state, attention_mask).to(embed_dtype)
1611
  for seq, emb, mask in zip(seqs, embeddings, attention_mask):
 
1579
  print(f"Embedding {len(to_embed)} new sequences")
1580
  if len(to_embed) > 0:
1581
  with torch.no_grad():
1582
+ for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'):
1583
+ seqs = to_embed[batch_start:batch_start + batch_size]
1584
  input_ids, attention_mask = self._embed(seqs, return_attention_mask=True)
1585
  embeddings = get_embeddings(input_ids, attention_mask).float() # sql requires float32
1586
  for seq, emb, mask in zip(seqs, embeddings, attention_mask):
 
1604
 
1605
  if len(to_embed) > 0:
1606
  with torch.no_grad():
1607
+ for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'):
1608
+ seqs = to_embed[batch_start:batch_start + batch_size]
1609
  last_hidden_state, attention_mask = self._embed(seqs, return_attention_mask=True)
1610
  embeddings = get_embeddings(last_hidden_state, attention_mask).to(embed_dtype)
1611
  for seq, emb, mask in zip(seqs, embeddings, attention_mask):