### Embedding Mixin + Pooler import os import sqlite3 import networkx as nx import numpy as np import torch from tqdm.auto import tqdm from typing import Callable, List, Optional from torch.utils.data import DataLoader from torch.utils.data import Dataset as TorchDataset from transformers import PreTrainedTokenizerBase class Pooler: def __init__(self, pooling_types: List[str]): self.pooling_types = pooling_types self.pooling_options = { 'mean': self.mean_pooling, 'max': self.max_pooling, 'norm': self.norm_pooling, 'median': self.median_pooling, 'std': self.std_pooling, 'var': self.var_pooling, 'cls': self.cls_pooling, 'parti': self._pool_parti, } def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor: maxed_attentions = torch.max(attentions, dim=1)[0] return maxed_attentions def _page_rank(self, attention_matrix, personalization=None, nstart=None, prune_type="top_k_outdegree"): # Run PageRank on the attention matrix converted to a graph. # Raises exceptions if the graph doesn't match the token sequence or has no edges. # Returns the PageRank scores for each token node. G = self._convert_to_graph(attention_matrix) if G.number_of_nodes() != attention_matrix.shape[0]: raise Exception( f"The number of nodes in the graph should be equal to the number of tokens in sequence! You have {G.number_of_nodes()} nodes for {attention_matrix.shape[0]} tokens.") if G.number_of_edges() == 0: raise Exception(f"You don't seem to have any attention edges left in the graph.") return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100) def _convert_to_graph(self, matrix): # Convert a matrix (e.g., attention scores) to a directed graph using networkx. # Each element in the matrix represents a directed edge with a weight. G = nx.from_numpy_array(matrix, create_using=nx.DiGraph) return G def _calculate_importance_weights(self, dict_importance, attention_mask: Optional[torch.Tensor] = None): # Remove keys where attention_mask is 0 if attention_mask is not None: for k in list(dict_importance.keys()): if attention_mask[k] == 0: del dict_importance[k] #dict_importance[0] # remove cls #dict_importance[-1] # remove eos total = sum(dict_importance.values()) return np.array([v / total for _, v in dict_importance.items()]) def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d) maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy() # emb is (b, L, d), maxed_attentions is (b, L, L) emb_pooled = [] for e, a, mask in zip(emb, maxed_attentions, attention_mask): dict_importance = self._page_rank(a) importance_weights = self._calculate_importance_weights(dict_importance, mask) num_tokens = int(mask.sum().item()) emb_pooled.append(np.average(e[:num_tokens], weights=importance_weights, axis=0)) pooled = torch.tensor(np.array(emb_pooled)) return pooled def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) if attention_mask is None: return emb.mean(dim=1) else: attention_mask = attention_mask.unsqueeze(-1) return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) if attention_mask is None: return emb.max(dim=1).values else: attention_mask = attention_mask.unsqueeze(-1) return (emb * attention_mask).max(dim=1).values def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) if attention_mask is None: return emb.norm(dim=1, p=2) else: attention_mask = attention_mask.unsqueeze(-1) return (emb * attention_mask).norm(dim=1, p=2) def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) if attention_mask is None: return emb.median(dim=1).values else: attention_mask = attention_mask.unsqueeze(-1) return (emb * attention_mask).median(dim=1).values def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) if attention_mask is None: return emb.std(dim=1) else: # Compute variance correctly over non-masked positions, then take sqrt var = self.var_pooling(emb, attention_mask, **kwargs) return torch.sqrt(var) def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) if attention_mask is None: return emb.var(dim=1) else: # Correctly compute variance over only non-masked positions attention_mask = attention_mask.unsqueeze(-1) # (b, L, 1) # Compute mean over non-masked positions mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d) mean = mean.unsqueeze(1) # (b, 1, d) # Compute squared differences from mean, only over non-masked positions squared_diff = (emb - mean) ** 2 # (b, L, d) # Sum squared differences over non-masked positions and divide by count var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d) return var def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d) return emb[:, 0, :] def __call__( self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, attentions: Optional[torch.Tensor] = None ): # [mean, max] final_emb = [] for pooling_type in self.pooling_types: final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) # (b, d) return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d) class ProteinDataset(TorchDataset): """Simple dataset for protein sequences.""" def __init__(self, sequences: list[str]): self.sequences = sequences def __len__(self) -> int: return len(self.sequences) def __getitem__(self, idx: int) -> str: return self.sequences[idx] def build_collator(tokenizer: PreTrainedTokenizerBase) -> Callable[[list[str]], dict[str, torch.Tensor]]: def _collate_fn(sequences: list[str]) -> dict[str, torch.Tensor]: return tokenizer(sequences, return_tensors="pt", padding='longest') return _collate_fn class EmbeddingMixin: def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: raise NotImplementedError @property def device(self) -> torch.device: """Get the device of the model.""" return next(self.parameters()).device def _read_sequences_from_db(self, db_path: str) -> set[str]: """Read sequences from SQLite database.""" sequences = [] with sqlite3.connect(db_path) as conn: c = conn.cursor() c.execute("SELECT sequence FROM embeddings") while True: row = c.fetchone() if row is None: break sequences.append(row[0]) return set(sequences) def _ensure_embeddings_table(self, conn: sqlite3.Connection) -> None: cursor = conn.cursor() cursor.execute( "CREATE TABLE IF NOT EXISTS embeddings (" "sequence TEXT PRIMARY KEY, " "embedding BLOB NOT NULL, " "shape TEXT, " "dtype TEXT" ")" ) cursor.execute("PRAGMA table_info(embeddings)") rows = cursor.fetchall() column_names = [row[1] for row in rows] if "shape" not in column_names: cursor.execute("ALTER TABLE embeddings ADD COLUMN shape TEXT") if "dtype" not in column_names: cursor.execute("ALTER TABLE embeddings ADD COLUMN dtype TEXT") conn.commit() def load_embeddings_from_pth(self, save_path: str) -> dict[str, torch.Tensor]: assert os.path.exists(save_path), f"Embedding file does not exist: {save_path}" payload = torch.load(save_path, map_location="cpu", weights_only=True) assert isinstance(payload, dict), "Expected .pth embeddings file to contain a dictionary." for sequence, tensor in payload.items(): assert isinstance(sequence, str), "Expected embedding dictionary keys to be sequences (str)." assert isinstance(tensor, torch.Tensor), "Expected embedding dictionary values to be tensors." return payload def load_embeddings_from_db(self, db_path: str, sequences: Optional[List[str]] = None) -> dict[str, torch.Tensor]: assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}" loaded: dict[str, torch.Tensor] = {} with sqlite3.connect(db_path) as conn: self._ensure_embeddings_table(conn) cursor = conn.cursor() if sequences is None: cursor.execute("SELECT sequence, embedding, shape, dtype FROM embeddings") else: if len(sequences) == 0: return loaded placeholders = ",".join(["?"] * len(sequences)) cursor.execute( f"SELECT sequence, embedding, shape, dtype FROM embeddings WHERE sequence IN ({placeholders})", tuple(sequences), ) rows = cursor.fetchall() for row in rows: sequence = row[0] embedding_bytes = row[1] shape_text = row[2] dtype_text = row[3] assert shape_text is not None, "Missing shape metadata in embeddings table." assert dtype_text is not None, "Missing dtype metadata in embeddings table." shape_values = [int(value) for value in shape_text.split(",") if len(value) > 0] assert len(shape_values) > 0, f"Invalid shape metadata for sequence: {sequence}" expected_size = int(np.prod(shape_values)) np_dtype = np.dtype(dtype_text) array = np.frombuffer(embedding_bytes, dtype=np_dtype) assert array.size == expected_size, f"Shape mismatch while reading sequence: {sequence}" reshaped = array.copy().reshape(tuple(shape_values)) loaded[sequence] = torch.from_numpy(reshaped) return loaded def embed_dataset( self, sequences: List[str], tokenizer: Optional[PreTrainedTokenizerBase] = None, batch_size: int = 2, max_len: int = 512, truncate: bool = True, full_embeddings: bool = False, embed_dtype: torch.dtype = torch.float32, pooling_types: List[str] = ['mean'], num_workers: int = 0, sql: bool = False, save: bool = True, sql_db_path: str = 'embeddings.db', save_path: str = 'embeddings.pth', **kwargs, ) -> Optional[dict[str, torch.Tensor]]: """ Embed a dataset of protein sequences. Supports two modes: - Tokenizer mode (ESM2/ESM++): provide `tokenizer`, `_embed(input_ids, attention_mask)` is used. - Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used. """ sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences])) sequences = sorted(sequences, key=len, reverse=True) hidden_size = self.config.hidden_size pooler = Pooler(pooling_types) if not full_embeddings else None tokenizer_mode = tokenizer is not None if tokenizer_mode: collate_fn = build_collator(tokenizer) device = self.device else: collate_fn = None device = None def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: if full_embeddings or residue_embeddings.ndim == 2: return residue_embeddings return pooler(residue_embeddings, attention_mask) def iter_batches(to_embed: List[str]): if tokenizer_mode: assert collate_fn is not None assert device is not None dataset = ProteinDataset(to_embed) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, shuffle=False) for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'): seqs = to_embed[i * batch_size:(i + 1) * batch_size] input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) residue_embeddings = self._embed(input_ids, attention_mask) yield seqs, residue_embeddings, attention_mask else: for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'): seqs = to_embed[batch_start:batch_start + batch_size] batch_output = self._embed(seqs, return_attention_mask=True, **kwargs) assert isinstance(batch_output, tuple), "Sequence mode _embed must return (last_hidden_state, attention_mask)." assert len(batch_output) == 2, "Sequence mode _embed must return exactly two values." residue_embeddings, attention_mask = batch_output assert isinstance(attention_mask, torch.Tensor), "Sequence mode _embed must return attention_mask as a torch.Tensor." yield seqs, residue_embeddings, attention_mask if sql: conn = sqlite3.connect(sql_db_path) self._ensure_embeddings_table(conn) c = conn.cursor() already_embedded = self._read_sequences_from_db(sql_db_path) to_embed = [seq for seq in sequences if seq not in already_embedded] print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}") print(f"Embedding {len(to_embed)} new sequences") if len(to_embed) > 0: with torch.no_grad(): for i, (seqs, residue_embeddings, attention_mask) in enumerate(iter_batches(to_embed)): embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype) for seq, emb, mask in zip(seqs, embeddings, attention_mask): if full_embeddings: emb = emb[mask.bool()].reshape(-1, hidden_size) emb_np = emb.cpu().numpy() emb_shape = ",".join([str(dim) for dim in emb_np.shape]) emb_dtype = str(emb_np.dtype) c.execute( "INSERT OR REPLACE INTO embeddings (sequence, embedding, shape, dtype) VALUES (?, ?, ?, ?)", (seq, emb_np.tobytes(), emb_shape, emb_dtype), ) if tokenizer_mode and (i + 1) % 100 == 0: conn.commit() conn.commit() conn.close() return None embeddings_dict = {} if os.path.exists(save_path): embeddings_dict = self.load_embeddings_from_pth(save_path) to_embed = [seq for seq in sequences if seq not in embeddings_dict] print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}") print(f"Embedding {len(to_embed)} new sequences") else: to_embed = sequences print(f"Embedding {len(to_embed)} new sequences") if len(to_embed) > 0: with torch.no_grad(): for seqs, residue_embeddings, attention_mask in iter_batches(to_embed): embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype) for seq, emb, mask in zip(seqs, embeddings, attention_mask): if full_embeddings: emb = emb[mask.bool()].reshape(-1, hidden_size) embeddings_dict[seq] = emb.cpu() if save: torch.save(embeddings_dict, save_path) return embeddings_dict # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: Apache-2.0 """ FastPLMs-compatible DPLM implementation. """ import torch import torch.nn as nn from torch.nn import functional as F from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Union from transformers import AutoTokenizer, EsmTokenizer from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, ModelOutput, SequenceClassifierOutput, TokenClassifierOutput, ) from transformers.models.esm.configuration_esm import EsmConfig from transformers.models.esm.modeling_esm import ( EsmAttention, EsmClassificationHead, EsmContactPredictionHead, EsmEmbeddings, EsmEncoder, EsmIntermediate, EsmLayer, EsmLMHead, EsmOutput, EsmPooler, EsmPreTrainedModel, EsmSelfAttention, EsmSelfOutput, ) try: from torch.nn.attention.flex_attention import create_block_mask, flex_attention except (ImportError, AttributeError): create_block_mask = None flex_attention = None from transformers import PreTrainedTokenizerBase class BaseSequenceTokenizer: def __init__(self, tokenizer: PreTrainedTokenizerBase): self.tokenizer = tokenizer def __call__(self, sequences, **kwargs): raise NotImplementedError def get_attention_mask( attn_backend: str, batch_size: int, seq_len: int, device: torch.device, attention_mask: Optional[torch.Tensor] = None, ) -> Tuple[Optional[torch.Tensor], Optional[object]]: if attention_mask is None: token_attention_mask = torch.ones((batch_size, seq_len), device=device).bool() else: token_attention_mask = attention_mask.bool() if attn_backend == "flex": assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable." def mask_mod(batch_idx, head_idx, q_idx, kv_idx): return token_attention_mask[batch_idx, q_idx] & token_attention_mask[batch_idx, kv_idx] flex_block_mask = create_block_mask( mask_mod, batch_size, 1, seq_len, seq_len, device=device, ) extended_attention_mask = None else: flex_block_mask = None extended_attention_mask = token_attention_mask[:, None, :, None] & token_attention_mask[:, None, None, :] return extended_attention_mask, flex_block_mask @dataclass class DPLMMaskedLMOutput(ModelOutput): loss: Optional[torch.Tensor] = None logits: Optional[torch.Tensor] = None last_hidden_state: Optional[torch.Tensor] = None hidden_states: Optional[Tuple[torch.Tensor, ...]] = None attentions: Optional[Tuple[torch.Tensor, ...]] = None class DPLMConfig(EsmConfig): model_type = "dplm" def __init__( self, attn_backend: str = "sdpa", **kwargs, ): super().__init__(**kwargs) self.attn_backend = attn_backend self.tie_word_embeddings = False class DPLMPreTrainedModel(EsmPreTrainedModel): config_class = DPLMConfig base_model_prefix = "dplm" supports_gradient_checkpointing = True tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D") all_tied_weights_keys = {} @property def attn_backend(self) -> str: return self.config.attn_backend @attn_backend.setter def attn_backend(self, backend: str) -> None: assert backend in ("sdpa", "flex"), f"Unsupported attn_backend: {backend}" self.config.attn_backend = backend class ModifiedEsmSelfAttention(EsmSelfAttention): def __init__(self, config, position_embedding_type=None): super().__init__(config, position_embedding_type) self.config = config def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, flex_block_mask: Optional[object] = None, **kwargs, ) -> Tuple[torch.Tensor]: if past_key_values is not None: past_key_value = past_key_values mixed_query_layer = self.query(hidden_states) is_cross_attention = encoder_hidden_states is not None if is_cross_attention and past_key_value is not None: key_layer = past_key_value[0] value_layer = past_key_value[1] attention_mask = encoder_attention_mask elif is_cross_attention: key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) attention_mask = encoder_attention_mask elif past_key_value is not None: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) key_layer = torch.cat([past_key_value[0], key_layer], dim=2) value_layer = torch.cat([past_key_value[1], value_layer], dim=2) else: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) * self.attention_head_size**-0.5 if self.is_decoder: past_key_value = (key_layer, value_layer) if self.position_embedding_type == "rotary": query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) if self.position_embedding_type in ["relative_key", "relative_key_query"]: raise NotImplementedError query_layer = query_layer.contiguous() key_layer = key_layer.contiguous() value_layer = value_layer.contiguous() if output_attentions: assert attention_mask is not None, "output_attentions=True requires a concrete attention mask." attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores.masked_fill(attention_mask.logical_not(), float("-inf")) attention_probs = F.softmax(attention_scores, dim=-1, dtype=torch.float32).to(query_layer.dtype) context_layer = torch.matmul(attention_probs, value_layer) else: attention_probs = None if self.config.attn_backend == "flex": assert flex_attention is not None, "Flex attention backend requested but torch.flex_attention is unavailable." assert query_layer.dtype in (torch.float16, torch.bfloat16), ( f"Flex attention backend requires float16 or bfloat16, got {query_layer.dtype}." ) assert is_cross_attention is False, "Flex attention backend currently does not support cross-attention." assert past_key_value is None, "Flex attention backend currently does not support KV caching." assert flex_block_mask is not None, "Flex attention backend requires a block mask." context_layer = flex_attention( query_layer, key_layer, value_layer, block_mask=flex_block_mask, scale=1.0, ) else: context_layer = F.scaled_dot_product_attention( query_layer, key_layer, value_layer, attn_mask=attention_mask, scale=1.0, ) if head_mask is not None and torch.is_tensor(head_mask): context_layer = context_layer * head_mask context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if self.is_decoder: outputs = outputs + (past_key_value,) return outputs class ModifiedEsmAttention(EsmAttention): def __init__(self, config): nn.Module.__init__(self) self.self = ModifiedEsmSelfAttention(config) self.output = EsmSelfOutput(config) self.pruned_heads = set() self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], head_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: bool = False, flex_block_mask: Optional[object] = None, ): hidden_states_ln = self.LayerNorm(hidden_states) self_outputs = self.self( hidden_states_ln, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, flex_block_mask=flex_block_mask, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] return outputs class ModifiedEsmLayer(EsmLayer): def __init__(self, config): nn.Module.__init__(self) self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.attention = ModifiedEsmAttention(config) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if self.is_decoder is False: raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added") self.crossattention = ModifiedEsmAttention(config) self.intermediate = EsmIntermediate(config) self.output = EsmOutput(config) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], head_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: bool = False, flex_block_mask: Optional[object] = None, ): self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, past_key_value=self_attn_past_key_value, flex_block_mask=flex_block_mask, ) attention_output = self_attention_outputs[0] if self.is_decoder: outputs = self_attention_outputs[1:-1] present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] if self.is_decoder and encoder_hidden_states is not None: if self.add_cross_attention is False: raise AttributeError( f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention " "layers by setting `config.add_cross_attention=True`" ) cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, cross_attn_past_key_value, output_attentions, flex_block_mask=None, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] present_key_value = present_key_value + cross_attention_outputs[-1] layer_output = self.feed_forward_chunk(attention_output) outputs = (layer_output,) + outputs if self.is_decoder: outputs = outputs + (present_key_value,) return outputs class ModifiedEsmEncoder(EsmEncoder): def __init__(self, config): nn.Module.__init__(self) self.config = config self.layer = nn.ModuleList([ModifiedEsmLayer(config) for _ in range(config.num_hidden_layers)]) self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], head_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[Tuple[Tuple[torch.FloatTensor]]]] = None, use_cache: Optional[bool] = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, flex_block_mask: Optional[object] = None, ): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, flex_block_mask, ) else: layer_outputs = layer_module( hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, flex_block_mask, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = next_decoder_cache + (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if self.emb_layer_norm_after: hidden_states = self.emb_layer_norm_after(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if return_dict is False: return tuple( value for value in [ hidden_states, next_decoder_cache, all_hidden_states, all_self_attentions, all_cross_attentions, ] if value is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, ) class DPLMModel(DPLMPreTrainedModel, EmbeddingMixin): config_class = DPLMConfig def get_input_embeddings(self) -> nn.Module: return self.embeddings.word_embeddings def __init__(self, config, add_pooling_layer=True): DPLMPreTrainedModel.__init__(self, config) self.config = config self.embeddings = EsmEmbeddings(config) self.encoder = ModifiedEsmEncoder(config) self.pooler = EsmPooler(config) if add_pooling_layer else None self.contact_head = EsmContactPredictionHead( in_features=config.num_hidden_layers * config.num_attention_heads, bias=True, ) self.post_init() def _convert_head_mask_to_5d(self, head_mask: torch.Tensor, num_hidden_layers: int) -> torch.Tensor: if head_mask.dim() == 1: head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1) elif head_mask.dim() == 2: head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) assert head_mask.dim() == 5, f"head_mask.dim != 5, got {head_mask.dim()}" head_mask = head_mask.to(dtype=self.dtype) return head_mask def get_head_mask( self, head_mask: Optional[torch.Tensor], num_hidden_layers: int, is_attention_chunked: bool = False, ) -> Union[torch.Tensor, List[None]]: if head_mask is None: return [None] * num_hidden_layers head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers) if is_attention_chunked: head_mask = head_mask.unsqueeze(-1) return head_mask def set_input_embeddings(self, value): self.embeddings.word_embeddings = value def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: if attention_mask is None: attention_mask = input_ids.ne(self.config.pad_token_id) outputs = self( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=False, output_attentions=False, return_dict=True, ) return outputs.last_hidden_state def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: attns = self(input_ids, attention_mask=attention_mask, output_attentions=True).attentions attns = torch.stack(attns, dim=1) attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3) attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4) return self.contact_head(input_ids, attns) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.config.is_decoder: use_cache = use_cache if use_cache is not None else self.config.use_cache else: use_cache = False if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") if input_ids is not None: input_shape = input_ids.size() elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if attention_mask is None: token_attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device).bool() elif attention_mask.dim() == 2: token_attention_mask = attention_mask.bool() elif attention_mask.dim() == 4: assert input_ids is not None, "4D attention_mask requires input_ids to infer token-level mask." token_attention_mask = input_ids.ne(self.config.pad_token_id) else: raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}") if self.config.is_decoder and encoder_hidden_states is not None: encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = encoder_attention_mask head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) embedding_attention_mask = token_attention_mask if embedding_attention_mask is None and input_ids is not None: embedding_attention_mask = input_ids.ne(self.config.pad_token_id) if self.config.attn_backend == "flex" and output_attentions: raise AssertionError("output_attentions=True is not supported with attn_backend='flex'.") extended_attention_mask, flex_block_mask = get_attention_mask( attn_backend=self.config.attn_backend, batch_size=batch_size, seq_len=seq_length, device=device, attention_mask=token_attention_mask, ) embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, attention_mask=embedding_attention_mask, inputs_embeds=inputs_embeds, ) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, flex_block_mask=flex_block_mask, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None if return_dict is False: return (sequence_output, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, past_key_values=None, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, cross_attentions=encoder_outputs.cross_attentions, ) class DPLMForMaskedLM(DPLMPreTrainedModel, EmbeddingMixin): config_class = DPLMConfig def __init__(self, config, dropout: float = 0.1): config.hidden_dropout_prob = dropout DPLMPreTrainedModel.__init__(self, config) self.esm = DPLMModel(config, add_pooling_layer=False) self.lm_head = EsmLMHead(config) self.loss_fct = nn.CrossEntropyLoss() self.post_init() self.tokenizer = self.__class__.tokenizer if isinstance(config._name_or_path, str) and len(config._name_or_path) > 0: try: self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path) except Exception: self.tokenizer = self.__class__.tokenizer self.mask_id = self.tokenizer.mask_token_id self.pad_id = self.tokenizer.pad_token_id self.bos_id = self.tokenizer.cls_token_id self.eos_id = self.tokenizer.eos_token_id self.x_id = self.tokenizer.convert_tokens_to_ids("X") self.contact_head = None def get_input_embeddings(self) -> nn.Module: return self.esm.embeddings.word_embeddings def get_output_embeddings(self): return self.lm_head.decoder def set_output_embeddings(self, new_embeddings): self.lm_head.decoder = new_embeddings def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: return self.esm._embed(input_ids, attention_mask) def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: return self.esm.predict_contacts(input_ids, attention_mask=attention_mask) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, decoder_input_ids: Optional[torch.Tensor] = None, decoder_attention_mask: Optional[torch.Tensor] = None, decoder_inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], DPLMMaskedLMOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict if attention_mask is None and input_ids is not None: attention_mask = input_ids.ne(self.pad_id) outputs = self.esm( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, ) sequence_output = outputs.last_hidden_state logits = self.lm_head(sequence_output) loss = None if labels is not None: labels = labels.to(logits.device) loss = self.loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) if return_dict is False: output = (logits, sequence_output, outputs.hidden_states, outputs.attentions) if loss is not None: return (loss,) + output return output return DPLMMaskedLMOutput( loss=loss, logits=logits, last_hidden_state=sequence_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class DPLMForSequenceClassification(DPLMPreTrainedModel, EmbeddingMixin): config_class = DPLMConfig def get_input_embeddings(self) -> nn.Module: return self.esm.embeddings.word_embeddings def __init__(self, config): DPLMPreTrainedModel.__init__(self, config) self.num_labels = config.num_labels self.esm = DPLMModel(config, add_pooling_layer=False) self.classifier = EsmClassificationHead(config) self.mse = nn.MSELoss() self.ce = nn.CrossEntropyLoss() self.bce = nn.BCEWithLogitsLoss() self.post_init() def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: return self.esm._embed(input_ids, attention_mask) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: outputs = self.esm( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, ) sequence_output = outputs.last_hidden_state logits = self.classifier(sequence_output) loss = None if labels is not None: labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": if self.num_labels == 1: loss = self.mse(logits.squeeze(), labels.squeeze()) else: loss = self.mse(logits, labels) elif self.config.problem_type == "single_label_classification": loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss = self.bce(logits, labels) return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class DPLMForTokenClassification(DPLMPreTrainedModel, EmbeddingMixin): config_class = DPLMConfig def get_input_embeddings(self) -> nn.Module: return self.esm.embeddings.word_embeddings def __init__(self, config): DPLMPreTrainedModel.__init__(self, config) self.num_labels = config.num_labels self.esm = DPLMModel(config, add_pooling_layer=False) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.loss_fct = nn.CrossEntropyLoss() self.post_init() def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: return self.esm._embed(input_ids, attention_mask) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: outputs = self.esm( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, ) sequence_output = self.dropout(outputs.last_hidden_state) logits = self.classifier(sequence_output) loss = None if labels is not None: labels = labels.to(logits.device) loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )