# token_encoder.py (FIXED) import torch import torch.nn as nn from typing import List, Any from PIL import Image from models.multi_modal_processor import MultiModalEncoder from models.wallet_set_encoder import WalletSetEncoder # Using your set encoder from models.vocabulary import NUM_PROTOCOLS class TokenEncoder(nn.Module): """ Encodes a token's core identity into a single . FIXED: This version uses a robust fusion architecture and provides a dynamic, smaller output dimension (e.g., 2048) suitable for being an input to a larger model. """ def __init__( self, multi_dim: int, # NEW: Pass the dimension directly output_dim: int = 2048, internal_dim: int = 1024, # INCREASED: Better balance between bottleneck and capacity protocol_embed_dim: int = 64, vanity_embed_dim: int = 32, # NEW: Small embedding for the vanity flag nhead: int = 4, num_layers: int = 1, dtype: torch.dtype = torch.float16 ): """ Initializes the TokenEncoder. Args: siglip_dim (int): The embedding dimension of the multimodal encoder (e.g., 1152). output_dim (int): The final dimension of the (e.g., 2048). internal_dim (int): The shared dimension for the internal fusion transformer (e.g., 1024). protocol_embed_dim (int): Small dimension for the protocol ID (e.g., 64). vanity_embed_dim (int): Small dimension for the is_vanity boolean flag. nhead (int): Attention heads for the fusion transformer. num_layers (int): Layers for the fusion transformer. dtype (torch.dtype): The data type (e.g., torch.float16). """ super().__init__() self.output_dim = output_dim self.internal_dim = internal_dim self.dtype = dtype # Store SigLIP's fixed output dim (e.g., 1152) self.multi_dim = multi_dim # --- 1. Projection Layers --- # Project all features to the *internal_dim* self.name_proj = nn.Linear(self.multi_dim, internal_dim) self.symbol_proj = nn.Linear(self.multi_dim, internal_dim) self.image_proj = nn.Linear(self.multi_dim, internal_dim) # --- 2. Categorical & Boolean Feature Embeddings --- # Use small vocab size and small embed dim self.protocol_embedding = nn.Embedding(NUM_PROTOCOLS, protocol_embed_dim) # Project from small dim (64) up to internal_dim (1024) self.protocol_proj = nn.Linear(protocol_embed_dim, internal_dim) # NEW: Embedding for the is_vanity boolean flag self.vanity_embedding = nn.Embedding(2, vanity_embed_dim) # 2 classes: True/False self.vanity_proj = nn.Linear(vanity_embed_dim, internal_dim) # --- 3. Fusion Encoder --- # Re-use WalletSetEncoder to fuse the sequence of 5 features self.fusion_transformer = WalletSetEncoder( d_model=internal_dim, nhead=nhead, num_layers=num_layers, dim_feedforward=internal_dim * 4, # Standard 4x dtype=dtype ) # --- 4. Final Output Projection --- # Project from the transformer's output (internal_dim) # to the final, dynamic output_dim. self.final_projection = nn.Sequential( nn.Linear(internal_dim, internal_dim * 2), nn.GELU(), nn.LayerNorm(internal_dim * 2), nn.Linear(internal_dim * 2, output_dim), nn.LayerNorm(output_dim) ) # Cast new layers to the correct dtype and device device = "cuda" if torch.cuda.is_available() else "cpu" self.to(device=device, dtype=dtype) # Log params total_params = sum(p.numel() for p in self.parameters()) trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) print(f"[TokenEncoder] Params: {total_params:,} (Trainable: {trainable_params:,})") def forward( self, name_embeds: torch.Tensor, symbol_embeds: torch.Tensor, image_embeds: torch.Tensor, protocol_ids: torch.Tensor, is_vanity_flags: torch.Tensor, ) -> torch.Tensor: """ Processes a batch of token data to create a batch of embeddings. Args: name_embeds (torch.Tensor): Pre-computed embeddings for token names. Shape: [B, siglip_dim] symbol_embeds (torch.Tensor): Pre-computed embeddings for token symbols. Shape: [B, siglip_dim] image_embeds (torch.Tensor): Pre-computed embeddings for token images. Shape: [B, siglip_dim] protocol_ids (torch.Tensor): Batch of protocol IDs. Shape: [B] is_vanity_flags (torch.Tensor): Batch of boolean flags for vanity addresses. Shape: [B] Returns: torch.Tensor: The final batch. Shape: [batch_size, output_dim] """ device = name_embeds.device batch_size = name_embeds.shape[0] protocol_ids_long = protocol_ids.to(device, dtype=torch.long) protocol_emb_raw = self.protocol_embedding(protocol_ids_long) # [B, 64] # NEW: Get vanity embedding vanity_ids_long = is_vanity_flags.to(device, dtype=torch.long) vanity_emb_raw = self.vanity_embedding(vanity_ids_long) # [B, 32] # 3. Project all features to internal_dim (e.g., 1024) name_emb = self.name_proj(name_embeds) symbol_emb = self.symbol_proj(symbol_embeds) image_emb = self.image_proj(image_embeds) protocol_emb = self.protocol_proj(protocol_emb_raw) vanity_emb = self.vanity_proj(vanity_emb_raw) # NEW # 4. Stack all projected features into a sequence feature_sequence = torch.stack([ name_emb, symbol_emb, image_emb, protocol_emb, vanity_emb, # NEW: Add the vanity embedding to the sequence ], dim=1) # 5. Create the padding mask (all False, since we have a fixed number of features for all) padding_mask = torch.zeros(batch_size, feature_sequence.shape[1], device=device, dtype=torch.bool) # 6. Fuse the sequence with the Transformer Encoder # This returns the [CLS] token output. # Shape: [B, internal_dim] fused_embedding = self.fusion_transformer( item_embeds=feature_sequence, src_key_padding_mask=padding_mask ) # 7. Project to the final output dimension # Shape: [B, output_dim] token_vibe_embedding = self.final_projection(fused_embedding) return token_vibe_embedding