File size: 14,156 Bytes
858826c bb2313b 858826c bb2313b 858826c 5800f64 858826c bb2313b 858826c bb2313b 858826c bb2313b 858826c bb2313b 858826c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 | import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Dict, Any, Optional
from PIL import Image
# We assume these helper modules are in the same directory
from models.multi_modal_processor import MultiModalEncoder
from models.wallet_set_encoder import WalletSetEncoder
class WalletEncoder(nn.Module):
"""
Encodes a wallet's full identity into a single <WalletEmbedding>.
UPDATED: Aligned with the final feature spec.
"""
def __init__(
self,
encoder: MultiModalEncoder ,
d_model: int = 2048, # Standardized to d_model
token_vibe_dim: int = 2048, # Expects token vibe of d_model
set_encoder_nhead: int = 8,
set_encoder_nlayers: int = 2,
dtype: torch.dtype = torch.float16
):
"""
Initializes the WalletEncoder.
Args:
d_model (int): The final output dimension (e.g., 4096).
token_vibe_dim (int): The dimension of the pre-computed
<TokenVibeEmbedding> (e.g., 1024).
encoder (MultiModalEncoder): Instantiated SigLIP encoder.
time_encoder (ContextualTimeEncoder): Instantiated time encoder.
set_encoder_nhead (int): Attention heads for set encoders.
set_encoder_nlayers (int): Transformer layers for set encoders.
dtype (torch.dtype): Data type.
"""
super().__init__()
self.d_model = d_model
self.dtype = dtype
self.encoder = encoder
# --- Dimensions ---
self.token_vibe_dim = token_vibe_dim
self.mmp_dim = self.encoder.embedding_dim # 1152
# === 1. Profile Encoder (FIXED) ===
# 5 deployer_stats + 1 balance + 4 lifetime_counts +
# 3 lifetime_trading + 12 1d_stats + 12 7d_stats = 37
self.profile_numerical_features = 37
self.profile_num_norm = nn.LayerNorm(self.profile_numerical_features)
# FIXED: Input dim no longer has bool embed or deployed tokens embed
profile_mlp_in_dim = self.profile_numerical_features # 37
self.profile_encoder_mlp = self._build_mlp(profile_mlp_in_dim, d_model)
# === 2. Social Encoder (FIXED) ===
# 4 booleans: has_pf, has_twitter, has_telegram, is_exchange_wallet
self.social_bool_embed = nn.Embedding(2, 16)
# FIXED: Input dim is (4 * 16) + mmp_dim
social_mlp_in_dim = (16 * 4) + self.mmp_dim # username embed
self.social_encoder_mlp = self._build_mlp(social_mlp_in_dim, d_model)
# === 3. Holdings Encoder (FIXED) ===
# 11 original stats + 1 holding_time = 12
self.holding_numerical_features = 12
self.holding_num_norm = nn.LayerNorm(self.holding_numerical_features)
# FIXED: Input dim no longer uses time_encoder
holding_row_in_dim = (
self.token_vibe_dim + # <TokenVibeEmbedding>
self.holding_numerical_features # 12
)
self.holding_row_encoder_mlp = self._build_mlp(holding_row_in_dim, d_model)
self.holdings_set_encoder = WalletSetEncoder(
d_model, set_encoder_nhead, set_encoder_nlayers, dtype=dtype
)
# === 5. Final Fusion Encoder (Unchanged) ===
# Still fuses 4 components: Profile, Social, Holdings, Graph
self.fusion_mlp = nn.Sequential(
nn.Linear(d_model * 3, d_model * 2), # Input is d_model * 3
nn.GELU(),
nn.LayerNorm(d_model * 2),
nn.Linear(d_model * 2, d_model),
nn.LayerNorm(d_model)
)
self.to(dtype)
# Log params (excluding the shared encoder which might be huge and already logged)
# Note: self.encoder is external, but if we include it here, it will double count.
# Ideally we only log *this* module's params.
my_params = sum(p.numel() for p in self.parameters())
# To avoid double counting the external encoder if it's a submodule (it is assigned to self.encoder)
# But wait, self.encoder IS a submodule.
# We should subtract it if we just want "WalletEncoder specific" params, or clarify.
# Let's verify if self.encoder params are included in self.parameters().
# Yes they are because `self.encoder = encoder` assigns it.
# Actually `encoder` is passed in. If `MultiModalEncoder` is an `nn.Module` (it is NOT), then it would be registered.
# `MultiModalEncoder` is a wrapper class, NOT an `nn.Module`.
# However, it contains `self.model` which is an `nn.Module`.
# But `WalletEncoder` stores `self.encoder = encoder`.
# Since `MultiModalEncoder` is not an `nn.Module`, `self.encoder` is just a standard attribute.
# So `self.parameters()` of `WalletEncoder` will NOT include `MultiModalEncoder` params.
# EXCEPT... we don't know if `MultiModalEncoder` subclassed `nn.Module`.
# I checked earlier: `class MultiModalEncoder:` -> No `nn.Module`.
# So we are safe. `self.parameters()` will only be the MLPs and SetEncoders defined in WalletEncoder.
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
print(f"[WalletEncoder] Params: {my_params:,} (Trainable: {trainable_params:,})")
def _build_mlp(self, in_dim, out_dim):
return nn.Sequential(
nn.Linear(in_dim, out_dim * 2),
nn.GELU(),
nn.LayerNorm(out_dim * 2),
nn.Linear(out_dim * 2, out_dim),
).to(self.dtype)
def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor:
# Log-normalizes numerical features (like age, stats, etc.)
return torch.sign(x) * torch.log1p(torch.abs(x))
def _get_device(self) -> torch.device:
return self.encoder.device
def forward(
self,
profile_rows: List[Dict[str, Any]],
social_rows: List[Dict[str, Any]],
holdings_batch: List[List[Dict[str, Any]]],
token_vibe_lookup: Dict[str, torch.Tensor],
embedding_pool: torch.Tensor,
username_embed_indices: torch.Tensor
) -> torch.Tensor:
device = self._get_device()
profile_embed = self._encode_profile_batch(profile_rows, device)
social_embed = self._encode_social_batch(social_rows, embedding_pool, username_embed_indices, device)
holdings_embed = self._encode_holdings_batch(holdings_batch, token_vibe_lookup, device)
fused = torch.cat([profile_embed, social_embed, holdings_embed], dim=1)
return self.fusion_mlp(fused)
def _encode_profile_batch(self, profile_rows, device):
batch_size = len(profile_rows)
# FIXED: 37 numerical features
num_tensor = torch.zeros(batch_size, self.profile_numerical_features, device=device, dtype=self.dtype)
# bool_tensor removed
# time_tensor removed
for i, row in enumerate(profile_rows):
# A: Numerical (FIXED: 37 features, MUST be present)
num_data = [
# 1. Deployed Token Aggregates (5)
row.get('deployed_tokens_count', 0.0),
row.get('deployed_tokens_migrated_pct', 0.0),
row.get('deployed_tokens_avg_lifetime_sec', 0.0),
row.get('deployed_tokens_avg_peak_mc_usd', 0.0),
row.get('deployed_tokens_median_peak_mc_usd', 0.0),
# 3. Balance (1)
row.get('balance', 0.0),
# 4. Lifetime Transaction Counts (4)
row.get('transfers_in_count', 0.0), row.get('transfers_out_count', 0.0),
row.get('spl_transfers_in_count', 0.0), row.get('spl_transfers_out_count', 0.0),
# 5. Lifetime Trading Stats (3)
row.get('total_buys_count', 0.0), row.get('total_sells_count', 0.0),
row.get('total_winrate', 0.0),
# 6. 1-Day Stats (12)
row.get('stats_1d_realized_profit_sol', 0.0), row.get('stats_1d_realized_profit_pnl', 0.0),
row.get('stats_1d_buy_count', 0.0), row.get('stats_1d_sell_count', 0.0),
row.get('stats_1d_transfer_in_count', 0.0), row.get('stats_1d_transfer_out_count', 0.0),
row.get('stats_1d_avg_holding_period', 0.0), row.get('stats_1d_total_bought_cost_sol', 0.0),
row.get('stats_1d_total_sold_income_sol', 0.0), row.get('stats_1d_total_fee', 0.0),
row.get('stats_1d_winrate', 0.0), row.get('stats_1d_tokens_traded', 0.0),
# 7. 7-Day Stats (12)
row.get('stats_7d_realized_profit_sol', 0.0), row.get('stats_7d_realized_profit_pnl', 0.0),
row.get('stats_7d_buy_count', 0.0), row.get('stats_7d_sell_count', 0.0),
row.get('stats_7d_transfer_in_count', 0.0), row.get('stats_7d_transfer_out_count', 0.0),
row.get('stats_7d_avg_holding_period', 0.0), row.get('stats_7d_total_bought_cost_sol', 0.0),
row.get('stats_7d_total_sold_income_sol', 0.0), row.get('stats_7d_total_fee', 0.0),
row.get('stats_7d_winrate', 0.0), row.get('stats_7d_tokens_traded', 0.0),
]
num_tensor[i] = torch.tensor(num_data, dtype=self.dtype)
# C: Booleans and deployed_tokens lists are GONE
# Log-normalize all numerical features (stats, etc.)
num_embed = self.profile_num_norm(self._safe_signed_log(num_tensor))
# The profile fused tensor is now just the numerical embeddings
profile_fused = num_embed
return self.profile_encoder_mlp(profile_fused)
def _encode_social_batch(self, social_rows, embedding_pool, username_embed_indices, device):
batch_size = len(social_rows)
# FIXED: 4 boolean features
bool_tensor = torch.zeros(batch_size, 4, device=device, dtype=torch.long)
for i, row in enumerate(social_rows):
# All features MUST be present
bool_tensor[i, 0] = 1 if row['has_pf_profile'] else 0
bool_tensor[i, 1] = 1 if row['has_twitter'] else 0
bool_tensor[i, 2] = 1 if row['has_telegram'] else 0
# FIXED: Added is_exchange_wallet
bool_tensor[i, 3] = 1 if row['is_exchange_wallet'] else 0
bool_embeds = self.social_bool_embed(bool_tensor).view(batch_size, -1) # [B, 64]
# --- NEW: Look up pre-computed username embeddings ---
# --- FIXED: Handle case where embedding_pool is empty ---
if embedding_pool.numel() > 0:
# SAFETY: build a padded view so missing indices (-1) map to a zero vector
pad_row = torch.zeros(1, embedding_pool.size(1), device=device, dtype=embedding_pool.dtype)
pool_padded = torch.cat([pad_row, embedding_pool], dim=0)
shifted_idx = torch.where(username_embed_indices >= 0, username_embed_indices + 1, torch.zeros_like(username_embed_indices))
username_embed = F.embedding(shifted_idx, pool_padded)
else:
# If there are no embeddings, create a zero tensor of the correct shape
username_embed = torch.zeros(batch_size, self.mmp_dim, device=device, dtype=self.dtype)
social_fused = torch.cat([bool_embeds, username_embed], dim=1)
return self.social_encoder_mlp(social_fused)
def _encode_holdings_batch(self, holdings_batch, token_vibe_lookup, device):
batch_size = len(holdings_batch)
max_len = max(len(h) for h in holdings_batch) if any(holdings_batch) else 1
seq_embeds = torch.zeros(batch_size, max_len, self.d_model, device=device, dtype=self.dtype)
mask = torch.ones(batch_size, max_len, device=device, dtype=torch.bool)
default_vibe = torch.zeros(self.token_vibe_dim, device=device, dtype=self.dtype)
for i, holdings in enumerate(holdings_batch):
if not holdings: continue
h_len = min(len(holdings), max_len)
holdings = holdings[:h_len]
# --- FIXED: Safely get vibes, using default if mint_address is missing or not in lookup ---
vibes = [token_vibe_lookup.get(row['mint_address'], default_vibe) for row in holdings if 'mint_address' in row]
if not vibes: continue # Skip if no valid holdings with vibes
vibe_tensor = torch.stack(vibes)
# time_tensor removed
num_data_list = []
for row in holdings:
# FIXED: All 12 numerical features MUST be present
num_data = [
# Use .get() with a 0.0 default for safety
row.get('holding_time', 0.0),
row.get('balance_pct_to_supply', 0.0),
row.get('history_bought_cost_sol', 0.0), # Corrected key from schema
row.get('bought_amount_sol_pct_to_native_balance', 0.0), # This key is not in schema, will default to 0
row.get('history_total_buys', 0.0),
row.get('history_total_sells', 0.0),
row.get('realized_profit_pnl', 0.0),
row.get('realized_profit_sol', 0.0),
row.get('history_transfer_in', 0.0),
row.get('history_transfer_out', 0.0),
row.get('avarage_trade_gap_seconds', 0.0),
row.get('total_fees', 0.0) # Corrected key from schema
]
num_data_list.append(num_data)
num_tensor = torch.tensor(num_data_list, device=device, dtype=self.dtype)
# Log-normalize all numerical features (holding_time, stats, etc.)
num_embed = self.holding_num_norm(self._safe_signed_log(num_tensor))
# time_embed removed
# FIXED: Fused tensor no longer has time_embed
fused_rows = torch.cat([vibe_tensor, num_embed], dim=1)
encoded_rows = self.holding_row_encoder_mlp(fused_rows)
seq_embeds[i, :h_len] = encoded_rows
mask[i, :h_len] = False
return self.holdings_set_encoder(seq_embeds, mask)
|