oracle / models /model.py
zirobtc's picture
Upload folder using huggingface_hub
d195287 verified
# model.py (REFACTORED AND FIXED)
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, LlamaConfig
from typing import List, Dict, Any, Optional, Tuple
import os
import json
# --- NOW, we import all the encoders ---
from models.helper_encoders import ContextualTimeEncoder
from models.token_encoder import TokenEncoder
from models.wallet_encoder import WalletEncoder
from models.graph_updater import GraphUpdater
from models.ohlc_embedder import OHLCEmbedder
from models.quant_ohlc_embedder import QuantOHLCEmbedder
from models.HoldersEncoder import HolderDistributionEncoder # NEW
from models.SocialEncoders import SocialEncoder # NEW
import models.vocabulary as vocab # For vocab sizes
from data.context_targets import MOVEMENT_CLASS_NAMES
class Oracle(nn.Module):
"""
"""
def __init__(self,
token_encoder: TokenEncoder,
wallet_encoder: WalletEncoder,
graph_updater: GraphUpdater,
ohlc_embedder: OHLCEmbedder, # NEW
quant_ohlc_embedder: QuantOHLCEmbedder,
time_encoder: ContextualTimeEncoder,
num_event_types: int,
multi_modal_dim: int,
event_pad_id: int,
event_type_to_id: Dict[str, int],
model_config_name: str = "llama3-12l-768d-gqa4-8k-random",
quantiles: List[float] = [0.1, 0.5, 0.9],
horizons_seconds: List[int] = [30, 60, 120, 240, 420],
dtype: torch.dtype = torch.bfloat16):
super().__init__()
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
self.dtype = dtype
self.multi_modal_dim = multi_modal_dim
self.num_event_types = num_event_types
self.event_pad_id = event_pad_id
self.model_config_name = model_config_name
self.quantiles = quantiles
self.horizons_seconds = horizons_seconds
self.num_outputs = len(quantiles) * len(horizons_seconds)
self.num_movement_classes = len(MOVEMENT_CLASS_NAMES)
self.dtype = dtype
# --- 2. Backbone: Llama-style decoder, RANDOM INIT (no pretrained weights) ---
# This gives you RoPE + modern decoder blocks and lets HF use optimized attention
# implementations (SDPA / FlashAttention) without us implementing a transformer.
#
# Size target: ~80-120M params, suitable for 8k-ish seq caps with your data regime.
attn_impl = os.getenv("HF_ATTN_IMPL", "sdpa") # "sdpa" (safe) or "flash_attention_2" (if installed)
llama_cfg = LlamaConfig(
# Model size
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
# GQA-style KV heads (Llama 3-style efficiency knob)
num_key_value_heads=4,
# Long context (must be >= your effective max sequence length)
max_position_embeddings=8192,
# Llama 3 uses a large theta; harmless for random init and helps longer contexts.
rope_theta=500000.0,
rms_norm_eps=1e-5,
# Unused when providing inputs_embeds, but required by config
vocab_size=32000,
)
self.d_model = llama_cfg.hidden_size
# Older transformers versions may not support attn_implementation in from_config.
# Also, flash_attention_2 requires optional deps; fall back to SDPA if unavailable.
try:
self.model = AutoModel.from_config(llama_cfg, attn_implementation=attn_impl)
except TypeError:
self.model = AutoModel.from_config(llama_cfg)
except Exception:
if attn_impl != "sdpa":
self.model = AutoModel.from_config(llama_cfg, attn_implementation="sdpa")
else:
raise
# Disable KV cache during training (saves memory; not used for full-seq training).
if hasattr(self.model, "config"):
self.model.config.use_cache = False
self.model.to(self.device, dtype=self.dtype)
# Quantile prediction head (maps pooled hidden state -> flattened horizon/quantile grid)
self.quantile_head = nn.Sequential(
nn.Linear(self.d_model, self.d_model),
nn.GELU(),
nn.Linear(self.d_model, self.num_outputs)
)
self.quality_head = nn.Sequential(
nn.Linear(self.d_model, self.d_model),
nn.GELU(),
nn.Linear(self.d_model, 1)
)
self.movement_head = nn.Sequential(
nn.Linear(self.d_model, self.d_model),
nn.GELU(),
nn.Linear(self.d_model, len(self.horizons_seconds) * self.num_movement_classes)
)
self.event_type_to_id = event_type_to_id
# --- 1. Store All Encoders ---
# Define Token Roles before using them
self.token_roles = {'main': 0, 'quote': 1, 'trending': 2} # Add trending for future use
self.main_token_role_id = self.token_roles['main']
self.quote_token_role_id = self.token_roles['quote']
self.trending_token_role_id = self.token_roles['trending']
self.token_encoder = token_encoder
self.wallet_encoder = wallet_encoder
self.graph_updater = graph_updater
self.ohlc_embedder = ohlc_embedder
self.quant_ohlc_embedder = quant_ohlc_embedder
self.time_encoder = time_encoder # Store time_encoder
self.social_encoder = SocialEncoder(d_model=self.d_model, dtype=self.dtype) # Now self.d_model is defined
# --- 4. Define Sequence Feature Embeddings ---
self.event_type_embedding = nn.Embedding(num_event_types, self.d_model, padding_idx=event_pad_id)
# --- NEW: Token Role Embeddings ---
self.token_role_embedding = nn.Embedding(len(self.token_roles), self.d_model)
# --- 5. Define Entity Padding (Learnable) ---
self.pad_wallet_emb = nn.Parameter(torch.zeros(1, self.wallet_encoder.d_model))
self.pad_token_emb = nn.Parameter(torch.zeros(1, self.token_encoder.output_dim))
self.pad_ohlc_emb = nn.Parameter(torch.zeros(1, self.quant_ohlc_embedder.output_dim))
self.pad_precomputed_emb = nn.Parameter(torch.zeros(1, self.multi_modal_dim)) # NEW: For text/images
# --- NEW: Instantiate HolderDistributionEncoder internally ---
self.holder_dist_encoder = HolderDistributionEncoder(
wallet_embedding_dim=self.wallet_encoder.d_model,
output_dim=self.d_model,
dtype=self.dtype # Pass the correct dtype
)
self.pad_holder_snapshot_emb = nn.Parameter(torch.zeros(1, self.d_model)) # Output of holder_dist_encoder is d_model
# --- 6. Define Projection MLPs ---
self.time_proj = nn.Linear(self.time_encoder.projection.out_features, self.d_model)
self.rel_ts_proj = nn.Linear(1, self.d_model)
self.rel_ts_norm = nn.LayerNorm(1)
self.wallet_proj = nn.Linear(self.wallet_encoder.d_model, self.d_model)
self.token_proj = nn.Linear(self.token_encoder.output_dim, self.d_model)
self.ohlc_proj = nn.Linear(self.quant_ohlc_embedder.output_dim, self.d_model)
self.chart_interval_fusion_embedding = nn.Embedding(vocab.NUM_OHLC_INTERVALS, 32, padding_idx=0)
fusion_input_dim = self.ohlc_embedder.output_dim + self.quant_ohlc_embedder.output_dim + 32
self.chart_fusion = nn.Sequential(
nn.Linear(fusion_input_dim, self.quant_ohlc_embedder.output_dim),
nn.GELU(),
nn.LayerNorm(self.quant_ohlc_embedder.output_dim),
nn.Linear(self.quant_ohlc_embedder.output_dim, self.quant_ohlc_embedder.output_dim),
nn.LayerNorm(self.quant_ohlc_embedder.output_dim),
)
# self.holder_snapshot_proj is no longer needed as HolderDistributionEncoder outputs directly to d_model
# --- NEW: Layers for Transfer Numerical Features ---
self.transfer_num_norm = nn.LayerNorm(4) # Normalize the 4 features
self.transfer_num_proj = nn.Linear(4, self.d_model) # Project to d_model
# --- NEW: Layers for Trade Numerical Features ---
# --- FIXED: Size reduced from 10 to 8 ---
self.trade_num_norm = nn.LayerNorm(8)
self.trade_num_proj = nn.Linear(8, self.d_model)
# --- NEW: Embedding for categorical dex_platform_id ---
self.dex_platform_embedding = nn.Embedding(vocab.NUM_DEX_PLATFORMS, self.d_model)
# --- NEW: Embedding for categorical trade_direction ---
self.trade_direction_embedding = nn.Embedding(2, self.d_model) # 0 for buy, 1 for sell
# --- FIXED: Embedding for categorical mev_protection is now binary ---
self.mev_protection_embedding = nn.Embedding(2, self.d_model) # 0 for false, 1 for true
# --- NEW: Embedding for categorical is_bundle ---
self.is_bundle_embedding = nn.Embedding(2, self.d_model) # 0 for false, 1 for true
# --- NEW: Separate Layers for Deployer Trade Numerical Features ---
# --- FIXED: Size reduced from 10 to 8 ---
self.deployer_trade_num_norm = nn.LayerNorm(8)
self.deployer_trade_num_proj = nn.Linear(8, self.d_model)
# --- NEW: Separate Layers for Smart Wallet Trade Numerical Features ---
# --- FIXED: Size reduced from 10 to 8 ---
self.smart_wallet_trade_num_norm = nn.LayerNorm(8)
self.smart_wallet_trade_num_proj = nn.Linear(8, self.d_model)
# --- NEW: Layers for PoolCreated Numerical Features ---
# --- FIXED: Size reduced from 5 to 4 ---
self.pool_created_num_norm = nn.LayerNorm(2)
self.pool_created_num_proj = nn.Linear(2, self.d_model)
# --- NEW: Layers for LiquidityChange Numerical Features ---
# --- FIXED: Size reduced from 3 to 2 ---
self.liquidity_change_num_norm = nn.LayerNorm(1)
self.liquidity_change_num_proj = nn.Linear(1, self.d_model)
# --- NEW: Embedding for categorical change_type_id ---
# --- FIXED: Hardcoded the number of types (add/remove) as per user instruction ---
self.liquidity_change_type_embedding = nn.Embedding(2, self.d_model)
# --- NEW: Layers for FeeCollected Numerical Features ---
self.fee_collected_num_norm = nn.LayerNorm(1) # sol_amount only
self.fee_collected_num_proj = nn.Linear(1, self.d_model)
# --- NEW: Layers for TokenBurn Numerical Features ---
self.token_burn_num_norm = nn.LayerNorm(2) # amount_pct, amount_tokens
self.token_burn_num_proj = nn.Linear(2, self.d_model)
# --- NEW: Layers for SupplyLock Numerical Features ---
self.supply_lock_num_norm = nn.LayerNorm(2) # amount_pct, lock_duration
self.supply_lock_num_proj = nn.Linear(2, self.d_model)
# --- NEW: Layers for OnChain_Snapshot Numerical Features ---
self.onchain_snapshot_num_norm = nn.LayerNorm(14)
self.onchain_snapshot_num_proj = nn.Linear(14, self.d_model)
# --- NEW: Layers for TrendingToken Numerical Features ---
# --- FIXED: Size reduced from 3 to 1 (rank only) ---
self.trending_token_num_norm = nn.LayerNorm(1)
self.trending_token_num_proj = nn.Linear(1, self.d_model)
# --- NEW: Embeddings for categorical IDs ---
self.trending_list_source_embedding = nn.Embedding(vocab.NUM_TRENDING_LIST_SOURCES, self.d_model)
self.trending_timeframe_embedding = nn.Embedding(vocab.NUM_TRENDING_LIST_TIMEFRAMES, self.d_model)
# --- NEW: Layers for BoostedToken Numerical Features ---
self.boosted_token_num_norm = nn.LayerNorm(2) # total_boost_amount, rank
self.boosted_token_num_proj = nn.Linear(2, self.d_model)
# --- NEW: Layers for DexBoost_Paid Numerical Features ---
self.dexboost_paid_num_norm = nn.LayerNorm(2) # amount, total_amount_on_token
self.dexboost_paid_num_proj = nn.Linear(2, self.d_model)
# --- NEW: Layers for DexProfile_Updated Features ---
self.dexprofile_updated_flags_proj = nn.Linear(4, self.d_model) # Project the 4 boolean flags
# --- NEW: Projection for all pre-computed embeddings (text/images) ---
self.precomputed_proj = nn.Linear(self.multi_modal_dim, self.d_model)
# --- NEW: Embedding for Protocol IDs (used in Migrated event) ---
self.protocol_embedding = nn.Embedding(vocab.NUM_PROTOCOLS, self.d_model)
# --- NEW: Embeddings for TrackerEncoder Events ---
# Note: NUM_CALL_CHANNELS might need to be large and managed as vocab grows.
self.alpha_group_embedding = nn.Embedding(vocab.NUM_ALPHA_GROUPS, self.d_model)
self.call_channel_embedding = nn.Embedding(vocab.NUM_CALL_CHANNELS, self.d_model)
self.cex_listing_embedding = nn.Embedding(vocab.NUM_EXCHANGES, self.d_model)
# --- NEW: Layers for GlobalTrendingEncoder Events ---
self.global_trending_num_norm = nn.LayerNorm(1) # rank
self.global_trending_num_proj = nn.Linear(1, self.d_model)
# --- NEW: Layers for ChainSnapshot Events ---
self.chainsnapshot_num_norm = nn.LayerNorm(2) # native_token_price_usd, gas_fee
self.chainsnapshot_num_proj = nn.Linear(2, self.d_model)
# --- NEW: Layers for Lighthouse_Snapshot Events ---
# --- FIXED: Size reduced from 7 to 5 ---
self.lighthousesnapshot_num_norm = nn.LayerNorm(5)
self.lighthousesnapshot_num_proj = nn.Linear(5, self.d_model)
# --- NEW: Embedding for timeframe ID (re-uses protocol_embedding) ---
self.lighthouse_timeframe_embedding = nn.Embedding(vocab.NUM_LIGHTHOUSE_TIMEFRAMES, self.d_model)
# --- Embeddings for Special Context Tokens ---
# Must match vocabulary event names (see models/vocabulary.py).
self.special_context_tokens = {'MIDDLE': 0, 'RECENT': 1}
self.special_context_embedding = nn.Embedding(len(self.special_context_tokens), self.d_model)
# --- 7. Prediction Head --- (Unchanged)
# self.prediction_head = nn.Linear(self.d_model, self.num_outputs)
# --- 8. Move all new modules to correct dtype ---
self.to(dtype)
print("Oracle model (full pipeline) initialized.")
def save_pretrained(self, save_directory: str):
"""
Saves the model in a Hugging Face-compatible way.
"""
if not os.path.exists(save_directory):
os.makedirs(save_directory)
# 1. Save the inner transformer model using its own save_pretrained
# This gives us the standard HF config.json and pytorch_model.bin for the backbone
self.model.save_pretrained(save_directory)
# 2. Save the whole Oracle state dict (includes transformer + all custom encoders)
# We use 'oracle_model.bin' for the full state.
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
# 3. Save Oracle specific metadata for reconstruction
oracle_config = {
"num_event_types": self.num_event_types,
"multi_modal_dim": self.multi_modal_dim,
"event_pad_id": self.event_pad_id,
"model_config_name": self.model_config_name,
"quantiles": self.quantiles,
"horizons_seconds": self.horizons_seconds,
"dtype": str(self.dtype),
"event_type_to_id": self.event_type_to_id
}
with open(os.path.join(save_directory, "oracle_config.json"), "w") as f:
json.dump(oracle_config, f, indent=2)
print(f"✅ Oracle model saved to {save_directory}")
@classmethod
def from_pretrained(cls, load_directory: str,
token_encoder, wallet_encoder, graph_updater, ohlc_embedder, quant_ohlc_embedder, time_encoder):
"""
Loads the Oracle model from a saved directory.
Note: You must still provide the initialized sub-encoders (or we can refactor to save them too).
"""
config_path = os.path.join(load_directory, "oracle_config.json")
with open(config_path, "r") as f:
config = json.load(f)
# Determine dtype from string
dtype = torch.bfloat16 # Default
if "float32" in config["dtype"]: dtype = torch.float32
elif "float16" in config["dtype"]: dtype = torch.float16
# Instantiate model
model = cls(
token_encoder=token_encoder,
wallet_encoder=wallet_encoder,
graph_updater=graph_updater,
ohlc_embedder=ohlc_embedder,
quant_ohlc_embedder=quant_ohlc_embedder,
time_encoder=time_encoder,
num_event_types=config["num_event_types"],
multi_modal_dim=config["multi_modal_dim"],
event_pad_id=config["event_pad_id"],
event_type_to_id=config["event_type_to_id"],
model_config_name=config["model_config_name"],
quantiles=config["quantiles"],
horizons_seconds=config["horizons_seconds"],
dtype=dtype
)
# Load weights
weight_path = os.path.join(load_directory, "pytorch_model.bin")
state_dict = torch.load(weight_path, map_location="cpu")
model.load_state_dict(state_dict)
print(f"✅ Oracle model loaded from {load_directory}")
return model
def _normalize_and_project(self,
features: torch.Tensor,
norm_layer: nn.LayerNorm,
proj_layer: nn.Linear,
log_indices: Optional[List[int]] = None) -> torch.Tensor:
"""
A helper function to selectively apply log scaling, then normalize and project.
"""
processed_features = torch.nan_to_num(
features.to(torch.float32),
nan=0.0,
posinf=1e6,
neginf=-1e6
)
# Apply log scaling only to specified indices
if log_indices:
# Ensure log_indices are valid
valid_indices = [i for i in log_indices if i < processed_features.shape[-1]]
if valid_indices:
log_features = processed_features[:, :, valid_indices]
log_scaled = torch.sign(log_features) * torch.log1p(torch.abs(log_features))
processed_features[:, :, valid_indices] = log_scaled
# Normalize and project the entire feature set
norm_dtype = norm_layer.weight.dtype
proj_dtype = proj_layer.weight.dtype
normed_features = norm_layer(processed_features.to(norm_dtype))
normed_features = torch.nan_to_num(normed_features, nan=0.0, posinf=0.0, neginf=0.0)
return proj_layer(normed_features.to(proj_dtype))
def _run_snapshot_encoders(self,
batch: Dict[str, Any],
final_wallet_embeddings_raw: torch.Tensor,
wallet_addr_to_batch_idx: Dict[str, int]) -> Dict[str, torch.Tensor]:
"""
Runs snapshot-style encoders that process raw data into embeddings.
This is now truly end-to-end.
"""
device = self.device
all_holder_snapshot_embeds = []
# Iterate through each HolderSnapshot event's raw data
for raw_holder_list in batch['holder_snapshot_raw_data']:
processed_holder_data = []
for holder in raw_holder_list:
wallet_addr = holder['wallet']
# Get the graph-updated wallet embedding using its index
wallet_idx = wallet_addr_to_batch_idx.get(wallet_addr, 0) # 0 is padding
if wallet_idx > 0: # If it's a valid wallet
wallet_embedding = final_wallet_embeddings_raw[wallet_idx - 1] # Adjust for 1-based indexing
processed_holder_data.append({
'wallet_embedding': wallet_embedding,
'pct': holder['holding_pct']
})
# Pass the processed data to the HolderDistributionEncoder
all_holder_snapshot_embeds.append(self.holder_dist_encoder(processed_holder_data))
return {"holder_snapshot": torch.cat(all_holder_snapshot_embeds, dim=0) if all_holder_snapshot_embeds else torch.empty(0, self.d_model, device=device, dtype=self.dtype)}
def _run_dynamic_encoders(self, batch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
"""
Runs all dynamic encoders and returns a dictionary of raw, unprojected embeddings.
"""
device = self.device
# --- NEW: Get pre-computed embedding indices ---
token_encoder_inputs = batch['token_encoder_inputs']
wallet_encoder_inputs = batch['wallet_encoder_inputs']
# The pre-computed embedding pool for the whole batch
embedding_pool = torch.nan_to_num(
batch['embedding_pool'].to(device, self.dtype),
nan=0.0,
posinf=0.0,
neginf=0.0
)
ohlc_price_tensors = torch.nan_to_num(
batch['ohlc_price_tensors'].to(device, self.dtype),
nan=0.0,
posinf=0.0,
neginf=0.0
)
ohlc_interval_ids = batch['ohlc_interval_ids'].to(device)
quant_ohlc_feature_tensors = torch.nan_to_num(
batch['quant_ohlc_feature_tensors'].to(device, self.dtype),
nan=0.0,
posinf=0.0,
neginf=0.0
)
quant_ohlc_feature_mask = batch['quant_ohlc_feature_mask'].to(device)
quant_ohlc_feature_version_ids = batch['quant_ohlc_feature_version_ids'].to(device)
graph_updater_links = batch['graph_updater_links']
# 1a. Encode Tokens
# --- FIXED: Check for a key that still exists ---
if token_encoder_inputs['name_embed_indices'].numel() > 0:
# --- NEW: Gather pre-computed embeddings and pass to encoder ---
# --- CRITICAL FIX: Remove keys that are not part of the TokenEncoder's signature ---
encoder_args = token_encoder_inputs.copy()
encoder_args.pop('_addresses_for_lookup', None) # This key is for the WalletEncoder
encoder_args.pop('name_embed_indices', None)
encoder_args.pop('symbol_embed_indices', None)
encoder_args.pop('image_embed_indices', None)
# --- SAFETY: Create a padded view of the embedding pool and map missing indices (-1) to pad ---
if embedding_pool.numel() > 0:
pad_row = torch.zeros(1, embedding_pool.size(1), device=device, dtype=embedding_pool.dtype)
pool_padded = torch.cat([pad_row, embedding_pool], dim=0)
def pad_and_lookup(idx_tensor: torch.Tensor) -> torch.Tensor:
# Map valid indices >=0 to +1 (shift), invalid (<0) to 0 (pad)
shifted = torch.where(idx_tensor >= 0, idx_tensor + 1, torch.zeros_like(idx_tensor))
return F.embedding(shifted, pool_padded)
name_embeds = pad_and_lookup(token_encoder_inputs['name_embed_indices'])
symbol_embeds = pad_and_lookup(token_encoder_inputs['symbol_embed_indices'])
image_embeds = pad_and_lookup(token_encoder_inputs['image_embed_indices'])
else:
# Empty pool: provide zeros with correct shapes
n = token_encoder_inputs['name_embed_indices'].shape[0]
d = self.multi_modal_dim
zeros = torch.zeros(n, d, device=device, dtype=self.dtype)
name_embeds = zeros
symbol_embeds = zeros
image_embeds = zeros
batch_token_embeddings_unupd = self.token_encoder(
name_embeds=name_embeds,
symbol_embeds=symbol_embeds,
image_embeds=image_embeds,
# Pass all other keys like protocol_ids, is_vanity_flags, etc.
**encoder_args
)
else:
batch_token_embeddings_unupd = torch.empty(0, self.token_encoder.output_dim, device=device, dtype=self.dtype)
# 1b. Encode Wallets
if wallet_encoder_inputs['profile_rows']:
temp_token_lookup = {
addr: batch_token_embeddings_unupd[i]
for i, addr in enumerate(batch['token_encoder_inputs']['_addresses_for_lookup']) # Use helper key
}
initial_wallet_embeddings = self.wallet_encoder(
**wallet_encoder_inputs,
token_vibe_lookup=temp_token_lookup,
embedding_pool=embedding_pool
)
else:
initial_wallet_embeddings = torch.empty(0, self.wallet_encoder.d_model, device=device, dtype=self.dtype)
# 1c. Encode OHLC
if ohlc_price_tensors.shape[0] > 0:
raw_chart_embeddings = self.ohlc_embedder(ohlc_price_tensors, ohlc_interval_ids)
else:
raw_chart_embeddings = torch.empty(0, self.ohlc_embedder.output_dim, device=device, dtype=self.dtype)
if quant_ohlc_feature_tensors.shape[0] > 0:
quant_chart_embeddings = self.quant_ohlc_embedder(
quant_ohlc_feature_tensors,
quant_ohlc_feature_mask,
quant_ohlc_feature_version_ids,
)
else:
quant_chart_embeddings = torch.empty(0, self.quant_ohlc_embedder.output_dim, device=device, dtype=self.dtype)
num_chart_segments = max(raw_chart_embeddings.shape[0], quant_chart_embeddings.shape[0])
if num_chart_segments > 0:
if raw_chart_embeddings.shape[0] == 0:
raw_chart_embeddings = torch.zeros(
num_chart_segments,
self.ohlc_embedder.output_dim,
device=device,
dtype=self.dtype,
)
if quant_chart_embeddings.shape[0] == 0:
quant_chart_embeddings = torch.zeros(
num_chart_segments,
self.quant_ohlc_embedder.output_dim,
device=device,
dtype=self.dtype,
)
interval_embeds = self.chart_interval_fusion_embedding(ohlc_interval_ids[:num_chart_segments]).to(self.dtype)
batch_ohlc_embeddings_raw = self.chart_fusion(
torch.cat([raw_chart_embeddings, quant_chart_embeddings, interval_embeds], dim=-1)
)
else:
batch_ohlc_embeddings_raw = torch.empty(0, self.quant_ohlc_embedder.output_dim, device=device, dtype=self.dtype)
# 1d. Run Graph Updater
pad_wallet_raw = self.pad_wallet_emb.to(self.dtype)
pad_token_raw = self.pad_token_emb.to(self.dtype)
padded_wallet_tensor = torch.cat([pad_wallet_raw, initial_wallet_embeddings], dim=0)
padded_token_tensor = torch.cat([pad_token_raw, batch_token_embeddings_unupd], dim=0)
x_dict_initial = {}
if padded_wallet_tensor.shape[0] > 1: x_dict_initial['wallet'] = padded_wallet_tensor
if padded_token_tensor.shape[0] > 1: x_dict_initial['token'] = padded_token_tensor
if x_dict_initial and graph_updater_links:
final_entity_embeddings_dict = self.graph_updater(x_dict_initial, graph_updater_links)
final_padded_wallet_embs = final_entity_embeddings_dict.get('wallet', padded_wallet_tensor)
final_padded_token_embs = final_entity_embeddings_dict.get('token', padded_token_tensor)
else:
final_padded_wallet_embs = padded_wallet_tensor
final_padded_token_embs = padded_token_tensor
# Strip padding before returning
final_wallet_embeddings_raw = final_padded_wallet_embs[1:]
final_token_embeddings_raw = final_padded_token_embs[1:]
return {
"wallet": final_wallet_embeddings_raw,
"token": final_token_embeddings_raw,
"ohlc": batch_ohlc_embeddings_raw
}
def _project_and_gather_embeddings(self, raw_embeds: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Projects raw embeddings to d_model and gathers them into sequence-aligned tensors.
"""
# Project raw embeddings to d_model
final_wallet_proj = self.wallet_proj(raw_embeds['wallet'])
final_token_proj = self.token_proj(raw_embeds['token'])
final_ohlc_proj = self.ohlc_proj(raw_embeds['ohlc'])
# Project padding embeddings to d_model
pad_wallet = self.wallet_proj(self.pad_wallet_emb.to(self.dtype))
pad_token = self.token_proj(self.pad_token_emb.to(self.dtype))
pad_ohlc = self.ohlc_proj(self.pad_ohlc_emb.to(self.dtype))
pad_holder_snapshot = self.pad_holder_snapshot_emb.to(self.dtype) # Already d_model
# --- NEW: Project pre-computed embeddings and create lookup ---
precomputed_pool = torch.nan_to_num(
batch['embedding_pool'].to(self.device, self.dtype),
nan=0.0,
posinf=0.0,
neginf=0.0
)
final_precomputed_proj = self.precomputed_proj(precomputed_pool)
pad_precomputed = self.precomputed_proj(self.pad_precomputed_emb.to(self.dtype))
final_precomputed_lookup = torch.cat([pad_precomputed, final_precomputed_proj], dim=0)
# Create final lookup tables with padding at index 0
final_wallet_lookup = torch.cat([pad_wallet, final_wallet_proj], dim=0)
final_token_lookup = torch.cat([pad_token, final_token_proj], dim=0)
final_ohlc_lookup = torch.cat([pad_ohlc, final_ohlc_proj], dim=0)
# --- NEW: Add Role Embeddings ---
main_role_emb = self.token_role_embedding(torch.tensor(self.main_token_role_id, device=self.device))
quote_role_emb = self.token_role_embedding(torch.tensor(self.quote_token_role_id, device=self.device))
trending_role_emb = self.token_role_embedding(torch.tensor(self.trending_token_role_id, device=self.device))
# Gather base embeddings
gathered_main_token_embs = F.embedding(batch['token_indices'], final_token_lookup)
gathered_quote_token_embs = F.embedding(batch['quote_token_indices'], final_token_lookup)
gathered_trending_token_embs = F.embedding(batch['trending_token_indices'], final_token_lookup)
gathered_boosted_token_embs = F.embedding(batch['boosted_token_indices'], final_token_lookup)
# --- NEW: Handle HolderSnapshot ---
final_holder_snapshot_lookup = torch.cat([pad_holder_snapshot, raw_embeds['holder_snapshot']], dim=0)
# Gather embeddings for each event in the sequence
return {
"wallet": F.embedding(batch['wallet_indices'], final_wallet_lookup),
"token": gathered_main_token_embs, # This is the baseline, no role needed
"ohlc": F.embedding(batch['ohlc_indices'], final_ohlc_lookup),
"original_author": F.embedding(batch['original_author_indices'], final_wallet_lookup), # NEW
"dest_wallet": F.embedding(batch['dest_wallet_indices'], final_wallet_lookup), # Also gather dest wallet
"quote_token": gathered_quote_token_embs + quote_role_emb,
"trending_token": gathered_trending_token_embs + trending_role_emb,
"boosted_token": gathered_boosted_token_embs + trending_role_emb, # Same role as trending
"holder_snapshot": F.embedding(batch['holder_snapshot_indices'], final_holder_snapshot_lookup), # NEW
"precomputed": final_precomputed_lookup # NEW: Pass the full lookup table
}
def _get_transfer_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculates the special embeddings for Transfer/LargeTransfer events.
"""
device = self.device
transfer_numerical_features = batch['transfer_numerical_features']
event_type_ids = batch['event_type_ids']
# --- FIXED: Selectively log-scale features ---
# Log scale: token_amount (idx 0), priority_fee (idx 3)
# Linear scale: transfer_pct_of_total_supply (idx 1), transfer_pct_of_holding (idx 2)
projected_transfer_features = self._normalize_and_project(
transfer_numerical_features, self.transfer_num_norm, self.transfer_num_proj, log_indices=[0, 3]
)
# Create a mask for Transfer/LargeTransfer events
transfer_event_ids = [self.event_type_to_id.get('Transfer', -1), self.event_type_to_id.get('LargeTransfer', -1)] # ADDED LargeTransfer
transfer_mask = torch.isin(event_type_ids, torch.tensor(transfer_event_ids, device=device)).unsqueeze(-1)
# Combine destination wallet and numerical features, then apply mask
return (gathered_embeds['dest_wallet'] + projected_transfer_features) * transfer_mask
def _get_trade_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculates the special embeddings for Trade events.
"""
device = self.device
trade_numerical_features = batch['trade_numerical_features']
trade_dex_ids = batch['trade_dex_ids'] # NEW
trade_direction_ids = batch['trade_direction_ids']
trade_mev_protection_ids = batch['trade_mev_protection_ids'] # NEW
trade_is_bundle_ids = batch['trade_is_bundle_ids'] # NEW
event_type_ids = batch['event_type_ids']
# --- FIXED: Selectively log-scale features ---
# Log scale: sol_amount (idx 0), priority_fee (idx 1), total_usd (idx 7)
# Linear scale: pcts, slippage, price_impact, success flags
projected_trade_features = self._normalize_and_project(
trade_numerical_features, self.trade_num_norm, self.trade_num_proj, log_indices=[0, 1, 7]
)
# --- CORRECTED: This layer now handles both generic and large trades ---
trade_event_names = ['Trade', 'LargeTrade']
trade_event_ids = [self.event_type_to_id.get(name, -1) for name in trade_event_names]
# Create mask where event_type_id is one of the trade event ids
trade_mask = torch.isin(event_type_ids, torch.tensor(trade_event_ids, device=device)).unsqueeze(-1)
# --- NEW: Get embedding for the categorical dex_id ---
dex_id_embeds = self.dex_platform_embedding(trade_dex_ids)
direction_embeds = self.trade_direction_embedding(trade_direction_ids)
mev_embeds = self.mev_protection_embedding(trade_mev_protection_ids) # NEW
bundle_embeds = self.is_bundle_embedding(trade_is_bundle_ids) # NEW
return (projected_trade_features + dex_id_embeds + direction_embeds + mev_embeds + bundle_embeds) * trade_mask
def _get_deployer_trade_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculates the special embeddings for Deployer_Trade events using its own layers.
"""
device = self.device
deployer_trade_numerical_features = batch['deployer_trade_numerical_features']
trade_dex_ids = batch['trade_dex_ids'] # NEW: Re-use the same ID tensor
trade_direction_ids = batch['trade_direction_ids']
trade_mev_protection_ids = batch['trade_mev_protection_ids'] # NEW
trade_is_bundle_ids = batch['trade_is_bundle_ids'] # NEW
event_type_ids = batch['event_type_ids']
# --- FIXED: Selectively log-scale features ---
# Log scale: sol_amount (idx 0), priority_fee (idx 1), total_usd (idx 7)
projected_deployer_trade_features = self._normalize_and_project(
deployer_trade_numerical_features, self.deployer_trade_num_norm, self.deployer_trade_num_proj, log_indices=[0, 1, 7]
)
dex_id_embeds = self.dex_platform_embedding(trade_dex_ids)
direction_embeds = self.trade_direction_embedding(trade_direction_ids)
mev_embeds = self.mev_protection_embedding(trade_mev_protection_ids) # NEW
bundle_embeds = self.is_bundle_embedding(trade_is_bundle_ids) # NEW
deployer_trade_mask = (event_type_ids == self.event_type_to_id.get('Deployer_Trade', -1)).unsqueeze(-1)
return (projected_deployer_trade_features + dex_id_embeds + direction_embeds + mev_embeds + bundle_embeds) * deployer_trade_mask
def _get_smart_wallet_trade_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculates the special embeddings for SmartWallet_Trade events using its own layers.
"""
device = self.device
smart_wallet_trade_numerical_features = batch['smart_wallet_trade_numerical_features']
trade_dex_ids = batch['trade_dex_ids'] # NEW: Re-use the same ID tensor
trade_direction_ids = batch['trade_direction_ids']
trade_mev_protection_ids = batch['trade_mev_protection_ids'] # NEW
trade_is_bundle_ids = batch['trade_is_bundle_ids'] # NEW
event_type_ids = batch['event_type_ids']
# --- FIXED: Selectively log-scale features ---
# Log scale: sol_amount (idx 0), priority_fee (idx 1), total_usd (idx 7)
projected_features = self._normalize_and_project(
smart_wallet_trade_numerical_features, self.smart_wallet_trade_num_norm, self.smart_wallet_trade_num_proj, log_indices=[0, 1, 7]
)
dex_id_embeds = self.dex_platform_embedding(trade_dex_ids)
direction_embeds = self.trade_direction_embedding(trade_direction_ids)
mev_embeds = self.mev_protection_embedding(trade_mev_protection_ids) # NEW
bundle_embeds = self.is_bundle_embedding(trade_is_bundle_ids) # NEW
mask = (event_type_ids == self.event_type_to_id.get('SmartWallet_Trade', -1)).unsqueeze(-1)
return (projected_features + dex_id_embeds + direction_embeds + mev_embeds + bundle_embeds) * mask
def _get_pool_created_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculates the special embeddings for PoolCreated events.
"""
device = self.device
pool_created_numerical_features = batch['pool_created_numerical_features']
pool_created_protocol_ids = batch['pool_created_protocol_ids'] # NEW
event_type_ids = batch['event_type_ids']
# --- FIXED: Selectively log-scale features ---
# Log scale: base_amount (idx 0), quote_amount (idx 1)
# Linear scale: pcts (idx 2, 3)
projected_features = self._normalize_and_project(
pool_created_numerical_features, self.pool_created_num_norm, self.pool_created_num_proj, log_indices=[0, 1]
)
# --- NEW: Get embedding for the categorical protocol_id ---
protocol_id_embeds = self.protocol_embedding(pool_created_protocol_ids)
# Create mask for the event
mask = (event_type_ids == self.event_type_to_id.get('PoolCreated', -1)).unsqueeze(-1)
# Combine Quote Token embedding with projected numericals
return (gathered_embeds['quote_token'] + projected_features + protocol_id_embeds) * mask
def _get_liquidity_change_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculates the special embeddings for LiquidityChange events.
"""
device = self.device
liquidity_change_numerical_features = batch['liquidity_change_numerical_features']
liquidity_change_type_ids = batch['liquidity_change_type_ids'] # NEW
event_type_ids = batch['event_type_ids']
# --- FIXED: Selectively log-scale features ---
# Log scale: quote_amount (idx 0)
projected_features = self._normalize_and_project(
liquidity_change_numerical_features, self.liquidity_change_num_norm, self.liquidity_change_num_proj, log_indices=[0]
)
# --- NEW: Get embedding for the categorical change_type_id ---
change_type_embeds = self.liquidity_change_type_embedding(liquidity_change_type_ids)
# Create mask for the event
mask = (event_type_ids == self.event_type_to_id.get('LiquidityChange', -1)).unsqueeze(-1)
# Combine Quote Token embedding with projected numericals
return (gathered_embeds['quote_token'] + projected_features + change_type_embeds) * mask
def _get_fee_collected_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculates the special embeddings for FeeCollected events.
"""
device = self.device
fee_collected_numerical_features = batch['fee_collected_numerical_features']
event_type_ids = batch['event_type_ids']
# --- FIXED: Single amount, log-scale ---
projected_features = self._normalize_and_project(
fee_collected_numerical_features, self.fee_collected_num_norm, self.fee_collected_num_proj, log_indices=[0]
)
# Create mask for the event
mask = (event_type_ids == self.event_type_to_id.get('FeeCollected', -1)).unsqueeze(-1)
return projected_features * mask
def _get_token_burn_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculates the special embeddings for TokenBurn events.
"""
device = self.device
token_burn_numerical_features = batch['token_burn_numerical_features']
event_type_ids = batch['event_type_ids']
# --- FIXED: Selectively log-scale features ---
# Log scale: amount_tokens_burned (idx 1)
# Linear scale: amount_pct_of_total_supply (idx 0)
projected_features = self._normalize_and_project(
token_burn_numerical_features, self.token_burn_num_norm, self.token_burn_num_proj, log_indices=[1]
)
# Create mask for the event
mask = (event_type_ids == self.event_type_to_id.get('TokenBurn', -1)).unsqueeze(-1)
return projected_features * mask
def _get_supply_lock_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculates the special embeddings for SupplyLock events.
"""
device = self.device
supply_lock_numerical_features = batch['supply_lock_numerical_features']
event_type_ids = batch['event_type_ids']
# --- FIXED: Selectively log-scale features ---
# Log scale: lock_duration (idx 1)
# Linear scale: amount_pct_of_total_supply (idx 0)
projected_features = self._normalize_and_project(
supply_lock_numerical_features, self.supply_lock_num_norm, self.supply_lock_num_proj, log_indices=[1]
)
# Create mask for the event
mask = (event_type_ids == self.event_type_to_id.get('SupplyLock', -1)).unsqueeze(-1)
return projected_features * mask
def _get_onchain_snapshot_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculates the special embeddings for OnChain_Snapshot events.
"""
device = self.device
onchain_snapshot_numerical_features = batch['onchain_snapshot_numerical_features']
event_type_ids = batch['event_type_ids']
# --- FIXED: Selectively log-scale features ---
# Log scale: counts, market_cap, liquidity, volume, fees (almost all)
# Linear scale: growth_rate, holder_pcts (indices 3, 4, 5, 6, 7)
projected_features = self._normalize_and_project(
onchain_snapshot_numerical_features, self.onchain_snapshot_num_norm, self.onchain_snapshot_num_proj, log_indices=[0, 1, 2, 8, 9, 10, 11, 12, 13]
)
# Create mask for the event
mask = (event_type_ids == self.event_type_to_id.get('OnChain_Snapshot', -1)).unsqueeze(-1)
return projected_features * mask
def _get_trending_token_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculates the special embeddings for TrendingToken events.
"""
device = self.device
trending_token_numerical_features = batch['trending_token_numerical_features']
trending_token_source_ids = batch['trending_token_source_ids'] # NEW
trending_token_timeframe_ids = batch['trending_token_timeframe_ids'] # NEW
event_type_ids = batch['event_type_ids']
# --- FIXED: Rank is already inverted (0-1), so treat as linear ---
projected_features = self._normalize_and_project(
trending_token_numerical_features, self.trending_token_num_norm, self.trending_token_num_proj, log_indices=None
)
# --- NEW: Get embeddings for categorical IDs ---
source_embeds = self.trending_list_source_embedding(trending_token_source_ids)
timeframe_embeds = self.trending_timeframe_embedding(trending_token_timeframe_ids)
# Create mask for the event
mask = (event_type_ids == self.event_type_to_id.get('TrendingToken', -1)).unsqueeze(-1)
# Combine Trending Token embedding with its projected numericals
return (gathered_embeds['trending_token'] + projected_features + source_embeds + timeframe_embeds) * mask
def _get_boosted_token_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculates the special embeddings for BoostedToken events.
"""
device = self.device
boosted_token_numerical_features = batch['boosted_token_numerical_features']
event_type_ids = batch['event_type_ids']
# --- FIXED: Selectively log-scale features ---
# Log scale: total_boost_amount (idx 0)
# Linear scale: inverted rank (idx 1)
projected_features = self._normalize_and_project(
boosted_token_numerical_features, self.boosted_token_num_norm, self.boosted_token_num_proj, log_indices=[0]
)
# Create mask for the event
mask = (event_type_ids == self.event_type_to_id.get('BoostedToken', -1)).unsqueeze(-1)
# Combine Boosted Token embedding with its projected numericals
return (gathered_embeds['boosted_token'] + projected_features) * mask
def _get_dexboost_paid_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculates the special embeddings for DexBoost_Paid events.
"""
device = self.device
dexboost_paid_numerical_features = batch['dexboost_paid_numerical_features']
event_type_ids = batch['event_type_ids']
# --- FIXED: All features are amounts, so log-scale all ---
projected_features = self._normalize_and_project(
dexboost_paid_numerical_features, self.dexboost_paid_num_norm, self.dexboost_paid_num_proj, log_indices=[0, 1]
)
# Create mask for the event
mask = (event_type_ids == self.event_type_to_id.get('DexBoost_Paid', -1)).unsqueeze(-1)
return projected_features * mask
def _get_alphagroup_call_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Handles AlphaGroup_Call events by looking up the group_id embedding.
"""
device = self.device
group_ids = batch['alpha_group_ids']
event_type_ids = batch['event_type_ids']
# Look up the embedding for the group ID
group_embeds = self.alpha_group_embedding(group_ids)
# Create mask for the event
mask = (event_type_ids == self.event_type_to_id.get('AlphaGroup_Call', -1)).unsqueeze(-1)
return group_embeds * mask
def _get_channel_call_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Handles Channel_Call events by looking up the channel_id embedding.
"""
device = self.device
channel_ids = batch['channel_ids']
event_type_ids = batch['event_type_ids']
channel_embeds = self.call_channel_embedding(channel_ids)
mask = (event_type_ids == self.event_type_to_id.get('Channel_Call', -1)).unsqueeze(-1)
return channel_embeds * mask
def _get_cexlisting_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Handles CexListing events by looking up the exchange_id embedding.
"""
device = self.device
exchange_ids = batch['exchange_ids']
event_type_ids = batch['event_type_ids']
exchange_embeds = self.cex_listing_embedding(exchange_ids)
mask = (event_type_ids == self.event_type_to_id.get('CexListing', -1)).unsqueeze(-1)
return exchange_embeds * mask
def _get_chainsnapshot_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Handles ChainSnapshot events.
"""
device = self.device
numerical_features = batch['chainsnapshot_numerical_features']
event_type_ids = batch['event_type_ids']
# --- FIXED: All features are amounts/prices, so log-scale all ---
projected_features = self._normalize_and_project(
numerical_features, self.chainsnapshot_num_norm, self.chainsnapshot_num_proj, log_indices=[0, 1]
)
mask = (event_type_ids == self.event_type_to_id.get('ChainSnapshot', -1)).unsqueeze(-1)
return projected_features * mask
def _get_lighthousesnapshot_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Handles Lighthouse_Snapshot events.
"""
device = self.device
numerical_features = batch['lighthousesnapshot_numerical_features']
protocol_ids = batch['lighthousesnapshot_protocol_ids'] # NEW
timeframe_ids = batch['lighthousesnapshot_timeframe_ids'] # NEW
event_type_ids = batch['event_type_ids']
# --- FIXED: All features are counts/volumes, so log-scale all ---
projected_features = self._normalize_and_project(
numerical_features, self.lighthousesnapshot_num_norm, self.lighthousesnapshot_num_proj, log_indices=[0, 1, 2, 3, 4]
)
# --- NEW: Get embeddings for categorical IDs ---
# Re-use the main protocol embedding layer
protocol_embeds = self.protocol_embedding(protocol_ids)
timeframe_embeds = self.lighthouse_timeframe_embedding(timeframe_ids)
mask = (event_type_ids == self.event_type_to_id.get('Lighthouse_Snapshot', -1)).unsqueeze(-1)
return (projected_features + protocol_embeds + timeframe_embeds) * mask
def _get_migrated_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Handles Migrated events by looking up the protocol_id embedding.
"""
device = self.device
protocol_ids = batch['migrated_protocol_ids']
event_type_ids = batch['event_type_ids']
# Look up the embedding for the protocol ID
protocol_embeds = self.protocol_embedding(protocol_ids)
# Create mask for the event
mask = (event_type_ids == self.event_type_to_id.get('Migrated', -1)).unsqueeze(-1)
return protocol_embeds * mask
def _get_special_context_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Handles special context tokens like 'MIDDLE' and 'RECENT' by adding their unique learnable embeddings.
"""
device = self.device
event_type_ids = batch['event_type_ids']
B, L = event_type_ids.shape
middle_id = self.event_type_to_id.get('MIDDLE', -1)
recent_id = self.event_type_to_id.get('RECENT', -1)
middle_mask = (event_type_ids == middle_id)
recent_mask = (event_type_ids == recent_id)
middle_emb = self.special_context_embedding(torch.tensor(self.special_context_tokens['MIDDLE'], device=device))
recent_emb = self.special_context_embedding(torch.tensor(self.special_context_tokens['RECENT'], device=device))
# Add the embeddings at the correct locations
return middle_mask.unsqueeze(-1) * middle_emb + recent_mask.unsqueeze(-1) * recent_emb
def _pool_hidden_states(self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor) -> torch.Tensor:
"""
Pools variable-length hidden states into a single embedding per sequence by
selecting the last non-masked token for each batch element.
"""
if hidden_states.size(0) == 0:
return torch.empty(0, self.d_model, device=hidden_states.device, dtype=hidden_states.dtype)
seq_lengths = attention_mask.long().sum(dim=1)
last_indices = torch.clamp(seq_lengths - 1, min=0)
batch_indices = torch.arange(hidden_states.size(0), device=hidden_states.device)
return hidden_states[batch_indices, last_indices]
def forward(self, batch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
device = self.device
# Unpack core sequence tensors
event_type_ids = batch['event_type_ids'].to(device)
timestamps_float = batch['timestamps_float'].to(device)
relative_ts = batch['relative_ts'].to(device, self.dtype)
attention_mask = batch['attention_mask'].to(device)
B, L = event_type_ids.shape
if B == 0 or L == 0:
print("Warning: Received empty batch in Oracle forward.")
empty_hidden = torch.empty(0, L, self.d_model, device=device, dtype=self.dtype)
empty_mask = torch.empty(0, L, device=device, dtype=torch.long)
empty_quantiles = torch.empty(0, self.num_outputs, device=device, dtype=self.dtype)
empty_quality = torch.empty(0, device=device, dtype=self.dtype)
empty_movement = torch.empty(0, len(self.horizons_seconds), self.num_movement_classes, device=device, dtype=self.dtype)
return {
'quantile_logits': empty_quantiles,
'quality_logits': empty_quality,
'movement_logits': empty_movement,
'pooled_states': torch.empty(0, self.d_model, device=device, dtype=self.dtype),
'hidden_states': empty_hidden,
'attention_mask': empty_mask
}
# === 1. Run Dynamic Encoders (produces graph-updated entity embeddings) ===
dynamic_raw_embeds = self._run_dynamic_encoders(batch)
# === 2. Run Snapshot Encoders (uses dynamic_raw_embeds) ===
wallet_addr_to_batch_idx = batch['wallet_addr_to_batch_idx']
snapshot_raw_embeds = self._run_snapshot_encoders(batch, dynamic_raw_embeds['wallet'], wallet_addr_to_batch_idx)
# === 3. Project Raw Embeddings and Gather for Sequence ===
raw_embeds = {**dynamic_raw_embeds, **snapshot_raw_embeds}
gathered_embeds = self._project_and_gather_embeddings(raw_embeds, batch)
# === 4. Assemble Final `inputs_embeds` ===
event_embeds = self.event_type_embedding(event_type_ids)
ts_embeds = self.time_proj(self.time_encoder(timestamps_float))
# Stabilize relative time: minutes scale + signed log1p + LayerNorm before projection
relative_ts_fp32 = batch['relative_ts'].to(device, torch.float32)
rel_ts_minutes = relative_ts_fp32 / 60.0
rel_ts_processed = torch.sign(rel_ts_minutes) * torch.log1p(torch.abs(rel_ts_minutes))
# Match LayerNorm parameter dtype, then match Linear parameter dtype
norm_dtype = self.rel_ts_norm.weight.dtype
proj_dtype = self.rel_ts_proj.weight.dtype
rel_ts_normed = self.rel_ts_norm(rel_ts_processed.to(norm_dtype))
rel_ts_embeds = self.rel_ts_proj(rel_ts_normed.to(proj_dtype))
# Get special embeddings for Transfer events
transfer_specific_embeds = self._get_transfer_specific_embeddings(batch, gathered_embeds)
# Get special embeddings for Trade events
trade_specific_embeds = self._get_trade_specific_embeddings(batch)
# Get special embeddings for Deployer Trade events
deployer_trade_specific_embeds = self._get_deployer_trade_specific_embeddings(batch)
# Get special embeddings for Smart Wallet Trade events
smart_wallet_trade_specific_embeds = self._get_smart_wallet_trade_specific_embeddings(batch)
# Get special embeddings for PoolCreated events
pool_created_specific_embeds = self._get_pool_created_specific_embeddings(batch, gathered_embeds)
# Get special embeddings for LiquidityChange events
liquidity_change_specific_embeds = self._get_liquidity_change_specific_embeddings(batch, gathered_embeds)
# Get special embeddings for FeeCollected events
fee_collected_specific_embeds = self._get_fee_collected_specific_embeddings(batch)
# Get special embeddings for TokenBurn events
token_burn_specific_embeds = self._get_token_burn_specific_embeddings(batch)
# Get special embeddings for SupplyLock events
supply_lock_specific_embeds = self._get_supply_lock_specific_embeddings(batch)
# Get special embeddings for OnChain_Snapshot events
onchain_snapshot_specific_embeds = self._get_onchain_snapshot_specific_embeddings(batch)
# Get special embeddings for TrendingToken events
trending_token_specific_embeds = self._get_trending_token_specific_embeddings(batch, gathered_embeds)
# Get special embeddings for BoostedToken events
boosted_token_specific_embeds = self._get_boosted_token_specific_embeddings(batch, gathered_embeds)
# Get special embeddings for DexBoost_Paid events
dexboost_paid_specific_embeds = self._get_dexboost_paid_specific_embeddings(batch)
# --- NEW: Get embeddings for Tracker events ---
alphagroup_call_specific_embeds = self._get_alphagroup_call_specific_embeddings(batch)
channel_call_specific_embeds = self._get_channel_call_specific_embeddings(batch)
cexlisting_specific_embeds = self._get_cexlisting_specific_embeddings(batch)
# --- NEW: Get embeddings for Chain and Lighthouse Snapshots ---
chainsnapshot_specific_embeds = self._get_chainsnapshot_specific_embeddings(batch)
lighthousesnapshot_specific_embeds = self._get_lighthousesnapshot_specific_embeddings(batch)
migrated_specific_embeds = self._get_migrated_specific_embeddings(batch)
# --- NEW: Handle DexProfile_Updated flags separately ---
dexprofile_updated_flags = batch['dexprofile_updated_flags']
dexprofile_flags_embeds = self.dexprofile_updated_flags_proj(dexprofile_updated_flags.to(self.dtype))
# --- REFACTORED: All text-based events are handled by the SocialEncoder ---
# This single call will replace the inefficient loops for social, dexprofile, and global trending events.
# The SocialEncoder's forward pass will need to be updated to handle this.
textual_event_embeds = self.social_encoder(
batch=batch,
gathered_embeds=gathered_embeds
)
# --- NEW: Get embeddings for special context injection tokens ---
special_context_embeds = self._get_special_context_embeddings(batch)
# --- Combine all features ---
# Sum in float32 for numerical stability, then cast back to model dtype
components = [
event_embeds, ts_embeds, rel_ts_embeds,
gathered_embeds['wallet'], gathered_embeds['token'], gathered_embeds['original_author'], gathered_embeds['ohlc'],
transfer_specific_embeds, trade_specific_embeds, deployer_trade_specific_embeds, smart_wallet_trade_specific_embeds,
pool_created_specific_embeds, liquidity_change_specific_embeds, fee_collected_specific_embeds,
token_burn_specific_embeds, supply_lock_specific_embeds, onchain_snapshot_specific_embeds,
trending_token_specific_embeds, boosted_token_specific_embeds, dexboost_paid_specific_embeds,
alphagroup_call_specific_embeds, channel_call_specific_embeds, cexlisting_specific_embeds,
migrated_specific_embeds, special_context_embeds, gathered_embeds['holder_snapshot'], textual_event_embeds,
dexprofile_flags_embeds, chainsnapshot_specific_embeds, lighthousesnapshot_specific_embeds
]
inputs_embeds = sum([t.float() for t in components]).to(self.dtype)
hf_attention_mask = attention_mask.to(device=device, dtype=torch.long)
outputs = self.model(
inputs_embeds=inputs_embeds,
attention_mask=hf_attention_mask,
return_dict=True
)
sequence_hidden = outputs.last_hidden_state
pooled_states = self._pool_hidden_states(sequence_hidden, hf_attention_mask)
quantile_logits = self.quantile_head(pooled_states)
quality_logits = self.quality_head(pooled_states).squeeze(-1)
movement_logits = self.movement_head(pooled_states).view(
pooled_states.shape[0],
len(self.horizons_seconds),
self.num_movement_classes,
)
return {
'quantile_logits': quantile_logits,
'quality_logits': quality_logits,
'movement_logits': movement_logits,
'pooled_states': pooled_states,
'hidden_states': sequence_hidden,
'attention_mask': hf_attention_mask
}