# model.py (REFACTORED AND FIXED) import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModel, LlamaConfig from typing import List, Dict, Any, Optional, Tuple import os import json # --- NOW, we import all the encoders --- from models.helper_encoders import ContextualTimeEncoder from models.token_encoder import TokenEncoder from models.wallet_encoder import WalletEncoder from models.graph_updater import GraphUpdater from models.ohlc_embedder import OHLCEmbedder from models.quant_ohlc_embedder import QuantOHLCEmbedder from models.HoldersEncoder import HolderDistributionEncoder # NEW from models.SocialEncoders import SocialEncoder # NEW import models.vocabulary as vocab # For vocab sizes from data.context_targets import MOVEMENT_CLASS_NAMES class Oracle(nn.Module): """ """ def __init__(self, token_encoder: TokenEncoder, wallet_encoder: WalletEncoder, graph_updater: GraphUpdater, ohlc_embedder: OHLCEmbedder, # NEW quant_ohlc_embedder: QuantOHLCEmbedder, time_encoder: ContextualTimeEncoder, num_event_types: int, multi_modal_dim: int, event_pad_id: int, event_type_to_id: Dict[str, int], model_config_name: str = "llama3-12l-768d-gqa4-8k-random", quantiles: List[float] = [0.1, 0.5, 0.9], horizons_seconds: List[int] = [30, 60, 120, 240, 420], dtype: torch.dtype = torch.bfloat16): super().__init__() device = "cuda" if torch.cuda.is_available() else "cpu" self.device = torch.device(device) self.dtype = dtype self.multi_modal_dim = multi_modal_dim self.num_event_types = num_event_types self.event_pad_id = event_pad_id self.model_config_name = model_config_name self.quantiles = quantiles self.horizons_seconds = horizons_seconds self.num_outputs = len(quantiles) * len(horizons_seconds) self.num_movement_classes = len(MOVEMENT_CLASS_NAMES) self.dtype = dtype # --- 2. Backbone: Llama-style decoder, RANDOM INIT (no pretrained weights) --- # This gives you RoPE + modern decoder blocks and lets HF use optimized attention # implementations (SDPA / FlashAttention) without us implementing a transformer. # # Size target: ~80-120M params, suitable for 8k-ish seq caps with your data regime. attn_impl = os.getenv("HF_ATTN_IMPL", "sdpa") # "sdpa" (safe) or "flash_attention_2" (if installed) llama_cfg = LlamaConfig( # Model size hidden_size=768, intermediate_size=3072, num_hidden_layers=12, num_attention_heads=12, # GQA-style KV heads (Llama 3-style efficiency knob) num_key_value_heads=4, # Long context (must be >= your effective max sequence length) max_position_embeddings=8192, # Llama 3 uses a large theta; harmless for random init and helps longer contexts. rope_theta=500000.0, rms_norm_eps=1e-5, # Unused when providing inputs_embeds, but required by config vocab_size=32000, ) self.d_model = llama_cfg.hidden_size # Older transformers versions may not support attn_implementation in from_config. # Also, flash_attention_2 requires optional deps; fall back to SDPA if unavailable. try: self.model = AutoModel.from_config(llama_cfg, attn_implementation=attn_impl) except TypeError: self.model = AutoModel.from_config(llama_cfg) except Exception: if attn_impl != "sdpa": self.model = AutoModel.from_config(llama_cfg, attn_implementation="sdpa") else: raise # Disable KV cache during training (saves memory; not used for full-seq training). if hasattr(self.model, "config"): self.model.config.use_cache = False self.model.to(self.device, dtype=self.dtype) # Quantile prediction head (maps pooled hidden state -> flattened horizon/quantile grid) self.quantile_head = nn.Sequential( nn.Linear(self.d_model, self.d_model), nn.GELU(), nn.Linear(self.d_model, self.num_outputs) ) self.quality_head = nn.Sequential( nn.Linear(self.d_model, self.d_model), nn.GELU(), nn.Linear(self.d_model, 1) ) self.movement_head = nn.Sequential( nn.Linear(self.d_model, self.d_model), nn.GELU(), nn.Linear(self.d_model, len(self.horizons_seconds) * self.num_movement_classes) ) self.event_type_to_id = event_type_to_id # --- 1. Store All Encoders --- # Define Token Roles before using them self.token_roles = {'main': 0, 'quote': 1, 'trending': 2} # Add trending for future use self.main_token_role_id = self.token_roles['main'] self.quote_token_role_id = self.token_roles['quote'] self.trending_token_role_id = self.token_roles['trending'] self.token_encoder = token_encoder self.wallet_encoder = wallet_encoder self.graph_updater = graph_updater self.ohlc_embedder = ohlc_embedder self.quant_ohlc_embedder = quant_ohlc_embedder self.time_encoder = time_encoder # Store time_encoder self.social_encoder = SocialEncoder(d_model=self.d_model, dtype=self.dtype) # Now self.d_model is defined # --- 4. Define Sequence Feature Embeddings --- self.event_type_embedding = nn.Embedding(num_event_types, self.d_model, padding_idx=event_pad_id) # --- NEW: Token Role Embeddings --- self.token_role_embedding = nn.Embedding(len(self.token_roles), self.d_model) # --- 5. Define Entity Padding (Learnable) --- self.pad_wallet_emb = nn.Parameter(torch.zeros(1, self.wallet_encoder.d_model)) self.pad_token_emb = nn.Parameter(torch.zeros(1, self.token_encoder.output_dim)) self.pad_ohlc_emb = nn.Parameter(torch.zeros(1, self.quant_ohlc_embedder.output_dim)) self.pad_precomputed_emb = nn.Parameter(torch.zeros(1, self.multi_modal_dim)) # NEW: For text/images # --- NEW: Instantiate HolderDistributionEncoder internally --- self.holder_dist_encoder = HolderDistributionEncoder( wallet_embedding_dim=self.wallet_encoder.d_model, output_dim=self.d_model, dtype=self.dtype # Pass the correct dtype ) self.pad_holder_snapshot_emb = nn.Parameter(torch.zeros(1, self.d_model)) # Output of holder_dist_encoder is d_model # --- 6. Define Projection MLPs --- self.time_proj = nn.Linear(self.time_encoder.projection.out_features, self.d_model) self.rel_ts_proj = nn.Linear(1, self.d_model) self.rel_ts_norm = nn.LayerNorm(1) self.wallet_proj = nn.Linear(self.wallet_encoder.d_model, self.d_model) self.token_proj = nn.Linear(self.token_encoder.output_dim, self.d_model) self.ohlc_proj = nn.Linear(self.quant_ohlc_embedder.output_dim, self.d_model) self.chart_interval_fusion_embedding = nn.Embedding(vocab.NUM_OHLC_INTERVALS, 32, padding_idx=0) fusion_input_dim = self.ohlc_embedder.output_dim + self.quant_ohlc_embedder.output_dim + 32 self.chart_fusion = nn.Sequential( nn.Linear(fusion_input_dim, self.quant_ohlc_embedder.output_dim), nn.GELU(), nn.LayerNorm(self.quant_ohlc_embedder.output_dim), nn.Linear(self.quant_ohlc_embedder.output_dim, self.quant_ohlc_embedder.output_dim), nn.LayerNorm(self.quant_ohlc_embedder.output_dim), ) # self.holder_snapshot_proj is no longer needed as HolderDistributionEncoder outputs directly to d_model # --- NEW: Layers for Transfer Numerical Features --- self.transfer_num_norm = nn.LayerNorm(4) # Normalize the 4 features self.transfer_num_proj = nn.Linear(4, self.d_model) # Project to d_model # --- NEW: Layers for Trade Numerical Features --- # --- FIXED: Size reduced from 10 to 8 --- self.trade_num_norm = nn.LayerNorm(8) self.trade_num_proj = nn.Linear(8, self.d_model) # --- NEW: Embedding for categorical dex_platform_id --- self.dex_platform_embedding = nn.Embedding(vocab.NUM_DEX_PLATFORMS, self.d_model) # --- NEW: Embedding for categorical trade_direction --- self.trade_direction_embedding = nn.Embedding(2, self.d_model) # 0 for buy, 1 for sell # --- FIXED: Embedding for categorical mev_protection is now binary --- self.mev_protection_embedding = nn.Embedding(2, self.d_model) # 0 for false, 1 for true # --- NEW: Embedding for categorical is_bundle --- self.is_bundle_embedding = nn.Embedding(2, self.d_model) # 0 for false, 1 for true # --- NEW: Separate Layers for Deployer Trade Numerical Features --- # --- FIXED: Size reduced from 10 to 8 --- self.deployer_trade_num_norm = nn.LayerNorm(8) self.deployer_trade_num_proj = nn.Linear(8, self.d_model) # --- NEW: Separate Layers for Smart Wallet Trade Numerical Features --- # --- FIXED: Size reduced from 10 to 8 --- self.smart_wallet_trade_num_norm = nn.LayerNorm(8) self.smart_wallet_trade_num_proj = nn.Linear(8, self.d_model) # --- NEW: Layers for PoolCreated Numerical Features --- # --- FIXED: Size reduced from 5 to 4 --- self.pool_created_num_norm = nn.LayerNorm(2) self.pool_created_num_proj = nn.Linear(2, self.d_model) # --- NEW: Layers for LiquidityChange Numerical Features --- # --- FIXED: Size reduced from 3 to 2 --- self.liquidity_change_num_norm = nn.LayerNorm(1) self.liquidity_change_num_proj = nn.Linear(1, self.d_model) # --- NEW: Embedding for categorical change_type_id --- # --- FIXED: Hardcoded the number of types (add/remove) as per user instruction --- self.liquidity_change_type_embedding = nn.Embedding(2, self.d_model) # --- NEW: Layers for FeeCollected Numerical Features --- self.fee_collected_num_norm = nn.LayerNorm(1) # sol_amount only self.fee_collected_num_proj = nn.Linear(1, self.d_model) # --- NEW: Layers for TokenBurn Numerical Features --- self.token_burn_num_norm = nn.LayerNorm(2) # amount_pct, amount_tokens self.token_burn_num_proj = nn.Linear(2, self.d_model) # --- NEW: Layers for SupplyLock Numerical Features --- self.supply_lock_num_norm = nn.LayerNorm(2) # amount_pct, lock_duration self.supply_lock_num_proj = nn.Linear(2, self.d_model) # --- NEW: Layers for OnChain_Snapshot Numerical Features --- self.onchain_snapshot_num_norm = nn.LayerNorm(14) self.onchain_snapshot_num_proj = nn.Linear(14, self.d_model) # --- NEW: Layers for TrendingToken Numerical Features --- # --- FIXED: Size reduced from 3 to 1 (rank only) --- self.trending_token_num_norm = nn.LayerNorm(1) self.trending_token_num_proj = nn.Linear(1, self.d_model) # --- NEW: Embeddings for categorical IDs --- self.trending_list_source_embedding = nn.Embedding(vocab.NUM_TRENDING_LIST_SOURCES, self.d_model) self.trending_timeframe_embedding = nn.Embedding(vocab.NUM_TRENDING_LIST_TIMEFRAMES, self.d_model) # --- NEW: Layers for BoostedToken Numerical Features --- self.boosted_token_num_norm = nn.LayerNorm(2) # total_boost_amount, rank self.boosted_token_num_proj = nn.Linear(2, self.d_model) # --- NEW: Layers for DexBoost_Paid Numerical Features --- self.dexboost_paid_num_norm = nn.LayerNorm(2) # amount, total_amount_on_token self.dexboost_paid_num_proj = nn.Linear(2, self.d_model) # --- NEW: Layers for DexProfile_Updated Features --- self.dexprofile_updated_flags_proj = nn.Linear(4, self.d_model) # Project the 4 boolean flags # --- NEW: Projection for all pre-computed embeddings (text/images) --- self.precomputed_proj = nn.Linear(self.multi_modal_dim, self.d_model) # --- NEW: Embedding for Protocol IDs (used in Migrated event) --- self.protocol_embedding = nn.Embedding(vocab.NUM_PROTOCOLS, self.d_model) # --- NEW: Embeddings for TrackerEncoder Events --- # Note: NUM_CALL_CHANNELS might need to be large and managed as vocab grows. self.alpha_group_embedding = nn.Embedding(vocab.NUM_ALPHA_GROUPS, self.d_model) self.call_channel_embedding = nn.Embedding(vocab.NUM_CALL_CHANNELS, self.d_model) self.cex_listing_embedding = nn.Embedding(vocab.NUM_EXCHANGES, self.d_model) # --- NEW: Layers for GlobalTrendingEncoder Events --- self.global_trending_num_norm = nn.LayerNorm(1) # rank self.global_trending_num_proj = nn.Linear(1, self.d_model) # --- NEW: Layers for ChainSnapshot Events --- self.chainsnapshot_num_norm = nn.LayerNorm(2) # native_token_price_usd, gas_fee self.chainsnapshot_num_proj = nn.Linear(2, self.d_model) # --- NEW: Layers for Lighthouse_Snapshot Events --- # --- FIXED: Size reduced from 7 to 5 --- self.lighthousesnapshot_num_norm = nn.LayerNorm(5) self.lighthousesnapshot_num_proj = nn.Linear(5, self.d_model) # --- NEW: Embedding for timeframe ID (re-uses protocol_embedding) --- self.lighthouse_timeframe_embedding = nn.Embedding(vocab.NUM_LIGHTHOUSE_TIMEFRAMES, self.d_model) # --- Embeddings for Special Context Tokens --- # Must match vocabulary event names (see models/vocabulary.py). self.special_context_tokens = {'MIDDLE': 0, 'RECENT': 1} self.special_context_embedding = nn.Embedding(len(self.special_context_tokens), self.d_model) # --- 7. Prediction Head --- (Unchanged) # self.prediction_head = nn.Linear(self.d_model, self.num_outputs) # --- 8. Move all new modules to correct dtype --- self.to(dtype) print("Oracle model (full pipeline) initialized.") def save_pretrained(self, save_directory: str): """ Saves the model in a Hugging Face-compatible way. """ if not os.path.exists(save_directory): os.makedirs(save_directory) # 1. Save the inner transformer model using its own save_pretrained # This gives us the standard HF config.json and pytorch_model.bin for the backbone self.model.save_pretrained(save_directory) # 2. Save the whole Oracle state dict (includes transformer + all custom encoders) # We use 'oracle_model.bin' for the full state. torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin")) # 3. Save Oracle specific metadata for reconstruction oracle_config = { "num_event_types": self.num_event_types, "multi_modal_dim": self.multi_modal_dim, "event_pad_id": self.event_pad_id, "model_config_name": self.model_config_name, "quantiles": self.quantiles, "horizons_seconds": self.horizons_seconds, "dtype": str(self.dtype), "event_type_to_id": self.event_type_to_id } with open(os.path.join(save_directory, "oracle_config.json"), "w") as f: json.dump(oracle_config, f, indent=2) print(f"✅ Oracle model saved to {save_directory}") @classmethod def from_pretrained(cls, load_directory: str, token_encoder, wallet_encoder, graph_updater, ohlc_embedder, quant_ohlc_embedder, time_encoder): """ Loads the Oracle model from a saved directory. Note: You must still provide the initialized sub-encoders (or we can refactor to save them too). """ config_path = os.path.join(load_directory, "oracle_config.json") with open(config_path, "r") as f: config = json.load(f) # Determine dtype from string dtype = torch.bfloat16 # Default if "float32" in config["dtype"]: dtype = torch.float32 elif "float16" in config["dtype"]: dtype = torch.float16 # Instantiate model model = cls( token_encoder=token_encoder, wallet_encoder=wallet_encoder, graph_updater=graph_updater, ohlc_embedder=ohlc_embedder, quant_ohlc_embedder=quant_ohlc_embedder, time_encoder=time_encoder, num_event_types=config["num_event_types"], multi_modal_dim=config["multi_modal_dim"], event_pad_id=config["event_pad_id"], event_type_to_id=config["event_type_to_id"], model_config_name=config["model_config_name"], quantiles=config["quantiles"], horizons_seconds=config["horizons_seconds"], dtype=dtype ) # Load weights weight_path = os.path.join(load_directory, "pytorch_model.bin") state_dict = torch.load(weight_path, map_location="cpu") model.load_state_dict(state_dict) print(f"✅ Oracle model loaded from {load_directory}") return model def _normalize_and_project(self, features: torch.Tensor, norm_layer: nn.LayerNorm, proj_layer: nn.Linear, log_indices: Optional[List[int]] = None) -> torch.Tensor: """ A helper function to selectively apply log scaling, then normalize and project. """ processed_features = torch.nan_to_num( features.to(torch.float32), nan=0.0, posinf=1e6, neginf=-1e6 ) # Apply log scaling only to specified indices if log_indices: # Ensure log_indices are valid valid_indices = [i for i in log_indices if i < processed_features.shape[-1]] if valid_indices: log_features = processed_features[:, :, valid_indices] log_scaled = torch.sign(log_features) * torch.log1p(torch.abs(log_features)) processed_features[:, :, valid_indices] = log_scaled # Normalize and project the entire feature set norm_dtype = norm_layer.weight.dtype proj_dtype = proj_layer.weight.dtype normed_features = norm_layer(processed_features.to(norm_dtype)) normed_features = torch.nan_to_num(normed_features, nan=0.0, posinf=0.0, neginf=0.0) return proj_layer(normed_features.to(proj_dtype)) def _run_snapshot_encoders(self, batch: Dict[str, Any], final_wallet_embeddings_raw: torch.Tensor, wallet_addr_to_batch_idx: Dict[str, int]) -> Dict[str, torch.Tensor]: """ Runs snapshot-style encoders that process raw data into embeddings. This is now truly end-to-end. """ device = self.device all_holder_snapshot_embeds = [] # Iterate through each HolderSnapshot event's raw data for raw_holder_list in batch['holder_snapshot_raw_data']: processed_holder_data = [] for holder in raw_holder_list: wallet_addr = holder['wallet'] # Get the graph-updated wallet embedding using its index wallet_idx = wallet_addr_to_batch_idx.get(wallet_addr, 0) # 0 is padding if wallet_idx > 0: # If it's a valid wallet wallet_embedding = final_wallet_embeddings_raw[wallet_idx - 1] # Adjust for 1-based indexing processed_holder_data.append({ 'wallet_embedding': wallet_embedding, 'pct': holder['holding_pct'] }) # Pass the processed data to the HolderDistributionEncoder all_holder_snapshot_embeds.append(self.holder_dist_encoder(processed_holder_data)) return {"holder_snapshot": torch.cat(all_holder_snapshot_embeds, dim=0) if all_holder_snapshot_embeds else torch.empty(0, self.d_model, device=device, dtype=self.dtype)} def _run_dynamic_encoders(self, batch: Dict[str, Any]) -> Dict[str, torch.Tensor]: """ Runs all dynamic encoders and returns a dictionary of raw, unprojected embeddings. """ device = self.device # --- NEW: Get pre-computed embedding indices --- token_encoder_inputs = batch['token_encoder_inputs'] wallet_encoder_inputs = batch['wallet_encoder_inputs'] # The pre-computed embedding pool for the whole batch embedding_pool = torch.nan_to_num( batch['embedding_pool'].to(device, self.dtype), nan=0.0, posinf=0.0, neginf=0.0 ) ohlc_price_tensors = torch.nan_to_num( batch['ohlc_price_tensors'].to(device, self.dtype), nan=0.0, posinf=0.0, neginf=0.0 ) ohlc_interval_ids = batch['ohlc_interval_ids'].to(device) quant_ohlc_feature_tensors = torch.nan_to_num( batch['quant_ohlc_feature_tensors'].to(device, self.dtype), nan=0.0, posinf=0.0, neginf=0.0 ) quant_ohlc_feature_mask = batch['quant_ohlc_feature_mask'].to(device) quant_ohlc_feature_version_ids = batch['quant_ohlc_feature_version_ids'].to(device) graph_updater_links = batch['graph_updater_links'] # 1a. Encode Tokens # --- FIXED: Check for a key that still exists --- if token_encoder_inputs['name_embed_indices'].numel() > 0: # --- NEW: Gather pre-computed embeddings and pass to encoder --- # --- CRITICAL FIX: Remove keys that are not part of the TokenEncoder's signature --- encoder_args = token_encoder_inputs.copy() encoder_args.pop('_addresses_for_lookup', None) # This key is for the WalletEncoder encoder_args.pop('name_embed_indices', None) encoder_args.pop('symbol_embed_indices', None) encoder_args.pop('image_embed_indices', None) # --- SAFETY: Create a padded view of the embedding pool and map missing indices (-1) to pad --- if embedding_pool.numel() > 0: pad_row = torch.zeros(1, embedding_pool.size(1), device=device, dtype=embedding_pool.dtype) pool_padded = torch.cat([pad_row, embedding_pool], dim=0) def pad_and_lookup(idx_tensor: torch.Tensor) -> torch.Tensor: # Map valid indices >=0 to +1 (shift), invalid (<0) to 0 (pad) shifted = torch.where(idx_tensor >= 0, idx_tensor + 1, torch.zeros_like(idx_tensor)) return F.embedding(shifted, pool_padded) name_embeds = pad_and_lookup(token_encoder_inputs['name_embed_indices']) symbol_embeds = pad_and_lookup(token_encoder_inputs['symbol_embed_indices']) image_embeds = pad_and_lookup(token_encoder_inputs['image_embed_indices']) else: # Empty pool: provide zeros with correct shapes n = token_encoder_inputs['name_embed_indices'].shape[0] d = self.multi_modal_dim zeros = torch.zeros(n, d, device=device, dtype=self.dtype) name_embeds = zeros symbol_embeds = zeros image_embeds = zeros batch_token_embeddings_unupd = self.token_encoder( name_embeds=name_embeds, symbol_embeds=symbol_embeds, image_embeds=image_embeds, # Pass all other keys like protocol_ids, is_vanity_flags, etc. **encoder_args ) else: batch_token_embeddings_unupd = torch.empty(0, self.token_encoder.output_dim, device=device, dtype=self.dtype) # 1b. Encode Wallets if wallet_encoder_inputs['profile_rows']: temp_token_lookup = { addr: batch_token_embeddings_unupd[i] for i, addr in enumerate(batch['token_encoder_inputs']['_addresses_for_lookup']) # Use helper key } initial_wallet_embeddings = self.wallet_encoder( **wallet_encoder_inputs, token_vibe_lookup=temp_token_lookup, embedding_pool=embedding_pool ) else: initial_wallet_embeddings = torch.empty(0, self.wallet_encoder.d_model, device=device, dtype=self.dtype) # 1c. Encode OHLC if ohlc_price_tensors.shape[0] > 0: raw_chart_embeddings = self.ohlc_embedder(ohlc_price_tensors, ohlc_interval_ids) else: raw_chart_embeddings = torch.empty(0, self.ohlc_embedder.output_dim, device=device, dtype=self.dtype) if quant_ohlc_feature_tensors.shape[0] > 0: quant_chart_embeddings = self.quant_ohlc_embedder( quant_ohlc_feature_tensors, quant_ohlc_feature_mask, quant_ohlc_feature_version_ids, ) else: quant_chart_embeddings = torch.empty(0, self.quant_ohlc_embedder.output_dim, device=device, dtype=self.dtype) num_chart_segments = max(raw_chart_embeddings.shape[0], quant_chart_embeddings.shape[0]) if num_chart_segments > 0: if raw_chart_embeddings.shape[0] == 0: raw_chart_embeddings = torch.zeros( num_chart_segments, self.ohlc_embedder.output_dim, device=device, dtype=self.dtype, ) if quant_chart_embeddings.shape[0] == 0: quant_chart_embeddings = torch.zeros( num_chart_segments, self.quant_ohlc_embedder.output_dim, device=device, dtype=self.dtype, ) interval_embeds = self.chart_interval_fusion_embedding(ohlc_interval_ids[:num_chart_segments]).to(self.dtype) batch_ohlc_embeddings_raw = self.chart_fusion( torch.cat([raw_chart_embeddings, quant_chart_embeddings, interval_embeds], dim=-1) ) else: batch_ohlc_embeddings_raw = torch.empty(0, self.quant_ohlc_embedder.output_dim, device=device, dtype=self.dtype) # 1d. Run Graph Updater pad_wallet_raw = self.pad_wallet_emb.to(self.dtype) pad_token_raw = self.pad_token_emb.to(self.dtype) padded_wallet_tensor = torch.cat([pad_wallet_raw, initial_wallet_embeddings], dim=0) padded_token_tensor = torch.cat([pad_token_raw, batch_token_embeddings_unupd], dim=0) x_dict_initial = {} if padded_wallet_tensor.shape[0] > 1: x_dict_initial['wallet'] = padded_wallet_tensor if padded_token_tensor.shape[0] > 1: x_dict_initial['token'] = padded_token_tensor if x_dict_initial and graph_updater_links: final_entity_embeddings_dict = self.graph_updater(x_dict_initial, graph_updater_links) final_padded_wallet_embs = final_entity_embeddings_dict.get('wallet', padded_wallet_tensor) final_padded_token_embs = final_entity_embeddings_dict.get('token', padded_token_tensor) else: final_padded_wallet_embs = padded_wallet_tensor final_padded_token_embs = padded_token_tensor # Strip padding before returning final_wallet_embeddings_raw = final_padded_wallet_embs[1:] final_token_embeddings_raw = final_padded_token_embs[1:] return { "wallet": final_wallet_embeddings_raw, "token": final_token_embeddings_raw, "ohlc": batch_ohlc_embeddings_raw } def _project_and_gather_embeddings(self, raw_embeds: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Projects raw embeddings to d_model and gathers them into sequence-aligned tensors. """ # Project raw embeddings to d_model final_wallet_proj = self.wallet_proj(raw_embeds['wallet']) final_token_proj = self.token_proj(raw_embeds['token']) final_ohlc_proj = self.ohlc_proj(raw_embeds['ohlc']) # Project padding embeddings to d_model pad_wallet = self.wallet_proj(self.pad_wallet_emb.to(self.dtype)) pad_token = self.token_proj(self.pad_token_emb.to(self.dtype)) pad_ohlc = self.ohlc_proj(self.pad_ohlc_emb.to(self.dtype)) pad_holder_snapshot = self.pad_holder_snapshot_emb.to(self.dtype) # Already d_model # --- NEW: Project pre-computed embeddings and create lookup --- precomputed_pool = torch.nan_to_num( batch['embedding_pool'].to(self.device, self.dtype), nan=0.0, posinf=0.0, neginf=0.0 ) final_precomputed_proj = self.precomputed_proj(precomputed_pool) pad_precomputed = self.precomputed_proj(self.pad_precomputed_emb.to(self.dtype)) final_precomputed_lookup = torch.cat([pad_precomputed, final_precomputed_proj], dim=0) # Create final lookup tables with padding at index 0 final_wallet_lookup = torch.cat([pad_wallet, final_wallet_proj], dim=0) final_token_lookup = torch.cat([pad_token, final_token_proj], dim=0) final_ohlc_lookup = torch.cat([pad_ohlc, final_ohlc_proj], dim=0) # --- NEW: Add Role Embeddings --- main_role_emb = self.token_role_embedding(torch.tensor(self.main_token_role_id, device=self.device)) quote_role_emb = self.token_role_embedding(torch.tensor(self.quote_token_role_id, device=self.device)) trending_role_emb = self.token_role_embedding(torch.tensor(self.trending_token_role_id, device=self.device)) # Gather base embeddings gathered_main_token_embs = F.embedding(batch['token_indices'], final_token_lookup) gathered_quote_token_embs = F.embedding(batch['quote_token_indices'], final_token_lookup) gathered_trending_token_embs = F.embedding(batch['trending_token_indices'], final_token_lookup) gathered_boosted_token_embs = F.embedding(batch['boosted_token_indices'], final_token_lookup) # --- NEW: Handle HolderSnapshot --- final_holder_snapshot_lookup = torch.cat([pad_holder_snapshot, raw_embeds['holder_snapshot']], dim=0) # Gather embeddings for each event in the sequence return { "wallet": F.embedding(batch['wallet_indices'], final_wallet_lookup), "token": gathered_main_token_embs, # This is the baseline, no role needed "ohlc": F.embedding(batch['ohlc_indices'], final_ohlc_lookup), "original_author": F.embedding(batch['original_author_indices'], final_wallet_lookup), # NEW "dest_wallet": F.embedding(batch['dest_wallet_indices'], final_wallet_lookup), # Also gather dest wallet "quote_token": gathered_quote_token_embs + quote_role_emb, "trending_token": gathered_trending_token_embs + trending_role_emb, "boosted_token": gathered_boosted_token_embs + trending_role_emb, # Same role as trending "holder_snapshot": F.embedding(batch['holder_snapshot_indices'], final_holder_snapshot_lookup), # NEW "precomputed": final_precomputed_lookup # NEW: Pass the full lookup table } def _get_transfer_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor: """ Calculates the special embeddings for Transfer/LargeTransfer events. """ device = self.device transfer_numerical_features = batch['transfer_numerical_features'] event_type_ids = batch['event_type_ids'] # --- FIXED: Selectively log-scale features --- # Log scale: token_amount (idx 0), priority_fee (idx 3) # Linear scale: transfer_pct_of_total_supply (idx 1), transfer_pct_of_holding (idx 2) projected_transfer_features = self._normalize_and_project( transfer_numerical_features, self.transfer_num_norm, self.transfer_num_proj, log_indices=[0, 3] ) # Create a mask for Transfer/LargeTransfer events transfer_event_ids = [self.event_type_to_id.get('Transfer', -1), self.event_type_to_id.get('LargeTransfer', -1)] # ADDED LargeTransfer transfer_mask = torch.isin(event_type_ids, torch.tensor(transfer_event_ids, device=device)).unsqueeze(-1) # Combine destination wallet and numerical features, then apply mask return (gathered_embeds['dest_wallet'] + projected_transfer_features) * transfer_mask def _get_trade_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """ Calculates the special embeddings for Trade events. """ device = self.device trade_numerical_features = batch['trade_numerical_features'] trade_dex_ids = batch['trade_dex_ids'] # NEW trade_direction_ids = batch['trade_direction_ids'] trade_mev_protection_ids = batch['trade_mev_protection_ids'] # NEW trade_is_bundle_ids = batch['trade_is_bundle_ids'] # NEW event_type_ids = batch['event_type_ids'] # --- FIXED: Selectively log-scale features --- # Log scale: sol_amount (idx 0), priority_fee (idx 1), total_usd (idx 7) # Linear scale: pcts, slippage, price_impact, success flags projected_trade_features = self._normalize_and_project( trade_numerical_features, self.trade_num_norm, self.trade_num_proj, log_indices=[0, 1, 7] ) # --- CORRECTED: This layer now handles both generic and large trades --- trade_event_names = ['Trade', 'LargeTrade'] trade_event_ids = [self.event_type_to_id.get(name, -1) for name in trade_event_names] # Create mask where event_type_id is one of the trade event ids trade_mask = torch.isin(event_type_ids, torch.tensor(trade_event_ids, device=device)).unsqueeze(-1) # --- NEW: Get embedding for the categorical dex_id --- dex_id_embeds = self.dex_platform_embedding(trade_dex_ids) direction_embeds = self.trade_direction_embedding(trade_direction_ids) mev_embeds = self.mev_protection_embedding(trade_mev_protection_ids) # NEW bundle_embeds = self.is_bundle_embedding(trade_is_bundle_ids) # NEW return (projected_trade_features + dex_id_embeds + direction_embeds + mev_embeds + bundle_embeds) * trade_mask def _get_deployer_trade_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """ Calculates the special embeddings for Deployer_Trade events using its own layers. """ device = self.device deployer_trade_numerical_features = batch['deployer_trade_numerical_features'] trade_dex_ids = batch['trade_dex_ids'] # NEW: Re-use the same ID tensor trade_direction_ids = batch['trade_direction_ids'] trade_mev_protection_ids = batch['trade_mev_protection_ids'] # NEW trade_is_bundle_ids = batch['trade_is_bundle_ids'] # NEW event_type_ids = batch['event_type_ids'] # --- FIXED: Selectively log-scale features --- # Log scale: sol_amount (idx 0), priority_fee (idx 1), total_usd (idx 7) projected_deployer_trade_features = self._normalize_and_project( deployer_trade_numerical_features, self.deployer_trade_num_norm, self.deployer_trade_num_proj, log_indices=[0, 1, 7] ) dex_id_embeds = self.dex_platform_embedding(trade_dex_ids) direction_embeds = self.trade_direction_embedding(trade_direction_ids) mev_embeds = self.mev_protection_embedding(trade_mev_protection_ids) # NEW bundle_embeds = self.is_bundle_embedding(trade_is_bundle_ids) # NEW deployer_trade_mask = (event_type_ids == self.event_type_to_id.get('Deployer_Trade', -1)).unsqueeze(-1) return (projected_deployer_trade_features + dex_id_embeds + direction_embeds + mev_embeds + bundle_embeds) * deployer_trade_mask def _get_smart_wallet_trade_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """ Calculates the special embeddings for SmartWallet_Trade events using its own layers. """ device = self.device smart_wallet_trade_numerical_features = batch['smart_wallet_trade_numerical_features'] trade_dex_ids = batch['trade_dex_ids'] # NEW: Re-use the same ID tensor trade_direction_ids = batch['trade_direction_ids'] trade_mev_protection_ids = batch['trade_mev_protection_ids'] # NEW trade_is_bundle_ids = batch['trade_is_bundle_ids'] # NEW event_type_ids = batch['event_type_ids'] # --- FIXED: Selectively log-scale features --- # Log scale: sol_amount (idx 0), priority_fee (idx 1), total_usd (idx 7) projected_features = self._normalize_and_project( smart_wallet_trade_numerical_features, self.smart_wallet_trade_num_norm, self.smart_wallet_trade_num_proj, log_indices=[0, 1, 7] ) dex_id_embeds = self.dex_platform_embedding(trade_dex_ids) direction_embeds = self.trade_direction_embedding(trade_direction_ids) mev_embeds = self.mev_protection_embedding(trade_mev_protection_ids) # NEW bundle_embeds = self.is_bundle_embedding(trade_is_bundle_ids) # NEW mask = (event_type_ids == self.event_type_to_id.get('SmartWallet_Trade', -1)).unsqueeze(-1) return (projected_features + dex_id_embeds + direction_embeds + mev_embeds + bundle_embeds) * mask def _get_pool_created_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor: """ Calculates the special embeddings for PoolCreated events. """ device = self.device pool_created_numerical_features = batch['pool_created_numerical_features'] pool_created_protocol_ids = batch['pool_created_protocol_ids'] # NEW event_type_ids = batch['event_type_ids'] # --- FIXED: Selectively log-scale features --- # Log scale: base_amount (idx 0), quote_amount (idx 1) # Linear scale: pcts (idx 2, 3) projected_features = self._normalize_and_project( pool_created_numerical_features, self.pool_created_num_norm, self.pool_created_num_proj, log_indices=[0, 1] ) # --- NEW: Get embedding for the categorical protocol_id --- protocol_id_embeds = self.protocol_embedding(pool_created_protocol_ids) # Create mask for the event mask = (event_type_ids == self.event_type_to_id.get('PoolCreated', -1)).unsqueeze(-1) # Combine Quote Token embedding with projected numericals return (gathered_embeds['quote_token'] + projected_features + protocol_id_embeds) * mask def _get_liquidity_change_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor: """ Calculates the special embeddings for LiquidityChange events. """ device = self.device liquidity_change_numerical_features = batch['liquidity_change_numerical_features'] liquidity_change_type_ids = batch['liquidity_change_type_ids'] # NEW event_type_ids = batch['event_type_ids'] # --- FIXED: Selectively log-scale features --- # Log scale: quote_amount (idx 0) projected_features = self._normalize_and_project( liquidity_change_numerical_features, self.liquidity_change_num_norm, self.liquidity_change_num_proj, log_indices=[0] ) # --- NEW: Get embedding for the categorical change_type_id --- change_type_embeds = self.liquidity_change_type_embedding(liquidity_change_type_ids) # Create mask for the event mask = (event_type_ids == self.event_type_to_id.get('LiquidityChange', -1)).unsqueeze(-1) # Combine Quote Token embedding with projected numericals return (gathered_embeds['quote_token'] + projected_features + change_type_embeds) * mask def _get_fee_collected_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """ Calculates the special embeddings for FeeCollected events. """ device = self.device fee_collected_numerical_features = batch['fee_collected_numerical_features'] event_type_ids = batch['event_type_ids'] # --- FIXED: Single amount, log-scale --- projected_features = self._normalize_and_project( fee_collected_numerical_features, self.fee_collected_num_norm, self.fee_collected_num_proj, log_indices=[0] ) # Create mask for the event mask = (event_type_ids == self.event_type_to_id.get('FeeCollected', -1)).unsqueeze(-1) return projected_features * mask def _get_token_burn_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """ Calculates the special embeddings for TokenBurn events. """ device = self.device token_burn_numerical_features = batch['token_burn_numerical_features'] event_type_ids = batch['event_type_ids'] # --- FIXED: Selectively log-scale features --- # Log scale: amount_tokens_burned (idx 1) # Linear scale: amount_pct_of_total_supply (idx 0) projected_features = self._normalize_and_project( token_burn_numerical_features, self.token_burn_num_norm, self.token_burn_num_proj, log_indices=[1] ) # Create mask for the event mask = (event_type_ids == self.event_type_to_id.get('TokenBurn', -1)).unsqueeze(-1) return projected_features * mask def _get_supply_lock_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """ Calculates the special embeddings for SupplyLock events. """ device = self.device supply_lock_numerical_features = batch['supply_lock_numerical_features'] event_type_ids = batch['event_type_ids'] # --- FIXED: Selectively log-scale features --- # Log scale: lock_duration (idx 1) # Linear scale: amount_pct_of_total_supply (idx 0) projected_features = self._normalize_and_project( supply_lock_numerical_features, self.supply_lock_num_norm, self.supply_lock_num_proj, log_indices=[1] ) # Create mask for the event mask = (event_type_ids == self.event_type_to_id.get('SupplyLock', -1)).unsqueeze(-1) return projected_features * mask def _get_onchain_snapshot_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """ Calculates the special embeddings for OnChain_Snapshot events. """ device = self.device onchain_snapshot_numerical_features = batch['onchain_snapshot_numerical_features'] event_type_ids = batch['event_type_ids'] # --- FIXED: Selectively log-scale features --- # Log scale: counts, market_cap, liquidity, volume, fees (almost all) # Linear scale: growth_rate, holder_pcts (indices 3, 4, 5, 6, 7) projected_features = self._normalize_and_project( onchain_snapshot_numerical_features, self.onchain_snapshot_num_norm, self.onchain_snapshot_num_proj, log_indices=[0, 1, 2, 8, 9, 10, 11, 12, 13] ) # Create mask for the event mask = (event_type_ids == self.event_type_to_id.get('OnChain_Snapshot', -1)).unsqueeze(-1) return projected_features * mask def _get_trending_token_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor: """ Calculates the special embeddings for TrendingToken events. """ device = self.device trending_token_numerical_features = batch['trending_token_numerical_features'] trending_token_source_ids = batch['trending_token_source_ids'] # NEW trending_token_timeframe_ids = batch['trending_token_timeframe_ids'] # NEW event_type_ids = batch['event_type_ids'] # --- FIXED: Rank is already inverted (0-1), so treat as linear --- projected_features = self._normalize_and_project( trending_token_numerical_features, self.trending_token_num_norm, self.trending_token_num_proj, log_indices=None ) # --- NEW: Get embeddings for categorical IDs --- source_embeds = self.trending_list_source_embedding(trending_token_source_ids) timeframe_embeds = self.trending_timeframe_embedding(trending_token_timeframe_ids) # Create mask for the event mask = (event_type_ids == self.event_type_to_id.get('TrendingToken', -1)).unsqueeze(-1) # Combine Trending Token embedding with its projected numericals return (gathered_embeds['trending_token'] + projected_features + source_embeds + timeframe_embeds) * mask def _get_boosted_token_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor: """ Calculates the special embeddings for BoostedToken events. """ device = self.device boosted_token_numerical_features = batch['boosted_token_numerical_features'] event_type_ids = batch['event_type_ids'] # --- FIXED: Selectively log-scale features --- # Log scale: total_boost_amount (idx 0) # Linear scale: inverted rank (idx 1) projected_features = self._normalize_and_project( boosted_token_numerical_features, self.boosted_token_num_norm, self.boosted_token_num_proj, log_indices=[0] ) # Create mask for the event mask = (event_type_ids == self.event_type_to_id.get('BoostedToken', -1)).unsqueeze(-1) # Combine Boosted Token embedding with its projected numericals return (gathered_embeds['boosted_token'] + projected_features) * mask def _get_dexboost_paid_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """ Calculates the special embeddings for DexBoost_Paid events. """ device = self.device dexboost_paid_numerical_features = batch['dexboost_paid_numerical_features'] event_type_ids = batch['event_type_ids'] # --- FIXED: All features are amounts, so log-scale all --- projected_features = self._normalize_and_project( dexboost_paid_numerical_features, self.dexboost_paid_num_norm, self.dexboost_paid_num_proj, log_indices=[0, 1] ) # Create mask for the event mask = (event_type_ids == self.event_type_to_id.get('DexBoost_Paid', -1)).unsqueeze(-1) return projected_features * mask def _get_alphagroup_call_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """ Handles AlphaGroup_Call events by looking up the group_id embedding. """ device = self.device group_ids = batch['alpha_group_ids'] event_type_ids = batch['event_type_ids'] # Look up the embedding for the group ID group_embeds = self.alpha_group_embedding(group_ids) # Create mask for the event mask = (event_type_ids == self.event_type_to_id.get('AlphaGroup_Call', -1)).unsqueeze(-1) return group_embeds * mask def _get_channel_call_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """ Handles Channel_Call events by looking up the channel_id embedding. """ device = self.device channel_ids = batch['channel_ids'] event_type_ids = batch['event_type_ids'] channel_embeds = self.call_channel_embedding(channel_ids) mask = (event_type_ids == self.event_type_to_id.get('Channel_Call', -1)).unsqueeze(-1) return channel_embeds * mask def _get_cexlisting_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """ Handles CexListing events by looking up the exchange_id embedding. """ device = self.device exchange_ids = batch['exchange_ids'] event_type_ids = batch['event_type_ids'] exchange_embeds = self.cex_listing_embedding(exchange_ids) mask = (event_type_ids == self.event_type_to_id.get('CexListing', -1)).unsqueeze(-1) return exchange_embeds * mask def _get_chainsnapshot_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """ Handles ChainSnapshot events. """ device = self.device numerical_features = batch['chainsnapshot_numerical_features'] event_type_ids = batch['event_type_ids'] # --- FIXED: All features are amounts/prices, so log-scale all --- projected_features = self._normalize_and_project( numerical_features, self.chainsnapshot_num_norm, self.chainsnapshot_num_proj, log_indices=[0, 1] ) mask = (event_type_ids == self.event_type_to_id.get('ChainSnapshot', -1)).unsqueeze(-1) return projected_features * mask def _get_lighthousesnapshot_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """ Handles Lighthouse_Snapshot events. """ device = self.device numerical_features = batch['lighthousesnapshot_numerical_features'] protocol_ids = batch['lighthousesnapshot_protocol_ids'] # NEW timeframe_ids = batch['lighthousesnapshot_timeframe_ids'] # NEW event_type_ids = batch['event_type_ids'] # --- FIXED: All features are counts/volumes, so log-scale all --- projected_features = self._normalize_and_project( numerical_features, self.lighthousesnapshot_num_norm, self.lighthousesnapshot_num_proj, log_indices=[0, 1, 2, 3, 4] ) # --- NEW: Get embeddings for categorical IDs --- # Re-use the main protocol embedding layer protocol_embeds = self.protocol_embedding(protocol_ids) timeframe_embeds = self.lighthouse_timeframe_embedding(timeframe_ids) mask = (event_type_ids == self.event_type_to_id.get('Lighthouse_Snapshot', -1)).unsqueeze(-1) return (projected_features + protocol_embeds + timeframe_embeds) * mask def _get_migrated_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """ Handles Migrated events by looking up the protocol_id embedding. """ device = self.device protocol_ids = batch['migrated_protocol_ids'] event_type_ids = batch['event_type_ids'] # Look up the embedding for the protocol ID protocol_embeds = self.protocol_embedding(protocol_ids) # Create mask for the event mask = (event_type_ids == self.event_type_to_id.get('Migrated', -1)).unsqueeze(-1) return protocol_embeds * mask def _get_special_context_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """ Handles special context tokens like 'MIDDLE' and 'RECENT' by adding their unique learnable embeddings. """ device = self.device event_type_ids = batch['event_type_ids'] B, L = event_type_ids.shape middle_id = self.event_type_to_id.get('MIDDLE', -1) recent_id = self.event_type_to_id.get('RECENT', -1) middle_mask = (event_type_ids == middle_id) recent_mask = (event_type_ids == recent_id) middle_emb = self.special_context_embedding(torch.tensor(self.special_context_tokens['MIDDLE'], device=device)) recent_emb = self.special_context_embedding(torch.tensor(self.special_context_tokens['RECENT'], device=device)) # Add the embeddings at the correct locations return middle_mask.unsqueeze(-1) * middle_emb + recent_mask.unsqueeze(-1) * recent_emb def _pool_hidden_states(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: """ Pools variable-length hidden states into a single embedding per sequence by selecting the last non-masked token for each batch element. """ if hidden_states.size(0) == 0: return torch.empty(0, self.d_model, device=hidden_states.device, dtype=hidden_states.dtype) seq_lengths = attention_mask.long().sum(dim=1) last_indices = torch.clamp(seq_lengths - 1, min=0) batch_indices = torch.arange(hidden_states.size(0), device=hidden_states.device) return hidden_states[batch_indices, last_indices] def forward(self, batch: Dict[str, Any]) -> Dict[str, torch.Tensor]: device = self.device # Unpack core sequence tensors event_type_ids = batch['event_type_ids'].to(device) timestamps_float = batch['timestamps_float'].to(device) relative_ts = batch['relative_ts'].to(device, self.dtype) attention_mask = batch['attention_mask'].to(device) B, L = event_type_ids.shape if B == 0 or L == 0: print("Warning: Received empty batch in Oracle forward.") empty_hidden = torch.empty(0, L, self.d_model, device=device, dtype=self.dtype) empty_mask = torch.empty(0, L, device=device, dtype=torch.long) empty_quantiles = torch.empty(0, self.num_outputs, device=device, dtype=self.dtype) empty_quality = torch.empty(0, device=device, dtype=self.dtype) empty_movement = torch.empty(0, len(self.horizons_seconds), self.num_movement_classes, device=device, dtype=self.dtype) return { 'quantile_logits': empty_quantiles, 'quality_logits': empty_quality, 'movement_logits': empty_movement, 'pooled_states': torch.empty(0, self.d_model, device=device, dtype=self.dtype), 'hidden_states': empty_hidden, 'attention_mask': empty_mask } # === 1. Run Dynamic Encoders (produces graph-updated entity embeddings) === dynamic_raw_embeds = self._run_dynamic_encoders(batch) # === 2. Run Snapshot Encoders (uses dynamic_raw_embeds) === wallet_addr_to_batch_idx = batch['wallet_addr_to_batch_idx'] snapshot_raw_embeds = self._run_snapshot_encoders(batch, dynamic_raw_embeds['wallet'], wallet_addr_to_batch_idx) # === 3. Project Raw Embeddings and Gather for Sequence === raw_embeds = {**dynamic_raw_embeds, **snapshot_raw_embeds} gathered_embeds = self._project_and_gather_embeddings(raw_embeds, batch) # === 4. Assemble Final `inputs_embeds` === event_embeds = self.event_type_embedding(event_type_ids) ts_embeds = self.time_proj(self.time_encoder(timestamps_float)) # Stabilize relative time: minutes scale + signed log1p + LayerNorm before projection relative_ts_fp32 = batch['relative_ts'].to(device, torch.float32) rel_ts_minutes = relative_ts_fp32 / 60.0 rel_ts_processed = torch.sign(rel_ts_minutes) * torch.log1p(torch.abs(rel_ts_minutes)) # Match LayerNorm parameter dtype, then match Linear parameter dtype norm_dtype = self.rel_ts_norm.weight.dtype proj_dtype = self.rel_ts_proj.weight.dtype rel_ts_normed = self.rel_ts_norm(rel_ts_processed.to(norm_dtype)) rel_ts_embeds = self.rel_ts_proj(rel_ts_normed.to(proj_dtype)) # Get special embeddings for Transfer events transfer_specific_embeds = self._get_transfer_specific_embeddings(batch, gathered_embeds) # Get special embeddings for Trade events trade_specific_embeds = self._get_trade_specific_embeddings(batch) # Get special embeddings for Deployer Trade events deployer_trade_specific_embeds = self._get_deployer_trade_specific_embeddings(batch) # Get special embeddings for Smart Wallet Trade events smart_wallet_trade_specific_embeds = self._get_smart_wallet_trade_specific_embeddings(batch) # Get special embeddings for PoolCreated events pool_created_specific_embeds = self._get_pool_created_specific_embeddings(batch, gathered_embeds) # Get special embeddings for LiquidityChange events liquidity_change_specific_embeds = self._get_liquidity_change_specific_embeddings(batch, gathered_embeds) # Get special embeddings for FeeCollected events fee_collected_specific_embeds = self._get_fee_collected_specific_embeddings(batch) # Get special embeddings for TokenBurn events token_burn_specific_embeds = self._get_token_burn_specific_embeddings(batch) # Get special embeddings for SupplyLock events supply_lock_specific_embeds = self._get_supply_lock_specific_embeddings(batch) # Get special embeddings for OnChain_Snapshot events onchain_snapshot_specific_embeds = self._get_onchain_snapshot_specific_embeddings(batch) # Get special embeddings for TrendingToken events trending_token_specific_embeds = self._get_trending_token_specific_embeddings(batch, gathered_embeds) # Get special embeddings for BoostedToken events boosted_token_specific_embeds = self._get_boosted_token_specific_embeddings(batch, gathered_embeds) # Get special embeddings for DexBoost_Paid events dexboost_paid_specific_embeds = self._get_dexboost_paid_specific_embeddings(batch) # --- NEW: Get embeddings for Tracker events --- alphagroup_call_specific_embeds = self._get_alphagroup_call_specific_embeddings(batch) channel_call_specific_embeds = self._get_channel_call_specific_embeddings(batch) cexlisting_specific_embeds = self._get_cexlisting_specific_embeddings(batch) # --- NEW: Get embeddings for Chain and Lighthouse Snapshots --- chainsnapshot_specific_embeds = self._get_chainsnapshot_specific_embeddings(batch) lighthousesnapshot_specific_embeds = self._get_lighthousesnapshot_specific_embeddings(batch) migrated_specific_embeds = self._get_migrated_specific_embeddings(batch) # --- NEW: Handle DexProfile_Updated flags separately --- dexprofile_updated_flags = batch['dexprofile_updated_flags'] dexprofile_flags_embeds = self.dexprofile_updated_flags_proj(dexprofile_updated_flags.to(self.dtype)) # --- REFACTORED: All text-based events are handled by the SocialEncoder --- # This single call will replace the inefficient loops for social, dexprofile, and global trending events. # The SocialEncoder's forward pass will need to be updated to handle this. textual_event_embeds = self.social_encoder( batch=batch, gathered_embeds=gathered_embeds ) # --- NEW: Get embeddings for special context injection tokens --- special_context_embeds = self._get_special_context_embeddings(batch) # --- Combine all features --- # Sum in float32 for numerical stability, then cast back to model dtype components = [ event_embeds, ts_embeds, rel_ts_embeds, gathered_embeds['wallet'], gathered_embeds['token'], gathered_embeds['original_author'], gathered_embeds['ohlc'], transfer_specific_embeds, trade_specific_embeds, deployer_trade_specific_embeds, smart_wallet_trade_specific_embeds, pool_created_specific_embeds, liquidity_change_specific_embeds, fee_collected_specific_embeds, token_burn_specific_embeds, supply_lock_specific_embeds, onchain_snapshot_specific_embeds, trending_token_specific_embeds, boosted_token_specific_embeds, dexboost_paid_specific_embeds, alphagroup_call_specific_embeds, channel_call_specific_embeds, cexlisting_specific_embeds, migrated_specific_embeds, special_context_embeds, gathered_embeds['holder_snapshot'], textual_event_embeds, dexprofile_flags_embeds, chainsnapshot_specific_embeds, lighthousesnapshot_specific_embeds ] inputs_embeds = sum([t.float() for t in components]).to(self.dtype) hf_attention_mask = attention_mask.to(device=device, dtype=torch.long) outputs = self.model( inputs_embeds=inputs_embeds, attention_mask=hf_attention_mask, return_dict=True ) sequence_hidden = outputs.last_hidden_state pooled_states = self._pool_hidden_states(sequence_hidden, hf_attention_mask) quantile_logits = self.quantile_head(pooled_states) quality_logits = self.quality_head(pooled_states).squeeze(-1) movement_logits = self.movement_head(pooled_states).view( pooled_states.shape[0], len(self.horizons_seconds), self.num_movement_classes, ) return { 'quantile_logits': quantile_logits, 'quality_logits': quality_logits, 'movement_logits': movement_logits, 'pooled_states': pooled_states, 'hidden_states': sequence_hidden, 'attention_mask': hf_attention_mask }