import torch import torch.nn as nn import torch.nn.functional as F from typing import List, Dict, Any, Optional from PIL import Image # We assume these helper modules are in the same directory from models.multi_modal_processor import MultiModalEncoder from models.wallet_set_encoder import WalletSetEncoder class WalletEncoder(nn.Module): """ Encodes a wallet's full identity into a single . UPDATED: Aligned with the final feature spec. """ def __init__( self, encoder: MultiModalEncoder , d_model: int = 2048, # Standardized to d_model token_vibe_dim: int = 2048, # Expects token vibe of d_model set_encoder_nhead: int = 8, set_encoder_nlayers: int = 2, dtype: torch.dtype = torch.float16 ): """ Initializes the WalletEncoder. Args: d_model (int): The final output dimension (e.g., 4096). token_vibe_dim (int): The dimension of the pre-computed (e.g., 1024). encoder (MultiModalEncoder): Instantiated SigLIP encoder. time_encoder (ContextualTimeEncoder): Instantiated time encoder. set_encoder_nhead (int): Attention heads for set encoders. set_encoder_nlayers (int): Transformer layers for set encoders. dtype (torch.dtype): Data type. """ super().__init__() self.d_model = d_model self.dtype = dtype self.encoder = encoder # --- Dimensions --- self.token_vibe_dim = token_vibe_dim self.mmp_dim = self.encoder.embedding_dim # 1152 # === 1. Profile Encoder (FIXED) === # 5 deployer_stats + 1 balance + 4 lifetime_counts + # 3 lifetime_trading + 12 1d_stats + 12 7d_stats = 37 self.profile_numerical_features = 37 self.profile_num_norm = nn.LayerNorm(self.profile_numerical_features) # FIXED: Input dim no longer has bool embed or deployed tokens embed profile_mlp_in_dim = self.profile_numerical_features # 37 self.profile_encoder_mlp = self._build_mlp(profile_mlp_in_dim, d_model) # === 2. Social Encoder (FIXED) === # 4 booleans: has_pf, has_twitter, has_telegram, is_exchange_wallet self.social_bool_embed = nn.Embedding(2, 16) # FIXED: Input dim is (4 * 16) + mmp_dim social_mlp_in_dim = (16 * 4) + self.mmp_dim # username embed self.social_encoder_mlp = self._build_mlp(social_mlp_in_dim, d_model) # === 3. Holdings Encoder (FIXED) === # 11 original stats + 1 holding_time = 12 self.holding_numerical_features = 12 self.holding_num_norm = nn.LayerNorm(self.holding_numerical_features) # FIXED: Input dim no longer uses time_encoder holding_row_in_dim = ( self.token_vibe_dim + # self.holding_numerical_features # 12 ) self.holding_row_encoder_mlp = self._build_mlp(holding_row_in_dim, d_model) self.holdings_set_encoder = WalletSetEncoder( d_model, set_encoder_nhead, set_encoder_nlayers, dtype=dtype ) # === 5. Final Fusion Encoder (Unchanged) === # Still fuses 4 components: Profile, Social, Holdings, Graph self.fusion_mlp = nn.Sequential( nn.Linear(d_model * 3, d_model * 2), # Input is d_model * 3 nn.GELU(), nn.LayerNorm(d_model * 2), nn.Linear(d_model * 2, d_model), nn.LayerNorm(d_model) ) self.to(dtype) # Log params (excluding the shared encoder which might be huge and already logged) # Note: self.encoder is external, but if we include it here, it will double count. # Ideally we only log *this* module's params. my_params = sum(p.numel() for p in self.parameters()) # To avoid double counting the external encoder if it's a submodule (it is assigned to self.encoder) # But wait, self.encoder IS a submodule. # We should subtract it if we just want "WalletEncoder specific" params, or clarify. # Let's verify if self.encoder params are included in self.parameters(). # Yes they are because `self.encoder = encoder` assigns it. # Actually `encoder` is passed in. If `MultiModalEncoder` is an `nn.Module` (it is NOT), then it would be registered. # `MultiModalEncoder` is a wrapper class, NOT an `nn.Module`. # However, it contains `self.model` which is an `nn.Module`. # But `WalletEncoder` stores `self.encoder = encoder`. # Since `MultiModalEncoder` is not an `nn.Module`, `self.encoder` is just a standard attribute. # So `self.parameters()` of `WalletEncoder` will NOT include `MultiModalEncoder` params. # EXCEPT... we don't know if `MultiModalEncoder` subclassed `nn.Module`. # I checked earlier: `class MultiModalEncoder:` -> No `nn.Module`. # So we are safe. `self.parameters()` will only be the MLPs and SetEncoders defined in WalletEncoder. trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) print(f"[WalletEncoder] Params: {my_params:,} (Trainable: {trainable_params:,})") def _build_mlp(self, in_dim, out_dim): return nn.Sequential( nn.Linear(in_dim, out_dim * 2), nn.GELU(), nn.LayerNorm(out_dim * 2), nn.Linear(out_dim * 2, out_dim), ).to(self.dtype) def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor: # Log-normalizes numerical features (like age, stats, etc.) return torch.sign(x) * torch.log1p(torch.abs(x)) def _get_device(self) -> torch.device: return self.encoder.device def forward( self, profile_rows: List[Dict[str, Any]], social_rows: List[Dict[str, Any]], holdings_batch: List[List[Dict[str, Any]]], token_vibe_lookup: Dict[str, torch.Tensor], embedding_pool: torch.Tensor, username_embed_indices: torch.Tensor ) -> torch.Tensor: device = self._get_device() profile_embed = self._encode_profile_batch(profile_rows, device) social_embed = self._encode_social_batch(social_rows, embedding_pool, username_embed_indices, device) holdings_embed = self._encode_holdings_batch(holdings_batch, token_vibe_lookup, device) fused = torch.cat([profile_embed, social_embed, holdings_embed], dim=1) return self.fusion_mlp(fused) def _encode_profile_batch(self, profile_rows, device): batch_size = len(profile_rows) # FIXED: 37 numerical features num_tensor = torch.zeros(batch_size, self.profile_numerical_features, device=device, dtype=self.dtype) # bool_tensor removed # time_tensor removed for i, row in enumerate(profile_rows): # A: Numerical (FIXED: 37 features, MUST be present) num_data = [ # 1. Deployed Token Aggregates (5) row.get('deployed_tokens_count', 0.0), row.get('deployed_tokens_migrated_pct', 0.0), row.get('deployed_tokens_avg_lifetime_sec', 0.0), row.get('deployed_tokens_avg_peak_mc_usd', 0.0), row.get('deployed_tokens_median_peak_mc_usd', 0.0), # 3. Balance (1) row.get('balance', 0.0), # 4. Lifetime Transaction Counts (4) row.get('transfers_in_count', 0.0), row.get('transfers_out_count', 0.0), row.get('spl_transfers_in_count', 0.0), row.get('spl_transfers_out_count', 0.0), # 5. Lifetime Trading Stats (3) row.get('total_buys_count', 0.0), row.get('total_sells_count', 0.0), row.get('total_winrate', 0.0), # 6. 1-Day Stats (12) row.get('stats_1d_realized_profit_sol', 0.0), row.get('stats_1d_realized_profit_pnl', 0.0), row.get('stats_1d_buy_count', 0.0), row.get('stats_1d_sell_count', 0.0), row.get('stats_1d_transfer_in_count', 0.0), row.get('stats_1d_transfer_out_count', 0.0), row.get('stats_1d_avg_holding_period', 0.0), row.get('stats_1d_total_bought_cost_sol', 0.0), row.get('stats_1d_total_sold_income_sol', 0.0), row.get('stats_1d_total_fee', 0.0), row.get('stats_1d_winrate', 0.0), row.get('stats_1d_tokens_traded', 0.0), # 7. 7-Day Stats (12) row.get('stats_7d_realized_profit_sol', 0.0), row.get('stats_7d_realized_profit_pnl', 0.0), row.get('stats_7d_buy_count', 0.0), row.get('stats_7d_sell_count', 0.0), row.get('stats_7d_transfer_in_count', 0.0), row.get('stats_7d_transfer_out_count', 0.0), row.get('stats_7d_avg_holding_period', 0.0), row.get('stats_7d_total_bought_cost_sol', 0.0), row.get('stats_7d_total_sold_income_sol', 0.0), row.get('stats_7d_total_fee', 0.0), row.get('stats_7d_winrate', 0.0), row.get('stats_7d_tokens_traded', 0.0), ] num_tensor[i] = torch.tensor(num_data, dtype=self.dtype) # C: Booleans and deployed_tokens lists are GONE # Log-normalize all numerical features (stats, etc.) num_embed = self.profile_num_norm(self._safe_signed_log(num_tensor)) # The profile fused tensor is now just the numerical embeddings profile_fused = num_embed return self.profile_encoder_mlp(profile_fused) def _encode_social_batch(self, social_rows, embedding_pool, username_embed_indices, device): batch_size = len(social_rows) # FIXED: 4 boolean features bool_tensor = torch.zeros(batch_size, 4, device=device, dtype=torch.long) for i, row in enumerate(social_rows): # All features MUST be present bool_tensor[i, 0] = 1 if row['has_pf_profile'] else 0 bool_tensor[i, 1] = 1 if row['has_twitter'] else 0 bool_tensor[i, 2] = 1 if row['has_telegram'] else 0 # FIXED: Added is_exchange_wallet bool_tensor[i, 3] = 1 if row['is_exchange_wallet'] else 0 bool_embeds = self.social_bool_embed(bool_tensor).view(batch_size, -1) # [B, 64] # --- NEW: Look up pre-computed username embeddings --- # --- FIXED: Handle case where embedding_pool is empty --- if embedding_pool.numel() > 0: # SAFETY: build a padded view so missing indices (-1) map to a zero vector pad_row = torch.zeros(1, embedding_pool.size(1), device=device, dtype=embedding_pool.dtype) pool_padded = torch.cat([pad_row, embedding_pool], dim=0) shifted_idx = torch.where(username_embed_indices >= 0, username_embed_indices + 1, torch.zeros_like(username_embed_indices)) username_embed = F.embedding(shifted_idx, pool_padded) else: # If there are no embeddings, create a zero tensor of the correct shape username_embed = torch.zeros(batch_size, self.mmp_dim, device=device, dtype=self.dtype) social_fused = torch.cat([bool_embeds, username_embed], dim=1) return self.social_encoder_mlp(social_fused) def _encode_holdings_batch(self, holdings_batch, token_vibe_lookup, device): batch_size = len(holdings_batch) max_len = max(len(h) for h in holdings_batch) if any(holdings_batch) else 1 seq_embeds = torch.zeros(batch_size, max_len, self.d_model, device=device, dtype=self.dtype) mask = torch.ones(batch_size, max_len, device=device, dtype=torch.bool) default_vibe = torch.zeros(self.token_vibe_dim, device=device, dtype=self.dtype) for i, holdings in enumerate(holdings_batch): if not holdings: continue h_len = min(len(holdings), max_len) holdings = holdings[:h_len] # --- FIXED: Safely get vibes, using default if mint_address is missing or not in lookup --- vibes = [token_vibe_lookup.get(row['mint_address'], default_vibe) for row in holdings if 'mint_address' in row] if not vibes: continue # Skip if no valid holdings with vibes vibe_tensor = torch.stack(vibes) # time_tensor removed num_data_list = [] for row in holdings: # FIXED: All 12 numerical features MUST be present num_data = [ # Use .get() with a 0.0 default for safety row.get('holding_time', 0.0), row.get('balance_pct_to_supply', 0.0), row.get('history_bought_cost_sol', 0.0), # Corrected key from schema row.get('bought_amount_sol_pct_to_native_balance', 0.0), # This key is not in schema, will default to 0 row.get('history_total_buys', 0.0), row.get('history_total_sells', 0.0), row.get('realized_profit_pnl', 0.0), row.get('realized_profit_sol', 0.0), row.get('history_transfer_in', 0.0), row.get('history_transfer_out', 0.0), row.get('avarage_trade_gap_seconds', 0.0), row.get('total_fees', 0.0) # Corrected key from schema ] num_data_list.append(num_data) num_tensor = torch.tensor(num_data_list, device=device, dtype=self.dtype) # Log-normalize all numerical features (holding_time, stats, etc.) num_embed = self.holding_num_norm(self._safe_signed_log(num_tensor)) # time_embed removed # FIXED: Fused tensor no longer has time_embed fused_rows = torch.cat([vibe_tensor, num_embed], dim=1) encoded_rows = self.holding_row_encoder_mlp(fused_rows) seq_embeds[i, :h_len] = encoded_rows mask[i, :h_len] = False return self.holdings_set_encoder(seq_embeds, mask)