Update modeling_fastesm.py
Browse files- modeling_fastesm.py +129 -13
modeling_fastesm.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
from torch.nn import functional as F
|
|
|
|
| 4 |
from typing import Optional, Tuple, Union
|
| 5 |
from einops import rearrange
|
| 6 |
from transformers import PreTrainedModel, PretrainedConfig
|
|
@@ -20,11 +21,11 @@ from transformers.models.esm.modeling_esm import (
|
|
| 20 |
EsmClassificationHead,
|
| 21 |
create_position_ids_from_input_ids,
|
| 22 |
)
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
class FastEsmConfig(PretrainedConfig):
|
| 26 |
model_type = "fast_esm"
|
| 27 |
-
|
| 28 |
def __init__(
|
| 29 |
self,
|
| 30 |
vocab_size=None,
|
|
@@ -141,14 +142,6 @@ class EsmEmbeddings(nn.Module):
|
|
| 141 |
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
| 142 |
)
|
| 143 |
|
| 144 |
-
self.padding_idx = config.pad_token_id
|
| 145 |
-
self.position_embeddings = nn.Embedding(
|
| 146 |
-
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
|
| 147 |
-
)
|
| 148 |
-
# Token dropout does not work correctly so we disable it
|
| 149 |
-
# self.token_dropout = config.token_dropout
|
| 150 |
-
self.mask_token_id = config.mask_token_id
|
| 151 |
-
|
| 152 |
def forward(
|
| 153 |
self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
| 154 |
):
|
|
@@ -164,10 +157,6 @@ class EsmEmbeddings(nn.Module):
|
|
| 164 |
|
| 165 |
embeddings = inputs_embeds
|
| 166 |
|
| 167 |
-
if self.position_embedding_type == "absolute":
|
| 168 |
-
position_embeddings = self.position_embeddings(position_ids)
|
| 169 |
-
embeddings = embeddings + position_embeddings
|
| 170 |
-
|
| 171 |
if self.layer_norm is not None:
|
| 172 |
embeddings = self.layer_norm(embeddings)
|
| 173 |
if attention_mask is not None:
|
|
@@ -336,6 +325,19 @@ class EsmEncoder(nn.Module):
|
|
| 336 |
)
|
| 337 |
|
| 338 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
class FastEsmPreTrainedModel(PreTrainedModel):
|
| 340 |
"""
|
| 341 |
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
|
@@ -364,6 +366,120 @@ class FastEsmPreTrainedModel(PreTrainedModel):
|
|
| 364 |
except AttributeError:
|
| 365 |
return self.esm.embeddings.word_embeddings
|
| 366 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
|
| 368 |
class FastEsmModel(FastEsmPreTrainedModel):
|
| 369 |
def __init__(self, config, add_pooling_layer=True):
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
from torch.nn import functional as F
|
| 4 |
+
from torch.utils.data import Dataset, DataLoader
|
| 5 |
from typing import Optional, Tuple, Union
|
| 6 |
from einops import rearrange
|
| 7 |
from transformers import PreTrainedModel, PretrainedConfig
|
|
|
|
| 21 |
EsmClassificationHead,
|
| 22 |
create_position_ids_from_input_ids,
|
| 23 |
)
|
| 24 |
+
from tqdm.auto import tqdm
|
| 25 |
|
| 26 |
|
| 27 |
class FastEsmConfig(PretrainedConfig):
|
| 28 |
model_type = "fast_esm"
|
|
|
|
| 29 |
def __init__(
|
| 30 |
self,
|
| 31 |
vocab_size=None,
|
|
|
|
| 142 |
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
| 143 |
)
|
| 144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
def forward(
|
| 146 |
self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
| 147 |
):
|
|
|
|
| 157 |
|
| 158 |
embeddings = inputs_embeds
|
| 159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
if self.layer_norm is not None:
|
| 161 |
embeddings = self.layer_norm(embeddings)
|
| 162 |
if attention_mask is not None:
|
|
|
|
| 325 |
)
|
| 326 |
|
| 327 |
|
| 328 |
+
### Dataset for Embedding
|
| 329 |
+
class ProteinDataset(Dataset):
|
| 330 |
+
"""Simple dataset for protein sequences."""
|
| 331 |
+
def __init__(self, sequences: list[str]):
|
| 332 |
+
self.sequences = sequences
|
| 333 |
+
|
| 334 |
+
def __len__(self) -> int:
|
| 335 |
+
return len(self.sequences)
|
| 336 |
+
|
| 337 |
+
def __getitem__(self, idx: int) -> str:
|
| 338 |
+
return self.sequences[idx]
|
| 339 |
+
|
| 340 |
+
|
| 341 |
class FastEsmPreTrainedModel(PreTrainedModel):
|
| 342 |
"""
|
| 343 |
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
|
|
|
| 366 |
except AttributeError:
|
| 367 |
return self.esm.embeddings.word_embeddings
|
| 368 |
|
| 369 |
+
@property
|
| 370 |
+
def device(self) -> torch.device:
|
| 371 |
+
"""Get the device of the model."""
|
| 372 |
+
return next(self.parameters()).device
|
| 373 |
+
|
| 374 |
+
def mean_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 375 |
+
"""Apply mean pooling to sequence outputs."""
|
| 376 |
+
if attention_mask is None:
|
| 377 |
+
return x.mean(dim=1)
|
| 378 |
+
else:
|
| 379 |
+
attention_mask = attention_mask.unsqueeze(-1)
|
| 380 |
+
return (x * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
|
| 381 |
+
|
| 382 |
+
def _collate_fn(self, sequences: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
|
| 383 |
+
"""Collate function for batching sequences."""
|
| 384 |
+
return self.tokenizer(sequences, return_tensors="pt", padding='longest', pad_to_multiple_of=8)
|
| 385 |
+
|
| 386 |
+
def _read_sequences_from_db(self, db_path: str) -> set[str]:
|
| 387 |
+
"""Read sequences from SQLite database."""
|
| 388 |
+
import sqlite3
|
| 389 |
+
sequences = []
|
| 390 |
+
with sqlite3.connect(db_path) as conn:
|
| 391 |
+
c = conn.cursor()
|
| 392 |
+
c.execute("SELECT sequence FROM embeddings")
|
| 393 |
+
while True:
|
| 394 |
+
row = c.fetchone()
|
| 395 |
+
if row is None:
|
| 396 |
+
break
|
| 397 |
+
sequences.append(row[0])
|
| 398 |
+
return set(sequences)
|
| 399 |
+
|
| 400 |
+
def embed_dataset(
|
| 401 |
+
self,
|
| 402 |
+
sequences: list[str],
|
| 403 |
+
batch_size: int = 2,
|
| 404 |
+
max_len: int = 512,
|
| 405 |
+
full_embeddings: bool = False,
|
| 406 |
+
full_precision: bool = False,
|
| 407 |
+
pooling_type: str = 'mean',
|
| 408 |
+
num_workers: int = 0,
|
| 409 |
+
sql: bool = False,
|
| 410 |
+
sql_db_path: str = 'embeddings.db',
|
| 411 |
+
) -> Optional[dict[str, torch.Tensor]]:
|
| 412 |
+
"""Embed a dataset of protein sequences.
|
| 413 |
+
|
| 414 |
+
Args:
|
| 415 |
+
sequences: List of protein sequences
|
| 416 |
+
batch_size: Batch size for processing
|
| 417 |
+
max_len: Maximum sequence length
|
| 418 |
+
full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
|
| 419 |
+
full_precision: Whether to cast to full precision (float32) before storage - relevant for dict storage
|
| 420 |
+
pooling_type: Type of pooling ('mean' or 'cls')
|
| 421 |
+
num_workers: Number of workers for data loading, 0 for the main process
|
| 422 |
+
sql: Whether to store embeddings in SQLite database - will be stored in float32
|
| 423 |
+
sql_db_path: Path to SQLite database
|
| 424 |
+
|
| 425 |
+
Returns:
|
| 426 |
+
Dictionary mapping sequences to embeddings, or None if sql=True
|
| 427 |
+
"""
|
| 428 |
+
sequences = list(set([seq[:max_len] for seq in sequences]))
|
| 429 |
+
sequences = sorted(sequences, key=len, reverse=True)
|
| 430 |
+
dataset = ProteinDataset(sequences)
|
| 431 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn)
|
| 432 |
+
device = self.device
|
| 433 |
+
|
| 434 |
+
def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 435 |
+
if full_embeddings:
|
| 436 |
+
return residue_embeddings
|
| 437 |
+
elif pooling_type == 'mean':
|
| 438 |
+
return self.mean_pooling(residue_embeddings, attention_mask)
|
| 439 |
+
else:
|
| 440 |
+
return residue_embeddings[:, 0, :]
|
| 441 |
+
|
| 442 |
+
if sql:
|
| 443 |
+
import sqlite3
|
| 444 |
+
conn = sqlite3.connect(sql_db_path)
|
| 445 |
+
c = conn.cursor()
|
| 446 |
+
c.execute('CREATE TABLE IF NOT EXISTS embeddings (sequence text PRIMARY KEY, embedding blob)')
|
| 447 |
+
already_embedded = self._read_sequences_from_db(sql_db_path)
|
| 448 |
+
to_embed = [seq for seq in sequences if seq not in already_embedded]
|
| 449 |
+
print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
|
| 450 |
+
print(f"Embedding {len(to_embed)} new sequences")
|
| 451 |
+
|
| 452 |
+
with torch.no_grad():
|
| 453 |
+
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
| 454 |
+
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
| 455 |
+
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
| 456 |
+
residue_embeddings = self.forward(input_ids, attention_mask, output_hidden_states=True).hidden_states[-1].float() # required for sql
|
| 457 |
+
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
| 458 |
+
|
| 459 |
+
for seq, emb in zip(seqs, embeddings):
|
| 460 |
+
c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
|
| 461 |
+
(seq, emb.cpu().numpy().tobytes()))
|
| 462 |
+
|
| 463 |
+
if (i + 1) % 100 == 0:
|
| 464 |
+
conn.commit()
|
| 465 |
+
|
| 466 |
+
conn.commit()
|
| 467 |
+
conn.close()
|
| 468 |
+
return None
|
| 469 |
+
|
| 470 |
+
embeddings_dict = {}
|
| 471 |
+
with torch.no_grad():
|
| 472 |
+
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
| 473 |
+
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
| 474 |
+
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
| 475 |
+
residue_embeddings = self.forward(input_ids, attention_mask, output_hidden_states=True).hidden_states[-1].float()
|
| 476 |
+
if full_precision:
|
| 477 |
+
residue_embeddings = residue_embeddings.float()
|
| 478 |
+
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
| 479 |
+
for seq, emb in zip(seqs, embeddings):
|
| 480 |
+
embeddings_dict[seq] = emb
|
| 481 |
+
|
| 482 |
+
return embeddings_dict
|
| 483 |
|
| 484 |
class FastEsmModel(FastEsmPreTrainedModel):
|
| 485 |
def __init__(self, config, add_pooling_layer=True):
|