| | """AirRep model implementation.""" |
| |
|
| | from typing import Optional |
| | import torch |
| | import torch.nn as nn |
| | from transformers import BertModel, BertConfig, PreTrainedModel |
| | from transformers.modeling_outputs import BaseModelOutput |
| |
|
| |
|
| | def mean_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
| | """Apply mean pooling to hidden states.""" |
| | last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) |
| | return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] |
| |
|
| |
|
| | class AirRepConfig(BertConfig): |
| | """Configuration class for AirRep model.""" |
| |
|
| | model_type = "airrep" |
| |
|
| | def __init__( |
| | self, |
| | **kwargs |
| | ): |
| | super().__init__(**kwargs) |
| |
|
| |
|
| | class AirRepModel(PreTrainedModel): |
| | """ |
| | AirRep model with BERT encoder and projection layer. |
| | |
| | This is a standalone model, not a wrapper. |
| | """ |
| |
|
| | config_class = AirRepConfig |
| | base_model_prefix = "airrep" |
| |
|
| | def __init__(self, config: AirRepConfig): |
| | super().__init__(config) |
| | self.config = config |
| |
|
| | |
| | self.bert = BertModel(config, add_pooling_layer=False) |
| |
|
| | |
| | self.projector = nn.Linear( |
| | config.hidden_size, |
| | config.hidden_size, |
| | dtype=torch.bfloat16 |
| | ) |
| |
|
| | |
| | self.post_init() |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | token_type_ids: Optional[torch.Tensor] = None, |
| | **kwargs |
| | ) -> torch.Tensor: |
| | """ |
| | Forward pass. |
| | |
| | Args: |
| | input_ids: Input token IDs |
| | attention_mask: Attention mask |
| | token_type_ids: Token type IDs |
| | |
| | Returns: |
| | Pooled and projected embeddings (batch_size, hidden_size) |
| | """ |
| | |
| | outputs = self.bert( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | token_type_ids=token_type_ids, |
| | output_hidden_states=True, |
| | return_dict=True, |
| | ) |
| |
|
| | |
| | last_hidden_state = outputs.last_hidden_state |
| | if attention_mask is None: |
| | attention_mask = torch.ones_like(input_ids) |
| | pooled = mean_pooling(last_hidden_state, attention_mask) |
| |
|
| | |
| | projected = self.projector(pooled) |
| |
|
| | return projected |
| |
|
| | def save_pretrained(self, save_directory: str, **kwargs): |
| | """Save model and config.""" |
| | super().save_pretrained(save_directory, **kwargs) |