import torch import torch.nn as nn class WalletSetEncoder(nn.Module): """ Encodes a variable-length set of embeddings into a single fixed-size vector using a Transformer encoder and a [CLS] token. This is used to pool: 1. A wallet's `wallet_holdings` (a set of [holding_embeds]). 2. A wallet's `Neo4J links` (a set of [link_embeds]). 3. A wallet's `deployed_tokens` (a set of [token_name_embeds]). """ def __init__( self, d_model: int, nhead: int, num_layers: int, dim_feedforward: int = 2048, dropout: float = 0.1, dtype: torch.dtype = torch.float16 ): """ Initializes the Set Encoder. Args: d_model (int): The input/output dimension of the embeddings. nhead (int): Number of attention heads. num_layers (int): Number of transformer layers. dim_feedforward (int): Hidden dimension of the feedforward network. dropout (float): Dropout rate. dtype (torch.dtype): Data type. """ super().__init__() self.d_model = d_model self.dtype = dtype # The learnable [CLS] token, which will aggregate the set representation self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) nn.init.normal_(self.cls_token, std=0.02) encoder_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True ) self.transformer_encoder = nn.TransformerEncoder( encoder_layer, num_layers=num_layers ) self.output_norm = nn.LayerNorm(d_model) self.to(dtype) def forward( self, item_embeds: torch.Tensor, src_key_padding_mask: torch.Tensor ) -> torch.Tensor: """ Forward pass. Args: item_embeds (torch.Tensor): The batch of item embeddings. Shape: [batch_size, seq_len, d_model] src_key_padding_mask (torch.Tensor): The boolean padding mask for the items, where True indicates a padded position that should be ignored. Shape: [batch_size, seq_len] Returns: torch.Tensor: The pooled set embedding. Shape: [batch_size, d_model] """ batch_size = item_embeds.shape[0] # 1. Create [CLS] token batch and concatenate with item embeddings cls_tokens = self.cls_token.expand(batch_size, -1, -1).to(self.dtype) x = torch.cat([cls_tokens, item_embeds], dim=1) # 2. Create the mask for the [CLS] token (it is never masked) cls_mask = torch.zeros(batch_size, 1, device=src_key_padding_mask.device, dtype=torch.bool) # 3. Concatenate the [CLS] mask with the item mask full_padding_mask = torch.cat([cls_mask, src_key_padding_mask], dim=1) # 4. Pass through Transformer transformer_output = self.transformer_encoder( x, src_key_padding_mask=full_padding_mask ) # 5. Extract the output of the [CLS] token (the first token in the sequence) cls_output = transformer_output[:, 0, :] return self.output_norm(cls_output)