File size: 6,908 Bytes
858826c 5800f64 858826c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | # token_encoder.py (FIXED)
import torch
import torch.nn as nn
from typing import List, Any
from PIL import Image
from models.multi_modal_processor import MultiModalEncoder
from models.wallet_set_encoder import WalletSetEncoder # Using your set encoder
from models.vocabulary import NUM_PROTOCOLS
class TokenEncoder(nn.Module):
"""
Encodes a token's core identity into a single <TokenVibeEmbedding>.
FIXED: This version uses a robust fusion architecture and provides
a dynamic, smaller output dimension (e.g., 2048) suitable for
being an input to a larger model.
"""
def __init__(
self,
multi_dim: int, # NEW: Pass the dimension directly
output_dim: int = 2048,
internal_dim: int = 1024, # INCREASED: Better balance between bottleneck and capacity
protocol_embed_dim: int = 64,
vanity_embed_dim: int = 32, # NEW: Small embedding for the vanity flag
nhead: int = 4,
num_layers: int = 1,
dtype: torch.dtype = torch.float16
):
"""
Initializes the TokenEncoder.
Args:
siglip_dim (int): The embedding dimension of the multimodal encoder (e.g., 1152).
output_dim (int):
The final dimension of the <TokenVibeEmbedding> (e.g., 2048).
internal_dim (int):
The shared dimension for the internal fusion transformer (e.g., 1024).
protocol_embed_dim (int):
Small dimension for the protocol ID (e.g., 64).
vanity_embed_dim (int):
Small dimension for the is_vanity boolean flag.
nhead (int):
Attention heads for the fusion transformer.
num_layers (int):
Layers for the fusion transformer.
dtype (torch.dtype):
The data type (e.g., torch.float16).
"""
super().__init__()
self.output_dim = output_dim
self.internal_dim = internal_dim
self.dtype = dtype
# Store SigLIP's fixed output dim (e.g., 1152)
self.multi_dim = multi_dim
# --- 1. Projection Layers ---
# Project all features to the *internal_dim*
self.name_proj = nn.Linear(self.multi_dim, internal_dim)
self.symbol_proj = nn.Linear(self.multi_dim, internal_dim)
self.image_proj = nn.Linear(self.multi_dim, internal_dim)
# --- 2. Categorical & Boolean Feature Embeddings ---
# Use small vocab size and small embed dim
self.protocol_embedding = nn.Embedding(NUM_PROTOCOLS, protocol_embed_dim)
# Project from small dim (64) up to internal_dim (1024)
self.protocol_proj = nn.Linear(protocol_embed_dim, internal_dim)
# NEW: Embedding for the is_vanity boolean flag
self.vanity_embedding = nn.Embedding(2, vanity_embed_dim) # 2 classes: True/False
self.vanity_proj = nn.Linear(vanity_embed_dim, internal_dim)
# --- 3. Fusion Encoder ---
# Re-use WalletSetEncoder to fuse the sequence of 5 features
self.fusion_transformer = WalletSetEncoder(
d_model=internal_dim,
nhead=nhead,
num_layers=num_layers,
dim_feedforward=internal_dim * 4, # Standard 4x
dtype=dtype
)
# --- 4. Final Output Projection ---
# Project from the transformer's output (internal_dim)
# to the final, dynamic output_dim.
self.final_projection = nn.Sequential(
nn.Linear(internal_dim, internal_dim * 2),
nn.GELU(),
nn.LayerNorm(internal_dim * 2),
nn.Linear(internal_dim * 2, output_dim),
nn.LayerNorm(output_dim)
)
# Cast new layers to the correct dtype and device
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device=device, dtype=dtype)
# Log params
total_params = sum(p.numel() for p in self.parameters())
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
print(f"[TokenEncoder] Params: {total_params:,} (Trainable: {trainable_params:,})")
def forward(
self,
name_embeds: torch.Tensor,
symbol_embeds: torch.Tensor,
image_embeds: torch.Tensor,
protocol_ids: torch.Tensor,
is_vanity_flags: torch.Tensor,
) -> torch.Tensor:
"""
Processes a batch of token data to create a batch of embeddings.
Args:
name_embeds (torch.Tensor): Pre-computed embeddings for token names. Shape: [B, siglip_dim]
symbol_embeds (torch.Tensor): Pre-computed embeddings for token symbols. Shape: [B, siglip_dim]
image_embeds (torch.Tensor): Pre-computed embeddings for token images. Shape: [B, siglip_dim]
protocol_ids (torch.Tensor): Batch of protocol IDs. Shape: [B]
is_vanity_flags (torch.Tensor): Batch of boolean flags for vanity addresses. Shape: [B]
Returns:
torch.Tensor: The final <TokenVibeEmbedding> batch.
Shape: [batch_size, output_dim]
"""
device = name_embeds.device
batch_size = name_embeds.shape[0]
protocol_ids_long = protocol_ids.to(device, dtype=torch.long)
protocol_emb_raw = self.protocol_embedding(protocol_ids_long) # [B, 64]
# NEW: Get vanity embedding
vanity_ids_long = is_vanity_flags.to(device, dtype=torch.long)
vanity_emb_raw = self.vanity_embedding(vanity_ids_long) # [B, 32]
# 3. Project all features to internal_dim (e.g., 1024)
name_emb = self.name_proj(name_embeds)
symbol_emb = self.symbol_proj(symbol_embeds)
image_emb = self.image_proj(image_embeds)
protocol_emb = self.protocol_proj(protocol_emb_raw)
vanity_emb = self.vanity_proj(vanity_emb_raw) # NEW
# 4. Stack all projected features into a sequence
feature_sequence = torch.stack([
name_emb,
symbol_emb,
image_emb,
protocol_emb,
vanity_emb, # NEW: Add the vanity embedding to the sequence
], dim=1)
# 5. Create the padding mask (all False, since we have a fixed number of features for all)
padding_mask = torch.zeros(batch_size, feature_sequence.shape[1], device=device, dtype=torch.bool)
# 6. Fuse the sequence with the Transformer Encoder
# This returns the [CLS] token output.
# Shape: [B, internal_dim]
fused_embedding = self.fusion_transformer(
item_embeds=feature_sequence,
src_key_padding_mask=padding_mask
)
# 7. Project to the final output dimension
# Shape: [B, output_dim]
token_vibe_embedding = self.final_projection(fused_embedding)
return token_vibe_embedding
|