oracle / models /token_encoder.py
zirobtc's picture
Upload folder using huggingface_hub
5800f64
# token_encoder.py (FIXED)
import torch
import torch.nn as nn
from typing import List, Any
from PIL import Image
from models.multi_modal_processor import MultiModalEncoder
from models.wallet_set_encoder import WalletSetEncoder # Using your set encoder
from models.vocabulary import NUM_PROTOCOLS
class TokenEncoder(nn.Module):
"""
Encodes a token's core identity into a single <TokenVibeEmbedding>.
FIXED: This version uses a robust fusion architecture and provides
a dynamic, smaller output dimension (e.g., 2048) suitable for
being an input to a larger model.
"""
def __init__(
self,
multi_dim: int, # NEW: Pass the dimension directly
output_dim: int = 2048,
internal_dim: int = 1024, # INCREASED: Better balance between bottleneck and capacity
protocol_embed_dim: int = 64,
vanity_embed_dim: int = 32, # NEW: Small embedding for the vanity flag
nhead: int = 4,
num_layers: int = 1,
dtype: torch.dtype = torch.float16
):
"""
Initializes the TokenEncoder.
Args:
siglip_dim (int): The embedding dimension of the multimodal encoder (e.g., 1152).
output_dim (int):
The final dimension of the <TokenVibeEmbedding> (e.g., 2048).
internal_dim (int):
The shared dimension for the internal fusion transformer (e.g., 1024).
protocol_embed_dim (int):
Small dimension for the protocol ID (e.g., 64).
vanity_embed_dim (int):
Small dimension for the is_vanity boolean flag.
nhead (int):
Attention heads for the fusion transformer.
num_layers (int):
Layers for the fusion transformer.
dtype (torch.dtype):
The data type (e.g., torch.float16).
"""
super().__init__()
self.output_dim = output_dim
self.internal_dim = internal_dim
self.dtype = dtype
# Store SigLIP's fixed output dim (e.g., 1152)
self.multi_dim = multi_dim
# --- 1. Projection Layers ---
# Project all features to the *internal_dim*
self.name_proj = nn.Linear(self.multi_dim, internal_dim)
self.symbol_proj = nn.Linear(self.multi_dim, internal_dim)
self.image_proj = nn.Linear(self.multi_dim, internal_dim)
# --- 2. Categorical & Boolean Feature Embeddings ---
# Use small vocab size and small embed dim
self.protocol_embedding = nn.Embedding(NUM_PROTOCOLS, protocol_embed_dim)
# Project from small dim (64) up to internal_dim (1024)
self.protocol_proj = nn.Linear(protocol_embed_dim, internal_dim)
# NEW: Embedding for the is_vanity boolean flag
self.vanity_embedding = nn.Embedding(2, vanity_embed_dim) # 2 classes: True/False
self.vanity_proj = nn.Linear(vanity_embed_dim, internal_dim)
# --- 3. Fusion Encoder ---
# Re-use WalletSetEncoder to fuse the sequence of 5 features
self.fusion_transformer = WalletSetEncoder(
d_model=internal_dim,
nhead=nhead,
num_layers=num_layers,
dim_feedforward=internal_dim * 4, # Standard 4x
dtype=dtype
)
# --- 4. Final Output Projection ---
# Project from the transformer's output (internal_dim)
# to the final, dynamic output_dim.
self.final_projection = nn.Sequential(
nn.Linear(internal_dim, internal_dim * 2),
nn.GELU(),
nn.LayerNorm(internal_dim * 2),
nn.Linear(internal_dim * 2, output_dim),
nn.LayerNorm(output_dim)
)
# Cast new layers to the correct dtype and device
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device=device, dtype=dtype)
# Log params
total_params = sum(p.numel() for p in self.parameters())
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
print(f"[TokenEncoder] Params: {total_params:,} (Trainable: {trainable_params:,})")
def forward(
self,
name_embeds: torch.Tensor,
symbol_embeds: torch.Tensor,
image_embeds: torch.Tensor,
protocol_ids: torch.Tensor,
is_vanity_flags: torch.Tensor,
) -> torch.Tensor:
"""
Processes a batch of token data to create a batch of embeddings.
Args:
name_embeds (torch.Tensor): Pre-computed embeddings for token names. Shape: [B, siglip_dim]
symbol_embeds (torch.Tensor): Pre-computed embeddings for token symbols. Shape: [B, siglip_dim]
image_embeds (torch.Tensor): Pre-computed embeddings for token images. Shape: [B, siglip_dim]
protocol_ids (torch.Tensor): Batch of protocol IDs. Shape: [B]
is_vanity_flags (torch.Tensor): Batch of boolean flags for vanity addresses. Shape: [B]
Returns:
torch.Tensor: The final <TokenVibeEmbedding> batch.
Shape: [batch_size, output_dim]
"""
device = name_embeds.device
batch_size = name_embeds.shape[0]
protocol_ids_long = protocol_ids.to(device, dtype=torch.long)
protocol_emb_raw = self.protocol_embedding(protocol_ids_long) # [B, 64]
# NEW: Get vanity embedding
vanity_ids_long = is_vanity_flags.to(device, dtype=torch.long)
vanity_emb_raw = self.vanity_embedding(vanity_ids_long) # [B, 32]
# 3. Project all features to internal_dim (e.g., 1024)
name_emb = self.name_proj(name_embeds)
symbol_emb = self.symbol_proj(symbol_embeds)
image_emb = self.image_proj(image_embeds)
protocol_emb = self.protocol_proj(protocol_emb_raw)
vanity_emb = self.vanity_proj(vanity_emb_raw) # NEW
# 4. Stack all projected features into a sequence
feature_sequence = torch.stack([
name_emb,
symbol_emb,
image_emb,
protocol_emb,
vanity_emb, # NEW: Add the vanity embedding to the sequence
], dim=1)
# 5. Create the padding mask (all False, since we have a fixed number of features for all)
padding_mask = torch.zeros(batch_size, feature_sequence.shape[1], device=device, dtype=torch.bool)
# 6. Fuse the sequence with the Transformer Encoder
# This returns the [CLS] token output.
# Shape: [B, internal_dim]
fused_embedding = self.fusion_transformer(
item_embeds=feature_sequence,
src_key_padding_mask=padding_mask
)
# 7. Project to the final output dimension
# Shape: [B, output_dim]
token_vibe_embedding = self.final_projection(fused_embedding)
return token_vibe_embedding