| import torch |
| import torch.nn as nn |
|
|
| class WalletSetEncoder(nn.Module): |
| """ |
| Encodes a variable-length set of embeddings into a single fixed-size vector |
| using a Transformer encoder and a [CLS] token. |
| |
| This is used to pool: |
| 1. A wallet's `wallet_holdings` (a set of [holding_embeds]). |
| 2. A wallet's `Neo4J links` (a set of [link_embeds]). |
| 3. A wallet's `deployed_tokens` (a set of [token_name_embeds]). |
| """ |
| def __init__( |
| self, |
| d_model: int, |
| nhead: int, |
| num_layers: int, |
| dim_feedforward: int = 2048, |
| dropout: float = 0.1, |
| dtype: torch.dtype = torch.float16 |
| ): |
| """ |
| Initializes the Set Encoder. |
| |
| Args: |
| d_model (int): The input/output dimension of the embeddings. |
| nhead (int): Number of attention heads. |
| num_layers (int): Number of transformer layers. |
| dim_feedforward (int): Hidden dimension of the feedforward network. |
| dropout (float): Dropout rate. |
| dtype (torch.dtype): Data type. |
| """ |
| super().__init__() |
| self.d_model = d_model |
| self.dtype = dtype |
|
|
| |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) |
| nn.init.normal_(self.cls_token, std=0.02) |
|
|
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=d_model, |
| nhead=nhead, |
| dim_feedforward=dim_feedforward, |
| dropout=dropout, |
| batch_first=True |
| ) |
| self.transformer_encoder = nn.TransformerEncoder( |
| encoder_layer, |
| num_layers=num_layers |
| ) |
| self.output_norm = nn.LayerNorm(d_model) |
|
|
| self.to(dtype) |
|
|
| def forward( |
| self, |
| item_embeds: torch.Tensor, |
| src_key_padding_mask: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| Forward pass. |
| |
| Args: |
| item_embeds (torch.Tensor): |
| The batch of item embeddings. |
| Shape: [batch_size, seq_len, d_model] |
| src_key_padding_mask (torch.Tensor): |
| The boolean padding mask for the items, where True indicates |
| a padded position that should be ignored. |
| Shape: [batch_size, seq_len] |
| |
| Returns: |
| torch.Tensor: The pooled set embedding. |
| Shape: [batch_size, d_model] |
| """ |
| batch_size = item_embeds.shape[0] |
|
|
| |
| cls_tokens = self.cls_token.expand(batch_size, -1, -1).to(self.dtype) |
| x = torch.cat([cls_tokens, item_embeds], dim=1) |
| |
| |
| cls_mask = torch.zeros(batch_size, 1, device=src_key_padding_mask.device, dtype=torch.bool) |
| |
| |
| full_padding_mask = torch.cat([cls_mask, src_key_padding_mask], dim=1) |
|
|
| |
| transformer_output = self.transformer_encoder( |
| x, |
| src_key_padding_mask=full_padding_mask |
| ) |
|
|
| |
| cls_output = transformer_output[:, 0, :] |
|
|
| return self.output_norm(cls_output) |