diff --git "a/modeling_dplm.py" "b/modeling_dplm.py" --- "a/modeling_dplm.py" +++ "b/modeling_dplm.py" @@ -1,1168 +1,1180 @@ -### Embedding Mixin + Pooler -import os -import sqlite3 -import networkx as nx -import numpy as np -import torch -from tqdm.auto import tqdm -from typing import Callable, List, Optional -from torch.utils.data import DataLoader -from torch.utils.data import Dataset as TorchDataset -from transformers import PreTrainedTokenizerBase - - -class Pooler: - def __init__(self, pooling_types: List[str]): - self.pooling_types = pooling_types - self.pooling_options = { - 'mean': self.mean_pooling, - 'max': self.max_pooling, - 'norm': self.norm_pooling, - 'median': self.median_pooling, - 'std': self.std_pooling, - 'var': self.var_pooling, - 'cls': self.cls_pooling, - 'parti': self._pool_parti, - } - - def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor: - maxed_attentions = torch.max(attentions, dim=1)[0] - return maxed_attentions - - def _page_rank(self, attention_matrix, personalization=None, nstart=None, prune_type="top_k_outdegree"): - # Run PageRank on the attention matrix converted to a graph. - # Raises exceptions if the graph doesn't match the token sequence or has no edges. - # Returns the PageRank scores for each token node. - G = self._convert_to_graph(attention_matrix) - if G.number_of_nodes() != attention_matrix.shape[0]: - raise Exception( - f"The number of nodes in the graph should be equal to the number of tokens in sequence! You have {G.number_of_nodes()} nodes for {attention_matrix.shape[0]} tokens.") - if G.number_of_edges() == 0: - raise Exception(f"You don't seem to have any attention edges left in the graph.") - - return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100) - - def _convert_to_graph(self, matrix): - # Convert a matrix (e.g., attention scores) to a directed graph using networkx. - # Each element in the matrix represents a directed edge with a weight. - G = nx.from_numpy_array(matrix, create_using=nx.DiGraph) - return G - - def _calculate_importance_weights(self, dict_importance, attention_mask: Optional[torch.Tensor] = None): - # Remove keys where attention_mask is 0 - if attention_mask is not None: - for k in list(dict_importance.keys()): - if attention_mask[k] == 0: - del dict_importance[k] - - #dict_importance[0] # remove cls - #dict_importance[-1] # remove eos - total = sum(dict_importance.values()) - return np.array([v / total for _, v in dict_importance.items()]) - - def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d) - maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy() - # emb is (b, L, d), maxed_attentions is (b, L, L) - emb_pooled = [] - for e, a, mask in zip(emb, maxed_attentions, attention_mask): - dict_importance = self._page_rank(a) - importance_weights = self._calculate_importance_weights(dict_importance, mask) - num_tokens = int(mask.sum().item()) - emb_pooled.append(np.average(e[:num_tokens], weights=importance_weights, axis=0)) - pooled = torch.tensor(np.array(emb_pooled)) - return pooled - - def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) - if attention_mask is None: - return emb.mean(dim=1) - else: - attention_mask = attention_mask.unsqueeze(-1) - return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) - - def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) - if attention_mask is None: - return emb.max(dim=1).values - else: - attention_mask = attention_mask.unsqueeze(-1) - return (emb * attention_mask).max(dim=1).values - - def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) - if attention_mask is None: - return emb.norm(dim=1, p=2) - else: - attention_mask = attention_mask.unsqueeze(-1) - return (emb * attention_mask).norm(dim=1, p=2) - - def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) - if attention_mask is None: - return emb.median(dim=1).values - else: - attention_mask = attention_mask.unsqueeze(-1) - return (emb * attention_mask).median(dim=1).values - - def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) - if attention_mask is None: - return emb.std(dim=1) - else: - # Compute variance correctly over non-masked positions, then take sqrt - var = self.var_pooling(emb, attention_mask, **kwargs) - return torch.sqrt(var) - - def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) - if attention_mask is None: - return emb.var(dim=1) - else: - # Correctly compute variance over only non-masked positions - attention_mask = attention_mask.unsqueeze(-1) # (b, L, 1) - # Compute mean over non-masked positions - mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d) - mean = mean.unsqueeze(1) # (b, 1, d) - # Compute squared differences from mean, only over non-masked positions - squared_diff = (emb - mean) ** 2 # (b, L, d) - # Sum squared differences over non-masked positions and divide by count - var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d) - return var - - def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) - return emb[:, 0, :] - - def __call__( - self, - emb: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - attentions: Optional[torch.Tensor] = None - ): # [mean, max] - final_emb = [] - for pooling_type in self.pooling_types: - final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) # (b, d) - return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d) - - -class ProteinDataset(TorchDataset): - """Simple dataset for protein sequences.""" - def __init__(self, sequences: list[str]): - self.sequences = sequences - - def __len__(self) -> int: - return len(self.sequences) - - def __getitem__(self, idx: int) -> str: - return self.sequences[idx] - - -def build_collator(tokenizer: PreTrainedTokenizerBase) -> Callable[[list[str]], dict[str, torch.Tensor]]: - def _collate_fn(sequences: list[str]) -> dict[str, torch.Tensor]: - return tokenizer(sequences, return_tensors="pt", padding='longest') - return _collate_fn - - -class EmbeddingMixin: - def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - raise NotImplementedError - - @property - def device(self) -> torch.device: - """Get the device of the model.""" - return next(self.parameters()).device - - def _read_sequences_from_db(self, db_path: str) -> set[str]: - """Read sequences from SQLite database.""" - sequences = [] - with sqlite3.connect(db_path) as conn: - c = conn.cursor() - c.execute("SELECT sequence FROM embeddings") - while True: - row = c.fetchone() - if row is None: - break - sequences.append(row[0]) - return set(sequences) - - def _ensure_embeddings_table(self, conn: sqlite3.Connection) -> None: - cursor = conn.cursor() - cursor.execute( - "CREATE TABLE IF NOT EXISTS embeddings (" - "sequence TEXT PRIMARY KEY, " - "embedding BLOB NOT NULL, " - "shape TEXT, " - "dtype TEXT" - ")" - ) - cursor.execute("PRAGMA table_info(embeddings)") - rows = cursor.fetchall() - column_names = [row[1] for row in rows] - if "shape" not in column_names: - cursor.execute("ALTER TABLE embeddings ADD COLUMN shape TEXT") - if "dtype" not in column_names: - cursor.execute("ALTER TABLE embeddings ADD COLUMN dtype TEXT") - conn.commit() - - def load_embeddings_from_pth(self, save_path: str) -> dict[str, torch.Tensor]: - assert os.path.exists(save_path), f"Embedding file does not exist: {save_path}" - payload = torch.load(save_path, map_location="cpu", weights_only=True) - assert isinstance(payload, dict), "Expected .pth embeddings file to contain a dictionary." - for sequence, tensor in payload.items(): - assert isinstance(sequence, str), "Expected embedding dictionary keys to be sequences (str)." - assert isinstance(tensor, torch.Tensor), "Expected embedding dictionary values to be tensors." - return payload - - def load_embeddings_from_db(self, db_path: str, sequences: Optional[List[str]] = None) -> dict[str, torch.Tensor]: - assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}" - loaded: dict[str, torch.Tensor] = {} - with sqlite3.connect(db_path) as conn: - self._ensure_embeddings_table(conn) - cursor = conn.cursor() - if sequences is None: - cursor.execute("SELECT sequence, embedding, shape, dtype FROM embeddings") - else: - if len(sequences) == 0: - return loaded - placeholders = ",".join(["?"] * len(sequences)) - cursor.execute( - f"SELECT sequence, embedding, shape, dtype FROM embeddings WHERE sequence IN ({placeholders})", - tuple(sequences), - ) - - rows = cursor.fetchall() - for row in rows: - sequence = row[0] - embedding_bytes = row[1] - shape_text = row[2] - dtype_text = row[3] - assert shape_text is not None, "Missing shape metadata in embeddings table." - assert dtype_text is not None, "Missing dtype metadata in embeddings table." - shape_values = [int(value) for value in shape_text.split(",") if len(value) > 0] - assert len(shape_values) > 0, f"Invalid shape metadata for sequence: {sequence}" - expected_size = int(np.prod(shape_values)) - np_dtype = np.dtype(dtype_text) - array = np.frombuffer(embedding_bytes, dtype=np_dtype) - assert array.size == expected_size, f"Shape mismatch while reading sequence: {sequence}" - reshaped = array.copy().reshape(tuple(shape_values)) - loaded[sequence] = torch.from_numpy(reshaped) - return loaded - - def embed_dataset( - self, - sequences: List[str], - tokenizer: Optional[PreTrainedTokenizerBase] = None, - batch_size: int = 2, - max_len: int = 512, - truncate: bool = True, - full_embeddings: bool = False, - embed_dtype: torch.dtype = torch.float32, - pooling_types: List[str] = ['mean'], - num_workers: int = 0, - sql: bool = False, - save: bool = True, - sql_db_path: str = 'embeddings.db', - save_path: str = 'embeddings.pth', - **kwargs, - ) -> Optional[dict[str, torch.Tensor]]: - """ - Embed a dataset of protein sequences. - - Supports two modes: - - Tokenizer mode (ESM2/ESM++): provide `tokenizer`, `_embed(input_ids, attention_mask)` is used. - - Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used. - """ - sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences])) - sequences = sorted(sequences, key=len, reverse=True) - hidden_size = self.config.hidden_size - pooler = Pooler(pooling_types) if not full_embeddings else None - tokenizer_mode = tokenizer is not None - if tokenizer_mode: - collate_fn = build_collator(tokenizer) - device = self.device - else: - collate_fn = None - device = None - - def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - if full_embeddings or residue_embeddings.ndim == 2: - return residue_embeddings - return pooler(residue_embeddings, attention_mask) - - def iter_batches(to_embed: List[str]): - if tokenizer_mode: - assert collate_fn is not None - assert device is not None - dataset = ProteinDataset(to_embed) - dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, shuffle=False) - for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'): - seqs = to_embed[i * batch_size:(i + 1) * batch_size] - input_ids = batch['input_ids'].to(device) - attention_mask = batch['attention_mask'].to(device) - residue_embeddings = self._embed(input_ids, attention_mask) - yield seqs, residue_embeddings, attention_mask - else: - for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'): - seqs = to_embed[batch_start:batch_start + batch_size] - batch_output = self._embed(seqs, return_attention_mask=True, **kwargs) - assert isinstance(batch_output, tuple), "Sequence mode _embed must return (last_hidden_state, attention_mask)." - assert len(batch_output) == 2, "Sequence mode _embed must return exactly two values." - residue_embeddings, attention_mask = batch_output - assert isinstance(attention_mask, torch.Tensor), "Sequence mode _embed must return attention_mask as a torch.Tensor." - yield seqs, residue_embeddings, attention_mask - - if sql: - conn = sqlite3.connect(sql_db_path) - self._ensure_embeddings_table(conn) - c = conn.cursor() - already_embedded = self._read_sequences_from_db(sql_db_path) - to_embed = [seq for seq in sequences if seq not in already_embedded] - print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}") - print(f"Embedding {len(to_embed)} new sequences") - if len(to_embed) > 0: - with torch.no_grad(): - for i, (seqs, residue_embeddings, attention_mask) in enumerate(iter_batches(to_embed)): - embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype) - for seq, emb, mask in zip(seqs, embeddings, attention_mask): - if full_embeddings: - emb = emb[mask.bool()].reshape(-1, hidden_size) - emb_np = emb.cpu().numpy() - emb_shape = ",".join([str(dim) for dim in emb_np.shape]) - emb_dtype = str(emb_np.dtype) - c.execute( - "INSERT OR REPLACE INTO embeddings (sequence, embedding, shape, dtype) VALUES (?, ?, ?, ?)", - (seq, emb_np.tobytes(), emb_shape, emb_dtype), - ) - if tokenizer_mode and (i + 1) % 100 == 0: - conn.commit() - conn.commit() - conn.close() - return None - - embeddings_dict = {} - if os.path.exists(save_path): - embeddings_dict = self.load_embeddings_from_pth(save_path) - to_embed = [seq for seq in sequences if seq not in embeddings_dict] - print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}") - print(f"Embedding {len(to_embed)} new sequences") - else: - to_embed = sequences - print(f"Embedding {len(to_embed)} new sequences") - - if len(to_embed) > 0: - with torch.no_grad(): - for seqs, residue_embeddings, attention_mask in iter_batches(to_embed): - embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype) - for seq, emb, mask in zip(seqs, embeddings, attention_mask): - if full_embeddings: - emb = emb[mask.bool()].reshape(-1, hidden_size) - embeddings_dict[seq] = emb.cpu() - - if save: - torch.save(embeddings_dict, save_path) - - return embeddings_dict - - -# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates -# SPDX-License-Identifier: Apache-2.0 -""" -FastPLMs-compatible DPLM implementation. -""" - -import torch -import torch.nn as nn -from torch.nn import functional as F -from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Union - -from transformers import AutoTokenizer, EsmTokenizer -from transformers.modeling_outputs import ( - BaseModelOutputWithPastAndCrossAttentions, - BaseModelOutputWithPoolingAndCrossAttentions, - ModelOutput, - SequenceClassifierOutput, - TokenClassifierOutput, -) -from transformers.models.esm.configuration_esm import EsmConfig -from transformers.models.esm.modeling_esm import ( - EsmAttention, - EsmClassificationHead, - EsmContactPredictionHead, - EsmEmbeddings, - EsmEncoder, - EsmIntermediate, - EsmLayer, - EsmLMHead, - EsmOutput, - EsmPooler, - EsmPreTrainedModel, - EsmSelfAttention, - EsmSelfOutput, -) - -try: - from torch.nn.attention.flex_attention import create_block_mask, flex_attention -except (ImportError, AttributeError): - create_block_mask = None - flex_attention = None - - -from transformers import PreTrainedTokenizerBase - - -class BaseSequenceTokenizer: - def __init__(self, tokenizer: PreTrainedTokenizerBase): - self.tokenizer = tokenizer - - def __call__(self, sequences, **kwargs): - raise NotImplementedError - - -def _create_pad_block_mask(attention_mask_2d: torch.Tensor): - assert create_block_mask is not None, "Flex attention block mask requires create_block_mask." - token_valid = attention_mask_2d.bool() - batch_size, seq_len = token_valid.shape - - def mask_mod(batch_idx, head_idx, q_idx, kv_idx): - return token_valid[batch_idx, q_idx] & token_valid[batch_idx, kv_idx] - - return create_block_mask( - mask_mod, - batch_size, - 1, - seq_len, - seq_len, - device=attention_mask_2d.device, - ) - - -@dataclass -class DPLMMaskedLMOutput(ModelOutput): - loss: Optional[torch.Tensor] = None - logits: Optional[torch.Tensor] = None - last_hidden_state: Optional[torch.Tensor] = None - hidden_states: Optional[Tuple[torch.Tensor, ...]] = None - attentions: Optional[Tuple[torch.Tensor, ...]] = None - - -class DPLMConfig(EsmConfig): - model_type = "dplm" - - def __init__( - self, - attn_backend: str = "sdpa", - **kwargs, - ): - super().__init__(**kwargs) - self.attn_backend = attn_backend - self.tie_word_embeddings = False - - -class DPLMPreTrainedModel(EsmPreTrainedModel): - config_class = DPLMConfig - base_model_prefix = "dplm" - supports_gradient_checkpointing = True - tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D") - all_tied_weights_keys = {} - - -class ModifiedEsmSelfAttention(EsmSelfAttention): - def __init__(self, config, position_embedding_type=None): - super().__init__(config, position_embedding_type) - self.attn_backend = config.attn_backend - - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - output_attentions: Optional[bool] = False, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - flex_block_mask: Optional[object] = None, - **kwargs, - ) -> Tuple[torch.Tensor]: - if past_key_values is not None: - past_key_value = past_key_values - - mixed_query_layer = self.query(hidden_states) - is_cross_attention = encoder_hidden_states is not None - - if is_cross_attention and past_key_value is not None: - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - query_layer = self.transpose_for_scores(mixed_query_layer) * self.attention_head_size**-0.5 - - if self.is_decoder: - past_key_value = (key_layer, value_layer) - - if self.position_embedding_type == "rotary": - query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) - - if self.position_embedding_type in ["relative_key", "relative_key_query"]: - raise NotImplementedError - - query_layer = query_layer.contiguous() - key_layer = key_layer.contiguous() - value_layer = value_layer.contiguous() - - if output_attentions: - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - if attention_mask is not None: - attention_scores = attention_scores + attention_mask - attention_probs = F.softmax(attention_scores, dim=-1, dtype=torch.float32).to(query_layer.dtype) - context_layer = torch.matmul(attention_probs, value_layer) - else: - attention_probs = None - if self.attn_backend == "flex": - assert flex_attention is not None, "Flex attention backend requested but torch.flex_attention is unavailable." - assert query_layer.dtype in (torch.float16, torch.bfloat16), ( - f"Flex attention backend requires float16 or bfloat16, got {query_layer.dtype}." - ) - assert is_cross_attention is False, "Flex attention backend currently does not support cross-attention." - assert past_key_value is None, "Flex attention backend currently does not support KV caching." - if attention_mask is not None: - assert flex_block_mask is not None, ( - "Flex attention backend requires a block mask when attention_mask is provided." - ) - context_layer = flex_attention( - query_layer, - key_layer, - value_layer, - block_mask=flex_block_mask, - scale=1.0, - ) - else: - context_layer = F.scaled_dot_product_attention( - query_layer, - key_layer, - value_layer, - attn_mask=attention_mask, - scale=1.0, - ) - - if head_mask is not None and torch.is_tensor(head_mask): - context_layer = context_layer * head_mask - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) - - outputs = (context_layer, attention_probs) - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs - - -class ModifiedEsmAttention(EsmAttention): - def __init__(self, config): - nn.Module.__init__(self) - self.self = ModifiedEsmSelfAttention(config) - self.output = EsmSelfOutput(config) - self.pruned_heads = set() - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=None, - output_attentions=False, - flex_block_mask=None, - ): - hidden_states_ln = self.LayerNorm(hidden_states) - self_outputs = self.self( - hidden_states_ln, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - flex_block_mask=flex_block_mask, - ) - attention_output = self.output(self_outputs[0], hidden_states) - outputs = (attention_output,) + self_outputs[1:] - return outputs - - -class ModifiedEsmLayer(EsmLayer): - def __init__(self, config): - nn.Module.__init__(self) - self.chunk_size_feed_forward = config.chunk_size_feed_forward - self.seq_len_dim = 1 - self.attention = ModifiedEsmAttention(config) - self.is_decoder = config.is_decoder - self.add_cross_attention = config.add_cross_attention - if self.add_cross_attention: - if self.is_decoder is False: - raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = ModifiedEsmAttention(config) - self.intermediate = EsmIntermediate(config) - self.output = EsmOutput(config) - self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=None, - output_attentions=False, - flex_block_mask=None, - ): - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - self_attention_outputs = self.attention( - hidden_states, - attention_mask, - head_mask, - output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, - flex_block_mask=flex_block_mask, - ) - attention_output = self_attention_outputs[0] - - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] - - if self.is_decoder and encoder_hidden_states is not None: - if self.add_cross_attention is False: - raise AttributeError( - f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention " - "layers by setting `config.add_cross_attention=True`" - ) - - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - cross_attention_outputs = self.crossattention( - attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, - flex_block_mask=None, - ) - attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] - present_key_value = present_key_value + cross_attention_outputs[-1] - - layer_output = self.feed_forward_chunk(attention_output) - outputs = (layer_output,) + outputs - - if self.is_decoder: - outputs = outputs + (present_key_value,) - return outputs - - -class ModifiedEsmEncoder(EsmEncoder): - def __init__(self, config): - nn.Module.__init__(self) - self.config = config - self.layer = nn.ModuleList([ModifiedEsmLayer(config) for _ in range(config.num_hidden_layers)]) - self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - flex_block_mask=None, - ): - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - next_decoder_cache = () if use_cache else None - - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - flex_block_mask, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - flex_block_mask, - ) - - hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = next_decoder_cache + (layer_outputs[-1],) - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (layer_outputs[2],) - - if self.emb_layer_norm_after: - hidden_states = self.emb_layer_norm_after(hidden_states) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if return_dict is False: - return tuple( - value - for value in [ - hidden_states, - next_decoder_cache, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] - if value is not None - ) - - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) - - -class DPLMModel(DPLMPreTrainedModel, EmbeddingMixin): - config_class = DPLMConfig - - def get_input_embeddings(self) -> nn.Module: - return self.embeddings.word_embeddings - - def __init__(self, config, add_pooling_layer=True): - DPLMPreTrainedModel.__init__(self, config) - self.config = config - self.embeddings = EsmEmbeddings(config) - self.encoder = ModifiedEsmEncoder(config) - self.pooler = EsmPooler(config) if add_pooling_layer else None - self.contact_head = EsmContactPredictionHead( - in_features=config.num_hidden_layers * config.num_attention_heads, - bias=True, - ) - self.post_init() - - def _convert_head_mask_to_5d(self, head_mask: torch.Tensor, num_hidden_layers: int) -> torch.Tensor: - if head_mask.dim() == 1: - head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) - head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1) - elif head_mask.dim() == 2: - head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) - assert head_mask.dim() == 5, f"head_mask.dim != 5, got {head_mask.dim()}" - head_mask = head_mask.to(dtype=self.dtype) - return head_mask - - def get_head_mask( - self, - head_mask: Optional[torch.Tensor], - num_hidden_layers: int, - is_attention_chunked: bool = False, - ) -> Union[torch.Tensor, List[None]]: - if head_mask is None: - return [None] * num_hidden_layers - head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers) - if is_attention_chunked: - head_mask = head_mask.unsqueeze(-1) - return head_mask - - def set_input_embeddings(self, value): - self.embeddings.word_embeddings = value - - def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - if attention_mask is None: - attention_mask = input_ids.ne(self.config.pad_token_id) - outputs = self( - input_ids=input_ids, - attention_mask=attention_mask, - output_hidden_states=False, - output_attentions=False, - return_dict=True, - ) - return outputs.last_hidden_state - - def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: - attns = self(input_ids, attention_mask=attention_mask, output_attentions=True).attentions - attns = torch.stack(attns, dim=1) - attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3) - attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4) - return self.contact_head(input_ids, attns) - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if self.config.is_decoder: - use_cache = use_cache if use_cache is not None else self.config.use_cache - else: - use_cache = False - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - if input_ids is not None: - input_shape = input_ids.size() - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - batch_size, seq_length = input_shape - device = input_ids.device if input_ids is not None else inputs_embeds.device - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - - if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device) - - token_attention_mask = None - if attention_mask.dim() == 2: - token_attention_mask = attention_mask.bool() - if self.config.attn_backend == "flex" and output_attentions is False: - extended_attention_mask = None - else: - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) - elif attention_mask.dim() == 4: - if self.config.attn_backend == "flex" and output_attentions is False: - extended_attention_mask = None - else: - extended_attention_mask = attention_mask - if input_ids is not None: - token_attention_mask = input_ids.ne(self.config.pad_token_id) - else: - raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}") - - if self.config.is_decoder and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_extended_attention_mask = encoder_attention_mask - - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - - embedding_attention_mask = token_attention_mask - if embedding_attention_mask is None and input_ids is not None: - embedding_attention_mask = input_ids.ne(self.config.pad_token_id) - - flex_block_mask = None - if ( - self.config.attn_backend == "flex" - and token_attention_mask is not None - and output_attentions is False - ): - assert create_block_mask is not None, ( - "Flex attention backend requested but torch.create_block_mask is unavailable." - ) - flex_block_mask = _create_pad_block_mask(token_attention_mask) - - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=embedding_attention_mask, - inputs_embeds=inputs_embeds, - ) - encoder_outputs = self.encoder( - embedding_output, - attention_mask=extended_attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - flex_block_mask=flex_block_mask, - ) - sequence_output = encoder_outputs[0] - pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - - if return_dict is False: - return (sequence_output, pooled_output) + encoder_outputs[1:] - - return BaseModelOutputWithPoolingAndCrossAttentions( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - past_key_values=None, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - cross_attentions=encoder_outputs.cross_attentions, - ) - - -class DPLMForMaskedLM(DPLMPreTrainedModel, EmbeddingMixin): - config_class = DPLMConfig - - def __init__(self, config, dropout: float = 0.1): - config.hidden_dropout_prob = dropout - DPLMPreTrainedModel.__init__(self, config) - self.esm = DPLMModel(config, add_pooling_layer=False) - self.lm_head = EsmLMHead(config) - self.loss_fct = nn.CrossEntropyLoss() - self.post_init() - - self.tokenizer = self.__class__.tokenizer - if isinstance(config._name_or_path, str) and len(config._name_or_path) > 0: - try: - self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path) - except Exception: - self.tokenizer = self.__class__.tokenizer - - self.mask_id = self.tokenizer.mask_token_id - self.pad_id = self.tokenizer.pad_token_id - self.bos_id = self.tokenizer.cls_token_id - self.eos_id = self.tokenizer.eos_token_id - self.x_id = self.tokenizer.convert_tokens_to_ids("X") - self.contact_head = None - - def get_input_embeddings(self) -> nn.Module: - return self.esm.embeddings.word_embeddings - - def get_output_embeddings(self): - return self.lm_head.decoder - - def set_output_embeddings(self, new_embeddings): - self.lm_head.decoder = new_embeddings - - def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - return self.esm._embed(input_ids, attention_mask) - - def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: - return self.esm.predict_contacts(input_ids, attention_mask=attention_mask) - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - decoder_input_ids: Optional[torch.Tensor] = None, - decoder_attention_mask: Optional[torch.Tensor] = None, - decoder_inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - ) -> Union[Tuple[torch.Tensor], DPLMMaskedLMOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if attention_mask is None and input_ids is not None: - attention_mask = input_ids.ne(self.pad_id) - - outputs = self.esm( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=True, - ) - sequence_output = outputs.last_hidden_state - logits = self.lm_head(sequence_output) - - loss = None - if labels is not None: - labels = labels.to(logits.device) - loss = self.loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) - - if return_dict is False: - output = (logits, sequence_output, outputs.hidden_states, outputs.attentions) - if loss is not None: - return (loss,) + output - return output - - return DPLMMaskedLMOutput( - loss=loss, - logits=logits, - last_hidden_state=sequence_output, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -class DPLMForSequenceClassification(DPLMPreTrainedModel, EmbeddingMixin): - config_class = DPLMConfig - - def get_input_embeddings(self) -> nn.Module: - return self.esm.embeddings.word_embeddings - - def __init__(self, config): - DPLMPreTrainedModel.__init__(self, config) - self.num_labels = config.num_labels - self.esm = DPLMModel(config, add_pooling_layer=False) - self.classifier = EsmClassificationHead(config) - self.mse = nn.MSELoss() - self.ce = nn.CrossEntropyLoss() - self.bce = nn.BCEWithLogitsLoss() - self.post_init() - - def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - return self.esm._embed(input_ids, attention_mask) - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs, - ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: - outputs = self.esm( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=True, - ) - sequence_output = outputs.last_hidden_state - logits = self.classifier(sequence_output) - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - if self.num_labels == 1: - loss = self.mse(logits.squeeze(), labels.squeeze()) - else: - loss = self.mse(logits, labels) - elif self.config.problem_type == "single_label_classification": - loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss = self.bce(logits, labels) - - return SequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -class DPLMForTokenClassification(DPLMPreTrainedModel, EmbeddingMixin): - config_class = DPLMConfig - - def get_input_embeddings(self) -> nn.Module: - return self.esm.embeddings.word_embeddings - - def __init__(self, config): - DPLMPreTrainedModel.__init__(self, config) - self.num_labels = config.num_labels - self.esm = DPLMModel(config, add_pooling_layer=False) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) - self.loss_fct = nn.CrossEntropyLoss() - self.post_init() - - def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - return self.esm._embed(input_ids, attention_mask) - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs, - ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: - outputs = self.esm( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=True, - ) - sequence_output = self.dropout(outputs.last_hidden_state) - logits = self.classifier(sequence_output) - - loss = None - if labels is not None: - labels = labels.to(logits.device) - loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) +### Embedding Mixin + Pooler +import os +import sqlite3 +import networkx as nx +import numpy as np +import torch +from tqdm.auto import tqdm +from typing import Callable, List, Optional +from torch.utils.data import DataLoader +from torch.utils.data import Dataset as TorchDataset +from transformers import PreTrainedTokenizerBase + + +class Pooler: + def __init__(self, pooling_types: List[str]): + self.pooling_types = pooling_types + self.pooling_options = { + 'mean': self.mean_pooling, + 'max': self.max_pooling, + 'norm': self.norm_pooling, + 'median': self.median_pooling, + 'std': self.std_pooling, + 'var': self.var_pooling, + 'cls': self.cls_pooling, + 'parti': self._pool_parti, + } + + def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor: + maxed_attentions = torch.max(attentions, dim=1)[0] + return maxed_attentions + + def _page_rank(self, attention_matrix, personalization=None, nstart=None, prune_type="top_k_outdegree"): + # Run PageRank on the attention matrix converted to a graph. + # Raises exceptions if the graph doesn't match the token sequence or has no edges. + # Returns the PageRank scores for each token node. + G = self._convert_to_graph(attention_matrix) + if G.number_of_nodes() != attention_matrix.shape[0]: + raise Exception( + f"The number of nodes in the graph should be equal to the number of tokens in sequence! You have {G.number_of_nodes()} nodes for {attention_matrix.shape[0]} tokens.") + if G.number_of_edges() == 0: + raise Exception(f"You don't seem to have any attention edges left in the graph.") + + return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100) + + def _convert_to_graph(self, matrix): + # Convert a matrix (e.g., attention scores) to a directed graph using networkx. + # Each element in the matrix represents a directed edge with a weight. + G = nx.from_numpy_array(matrix, create_using=nx.DiGraph) + return G + + def _calculate_importance_weights(self, dict_importance, attention_mask: Optional[torch.Tensor] = None): + # Remove keys where attention_mask is 0 + if attention_mask is not None: + for k in list(dict_importance.keys()): + if attention_mask[k] == 0: + del dict_importance[k] + + #dict_importance[0] # remove cls + #dict_importance[-1] # remove eos + total = sum(dict_importance.values()) + return np.array([v / total for _, v in dict_importance.items()]) + + def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d) + maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy() + # emb is (b, L, d), maxed_attentions is (b, L, L) + emb_pooled = [] + for e, a, mask in zip(emb, maxed_attentions, attention_mask): + dict_importance = self._page_rank(a) + importance_weights = self._calculate_importance_weights(dict_importance, mask) + num_tokens = int(mask.sum().item()) + emb_pooled.append(np.average(e[:num_tokens], weights=importance_weights, axis=0)) + pooled = torch.tensor(np.array(emb_pooled)) + return pooled + + def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) + if attention_mask is None: + return emb.mean(dim=1) + else: + attention_mask = attention_mask.unsqueeze(-1) + return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) + + def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) + if attention_mask is None: + return emb.max(dim=1).values + else: + attention_mask = attention_mask.unsqueeze(-1) + return (emb * attention_mask).max(dim=1).values + + def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) + if attention_mask is None: + return emb.norm(dim=1, p=2) + else: + attention_mask = attention_mask.unsqueeze(-1) + return (emb * attention_mask).norm(dim=1, p=2) + + def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) + if attention_mask is None: + return emb.median(dim=1).values + else: + attention_mask = attention_mask.unsqueeze(-1) + return (emb * attention_mask).median(dim=1).values + + def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) + if attention_mask is None: + return emb.std(dim=1) + else: + # Compute variance correctly over non-masked positions, then take sqrt + var = self.var_pooling(emb, attention_mask, **kwargs) + return torch.sqrt(var) + + def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) + if attention_mask is None: + return emb.var(dim=1) + else: + # Correctly compute variance over only non-masked positions + attention_mask = attention_mask.unsqueeze(-1) # (b, L, 1) + # Compute mean over non-masked positions + mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d) + mean = mean.unsqueeze(1) # (b, 1, d) + # Compute squared differences from mean, only over non-masked positions + squared_diff = (emb - mean) ** 2 # (b, L, d) + # Sum squared differences over non-masked positions and divide by count + var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d) + return var + + def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) + return emb[:, 0, :] + + def __call__( + self, + emb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + attentions: Optional[torch.Tensor] = None + ): # [mean, max] + final_emb = [] + for pooling_type in self.pooling_types: + final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) # (b, d) + return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d) + + +class ProteinDataset(TorchDataset): + """Simple dataset for protein sequences.""" + def __init__(self, sequences: list[str]): + self.sequences = sequences + + def __len__(self) -> int: + return len(self.sequences) + + def __getitem__(self, idx: int) -> str: + return self.sequences[idx] + + +def build_collator(tokenizer: PreTrainedTokenizerBase) -> Callable[[list[str]], dict[str, torch.Tensor]]: + def _collate_fn(sequences: list[str]) -> dict[str, torch.Tensor]: + return tokenizer(sequences, return_tensors="pt", padding='longest') + return _collate_fn + + +class EmbeddingMixin: + def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + raise NotImplementedError + + @property + def device(self) -> torch.device: + """Get the device of the model.""" + return next(self.parameters()).device + + def _read_sequences_from_db(self, db_path: str) -> set[str]: + """Read sequences from SQLite database.""" + sequences = [] + with sqlite3.connect(db_path) as conn: + c = conn.cursor() + c.execute("SELECT sequence FROM embeddings") + while True: + row = c.fetchone() + if row is None: + break + sequences.append(row[0]) + return set(sequences) + + def _ensure_embeddings_table(self, conn: sqlite3.Connection) -> None: + cursor = conn.cursor() + cursor.execute( + "CREATE TABLE IF NOT EXISTS embeddings (" + "sequence TEXT PRIMARY KEY, " + "embedding BLOB NOT NULL, " + "shape TEXT, " + "dtype TEXT" + ")" + ) + cursor.execute("PRAGMA table_info(embeddings)") + rows = cursor.fetchall() + column_names = [row[1] for row in rows] + if "shape" not in column_names: + cursor.execute("ALTER TABLE embeddings ADD COLUMN shape TEXT") + if "dtype" not in column_names: + cursor.execute("ALTER TABLE embeddings ADD COLUMN dtype TEXT") + conn.commit() + + def load_embeddings_from_pth(self, save_path: str) -> dict[str, torch.Tensor]: + assert os.path.exists(save_path), f"Embedding file does not exist: {save_path}" + payload = torch.load(save_path, map_location="cpu", weights_only=True) + assert isinstance(payload, dict), "Expected .pth embeddings file to contain a dictionary." + for sequence, tensor in payload.items(): + assert isinstance(sequence, str), "Expected embedding dictionary keys to be sequences (str)." + assert isinstance(tensor, torch.Tensor), "Expected embedding dictionary values to be tensors." + return payload + + def load_embeddings_from_db(self, db_path: str, sequences: Optional[List[str]] = None) -> dict[str, torch.Tensor]: + assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}" + loaded: dict[str, torch.Tensor] = {} + with sqlite3.connect(db_path) as conn: + self._ensure_embeddings_table(conn) + cursor = conn.cursor() + if sequences is None: + cursor.execute("SELECT sequence, embedding, shape, dtype FROM embeddings") + else: + if len(sequences) == 0: + return loaded + placeholders = ",".join(["?"] * len(sequences)) + cursor.execute( + f"SELECT sequence, embedding, shape, dtype FROM embeddings WHERE sequence IN ({placeholders})", + tuple(sequences), + ) + + rows = cursor.fetchall() + for row in rows: + sequence = row[0] + embedding_bytes = row[1] + shape_text = row[2] + dtype_text = row[3] + assert shape_text is not None, "Missing shape metadata in embeddings table." + assert dtype_text is not None, "Missing dtype metadata in embeddings table." + shape_values = [int(value) for value in shape_text.split(",") if len(value) > 0] + assert len(shape_values) > 0, f"Invalid shape metadata for sequence: {sequence}" + expected_size = int(np.prod(shape_values)) + np_dtype = np.dtype(dtype_text) + array = np.frombuffer(embedding_bytes, dtype=np_dtype) + assert array.size == expected_size, f"Shape mismatch while reading sequence: {sequence}" + reshaped = array.copy().reshape(tuple(shape_values)) + loaded[sequence] = torch.from_numpy(reshaped) + return loaded + + def embed_dataset( + self, + sequences: List[str], + tokenizer: Optional[PreTrainedTokenizerBase] = None, + batch_size: int = 2, + max_len: int = 512, + truncate: bool = True, + full_embeddings: bool = False, + embed_dtype: torch.dtype = torch.float32, + pooling_types: List[str] = ['mean'], + num_workers: int = 0, + sql: bool = False, + save: bool = True, + sql_db_path: str = 'embeddings.db', + save_path: str = 'embeddings.pth', + **kwargs, + ) -> Optional[dict[str, torch.Tensor]]: + """ + Embed a dataset of protein sequences. + + Supports two modes: + - Tokenizer mode (ESM2/ESM++): provide `tokenizer`, `_embed(input_ids, attention_mask)` is used. + - Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used. + """ + sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences])) + sequences = sorted(sequences, key=len, reverse=True) + hidden_size = self.config.hidden_size + pooler = Pooler(pooling_types) if not full_embeddings else None + tokenizer_mode = tokenizer is not None + if tokenizer_mode: + collate_fn = build_collator(tokenizer) + device = self.device + else: + collate_fn = None + device = None + + def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + if full_embeddings or residue_embeddings.ndim == 2: + return residue_embeddings + return pooler(residue_embeddings, attention_mask) + + def iter_batches(to_embed: List[str]): + if tokenizer_mode: + assert collate_fn is not None + assert device is not None + dataset = ProteinDataset(to_embed) + dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, shuffle=False) + for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'): + seqs = to_embed[i * batch_size:(i + 1) * batch_size] + input_ids = batch['input_ids'].to(device) + attention_mask = batch['attention_mask'].to(device) + residue_embeddings = self._embed(input_ids, attention_mask) + yield seqs, residue_embeddings, attention_mask + else: + for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'): + seqs = to_embed[batch_start:batch_start + batch_size] + batch_output = self._embed(seqs, return_attention_mask=True, **kwargs) + assert isinstance(batch_output, tuple), "Sequence mode _embed must return (last_hidden_state, attention_mask)." + assert len(batch_output) == 2, "Sequence mode _embed must return exactly two values." + residue_embeddings, attention_mask = batch_output + assert isinstance(attention_mask, torch.Tensor), "Sequence mode _embed must return attention_mask as a torch.Tensor." + yield seqs, residue_embeddings, attention_mask + + if sql: + conn = sqlite3.connect(sql_db_path) + self._ensure_embeddings_table(conn) + c = conn.cursor() + already_embedded = self._read_sequences_from_db(sql_db_path) + to_embed = [seq for seq in sequences if seq not in already_embedded] + print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}") + print(f"Embedding {len(to_embed)} new sequences") + if len(to_embed) > 0: + with torch.no_grad(): + for i, (seqs, residue_embeddings, attention_mask) in enumerate(iter_batches(to_embed)): + embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype) + for seq, emb, mask in zip(seqs, embeddings, attention_mask): + if full_embeddings: + emb = emb[mask.bool()].reshape(-1, hidden_size) + emb_np = emb.cpu().numpy() + emb_shape = ",".join([str(dim) for dim in emb_np.shape]) + emb_dtype = str(emb_np.dtype) + c.execute( + "INSERT OR REPLACE INTO embeddings (sequence, embedding, shape, dtype) VALUES (?, ?, ?, ?)", + (seq, emb_np.tobytes(), emb_shape, emb_dtype), + ) + if tokenizer_mode and (i + 1) % 100 == 0: + conn.commit() + conn.commit() + conn.close() + return None + + embeddings_dict = {} + if os.path.exists(save_path): + embeddings_dict = self.load_embeddings_from_pth(save_path) + to_embed = [seq for seq in sequences if seq not in embeddings_dict] + print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}") + print(f"Embedding {len(to_embed)} new sequences") + else: + to_embed = sequences + print(f"Embedding {len(to_embed)} new sequences") + + if len(to_embed) > 0: + with torch.no_grad(): + for seqs, residue_embeddings, attention_mask in iter_batches(to_embed): + embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype) + for seq, emb, mask in zip(seqs, embeddings, attention_mask): + if full_embeddings: + emb = emb[mask.bool()].reshape(-1, hidden_size) + embeddings_dict[seq] = emb.cpu() + + if save: + torch.save(embeddings_dict, save_path) + + return embeddings_dict + + +# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: Apache-2.0 +""" +FastPLMs-compatible DPLM implementation. +""" + +import torch +import torch.nn as nn +from torch.nn import functional as F +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +from transformers import AutoTokenizer, EsmTokenizer +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + ModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.models.esm.configuration_esm import EsmConfig +from transformers.models.esm.modeling_esm import ( + EsmAttention, + EsmClassificationHead, + EsmContactPredictionHead, + EsmEmbeddings, + EsmEncoder, + EsmIntermediate, + EsmLayer, + EsmLMHead, + EsmOutput, + EsmPooler, + EsmPreTrainedModel, + EsmSelfAttention, + EsmSelfOutput, +) + +try: + from torch.nn.attention.flex_attention import create_block_mask, flex_attention +except (ImportError, AttributeError): + create_block_mask = None + flex_attention = None + + +from transformers import PreTrainedTokenizerBase + + +class BaseSequenceTokenizer: + def __init__(self, tokenizer: PreTrainedTokenizerBase): + self.tokenizer = tokenizer + + def __call__(self, sequences, **kwargs): + raise NotImplementedError + + +def get_attention_mask( + attn_backend: str, + batch_size: int, + seq_len: int, + device: torch.device, + attention_mask: Optional[torch.Tensor] = None, +) -> Tuple[Optional[torch.Tensor], Optional[object]]: + if attention_mask is None: + token_attention_mask = torch.ones((batch_size, seq_len), device=device).bool() + else: + token_attention_mask = attention_mask.bool() + + if attn_backend == "flex": + assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable." + + def mask_mod(batch_idx, head_idx, q_idx, kv_idx): + return token_attention_mask[batch_idx, q_idx] & token_attention_mask[batch_idx, kv_idx] + + flex_block_mask = create_block_mask( + mask_mod, + batch_size, + 1, + seq_len, + seq_len, + device=device, + ) + extended_attention_mask = None + else: + flex_block_mask = None + extended_attention_mask = token_attention_mask[:, None, :, None] & token_attention_mask[:, None, None, :] + + return extended_attention_mask, flex_block_mask + + +@dataclass +class DPLMMaskedLMOutput(ModelOutput): + loss: Optional[torch.Tensor] = None + logits: Optional[torch.Tensor] = None + last_hidden_state: Optional[torch.Tensor] = None + hidden_states: Optional[Tuple[torch.Tensor, ...]] = None + attentions: Optional[Tuple[torch.Tensor, ...]] = None + + +class DPLMConfig(EsmConfig): + model_type = "dplm" + + def __init__( + self, + attn_backend: str = "sdpa", + **kwargs, + ): + super().__init__(**kwargs) + self.attn_backend = attn_backend + self.tie_word_embeddings = False + + +class DPLMPreTrainedModel(EsmPreTrainedModel): + config_class = DPLMConfig + base_model_prefix = "dplm" + supports_gradient_checkpointing = True + tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D") + all_tied_weights_keys = {} + + @property + def attn_backend(self) -> str: + return self.config.attn_backend + + @attn_backend.setter + def attn_backend(self, backend: str) -> None: + assert backend in ("sdpa", "flex"), f"Unsupported attn_backend: {backend}" + self.config.attn_backend = backend + + +class ModifiedEsmSelfAttention(EsmSelfAttention): + def __init__(self, config, position_embedding_type=None): + super().__init__(config, position_embedding_type) + self.config = config + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + flex_block_mask: Optional[object] = None, + **kwargs, + ) -> Tuple[torch.Tensor]: + if past_key_values is not None: + past_key_value = past_key_values + + mixed_query_layer = self.query(hidden_states) + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) * self.attention_head_size**-0.5 + + if self.is_decoder: + past_key_value = (key_layer, value_layer) + + if self.position_embedding_type == "rotary": + query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) + + if self.position_embedding_type in ["relative_key", "relative_key_query"]: + raise NotImplementedError + + query_layer = query_layer.contiguous() + key_layer = key_layer.contiguous() + value_layer = value_layer.contiguous() + + if output_attentions: + assert attention_mask is not None, "output_attentions=True requires a concrete attention mask." + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores.masked_fill(attention_mask.logical_not(), float("-inf")) + attention_probs = F.softmax(attention_scores, dim=-1, dtype=torch.float32).to(query_layer.dtype) + context_layer = torch.matmul(attention_probs, value_layer) + else: + attention_probs = None + if self.config.attn_backend == "flex": + assert flex_attention is not None, "Flex attention backend requested but torch.flex_attention is unavailable." + assert query_layer.dtype in (torch.float16, torch.bfloat16), ( + f"Flex attention backend requires float16 or bfloat16, got {query_layer.dtype}." + ) + assert is_cross_attention is False, "Flex attention backend currently does not support cross-attention." + assert past_key_value is None, "Flex attention backend currently does not support KV caching." + assert flex_block_mask is not None, "Flex attention backend requires a block mask." + context_layer = flex_attention( + query_layer, + key_layer, + value_layer, + block_mask=flex_block_mask, + scale=1.0, + ) + else: + context_layer = F.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + scale=1.0, + ) + + if head_mask is not None and torch.is_tensor(head_mask): + context_layer = context_layer * head_mask + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +class ModifiedEsmAttention(EsmAttention): + def __init__(self, config): + nn.Module.__init__(self) + self.self = ModifiedEsmSelfAttention(config) + self.output = EsmSelfOutput(config) + self.pruned_heads = set() + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: bool = False, + flex_block_mask: Optional[object] = None, + ): + hidden_states_ln = self.LayerNorm(hidden_states) + self_outputs = self.self( + hidden_states_ln, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + flex_block_mask=flex_block_mask, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] + return outputs + + +class ModifiedEsmLayer(EsmLayer): + def __init__(self, config): + nn.Module.__init__(self) + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = ModifiedEsmAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if self.is_decoder is False: + raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = ModifiedEsmAttention(config) + self.intermediate = EsmIntermediate(config) + self.output = EsmOutput(config) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: bool = False, + flex_block_mask: Optional[object] = None, + ): + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + flex_block_mask=flex_block_mask, + ) + attention_output = self_attention_outputs[0] + + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] + + if self.is_decoder and encoder_hidden_states is not None: + if self.add_cross_attention is False: + raise AttributeError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention " + "layers by setting `config.add_cross_attention=True`" + ) + + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + flex_block_mask=None, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] + present_key_value = present_key_value + cross_attention_outputs[-1] + + layer_output = self.feed_forward_chunk(attention_output) + outputs = (layer_output,) + outputs + + if self.is_decoder: + outputs = outputs + (present_key_value,) + return outputs + + +class ModifiedEsmEncoder(EsmEncoder): + def __init__(self, config): + nn.Module.__init__(self) + self.config = config + self.layer = nn.ModuleList([ModifiedEsmLayer(config) for _ in range(config.num_hidden_layers)]) + self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[Tuple[Tuple[torch.FloatTensor]]]] = None, + use_cache: Optional[bool] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + flex_block_mask: Optional[object] = None, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + next_decoder_cache = () if use_cache else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + flex_block_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + flex_block_mask, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache = next_decoder_cache + (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if self.emb_layer_norm_after: + hidden_states = self.emb_layer_norm_after(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if return_dict is False: + return tuple( + value + for value in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if value is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class DPLMModel(DPLMPreTrainedModel, EmbeddingMixin): + config_class = DPLMConfig + + def get_input_embeddings(self) -> nn.Module: + return self.embeddings.word_embeddings + + def __init__(self, config, add_pooling_layer=True): + DPLMPreTrainedModel.__init__(self, config) + self.config = config + self.embeddings = EsmEmbeddings(config) + self.encoder = ModifiedEsmEncoder(config) + self.pooler = EsmPooler(config) if add_pooling_layer else None + self.contact_head = EsmContactPredictionHead( + in_features=config.num_hidden_layers * config.num_attention_heads, + bias=True, + ) + self.post_init() + + def _convert_head_mask_to_5d(self, head_mask: torch.Tensor, num_hidden_layers: int) -> torch.Tensor: + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) + assert head_mask.dim() == 5, f"head_mask.dim != 5, got {head_mask.dim()}" + head_mask = head_mask.to(dtype=self.dtype) + return head_mask + + def get_head_mask( + self, + head_mask: Optional[torch.Tensor], + num_hidden_layers: int, + is_attention_chunked: bool = False, + ) -> Union[torch.Tensor, List[None]]: + if head_mask is None: + return [None] * num_hidden_layers + head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers) + if is_attention_chunked: + head_mask = head_mask.unsqueeze(-1) + return head_mask + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + if attention_mask is None: + attention_mask = input_ids.ne(self.config.pad_token_id) + outputs = self( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=False, + output_attentions=False, + return_dict=True, + ) + return outputs.last_hidden_state + + def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + attns = self(input_ids, attention_mask=attention_mask, output_attentions=True).attentions + attns = torch.stack(attns, dim=1) + attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3) + attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4) + return self.contact_head(input_ids, attns) + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + token_attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device).bool() + elif attention_mask.dim() == 2: + token_attention_mask = attention_mask.bool() + elif attention_mask.dim() == 4: + assert input_ids is not None, "4D attention_mask requires input_ids to infer token-level mask." + token_attention_mask = input_ids.ne(self.config.pad_token_id) + else: + raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}") + + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = encoder_attention_mask + + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_attention_mask = token_attention_mask + if embedding_attention_mask is None and input_ids is not None: + embedding_attention_mask = input_ids.ne(self.config.pad_token_id) + + if self.config.attn_backend == "flex" and output_attentions: + raise AssertionError("output_attentions=True is not supported with attn_backend='flex'.") + + extended_attention_mask, flex_block_mask = get_attention_mask( + attn_backend=self.config.attn_backend, + batch_size=batch_size, + seq_len=seq_length, + device=device, + attention_mask=token_attention_mask, + ) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=embedding_attention_mask, + inputs_embeds=inputs_embeds, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + flex_block_mask=flex_block_mask, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if return_dict is False: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=None, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class DPLMForMaskedLM(DPLMPreTrainedModel, EmbeddingMixin): + config_class = DPLMConfig + + def __init__(self, config, dropout: float = 0.1): + config.hidden_dropout_prob = dropout + DPLMPreTrainedModel.__init__(self, config) + self.esm = DPLMModel(config, add_pooling_layer=False) + self.lm_head = EsmLMHead(config) + self.loss_fct = nn.CrossEntropyLoss() + self.post_init() + + self.tokenizer = self.__class__.tokenizer + if isinstance(config._name_or_path, str) and len(config._name_or_path) > 0: + try: + self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path) + except Exception: + self.tokenizer = self.__class__.tokenizer + + self.mask_id = self.tokenizer.mask_token_id + self.pad_id = self.tokenizer.pad_token_id + self.bos_id = self.tokenizer.cls_token_id + self.eos_id = self.tokenizer.eos_token_id + self.x_id = self.tokenizer.convert_tokens_to_ids("X") + self.contact_head = None + + def get_input_embeddings(self) -> nn.Module: + return self.esm.embeddings.word_embeddings + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + + def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + return self.esm._embed(input_ids, attention_mask) + + def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + return self.esm.predict_contacts(input_ids, attention_mask=attention_mask) + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.Tensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ) -> Union[Tuple[torch.Tensor], DPLMMaskedLMOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if attention_mask is None and input_ids is not None: + attention_mask = input_ids.ne(self.pad_id) + + outputs = self.esm( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + sequence_output = outputs.last_hidden_state + logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss = self.loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if return_dict is False: + output = (logits, sequence_output, outputs.hidden_states, outputs.attentions) + if loss is not None: + return (loss,) + output + return output + + return DPLMMaskedLMOutput( + loss=loss, + logits=logits, + last_hidden_state=sequence_output, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class DPLMForSequenceClassification(DPLMPreTrainedModel, EmbeddingMixin): + config_class = DPLMConfig + + def get_input_embeddings(self) -> nn.Module: + return self.esm.embeddings.word_embeddings + + def __init__(self, config): + DPLMPreTrainedModel.__init__(self, config) + self.num_labels = config.num_labels + self.esm = DPLMModel(config, add_pooling_layer=False) + self.classifier = EsmClassificationHead(config) + self.mse = nn.MSELoss() + self.ce = nn.CrossEntropyLoss() + self.bce = nn.BCEWithLogitsLoss() + self.post_init() + + def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + return self.esm._embed(input_ids, attention_mask) + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + outputs = self.esm( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + sequence_output = outputs.last_hidden_state + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + if self.num_labels == 1: + loss = self.mse(logits.squeeze(), labels.squeeze()) + else: + loss = self.mse(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss = self.bce(logits, labels) + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class DPLMForTokenClassification(DPLMPreTrainedModel, EmbeddingMixin): + config_class = DPLMConfig + + def get_input_embeddings(self) -> nn.Module: + return self.esm.embeddings.word_embeddings + + def __init__(self, config): + DPLMPreTrainedModel.__init__(self, config) + self.num_labels = config.num_labels + self.esm = DPLMModel(config, add_pooling_layer=False) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self.loss_fct = nn.CrossEntropyLoss() + self.post_init() + + def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + return self.esm._embed(input_ids, attention_mask) + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + outputs = self.esm( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + sequence_output = self.dropout(outputs.last_hidden_state) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + )