| |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import AutoModel, LlamaConfig |
| from typing import List, Dict, Any, Optional, Tuple |
| import os |
| import json |
|
|
| |
| from models.helper_encoders import ContextualTimeEncoder |
| from models.token_encoder import TokenEncoder |
| from models.wallet_encoder import WalletEncoder |
| from models.graph_updater import GraphUpdater |
| from models.ohlc_embedder import OHLCEmbedder |
| from models.quant_ohlc_embedder import QuantOHLCEmbedder |
| from models.HoldersEncoder import HolderDistributionEncoder |
| from models.SocialEncoders import SocialEncoder |
| import models.vocabulary as vocab |
| from data.context_targets import MOVEMENT_CLASS_NAMES |
|
|
| class Oracle(nn.Module): |
| """ |
| |
| """ |
| def __init__(self, |
| token_encoder: TokenEncoder, |
| wallet_encoder: WalletEncoder, |
| graph_updater: GraphUpdater, |
| ohlc_embedder: OHLCEmbedder, |
| quant_ohlc_embedder: QuantOHLCEmbedder, |
| time_encoder: ContextualTimeEncoder, |
| num_event_types: int, |
| multi_modal_dim: int, |
| event_pad_id: int, |
| event_type_to_id: Dict[str, int], |
| model_config_name: str = "llama3-12l-768d-gqa4-8k-random", |
| quantiles: List[float] = [0.1, 0.5, 0.9], |
| horizons_seconds: List[int] = [30, 60, 120, 240, 420], |
| dtype: torch.dtype = torch.bfloat16): |
| |
| super().__init__() |
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.device = torch.device(device) |
| self.dtype = dtype |
| self.multi_modal_dim = multi_modal_dim |
|
|
|
|
| self.num_event_types = num_event_types |
| self.event_pad_id = event_pad_id |
| self.model_config_name = model_config_name |
| self.quantiles = quantiles |
| self.horizons_seconds = horizons_seconds |
| self.num_outputs = len(quantiles) * len(horizons_seconds) |
| self.num_movement_classes = len(MOVEMENT_CLASS_NAMES) |
| self.dtype = dtype |
| |
| |
| |
| |
| |
| |
| attn_impl = os.getenv("HF_ATTN_IMPL", "sdpa") |
| llama_cfg = LlamaConfig( |
| |
| hidden_size=768, |
| intermediate_size=3072, |
| num_hidden_layers=12, |
| num_attention_heads=12, |
| |
| num_key_value_heads=4, |
| |
| max_position_embeddings=8192, |
| |
| rope_theta=500000.0, |
| rms_norm_eps=1e-5, |
| |
| vocab_size=32000, |
| ) |
| self.d_model = llama_cfg.hidden_size |
| |
| |
| try: |
| self.model = AutoModel.from_config(llama_cfg, attn_implementation=attn_impl) |
| except TypeError: |
| self.model = AutoModel.from_config(llama_cfg) |
| except Exception: |
| if attn_impl != "sdpa": |
| self.model = AutoModel.from_config(llama_cfg, attn_implementation="sdpa") |
| else: |
| raise |
| |
| if hasattr(self.model, "config"): |
| self.model.config.use_cache = False |
| self.model.to(self.device, dtype=self.dtype) |
|
|
| |
| self.quantile_head = nn.Sequential( |
| nn.Linear(self.d_model, self.d_model), |
| nn.GELU(), |
| nn.Linear(self.d_model, self.num_outputs) |
| ) |
| self.quality_head = nn.Sequential( |
| nn.Linear(self.d_model, self.d_model), |
| nn.GELU(), |
| nn.Linear(self.d_model, 1) |
| ) |
| self.movement_head = nn.Sequential( |
| nn.Linear(self.d_model, self.d_model), |
| nn.GELU(), |
| nn.Linear(self.d_model, len(self.horizons_seconds) * self.num_movement_classes) |
| ) |
|
|
| self.event_type_to_id = event_type_to_id |
|
|
| |
| |
| self.token_roles = {'main': 0, 'quote': 1, 'trending': 2} |
| self.main_token_role_id = self.token_roles['main'] |
| self.quote_token_role_id = self.token_roles['quote'] |
| self.trending_token_role_id = self.token_roles['trending'] |
| |
| |
| self.token_encoder = token_encoder |
| self.wallet_encoder = wallet_encoder |
| self.graph_updater = graph_updater |
| self.ohlc_embedder = ohlc_embedder |
| self.quant_ohlc_embedder = quant_ohlc_embedder |
| self.time_encoder = time_encoder |
|
|
| self.social_encoder = SocialEncoder(d_model=self.d_model, dtype=self.dtype) |
| |
| |
| self.event_type_embedding = nn.Embedding(num_event_types, self.d_model, padding_idx=event_pad_id) |
| |
| |
| self.token_role_embedding = nn.Embedding(len(self.token_roles), self.d_model) |
|
|
|
|
|
|
| |
| self.pad_wallet_emb = nn.Parameter(torch.zeros(1, self.wallet_encoder.d_model)) |
| self.pad_token_emb = nn.Parameter(torch.zeros(1, self.token_encoder.output_dim)) |
| self.pad_ohlc_emb = nn.Parameter(torch.zeros(1, self.quant_ohlc_embedder.output_dim)) |
| self.pad_precomputed_emb = nn.Parameter(torch.zeros(1, self.multi_modal_dim)) |
| |
| |
| self.holder_dist_encoder = HolderDistributionEncoder( |
| wallet_embedding_dim=self.wallet_encoder.d_model, |
| output_dim=self.d_model, |
| dtype=self.dtype |
| ) |
| self.pad_holder_snapshot_emb = nn.Parameter(torch.zeros(1, self.d_model)) |
| |
| |
| self.time_proj = nn.Linear(self.time_encoder.projection.out_features, self.d_model) |
| self.rel_ts_proj = nn.Linear(1, self.d_model) |
| self.rel_ts_norm = nn.LayerNorm(1) |
| self.wallet_proj = nn.Linear(self.wallet_encoder.d_model, self.d_model) |
| self.token_proj = nn.Linear(self.token_encoder.output_dim, self.d_model) |
| self.ohlc_proj = nn.Linear(self.quant_ohlc_embedder.output_dim, self.d_model) |
| self.chart_interval_fusion_embedding = nn.Embedding(vocab.NUM_OHLC_INTERVALS, 32, padding_idx=0) |
| fusion_input_dim = self.ohlc_embedder.output_dim + self.quant_ohlc_embedder.output_dim + 32 |
| self.chart_fusion = nn.Sequential( |
| nn.Linear(fusion_input_dim, self.quant_ohlc_embedder.output_dim), |
| nn.GELU(), |
| nn.LayerNorm(self.quant_ohlc_embedder.output_dim), |
| nn.Linear(self.quant_ohlc_embedder.output_dim, self.quant_ohlc_embedder.output_dim), |
| nn.LayerNorm(self.quant_ohlc_embedder.output_dim), |
| ) |
| |
| |
| |
| |
| self.transfer_num_norm = nn.LayerNorm(4) |
| self.transfer_num_proj = nn.Linear(4, self.d_model) |
| |
| |
| |
| self.trade_num_norm = nn.LayerNorm(8) |
| self.trade_num_proj = nn.Linear(8, self.d_model) |
| |
| self.dex_platform_embedding = nn.Embedding(vocab.NUM_DEX_PLATFORMS, self.d_model) |
| |
| self.trade_direction_embedding = nn.Embedding(2, self.d_model) |
| |
| self.mev_protection_embedding = nn.Embedding(2, self.d_model) |
| |
| self.is_bundle_embedding = nn.Embedding(2, self.d_model) |
|
|
| |
| |
| self.deployer_trade_num_norm = nn.LayerNorm(8) |
| self.deployer_trade_num_proj = nn.Linear(8, self.d_model) |
|
|
| |
| |
| self.smart_wallet_trade_num_norm = nn.LayerNorm(8) |
| self.smart_wallet_trade_num_proj = nn.Linear(8, self.d_model) |
|
|
| |
| |
| self.pool_created_num_norm = nn.LayerNorm(2) |
| self.pool_created_num_proj = nn.Linear(2, self.d_model) |
|
|
| |
| |
| self.liquidity_change_num_norm = nn.LayerNorm(1) |
| self.liquidity_change_num_proj = nn.Linear(1, self.d_model) |
| |
| |
| self.liquidity_change_type_embedding = nn.Embedding(2, self.d_model) |
|
|
| |
| self.fee_collected_num_norm = nn.LayerNorm(1) |
| self.fee_collected_num_proj = nn.Linear(1, self.d_model) |
|
|
| |
| self.token_burn_num_norm = nn.LayerNorm(2) |
| self.token_burn_num_proj = nn.Linear(2, self.d_model) |
|
|
| |
| self.supply_lock_num_norm = nn.LayerNorm(2) |
| self.supply_lock_num_proj = nn.Linear(2, self.d_model) |
|
|
| |
| self.onchain_snapshot_num_norm = nn.LayerNorm(14) |
| self.onchain_snapshot_num_proj = nn.Linear(14, self.d_model) |
|
|
| |
| |
| self.trending_token_num_norm = nn.LayerNorm(1) |
| self.trending_token_num_proj = nn.Linear(1, self.d_model) |
| |
| self.trending_list_source_embedding = nn.Embedding(vocab.NUM_TRENDING_LIST_SOURCES, self.d_model) |
| self.trending_timeframe_embedding = nn.Embedding(vocab.NUM_TRENDING_LIST_TIMEFRAMES, self.d_model) |
|
|
| |
| self.boosted_token_num_norm = nn.LayerNorm(2) |
| self.boosted_token_num_proj = nn.Linear(2, self.d_model) |
| |
| |
| self.dexboost_paid_num_norm = nn.LayerNorm(2) |
| self.dexboost_paid_num_proj = nn.Linear(2, self.d_model) |
|
|
| |
| self.dexprofile_updated_flags_proj = nn.Linear(4, self.d_model) |
|
|
| |
| self.precomputed_proj = nn.Linear(self.multi_modal_dim, self.d_model) |
|
|
| |
| self.protocol_embedding = nn.Embedding(vocab.NUM_PROTOCOLS, self.d_model) |
|
|
| |
| |
| self.alpha_group_embedding = nn.Embedding(vocab.NUM_ALPHA_GROUPS, self.d_model) |
| self.call_channel_embedding = nn.Embedding(vocab.NUM_CALL_CHANNELS, self.d_model) |
| self.cex_listing_embedding = nn.Embedding(vocab.NUM_EXCHANGES, self.d_model) |
|
|
| |
| self.global_trending_num_norm = nn.LayerNorm(1) |
| self.global_trending_num_proj = nn.Linear(1, self.d_model) |
|
|
| |
| self.chainsnapshot_num_norm = nn.LayerNorm(2) |
| self.chainsnapshot_num_proj = nn.Linear(2, self.d_model) |
|
|
| |
| |
| self.lighthousesnapshot_num_norm = nn.LayerNorm(5) |
| self.lighthousesnapshot_num_proj = nn.Linear(5, self.d_model) |
| |
| self.lighthouse_timeframe_embedding = nn.Embedding(vocab.NUM_LIGHTHOUSE_TIMEFRAMES, self.d_model) |
|
|
| |
| |
| self.special_context_tokens = {'MIDDLE': 0, 'RECENT': 1} |
| self.special_context_embedding = nn.Embedding(len(self.special_context_tokens), self.d_model) |
|
|
|
|
| |
| |
|
|
| |
| self.to(dtype) |
| print("Oracle model (full pipeline) initialized.") |
|
|
| def save_pretrained(self, save_directory: str): |
| """ |
| Saves the model in a Hugging Face-compatible way. |
| """ |
| if not os.path.exists(save_directory): |
| os.makedirs(save_directory) |
|
|
| |
| |
| self.model.save_pretrained(save_directory) |
|
|
| |
| |
| torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin")) |
|
|
| |
| oracle_config = { |
| "num_event_types": self.num_event_types, |
| "multi_modal_dim": self.multi_modal_dim, |
| "event_pad_id": self.event_pad_id, |
| "model_config_name": self.model_config_name, |
| "quantiles": self.quantiles, |
| "horizons_seconds": self.horizons_seconds, |
| "dtype": str(self.dtype), |
| "event_type_to_id": self.event_type_to_id |
| } |
| with open(os.path.join(save_directory, "oracle_config.json"), "w") as f: |
| json.dump(oracle_config, f, indent=2) |
| |
| print(f"✅ Oracle model saved to {save_directory}") |
|
|
| @classmethod |
| def from_pretrained(cls, load_directory: str, |
| token_encoder, wallet_encoder, graph_updater, ohlc_embedder, quant_ohlc_embedder, time_encoder): |
| """ |
| Loads the Oracle model from a saved directory. |
| Note: You must still provide the initialized sub-encoders (or we can refactor to save them too). |
| """ |
| config_path = os.path.join(load_directory, "oracle_config.json") |
| with open(config_path, "r") as f: |
| config = json.load(f) |
| |
| |
| dtype = torch.bfloat16 |
| if "float32" in config["dtype"]: dtype = torch.float32 |
| elif "float16" in config["dtype"]: dtype = torch.float16 |
| |
| |
| model = cls( |
| token_encoder=token_encoder, |
| wallet_encoder=wallet_encoder, |
| graph_updater=graph_updater, |
| ohlc_embedder=ohlc_embedder, |
| quant_ohlc_embedder=quant_ohlc_embedder, |
| time_encoder=time_encoder, |
| num_event_types=config["num_event_types"], |
| multi_modal_dim=config["multi_modal_dim"], |
| event_pad_id=config["event_pad_id"], |
| event_type_to_id=config["event_type_to_id"], |
| model_config_name=config["model_config_name"], |
| quantiles=config["quantiles"], |
| horizons_seconds=config["horizons_seconds"], |
| dtype=dtype |
| ) |
| |
| |
| weight_path = os.path.join(load_directory, "pytorch_model.bin") |
| state_dict = torch.load(weight_path, map_location="cpu") |
| model.load_state_dict(state_dict) |
| print(f"✅ Oracle model loaded from {load_directory}") |
| return model |
|
|
| def _normalize_and_project(self, |
| features: torch.Tensor, |
| norm_layer: nn.LayerNorm, |
| proj_layer: nn.Linear, |
| log_indices: Optional[List[int]] = None) -> torch.Tensor: |
| """ |
| A helper function to selectively apply log scaling, then normalize and project. |
| """ |
| processed_features = torch.nan_to_num( |
| features.to(torch.float32), |
| nan=0.0, |
| posinf=1e6, |
| neginf=-1e6 |
| ) |
|
|
| |
| if log_indices: |
| |
| valid_indices = [i for i in log_indices if i < processed_features.shape[-1]] |
| if valid_indices: |
| log_features = processed_features[:, :, valid_indices] |
| log_scaled = torch.sign(log_features) * torch.log1p(torch.abs(log_features)) |
| processed_features[:, :, valid_indices] = log_scaled |
|
|
| |
| norm_dtype = norm_layer.weight.dtype |
| proj_dtype = proj_layer.weight.dtype |
| normed_features = norm_layer(processed_features.to(norm_dtype)) |
| normed_features = torch.nan_to_num(normed_features, nan=0.0, posinf=0.0, neginf=0.0) |
| return proj_layer(normed_features.to(proj_dtype)) |
|
|
| def _run_snapshot_encoders(self, |
| batch: Dict[str, Any], |
| final_wallet_embeddings_raw: torch.Tensor, |
| wallet_addr_to_batch_idx: Dict[str, int]) -> Dict[str, torch.Tensor]: |
| """ |
| Runs snapshot-style encoders that process raw data into embeddings. |
| This is now truly end-to-end. |
| """ |
| device = self.device |
| all_holder_snapshot_embeds = [] |
| |
| |
| for raw_holder_list in batch['holder_snapshot_raw_data']: |
| processed_holder_data = [] |
| for holder in raw_holder_list: |
| wallet_addr = holder['wallet'] |
| |
| wallet_idx = wallet_addr_to_batch_idx.get(wallet_addr, 0) |
| if wallet_idx > 0: |
| wallet_embedding = final_wallet_embeddings_raw[wallet_idx - 1] |
| processed_holder_data.append({ |
| 'wallet_embedding': wallet_embedding, |
| 'pct': holder['holding_pct'] |
| }) |
| |
| all_holder_snapshot_embeds.append(self.holder_dist_encoder(processed_holder_data)) |
| |
| return {"holder_snapshot": torch.cat(all_holder_snapshot_embeds, dim=0) if all_holder_snapshot_embeds else torch.empty(0, self.d_model, device=device, dtype=self.dtype)} |
|
|
|
|
| def _run_dynamic_encoders(self, batch: Dict[str, Any]) -> Dict[str, torch.Tensor]: |
| """ |
| Runs all dynamic encoders and returns a dictionary of raw, unprojected embeddings. |
| """ |
| device = self.device |
| |
| token_encoder_inputs = batch['token_encoder_inputs'] |
| wallet_encoder_inputs = batch['wallet_encoder_inputs'] |
| |
| embedding_pool = torch.nan_to_num( |
| batch['embedding_pool'].to(device, self.dtype), |
| nan=0.0, |
| posinf=0.0, |
| neginf=0.0 |
| ) |
|
|
| ohlc_price_tensors = torch.nan_to_num( |
| batch['ohlc_price_tensors'].to(device, self.dtype), |
| nan=0.0, |
| posinf=0.0, |
| neginf=0.0 |
| ) |
| ohlc_interval_ids = batch['ohlc_interval_ids'].to(device) |
| quant_ohlc_feature_tensors = torch.nan_to_num( |
| batch['quant_ohlc_feature_tensors'].to(device, self.dtype), |
| nan=0.0, |
| posinf=0.0, |
| neginf=0.0 |
| ) |
| quant_ohlc_feature_mask = batch['quant_ohlc_feature_mask'].to(device) |
| quant_ohlc_feature_version_ids = batch['quant_ohlc_feature_version_ids'].to(device) |
| graph_updater_links = batch['graph_updater_links'] |
|
|
| |
| |
| if token_encoder_inputs['name_embed_indices'].numel() > 0: |
| |
| |
| encoder_args = token_encoder_inputs.copy() |
| encoder_args.pop('_addresses_for_lookup', None) |
| encoder_args.pop('name_embed_indices', None) |
| encoder_args.pop('symbol_embed_indices', None) |
| encoder_args.pop('image_embed_indices', None) |
|
|
| |
| if embedding_pool.numel() > 0: |
| pad_row = torch.zeros(1, embedding_pool.size(1), device=device, dtype=embedding_pool.dtype) |
| pool_padded = torch.cat([pad_row, embedding_pool], dim=0) |
| def pad_and_lookup(idx_tensor: torch.Tensor) -> torch.Tensor: |
| |
| shifted = torch.where(idx_tensor >= 0, idx_tensor + 1, torch.zeros_like(idx_tensor)) |
| return F.embedding(shifted, pool_padded) |
| name_embeds = pad_and_lookup(token_encoder_inputs['name_embed_indices']) |
| symbol_embeds = pad_and_lookup(token_encoder_inputs['symbol_embed_indices']) |
| image_embeds = pad_and_lookup(token_encoder_inputs['image_embed_indices']) |
| else: |
| |
| n = token_encoder_inputs['name_embed_indices'].shape[0] |
| d = self.multi_modal_dim |
| zeros = torch.zeros(n, d, device=device, dtype=self.dtype) |
| name_embeds = zeros |
| symbol_embeds = zeros |
| image_embeds = zeros |
|
|
| batch_token_embeddings_unupd = self.token_encoder( |
| name_embeds=name_embeds, |
| symbol_embeds=symbol_embeds, |
| image_embeds=image_embeds, |
| |
| **encoder_args |
| ) |
| else: |
| batch_token_embeddings_unupd = torch.empty(0, self.token_encoder.output_dim, device=device, dtype=self.dtype) |
|
|
| |
| if wallet_encoder_inputs['profile_rows']: |
| temp_token_lookup = { |
| addr: batch_token_embeddings_unupd[i] |
| for i, addr in enumerate(batch['token_encoder_inputs']['_addresses_for_lookup']) |
| } |
| initial_wallet_embeddings = self.wallet_encoder( |
| **wallet_encoder_inputs, |
| token_vibe_lookup=temp_token_lookup, |
| embedding_pool=embedding_pool |
| ) |
| else: |
| initial_wallet_embeddings = torch.empty(0, self.wallet_encoder.d_model, device=device, dtype=self.dtype) |
|
|
| |
| if ohlc_price_tensors.shape[0] > 0: |
| raw_chart_embeddings = self.ohlc_embedder(ohlc_price_tensors, ohlc_interval_ids) |
| else: |
| raw_chart_embeddings = torch.empty(0, self.ohlc_embedder.output_dim, device=device, dtype=self.dtype) |
| if quant_ohlc_feature_tensors.shape[0] > 0: |
| quant_chart_embeddings = self.quant_ohlc_embedder( |
| quant_ohlc_feature_tensors, |
| quant_ohlc_feature_mask, |
| quant_ohlc_feature_version_ids, |
| ) |
| else: |
| quant_chart_embeddings = torch.empty(0, self.quant_ohlc_embedder.output_dim, device=device, dtype=self.dtype) |
| num_chart_segments = max(raw_chart_embeddings.shape[0], quant_chart_embeddings.shape[0]) |
| if num_chart_segments > 0: |
| if raw_chart_embeddings.shape[0] == 0: |
| raw_chart_embeddings = torch.zeros( |
| num_chart_segments, |
| self.ohlc_embedder.output_dim, |
| device=device, |
| dtype=self.dtype, |
| ) |
| if quant_chart_embeddings.shape[0] == 0: |
| quant_chart_embeddings = torch.zeros( |
| num_chart_segments, |
| self.quant_ohlc_embedder.output_dim, |
| device=device, |
| dtype=self.dtype, |
| ) |
| interval_embeds = self.chart_interval_fusion_embedding(ohlc_interval_ids[:num_chart_segments]).to(self.dtype) |
| batch_ohlc_embeddings_raw = self.chart_fusion( |
| torch.cat([raw_chart_embeddings, quant_chart_embeddings, interval_embeds], dim=-1) |
| ) |
| else: |
| batch_ohlc_embeddings_raw = torch.empty(0, self.quant_ohlc_embedder.output_dim, device=device, dtype=self.dtype) |
|
|
| |
| pad_wallet_raw = self.pad_wallet_emb.to(self.dtype) |
| pad_token_raw = self.pad_token_emb.to(self.dtype) |
| padded_wallet_tensor = torch.cat([pad_wallet_raw, initial_wallet_embeddings], dim=0) |
| padded_token_tensor = torch.cat([pad_token_raw, batch_token_embeddings_unupd], dim=0) |
|
|
| x_dict_initial = {} |
| if padded_wallet_tensor.shape[0] > 1: x_dict_initial['wallet'] = padded_wallet_tensor |
| if padded_token_tensor.shape[0] > 1: x_dict_initial['token'] = padded_token_tensor |
|
|
| if x_dict_initial and graph_updater_links: |
| final_entity_embeddings_dict = self.graph_updater(x_dict_initial, graph_updater_links) |
| final_padded_wallet_embs = final_entity_embeddings_dict.get('wallet', padded_wallet_tensor) |
| final_padded_token_embs = final_entity_embeddings_dict.get('token', padded_token_tensor) |
| else: |
| final_padded_wallet_embs = padded_wallet_tensor |
| final_padded_token_embs = padded_token_tensor |
|
|
| |
| final_wallet_embeddings_raw = final_padded_wallet_embs[1:] |
| final_token_embeddings_raw = final_padded_token_embs[1:] |
|
|
| return { |
| "wallet": final_wallet_embeddings_raw, |
| "token": final_token_embeddings_raw, |
| "ohlc": batch_ohlc_embeddings_raw |
| } |
|
|
| def _project_and_gather_embeddings(self, raw_embeds: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| """ |
| Projects raw embeddings to d_model and gathers them into sequence-aligned tensors. |
| """ |
| |
| final_wallet_proj = self.wallet_proj(raw_embeds['wallet']) |
| final_token_proj = self.token_proj(raw_embeds['token']) |
| final_ohlc_proj = self.ohlc_proj(raw_embeds['ohlc']) |
|
|
| |
| pad_wallet = self.wallet_proj(self.pad_wallet_emb.to(self.dtype)) |
| pad_token = self.token_proj(self.pad_token_emb.to(self.dtype)) |
| pad_ohlc = self.ohlc_proj(self.pad_ohlc_emb.to(self.dtype)) |
| pad_holder_snapshot = self.pad_holder_snapshot_emb.to(self.dtype) |
| |
| |
| precomputed_pool = torch.nan_to_num( |
| batch['embedding_pool'].to(self.device, self.dtype), |
| nan=0.0, |
| posinf=0.0, |
| neginf=0.0 |
| ) |
| final_precomputed_proj = self.precomputed_proj(precomputed_pool) |
| pad_precomputed = self.precomputed_proj(self.pad_precomputed_emb.to(self.dtype)) |
| final_precomputed_lookup = torch.cat([pad_precomputed, final_precomputed_proj], dim=0) |
|
|
| |
| final_wallet_lookup = torch.cat([pad_wallet, final_wallet_proj], dim=0) |
| final_token_lookup = torch.cat([pad_token, final_token_proj], dim=0) |
| final_ohlc_lookup = torch.cat([pad_ohlc, final_ohlc_proj], dim=0) |
|
|
|
|
| |
| main_role_emb = self.token_role_embedding(torch.tensor(self.main_token_role_id, device=self.device)) |
| quote_role_emb = self.token_role_embedding(torch.tensor(self.quote_token_role_id, device=self.device)) |
| trending_role_emb = self.token_role_embedding(torch.tensor(self.trending_token_role_id, device=self.device)) |
|
|
| |
| gathered_main_token_embs = F.embedding(batch['token_indices'], final_token_lookup) |
| gathered_quote_token_embs = F.embedding(batch['quote_token_indices'], final_token_lookup) |
| gathered_trending_token_embs = F.embedding(batch['trending_token_indices'], final_token_lookup) |
| gathered_boosted_token_embs = F.embedding(batch['boosted_token_indices'], final_token_lookup) |
|
|
| |
| final_holder_snapshot_lookup = torch.cat([pad_holder_snapshot, raw_embeds['holder_snapshot']], dim=0) |
|
|
| |
| return { |
| "wallet": F.embedding(batch['wallet_indices'], final_wallet_lookup), |
| "token": gathered_main_token_embs, |
| "ohlc": F.embedding(batch['ohlc_indices'], final_ohlc_lookup), |
| "original_author": F.embedding(batch['original_author_indices'], final_wallet_lookup), |
| "dest_wallet": F.embedding(batch['dest_wallet_indices'], final_wallet_lookup), |
| "quote_token": gathered_quote_token_embs + quote_role_emb, |
| "trending_token": gathered_trending_token_embs + trending_role_emb, |
| "boosted_token": gathered_boosted_token_embs + trending_role_emb, |
| "holder_snapshot": F.embedding(batch['holder_snapshot_indices'], final_holder_snapshot_lookup), |
| "precomputed": final_precomputed_lookup |
| } |
|
|
| def _get_transfer_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor: |
| """ |
| Calculates the special embeddings for Transfer/LargeTransfer events. |
| """ |
| device = self.device |
| transfer_numerical_features = batch['transfer_numerical_features'] |
| event_type_ids = batch['event_type_ids'] |
|
|
| |
| |
| |
| projected_transfer_features = self._normalize_and_project( |
| transfer_numerical_features, self.transfer_num_norm, self.transfer_num_proj, log_indices=[0, 3] |
| ) |
| |
| transfer_event_ids = [self.event_type_to_id.get('Transfer', -1), self.event_type_to_id.get('LargeTransfer', -1)] |
| transfer_mask = torch.isin(event_type_ids, torch.tensor(transfer_event_ids, device=device)).unsqueeze(-1) |
|
|
| |
| return (gathered_embeds['dest_wallet'] + projected_transfer_features) * transfer_mask |
|
|
| def _get_trade_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: |
| """ |
| Calculates the special embeddings for Trade events. |
| """ |
| device = self.device |
| trade_numerical_features = batch['trade_numerical_features'] |
| trade_dex_ids = batch['trade_dex_ids'] |
| trade_direction_ids = batch['trade_direction_ids'] |
| trade_mev_protection_ids = batch['trade_mev_protection_ids'] |
| trade_is_bundle_ids = batch['trade_is_bundle_ids'] |
| event_type_ids = batch['event_type_ids'] |
|
|
| |
| |
| |
| projected_trade_features = self._normalize_and_project( |
| trade_numerical_features, self.trade_num_norm, self.trade_num_proj, log_indices=[0, 1, 7] |
| ) |
|
|
| |
| trade_event_names = ['Trade', 'LargeTrade'] |
| trade_event_ids = [self.event_type_to_id.get(name, -1) for name in trade_event_names] |
| |
| |
| trade_mask = torch.isin(event_type_ids, torch.tensor(trade_event_ids, device=device)).unsqueeze(-1) |
|
|
| |
| dex_id_embeds = self.dex_platform_embedding(trade_dex_ids) |
| direction_embeds = self.trade_direction_embedding(trade_direction_ids) |
| mev_embeds = self.mev_protection_embedding(trade_mev_protection_ids) |
| bundle_embeds = self.is_bundle_embedding(trade_is_bundle_ids) |
|
|
| return (projected_trade_features + dex_id_embeds + direction_embeds + mev_embeds + bundle_embeds) * trade_mask |
|
|
| def _get_deployer_trade_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: |
| """ |
| Calculates the special embeddings for Deployer_Trade events using its own layers. |
| """ |
| device = self.device |
| deployer_trade_numerical_features = batch['deployer_trade_numerical_features'] |
| trade_dex_ids = batch['trade_dex_ids'] |
| trade_direction_ids = batch['trade_direction_ids'] |
| trade_mev_protection_ids = batch['trade_mev_protection_ids'] |
| trade_is_bundle_ids = batch['trade_is_bundle_ids'] |
| event_type_ids = batch['event_type_ids'] |
|
|
| |
| |
| projected_deployer_trade_features = self._normalize_and_project( |
| deployer_trade_numerical_features, self.deployer_trade_num_norm, self.deployer_trade_num_proj, log_indices=[0, 1, 7] |
| ) |
| |
| dex_id_embeds = self.dex_platform_embedding(trade_dex_ids) |
| direction_embeds = self.trade_direction_embedding(trade_direction_ids) |
| mev_embeds = self.mev_protection_embedding(trade_mev_protection_ids) |
| bundle_embeds = self.is_bundle_embedding(trade_is_bundle_ids) |
|
|
| deployer_trade_mask = (event_type_ids == self.event_type_to_id.get('Deployer_Trade', -1)).unsqueeze(-1) |
| return (projected_deployer_trade_features + dex_id_embeds + direction_embeds + mev_embeds + bundle_embeds) * deployer_trade_mask |
|
|
| def _get_smart_wallet_trade_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: |
| """ |
| Calculates the special embeddings for SmartWallet_Trade events using its own layers. |
| """ |
| device = self.device |
| smart_wallet_trade_numerical_features = batch['smart_wallet_trade_numerical_features'] |
| trade_dex_ids = batch['trade_dex_ids'] |
| trade_direction_ids = batch['trade_direction_ids'] |
| trade_mev_protection_ids = batch['trade_mev_protection_ids'] |
| trade_is_bundle_ids = batch['trade_is_bundle_ids'] |
| event_type_ids = batch['event_type_ids'] |
|
|
| |
| |
| projected_features = self._normalize_and_project( |
| smart_wallet_trade_numerical_features, self.smart_wallet_trade_num_norm, self.smart_wallet_trade_num_proj, log_indices=[0, 1, 7] |
| ) |
|
|
| dex_id_embeds = self.dex_platform_embedding(trade_dex_ids) |
| direction_embeds = self.trade_direction_embedding(trade_direction_ids) |
| mev_embeds = self.mev_protection_embedding(trade_mev_protection_ids) |
| bundle_embeds = self.is_bundle_embedding(trade_is_bundle_ids) |
|
|
| mask = (event_type_ids == self.event_type_to_id.get('SmartWallet_Trade', -1)).unsqueeze(-1) |
| return (projected_features + dex_id_embeds + direction_embeds + mev_embeds + bundle_embeds) * mask |
|
|
| def _get_pool_created_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor: |
| """ |
| Calculates the special embeddings for PoolCreated events. |
| """ |
| device = self.device |
| pool_created_numerical_features = batch['pool_created_numerical_features'] |
| pool_created_protocol_ids = batch['pool_created_protocol_ids'] |
| event_type_ids = batch['event_type_ids'] |
|
|
| |
| |
| |
| projected_features = self._normalize_and_project( |
| pool_created_numerical_features, self.pool_created_num_norm, self.pool_created_num_proj, log_indices=[0, 1] |
| ) |
| |
| protocol_id_embeds = self.protocol_embedding(pool_created_protocol_ids) |
|
|
| |
| mask = (event_type_ids == self.event_type_to_id.get('PoolCreated', -1)).unsqueeze(-1) |
|
|
| |
| return (gathered_embeds['quote_token'] + projected_features + protocol_id_embeds) * mask |
|
|
| def _get_liquidity_change_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor: |
| """ |
| Calculates the special embeddings for LiquidityChange events. |
| """ |
| device = self.device |
| liquidity_change_numerical_features = batch['liquidity_change_numerical_features'] |
| liquidity_change_type_ids = batch['liquidity_change_type_ids'] |
| event_type_ids = batch['event_type_ids'] |
|
|
| |
| |
| projected_features = self._normalize_and_project( |
| liquidity_change_numerical_features, self.liquidity_change_num_norm, self.liquidity_change_num_proj, log_indices=[0] |
| ) |
| |
| change_type_embeds = self.liquidity_change_type_embedding(liquidity_change_type_ids) |
|
|
| |
| mask = (event_type_ids == self.event_type_to_id.get('LiquidityChange', -1)).unsqueeze(-1) |
|
|
| |
| return (gathered_embeds['quote_token'] + projected_features + change_type_embeds) * mask |
|
|
| def _get_fee_collected_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: |
| """ |
| Calculates the special embeddings for FeeCollected events. |
| """ |
| device = self.device |
| fee_collected_numerical_features = batch['fee_collected_numerical_features'] |
| event_type_ids = batch['event_type_ids'] |
|
|
| |
| projected_features = self._normalize_and_project( |
| fee_collected_numerical_features, self.fee_collected_num_norm, self.fee_collected_num_proj, log_indices=[0] |
| ) |
|
|
| |
| mask = (event_type_ids == self.event_type_to_id.get('FeeCollected', -1)).unsqueeze(-1) |
|
|
| return projected_features * mask |
|
|
| def _get_token_burn_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: |
| """ |
| Calculates the special embeddings for TokenBurn events. |
| """ |
| device = self.device |
| token_burn_numerical_features = batch['token_burn_numerical_features'] |
| event_type_ids = batch['event_type_ids'] |
|
|
| |
| |
| |
| projected_features = self._normalize_and_project( |
| token_burn_numerical_features, self.token_burn_num_norm, self.token_burn_num_proj, log_indices=[1] |
| ) |
| |
| mask = (event_type_ids == self.event_type_to_id.get('TokenBurn', -1)).unsqueeze(-1) |
|
|
| return projected_features * mask |
|
|
| def _get_supply_lock_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: |
| """ |
| Calculates the special embeddings for SupplyLock events. |
| """ |
| device = self.device |
| supply_lock_numerical_features = batch['supply_lock_numerical_features'] |
| event_type_ids = batch['event_type_ids'] |
|
|
| |
| |
| |
| projected_features = self._normalize_and_project( |
| supply_lock_numerical_features, self.supply_lock_num_norm, self.supply_lock_num_proj, log_indices=[1] |
| ) |
| |
| mask = (event_type_ids == self.event_type_to_id.get('SupplyLock', -1)).unsqueeze(-1) |
|
|
| return projected_features * mask |
|
|
| def _get_onchain_snapshot_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: |
| """ |
| Calculates the special embeddings for OnChain_Snapshot events. |
| """ |
| device = self.device |
| onchain_snapshot_numerical_features = batch['onchain_snapshot_numerical_features'] |
| event_type_ids = batch['event_type_ids'] |
|
|
| |
| |
| |
| projected_features = self._normalize_and_project( |
| onchain_snapshot_numerical_features, self.onchain_snapshot_num_norm, self.onchain_snapshot_num_proj, log_indices=[0, 1, 2, 8, 9, 10, 11, 12, 13] |
| ) |
| |
| mask = (event_type_ids == self.event_type_to_id.get('OnChain_Snapshot', -1)).unsqueeze(-1) |
|
|
| return projected_features * mask |
|
|
| def _get_trending_token_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor: |
| """ |
| Calculates the special embeddings for TrendingToken events. |
| """ |
| device = self.device |
| trending_token_numerical_features = batch['trending_token_numerical_features'] |
| trending_token_source_ids = batch['trending_token_source_ids'] |
| trending_token_timeframe_ids = batch['trending_token_timeframe_ids'] |
| event_type_ids = batch['event_type_ids'] |
|
|
| |
| projected_features = self._normalize_and_project( |
| trending_token_numerical_features, self.trending_token_num_norm, self.trending_token_num_proj, log_indices=None |
| ) |
| |
| |
| source_embeds = self.trending_list_source_embedding(trending_token_source_ids) |
| timeframe_embeds = self.trending_timeframe_embedding(trending_token_timeframe_ids) |
|
|
| |
| mask = (event_type_ids == self.event_type_to_id.get('TrendingToken', -1)).unsqueeze(-1) |
|
|
| |
| return (gathered_embeds['trending_token'] + projected_features + source_embeds + timeframe_embeds) * mask |
|
|
| def _get_boosted_token_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor: |
| """ |
| Calculates the special embeddings for BoostedToken events. |
| """ |
| device = self.device |
| boosted_token_numerical_features = batch['boosted_token_numerical_features'] |
| event_type_ids = batch['event_type_ids'] |
|
|
| |
| |
| |
| projected_features = self._normalize_and_project( |
| boosted_token_numerical_features, self.boosted_token_num_norm, self.boosted_token_num_proj, log_indices=[0] |
| ) |
| |
| mask = (event_type_ids == self.event_type_to_id.get('BoostedToken', -1)).unsqueeze(-1) |
|
|
| |
| return (gathered_embeds['boosted_token'] + projected_features) * mask |
|
|
| def _get_dexboost_paid_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: |
| """ |
| Calculates the special embeddings for DexBoost_Paid events. |
| """ |
| device = self.device |
| dexboost_paid_numerical_features = batch['dexboost_paid_numerical_features'] |
| event_type_ids = batch['event_type_ids'] |
|
|
| |
| projected_features = self._normalize_and_project( |
| dexboost_paid_numerical_features, self.dexboost_paid_num_norm, self.dexboost_paid_num_proj, log_indices=[0, 1] |
| ) |
| |
| mask = (event_type_ids == self.event_type_to_id.get('DexBoost_Paid', -1)).unsqueeze(-1) |
|
|
| return projected_features * mask |
|
|
| def _get_alphagroup_call_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: |
| """ |
| Handles AlphaGroup_Call events by looking up the group_id embedding. |
| """ |
| device = self.device |
| group_ids = batch['alpha_group_ids'] |
| event_type_ids = batch['event_type_ids'] |
|
|
| |
| group_embeds = self.alpha_group_embedding(group_ids) |
|
|
| |
| mask = (event_type_ids == self.event_type_to_id.get('AlphaGroup_Call', -1)).unsqueeze(-1) |
| return group_embeds * mask |
|
|
| def _get_channel_call_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: |
| """ |
| Handles Channel_Call events by looking up the channel_id embedding. |
| """ |
| device = self.device |
| channel_ids = batch['channel_ids'] |
| event_type_ids = batch['event_type_ids'] |
|
|
| channel_embeds = self.call_channel_embedding(channel_ids) |
| mask = (event_type_ids == self.event_type_to_id.get('Channel_Call', -1)).unsqueeze(-1) |
| return channel_embeds * mask |
|
|
| def _get_cexlisting_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: |
| """ |
| Handles CexListing events by looking up the exchange_id embedding. |
| """ |
| device = self.device |
| exchange_ids = batch['exchange_ids'] |
| event_type_ids = batch['event_type_ids'] |
|
|
| exchange_embeds = self.cex_listing_embedding(exchange_ids) |
| mask = (event_type_ids == self.event_type_to_id.get('CexListing', -1)).unsqueeze(-1) |
| return exchange_embeds * mask |
|
|
| def _get_chainsnapshot_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: |
| """ |
| Handles ChainSnapshot events. |
| """ |
| device = self.device |
| numerical_features = batch['chainsnapshot_numerical_features'] |
| event_type_ids = batch['event_type_ids'] |
|
|
| |
| projected_features = self._normalize_and_project( |
| numerical_features, self.chainsnapshot_num_norm, self.chainsnapshot_num_proj, log_indices=[0, 1] |
| ) |
| mask = (event_type_ids == self.event_type_to_id.get('ChainSnapshot', -1)).unsqueeze(-1) |
| return projected_features * mask |
|
|
| def _get_lighthousesnapshot_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: |
| """ |
| Handles Lighthouse_Snapshot events. |
| """ |
| device = self.device |
| numerical_features = batch['lighthousesnapshot_numerical_features'] |
| protocol_ids = batch['lighthousesnapshot_protocol_ids'] |
| timeframe_ids = batch['lighthousesnapshot_timeframe_ids'] |
| event_type_ids = batch['event_type_ids'] |
|
|
| |
| projected_features = self._normalize_and_project( |
| numerical_features, self.lighthousesnapshot_num_norm, self.lighthousesnapshot_num_proj, log_indices=[0, 1, 2, 3, 4] |
| ) |
| |
| |
| protocol_embeds = self.protocol_embedding(protocol_ids) |
| timeframe_embeds = self.lighthouse_timeframe_embedding(timeframe_ids) |
|
|
| mask = (event_type_ids == self.event_type_to_id.get('Lighthouse_Snapshot', -1)).unsqueeze(-1) |
| return (projected_features + protocol_embeds + timeframe_embeds) * mask |
|
|
| def _get_migrated_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: |
| """ |
| Handles Migrated events by looking up the protocol_id embedding. |
| """ |
| device = self.device |
| protocol_ids = batch['migrated_protocol_ids'] |
| event_type_ids = batch['event_type_ids'] |
|
|
| |
| protocol_embeds = self.protocol_embedding(protocol_ids) |
|
|
| |
| mask = (event_type_ids == self.event_type_to_id.get('Migrated', -1)).unsqueeze(-1) |
| return protocol_embeds * mask |
|
|
| def _get_special_context_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: |
| """ |
| Handles special context tokens like 'MIDDLE' and 'RECENT' by adding their unique learnable embeddings. |
| """ |
| device = self.device |
| event_type_ids = batch['event_type_ids'] |
| B, L = event_type_ids.shape |
|
|
| middle_id = self.event_type_to_id.get('MIDDLE', -1) |
| recent_id = self.event_type_to_id.get('RECENT', -1) |
|
|
| middle_mask = (event_type_ids == middle_id) |
| recent_mask = (event_type_ids == recent_id) |
|
|
| middle_emb = self.special_context_embedding(torch.tensor(self.special_context_tokens['MIDDLE'], device=device)) |
| recent_emb = self.special_context_embedding(torch.tensor(self.special_context_tokens['RECENT'], device=device)) |
|
|
| |
| return middle_mask.unsqueeze(-1) * middle_emb + recent_mask.unsqueeze(-1) * recent_emb |
| |
| def _pool_hidden_states(self, |
| hidden_states: torch.Tensor, |
| attention_mask: torch.Tensor) -> torch.Tensor: |
| """ |
| Pools variable-length hidden states into a single embedding per sequence by |
| selecting the last non-masked token for each batch element. |
| """ |
| if hidden_states.size(0) == 0: |
| return torch.empty(0, self.d_model, device=hidden_states.device, dtype=hidden_states.dtype) |
|
|
| seq_lengths = attention_mask.long().sum(dim=1) |
| last_indices = torch.clamp(seq_lengths - 1, min=0) |
| batch_indices = torch.arange(hidden_states.size(0), device=hidden_states.device) |
| return hidden_states[batch_indices, last_indices] |
|
|
| def forward(self, batch: Dict[str, Any]) -> Dict[str, torch.Tensor]: |
| device = self.device |
| |
| |
| event_type_ids = batch['event_type_ids'].to(device) |
| timestamps_float = batch['timestamps_float'].to(device) |
| relative_ts = batch['relative_ts'].to(device, self.dtype) |
| attention_mask = batch['attention_mask'].to(device) |
|
|
| B, L = event_type_ids.shape |
| if B == 0 or L == 0: |
| print("Warning: Received empty batch in Oracle forward.") |
| empty_hidden = torch.empty(0, L, self.d_model, device=device, dtype=self.dtype) |
| empty_mask = torch.empty(0, L, device=device, dtype=torch.long) |
| empty_quantiles = torch.empty(0, self.num_outputs, device=device, dtype=self.dtype) |
| empty_quality = torch.empty(0, device=device, dtype=self.dtype) |
| empty_movement = torch.empty(0, len(self.horizons_seconds), self.num_movement_classes, device=device, dtype=self.dtype) |
| return { |
| 'quantile_logits': empty_quantiles, |
| 'quality_logits': empty_quality, |
| 'movement_logits': empty_movement, |
| 'pooled_states': torch.empty(0, self.d_model, device=device, dtype=self.dtype), |
| 'hidden_states': empty_hidden, |
| 'attention_mask': empty_mask |
| } |
|
|
| |
| dynamic_raw_embeds = self._run_dynamic_encoders(batch) |
|
|
|
|
| |
| wallet_addr_to_batch_idx = batch['wallet_addr_to_batch_idx'] |
| snapshot_raw_embeds = self._run_snapshot_encoders(batch, dynamic_raw_embeds['wallet'], wallet_addr_to_batch_idx) |
|
|
| |
| raw_embeds = {**dynamic_raw_embeds, **snapshot_raw_embeds} |
| gathered_embeds = self._project_and_gather_embeddings(raw_embeds, batch) |
|
|
| |
| event_embeds = self.event_type_embedding(event_type_ids) |
| ts_embeds = self.time_proj(self.time_encoder(timestamps_float)) |
| |
| relative_ts_fp32 = batch['relative_ts'].to(device, torch.float32) |
| rel_ts_minutes = relative_ts_fp32 / 60.0 |
| rel_ts_processed = torch.sign(rel_ts_minutes) * torch.log1p(torch.abs(rel_ts_minutes)) |
| |
| norm_dtype = self.rel_ts_norm.weight.dtype |
| proj_dtype = self.rel_ts_proj.weight.dtype |
| rel_ts_normed = self.rel_ts_norm(rel_ts_processed.to(norm_dtype)) |
| rel_ts_embeds = self.rel_ts_proj(rel_ts_normed.to(proj_dtype)) |
|
|
| |
| transfer_specific_embeds = self._get_transfer_specific_embeddings(batch, gathered_embeds) |
|
|
| |
| trade_specific_embeds = self._get_trade_specific_embeddings(batch) |
|
|
| |
| deployer_trade_specific_embeds = self._get_deployer_trade_specific_embeddings(batch) |
|
|
| |
| smart_wallet_trade_specific_embeds = self._get_smart_wallet_trade_specific_embeddings(batch) |
|
|
| |
| pool_created_specific_embeds = self._get_pool_created_specific_embeddings(batch, gathered_embeds) |
|
|
| |
| liquidity_change_specific_embeds = self._get_liquidity_change_specific_embeddings(batch, gathered_embeds) |
|
|
| |
| fee_collected_specific_embeds = self._get_fee_collected_specific_embeddings(batch) |
|
|
| |
| token_burn_specific_embeds = self._get_token_burn_specific_embeddings(batch) |
|
|
| |
| supply_lock_specific_embeds = self._get_supply_lock_specific_embeddings(batch) |
|
|
| |
| onchain_snapshot_specific_embeds = self._get_onchain_snapshot_specific_embeddings(batch) |
|
|
| |
| trending_token_specific_embeds = self._get_trending_token_specific_embeddings(batch, gathered_embeds) |
|
|
| |
| boosted_token_specific_embeds = self._get_boosted_token_specific_embeddings(batch, gathered_embeds) |
|
|
| |
| dexboost_paid_specific_embeds = self._get_dexboost_paid_specific_embeddings(batch) |
|
|
| |
| alphagroup_call_specific_embeds = self._get_alphagroup_call_specific_embeddings(batch) |
| channel_call_specific_embeds = self._get_channel_call_specific_embeddings(batch) |
| cexlisting_specific_embeds = self._get_cexlisting_specific_embeddings(batch) |
|
|
| |
| chainsnapshot_specific_embeds = self._get_chainsnapshot_specific_embeddings(batch) |
| lighthousesnapshot_specific_embeds = self._get_lighthousesnapshot_specific_embeddings(batch) |
| |
| migrated_specific_embeds = self._get_migrated_specific_embeddings(batch) |
|
|
| |
| dexprofile_updated_flags = batch['dexprofile_updated_flags'] |
| dexprofile_flags_embeds = self.dexprofile_updated_flags_proj(dexprofile_updated_flags.to(self.dtype)) |
|
|
| |
| |
| |
| textual_event_embeds = self.social_encoder( |
| batch=batch, |
| gathered_embeds=gathered_embeds |
| ) |
|
|
| |
| special_context_embeds = self._get_special_context_embeddings(batch) |
|
|
| |
| |
| components = [ |
| event_embeds, ts_embeds, rel_ts_embeds, |
| gathered_embeds['wallet'], gathered_embeds['token'], gathered_embeds['original_author'], gathered_embeds['ohlc'], |
| transfer_specific_embeds, trade_specific_embeds, deployer_trade_specific_embeds, smart_wallet_trade_specific_embeds, |
| pool_created_specific_embeds, liquidity_change_specific_embeds, fee_collected_specific_embeds, |
| token_burn_specific_embeds, supply_lock_specific_embeds, onchain_snapshot_specific_embeds, |
| trending_token_specific_embeds, boosted_token_specific_embeds, dexboost_paid_specific_embeds, |
| alphagroup_call_specific_embeds, channel_call_specific_embeds, cexlisting_specific_embeds, |
| migrated_specific_embeds, special_context_embeds, gathered_embeds['holder_snapshot'], textual_event_embeds, |
| dexprofile_flags_embeds, chainsnapshot_specific_embeds, lighthousesnapshot_specific_embeds |
| ] |
| inputs_embeds = sum([t.float() for t in components]).to(self.dtype) |
|
|
| hf_attention_mask = attention_mask.to(device=device, dtype=torch.long) |
| outputs = self.model( |
| inputs_embeds=inputs_embeds, |
| attention_mask=hf_attention_mask, |
| return_dict=True |
| ) |
| sequence_hidden = outputs.last_hidden_state |
| pooled_states = self._pool_hidden_states(sequence_hidden, hf_attention_mask) |
| quantile_logits = self.quantile_head(pooled_states) |
| quality_logits = self.quality_head(pooled_states).squeeze(-1) |
| movement_logits = self.movement_head(pooled_states).view( |
| pooled_states.shape[0], |
| len(self.horizons_seconds), |
| self.num_movement_classes, |
| ) |
|
|
| return { |
| 'quantile_logits': quantile_logits, |
| 'quality_logits': quality_logits, |
| 'movement_logits': movement_logits, |
| 'pooled_states': pooled_states, |
| 'hidden_states': sequence_hidden, |
| 'attention_mask': hf_attention_mask |
| } |
|
|