| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import List, Dict, Any, Optional |
| from PIL import Image |
|
|
| |
| from models.multi_modal_processor import MultiModalEncoder |
| from models.wallet_set_encoder import WalletSetEncoder |
|
|
|
|
| class WalletEncoder(nn.Module): |
| """ |
| Encodes a wallet's full identity into a single <WalletEmbedding>. |
| UPDATED: Aligned with the final feature spec. |
| """ |
|
|
| def __init__( |
| self, |
| encoder: MultiModalEncoder , |
| d_model: int = 2048, |
| token_vibe_dim: int = 2048, |
| set_encoder_nhead: int = 8, |
| set_encoder_nlayers: int = 2, |
| dtype: torch.dtype = torch.float16 |
| ): |
| """ |
| Initializes the WalletEncoder. |
| |
| Args: |
| d_model (int): The final output dimension (e.g., 4096). |
| token_vibe_dim (int): The dimension of the pre-computed |
| <TokenVibeEmbedding> (e.g., 1024). |
| encoder (MultiModalEncoder): Instantiated SigLIP encoder. |
| time_encoder (ContextualTimeEncoder): Instantiated time encoder. |
| set_encoder_nhead (int): Attention heads for set encoders. |
| set_encoder_nlayers (int): Transformer layers for set encoders. |
| dtype (torch.dtype): Data type. |
| """ |
| super().__init__() |
| self.d_model = d_model |
| self.dtype = dtype |
| self.encoder = encoder |
|
|
| |
| self.token_vibe_dim = token_vibe_dim |
| self.mmp_dim = self.encoder.embedding_dim |
|
|
| |
| |
| |
| self.profile_numerical_features = 37 |
| self.profile_num_norm = nn.LayerNorm(self.profile_numerical_features) |
| |
|
|
| |
| profile_mlp_in_dim = self.profile_numerical_features |
| self.profile_encoder_mlp = self._build_mlp(profile_mlp_in_dim, d_model) |
|
|
|
|
|
|
| |
| |
| self.social_bool_embed = nn.Embedding(2, 16) |
| |
| social_mlp_in_dim = (16 * 4) + self.mmp_dim |
| self.social_encoder_mlp = self._build_mlp(social_mlp_in_dim, d_model) |
|
|
|
|
| |
| |
| self.holding_numerical_features = 12 |
| self.holding_num_norm = nn.LayerNorm(self.holding_numerical_features) |
| |
| |
| holding_row_in_dim = ( |
| self.token_vibe_dim + |
| self.holding_numerical_features |
| ) |
| self.holding_row_encoder_mlp = self._build_mlp(holding_row_in_dim, d_model) |
| |
| self.holdings_set_encoder = WalletSetEncoder( |
| d_model, set_encoder_nhead, set_encoder_nlayers, dtype=dtype |
| ) |
|
|
|
|
| |
| |
| self.fusion_mlp = nn.Sequential( |
| nn.Linear(d_model * 3, d_model * 2), |
| nn.GELU(), |
| nn.LayerNorm(d_model * 2), |
| nn.Linear(d_model * 2, d_model), |
| nn.LayerNorm(d_model) |
| ) |
| self.to(dtype) |
|
|
| |
| |
| |
| my_params = sum(p.numel() for p in self.parameters()) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| print(f"[WalletEncoder] Params: {my_params:,} (Trainable: {trainable_params:,})") |
|
|
| def _build_mlp(self, in_dim, out_dim): |
| return nn.Sequential( |
| nn.Linear(in_dim, out_dim * 2), |
| nn.GELU(), |
| nn.LayerNorm(out_dim * 2), |
| nn.Linear(out_dim * 2, out_dim), |
| ).to(self.dtype) |
|
|
| def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor: |
| |
| return torch.sign(x) * torch.log1p(torch.abs(x)) |
|
|
| def _get_device(self) -> torch.device: |
| return self.encoder.device |
|
|
| def forward( |
| self, |
| profile_rows: List[Dict[str, Any]], |
| social_rows: List[Dict[str, Any]], |
| holdings_batch: List[List[Dict[str, Any]]], |
| token_vibe_lookup: Dict[str, torch.Tensor], |
| embedding_pool: torch.Tensor, |
| username_embed_indices: torch.Tensor |
| ) -> torch.Tensor: |
| device = self._get_device() |
| |
| profile_embed = self._encode_profile_batch(profile_rows, device) |
| social_embed = self._encode_social_batch(social_rows, embedding_pool, username_embed_indices, device) |
| holdings_embed = self._encode_holdings_batch(holdings_batch, token_vibe_lookup, device) |
|
|
| fused = torch.cat([profile_embed, social_embed, holdings_embed], dim=1) |
| return self.fusion_mlp(fused) |
|
|
| def _encode_profile_batch(self, profile_rows, device): |
| batch_size = len(profile_rows) |
| |
| num_tensor = torch.zeros(batch_size, self.profile_numerical_features, device=device, dtype=self.dtype) |
| |
| |
|
|
| for i, row in enumerate(profile_rows): |
| |
| num_data = [ |
| |
| row.get('deployed_tokens_count', 0.0), |
| row.get('deployed_tokens_migrated_pct', 0.0), |
| row.get('deployed_tokens_avg_lifetime_sec', 0.0), |
| row.get('deployed_tokens_avg_peak_mc_usd', 0.0), |
| row.get('deployed_tokens_median_peak_mc_usd', 0.0), |
| |
| row.get('balance', 0.0), |
| |
| row.get('transfers_in_count', 0.0), row.get('transfers_out_count', 0.0), |
| row.get('spl_transfers_in_count', 0.0), row.get('spl_transfers_out_count', 0.0), |
| |
| row.get('total_buys_count', 0.0), row.get('total_sells_count', 0.0), |
| row.get('total_winrate', 0.0), |
| |
| row.get('stats_1d_realized_profit_sol', 0.0), row.get('stats_1d_realized_profit_pnl', 0.0), |
| row.get('stats_1d_buy_count', 0.0), row.get('stats_1d_sell_count', 0.0), |
| row.get('stats_1d_transfer_in_count', 0.0), row.get('stats_1d_transfer_out_count', 0.0), |
| row.get('stats_1d_avg_holding_period', 0.0), row.get('stats_1d_total_bought_cost_sol', 0.0), |
| row.get('stats_1d_total_sold_income_sol', 0.0), row.get('stats_1d_total_fee', 0.0), |
| row.get('stats_1d_winrate', 0.0), row.get('stats_1d_tokens_traded', 0.0), |
| |
| row.get('stats_7d_realized_profit_sol', 0.0), row.get('stats_7d_realized_profit_pnl', 0.0), |
| row.get('stats_7d_buy_count', 0.0), row.get('stats_7d_sell_count', 0.0), |
| row.get('stats_7d_transfer_in_count', 0.0), row.get('stats_7d_transfer_out_count', 0.0), |
| row.get('stats_7d_avg_holding_period', 0.0), row.get('stats_7d_total_bought_cost_sol', 0.0), |
| row.get('stats_7d_total_sold_income_sol', 0.0), row.get('stats_7d_total_fee', 0.0), |
| row.get('stats_7d_winrate', 0.0), row.get('stats_7d_tokens_traded', 0.0), |
| ] |
| num_tensor[i] = torch.tensor(num_data, dtype=self.dtype) |
| |
| |
|
|
| |
| num_embed = self.profile_num_norm(self._safe_signed_log(num_tensor)) |
|
|
| |
| profile_fused = num_embed |
| return self.profile_encoder_mlp(profile_fused) |
|
|
| def _encode_social_batch(self, social_rows, embedding_pool, username_embed_indices, device): |
| batch_size = len(social_rows) |
| |
| bool_tensor = torch.zeros(batch_size, 4, device=device, dtype=torch.long) |
| for i, row in enumerate(social_rows): |
| |
| bool_tensor[i, 0] = 1 if row['has_pf_profile'] else 0 |
| bool_tensor[i, 1] = 1 if row['has_twitter'] else 0 |
| bool_tensor[i, 2] = 1 if row['has_telegram'] else 0 |
| |
| bool_tensor[i, 3] = 1 if row['is_exchange_wallet'] else 0 |
|
|
| bool_embeds = self.social_bool_embed(bool_tensor).view(batch_size, -1) |
| |
| |
| if embedding_pool.numel() > 0: |
| |
| pad_row = torch.zeros(1, embedding_pool.size(1), device=device, dtype=embedding_pool.dtype) |
| pool_padded = torch.cat([pad_row, embedding_pool], dim=0) |
| shifted_idx = torch.where(username_embed_indices >= 0, username_embed_indices + 1, torch.zeros_like(username_embed_indices)) |
| username_embed = F.embedding(shifted_idx, pool_padded) |
| else: |
| |
| username_embed = torch.zeros(batch_size, self.mmp_dim, device=device, dtype=self.dtype) |
| social_fused = torch.cat([bool_embeds, username_embed], dim=1) |
| return self.social_encoder_mlp(social_fused) |
|
|
| def _encode_holdings_batch(self, holdings_batch, token_vibe_lookup, device): |
| batch_size = len(holdings_batch) |
| max_len = max(len(h) for h in holdings_batch) if any(holdings_batch) else 1 |
| seq_embeds = torch.zeros(batch_size, max_len, self.d_model, device=device, dtype=self.dtype) |
| mask = torch.ones(batch_size, max_len, device=device, dtype=torch.bool) |
| default_vibe = torch.zeros(self.token_vibe_dim, device=device, dtype=self.dtype) |
|
|
| for i, holdings in enumerate(holdings_batch): |
| if not holdings: continue |
| h_len = min(len(holdings), max_len) |
| holdings = holdings[:h_len] |
|
|
| |
| vibes = [token_vibe_lookup.get(row['mint_address'], default_vibe) for row in holdings if 'mint_address' in row] |
| if not vibes: continue |
| vibe_tensor = torch.stack(vibes) |
| |
| |
|
|
| num_data_list = [] |
| for row in holdings: |
| |
| num_data = [ |
| |
| row.get('holding_time', 0.0), |
| row.get('balance_pct_to_supply', 0.0), |
| row.get('history_bought_cost_sol', 0.0), |
| row.get('bought_amount_sol_pct_to_native_balance', 0.0), |
| row.get('history_total_buys', 0.0), |
| row.get('history_total_sells', 0.0), |
| row.get('realized_profit_pnl', 0.0), |
| row.get('realized_profit_sol', 0.0), |
| row.get('history_transfer_in', 0.0), |
| row.get('history_transfer_out', 0.0), |
| row.get('avarage_trade_gap_seconds', 0.0), |
| row.get('total_fees', 0.0) |
| ] |
| num_data_list.append(num_data) |
|
|
| num_tensor = torch.tensor(num_data_list, device=device, dtype=self.dtype) |
| |
| |
| num_embed = self.holding_num_norm(self._safe_signed_log(num_tensor)) |
| |
| |
|
|
| |
| fused_rows = torch.cat([vibe_tensor, num_embed], dim=1) |
| encoded_rows = self.holding_row_encoder_mlp(fused_rows) |
| seq_embeds[i, :h_len] = encoded_rows |
| mask[i, :h_len] = False |
|
|
| return self.holdings_set_encoder(seq_embeds, mask) |
|
|