File size: 2,600 Bytes
411a334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""AirRep model implementation."""

from typing import Optional
import torch
import torch.nn as nn
from transformers import BertModel, BertConfig, PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput


def mean_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    """Apply mean pooling to hidden states."""
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]


class AirRepConfig(BertConfig):
    """Configuration class for AirRep model."""

    model_type = "airrep"

    def __init__(
        self,
        **kwargs
    ):
        super().__init__(**kwargs)


class AirRepModel(PreTrainedModel):
    """
    AirRep model with BERT encoder and projection layer.

    This is a standalone model, not a wrapper.
    """

    config_class = AirRepConfig
    base_model_prefix = "airrep"

    def __init__(self, config: AirRepConfig):
        super().__init__(config)
        self.config = config

        # BERT encoder
        self.bert = BertModel(config, add_pooling_layer=False)

        # Projection layer
        self.projector = nn.Linear(
            config.hidden_size,
            config.hidden_size,
            dtype=torch.bfloat16
        )

        # Initialize weights
        self.post_init()

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        **kwargs
    ) -> torch.Tensor:
        """
        Forward pass.

        Args:
            input_ids: Input token IDs
            attention_mask: Attention mask
            token_type_ids: Token type IDs

        Returns:
            Pooled and projected embeddings (batch_size, hidden_size)
        """
        # Get BERT outputs
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            output_hidden_states=True,
            return_dict=True,
        )

        # Mean pooling
        last_hidden_state = outputs.last_hidden_state
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        pooled = mean_pooling(last_hidden_state, attention_mask)

        # Project
        projected = self.projector(pooled)

        return projected

    def save_pretrained(self, save_directory: str, **kwargs):
        """Save model and config."""
        super().save_pretrained(save_directory, **kwargs)