Transformers
Safetensors
dplm2
custom_code
lhallee commited on
Commit
d544c12
·
verified ·
1 Parent(s): c3e0129

Upload modeling_dplm2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_dplm2.py +1366 -0
modeling_dplm2.py ADDED
@@ -0,0 +1,1366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Embedding Mixin + Pooler
2
+ import os
3
+ import sqlite3
4
+ import networkx as nx
5
+ import numpy as np
6
+ import torch
7
+ from tqdm.auto import tqdm
8
+ from typing import Callable, List, Optional
9
+ from torch.utils.data import DataLoader
10
+ from torch.utils.data import Dataset as TorchDataset
11
+ from transformers import PreTrainedTokenizerBase
12
+
13
+
14
+ class Pooler:
15
+ def __init__(self, pooling_types: List[str]):
16
+ self.pooling_types = pooling_types
17
+ self.pooling_options = {
18
+ 'mean': self.mean_pooling,
19
+ 'max': self.max_pooling,
20
+ 'norm': self.norm_pooling,
21
+ 'median': self.median_pooling,
22
+ 'std': self.std_pooling,
23
+ 'var': self.var_pooling,
24
+ 'cls': self.cls_pooling,
25
+ 'parti': self._pool_parti,
26
+ }
27
+
28
+ def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor:
29
+ maxed_attentions = torch.max(attentions, dim=1)[0]
30
+ return maxed_attentions
31
+
32
+ def _page_rank(self, attention_matrix, personalization=None, nstart=None, prune_type="top_k_outdegree"):
33
+ # Run PageRank on the attention matrix converted to a graph.
34
+ # Raises exceptions if the graph doesn't match the token sequence or has no edges.
35
+ # Returns the PageRank scores for each token node.
36
+ G = self._convert_to_graph(attention_matrix)
37
+ if G.number_of_nodes() != attention_matrix.shape[0]:
38
+ raise Exception(
39
+ f"The number of nodes in the graph should be equal to the number of tokens in sequence! You have {G.number_of_nodes()} nodes for {attention_matrix.shape[0]} tokens.")
40
+ if G.number_of_edges() == 0:
41
+ raise Exception(f"You don't seem to have any attention edges left in the graph.")
42
+
43
+ return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100)
44
+
45
+ def _convert_to_graph(self, matrix):
46
+ # Convert a matrix (e.g., attention scores) to a directed graph using networkx.
47
+ # Each element in the matrix represents a directed edge with a weight.
48
+ G = nx.from_numpy_array(matrix, create_using=nx.DiGraph)
49
+ return G
50
+
51
+ def _calculate_importance_weights(self, dict_importance, attention_mask: Optional[torch.Tensor] = None):
52
+ # Remove keys where attention_mask is 0
53
+ if attention_mask is not None:
54
+ for k in list(dict_importance.keys()):
55
+ if attention_mask[k] == 0:
56
+ del dict_importance[k]
57
+
58
+ #dict_importance[0] # remove cls
59
+ #dict_importance[-1] # remove eos
60
+ total = sum(dict_importance.values())
61
+ return np.array([v / total for _, v in dict_importance.items()])
62
+
63
+ def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
64
+ maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy()
65
+ # emb is (b, L, d), maxed_attentions is (b, L, L)
66
+ emb_pooled = []
67
+ for e, a, mask in zip(emb, maxed_attentions, attention_mask):
68
+ dict_importance = self._page_rank(a)
69
+ importance_weights = self._calculate_importance_weights(dict_importance, mask)
70
+ num_tokens = int(mask.sum().item())
71
+ emb_pooled.append(np.average(e[:num_tokens], weights=importance_weights, axis=0))
72
+ pooled = torch.tensor(np.array(emb_pooled))
73
+ return pooled
74
+
75
+ def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
76
+ if attention_mask is None:
77
+ return emb.mean(dim=1)
78
+ else:
79
+ attention_mask = attention_mask.unsqueeze(-1)
80
+ return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
81
+
82
+ def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
83
+ if attention_mask is None:
84
+ return emb.max(dim=1).values
85
+ else:
86
+ attention_mask = attention_mask.unsqueeze(-1)
87
+ return (emb * attention_mask).max(dim=1).values
88
+
89
+ def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
90
+ if attention_mask is None:
91
+ return emb.norm(dim=1, p=2)
92
+ else:
93
+ attention_mask = attention_mask.unsqueeze(-1)
94
+ return (emb * attention_mask).norm(dim=1, p=2)
95
+
96
+ def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
97
+ if attention_mask is None:
98
+ return emb.median(dim=1).values
99
+ else:
100
+ attention_mask = attention_mask.unsqueeze(-1)
101
+ return (emb * attention_mask).median(dim=1).values
102
+
103
+ def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
104
+ if attention_mask is None:
105
+ return emb.std(dim=1)
106
+ else:
107
+ # Compute variance correctly over non-masked positions, then take sqrt
108
+ var = self.var_pooling(emb, attention_mask, **kwargs)
109
+ return torch.sqrt(var)
110
+
111
+ def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
112
+ if attention_mask is None:
113
+ return emb.var(dim=1)
114
+ else:
115
+ # Correctly compute variance over only non-masked positions
116
+ attention_mask = attention_mask.unsqueeze(-1) # (b, L, 1)
117
+ # Compute mean over non-masked positions
118
+ mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
119
+ mean = mean.unsqueeze(1) # (b, 1, d)
120
+ # Compute squared differences from mean, only over non-masked positions
121
+ squared_diff = (emb - mean) ** 2 # (b, L, d)
122
+ # Sum squared differences over non-masked positions and divide by count
123
+ var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d)
124
+ return var
125
+
126
+ def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs): # (b, L, d) -> (b, d)
127
+ return emb[:, 0, :]
128
+
129
+ def __call__(
130
+ self,
131
+ emb: torch.Tensor,
132
+ attention_mask: Optional[torch.Tensor] = None,
133
+ attentions: Optional[torch.Tensor] = None
134
+ ): # [mean, max]
135
+ final_emb = []
136
+ for pooling_type in self.pooling_types:
137
+ final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) # (b, d)
138
+ return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d)
139
+
140
+
141
+ class ProteinDataset(TorchDataset):
142
+ """Simple dataset for protein sequences."""
143
+ def __init__(self, sequences: list[str]):
144
+ self.sequences = sequences
145
+
146
+ def __len__(self) -> int:
147
+ return len(self.sequences)
148
+
149
+ def __getitem__(self, idx: int) -> str:
150
+ return self.sequences[idx]
151
+
152
+
153
+ def build_collator(tokenizer: PreTrainedTokenizerBase) -> Callable[[list[str]], dict[str, torch.Tensor]]:
154
+ def _collate_fn(sequences: list[str]) -> dict[str, torch.Tensor]:
155
+ return tokenizer(sequences, return_tensors="pt", padding='longest')
156
+ return _collate_fn
157
+
158
+
159
+ class EmbeddingMixin:
160
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
161
+ raise NotImplementedError
162
+
163
+ @property
164
+ def device(self) -> torch.device:
165
+ """Get the device of the model."""
166
+ return next(self.parameters()).device
167
+
168
+ def _read_sequences_from_db(self, db_path: str) -> set[str]:
169
+ """Read sequences from SQLite database."""
170
+ sequences = []
171
+ with sqlite3.connect(db_path) as conn:
172
+ c = conn.cursor()
173
+ c.execute("SELECT sequence FROM embeddings")
174
+ while True:
175
+ row = c.fetchone()
176
+ if row is None:
177
+ break
178
+ sequences.append(row[0])
179
+ return set(sequences)
180
+
181
+ def _ensure_embeddings_table(self, conn: sqlite3.Connection) -> None:
182
+ cursor = conn.cursor()
183
+ cursor.execute(
184
+ "CREATE TABLE IF NOT EXISTS embeddings ("
185
+ "sequence TEXT PRIMARY KEY, "
186
+ "embedding BLOB NOT NULL, "
187
+ "shape TEXT, "
188
+ "dtype TEXT"
189
+ ")"
190
+ )
191
+ cursor.execute("PRAGMA table_info(embeddings)")
192
+ rows = cursor.fetchall()
193
+ column_names = [row[1] for row in rows]
194
+ if "shape" not in column_names:
195
+ cursor.execute("ALTER TABLE embeddings ADD COLUMN shape TEXT")
196
+ if "dtype" not in column_names:
197
+ cursor.execute("ALTER TABLE embeddings ADD COLUMN dtype TEXT")
198
+ conn.commit()
199
+
200
+ def load_embeddings_from_pth(self, save_path: str) -> dict[str, torch.Tensor]:
201
+ assert os.path.exists(save_path), f"Embedding file does not exist: {save_path}"
202
+ payload = torch.load(save_path, map_location="cpu", weights_only=True)
203
+ assert isinstance(payload, dict), "Expected .pth embeddings file to contain a dictionary."
204
+ for sequence, tensor in payload.items():
205
+ assert isinstance(sequence, str), "Expected embedding dictionary keys to be sequences (str)."
206
+ assert isinstance(tensor, torch.Tensor), "Expected embedding dictionary values to be tensors."
207
+ return payload
208
+
209
+ def load_embeddings_from_db(self, db_path: str, sequences: Optional[List[str]] = None) -> dict[str, torch.Tensor]:
210
+ assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}"
211
+ loaded: dict[str, torch.Tensor] = {}
212
+ with sqlite3.connect(db_path) as conn:
213
+ self._ensure_embeddings_table(conn)
214
+ cursor = conn.cursor()
215
+ if sequences is None:
216
+ cursor.execute("SELECT sequence, embedding, shape, dtype FROM embeddings")
217
+ else:
218
+ if len(sequences) == 0:
219
+ return loaded
220
+ placeholders = ",".join(["?"] * len(sequences))
221
+ cursor.execute(
222
+ f"SELECT sequence, embedding, shape, dtype FROM embeddings WHERE sequence IN ({placeholders})",
223
+ tuple(sequences),
224
+ )
225
+
226
+ rows = cursor.fetchall()
227
+ for row in rows:
228
+ sequence = row[0]
229
+ embedding_bytes = row[1]
230
+ shape_text = row[2]
231
+ dtype_text = row[3]
232
+ assert shape_text is not None, "Missing shape metadata in embeddings table."
233
+ assert dtype_text is not None, "Missing dtype metadata in embeddings table."
234
+ shape_values = [int(value) for value in shape_text.split(",") if len(value) > 0]
235
+ assert len(shape_values) > 0, f"Invalid shape metadata for sequence: {sequence}"
236
+ expected_size = int(np.prod(shape_values))
237
+ np_dtype = np.dtype(dtype_text)
238
+ array = np.frombuffer(embedding_bytes, dtype=np_dtype)
239
+ assert array.size == expected_size, f"Shape mismatch while reading sequence: {sequence}"
240
+ reshaped = array.copy().reshape(tuple(shape_values))
241
+ loaded[sequence] = torch.from_numpy(reshaped)
242
+ return loaded
243
+
244
+ def embed_dataset(
245
+ self,
246
+ sequences: List[str],
247
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
248
+ batch_size: int = 2,
249
+ max_len: int = 512,
250
+ truncate: bool = True,
251
+ full_embeddings: bool = False,
252
+ embed_dtype: torch.dtype = torch.float32,
253
+ pooling_types: List[str] = ['mean'],
254
+ num_workers: int = 0,
255
+ sql: bool = False,
256
+ save: bool = True,
257
+ sql_db_path: str = 'embeddings.db',
258
+ save_path: str = 'embeddings.pth',
259
+ **kwargs,
260
+ ) -> Optional[dict[str, torch.Tensor]]:
261
+ """
262
+ Embed a dataset of protein sequences.
263
+
264
+ Supports two modes:
265
+ - Tokenizer mode (ESM2/ESM++): provide `tokenizer`, `_embed(input_ids, attention_mask)` is used.
266
+ - Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used.
267
+ """
268
+ sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
269
+ sequences = sorted(sequences, key=len, reverse=True)
270
+ hidden_size = self.config.hidden_size
271
+ pooler = Pooler(pooling_types) if not full_embeddings else None
272
+ tokenizer_mode = tokenizer is not None
273
+ if tokenizer_mode:
274
+ collate_fn = build_collator(tokenizer)
275
+ device = self.device
276
+ else:
277
+ collate_fn = None
278
+ device = None
279
+
280
+ def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
281
+ if full_embeddings or residue_embeddings.ndim == 2:
282
+ return residue_embeddings
283
+ return pooler(residue_embeddings, attention_mask)
284
+
285
+ def iter_batches(to_embed: List[str]):
286
+ if tokenizer_mode:
287
+ assert collate_fn is not None
288
+ assert device is not None
289
+ dataset = ProteinDataset(to_embed)
290
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, shuffle=False)
291
+ for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
292
+ seqs = to_embed[i * batch_size:(i + 1) * batch_size]
293
+ input_ids = batch['input_ids'].to(device)
294
+ attention_mask = batch['attention_mask'].to(device)
295
+ residue_embeddings = self._embed(input_ids, attention_mask)
296
+ yield seqs, residue_embeddings, attention_mask
297
+ else:
298
+ for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'):
299
+ seqs = to_embed[batch_start:batch_start + batch_size]
300
+ batch_output = self._embed(seqs, return_attention_mask=True, **kwargs)
301
+ assert isinstance(batch_output, tuple), "Sequence mode _embed must return (last_hidden_state, attention_mask)."
302
+ assert len(batch_output) == 2, "Sequence mode _embed must return exactly two values."
303
+ residue_embeddings, attention_mask = batch_output
304
+ assert isinstance(attention_mask, torch.Tensor), "Sequence mode _embed must return attention_mask as a torch.Tensor."
305
+ yield seqs, residue_embeddings, attention_mask
306
+
307
+ if sql:
308
+ conn = sqlite3.connect(sql_db_path)
309
+ self._ensure_embeddings_table(conn)
310
+ c = conn.cursor()
311
+ already_embedded = self._read_sequences_from_db(sql_db_path)
312
+ to_embed = [seq for seq in sequences if seq not in already_embedded]
313
+ print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
314
+ print(f"Embedding {len(to_embed)} new sequences")
315
+ if len(to_embed) > 0:
316
+ with torch.no_grad():
317
+ for i, (seqs, residue_embeddings, attention_mask) in enumerate(iter_batches(to_embed)):
318
+ embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
319
+ for seq, emb, mask in zip(seqs, embeddings, attention_mask):
320
+ if full_embeddings:
321
+ emb = emb[mask.bool()].reshape(-1, hidden_size)
322
+ emb_np = emb.cpu().numpy()
323
+ emb_shape = ",".join([str(dim) for dim in emb_np.shape])
324
+ emb_dtype = str(emb_np.dtype)
325
+ c.execute(
326
+ "INSERT OR REPLACE INTO embeddings (sequence, embedding, shape, dtype) VALUES (?, ?, ?, ?)",
327
+ (seq, emb_np.tobytes(), emb_shape, emb_dtype),
328
+ )
329
+ if tokenizer_mode and (i + 1) % 100 == 0:
330
+ conn.commit()
331
+ conn.commit()
332
+ conn.close()
333
+ return None
334
+
335
+ embeddings_dict = {}
336
+ if os.path.exists(save_path):
337
+ embeddings_dict = self.load_embeddings_from_pth(save_path)
338
+ to_embed = [seq for seq in sequences if seq not in embeddings_dict]
339
+ print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}")
340
+ print(f"Embedding {len(to_embed)} new sequences")
341
+ else:
342
+ to_embed = sequences
343
+ print(f"Embedding {len(to_embed)} new sequences")
344
+
345
+ if len(to_embed) > 0:
346
+ with torch.no_grad():
347
+ for seqs, residue_embeddings, attention_mask in iter_batches(to_embed):
348
+ embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
349
+ for seq, emb, mask in zip(seqs, embeddings, attention_mask):
350
+ if full_embeddings:
351
+ emb = emb[mask.bool()].reshape(-1, hidden_size)
352
+ embeddings_dict[seq] = emb.cpu()
353
+
354
+ if save:
355
+ torch.save(embeddings_dict, save_path)
356
+
357
+ return embeddings_dict
358
+
359
+
360
+ """
361
+ FastPLMs-compatible DPLM2 implementation.
362
+ """
363
+
364
+ import torch
365
+ import torch.nn as nn
366
+ from torch.nn import functional as F
367
+ from dataclasses import dataclass
368
+ from typing import List, Optional, Tuple, Union
369
+
370
+ from transformers import EsmTokenizer
371
+ from transformers.modeling_outputs import (
372
+ BaseModelOutputWithPastAndCrossAttentions,
373
+ BaseModelOutputWithPoolingAndCrossAttentions,
374
+ ModelOutput,
375
+ SequenceClassifierOutput,
376
+ TokenClassifierOutput,
377
+ )
378
+ from transformers.models.esm.configuration_esm import EsmConfig
379
+ from transformers.models.esm.modeling_esm import (
380
+ EsmAttention,
381
+ EsmClassificationHead,
382
+ EsmEmbeddings,
383
+ EsmEncoder,
384
+ EsmIntermediate,
385
+ EsmLayer,
386
+ EsmLMHead,
387
+ EsmOutput,
388
+ EsmPooler,
389
+ EsmPreTrainedModel,
390
+ EsmSelfAttention,
391
+ EsmSelfOutput,
392
+ RotaryEmbedding,
393
+ apply_rotary_pos_emb,
394
+ )
395
+
396
+ try:
397
+ from torch.nn.attention.flex_attention import create_block_mask, flex_attention
398
+ except (ImportError, AttributeError):
399
+ create_block_mask = None
400
+ flex_attention = None
401
+
402
+
403
+ from transformers import PreTrainedTokenizerBase
404
+
405
+
406
+ class BaseSequenceTokenizer:
407
+ def __init__(self, tokenizer: PreTrainedTokenizerBase):
408
+ self.tokenizer = tokenizer
409
+
410
+ def __call__(self, sequences, **kwargs):
411
+ raise NotImplementedError
412
+
413
+
414
+ def get_attention_mask(
415
+ attn_backend: str,
416
+ batch_size: int,
417
+ seq_len: int,
418
+ device: torch.device,
419
+ attention_mask: Optional[torch.Tensor] = None,
420
+ ) -> Tuple[Optional[torch.Tensor], Optional[object]]:
421
+ if attention_mask is None:
422
+ attention_mask_2d = torch.ones((batch_size, seq_len), device=device).bool()
423
+ else:
424
+ attention_mask_2d = attention_mask.bool()
425
+
426
+ if attn_backend == "flex":
427
+ assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
428
+
429
+ if attention_mask is None:
430
+ flex_block_mask = None
431
+ else:
432
+ valid_lens = attention_mask_2d.sum(dim=-1)
433
+
434
+ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
435
+ return (q_idx < valid_lens[batch_idx]) & (kv_idx < valid_lens[batch_idx])
436
+
437
+ flex_block_mask = create_block_mask(
438
+ mask_mod,
439
+ batch_size,
440
+ 1,
441
+ seq_len,
442
+ seq_len,
443
+ device=device,
444
+ )
445
+ attention_mask_4d = None
446
+ else:
447
+ flex_block_mask = None
448
+ attention_mask_4d = attention_mask_2d[:, None, :, None] & attention_mask_2d[:, None, None, :]
449
+
450
+ return attention_mask_4d, flex_block_mask
451
+
452
+
453
+ def _infer_modality_type(input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
454
+ input_mask = attention_mask.bool()
455
+ modality_type = ((input_ids < 33) & input_mask).int()
456
+ modality_type[~input_mask] = 2
457
+ return modality_type
458
+
459
+
460
+ @dataclass
461
+ class DPLM2MaskedLMOutput(ModelOutput):
462
+ loss: Optional[torch.Tensor] = None
463
+ logits: Optional[torch.Tensor] = None
464
+ last_hidden_state: Optional[torch.Tensor] = None
465
+ hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
466
+ attentions: Optional[Tuple[torch.Tensor, ...]] = None
467
+
468
+
469
+ class DPLM2Config(EsmConfig):
470
+ model_type = "dplm2"
471
+
472
+ def __init__(
473
+ self,
474
+ attn_backend: str = "sdpa",
475
+ aa_type: int = 1,
476
+ struct_type: int = 0,
477
+ pad_type: int = 2,
478
+ **kwargs,
479
+ ):
480
+ super().__init__(**kwargs)
481
+ self.attn_backend = attn_backend
482
+ self.aa_type = aa_type
483
+ self.struct_type = struct_type
484
+ self.pad_type = pad_type
485
+ self.tie_word_embeddings = False
486
+
487
+
488
+ class DPLM2PreTrainedModel(EsmPreTrainedModel):
489
+ config_class = DPLM2Config
490
+ base_model_prefix = "dplm2"
491
+ supports_gradient_checkpointing = True
492
+ tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
493
+ all_tied_weights_keys = {}
494
+
495
+ @classmethod
496
+ def is_remote_code(cls) -> bool:
497
+ # Prevent post-load reinitialization of tensors already loaded from checkpoints.
498
+ return True
499
+
500
+ @property
501
+ def attn_backend(self) -> str:
502
+ return self.config.attn_backend
503
+
504
+ @attn_backend.setter
505
+ def attn_backend(self, backend: str) -> None:
506
+ assert backend in ("sdpa", "flex"), f"Unsupported attn_backend: {backend}"
507
+ self.config.attn_backend = backend
508
+
509
+
510
+
511
+ class ModifiedRotaryEmbedding(RotaryEmbedding):
512
+ def __init__(self, dim: int, aa_type: int, struct_type: int):
513
+ super().__init__(dim)
514
+ self.aa_type = aa_type
515
+ self.struct_type = struct_type
516
+
517
+ def _has_multimodal_tokens(self, type_ids: Optional[torch.Tensor]) -> bool:
518
+ if type_ids is None:
519
+ return False
520
+ aa_present = (type_ids == self.aa_type).any()
521
+ struct_present = (type_ids == self.struct_type).any()
522
+ return bool(aa_present and struct_present)
523
+
524
+ def _update_cos_sin_tables(
525
+ self,
526
+ x: torch.Tensor,
527
+ type_ids: Optional[torch.Tensor],
528
+ seq_dimension: int = 2,
529
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
530
+ seq_len = x.shape[seq_dimension]
531
+ if self._has_multimodal_tokens(type_ids):
532
+ seq_len = seq_len // 2
533
+
534
+ cache_is_stale = (
535
+ self._cos_cached is None
536
+ or self._sin_cached is None
537
+ or seq_len != self._seq_len_cached
538
+ or self._cos_cached.device != x.device
539
+ )
540
+ if cache_is_stale:
541
+ self._seq_len_cached = seq_len
542
+ t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
543
+ freqs = torch.outer(t, self.inv_freq)
544
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
545
+ self._cos_cached = emb.cos()[None, None, :, :]
546
+ self._sin_cached = emb.sin()[None, None, :, :]
547
+
548
+ return self._cos_cached, self._sin_cached
549
+
550
+ def forward(
551
+ self,
552
+ q: torch.Tensor,
553
+ k: torch.Tensor,
554
+ type_ids: Optional[torch.Tensor],
555
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
556
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
557
+ k,
558
+ type_ids=type_ids,
559
+ seq_dimension=-2,
560
+ )
561
+
562
+ if self._has_multimodal_tokens(type_ids):
563
+ q_1, q_2 = q.chunk(2, dim=-2)
564
+ k_1, k_2 = k.chunk(2, dim=-2)
565
+ q_1 = apply_rotary_pos_emb(q_1, self._cos_cached, self._sin_cached)
566
+ q_2 = apply_rotary_pos_emb(q_2, self._cos_cached, self._sin_cached)
567
+ k_1 = apply_rotary_pos_emb(k_1, self._cos_cached, self._sin_cached)
568
+ k_2 = apply_rotary_pos_emb(k_2, self._cos_cached, self._sin_cached)
569
+ return torch.cat((q_1, q_2), dim=-2), torch.cat((k_1, k_2), dim=-2)
570
+
571
+ return (
572
+ apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
573
+ apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
574
+ )
575
+
576
+
577
+ class ModifiedEsmSelfAttention(EsmSelfAttention):
578
+ def __init__(self, config, position_embedding_type=None):
579
+ super().__init__(config, position_embedding_type)
580
+ self.config = config
581
+ self.rotary_embeddings = ModifiedRotaryEmbedding(
582
+ dim=self.attention_head_size,
583
+ aa_type=config.aa_type,
584
+ struct_type=config.struct_type,
585
+ )
586
+
587
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
588
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
589
+ x = x.view(new_x_shape)
590
+ return x.permute(0, 2, 1, 3)
591
+
592
+ def forward(
593
+ self,
594
+ hidden_states: torch.Tensor,
595
+ attention_mask: Optional[torch.Tensor],
596
+ head_mask: Optional[torch.FloatTensor] = None,
597
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
598
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
599
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
600
+ output_attentions: Optional[bool] = False,
601
+ type_ids: Optional[torch.Tensor] = None,
602
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
603
+ flex_block_mask: Optional[object] = None,
604
+ **kwargs,
605
+ ) -> Tuple[torch.Tensor]:
606
+ if past_key_values is not None:
607
+ past_key_value = past_key_values
608
+
609
+ mixed_query_layer = self.query(hidden_states)
610
+ is_cross_attention = encoder_hidden_states is not None
611
+
612
+ if is_cross_attention and past_key_value is not None:
613
+ key_layer = past_key_value[0]
614
+ value_layer = past_key_value[1]
615
+ attention_mask = encoder_attention_mask
616
+ elif is_cross_attention:
617
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
618
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
619
+ attention_mask = encoder_attention_mask
620
+ elif past_key_value is not None:
621
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
622
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
623
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
624
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
625
+ else:
626
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
627
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
628
+
629
+ query_layer = self.transpose_for_scores(mixed_query_layer) * self.attention_head_size**-0.5
630
+
631
+ if self.is_decoder:
632
+ past_key_value = (key_layer, value_layer)
633
+
634
+ if self.position_embedding_type == "rotary":
635
+ query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer, type_ids)
636
+
637
+ if self.position_embedding_type in ["relative_key", "relative_key_query"]:
638
+ raise NotImplementedError
639
+
640
+ query_layer = query_layer.contiguous()
641
+ key_layer = key_layer.contiguous()
642
+ value_layer = value_layer.contiguous()
643
+
644
+ if output_attentions:
645
+ assert attention_mask is not None, "output_attentions=True requires a concrete attention mask."
646
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
647
+ attention_scores = attention_scores.masked_fill(attention_mask.logical_not(), float("-inf"))
648
+ attention_probs = F.softmax(attention_scores, dim=-1, dtype=torch.float32).to(query_layer.dtype)
649
+ context_layer = torch.matmul(attention_probs, value_layer)
650
+ else:
651
+ attention_probs = None
652
+ if self.config.attn_backend == "flex":
653
+ assert flex_attention is not None, "Flex attention backend requested but torch.flex_attention is unavailable."
654
+ assert query_layer.dtype in (torch.float16, torch.bfloat16), (
655
+ f"Flex attention backend requires float16 or bfloat16, got {query_layer.dtype}."
656
+ )
657
+ assert is_cross_attention is False, "Flex attention backend currently does not support cross-attention."
658
+ assert past_key_value is None, "Flex attention backend currently does not support KV caching."
659
+ assert flex_block_mask is not None, "Flex attention backend requires a block mask."
660
+ context_layer = flex_attention(
661
+ query_layer,
662
+ key_layer,
663
+ value_layer,
664
+ block_mask=flex_block_mask,
665
+ scale=1.0,
666
+ )
667
+ else:
668
+ context_layer = F.scaled_dot_product_attention(
669
+ query_layer,
670
+ key_layer,
671
+ value_layer,
672
+ attn_mask=attention_mask,
673
+ scale=1.0,
674
+ )
675
+
676
+ if head_mask is not None and torch.is_tensor(head_mask):
677
+ context_layer = context_layer * head_mask
678
+
679
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
680
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
681
+ context_layer = context_layer.view(new_context_layer_shape)
682
+
683
+ outputs = (context_layer, attention_probs)
684
+ if self.is_decoder:
685
+ outputs = outputs + (past_key_value,)
686
+ return outputs
687
+
688
+
689
+ class ModifiedEsmAttention(EsmAttention):
690
+ def __init__(self, config):
691
+ nn.Module.__init__(self)
692
+ self.self = ModifiedEsmSelfAttention(config)
693
+ self.output = EsmSelfOutput(config)
694
+ self.pruned_heads = set()
695
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
696
+
697
+ def forward(
698
+ self,
699
+ hidden_states: torch.Tensor,
700
+ attention_mask: Optional[torch.Tensor],
701
+ head_mask: Optional[torch.Tensor] = None,
702
+ encoder_hidden_states: Optional[torch.Tensor] = None,
703
+ encoder_attention_mask: Optional[torch.Tensor] = None,
704
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
705
+ output_attentions: bool = False,
706
+ type_ids: Optional[torch.Tensor] = None,
707
+ flex_block_mask: Optional[object] = None,
708
+ ):
709
+ hidden_states_ln = self.LayerNorm(hidden_states)
710
+ self_outputs = self.self(
711
+ hidden_states_ln,
712
+ attention_mask,
713
+ head_mask,
714
+ encoder_hidden_states,
715
+ encoder_attention_mask,
716
+ past_key_value,
717
+ output_attentions,
718
+ type_ids,
719
+ flex_block_mask=flex_block_mask,
720
+ )
721
+ attention_output = self.output(self_outputs[0], hidden_states)
722
+ outputs = (attention_output,) + self_outputs[1:]
723
+ return outputs
724
+
725
+
726
+ class ModifiedEsmLayer(EsmLayer):
727
+ def __init__(self, config):
728
+ nn.Module.__init__(self)
729
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
730
+ self.seq_len_dim = 1
731
+ self.attention = ModifiedEsmAttention(config)
732
+ self.is_decoder = config.is_decoder
733
+ self.add_cross_attention = config.add_cross_attention
734
+ if self.add_cross_attention:
735
+ if self.is_decoder is False:
736
+ raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added")
737
+ self.crossattention = ModifiedEsmAttention(config)
738
+ self.intermediate = EsmIntermediate(config)
739
+ self.output = EsmOutput(config)
740
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
741
+
742
+ def forward(
743
+ self,
744
+ hidden_states: torch.Tensor,
745
+ attention_mask: Optional[torch.Tensor],
746
+ head_mask: Optional[torch.Tensor] = None,
747
+ encoder_hidden_states: Optional[torch.Tensor] = None,
748
+ encoder_attention_mask: Optional[torch.Tensor] = None,
749
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
750
+ output_attentions: bool = False,
751
+ type_ids: Optional[torch.Tensor] = None,
752
+ flex_block_mask: Optional[object] = None,
753
+ ):
754
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
755
+ self_attention_outputs = self.attention(
756
+ hidden_states,
757
+ attention_mask,
758
+ head_mask,
759
+ output_attentions=output_attentions,
760
+ past_key_value=self_attn_past_key_value,
761
+ type_ids=type_ids,
762
+ flex_block_mask=flex_block_mask,
763
+ )
764
+ attention_output = self_attention_outputs[0]
765
+
766
+ if self.is_decoder:
767
+ outputs = self_attention_outputs[1:-1]
768
+ present_key_value = self_attention_outputs[-1]
769
+ else:
770
+ outputs = self_attention_outputs[1:]
771
+
772
+ if self.is_decoder and encoder_hidden_states is not None:
773
+ if self.add_cross_attention is False:
774
+ raise AttributeError(
775
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention "
776
+ "layers by setting `config.add_cross_attention=True`"
777
+ )
778
+
779
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
780
+ cross_attention_outputs = self.crossattention(
781
+ attention_output,
782
+ attention_mask,
783
+ head_mask,
784
+ encoder_hidden_states,
785
+ encoder_attention_mask,
786
+ cross_attn_past_key_value,
787
+ output_attentions,
788
+ type_ids=None,
789
+ flex_block_mask=None,
790
+ )
791
+ attention_output = cross_attention_outputs[0]
792
+ outputs = outputs + cross_attention_outputs[1:-1]
793
+ present_key_value = present_key_value + cross_attention_outputs[-1]
794
+
795
+ layer_output = self.feed_forward_chunk(attention_output)
796
+ outputs = (layer_output,) + outputs
797
+
798
+ if self.is_decoder:
799
+ outputs = outputs + (present_key_value,)
800
+ return outputs
801
+
802
+
803
+ class ModifiedEsmEncoder(EsmEncoder):
804
+ def __init__(self, config):
805
+ nn.Module.__init__(self)
806
+ self.config = config
807
+ self.layer = nn.ModuleList([ModifiedEsmLayer(config) for _ in range(config.num_hidden_layers)])
808
+ self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
809
+ self.gradient_checkpointing = False
810
+
811
+ def forward(
812
+ self,
813
+ hidden_states: torch.Tensor,
814
+ attention_mask: Optional[torch.Tensor],
815
+ head_mask: Optional[torch.Tensor] = None,
816
+ encoder_hidden_states: Optional[torch.Tensor] = None,
817
+ encoder_attention_mask: Optional[torch.Tensor] = None,
818
+ past_key_values: Optional[List[Tuple[Tuple[torch.FloatTensor]]]] = None,
819
+ use_cache: Optional[bool] = None,
820
+ output_attentions: bool = False,
821
+ output_hidden_states: bool = False,
822
+ return_dict: bool = True,
823
+ type_ids: Optional[torch.Tensor] = None,
824
+ flex_block_mask: Optional[object] = None,
825
+ ):
826
+ all_hidden_states = () if output_hidden_states else None
827
+ all_self_attentions = () if output_attentions else None
828
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
829
+ next_decoder_cache = () if use_cache else None
830
+
831
+ for i, layer_module in enumerate(self.layer):
832
+ if output_hidden_states:
833
+ all_hidden_states = all_hidden_states + (hidden_states,)
834
+
835
+ layer_head_mask = head_mask[i] if head_mask is not None else None
836
+ past_key_value = past_key_values[i] if past_key_values is not None else None
837
+
838
+ if self.gradient_checkpointing and self.training:
839
+ layer_outputs = self._gradient_checkpointing_func(
840
+ layer_module.__call__,
841
+ hidden_states,
842
+ attention_mask,
843
+ layer_head_mask,
844
+ encoder_hidden_states,
845
+ encoder_attention_mask,
846
+ past_key_value,
847
+ output_attentions,
848
+ type_ids,
849
+ flex_block_mask,
850
+ )
851
+ else:
852
+ layer_outputs = layer_module(
853
+ hidden_states,
854
+ attention_mask,
855
+ layer_head_mask,
856
+ encoder_hidden_states,
857
+ encoder_attention_mask,
858
+ past_key_value,
859
+ output_attentions,
860
+ type_ids,
861
+ flex_block_mask,
862
+ )
863
+
864
+ hidden_states = layer_outputs[0]
865
+ if use_cache:
866
+ next_decoder_cache = next_decoder_cache + (layer_outputs[-1],)
867
+ if output_attentions:
868
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
869
+ if self.config.add_cross_attention:
870
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
871
+
872
+ if self.emb_layer_norm_after:
873
+ hidden_states = self.emb_layer_norm_after(hidden_states)
874
+
875
+ if output_hidden_states:
876
+ all_hidden_states = all_hidden_states + (hidden_states,)
877
+
878
+ if return_dict is False:
879
+ return tuple(
880
+ value
881
+ for value in [
882
+ hidden_states,
883
+ next_decoder_cache,
884
+ all_hidden_states,
885
+ all_self_attentions,
886
+ all_cross_attentions,
887
+ ]
888
+ if value is not None
889
+ )
890
+
891
+ return BaseModelOutputWithPastAndCrossAttentions(
892
+ last_hidden_state=hidden_states,
893
+ past_key_values=next_decoder_cache,
894
+ hidden_states=all_hidden_states,
895
+ attentions=all_self_attentions,
896
+ cross_attentions=all_cross_attentions,
897
+ )
898
+
899
+
900
+ class FAST_DPLM2_ENCODER(DPLM2PreTrainedModel, EmbeddingMixin):
901
+ """Inner encoder class that holds the actual ESM-style weights (embeddings, encoder)
902
+ so that the weight keys are prefixed with 'esm.' in the outer DPLM2Model,
903
+ matching pretrained DPLM2 checkpoints."""
904
+
905
+ def __init__(self, config, **kwargs):
906
+ DPLM2PreTrainedModel.__init__(self, config, **kwargs)
907
+ self.config = config
908
+ self.embeddings = EsmEmbeddings(config)
909
+ self.encoder = ModifiedEsmEncoder(config)
910
+ self.post_init()
911
+
912
+ def get_input_embeddings(self) -> nn.Module:
913
+ return self.embeddings.word_embeddings
914
+
915
+ def set_input_embeddings(self, value):
916
+ self.embeddings.word_embeddings = value
917
+
918
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
919
+ if attention_mask is None:
920
+ attention_mask = input_ids.ne(self.config.pad_token_id)
921
+ type_ids = _infer_modality_type(input_ids, attention_mask)
922
+ outputs = self(
923
+ input_ids=input_ids,
924
+ attention_mask=attention_mask,
925
+ type_ids=type_ids,
926
+ output_hidden_states=False,
927
+ output_attentions=False,
928
+ return_dict=True,
929
+ )
930
+ return outputs.last_hidden_state
931
+
932
+ def _convert_head_mask_to_5d(self, head_mask: torch.Tensor, num_hidden_layers: int) -> torch.Tensor:
933
+ if head_mask.dim() == 1:
934
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
935
+ head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
936
+ elif head_mask.dim() == 2:
937
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
938
+ assert head_mask.dim() == 5, f"head_mask.dim != 5, got {head_mask.dim()}"
939
+ head_mask = head_mask.to(dtype=self.dtype)
940
+ return head_mask
941
+
942
+ def get_head_mask(
943
+ self,
944
+ head_mask: Optional[torch.Tensor],
945
+ num_hidden_layers: int,
946
+ is_attention_chunked: bool = False,
947
+ ) -> Union[torch.Tensor, List[None]]:
948
+ if head_mask is None:
949
+ return [None] * num_hidden_layers
950
+ head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
951
+ if is_attention_chunked:
952
+ head_mask = head_mask.unsqueeze(-1)
953
+ return head_mask
954
+
955
+ def forward(
956
+ self,
957
+ input_ids: Optional[torch.Tensor] = None,
958
+ attention_mask: Optional[torch.Tensor] = None,
959
+ position_ids: Optional[torch.Tensor] = None,
960
+ head_mask: Optional[torch.Tensor] = None,
961
+ inputs_embeds: Optional[torch.Tensor] = None,
962
+ encoder_hidden_states: Optional[torch.Tensor] = None,
963
+ encoder_attention_mask: Optional[torch.Tensor] = None,
964
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
965
+ use_cache: Optional[bool] = None,
966
+ output_attentions: Optional[bool] = None,
967
+ output_hidden_states: Optional[bool] = None,
968
+ return_dict: Optional[bool] = None,
969
+ type_ids: Optional[torch.Tensor] = None,
970
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
971
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
972
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
973
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
974
+
975
+ if self.config.is_decoder:
976
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
977
+ else:
978
+ use_cache = False
979
+
980
+ if input_ids is not None and inputs_embeds is not None:
981
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
982
+ if input_ids is not None:
983
+ input_shape = input_ids.size()
984
+ elif inputs_embeds is not None:
985
+ input_shape = inputs_embeds.size()[:-1]
986
+ else:
987
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
988
+
989
+ batch_size, seq_length = input_shape
990
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
991
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
992
+
993
+ if attention_mask is None:
994
+ attention_mask_2d = torch.ones((batch_size, seq_length + past_key_values_length), device=device).bool()
995
+ elif attention_mask.dim() == 2:
996
+ attention_mask_2d = attention_mask.bool()
997
+ elif attention_mask.dim() == 4:
998
+ assert input_ids is not None, "4D attention_mask requires input_ids to infer token-level mask."
999
+ attention_mask_2d = input_ids.ne(self.config.pad_token_id)
1000
+ else:
1001
+ raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}")
1002
+
1003
+ if self.config.is_decoder and encoder_hidden_states is not None:
1004
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
1005
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1006
+ if encoder_attention_mask is None:
1007
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1008
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1009
+ else:
1010
+ encoder_extended_attention_mask = encoder_attention_mask
1011
+
1012
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1013
+
1014
+ embedding_attention_mask = attention_mask_2d
1015
+ if embedding_attention_mask is None and input_ids is not None:
1016
+ embedding_attention_mask = input_ids.ne(self.config.pad_token_id)
1017
+
1018
+ if self.config.attn_backend == "flex" and output_attentions:
1019
+ raise AssertionError("output_attentions=True is not supported with attn_backend='flex'.")
1020
+
1021
+ attention_mask_4d, flex_block_mask = get_attention_mask(
1022
+ attn_backend=self.config.attn_backend,
1023
+ batch_size=batch_size,
1024
+ seq_len=seq_length,
1025
+ device=device,
1026
+ attention_mask=attention_mask_2d,
1027
+ )
1028
+
1029
+ embedding_output = self.embeddings(
1030
+ input_ids=input_ids,
1031
+ position_ids=position_ids,
1032
+ attention_mask=embedding_attention_mask,
1033
+ inputs_embeds=inputs_embeds,
1034
+ )
1035
+ encoder_outputs = self.encoder(
1036
+ embedding_output,
1037
+ attention_mask=attention_mask_4d,
1038
+ head_mask=head_mask,
1039
+ encoder_hidden_states=encoder_hidden_states,
1040
+ encoder_attention_mask=encoder_extended_attention_mask,
1041
+ past_key_values=past_key_values,
1042
+ use_cache=use_cache,
1043
+ output_attentions=output_attentions,
1044
+ output_hidden_states=output_hidden_states,
1045
+ return_dict=return_dict,
1046
+ type_ids=type_ids,
1047
+ flex_block_mask=flex_block_mask,
1048
+ )
1049
+ sequence_output = encoder_outputs[0]
1050
+
1051
+ if return_dict is False:
1052
+ return (sequence_output,) + encoder_outputs[1:]
1053
+
1054
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1055
+ last_hidden_state=sequence_output,
1056
+ past_key_values=None,
1057
+ hidden_states=encoder_outputs.hidden_states,
1058
+ attentions=encoder_outputs.attentions,
1059
+ cross_attentions=encoder_outputs.cross_attentions,
1060
+ )
1061
+
1062
+
1063
+ class DPLM2Model(DPLM2PreTrainedModel, EmbeddingMixin):
1064
+ config_class = DPLM2Config
1065
+ def __init__(self, config, add_pooling_layer=True):
1066
+ DPLM2PreTrainedModel.__init__(self, config)
1067
+ self.config = config
1068
+ self.esm = FAST_DPLM2_ENCODER(config)
1069
+ self.pooler = EsmPooler(config) if add_pooling_layer else None
1070
+ self.post_init()
1071
+
1072
+ def get_input_embeddings(self) -> nn.Module:
1073
+ return self.esm.embeddings.word_embeddings
1074
+
1075
+ def set_input_embeddings(self, value):
1076
+ self.esm.embeddings.word_embeddings = value
1077
+
1078
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1079
+ return self.esm._embed(input_ids, attention_mask)
1080
+
1081
+ def forward(
1082
+ self,
1083
+ input_ids: Optional[torch.Tensor] = None,
1084
+ attention_mask: Optional[torch.Tensor] = None,
1085
+ position_ids: Optional[torch.Tensor] = None,
1086
+ head_mask: Optional[torch.Tensor] = None,
1087
+ inputs_embeds: Optional[torch.Tensor] = None,
1088
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1089
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1090
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1091
+ use_cache: Optional[bool] = None,
1092
+ output_attentions: Optional[bool] = None,
1093
+ output_hidden_states: Optional[bool] = None,
1094
+ return_dict: Optional[bool] = None,
1095
+ type_ids: Optional[torch.Tensor] = None,
1096
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
1097
+ outputs = self.esm(
1098
+ input_ids=input_ids,
1099
+ attention_mask=attention_mask,
1100
+ position_ids=position_ids,
1101
+ head_mask=head_mask,
1102
+ inputs_embeds=inputs_embeds,
1103
+ encoder_hidden_states=encoder_hidden_states,
1104
+ encoder_attention_mask=encoder_attention_mask,
1105
+ past_key_values=past_key_values,
1106
+ use_cache=use_cache,
1107
+ output_attentions=output_attentions,
1108
+ output_hidden_states=output_hidden_states,
1109
+ return_dict=return_dict,
1110
+ type_ids=type_ids,
1111
+ )
1112
+ sequence_output = outputs[0]
1113
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1114
+
1115
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1116
+ if return_dict is False:
1117
+ return (sequence_output, pooled_output) + outputs[1:]
1118
+
1119
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1120
+ last_hidden_state=sequence_output,
1121
+ pooler_output=pooled_output,
1122
+ past_key_values=None,
1123
+ hidden_states=outputs.hidden_states,
1124
+ attentions=outputs.attentions,
1125
+ cross_attentions=outputs.cross_attentions,
1126
+ )
1127
+
1128
+
1129
+ class DPLM2ForMaskedLM(DPLM2PreTrainedModel, EmbeddingMixin):
1130
+ config_class = DPLM2Config
1131
+ def __init__(self, config, dropout: float = 0.1, vocab_size: Optional[int] = None):
1132
+ config.hidden_dropout_prob = dropout
1133
+ config.tie_word_embeddings = False
1134
+ if vocab_size is not None:
1135
+ config.vocab_size = vocab_size
1136
+ DPLM2PreTrainedModel.__init__(self, config)
1137
+ self.esm = FAST_DPLM2_ENCODER(config)
1138
+ self.lm_head = EsmLMHead(config)
1139
+ self.loss_fct = nn.CrossEntropyLoss()
1140
+ self.post_init()
1141
+ self.pad_id = config.pad_token_id
1142
+ self.tokenizer = self.__class__.tokenizer
1143
+ if isinstance(config._name_or_path, str) and len(config._name_or_path) > 0:
1144
+ self.tokenizer = EsmTokenizer.from_pretrained(config._name_or_path)
1145
+
1146
+ def get_input_embeddings(self) -> nn.Module:
1147
+ return self.esm.get_input_embeddings()
1148
+
1149
+ def get_output_embeddings(self):
1150
+ return self.lm_head.decoder
1151
+
1152
+ def set_output_embeddings(self, new_embeddings):
1153
+ self.lm_head.decoder = new_embeddings
1154
+
1155
+ def _get_modality_type(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
1156
+ return _infer_modality_type(input_ids, attention_mask)
1157
+
1158
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1159
+ if attention_mask is None:
1160
+ attention_mask = input_ids.ne(self.pad_id)
1161
+ type_ids = self._get_modality_type(input_ids, attention_mask)
1162
+ outputs = self.esm(
1163
+ input_ids=input_ids,
1164
+ attention_mask=attention_mask,
1165
+ type_ids=type_ids,
1166
+ output_attentions=False,
1167
+ output_hidden_states=False,
1168
+ return_dict=True,
1169
+ )
1170
+ return outputs.last_hidden_state
1171
+
1172
+ def forward(
1173
+ self,
1174
+ input_ids: Optional[torch.Tensor] = None,
1175
+ attention_mask: Optional[torch.Tensor] = None,
1176
+ type_ids: Optional[torch.Tensor] = None,
1177
+ inputs_embeds: Optional[torch.Tensor] = None,
1178
+ decoder_input_ids: Optional[torch.Tensor] = None,
1179
+ decoder_attention_mask: Optional[torch.Tensor] = None,
1180
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
1181
+ labels: Optional[torch.Tensor] = None,
1182
+ output_attentions: Optional[bool] = None,
1183
+ output_hidden_states: Optional[bool] = None,
1184
+ return_dict: Optional[bool] = None,
1185
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1186
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1187
+ ) -> Union[Tuple[torch.Tensor], DPLM2MaskedLMOutput]:
1188
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1189
+
1190
+ if attention_mask is None:
1191
+ assert input_ids is not None
1192
+ attention_mask = input_ids.ne(self.pad_id)
1193
+
1194
+ if type_ids is None:
1195
+ assert input_ids is not None
1196
+ type_ids = self._get_modality_type(input_ids, attention_mask)
1197
+
1198
+ outputs = self.esm(
1199
+ input_ids=input_ids,
1200
+ inputs_embeds=inputs_embeds,
1201
+ attention_mask=attention_mask,
1202
+ encoder_hidden_states=encoder_hidden_states,
1203
+ encoder_attention_mask=encoder_attention_mask,
1204
+ output_attentions=output_attentions,
1205
+ output_hidden_states=output_hidden_states,
1206
+ return_dict=True,
1207
+ type_ids=type_ids,
1208
+ )
1209
+
1210
+ sequence_output = outputs.last_hidden_state
1211
+ logits = self.lm_head(sequence_output)
1212
+ loss = None
1213
+ if labels is not None:
1214
+ labels = labels.to(logits.device)
1215
+ loss = self.loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
1216
+
1217
+ if return_dict is False:
1218
+ output = (logits, sequence_output, outputs.hidden_states, outputs.attentions)
1219
+ if loss is not None:
1220
+ return (loss,) + output
1221
+ return output
1222
+
1223
+ return DPLM2MaskedLMOutput(
1224
+ loss=loss,
1225
+ logits=logits,
1226
+ last_hidden_state=sequence_output,
1227
+ hidden_states=outputs.hidden_states,
1228
+ attentions=outputs.attentions,
1229
+ )
1230
+
1231
+
1232
+ class DPLM2ForSequenceClassification(DPLM2PreTrainedModel, EmbeddingMixin):
1233
+ config_class = DPLM2Config
1234
+
1235
+ def __init__(self, config):
1236
+ DPLM2PreTrainedModel.__init__(self, config)
1237
+ self.num_labels = config.num_labels
1238
+ self.esm = FAST_DPLM2_ENCODER(config)
1239
+ self.classifier = EsmClassificationHead(config)
1240
+ self.mse = nn.MSELoss()
1241
+ self.ce = nn.CrossEntropyLoss()
1242
+ self.bce = nn.BCEWithLogitsLoss()
1243
+ self.post_init()
1244
+
1245
+ def get_input_embeddings(self) -> nn.Module:
1246
+ return self.esm.get_input_embeddings()
1247
+
1248
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1249
+ return self.esm._embed(input_ids, attention_mask)
1250
+
1251
+ def forward(
1252
+ self,
1253
+ input_ids: Optional[torch.Tensor] = None,
1254
+ attention_mask: Optional[torch.Tensor] = None,
1255
+ type_ids: Optional[torch.Tensor] = None,
1256
+ inputs_embeds: Optional[torch.Tensor] = None,
1257
+ labels: Optional[torch.Tensor] = None,
1258
+ output_attentions: Optional[bool] = None,
1259
+ output_hidden_states: Optional[bool] = None,
1260
+ return_dict: Optional[bool] = None,
1261
+ **kwargs,
1262
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1263
+ if type_ids is None and input_ids is not None:
1264
+ if attention_mask is None:
1265
+ attention_mask = input_ids.ne(self.config.pad_token_id)
1266
+ type_ids = _infer_modality_type(input_ids, attention_mask)
1267
+
1268
+ outputs = self.esm(
1269
+ input_ids=input_ids,
1270
+ attention_mask=attention_mask,
1271
+ type_ids=type_ids,
1272
+ inputs_embeds=inputs_embeds,
1273
+ output_attentions=output_attentions,
1274
+ output_hidden_states=output_hidden_states,
1275
+ return_dict=True,
1276
+ )
1277
+ sequence_output = outputs.last_hidden_state
1278
+ logits = self.classifier(sequence_output)
1279
+
1280
+ loss = None
1281
+ if labels is not None:
1282
+ labels = labels.to(logits.device)
1283
+ if self.config.problem_type is None:
1284
+ if self.num_labels == 1:
1285
+ self.config.problem_type = "regression"
1286
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1287
+ self.config.problem_type = "single_label_classification"
1288
+ else:
1289
+ self.config.problem_type = "multi_label_classification"
1290
+
1291
+ if self.config.problem_type == "regression":
1292
+ if self.num_labels == 1:
1293
+ loss = self.mse(logits.squeeze(), labels.squeeze())
1294
+ else:
1295
+ loss = self.mse(logits, labels)
1296
+ elif self.config.problem_type == "single_label_classification":
1297
+ loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
1298
+ elif self.config.problem_type == "multi_label_classification":
1299
+ loss = self.bce(logits, labels)
1300
+
1301
+ return SequenceClassifierOutput(
1302
+ loss=loss,
1303
+ logits=logits,
1304
+ hidden_states=outputs.hidden_states,
1305
+ attentions=outputs.attentions,
1306
+ )
1307
+
1308
+
1309
+ class DPLM2ForTokenClassification(DPLM2PreTrainedModel, EmbeddingMixin):
1310
+ config_class = DPLM2Config
1311
+
1312
+ def __init__(self, config):
1313
+ DPLM2PreTrainedModel.__init__(self, config)
1314
+ self.num_labels = config.num_labels
1315
+ self.esm = FAST_DPLM2_ENCODER(config)
1316
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1317
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1318
+ self.loss_fct = nn.CrossEntropyLoss()
1319
+ self.post_init()
1320
+
1321
+ def get_input_embeddings(self) -> nn.Module:
1322
+ return self.esm.get_input_embeddings()
1323
+
1324
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1325
+ return self.esm._embed(input_ids, attention_mask)
1326
+
1327
+ def forward(
1328
+ self,
1329
+ input_ids: Optional[torch.Tensor] = None,
1330
+ attention_mask: Optional[torch.Tensor] = None,
1331
+ type_ids: Optional[torch.Tensor] = None,
1332
+ inputs_embeds: Optional[torch.Tensor] = None,
1333
+ labels: Optional[torch.Tensor] = None,
1334
+ output_attentions: Optional[bool] = None,
1335
+ output_hidden_states: Optional[bool] = None,
1336
+ return_dict: Optional[bool] = None,
1337
+ **kwargs,
1338
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1339
+ if type_ids is None and input_ids is not None:
1340
+ if attention_mask is None:
1341
+ attention_mask = input_ids.ne(self.config.pad_token_id)
1342
+ type_ids = _infer_modality_type(input_ids, attention_mask)
1343
+
1344
+ outputs = self.esm(
1345
+ input_ids=input_ids,
1346
+ attention_mask=attention_mask,
1347
+ type_ids=type_ids,
1348
+ inputs_embeds=inputs_embeds,
1349
+ output_attentions=output_attentions,
1350
+ output_hidden_states=output_hidden_states,
1351
+ return_dict=True,
1352
+ )
1353
+ sequence_output = self.dropout(outputs.last_hidden_state)
1354
+ logits = self.classifier(sequence_output)
1355
+
1356
+ loss = None
1357
+ if labels is not None:
1358
+ labels = labels.to(logits.device)
1359
+ loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1360
+
1361
+ return TokenClassifierOutput(
1362
+ loss=loss,
1363
+ logits=logits,
1364
+ hidden_states=outputs.hidden_states,
1365
+ attentions=outputs.attentions,
1366
+ )