Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +3 -3
modeling_esm_plusplus.py
CHANGED
|
@@ -711,7 +711,7 @@ class EmbeddingMixin:
|
|
| 711 |
seqs = to_embed[i * batch_size:(i + 1) * batch_size]
|
| 712 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
| 713 |
residue_embeddings = self._embed(input_ids, attention_mask).float() # sql requires float32
|
| 714 |
-
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
| 715 |
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
| 716 |
if full_embeddings:
|
| 717 |
emb = emb[mask.bool()].reshape(-1, hidden_size)
|
|
@@ -743,11 +743,11 @@ class EmbeddingMixin:
|
|
| 743 |
seqs = to_embed[i * batch_size:(i + 1) * batch_size]
|
| 744 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
| 745 |
residue_embeddings = self._embed(input_ids, attention_mask)
|
| 746 |
-
embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
|
| 747 |
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
| 748 |
if full_embeddings:
|
| 749 |
emb = emb[mask.bool()].reshape(-1, hidden_size)
|
| 750 |
-
embeddings_dict[seq] = emb
|
| 751 |
|
| 752 |
if save:
|
| 753 |
torch.save(embeddings_dict, save_path)
|
|
|
|
| 711 |
seqs = to_embed[i * batch_size:(i + 1) * batch_size]
|
| 712 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
| 713 |
residue_embeddings = self._embed(input_ids, attention_mask).float() # sql requires float32
|
| 714 |
+
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
| 715 |
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
| 716 |
if full_embeddings:
|
| 717 |
emb = emb[mask.bool()].reshape(-1, hidden_size)
|
|
|
|
| 743 |
seqs = to_embed[i * batch_size:(i + 1) * batch_size]
|
| 744 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
| 745 |
residue_embeddings = self._embed(input_ids, attention_mask)
|
| 746 |
+
embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
|
| 747 |
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
| 748 |
if full_embeddings:
|
| 749 |
emb = emb[mask.bool()].reshape(-1, hidden_size)
|
| 750 |
+
embeddings_dict[seq] = emb.cpu()
|
| 751 |
|
| 752 |
if save:
|
| 753 |
torch.save(embeddings_dict, save_path)
|