Upload modeling_fastesm.py with huggingface_hub
Browse files- modeling_fastesm.py +8 -4
modeling_fastesm.py
CHANGED
|
@@ -556,10 +556,6 @@ class FastEsmPreTrainedModel(PreTrainedModel):
|
|
| 556 |
Returns:
|
| 557 |
Dictionary mapping sequences to embeddings, or None if sql=True
|
| 558 |
"""
|
| 559 |
-
sequences = list(set([seq[:max_len] for seq in sequences]))
|
| 560 |
-
sequences = sorted(sequences, key=len, reverse=True)
|
| 561 |
-
dataset = ProteinDataset(sequences)
|
| 562 |
-
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn, shuffle=False)
|
| 563 |
device = self.device
|
| 564 |
|
| 565 |
def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
@@ -570,6 +566,7 @@ class FastEsmPreTrainedModel(PreTrainedModel):
|
|
| 570 |
else:
|
| 571 |
return residue_embeddings[:, 0, :]
|
| 572 |
|
|
|
|
| 573 |
if sql:
|
| 574 |
import sqlite3
|
| 575 |
conn = sqlite3.connect(sql_db_path)
|
|
@@ -580,6 +577,9 @@ class FastEsmPreTrainedModel(PreTrainedModel):
|
|
| 580 |
print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
|
| 581 |
print(f"Embedding {len(to_embed)} new sequences")
|
| 582 |
if len(to_embed) > 0:
|
|
|
|
|
|
|
|
|
|
| 583 |
with torch.no_grad():
|
| 584 |
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
| 585 |
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
|
@@ -598,6 +598,10 @@ class FastEsmPreTrainedModel(PreTrainedModel):
|
|
| 598 |
conn.close()
|
| 599 |
return None
|
| 600 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 601 |
embeddings_dict = {}
|
| 602 |
with torch.no_grad():
|
| 603 |
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
|
|
|
| 556 |
Returns:
|
| 557 |
Dictionary mapping sequences to embeddings, or None if sql=True
|
| 558 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 559 |
device = self.device
|
| 560 |
|
| 561 |
def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
|
| 566 |
else:
|
| 567 |
return residue_embeddings[:, 0, :]
|
| 568 |
|
| 569 |
+
sequences = list(set([seq[:max_len] for seq in sequences]))
|
| 570 |
if sql:
|
| 571 |
import sqlite3
|
| 572 |
conn = sqlite3.connect(sql_db_path)
|
|
|
|
| 577 |
print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
|
| 578 |
print(f"Embedding {len(to_embed)} new sequences")
|
| 579 |
if len(to_embed) > 0:
|
| 580 |
+
to_embed = sorted(to_embed, key=len, reverse=True)
|
| 581 |
+
dataset = ProteinDataset(to_embed)
|
| 582 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn, shuffle=False)
|
| 583 |
with torch.no_grad():
|
| 584 |
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
| 585 |
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
|
|
|
| 598 |
conn.close()
|
| 599 |
return None
|
| 600 |
|
| 601 |
+
sequences = list(set([seq[:max_len] for seq in sequences]))
|
| 602 |
+
sequences = sorted(sequences, key=len, reverse=True)
|
| 603 |
+
dataset = ProteinDataset(sequences)
|
| 604 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn, shuffle=False)
|
| 605 |
embeddings_dict = {}
|
| 606 |
with torch.no_grad():
|
| 607 |
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|