import torch import torch.nn as nn from typing import List, Dict, Any class HolderDistributionEncoder(nn.Module): """ Encodes a list of top holders (wallet embeddings + holding percentages) into a single fixed-size embedding representing the holder distribution. It uses a Transformer Encoder to capture patterns and relationships. """ def __init__(self, wallet_embedding_dim: int, output_dim: int, nhead: int = 4, num_layers: int = 2, dtype: torch.dtype = torch.float16): super().__init__() self.wallet_embedding_dim = wallet_embedding_dim self.output_dim = output_dim self.dtype = dtype # 1. MLP to project holding percentage to the wallet embedding dimension self.pct_proj = nn.Sequential( nn.Linear(1, wallet_embedding_dim // 4), nn.GELU(), nn.Linear(wallet_embedding_dim // 4, wallet_embedding_dim) ).to(dtype) # 2. Transformer Encoder to process the sequence of holders encoder_layer = nn.TransformerEncoderLayer( d_model=wallet_embedding_dim, nhead=nhead, dim_feedforward=wallet_embedding_dim * 4, dropout=0.1, activation='gelu', batch_first=True, dtype=dtype ) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) # 3. A learnable [CLS] token to aggregate the sequence information self.cls_token = nn.Parameter(torch.randn(1, 1, wallet_embedding_dim, dtype=dtype)) # 4. Final projection layer to get the desired output dimension self.final_proj = nn.Linear(wallet_embedding_dim, output_dim).to(dtype) def forward(self, holder_data: List[Dict[str, Any]]) -> torch.Tensor: """ Args: holder_data: A list of dictionaries, where each dict contains: 'wallet_embedding': A tensor of shape [wallet_embedding_dim] 'pct': The holding percentage as a float. Returns: A tensor of shape [1, output_dim] representing the entire distribution. """ if not holder_data: # Return a zero tensor if there are no holders return torch.zeros(1, self.output_dim, device=self.cls_token.device, dtype=self.dtype) # Prepare inputs for the transformer wallet_embeds = torch.stack([d['wallet_embedding'] for d in holder_data]) holder_pcts = torch.tensor([[d['pct']] for d in holder_data], device=wallet_embeds.device, dtype=self.dtype) # Project percentages and add to wallet embeddings to create holder features pct_embeds = self.pct_proj(holder_pcts) holder_inputs = (wallet_embeds + pct_embeds).unsqueeze(0) # Add batch dimension # Prepend the [CLS] token batch_size = holder_inputs.size(0) cls_tokens = self.cls_token.expand(batch_size, -1, -1) transformer_input = torch.cat((cls_tokens, holder_inputs), dim=1) # Pass through the transformer transformer_output = self.transformer_encoder(transformer_input) # Get the embedding of the [CLS] token (the first token) cls_embedding = transformer_output[:, 0, :] # Project to the final output dimension return self.final_proj(cls_embedding)