File size: 6,908 Bytes
858826c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5800f64
 
 
 
 
858826c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# token_encoder.py (FIXED)

import torch
import torch.nn as nn
from typing import List, Any
from PIL import Image  

from models.multi_modal_processor import MultiModalEncoder
from models.wallet_set_encoder import WalletSetEncoder # Using your set encoder
from models.vocabulary import NUM_PROTOCOLS

class TokenEncoder(nn.Module):
    """
    Encodes a token's core identity into a single <TokenVibeEmbedding>.

    FIXED: This version uses a robust fusion architecture and provides
    a dynamic, smaller output dimension (e.g., 2048) suitable for
    being an input to a larger model.
    """
    def __init__(
        self,
        multi_dim: int, # NEW: Pass the dimension directly
        output_dim: int = 2048,
        internal_dim: int = 1024, # INCREASED: Better balance between bottleneck and capacity
        protocol_embed_dim: int = 64,
        vanity_embed_dim: int = 32, # NEW: Small embedding for the vanity flag
        nhead: int = 4,
        num_layers: int = 1,
        dtype: torch.dtype = torch.float16
    ):
        """
        Initializes the TokenEncoder.

        Args:
            siglip_dim (int): The embedding dimension of the multimodal encoder (e.g., 1152).
            output_dim (int):
                The final dimension of the <TokenVibeEmbedding> (e.g., 2048).
            internal_dim (int):
                The shared dimension for the internal fusion transformer (e.g., 1024).
            protocol_embed_dim (int):
                Small dimension for the protocol ID (e.g., 64).
            vanity_embed_dim (int):
                Small dimension for the is_vanity boolean flag.
            nhead (int):
                Attention heads for the fusion transformer.
            num_layers (int):
                Layers for the fusion transformer.
            dtype (torch.dtype):
                The data type (e.g., torch.float16).
        """
        super().__init__()
        self.output_dim = output_dim
        self.internal_dim = internal_dim
        self.dtype = dtype

        # Store SigLIP's fixed output dim (e.g., 1152)
        self.multi_dim = multi_dim

        # --- 1. Projection Layers ---
        # Project all features to the *internal_dim*
        self.name_proj = nn.Linear(self.multi_dim, internal_dim)
        self.symbol_proj = nn.Linear(self.multi_dim, internal_dim)
        self.image_proj = nn.Linear(self.multi_dim, internal_dim)
    
        # --- 2. Categorical & Boolean Feature Embeddings ---

        # Use small vocab size and small embed dim
        self.protocol_embedding = nn.Embedding(NUM_PROTOCOLS, protocol_embed_dim)
        # Project from small dim (64) up to internal_dim (1024)
        self.protocol_proj = nn.Linear(protocol_embed_dim, internal_dim)

        # NEW: Embedding for the is_vanity boolean flag
        self.vanity_embedding = nn.Embedding(2, vanity_embed_dim) # 2 classes: True/False
        self.vanity_proj = nn.Linear(vanity_embed_dim, internal_dim)

        # --- 3. Fusion Encoder ---
        # Re-use WalletSetEncoder to fuse the sequence of 5 features
        self.fusion_transformer = WalletSetEncoder(
            d_model=internal_dim,
            nhead=nhead,
            num_layers=num_layers,
            dim_feedforward=internal_dim * 4, # Standard 4x
            dtype=dtype
        )
        
        # --- 4. Final Output Projection ---
        # Project from the transformer's output (internal_dim)
        # to the final, dynamic output_dim.
        self.final_projection = nn.Sequential(
            nn.Linear(internal_dim, internal_dim * 2),
            nn.GELU(),
            nn.LayerNorm(internal_dim * 2),
            nn.Linear(internal_dim * 2, output_dim),
            nn.LayerNorm(output_dim)
        )
        
        # Cast new layers to the correct dtype and device
        device = "cuda" if torch.cuda.is_available() else "cpu"
        self.to(device=device, dtype=dtype)

        # Log params
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print(f"[TokenEncoder] Params: {total_params:,} (Trainable: {trainable_params:,})")

    def forward(
        self,
        name_embeds: torch.Tensor,
        symbol_embeds: torch.Tensor,
        image_embeds: torch.Tensor,
        protocol_ids: torch.Tensor,
        is_vanity_flags: torch.Tensor,
    ) -> torch.Tensor:
        """
        Processes a batch of token data to create a batch of embeddings.

        Args:
            name_embeds (torch.Tensor): Pre-computed embeddings for token names. Shape: [B, siglip_dim]
            symbol_embeds (torch.Tensor): Pre-computed embeddings for token symbols. Shape: [B, siglip_dim]
            image_embeds (torch.Tensor): Pre-computed embeddings for token images. Shape: [B, siglip_dim]
            protocol_ids (torch.Tensor): Batch of protocol IDs. Shape: [B]
            is_vanity_flags (torch.Tensor): Batch of boolean flags for vanity addresses. Shape: [B]

        Returns:
            torch.Tensor: The final <TokenVibeEmbedding> batch.
                          Shape: [batch_size, output_dim]
        """
        device = name_embeds.device
        batch_size = name_embeds.shape[0]

        protocol_ids_long = protocol_ids.to(device, dtype=torch.long)
        protocol_emb_raw = self.protocol_embedding(protocol_ids_long) # [B, 64]
        
        # NEW: Get vanity embedding
        vanity_ids_long = is_vanity_flags.to(device, dtype=torch.long)
        vanity_emb_raw = self.vanity_embedding(vanity_ids_long) # [B, 32]
        
        # 3. Project all features to internal_dim (e.g., 1024)
        name_emb = self.name_proj(name_embeds)
        symbol_emb = self.symbol_proj(symbol_embeds)
        image_emb = self.image_proj(image_embeds)
        protocol_emb = self.protocol_proj(protocol_emb_raw)
        vanity_emb = self.vanity_proj(vanity_emb_raw) # NEW

        # 4. Stack all projected features into a sequence
        feature_sequence = torch.stack([
            name_emb,
            symbol_emb,
            image_emb,
            protocol_emb,
            vanity_emb, # NEW: Add the vanity embedding to the sequence
        ], dim=1) 
        
        # 5. Create the padding mask (all False, since we have a fixed number of features for all)
        padding_mask = torch.zeros(batch_size, feature_sequence.shape[1], device=device, dtype=torch.bool)

        # 6. Fuse the sequence with the Transformer Encoder
        # This returns the [CLS] token output.
        # Shape: [B, internal_dim]
        fused_embedding = self.fusion_transformer(
            item_embeds=feature_sequence,
            src_key_padding_mask=padding_mask
        )

        # 7. Project to the final output dimension
        # Shape: [B, output_dim]
        token_vibe_embedding = self.final_projection(fused_embedding)
        
        return token_vibe_embedding