Update modeling_esm_plusplus.py
Browse files- modeling_esm_plusplus.py +73 -17
modeling_esm_plusplus.py
CHANGED
|
@@ -16,7 +16,7 @@ import torch.nn.functional as F
|
|
| 16 |
from dataclasses import dataclass
|
| 17 |
from functools import cache, partial
|
| 18 |
from pathlib import Path
|
| 19 |
-
from typing import Optional, Tuple
|
| 20 |
from einops import rearrange, repeat
|
| 21 |
from huggingface_hub import snapshot_download
|
| 22 |
from tokenizers import Tokenizer
|
|
@@ -48,6 +48,7 @@ class ESMplusplusConfig(PretrainedConfig):
|
|
| 48 |
num_hidden_layers: int = 30,
|
| 49 |
num_labels: int = 2,
|
| 50 |
problem_type: str | None = None,
|
|
|
|
| 51 |
**kwargs,
|
| 52 |
):
|
| 53 |
super().__init__(**kwargs)
|
|
@@ -57,6 +58,7 @@ class ESMplusplusConfig(PretrainedConfig):
|
|
| 57 |
self.num_hidden_layers = num_hidden_layers
|
| 58 |
self.num_labels = num_labels
|
| 59 |
self.problem_type = problem_type
|
|
|
|
| 60 |
|
| 61 |
|
| 62 |
### Rotary Embeddings
|
|
@@ -290,15 +292,17 @@ class MultiHeadAttention(nn.Module):
|
|
| 290 |
k = k.flatten(-2, -1)
|
| 291 |
return q, k
|
| 292 |
|
| 293 |
-
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 294 |
"""
|
| 295 |
Args:
|
| 296 |
x: Input tensor
|
| 297 |
attention_mask: Optional attention mask
|
|
|
|
| 298 |
|
| 299 |
Returns:
|
| 300 |
-
Output tensor after self attention
|
| 301 |
"""
|
|
|
|
| 302 |
qkv_BLD3 = self.layernorm_qkv(x)
|
| 303 |
query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
|
| 304 |
query_BLD, key_BLD = (
|
|
@@ -307,11 +311,29 @@ class MultiHeadAttention(nn.Module):
|
|
| 307 |
)
|
| 308 |
query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
|
| 309 |
query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
context_BLD = rearrange(context_BHLD, "b h s d -> b s (h d)")
|
| 314 |
-
|
|
|
|
| 315 |
|
| 316 |
|
| 317 |
### Regression Head
|
|
@@ -360,19 +382,23 @@ class UnifiedTransformerBlock(nn.Module):
|
|
| 360 |
self,
|
| 361 |
x: torch.Tensor,
|
| 362 |
attention_mask: Optional[torch.Tensor] = None,
|
| 363 |
-
|
|
|
|
| 364 |
"""
|
| 365 |
Args:
|
| 366 |
x: Input tensor
|
| 367 |
attention_mask: Optional attention mask
|
|
|
|
| 368 |
|
| 369 |
Returns:
|
| 370 |
-
Output tensor after transformer block
|
| 371 |
"""
|
| 372 |
-
|
| 373 |
-
x = x +
|
| 374 |
r3 = self.ffn(x) / self.scaling_factor
|
| 375 |
x = x + r3
|
|
|
|
|
|
|
| 376 |
return x
|
| 377 |
|
| 378 |
|
|
@@ -382,6 +408,7 @@ class TransformerOutput(ModelOutput):
|
|
| 382 |
"""Output type for transformer encoder."""
|
| 383 |
last_hidden_state: Optional[torch.Tensor] = None
|
| 384 |
hidden_states: Optional[Tuple[torch.Tensor]] = None
|
|
|
|
| 385 |
|
| 386 |
|
| 387 |
@dataclass
|
|
@@ -391,6 +418,7 @@ class ESMplusplusOutput(ModelOutput):
|
|
| 391 |
logits: Optional[torch.Tensor] = None
|
| 392 |
last_hidden_state: Optional[torch.Tensor] = None
|
| 393 |
hidden_states: Optional[Tuple[torch.Tensor]] = None
|
|
|
|
| 394 |
|
| 395 |
|
| 396 |
### Transformer Stack
|
|
@@ -426,25 +454,42 @@ class TransformerStack(nn.Module):
|
|
| 426 |
x: torch.Tensor,
|
| 427 |
attention_mask: Optional[torch.Tensor] = None,
|
| 428 |
output_hidden_states: bool = False,
|
|
|
|
| 429 |
) -> TransformerOutput:
|
| 430 |
"""
|
| 431 |
Args:
|
| 432 |
x: Input tensor
|
| 433 |
attention_mask: Optional attention mask
|
| 434 |
output_hidden_states: Whether to return all hidden states
|
|
|
|
| 435 |
|
| 436 |
Returns:
|
| 437 |
-
TransformerOutput containing last hidden state and optionally all hidden states
|
| 438 |
"""
|
| 439 |
batch_size, seq_len, _ = x.shape
|
| 440 |
-
hidden_states = ()
|
|
|
|
|
|
|
| 441 |
if attention_mask is not None:
|
| 442 |
attention_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_len, seq_len).bool()
|
|
|
|
| 443 |
for block in self.blocks:
|
| 444 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
if output_hidden_states:
|
|
|
|
| 446 |
hidden_states += (x,)
|
| 447 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 448 |
|
| 449 |
|
| 450 |
### Dataset for Embedding
|
|
@@ -604,12 +649,19 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
|
|
| 604 |
|
| 605 |
return embeddings_dict
|
| 606 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 607 |
def forward(
|
| 608 |
self,
|
| 609 |
input_ids: Optional[torch.Tensor] = None,
|
| 610 |
attention_mask: Optional[torch.Tensor] = None,
|
| 611 |
labels: Optional[torch.Tensor] = None,
|
| 612 |
output_hidden_states: bool = False,
|
|
|
|
| 613 |
) -> ESMplusplusOutput:
|
| 614 |
"""Forward pass for masked language modeling.
|
| 615 |
|
|
@@ -618,12 +670,13 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
|
|
| 618 |
attention_mask: Attention mask
|
| 619 |
labels: Optional labels for masked tokens
|
| 620 |
output_hidden_states: Whether to return all hidden states
|
|
|
|
| 621 |
|
| 622 |
Returns:
|
| 623 |
-
ESMplusplusOutput containing loss, logits, and
|
| 624 |
"""
|
| 625 |
x = self.embed(input_ids)
|
| 626 |
-
output = self.transformer(x, attention_mask, output_hidden_states)
|
| 627 |
x = output.last_hidden_state
|
| 628 |
logits = self.sequence_head(x)
|
| 629 |
loss = None
|
|
@@ -634,6 +687,7 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
|
|
| 634 |
logits=logits,
|
| 635 |
last_hidden_state=x,
|
| 636 |
hidden_states=output.hidden_states,
|
|
|
|
| 637 |
)
|
| 638 |
|
| 639 |
|
|
@@ -658,6 +712,7 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
|
|
| 658 |
attention_mask: Optional[torch.Tensor] = None,
|
| 659 |
labels: Optional[torch.Tensor] = None,
|
| 660 |
output_hidden_states: bool = False,
|
|
|
|
| 661 |
) -> ESMplusplusOutput:
|
| 662 |
"""Forward pass for sequence classification.
|
| 663 |
|
|
@@ -666,6 +721,7 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
|
|
| 666 |
attention_mask: Attention mask
|
| 667 |
labels: Optional labels for classification
|
| 668 |
output_hidden_states: Whether to return all hidden states
|
|
|
|
| 669 |
|
| 670 |
Returns:
|
| 671 |
ESMplusplusOutput containing loss, logits, and hidden states
|
|
|
|
| 16 |
from dataclasses import dataclass
|
| 17 |
from functools import cache, partial
|
| 18 |
from pathlib import Path
|
| 19 |
+
from typing import Optional, Tuple, Union
|
| 20 |
from einops import rearrange, repeat
|
| 21 |
from huggingface_hub import snapshot_download
|
| 22 |
from tokenizers import Tokenizer
|
|
|
|
| 48 |
num_hidden_layers: int = 30,
|
| 49 |
num_labels: int = 2,
|
| 50 |
problem_type: str | None = None,
|
| 51 |
+
dropout: float = 0.0,
|
| 52 |
**kwargs,
|
| 53 |
):
|
| 54 |
super().__init__(**kwargs)
|
|
|
|
| 58 |
self.num_hidden_layers = num_hidden_layers
|
| 59 |
self.num_labels = num_labels
|
| 60 |
self.problem_type = problem_type
|
| 61 |
+
self.dropout = dropout
|
| 62 |
|
| 63 |
|
| 64 |
### Rotary Embeddings
|
|
|
|
| 292 |
k = k.flatten(-2, -1)
|
| 293 |
return q, k
|
| 294 |
|
| 295 |
+
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 296 |
"""
|
| 297 |
Args:
|
| 298 |
x: Input tensor
|
| 299 |
attention_mask: Optional attention mask
|
| 300 |
+
output_attentions: Whether to return attention weights
|
| 301 |
|
| 302 |
Returns:
|
| 303 |
+
Output tensor after self attention, and optionally attention weights
|
| 304 |
"""
|
| 305 |
+
attn_weights = None
|
| 306 |
qkv_BLD3 = self.layernorm_qkv(x)
|
| 307 |
query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
|
| 308 |
query_BLD, key_BLD = (
|
|
|
|
| 311 |
)
|
| 312 |
query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
|
| 313 |
query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
|
| 314 |
+
|
| 315 |
+
if output_attentions: # Manual attention computation
|
| 316 |
+
L, S = query_BLD.size(-2), key_BLD.size(-2)
|
| 317 |
+
scale = 1 / math.sqrt(query_BLD.size(-1))
|
| 318 |
+
attn_bias = torch.zeros(L, S, dtype=query_BLD.dtype, device=query_BLD.device)
|
| 319 |
+
if attention_mask is not None:
|
| 320 |
+
if attention_mask.dtype == torch.bool:
|
| 321 |
+
attention_mask.masked_fill_(attention_mask.logical_not(), float('-inf'))
|
| 322 |
+
else:
|
| 323 |
+
attn_bias += attention_mask
|
| 324 |
+
|
| 325 |
+
attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale
|
| 326 |
+
attn_weights += attn_bias
|
| 327 |
+
attn_weights = F.softmax(attn_weights, dim=-1)
|
| 328 |
+
context_BHLD = torch.matmul(attn_weights, value_BHLD)
|
| 329 |
+
else:
|
| 330 |
+
context_BHLD = F.scaled_dot_product_attention(
|
| 331 |
+
query_BHLD, key_BHLD, value_BHLD, attention_mask
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
context_BLD = rearrange(context_BHLD, "b h s d -> b s (h d)")
|
| 335 |
+
output = self.out_proj(context_BLD)
|
| 336 |
+
return output, attn_weights
|
| 337 |
|
| 338 |
|
| 339 |
### Regression Head
|
|
|
|
| 382 |
self,
|
| 383 |
x: torch.Tensor,
|
| 384 |
attention_mask: Optional[torch.Tensor] = None,
|
| 385 |
+
output_attentions: bool = False,
|
| 386 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 387 |
"""
|
| 388 |
Args:
|
| 389 |
x: Input tensor
|
| 390 |
attention_mask: Optional attention mask
|
| 391 |
+
output_attentions: Whether to return attention weights
|
| 392 |
|
| 393 |
Returns:
|
| 394 |
+
Output tensor after transformer block, and optionally attention weights
|
| 395 |
"""
|
| 396 |
+
attn_output, attn_weights = self.attn(x, attention_mask, output_attentions)
|
| 397 |
+
x = x + attn_output / self.scaling_factor
|
| 398 |
r3 = self.ffn(x) / self.scaling_factor
|
| 399 |
x = x + r3
|
| 400 |
+
if output_attentions:
|
| 401 |
+
return x, attn_weights
|
| 402 |
return x
|
| 403 |
|
| 404 |
|
|
|
|
| 408 |
"""Output type for transformer encoder."""
|
| 409 |
last_hidden_state: Optional[torch.Tensor] = None
|
| 410 |
hidden_states: Optional[Tuple[torch.Tensor]] = None
|
| 411 |
+
attentions: Optional[Tuple[torch.Tensor]] = None
|
| 412 |
|
| 413 |
|
| 414 |
@dataclass
|
|
|
|
| 418 |
logits: Optional[torch.Tensor] = None
|
| 419 |
last_hidden_state: Optional[torch.Tensor] = None
|
| 420 |
hidden_states: Optional[Tuple[torch.Tensor]] = None
|
| 421 |
+
attentions: Optional[Tuple[torch.Tensor]] = None
|
| 422 |
|
| 423 |
|
| 424 |
### Transformer Stack
|
|
|
|
| 454 |
x: torch.Tensor,
|
| 455 |
attention_mask: Optional[torch.Tensor] = None,
|
| 456 |
output_hidden_states: bool = False,
|
| 457 |
+
output_attentions: bool = False,
|
| 458 |
) -> TransformerOutput:
|
| 459 |
"""
|
| 460 |
Args:
|
| 461 |
x: Input tensor
|
| 462 |
attention_mask: Optional attention mask
|
| 463 |
output_hidden_states: Whether to return all hidden states
|
| 464 |
+
output_attentions: Whether to return attention weights
|
| 465 |
|
| 466 |
Returns:
|
| 467 |
+
TransformerOutput containing last hidden state and optionally all hidden states and attention weights
|
| 468 |
"""
|
| 469 |
batch_size, seq_len, _ = x.shape
|
| 470 |
+
hidden_states = () if output_hidden_states else None
|
| 471 |
+
attentions = () if output_attentions else None
|
| 472 |
+
|
| 473 |
if attention_mask is not None:
|
| 474 |
attention_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_len, seq_len).bool()
|
| 475 |
+
|
| 476 |
for block in self.blocks:
|
| 477 |
+
if output_attentions:
|
| 478 |
+
x, attn_weights = block(x, attention_mask, output_attentions)
|
| 479 |
+
if attentions is not None:
|
| 480 |
+
attentions += (attn_weights,)
|
| 481 |
+
else:
|
| 482 |
+
x = block(x, attention_mask, output_attentions)
|
| 483 |
+
|
| 484 |
if output_hidden_states:
|
| 485 |
+
assert hidden_states is not None
|
| 486 |
hidden_states += (x,)
|
| 487 |
+
|
| 488 |
+
return TransformerOutput(
|
| 489 |
+
last_hidden_state=self.norm(x),
|
| 490 |
+
hidden_states=hidden_states,
|
| 491 |
+
attentions=attentions
|
| 492 |
+
)
|
| 493 |
|
| 494 |
|
| 495 |
### Dataset for Embedding
|
|
|
|
| 649 |
|
| 650 |
return embeddings_dict
|
| 651 |
|
| 652 |
+
"""
|
| 653 |
+
TODO
|
| 654 |
+
- Add dropout (default 0.0)
|
| 655 |
+
- Class method for returning manually computed attention maps
|
| 656 |
+
"""
|
| 657 |
+
|
| 658 |
def forward(
|
| 659 |
self,
|
| 660 |
input_ids: Optional[torch.Tensor] = None,
|
| 661 |
attention_mask: Optional[torch.Tensor] = None,
|
| 662 |
labels: Optional[torch.Tensor] = None,
|
| 663 |
output_hidden_states: bool = False,
|
| 664 |
+
output_attentions: bool = False,
|
| 665 |
) -> ESMplusplusOutput:
|
| 666 |
"""Forward pass for masked language modeling.
|
| 667 |
|
|
|
|
| 670 |
attention_mask: Attention mask
|
| 671 |
labels: Optional labels for masked tokens
|
| 672 |
output_hidden_states: Whether to return all hidden states
|
| 673 |
+
output_attentions: Whether to return attention weights
|
| 674 |
|
| 675 |
Returns:
|
| 676 |
+
ESMplusplusOutput containing loss, logits, hidden states and attention weights
|
| 677 |
"""
|
| 678 |
x = self.embed(input_ids)
|
| 679 |
+
output = self.transformer(x, attention_mask, output_hidden_states, output_attentions)
|
| 680 |
x = output.last_hidden_state
|
| 681 |
logits = self.sequence_head(x)
|
| 682 |
loss = None
|
|
|
|
| 687 |
logits=logits,
|
| 688 |
last_hidden_state=x,
|
| 689 |
hidden_states=output.hidden_states,
|
| 690 |
+
attentions=output.attentions,
|
| 691 |
)
|
| 692 |
|
| 693 |
|
|
|
|
| 712 |
attention_mask: Optional[torch.Tensor] = None,
|
| 713 |
labels: Optional[torch.Tensor] = None,
|
| 714 |
output_hidden_states: bool = False,
|
| 715 |
+
output_attentions: bool = False,
|
| 716 |
) -> ESMplusplusOutput:
|
| 717 |
"""Forward pass for sequence classification.
|
| 718 |
|
|
|
|
| 721 |
attention_mask: Attention mask
|
| 722 |
labels: Optional labels for classification
|
| 723 |
output_hidden_states: Whether to return all hidden states
|
| 724 |
+
output_attentions: Whether to return attention weights
|
| 725 |
|
| 726 |
Returns:
|
| 727 |
ESMplusplusOutput containing loss, logits, and hidden states
|