File size: 38,270 Bytes
74d44ea ef2455a 74d44ea | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 | """
ESM2-Flash: ESM2 with flash attention and packed-sequence support.
Drop-in replacement for HuggingFace's EsmModel / EsmForMaskedLM with three
attention backends:
- flash_attn_varlen_func (packed sequences via cu_seqlens)
- scaled_dot_product_attention (default for padded sequences)
- eager matmul (when output_attentions=True)
Weight names are identical to the original ESM2 so pretrained checkpoints
load with strict=True.
"""
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn.functional import scaled_dot_product_attention
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
MaskedLMOutput,
)
from transformers.modeling_utils import PreTrainedModel
from .configuration_esm2_flash import Esm2FlashConfig
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func
FLASH_ATTN_AVAILABLE = True
except ImportError:
FLASH_ATTN_AVAILABLE = False
# ---------------------------------------------------------------------------
# Helper functions (matching original ESM2 exactly)
# ---------------------------------------------------------------------------
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(x, cos, sin):
"""Apply rotary embeddings. Supports two shape conventions:
Standard (original ESM2):
x: (batch, heads, seq, dim)
cos: (1, 1, seq, dim)
sin: (1, 1, seq, dim)
Packed:
x: (total_tokens, heads, dim)
cos: (total_tokens, 1, dim)
sin: (total_tokens, 1, dim)
"""
if x.dim() == 4:
# Standard path: slice cos/sin to match x seq length
cos = cos[:, :, : x.shape[-2], :]
sin = sin[:, :, : x.shape[-2], :]
return (x * cos) + (rotate_half(x) * sin)
def gelu(x):
"""Original ESM gelu. Using F.gelu yields subtly wrong results."""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def symmetrize(x):
"""Make layer symmetric in final two dimensions, used for contact prediction."""
return x + x.transpose(-1, -2)
def average_product_correct(x):
"""Perform average product correct, used for contact prediction."""
a1 = x.sum(-1, keepdims=True)
a2 = x.sum(-2, keepdims=True)
a12 = x.sum((-1, -2), keepdims=True)
avg = a1 * a2
avg.div_(a12)
normalized = x - avg
return normalized
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
"""
Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols are ignored.
"""
mask = input_ids.ne(padding_idx).int()
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
return incremental_indices.long() + padding_idx
# ---------------------------------------------------------------------------
# Rotary embeddings (extended with position_ids support for packing)
# ---------------------------------------------------------------------------
class RotaryEmbedding(torch.nn.Module):
"""
Rotary position embeddings based on RoFormer. Extended to accept explicit
position_ids for packed-sequence support.
"""
def __init__(self, dim: int):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self._seq_len_cached = None
self._cos_cached = None
self._sin_cached = None
def _update_cos_sin_tables(self, x, seq_dimension=2):
seq_len = x.shape[seq_dimension]
if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
self._seq_len_cached = seq_len
t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self._cos_cached = emb.cos()[None, None, :, :]
self._sin_cached = emb.sin()[None, None, :, :]
return self._cos_cached, self._sin_cached
def _compute_from_position_ids(self, position_ids, device, dtype):
"""Compute cos/sin tables from explicit position_ids (for packed sequences).
Args:
position_ids: (total_tokens,) int tensor, 0-indexed per sub-sequence
device: target device
dtype: target dtype for inv_freq
Returns:
cos: (total_tokens, 1, dim)
sin: (total_tokens, 1, dim)
"""
t = position_ids.float()
freqs = torch.outer(t, self.inv_freq.to(device=device))
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().unsqueeze(1) # (total_tokens, 1, dim)
sin = emb.sin().unsqueeze(1)
return cos, sin
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
q, k: query/key tensors.
Standard: (batch, heads, seq, dim)
Packed: (total_tokens, heads, dim)
position_ids: optional (total_tokens,) for packed mode
"""
if position_ids is not None:
# Packed path
cos, sin = self._compute_from_position_ids(position_ids, q.device, q.dtype)
else:
# Standard path (original ESM2 behaviour)
cos, sin = self._update_cos_sin_tables(k, seq_dimension=-2)
return (
apply_rotary_pos_emb(q, cos, sin),
apply_rotary_pos_emb(k, cos, sin),
)
# ---------------------------------------------------------------------------
# Contact prediction head (unchanged from ESM2)
# ---------------------------------------------------------------------------
class EsmContactPredictionHead(nn.Module):
"""Performs symmetrization, apc, and computes a logistic regression on the output features."""
def __init__(self, in_features: int, bias=True, eos_idx: int = 2):
super().__init__()
self.in_features = in_features
self.eos_idx = eos_idx
self.regression = nn.Linear(in_features, 1, bias)
self.activation = nn.Sigmoid()
def forward(self, tokens, attentions):
eos_mask = tokens.ne(self.eos_idx).to(attentions)
eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
attentions = attentions * eos_mask[:, None, None, :, :]
attentions = attentions[..., :-1, :-1]
attentions = attentions[..., 1:, 1:]
batch_size, layers, heads, seqlen, _ = attentions.size()
attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
attentions = average_product_correct(symmetrize(attentions))
attentions = attentions.permute(0, 2, 3, 1)
return self.activation(self.regression(attentions).squeeze(3))
# ---------------------------------------------------------------------------
# Embeddings
# ---------------------------------------------------------------------------
class Esm2FlashEmbeddings(nn.Module):
"""
Same as EsmEmbeddings with packed-sequence support for token_dropout.
"""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
if config.emb_layer_norm_before:
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
else:
self.layer_norm = None
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
self.padding_idx = config.pad_token_id
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
)
self.token_dropout = config.token_dropout
self.mask_token_id = config.mask_token_id
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
inputs_embeds=None,
past_key_values_length=0,
cu_seqlens=None,
):
if position_ids is None:
if input_ids is not None:
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
embeddings = inputs_embeds
if self.token_dropout:
embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0)
mask_ratio_train = 0.15 * 0.8
if cu_seqlens is not None:
# Packed sequences: compute src_lengths from cu_seqlens
seq_lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).float() # (num_seqs,)
# Count mask tokens per sequence
mask_counts = []
for i in range(len(seq_lengths)):
start, end = cu_seqlens[i], cu_seqlens[i + 1]
mask_counts.append((input_ids[0, start:end] == self.mask_token_id).sum().float())
mask_counts = torch.stack(mask_counts)
mask_ratio_observed = mask_counts / seq_lengths
# Build per-token scale factor
scale = (1 - mask_ratio_train) / (1 - mask_ratio_observed) # (num_seqs,)
# Expand to per-token
per_token_scale = torch.zeros(
embeddings.shape[1], device=embeddings.device, dtype=embeddings.dtype
)
for i in range(len(seq_lengths)):
start, end = cu_seqlens[i].item(), cu_seqlens[i + 1].item()
per_token_scale[start:end] = scale[i]
embeddings = (embeddings * per_token_scale[None, :, None]).to(embeddings.dtype)
else:
src_lengths = attention_mask.sum(-1)
mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to(
embeddings.dtype
)
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings = embeddings + position_embeddings
if self.layer_norm is not None:
embeddings = self.layer_norm(embeddings)
if attention_mask is not None:
embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
return embeddings
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
input_shape = inputs_embeds.size()[:-1]
sequence_length = input_shape[1]
position_ids = torch.arange(
self.padding_idx + 1,
sequence_length + self.padding_idx + 1,
dtype=torch.long,
device=inputs_embeds.device,
)
return position_ids.unsqueeze(0).expand(input_shape)
# ---------------------------------------------------------------------------
# Attention
# ---------------------------------------------------------------------------
class Esm2FlashSelfAttention(nn.Module):
"""Self-attention with three backends: flash, SDPA, and eager."""
def __init__(self, config, position_embedding_type=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
self.rotary_embeddings = None
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
elif self.position_embedding_type == "rotary":
self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
"""Reshape (batch, seq, hidden) -> (batch, heads, seq, dim)."""
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
position_ids: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
) -> Tuple[torch.Tensor, ...]:
batch_size, seq_len, _ = hidden_states.shape
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
# ESM2-specific: scale query before rotary (not the scores)
query_layer = query_layer * self.attention_head_size**-0.5
# --- Flash attention path (packed sequences) ---
if cu_seqlens is not None:
assert FLASH_ATTN_AVAILABLE, (
"flash_attn is required for packed sequences. "
"Install with: pip install flash-attn --no-build-isolation"
)
assert not output_attentions, "output_attentions is not supported with packed sequences."
assert batch_size == 1, "Packed sequences require batch_size=1."
# Reshape to (total_tokens, heads, dim) for flash_attn_varlen
q = query_layer.squeeze(0).transpose(0, 1) # (heads, seq, dim) -> (seq, heads, dim)
k = key_layer.squeeze(0).transpose(0, 1)
v = value_layer.squeeze(0).transpose(0, 1)
# Apply rotary with explicit position_ids
if self.rotary_embeddings is not None:
# position_ids: (1, total_tokens) -> (total_tokens,)
pos_ids = position_ids.squeeze(0) if position_ids is not None else None
q, k = self.rotary_embeddings(q, k, position_ids=pos_ids)
# Flash attention requires fp16 or bf16
input_dtype = q.dtype
if input_dtype == torch.float32:
q = q.to(torch.bfloat16)
k = k.to(torch.bfloat16)
v = v.to(torch.bfloat16)
context_layer = flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=self.dropout.p if self.training else 0.0,
causal=False,
softmax_scale=1.0, # Q is already scaled
)
# Cast back to input dtype
if input_dtype == torch.float32:
context_layer = context_layer.to(input_dtype)
# (total_tokens, heads, dim) -> (1, total_tokens, hidden_size)
context_layer = context_layer.reshape(1, seq_len, self.all_head_size)
return (context_layer,)
# --- Standard paths (padded sequences) ---
# Apply rotary with sequential positions (original ESM2 behaviour)
if self.position_embedding_type == "rotary":
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
# --- Eager path (output_attentions=True) ---
if output_attentions:
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype)
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum(
"bhld,lrd->bhlr", query_layer, positional_embedding
)
relative_position_scores_key = torch.einsum(
"bhrd,lrd->bhlr", key_layer, positional_embedding
)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
if attention_mask is not None:
attention_scores = attention_scores + attention_mask
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs.to(value_layer.dtype), value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return (context_layer, attention_probs)
# --- SDPA path (default for padded sequences) ---
context_layer = scaled_dot_product_attention(
query=query_layer,
key=key_layer,
value=value_layer,
attn_mask=attention_mask,
dropout_p=self.dropout.p if self.training else 0.0,
scale=1.0, # Q is already scaled
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return (context_layer,)
class EsmSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + input_tensor
return hidden_states
class Esm2FlashAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.self = Esm2FlashSelfAttention(config)
self.output = EsmSelfOutput(config)
self.pruned_heads = set()
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
output_attentions=False,
position_ids=None,
cu_seqlens=None,
max_seqlen=None,
):
hidden_states_ln = self.LayerNorm(hidden_states)
self_outputs = self.self(
hidden_states_ln,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:]
return outputs
# ---------------------------------------------------------------------------
# Feed-forward
# ---------------------------------------------------------------------------
class EsmIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = gelu(hidden_states)
return hidden_states
class EsmOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + input_tensor
return hidden_states
# ---------------------------------------------------------------------------
# Transformer layer
# ---------------------------------------------------------------------------
class Esm2FlashLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = Esm2FlashAttention(config)
self.intermediate = EsmIntermediate(config)
self.output = EsmOutput(config)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
output_attentions=False,
position_ids=None,
cu_seqlens=None,
max_seqlen=None,
):
self_attention_outputs = self.attention(
hidden_states,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # attentions if output_attentions
layer_output = self.feed_forward_chunk(attention_output)
outputs = (layer_output,) + outputs
return outputs
def feed_forward_chunk(self, attention_output):
attention_output_ln = self.LayerNorm(attention_output)
intermediate_output = self.intermediate(attention_output_ln)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
# ---------------------------------------------------------------------------
# Encoder (stack of layers)
# ---------------------------------------------------------------------------
class Esm2FlashEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([Esm2FlashLayer(config) for _ in range(config.num_hidden_layers)])
self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
position_ids=None,
cu_seqlens=None,
max_seqlen=None,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
output_attentions,
position_ids,
cu_seqlens,
max_seqlen,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask=attention_mask,
head_mask=layer_head_mask,
output_attentions=output_attentions,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.emb_layer_norm_after:
hidden_states = self.emb_layer_norm_after(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
# ---------------------------------------------------------------------------
# Pooler
# ---------------------------------------------------------------------------
class EsmPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
# ---------------------------------------------------------------------------
# LM Head
# ---------------------------------------------------------------------------
class EsmLMHead(nn.Module):
"""ESM Head for masked language modeling."""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
def forward(self, features, **kwargs):
x = self.dense(features)
x = gelu(x)
x = self.layer_norm(x)
x = self.decoder(x) + self.bias
return x
# ---------------------------------------------------------------------------
# PreTrainedModel base
# ---------------------------------------------------------------------------
class Esm2FlashPreTrainedModel(PreTrainedModel):
config_class = Esm2FlashConfig
base_model_prefix = "esm"
supports_gradient_checkpointing = True
_no_split_modules = ["Esm2FlashLayer", "Esm2FlashEmbeddings"]
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
# ---------------------------------------------------------------------------
# Esm2FlashModel
# ---------------------------------------------------------------------------
class Esm2FlashModel(Esm2FlashPreTrainedModel):
"""
ESM2 encoder with flash attention and packed-sequence support.
Accepts the same inputs as EsmModel, plus:
cu_seqlens: int32 tensor of cumulative sequence lengths for packing
max_seqlen: maximum sequence length in the packed batch
"""
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
self.embeddings = Esm2FlashEmbeddings(config)
self.encoder = Esm2FlashEncoder(config)
self.pooler = EsmPooler(config) if add_pooling_layer else None
self.contact_head = EsmContactPredictionHead(
in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
)
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
# --- Packed sequence path ---
if cu_seqlens is not None:
assert max_seqlen is not None, "max_seqlen must be provided when cu_seqlens is not None"
assert batch_size == 1, "Packed sequences require batch_size=1"
assert not output_attentions, "output_attentions is not supported with packed sequences"
# Compute rotary-compatible position_ids if not provided
# For packed sequences, position_ids should be 0-indexed per sub-sequence
if position_ids is None:
position_ids = torch.zeros(1, seq_length, dtype=torch.long, device=device)
for i in range(cu_seqlens.shape[0] - 1):
start = cu_seqlens[i].item()
end = cu_seqlens[i + 1].item()
position_ids[0, start:end] = torch.arange(end - start, device=device)
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
cu_seqlens=cu_seqlens,
)
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
encoder_outputs = self.encoder(
embedding_output,
head_mask=head_mask,
output_attentions=False,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
else:
# --- Standard padded path ---
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length)), device=device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def predict_contacts(self, tokens, attention_mask):
attns = self(tokens, attention_mask=attention_mask, return_dict=True, output_attentions=True).attentions
attns = torch.stack(attns, dim=1)
attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3)
attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
return self.contact_head(tokens, attns)
# ---------------------------------------------------------------------------
# Esm2FlashForMaskedLM
# ---------------------------------------------------------------------------
class Esm2FlashForMaskedLM(Esm2FlashPreTrainedModel):
_tied_weights_keys = ["lm_head.decoder.weight"]
def __init__(self, config):
super().__init__(config)
self.esm = Esm2FlashModel(config, add_pooling_layer=False)
self.lm_head = EsmLMHead(config)
self.init_weights()
def get_output_embeddings(self):
return self.lm_head.decoder
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MaskedLMOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.esm(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
labels = labels.to(prediction_scores.device)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def predict_contacts(self, tokens, attention_mask):
return self.esm.predict_contacts(tokens, attention_mask=attention_mask)
|