oracle / models /wallet_encoder.py
zirobtc's picture
Upload folder using huggingface_hub
bb2313b
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Dict, Any, Optional
from PIL import Image
# We assume these helper modules are in the same directory
from models.multi_modal_processor import MultiModalEncoder
from models.wallet_set_encoder import WalletSetEncoder
class WalletEncoder(nn.Module):
"""
Encodes a wallet's full identity into a single <WalletEmbedding>.
UPDATED: Aligned with the final feature spec.
"""
def __init__(
self,
encoder: MultiModalEncoder ,
d_model: int = 2048, # Standardized to d_model
token_vibe_dim: int = 2048, # Expects token vibe of d_model
set_encoder_nhead: int = 8,
set_encoder_nlayers: int = 2,
dtype: torch.dtype = torch.float16
):
"""
Initializes the WalletEncoder.
Args:
d_model (int): The final output dimension (e.g., 4096).
token_vibe_dim (int): The dimension of the pre-computed
<TokenVibeEmbedding> (e.g., 1024).
encoder (MultiModalEncoder): Instantiated SigLIP encoder.
time_encoder (ContextualTimeEncoder): Instantiated time encoder.
set_encoder_nhead (int): Attention heads for set encoders.
set_encoder_nlayers (int): Transformer layers for set encoders.
dtype (torch.dtype): Data type.
"""
super().__init__()
self.d_model = d_model
self.dtype = dtype
self.encoder = encoder
# --- Dimensions ---
self.token_vibe_dim = token_vibe_dim
self.mmp_dim = self.encoder.embedding_dim # 1152
# === 1. Profile Encoder (FIXED) ===
# 5 deployer_stats + 1 balance + 4 lifetime_counts +
# 3 lifetime_trading + 12 1d_stats + 12 7d_stats = 37
self.profile_numerical_features = 37
self.profile_num_norm = nn.LayerNorm(self.profile_numerical_features)
# FIXED: Input dim no longer has bool embed or deployed tokens embed
profile_mlp_in_dim = self.profile_numerical_features # 37
self.profile_encoder_mlp = self._build_mlp(profile_mlp_in_dim, d_model)
# === 2. Social Encoder (FIXED) ===
# 4 booleans: has_pf, has_twitter, has_telegram, is_exchange_wallet
self.social_bool_embed = nn.Embedding(2, 16)
# FIXED: Input dim is (4 * 16) + mmp_dim
social_mlp_in_dim = (16 * 4) + self.mmp_dim # username embed
self.social_encoder_mlp = self._build_mlp(social_mlp_in_dim, d_model)
# === 3. Holdings Encoder (FIXED) ===
# 11 original stats + 1 holding_time = 12
self.holding_numerical_features = 12
self.holding_num_norm = nn.LayerNorm(self.holding_numerical_features)
# FIXED: Input dim no longer uses time_encoder
holding_row_in_dim = (
self.token_vibe_dim + # <TokenVibeEmbedding>
self.holding_numerical_features # 12
)
self.holding_row_encoder_mlp = self._build_mlp(holding_row_in_dim, d_model)
self.holdings_set_encoder = WalletSetEncoder(
d_model, set_encoder_nhead, set_encoder_nlayers, dtype=dtype
)
# === 5. Final Fusion Encoder (Unchanged) ===
# Still fuses 4 components: Profile, Social, Holdings, Graph
self.fusion_mlp = nn.Sequential(
nn.Linear(d_model * 3, d_model * 2), # Input is d_model * 3
nn.GELU(),
nn.LayerNorm(d_model * 2),
nn.Linear(d_model * 2, d_model),
nn.LayerNorm(d_model)
)
self.to(dtype)
# Log params (excluding the shared encoder which might be huge and already logged)
# Note: self.encoder is external, but if we include it here, it will double count.
# Ideally we only log *this* module's params.
my_params = sum(p.numel() for p in self.parameters())
# To avoid double counting the external encoder if it's a submodule (it is assigned to self.encoder)
# But wait, self.encoder IS a submodule.
# We should subtract it if we just want "WalletEncoder specific" params, or clarify.
# Let's verify if self.encoder params are included in self.parameters().
# Yes they are because `self.encoder = encoder` assigns it.
# Actually `encoder` is passed in. If `MultiModalEncoder` is an `nn.Module` (it is NOT), then it would be registered.
# `MultiModalEncoder` is a wrapper class, NOT an `nn.Module`.
# However, it contains `self.model` which is an `nn.Module`.
# But `WalletEncoder` stores `self.encoder = encoder`.
# Since `MultiModalEncoder` is not an `nn.Module`, `self.encoder` is just a standard attribute.
# So `self.parameters()` of `WalletEncoder` will NOT include `MultiModalEncoder` params.
# EXCEPT... we don't know if `MultiModalEncoder` subclassed `nn.Module`.
# I checked earlier: `class MultiModalEncoder:` -> No `nn.Module`.
# So we are safe. `self.parameters()` will only be the MLPs and SetEncoders defined in WalletEncoder.
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
print(f"[WalletEncoder] Params: {my_params:,} (Trainable: {trainable_params:,})")
def _build_mlp(self, in_dim, out_dim):
return nn.Sequential(
nn.Linear(in_dim, out_dim * 2),
nn.GELU(),
nn.LayerNorm(out_dim * 2),
nn.Linear(out_dim * 2, out_dim),
).to(self.dtype)
def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor:
# Log-normalizes numerical features (like age, stats, etc.)
return torch.sign(x) * torch.log1p(torch.abs(x))
def _get_device(self) -> torch.device:
return self.encoder.device
def forward(
self,
profile_rows: List[Dict[str, Any]],
social_rows: List[Dict[str, Any]],
holdings_batch: List[List[Dict[str, Any]]],
token_vibe_lookup: Dict[str, torch.Tensor],
embedding_pool: torch.Tensor,
username_embed_indices: torch.Tensor
) -> torch.Tensor:
device = self._get_device()
profile_embed = self._encode_profile_batch(profile_rows, device)
social_embed = self._encode_social_batch(social_rows, embedding_pool, username_embed_indices, device)
holdings_embed = self._encode_holdings_batch(holdings_batch, token_vibe_lookup, device)
fused = torch.cat([profile_embed, social_embed, holdings_embed], dim=1)
return self.fusion_mlp(fused)
def _encode_profile_batch(self, profile_rows, device):
batch_size = len(profile_rows)
# FIXED: 37 numerical features
num_tensor = torch.zeros(batch_size, self.profile_numerical_features, device=device, dtype=self.dtype)
# bool_tensor removed
# time_tensor removed
for i, row in enumerate(profile_rows):
# A: Numerical (FIXED: 37 features, MUST be present)
num_data = [
# 1. Deployed Token Aggregates (5)
row.get('deployed_tokens_count', 0.0),
row.get('deployed_tokens_migrated_pct', 0.0),
row.get('deployed_tokens_avg_lifetime_sec', 0.0),
row.get('deployed_tokens_avg_peak_mc_usd', 0.0),
row.get('deployed_tokens_median_peak_mc_usd', 0.0),
# 3. Balance (1)
row.get('balance', 0.0),
# 4. Lifetime Transaction Counts (4)
row.get('transfers_in_count', 0.0), row.get('transfers_out_count', 0.0),
row.get('spl_transfers_in_count', 0.0), row.get('spl_transfers_out_count', 0.0),
# 5. Lifetime Trading Stats (3)
row.get('total_buys_count', 0.0), row.get('total_sells_count', 0.0),
row.get('total_winrate', 0.0),
# 6. 1-Day Stats (12)
row.get('stats_1d_realized_profit_sol', 0.0), row.get('stats_1d_realized_profit_pnl', 0.0),
row.get('stats_1d_buy_count', 0.0), row.get('stats_1d_sell_count', 0.0),
row.get('stats_1d_transfer_in_count', 0.0), row.get('stats_1d_transfer_out_count', 0.0),
row.get('stats_1d_avg_holding_period', 0.0), row.get('stats_1d_total_bought_cost_sol', 0.0),
row.get('stats_1d_total_sold_income_sol', 0.0), row.get('stats_1d_total_fee', 0.0),
row.get('stats_1d_winrate', 0.0), row.get('stats_1d_tokens_traded', 0.0),
# 7. 7-Day Stats (12)
row.get('stats_7d_realized_profit_sol', 0.0), row.get('stats_7d_realized_profit_pnl', 0.0),
row.get('stats_7d_buy_count', 0.0), row.get('stats_7d_sell_count', 0.0),
row.get('stats_7d_transfer_in_count', 0.0), row.get('stats_7d_transfer_out_count', 0.0),
row.get('stats_7d_avg_holding_period', 0.0), row.get('stats_7d_total_bought_cost_sol', 0.0),
row.get('stats_7d_total_sold_income_sol', 0.0), row.get('stats_7d_total_fee', 0.0),
row.get('stats_7d_winrate', 0.0), row.get('stats_7d_tokens_traded', 0.0),
]
num_tensor[i] = torch.tensor(num_data, dtype=self.dtype)
# C: Booleans and deployed_tokens lists are GONE
# Log-normalize all numerical features (stats, etc.)
num_embed = self.profile_num_norm(self._safe_signed_log(num_tensor))
# The profile fused tensor is now just the numerical embeddings
profile_fused = num_embed
return self.profile_encoder_mlp(profile_fused)
def _encode_social_batch(self, social_rows, embedding_pool, username_embed_indices, device):
batch_size = len(social_rows)
# FIXED: 4 boolean features
bool_tensor = torch.zeros(batch_size, 4, device=device, dtype=torch.long)
for i, row in enumerate(social_rows):
# All features MUST be present
bool_tensor[i, 0] = 1 if row['has_pf_profile'] else 0
bool_tensor[i, 1] = 1 if row['has_twitter'] else 0
bool_tensor[i, 2] = 1 if row['has_telegram'] else 0
# FIXED: Added is_exchange_wallet
bool_tensor[i, 3] = 1 if row['is_exchange_wallet'] else 0
bool_embeds = self.social_bool_embed(bool_tensor).view(batch_size, -1) # [B, 64]
# --- NEW: Look up pre-computed username embeddings ---
# --- FIXED: Handle case where embedding_pool is empty ---
if embedding_pool.numel() > 0:
# SAFETY: build a padded view so missing indices (-1) map to a zero vector
pad_row = torch.zeros(1, embedding_pool.size(1), device=device, dtype=embedding_pool.dtype)
pool_padded = torch.cat([pad_row, embedding_pool], dim=0)
shifted_idx = torch.where(username_embed_indices >= 0, username_embed_indices + 1, torch.zeros_like(username_embed_indices))
username_embed = F.embedding(shifted_idx, pool_padded)
else:
# If there are no embeddings, create a zero tensor of the correct shape
username_embed = torch.zeros(batch_size, self.mmp_dim, device=device, dtype=self.dtype)
social_fused = torch.cat([bool_embeds, username_embed], dim=1)
return self.social_encoder_mlp(social_fused)
def _encode_holdings_batch(self, holdings_batch, token_vibe_lookup, device):
batch_size = len(holdings_batch)
max_len = max(len(h) for h in holdings_batch) if any(holdings_batch) else 1
seq_embeds = torch.zeros(batch_size, max_len, self.d_model, device=device, dtype=self.dtype)
mask = torch.ones(batch_size, max_len, device=device, dtype=torch.bool)
default_vibe = torch.zeros(self.token_vibe_dim, device=device, dtype=self.dtype)
for i, holdings in enumerate(holdings_batch):
if not holdings: continue
h_len = min(len(holdings), max_len)
holdings = holdings[:h_len]
# --- FIXED: Safely get vibes, using default if mint_address is missing or not in lookup ---
vibes = [token_vibe_lookup.get(row['mint_address'], default_vibe) for row in holdings if 'mint_address' in row]
if not vibes: continue # Skip if no valid holdings with vibes
vibe_tensor = torch.stack(vibes)
# time_tensor removed
num_data_list = []
for row in holdings:
# FIXED: All 12 numerical features MUST be present
num_data = [
# Use .get() with a 0.0 default for safety
row.get('holding_time', 0.0),
row.get('balance_pct_to_supply', 0.0),
row.get('history_bought_cost_sol', 0.0), # Corrected key from schema
row.get('bought_amount_sol_pct_to_native_balance', 0.0), # This key is not in schema, will default to 0
row.get('history_total_buys', 0.0),
row.get('history_total_sells', 0.0),
row.get('realized_profit_pnl', 0.0),
row.get('realized_profit_sol', 0.0),
row.get('history_transfer_in', 0.0),
row.get('history_transfer_out', 0.0),
row.get('avarage_trade_gap_seconds', 0.0),
row.get('total_fees', 0.0) # Corrected key from schema
]
num_data_list.append(num_data)
num_tensor = torch.tensor(num_data_list, device=device, dtype=self.dtype)
# Log-normalize all numerical features (holding_time, stats, etc.)
num_embed = self.holding_num_norm(self._safe_signed_log(num_tensor))
# time_embed removed
# FIXED: Fused tensor no longer has time_embed
fused_rows = torch.cat([vibe_tensor, num_embed], dim=1)
encoded_rows = self.holding_row_encoder_mlp(fused_rows)
seq_embeds[i, :h_len] = encoded_rows
mask[i, :h_len] = False
return self.holdings_set_encoder(seq_embeds, mask)