| |
|
|
| import torch |
| import torch.nn as nn |
| from typing import List, Any |
| from PIL import Image |
|
|
| from models.multi_modal_processor import MultiModalEncoder |
| from models.wallet_set_encoder import WalletSetEncoder |
| from models.vocabulary import NUM_PROTOCOLS |
|
|
| class TokenEncoder(nn.Module): |
| """ |
| Encodes a token's core identity into a single <TokenVibeEmbedding>. |
| |
| FIXED: This version uses a robust fusion architecture and provides |
| a dynamic, smaller output dimension (e.g., 2048) suitable for |
| being an input to a larger model. |
| """ |
| def __init__( |
| self, |
| multi_dim: int, |
| output_dim: int = 2048, |
| internal_dim: int = 1024, |
| protocol_embed_dim: int = 64, |
| vanity_embed_dim: int = 32, |
| nhead: int = 4, |
| num_layers: int = 1, |
| dtype: torch.dtype = torch.float16 |
| ): |
| """ |
| Initializes the TokenEncoder. |
| |
| Args: |
| siglip_dim (int): The embedding dimension of the multimodal encoder (e.g., 1152). |
| output_dim (int): |
| The final dimension of the <TokenVibeEmbedding> (e.g., 2048). |
| internal_dim (int): |
| The shared dimension for the internal fusion transformer (e.g., 1024). |
| protocol_embed_dim (int): |
| Small dimension for the protocol ID (e.g., 64). |
| vanity_embed_dim (int): |
| Small dimension for the is_vanity boolean flag. |
| nhead (int): |
| Attention heads for the fusion transformer. |
| num_layers (int): |
| Layers for the fusion transformer. |
| dtype (torch.dtype): |
| The data type (e.g., torch.float16). |
| """ |
| super().__init__() |
| self.output_dim = output_dim |
| self.internal_dim = internal_dim |
| self.dtype = dtype |
|
|
| |
| self.multi_dim = multi_dim |
|
|
| |
| |
| self.name_proj = nn.Linear(self.multi_dim, internal_dim) |
| self.symbol_proj = nn.Linear(self.multi_dim, internal_dim) |
| self.image_proj = nn.Linear(self.multi_dim, internal_dim) |
| |
| |
|
|
| |
| self.protocol_embedding = nn.Embedding(NUM_PROTOCOLS, protocol_embed_dim) |
| |
| self.protocol_proj = nn.Linear(protocol_embed_dim, internal_dim) |
|
|
| |
| self.vanity_embedding = nn.Embedding(2, vanity_embed_dim) |
| self.vanity_proj = nn.Linear(vanity_embed_dim, internal_dim) |
|
|
| |
| |
| self.fusion_transformer = WalletSetEncoder( |
| d_model=internal_dim, |
| nhead=nhead, |
| num_layers=num_layers, |
| dim_feedforward=internal_dim * 4, |
| dtype=dtype |
| ) |
| |
| |
| |
| |
| self.final_projection = nn.Sequential( |
| nn.Linear(internal_dim, internal_dim * 2), |
| nn.GELU(), |
| nn.LayerNorm(internal_dim * 2), |
| nn.Linear(internal_dim * 2, output_dim), |
| nn.LayerNorm(output_dim) |
| ) |
| |
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.to(device=device, dtype=dtype) |
|
|
| |
| total_params = sum(p.numel() for p in self.parameters()) |
| trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| print(f"[TokenEncoder] Params: {total_params:,} (Trainable: {trainable_params:,})") |
|
|
| def forward( |
| self, |
| name_embeds: torch.Tensor, |
| symbol_embeds: torch.Tensor, |
| image_embeds: torch.Tensor, |
| protocol_ids: torch.Tensor, |
| is_vanity_flags: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Processes a batch of token data to create a batch of embeddings. |
| |
| Args: |
| name_embeds (torch.Tensor): Pre-computed embeddings for token names. Shape: [B, siglip_dim] |
| symbol_embeds (torch.Tensor): Pre-computed embeddings for token symbols. Shape: [B, siglip_dim] |
| image_embeds (torch.Tensor): Pre-computed embeddings for token images. Shape: [B, siglip_dim] |
| protocol_ids (torch.Tensor): Batch of protocol IDs. Shape: [B] |
| is_vanity_flags (torch.Tensor): Batch of boolean flags for vanity addresses. Shape: [B] |
| |
| Returns: |
| torch.Tensor: The final <TokenVibeEmbedding> batch. |
| Shape: [batch_size, output_dim] |
| """ |
| device = name_embeds.device |
| batch_size = name_embeds.shape[0] |
|
|
| protocol_ids_long = protocol_ids.to(device, dtype=torch.long) |
| protocol_emb_raw = self.protocol_embedding(protocol_ids_long) |
| |
| |
| vanity_ids_long = is_vanity_flags.to(device, dtype=torch.long) |
| vanity_emb_raw = self.vanity_embedding(vanity_ids_long) |
| |
| |
| name_emb = self.name_proj(name_embeds) |
| symbol_emb = self.symbol_proj(symbol_embeds) |
| image_emb = self.image_proj(image_embeds) |
| protocol_emb = self.protocol_proj(protocol_emb_raw) |
| vanity_emb = self.vanity_proj(vanity_emb_raw) |
|
|
| |
| feature_sequence = torch.stack([ |
| name_emb, |
| symbol_emb, |
| image_emb, |
| protocol_emb, |
| vanity_emb, |
| ], dim=1) |
| |
| |
| padding_mask = torch.zeros(batch_size, feature_sequence.shape[1], device=device, dtype=torch.bool) |
|
|
| |
| |
| |
| fused_embedding = self.fusion_transformer( |
| item_embeds=feature_sequence, |
| src_key_padding_mask=padding_mask |
| ) |
|
|
| |
| |
| token_vibe_embedding = self.final_projection(fused_embedding) |
| |
| return token_vibe_embedding |
|
|