lhallee commited on
Commit
2237f88
·
verified ·
1 Parent(s): ce0e094

Upload modeling_e1.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_e1.py +2 -2
modeling_e1.py CHANGED
@@ -1581,8 +1581,8 @@ class EmbeddingMixin:
1581
  with torch.no_grad():
1582
  for i, batch in tqdm(enumerate(range(0, len(to_embed), batch_size)), desc='Embedding batches'):
1583
  seqs = to_embed[i:i + batch_size]
1584
- input_ids, attention_mask = self._embed(seqs, return_attention_mask=True).float() # sql requires float32
1585
- embeddings = get_embeddings(input_ids, attention_mask)
1586
  for seq, emb, mask in zip(seqs, embeddings, attention_mask):
1587
  if full_embeddings:
1588
  emb = emb[mask.bool()].reshape(-1, hidden_size)
 
1581
  with torch.no_grad():
1582
  for i, batch in tqdm(enumerate(range(0, len(to_embed), batch_size)), desc='Embedding batches'):
1583
  seqs = to_embed[i:i + batch_size]
1584
+ input_ids, attention_mask = self._embed(seqs, return_attention_mask=True)
1585
+ embeddings = get_embeddings(input_ids, attention_mask).float() # sql requires float32
1586
  for seq, emb, mask in zip(seqs, embeddings, attention_mask):
1587
  if full_embeddings:
1588
  emb = emb[mask.bool()].reshape(-1, hidden_size)