| | import torch |
| | import torch.nn as nn |
| | from typing import Optional, List |
| | from dataclasses import dataclass |
| | from torch.nn.attention.flex_attention import create_block_mask |
| | from transformers import EsmTokenizer, PretrainedConfig, PreTrainedModel |
| | from transformers.modeling_outputs import ModelOutput |
| |
|
| | from model.attention import SelfAttention, MultiHeadPAttention |
| | from model.utils import norm, MLP |
| |
|
| |
|
| | @dataclass |
| | class PLMConfig(PretrainedConfig): |
| | def __init__( |
| | self, |
| | hidden_size: int = 512, |
| | num_attention_heads: int = 8, |
| | num_hidden_layers: int = 12, |
| | num_att_tokens: int = 512, |
| | vocab_size: int = 33, |
| | expansion_ratio: float = 2.0, |
| | attention_soft_cap: float = 64.0, |
| | add_att_soft_cap: bool = True, |
| | soft_logit_cap: float = 16.0, |
| | sliding_window_size: int = 2048, |
| | p_attention: bool = False, |
| | tie_embeddings: bool = False, |
| | unet: bool = False, |
| | mlm: bool = False, |
| | token_dropout: bool = True, |
| | **kwargs, |
| | ): |
| | super().__init__(**kwargs) |
| | self.hidden_size = hidden_size |
| | self.num_attention_heads = num_attention_heads |
| | self.num_hidden_layers = num_hidden_layers |
| | self.num_att_tokens = num_att_tokens |
| | self.vocab_size = vocab_size |
| | self.expansion_ratio = expansion_ratio |
| | self.soft_logit_cap = soft_logit_cap |
| | self.attention_soft_cap = attention_soft_cap |
| | self.add_att_soft_cap = add_att_soft_cap |
| | self.sliding_window_size = sliding_window_size |
| | self.p_attention = p_attention |
| | self.tie_embeddings = tie_embeddings |
| | self.unet = unet |
| | self.mlm = mlm |
| | self.token_dropout = token_dropout |
| |
|
| |
|
| | @dataclass |
| | class ESMOutput(ModelOutput): |
| | loss: Optional[torch.Tensor] = None |
| | logits: Optional[torch.Tensor] = None |
| | last_hidden_state: Optional[torch.Tensor] = None |
| |
|
| |
|
| | class ValueEmbedding(nn.Module): |
| | def __init__(self, config: PLMConfig): |
| | super().__init__() |
| | self.embed = nn.ModuleList([ |
| | nn.Embedding(config.vocab_size, config.hidden_size) |
| | for _ in range(config.num_hidden_layers // 2) |
| | ]) |
| |
|
| | def forward(self, inputs: torch.Tensor) -> List[torch.Tensor]: |
| | ve = [emb(inputs) for emb in self.embed] |
| | ve += reversed(ve) |
| | return ve |
| |
|
| |
|
| | class LMHead(nn.Module): |
| | def __init__(self, hidden_size: int, vocab_size: int, soft_logit_cap: float = 30.0): |
| | super().__init__() |
| | self.dense = nn.Linear(hidden_size, hidden_size) |
| | self.decoder = nn.Linear(hidden_size, vocab_size, bias=False) |
| | self.bias = nn.Parameter(torch.zeros(vocab_size)) |
| | self.soft_logit_cap = soft_logit_cap |
| | self.act = nn.GELU() |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x = self.dense(norm(x)) |
| | x = self.act(x) |
| | x = self.decoder(x) + self.bias |
| | return self.soft_logit_cap * torch.tanh(x / self.soft_logit_cap) |
| |
|
| |
|
| | class TransformerBlock(nn.Module): |
| | def __init__(self, config: PLMConfig): |
| | super().__init__() |
| | self.config = config |
| | if config.p_attention: |
| | self.attn = MultiHeadPAttention(config) |
| | else: |
| | self.attn = SelfAttention(config) |
| | self.mlp = MLP(config) |
| | self.unet = config.unet |
| | if config.unet: |
| | self.lambdas = nn.Parameter(torch.tensor([1., 0.])) |
| | |
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | vi: Optional[torch.Tensor] = None, |
| | x0: Optional[torch.Tensor] = None, |
| | last_eos: Optional[int] = None, |
| | **kwargs, |
| | ) -> torch.Tensor: |
| | if self.unet: |
| | x = self.lambdas[0] * x + self.lambdas[1] * x0 |
| | x = x + self.attn( |
| | x=norm(x), |
| | attention_mask=attention_mask, |
| | vi=vi, |
| | last_eos=last_eos, |
| | **kwargs, |
| | ) |
| | else: |
| | x = x + self.attn( |
| | x=norm(x), |
| | attention_mask=attention_mask, |
| | last_eos=last_eos, |
| | **kwargs, |
| | ) |
| | x = x + self.mlp(norm(x)) |
| | return x |
| |
|
| |
|
| | class Transformer(nn.Module): |
| | def __init__(self, config: PLMConfig): |
| | super().__init__() |
| | self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_hidden_layers)]) |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | **kwargs, |
| | ) -> torch.Tensor: |
| | for layer in self.layers: |
| | x = layer( |
| | x=x, |
| | attention_mask=attention_mask, |
| | **kwargs, |
| | ) |
| | return x |
| | |
| |
|
| | class UnetTransformer(nn.Module): |
| | def __init__(self, config: PLMConfig): |
| | super().__init__() |
| | assert config.num_hidden_layers % 2 == 0 |
| | self.num_encoder_layers = config.num_hidden_layers // 2 |
| | self.num_decoder_layers = config.num_hidden_layers // 2 |
| |
|
| | self.skip_weights = nn.Parameter(torch.ones(self.num_decoder_layers)) |
| |
|
| | self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_hidden_layers)]) |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | ve: List[torch.Tensor], |
| | attention_mask: Optional[torch.Tensor] = None, |
| | **kwargs, |
| | ) -> torch.Tensor: |
| | x0 = x |
| | ve_enc, ve_dec = ve[:self.num_encoder_layers], ve[self.num_encoder_layers:] |
| | skip_connections = [] |
| | for i in range(self.num_encoder_layers): |
| | x = self.layers[i]( |
| | x=x, |
| | attention_mask=attention_mask, |
| | vi=ve_enc[i], |
| | x0=x0, |
| | **kwargs, |
| | ) |
| | skip_connections.append(x) |
| | |
| | for i in range(self.num_decoder_layers): |
| | x = x + self.skip_weights[i] * skip_connections.pop() |
| | x = self.layers[self.num_encoder_layers + i]( |
| | x=x, |
| | attention_mask=attention_mask, |
| | vi=ve_dec[i], |
| | x0=x0, |
| | **kwargs, |
| | ) |
| | return x |
| |
|
| |
|
| | class PLM(PreTrainedModel): |
| | config_class = PLMConfig |
| | def __init__(self, config: PLMConfig): |
| | super().__init__(config) |
| | self.config = config |
| | self.tokenizer = EsmTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D') |
| | self.cls_token_id = self.tokenizer.cls_token_id |
| | self.eos_token_id = self.tokenizer.eos_token_id |
| | self.pad_token_id = self.tokenizer.pad_token_id |
| | self.mask_token_id = self.tokenizer.mask_token_id |
| | self.token_dropout = config.token_dropout |
| |
|
| | self.vocab_size = config.vocab_size |
| | self.n_heads = config.num_attention_heads |
| | self.sliding_window_size = config.sliding_window_size |
| |
|
| | self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) |
| |
|
| | self.unet = config.unet |
| | if config.unet: |
| | self.transformer = UnetTransformer(config) |
| | self.value_embeds = ValueEmbedding(config) |
| | else: |
| | self.transformer = Transformer(config) |
| | |
| | self.lm_head = LMHead(config.hidden_size, config.vocab_size, config.soft_logit_cap) |
| | if config.tie_embeddings: |
| | self.lm_head.decoder.weight = self.embedding.weight |
| |
|
| | self.mlm = config.mlm |
| | self.ce = nn.CrossEntropyLoss(ignore_index=-100, reduction='mean') |
| |
|
| | def get_last_hidden_state(self, input_ids: torch.Tensor, sliding_window_size: int) -> torch.Tensor: |
| | docs = (input_ids == self.cls_token_id).cumsum(0) |
| | eos_positions = (input_ids == self.eos_token_id).nonzero() |
| | if eos_positions.numel() > 0: |
| | last_eos = eos_positions[-1].squeeze() |
| | else: |
| | |
| | last_eos = len(input_ids) - 1 |
| | seq_len = len(input_ids) |
| |
|
| | def doc_mask_mod(b, h, q_idx, kv_idx): |
| | bidirectional_sliding_window_mask = torch.abs(q_idx - kv_idx) < sliding_window_size |
| | doc_mask = docs[q_idx] == docs[kv_idx] |
| | pad_mask = (q_idx <= last_eos) & (kv_idx <= last_eos) |
| | return bidirectional_sliding_window_mask & doc_mask & pad_mask |
| |
|
| | attention_mask = create_block_mask( |
| | mask_mod=doc_mask_mod, |
| | B=1, |
| | H=self.n_heads, |
| | Q_LEN=seq_len, |
| | KV_LEN=seq_len, |
| | device=input_ids.device, |
| | ) |
| |
|
| | x = self.embedding(input_ids) |
| |
|
| | if self.token_dropout: |
| | x = x.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0) |
| | real_token_count = len(input_ids[:last_eos]) |
| | mask_ratio_observed = (input_ids == self.mask_token_id).sum().float() / real_token_count |
| | x = (x * (1 - mask_ratio_observed)).to(x.dtype) |
| |
|
| | x = norm(x) |
| | if self.unet: |
| | ve = self.value_embeds(input_ids) |
| | x = self.transformer( |
| | x=x, |
| | ve=ve, |
| | attention_mask=attention_mask, |
| | last_eos=last_eos, |
| | ) |
| | else: |
| | x = self.transformer( |
| | x=x, |
| | attention_mask=attention_mask, |
| | last_eos=last_eos, |
| | ) |
| | return x |
| |
|
| | def get_vector_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: |
| | docs = (input_ids == self.cls_token_id).cumsum(0) |
| | x = self.get_last_hidden_state(input_ids) |
| | x = x.view(-1, self.config.hidden_size) |
| | |
| | |
| | |
| | num_docs = docs.max().item() |
| | doc_ids = docs - 1 |
| | |
| | doc_embeds = [] |
| | for doc_idx in range(num_docs): |
| | mask = (doc_ids == doc_idx) |
| | |
| | doc_embeds.append(x[mask].mean(dim=0)) |
| | |
| | return torch.stack(doc_embeds, dim=0) |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | labels: torch.Tensor, |
| | mask_rate: torch.Tensor, |
| | sliding_window_size: Optional[int] = None, |
| | ) -> torch.Tensor: |
| | if sliding_window_size is None: |
| | sliding_window_size = self.sliding_window_size |
| |
|
| | last_hidden_state = self.get_last_hidden_state(input_ids, sliding_window_size) |
| |
|
| | lm_logits = self.lm_head(norm(last_hidden_state)) |
| |
|
| | loss = self.ce( |
| | lm_logits.view(-1, self.vocab_size), |
| | labels.view(-1).long() |
| | ) |
| | |
| | |
| |
|
| | if torch.isnan(loss): |
| | torch.set_printoptions(profile="full") |
| | print("⚠️ NaN loss detected!") |
| | print("Input IDs:", input_ids.detach().cpu()) |
| | print("Labels:", labels.detach().cpu()) |
| | print("Logits:", lm_logits.detach().cpu()) |
| |
|
| | labels_cpu = labels.detach().cpu() |
| | if torch.all(labels_cpu == -100): |
| | print("⚠️ All labels are -100!") |
| | else: |
| | unique_labels = torch.unique(labels_cpu) |
| | print("Unique labels present:", unique_labels) |
| | |
| | return loss |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | from torchinfo import summary |
| | config = PLMConfig( |
| | hidden_size=768, |
| | num_attention_heads=6, |
| | num_hidden_layers=24, |
| | expansion_ratio=8/3, |
| | unet=True, |
| | ) |
| | model = PLM(config).cuda() |
| | summary(model) |
| |
|
| | input_ids = torch.randint(0, 33, (1, 100)).cuda() |
| | output = model(input_ids) |
| | print(f"loss: {output.loss}") |
| | print(f"logits: {output.logits[0].shape}") |
| | print(f"labels: {output.logits[1].shape}") |
| | print(f"last_hidden_state: {output.last_hidden_state.shape}") |