File size: 3,433 Bytes
858826c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import torch.nn as nn
from typing import List, Dict, Any

class HolderDistributionEncoder(nn.Module):
    """
    Encodes a list of top holders (wallet embeddings + holding percentages)
    into a single fixed-size embedding representing the holder distribution.
    It uses a Transformer Encoder to capture patterns and relationships.
    """
    def __init__(self,
                 wallet_embedding_dim: int,
                 output_dim: int,
                 nhead: int = 4,
                 num_layers: int = 2,
                 dtype: torch.dtype = torch.float16):
        super().__init__()
        self.wallet_embedding_dim = wallet_embedding_dim
        self.output_dim = output_dim
        self.dtype = dtype

        # 1. MLP to project holding percentage to the wallet embedding dimension
        self.pct_proj = nn.Sequential(
            nn.Linear(1, wallet_embedding_dim // 4),
            nn.GELU(),
            nn.Linear(wallet_embedding_dim // 4, wallet_embedding_dim)
        ).to(dtype)

        # 2. Transformer Encoder to process the sequence of holders
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=wallet_embedding_dim,
            nhead=nhead,
            dim_feedforward=wallet_embedding_dim * 4,
            dropout=0.1,
            activation='gelu',
            batch_first=True,
            dtype=dtype
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # 3. A learnable [CLS] token to aggregate the sequence information
        self.cls_token = nn.Parameter(torch.randn(1, 1, wallet_embedding_dim, dtype=dtype))

        # 4. Final projection layer to get the desired output dimension
        self.final_proj = nn.Linear(wallet_embedding_dim, output_dim).to(dtype)

    def forward(self, holder_data: List[Dict[str, Any]]) -> torch.Tensor:
        """
        Args:
            holder_data: A list of dictionaries, where each dict contains:
                         'wallet_embedding': A tensor of shape [wallet_embedding_dim]
                         'pct': The holding percentage as a float.

        Returns:
            A tensor of shape [1, output_dim] representing the entire distribution.
        """
        if not holder_data:
            # Return a zero tensor if there are no holders
            return torch.zeros(1, self.output_dim, device=self.cls_token.device, dtype=self.dtype)

        # Prepare inputs for the transformer
        wallet_embeds = torch.stack([d['wallet_embedding'] for d in holder_data])
        holder_pcts = torch.tensor([[d['pct']] for d in holder_data], device=wallet_embeds.device, dtype=self.dtype)

        # Project percentages and add to wallet embeddings to create holder features
        pct_embeds = self.pct_proj(holder_pcts)
        holder_inputs = (wallet_embeds + pct_embeds).unsqueeze(0) # Add batch dimension

        # Prepend the [CLS] token
        batch_size = holder_inputs.size(0)
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        transformer_input = torch.cat((cls_tokens, holder_inputs), dim=1)

        # Pass through the transformer
        transformer_output = self.transformer_encoder(transformer_input)

        # Get the embedding of the [CLS] token (the first token)
        cls_embedding = transformer_output[:, 0, :]

        # Project to the final output dimension
        return self.final_proj(cls_embedding)