| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import MaskedLMOutput |
| from .configuration_nicheformer import NicheformerConfig |
| import math |
|
|
| class PositionalEncoding(nn.Module): |
| """Positional encoding using sine and cosine functions.""" |
|
|
| def __init__(self, d_model: int, max_seq_len: int): |
| super().__init__() |
| encoding = torch.zeros(max_seq_len, d_model) |
| position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) |
|
|
| encoding[:, 0::2] = torch.sin(position * div_term) |
| encoding[:, 1::2] = torch.cos(position * div_term) |
| encoding = encoding.unsqueeze(0) |
|
|
| self.register_buffer('encoding', encoding, persistent=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """Add positional encoding to input tensor.""" |
| return x + self.encoding[:, :x.size(1)] |
|
|
| class NicheformerPreTrainedModel(PreTrainedModel): |
| """Base class for Nicheformer models.""" |
| |
| config_class = NicheformerConfig |
| base_model_prefix = "nicheformer" |
| supports_gradient_checkpointing = True |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| nn.init.xavier_normal_(module.weight) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
|
|
| class NicheformerModel(NicheformerPreTrainedModel): |
| def __init__(self, config: NicheformerConfig): |
| super().__init__(config) |
| |
| |
| self.encoder_layer = nn.TransformerEncoderLayer( |
| d_model=config.dim_model, |
| nhead=config.nheads, |
| dim_feedforward=config.dim_feedforward, |
| batch_first=config.batch_first, |
| dropout=config.dropout, |
| layer_norm_eps=1e-12 |
| ) |
| self.encoder = nn.TransformerEncoder( |
| encoder_layer=self.encoder_layer, |
| num_layers=config.nlayers, |
| enable_nested_tensor=False |
| ) |
|
|
| |
| self.embeddings = nn.Embedding( |
| num_embeddings=config.n_tokens+5, |
| embedding_dim=config.dim_model, |
| padding_idx=1 |
| ) |
|
|
| if config.learnable_pe: |
| self.positional_embedding = nn.Embedding( |
| num_embeddings=config.context_length, |
| embedding_dim=config.dim_model |
| ) |
| self.dropout = nn.Dropout(p=config.dropout) |
| self.register_buffer('pos', torch.arange(0, config.context_length, dtype=torch.long)) |
| else: |
| self.positional_embedding = PositionalEncoding( |
| d_model=config.dim_model, |
| max_seq_len=config.context_length |
| ) |
|
|
| |
| self.post_init() |
|
|
| def forward(self, input_ids, attention_mask=None): |
| token_embedding = self.embeddings(input_ids) |
|
|
| if self.config.learnable_pe: |
| pos_embedding = self.positional_embedding(self.pos.to(token_embedding.device)) |
| embeddings = self.dropout(token_embedding + pos_embedding) |
| else: |
| embeddings = self.positional_embedding(token_embedding) |
|
|
| |
| |
| if attention_mask is not None: |
| attention_mask = ~attention_mask.bool() |
|
|
| transformer_output = self.encoder( |
| embeddings, |
| src_key_padding_mask=attention_mask if attention_mask is not None else None, |
| is_causal=False |
| ) |
|
|
| return transformer_output |
| |
| def get_embeddings(self, input_ids, attention_mask=None, layer: int = -1, with_context: bool = False) -> torch.Tensor: |
| """Get embeddings from the model. |
| |
| Args: |
| input_ids: Input token IDs |
| attention_mask: Attention mask |
| layer: Which transformer layer to extract embeddings from (-1 means last layer) |
| with_context: Whether to include context tokens in the embeddings |
| |
| Returns: |
| torch.Tensor: Embeddings tensor |
| """ |
| |
| token_embedding = self.embeddings(input_ids) |
| |
| if self.config.learnable_pe: |
| pos_embedding = self.positional_embedding(self.pos.to(token_embedding.device)) |
| embeddings = self.dropout(token_embedding + pos_embedding) |
| else: |
| embeddings = self.positional_embedding(token_embedding) |
|
|
| |
| if layer < 0: |
| layer = self.config.nlayers + layer |
|
|
| |
| if attention_mask is not None: |
| padding_mask = ~attention_mask.bool() |
| else: |
| padding_mask = None |
|
|
| |
| for i in range(layer + 1): |
| embeddings = self.encoder.layers[i]( |
| embeddings, |
| src_key_padding_mask=padding_mask, |
| is_causal=False |
| ) |
|
|
| |
| if not with_context: |
| embeddings = embeddings[:, 3:, :] |
|
|
| |
| embeddings = embeddings.mean(dim=1) |
|
|
| return embeddings |
|
|
| class NicheformerForMaskedLM(NicheformerPreTrainedModel): |
| def __init__(self, config: NicheformerConfig): |
| super().__init__(config) |
| |
| self.nicheformer = NicheformerModel(config) |
| self.classifier_head = nn.Linear(config.dim_model, config.n_tokens, bias=False) |
| self.classifier_head.bias = nn.Parameter(torch.zeros(config.n_tokens)) |
| |
| |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| labels=None, |
| return_dict=None, |
| apply_masking=False, |
| ): |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| if apply_masking: |
| batch = { |
| 'input_ids': input_ids, |
| 'attention_mask': attention_mask |
| } |
| masked_batch = complete_masking(batch, self.config.masking_p, self.config.n_tokens) |
| input_ids = masked_batch['masked_indices'] |
| labels = masked_batch['input_ids'] |
| mask = masked_batch['mask'] |
| |
| labels = torch.where(mask, labels, torch.tensor(-100, device=labels.device)).long() |
|
|
| transformer_output = self.nicheformer( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| ) |
|
|
| prediction_scores = self.classifier_head(transformer_output) |
|
|
| masked_lm_loss = None |
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss() |
| masked_lm_loss = loss_fct( |
| prediction_scores.view(-1, self.config.n_tokens), |
| labels.view(-1) |
| ) |
|
|
| if not return_dict: |
| output = (prediction_scores,) + (transformer_output,) |
| return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output |
|
|
| return MaskedLMOutput( |
| loss=masked_lm_loss, |
| logits=prediction_scores, |
| hidden_states=transformer_output, |
| ) |
| |
| def get_embeddings(self, input_ids, attention_mask=None, layer: int = -1, with_context: bool = False) -> torch.Tensor: |
| """Get embeddings from the model. |
| |
| Args: |
| input_ids: Input token IDs |
| attention_mask: Attention mask |
| layer: Which transformer layer to extract embeddings from (-1 means last layer) |
| with_context: Whether to include context tokens in the embeddings |
| |
| Returns: |
| torch.Tensor: Embeddings tensor |
| """ |
| return self.nicheformer.get_embeddings( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| layer=layer, |
| with_context=with_context |
| ) |
|
|
| def complete_masking(batch, masking_p, n_tokens): |
| """Apply masking to input batch for masked language modeling. |
| |
| Args: |
| batch (dict): Input batch containing 'input_ids' and 'attention_mask' |
| masking_p (float): Probability of masking a token |
| n_tokens (int): Total number of tokens in vocabulary |
| |
| Returns: |
| dict: Batch with masked indices and masking information |
| """ |
| device = batch['input_ids'].device |
| input_ids = batch['input_ids'] |
| attention_mask = batch['attention_mask'] |
|
|
| |
| prob = torch.rand(input_ids.shape, device=device) |
| mask = (prob < masking_p) & (input_ids != PAD_TOKEN) & (input_ids != CLS_TOKEN) |
|
|
| |
| |
| |
| |
| masked_indices = input_ids.clone() |
|
|
| |
| num_tokens_to_mask = mask.sum().item() |
|
|
| |
| mask_mask = torch.rand(num_tokens_to_mask, device=device) < 0.8 |
| random_mask = (torch.rand(num_tokens_to_mask, device=device) < 0.5) & ~mask_mask |
|
|
| |
| masked_indices[mask] = torch.where( |
| mask_mask, |
| torch.tensor(MASK_TOKEN, device=device, dtype=torch.long), |
| masked_indices[mask] |
| ) |
| |
| |
| random_tokens = torch.randint( |
| 3, n_tokens, |
| (random_mask.sum(),), |
| device=device, |
| dtype=torch.long |
| ) |
| masked_indices[mask][random_mask] = random_tokens |
| |
| |
| |
| return { |
| 'masked_indices': masked_indices, |
| 'attention_mask': attention_mask, |
| 'mask': mask, |
| 'input_ids': input_ids |
| } |