File size: 3,433 Bytes
858826c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | import torch
import torch.nn as nn
from typing import List, Dict, Any
class HolderDistributionEncoder(nn.Module):
"""
Encodes a list of top holders (wallet embeddings + holding percentages)
into a single fixed-size embedding representing the holder distribution.
It uses a Transformer Encoder to capture patterns and relationships.
"""
def __init__(self,
wallet_embedding_dim: int,
output_dim: int,
nhead: int = 4,
num_layers: int = 2,
dtype: torch.dtype = torch.float16):
super().__init__()
self.wallet_embedding_dim = wallet_embedding_dim
self.output_dim = output_dim
self.dtype = dtype
# 1. MLP to project holding percentage to the wallet embedding dimension
self.pct_proj = nn.Sequential(
nn.Linear(1, wallet_embedding_dim // 4),
nn.GELU(),
nn.Linear(wallet_embedding_dim // 4, wallet_embedding_dim)
).to(dtype)
# 2. Transformer Encoder to process the sequence of holders
encoder_layer = nn.TransformerEncoderLayer(
d_model=wallet_embedding_dim,
nhead=nhead,
dim_feedforward=wallet_embedding_dim * 4,
dropout=0.1,
activation='gelu',
batch_first=True,
dtype=dtype
)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# 3. A learnable [CLS] token to aggregate the sequence information
self.cls_token = nn.Parameter(torch.randn(1, 1, wallet_embedding_dim, dtype=dtype))
# 4. Final projection layer to get the desired output dimension
self.final_proj = nn.Linear(wallet_embedding_dim, output_dim).to(dtype)
def forward(self, holder_data: List[Dict[str, Any]]) -> torch.Tensor:
"""
Args:
holder_data: A list of dictionaries, where each dict contains:
'wallet_embedding': A tensor of shape [wallet_embedding_dim]
'pct': The holding percentage as a float.
Returns:
A tensor of shape [1, output_dim] representing the entire distribution.
"""
if not holder_data:
# Return a zero tensor if there are no holders
return torch.zeros(1, self.output_dim, device=self.cls_token.device, dtype=self.dtype)
# Prepare inputs for the transformer
wallet_embeds = torch.stack([d['wallet_embedding'] for d in holder_data])
holder_pcts = torch.tensor([[d['pct']] for d in holder_data], device=wallet_embeds.device, dtype=self.dtype)
# Project percentages and add to wallet embeddings to create holder features
pct_embeds = self.pct_proj(holder_pcts)
holder_inputs = (wallet_embeds + pct_embeds).unsqueeze(0) # Add batch dimension
# Prepend the [CLS] token
batch_size = holder_inputs.size(0)
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
transformer_input = torch.cat((cls_tokens, holder_inputs), dim=1)
# Pass through the transformer
transformer_output = self.transformer_encoder(transformer_input)
# Get the embedding of the [CLS] token (the first token)
cls_embedding = transformer_output[:, 0, :]
# Project to the final output dimension
return self.final_proj(cls_embedding) |