File size: 61,753 Bytes
858826c e125fa3 858826c 0e3516b 858826c d195287 858826c 7064310 858826c d195287 858826c e125fa3 858826c 0e3516b 858826c 7064310 858826c e125fa3 858826c 4dd4ab4 7064310 858826c d195287 858826c d195287 858826c d195287 858826c e125fa3 858826c 0e3516b d195287 0e3516b d195287 0e3516b 858826c 54191e5 858826c 54191e5 858826c 54191e5 858826c 54191e5 858826c 54191e5 858826c 54191e5 858826c d195287 858826c d195287 858826c d195287 858826c 54191e5 858826c e125fa3 858826c e125fa3 858826c e125fa3 858826c 4dd4ab4 7064310 858826c 4dd4ab4 7064310 858826c 4dd4ab4 7064310 858826c 4dd4ab4 7064310 858826c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 | # model.py (REFACTORED AND FIXED)
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, LlamaConfig
from typing import List, Dict, Any, Optional, Tuple
import os
import json
# --- NOW, we import all the encoders ---
from models.helper_encoders import ContextualTimeEncoder
from models.token_encoder import TokenEncoder
from models.wallet_encoder import WalletEncoder
from models.graph_updater import GraphUpdater
from models.ohlc_embedder import OHLCEmbedder
from models.quant_ohlc_embedder import QuantOHLCEmbedder
from models.HoldersEncoder import HolderDistributionEncoder # NEW
from models.SocialEncoders import SocialEncoder # NEW
import models.vocabulary as vocab # For vocab sizes
from data.context_targets import MOVEMENT_CLASS_NAMES
class Oracle(nn.Module):
"""
"""
def __init__(self,
token_encoder: TokenEncoder,
wallet_encoder: WalletEncoder,
graph_updater: GraphUpdater,
ohlc_embedder: OHLCEmbedder, # NEW
quant_ohlc_embedder: QuantOHLCEmbedder,
time_encoder: ContextualTimeEncoder,
num_event_types: int,
multi_modal_dim: int,
event_pad_id: int,
event_type_to_id: Dict[str, int],
model_config_name: str = "llama3-12l-768d-gqa4-8k-random",
quantiles: List[float] = [0.1, 0.5, 0.9],
horizons_seconds: List[int] = [30, 60, 120, 240, 420],
dtype: torch.dtype = torch.bfloat16):
super().__init__()
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
self.dtype = dtype
self.multi_modal_dim = multi_modal_dim
self.num_event_types = num_event_types
self.event_pad_id = event_pad_id
self.model_config_name = model_config_name
self.quantiles = quantiles
self.horizons_seconds = horizons_seconds
self.num_outputs = len(quantiles) * len(horizons_seconds)
self.num_movement_classes = len(MOVEMENT_CLASS_NAMES)
self.dtype = dtype
# --- 2. Backbone: Llama-style decoder, RANDOM INIT (no pretrained weights) ---
# This gives you RoPE + modern decoder blocks and lets HF use optimized attention
# implementations (SDPA / FlashAttention) without us implementing a transformer.
#
# Size target: ~80-120M params, suitable for 8k-ish seq caps with your data regime.
attn_impl = os.getenv("HF_ATTN_IMPL", "sdpa") # "sdpa" (safe) or "flash_attention_2" (if installed)
llama_cfg = LlamaConfig(
# Model size
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
# GQA-style KV heads (Llama 3-style efficiency knob)
num_key_value_heads=4,
# Long context (must be >= your effective max sequence length)
max_position_embeddings=8192,
# Llama 3 uses a large theta; harmless for random init and helps longer contexts.
rope_theta=500000.0,
rms_norm_eps=1e-5,
# Unused when providing inputs_embeds, but required by config
vocab_size=32000,
)
self.d_model = llama_cfg.hidden_size
# Older transformers versions may not support attn_implementation in from_config.
# Also, flash_attention_2 requires optional deps; fall back to SDPA if unavailable.
try:
self.model = AutoModel.from_config(llama_cfg, attn_implementation=attn_impl)
except TypeError:
self.model = AutoModel.from_config(llama_cfg)
except Exception:
if attn_impl != "sdpa":
self.model = AutoModel.from_config(llama_cfg, attn_implementation="sdpa")
else:
raise
# Disable KV cache during training (saves memory; not used for full-seq training).
if hasattr(self.model, "config"):
self.model.config.use_cache = False
self.model.to(self.device, dtype=self.dtype)
# Quantile prediction head (maps pooled hidden state -> flattened horizon/quantile grid)
self.quantile_head = nn.Sequential(
nn.Linear(self.d_model, self.d_model),
nn.GELU(),
nn.Linear(self.d_model, self.num_outputs)
)
self.quality_head = nn.Sequential(
nn.Linear(self.d_model, self.d_model),
nn.GELU(),
nn.Linear(self.d_model, 1)
)
self.movement_head = nn.Sequential(
nn.Linear(self.d_model, self.d_model),
nn.GELU(),
nn.Linear(self.d_model, len(self.horizons_seconds) * self.num_movement_classes)
)
self.event_type_to_id = event_type_to_id
# --- 1. Store All Encoders ---
# Define Token Roles before using them
self.token_roles = {'main': 0, 'quote': 1, 'trending': 2} # Add trending for future use
self.main_token_role_id = self.token_roles['main']
self.quote_token_role_id = self.token_roles['quote']
self.trending_token_role_id = self.token_roles['trending']
self.token_encoder = token_encoder
self.wallet_encoder = wallet_encoder
self.graph_updater = graph_updater
self.ohlc_embedder = ohlc_embedder
self.quant_ohlc_embedder = quant_ohlc_embedder
self.time_encoder = time_encoder # Store time_encoder
self.social_encoder = SocialEncoder(d_model=self.d_model, dtype=self.dtype) # Now self.d_model is defined
# --- 4. Define Sequence Feature Embeddings ---
self.event_type_embedding = nn.Embedding(num_event_types, self.d_model, padding_idx=event_pad_id)
# --- NEW: Token Role Embeddings ---
self.token_role_embedding = nn.Embedding(len(self.token_roles), self.d_model)
# --- 5. Define Entity Padding (Learnable) ---
self.pad_wallet_emb = nn.Parameter(torch.zeros(1, self.wallet_encoder.d_model))
self.pad_token_emb = nn.Parameter(torch.zeros(1, self.token_encoder.output_dim))
self.pad_ohlc_emb = nn.Parameter(torch.zeros(1, self.quant_ohlc_embedder.output_dim))
self.pad_precomputed_emb = nn.Parameter(torch.zeros(1, self.multi_modal_dim)) # NEW: For text/images
# --- NEW: Instantiate HolderDistributionEncoder internally ---
self.holder_dist_encoder = HolderDistributionEncoder(
wallet_embedding_dim=self.wallet_encoder.d_model,
output_dim=self.d_model,
dtype=self.dtype # Pass the correct dtype
)
self.pad_holder_snapshot_emb = nn.Parameter(torch.zeros(1, self.d_model)) # Output of holder_dist_encoder is d_model
# --- 6. Define Projection MLPs ---
self.time_proj = nn.Linear(self.time_encoder.projection.out_features, self.d_model)
self.rel_ts_proj = nn.Linear(1, self.d_model)
self.rel_ts_norm = nn.LayerNorm(1)
self.wallet_proj = nn.Linear(self.wallet_encoder.d_model, self.d_model)
self.token_proj = nn.Linear(self.token_encoder.output_dim, self.d_model)
self.ohlc_proj = nn.Linear(self.quant_ohlc_embedder.output_dim, self.d_model)
self.chart_interval_fusion_embedding = nn.Embedding(vocab.NUM_OHLC_INTERVALS, 32, padding_idx=0)
fusion_input_dim = self.ohlc_embedder.output_dim + self.quant_ohlc_embedder.output_dim + 32
self.chart_fusion = nn.Sequential(
nn.Linear(fusion_input_dim, self.quant_ohlc_embedder.output_dim),
nn.GELU(),
nn.LayerNorm(self.quant_ohlc_embedder.output_dim),
nn.Linear(self.quant_ohlc_embedder.output_dim, self.quant_ohlc_embedder.output_dim),
nn.LayerNorm(self.quant_ohlc_embedder.output_dim),
)
# self.holder_snapshot_proj is no longer needed as HolderDistributionEncoder outputs directly to d_model
# --- NEW: Layers for Transfer Numerical Features ---
self.transfer_num_norm = nn.LayerNorm(4) # Normalize the 4 features
self.transfer_num_proj = nn.Linear(4, self.d_model) # Project to d_model
# --- NEW: Layers for Trade Numerical Features ---
# --- FIXED: Size reduced from 10 to 8 ---
self.trade_num_norm = nn.LayerNorm(8)
self.trade_num_proj = nn.Linear(8, self.d_model)
# --- NEW: Embedding for categorical dex_platform_id ---
self.dex_platform_embedding = nn.Embedding(vocab.NUM_DEX_PLATFORMS, self.d_model)
# --- NEW: Embedding for categorical trade_direction ---
self.trade_direction_embedding = nn.Embedding(2, self.d_model) # 0 for buy, 1 for sell
# --- FIXED: Embedding for categorical mev_protection is now binary ---
self.mev_protection_embedding = nn.Embedding(2, self.d_model) # 0 for false, 1 for true
# --- NEW: Embedding for categorical is_bundle ---
self.is_bundle_embedding = nn.Embedding(2, self.d_model) # 0 for false, 1 for true
# --- NEW: Separate Layers for Deployer Trade Numerical Features ---
# --- FIXED: Size reduced from 10 to 8 ---
self.deployer_trade_num_norm = nn.LayerNorm(8)
self.deployer_trade_num_proj = nn.Linear(8, self.d_model)
# --- NEW: Separate Layers for Smart Wallet Trade Numerical Features ---
# --- FIXED: Size reduced from 10 to 8 ---
self.smart_wallet_trade_num_norm = nn.LayerNorm(8)
self.smart_wallet_trade_num_proj = nn.Linear(8, self.d_model)
# --- NEW: Layers for PoolCreated Numerical Features ---
# --- FIXED: Size reduced from 5 to 4 ---
self.pool_created_num_norm = nn.LayerNorm(2)
self.pool_created_num_proj = nn.Linear(2, self.d_model)
# --- NEW: Layers for LiquidityChange Numerical Features ---
# --- FIXED: Size reduced from 3 to 2 ---
self.liquidity_change_num_norm = nn.LayerNorm(1)
self.liquidity_change_num_proj = nn.Linear(1, self.d_model)
# --- NEW: Embedding for categorical change_type_id ---
# --- FIXED: Hardcoded the number of types (add/remove) as per user instruction ---
self.liquidity_change_type_embedding = nn.Embedding(2, self.d_model)
# --- NEW: Layers for FeeCollected Numerical Features ---
self.fee_collected_num_norm = nn.LayerNorm(1) # sol_amount only
self.fee_collected_num_proj = nn.Linear(1, self.d_model)
# --- NEW: Layers for TokenBurn Numerical Features ---
self.token_burn_num_norm = nn.LayerNorm(2) # amount_pct, amount_tokens
self.token_burn_num_proj = nn.Linear(2, self.d_model)
# --- NEW: Layers for SupplyLock Numerical Features ---
self.supply_lock_num_norm = nn.LayerNorm(2) # amount_pct, lock_duration
self.supply_lock_num_proj = nn.Linear(2, self.d_model)
# --- NEW: Layers for OnChain_Snapshot Numerical Features ---
self.onchain_snapshot_num_norm = nn.LayerNorm(14)
self.onchain_snapshot_num_proj = nn.Linear(14, self.d_model)
# --- NEW: Layers for TrendingToken Numerical Features ---
# --- FIXED: Size reduced from 3 to 1 (rank only) ---
self.trending_token_num_norm = nn.LayerNorm(1)
self.trending_token_num_proj = nn.Linear(1, self.d_model)
# --- NEW: Embeddings for categorical IDs ---
self.trending_list_source_embedding = nn.Embedding(vocab.NUM_TRENDING_LIST_SOURCES, self.d_model)
self.trending_timeframe_embedding = nn.Embedding(vocab.NUM_TRENDING_LIST_TIMEFRAMES, self.d_model)
# --- NEW: Layers for BoostedToken Numerical Features ---
self.boosted_token_num_norm = nn.LayerNorm(2) # total_boost_amount, rank
self.boosted_token_num_proj = nn.Linear(2, self.d_model)
# --- NEW: Layers for DexBoost_Paid Numerical Features ---
self.dexboost_paid_num_norm = nn.LayerNorm(2) # amount, total_amount_on_token
self.dexboost_paid_num_proj = nn.Linear(2, self.d_model)
# --- NEW: Layers for DexProfile_Updated Features ---
self.dexprofile_updated_flags_proj = nn.Linear(4, self.d_model) # Project the 4 boolean flags
# --- NEW: Projection for all pre-computed embeddings (text/images) ---
self.precomputed_proj = nn.Linear(self.multi_modal_dim, self.d_model)
# --- NEW: Embedding for Protocol IDs (used in Migrated event) ---
self.protocol_embedding = nn.Embedding(vocab.NUM_PROTOCOLS, self.d_model)
# --- NEW: Embeddings for TrackerEncoder Events ---
# Note: NUM_CALL_CHANNELS might need to be large and managed as vocab grows.
self.alpha_group_embedding = nn.Embedding(vocab.NUM_ALPHA_GROUPS, self.d_model)
self.call_channel_embedding = nn.Embedding(vocab.NUM_CALL_CHANNELS, self.d_model)
self.cex_listing_embedding = nn.Embedding(vocab.NUM_EXCHANGES, self.d_model)
# --- NEW: Layers for GlobalTrendingEncoder Events ---
self.global_trending_num_norm = nn.LayerNorm(1) # rank
self.global_trending_num_proj = nn.Linear(1, self.d_model)
# --- NEW: Layers for ChainSnapshot Events ---
self.chainsnapshot_num_norm = nn.LayerNorm(2) # native_token_price_usd, gas_fee
self.chainsnapshot_num_proj = nn.Linear(2, self.d_model)
# --- NEW: Layers for Lighthouse_Snapshot Events ---
# --- FIXED: Size reduced from 7 to 5 ---
self.lighthousesnapshot_num_norm = nn.LayerNorm(5)
self.lighthousesnapshot_num_proj = nn.Linear(5, self.d_model)
# --- NEW: Embedding for timeframe ID (re-uses protocol_embedding) ---
self.lighthouse_timeframe_embedding = nn.Embedding(vocab.NUM_LIGHTHOUSE_TIMEFRAMES, self.d_model)
# --- Embeddings for Special Context Tokens ---
# Must match vocabulary event names (see models/vocabulary.py).
self.special_context_tokens = {'MIDDLE': 0, 'RECENT': 1}
self.special_context_embedding = nn.Embedding(len(self.special_context_tokens), self.d_model)
# --- 7. Prediction Head --- (Unchanged)
# self.prediction_head = nn.Linear(self.d_model, self.num_outputs)
# --- 8. Move all new modules to correct dtype ---
self.to(dtype)
print("Oracle model (full pipeline) initialized.")
def save_pretrained(self, save_directory: str):
"""
Saves the model in a Hugging Face-compatible way.
"""
if not os.path.exists(save_directory):
os.makedirs(save_directory)
# 1. Save the inner transformer model using its own save_pretrained
# This gives us the standard HF config.json and pytorch_model.bin for the backbone
self.model.save_pretrained(save_directory)
# 2. Save the whole Oracle state dict (includes transformer + all custom encoders)
# We use 'oracle_model.bin' for the full state.
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
# 3. Save Oracle specific metadata for reconstruction
oracle_config = {
"num_event_types": self.num_event_types,
"multi_modal_dim": self.multi_modal_dim,
"event_pad_id": self.event_pad_id,
"model_config_name": self.model_config_name,
"quantiles": self.quantiles,
"horizons_seconds": self.horizons_seconds,
"dtype": str(self.dtype),
"event_type_to_id": self.event_type_to_id
}
with open(os.path.join(save_directory, "oracle_config.json"), "w") as f:
json.dump(oracle_config, f, indent=2)
print(f"✅ Oracle model saved to {save_directory}")
@classmethod
def from_pretrained(cls, load_directory: str,
token_encoder, wallet_encoder, graph_updater, ohlc_embedder, quant_ohlc_embedder, time_encoder):
"""
Loads the Oracle model from a saved directory.
Note: You must still provide the initialized sub-encoders (or we can refactor to save them too).
"""
config_path = os.path.join(load_directory, "oracle_config.json")
with open(config_path, "r") as f:
config = json.load(f)
# Determine dtype from string
dtype = torch.bfloat16 # Default
if "float32" in config["dtype"]: dtype = torch.float32
elif "float16" in config["dtype"]: dtype = torch.float16
# Instantiate model
model = cls(
token_encoder=token_encoder,
wallet_encoder=wallet_encoder,
graph_updater=graph_updater,
ohlc_embedder=ohlc_embedder,
quant_ohlc_embedder=quant_ohlc_embedder,
time_encoder=time_encoder,
num_event_types=config["num_event_types"],
multi_modal_dim=config["multi_modal_dim"],
event_pad_id=config["event_pad_id"],
event_type_to_id=config["event_type_to_id"],
model_config_name=config["model_config_name"],
quantiles=config["quantiles"],
horizons_seconds=config["horizons_seconds"],
dtype=dtype
)
# Load weights
weight_path = os.path.join(load_directory, "pytorch_model.bin")
state_dict = torch.load(weight_path, map_location="cpu")
model.load_state_dict(state_dict)
print(f"✅ Oracle model loaded from {load_directory}")
return model
def _normalize_and_project(self,
features: torch.Tensor,
norm_layer: nn.LayerNorm,
proj_layer: nn.Linear,
log_indices: Optional[List[int]] = None) -> torch.Tensor:
"""
A helper function to selectively apply log scaling, then normalize and project.
"""
processed_features = torch.nan_to_num(
features.to(torch.float32),
nan=0.0,
posinf=1e6,
neginf=-1e6
)
# Apply log scaling only to specified indices
if log_indices:
# Ensure log_indices are valid
valid_indices = [i for i in log_indices if i < processed_features.shape[-1]]
if valid_indices:
log_features = processed_features[:, :, valid_indices]
log_scaled = torch.sign(log_features) * torch.log1p(torch.abs(log_features))
processed_features[:, :, valid_indices] = log_scaled
# Normalize and project the entire feature set
norm_dtype = norm_layer.weight.dtype
proj_dtype = proj_layer.weight.dtype
normed_features = norm_layer(processed_features.to(norm_dtype))
normed_features = torch.nan_to_num(normed_features, nan=0.0, posinf=0.0, neginf=0.0)
return proj_layer(normed_features.to(proj_dtype))
def _run_snapshot_encoders(self,
batch: Dict[str, Any],
final_wallet_embeddings_raw: torch.Tensor,
wallet_addr_to_batch_idx: Dict[str, int]) -> Dict[str, torch.Tensor]:
"""
Runs snapshot-style encoders that process raw data into embeddings.
This is now truly end-to-end.
"""
device = self.device
all_holder_snapshot_embeds = []
# Iterate through each HolderSnapshot event's raw data
for raw_holder_list in batch['holder_snapshot_raw_data']:
processed_holder_data = []
for holder in raw_holder_list:
wallet_addr = holder['wallet']
# Get the graph-updated wallet embedding using its index
wallet_idx = wallet_addr_to_batch_idx.get(wallet_addr, 0) # 0 is padding
if wallet_idx > 0: # If it's a valid wallet
wallet_embedding = final_wallet_embeddings_raw[wallet_idx - 1] # Adjust for 1-based indexing
processed_holder_data.append({
'wallet_embedding': wallet_embedding,
'pct': holder['holding_pct']
})
# Pass the processed data to the HolderDistributionEncoder
all_holder_snapshot_embeds.append(self.holder_dist_encoder(processed_holder_data))
return {"holder_snapshot": torch.cat(all_holder_snapshot_embeds, dim=0) if all_holder_snapshot_embeds else torch.empty(0, self.d_model, device=device, dtype=self.dtype)}
def _run_dynamic_encoders(self, batch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
"""
Runs all dynamic encoders and returns a dictionary of raw, unprojected embeddings.
"""
device = self.device
# --- NEW: Get pre-computed embedding indices ---
token_encoder_inputs = batch['token_encoder_inputs']
wallet_encoder_inputs = batch['wallet_encoder_inputs']
# The pre-computed embedding pool for the whole batch
embedding_pool = torch.nan_to_num(
batch['embedding_pool'].to(device, self.dtype),
nan=0.0,
posinf=0.0,
neginf=0.0
)
ohlc_price_tensors = torch.nan_to_num(
batch['ohlc_price_tensors'].to(device, self.dtype),
nan=0.0,
posinf=0.0,
neginf=0.0
)
ohlc_interval_ids = batch['ohlc_interval_ids'].to(device)
quant_ohlc_feature_tensors = torch.nan_to_num(
batch['quant_ohlc_feature_tensors'].to(device, self.dtype),
nan=0.0,
posinf=0.0,
neginf=0.0
)
quant_ohlc_feature_mask = batch['quant_ohlc_feature_mask'].to(device)
quant_ohlc_feature_version_ids = batch['quant_ohlc_feature_version_ids'].to(device)
graph_updater_links = batch['graph_updater_links']
# 1a. Encode Tokens
# --- FIXED: Check for a key that still exists ---
if token_encoder_inputs['name_embed_indices'].numel() > 0:
# --- NEW: Gather pre-computed embeddings and pass to encoder ---
# --- CRITICAL FIX: Remove keys that are not part of the TokenEncoder's signature ---
encoder_args = token_encoder_inputs.copy()
encoder_args.pop('_addresses_for_lookup', None) # This key is for the WalletEncoder
encoder_args.pop('name_embed_indices', None)
encoder_args.pop('symbol_embed_indices', None)
encoder_args.pop('image_embed_indices', None)
# --- SAFETY: Create a padded view of the embedding pool and map missing indices (-1) to pad ---
if embedding_pool.numel() > 0:
pad_row = torch.zeros(1, embedding_pool.size(1), device=device, dtype=embedding_pool.dtype)
pool_padded = torch.cat([pad_row, embedding_pool], dim=0)
def pad_and_lookup(idx_tensor: torch.Tensor) -> torch.Tensor:
# Map valid indices >=0 to +1 (shift), invalid (<0) to 0 (pad)
shifted = torch.where(idx_tensor >= 0, idx_tensor + 1, torch.zeros_like(idx_tensor))
return F.embedding(shifted, pool_padded)
name_embeds = pad_and_lookup(token_encoder_inputs['name_embed_indices'])
symbol_embeds = pad_and_lookup(token_encoder_inputs['symbol_embed_indices'])
image_embeds = pad_and_lookup(token_encoder_inputs['image_embed_indices'])
else:
# Empty pool: provide zeros with correct shapes
n = token_encoder_inputs['name_embed_indices'].shape[0]
d = self.multi_modal_dim
zeros = torch.zeros(n, d, device=device, dtype=self.dtype)
name_embeds = zeros
symbol_embeds = zeros
image_embeds = zeros
batch_token_embeddings_unupd = self.token_encoder(
name_embeds=name_embeds,
symbol_embeds=symbol_embeds,
image_embeds=image_embeds,
# Pass all other keys like protocol_ids, is_vanity_flags, etc.
**encoder_args
)
else:
batch_token_embeddings_unupd = torch.empty(0, self.token_encoder.output_dim, device=device, dtype=self.dtype)
# 1b. Encode Wallets
if wallet_encoder_inputs['profile_rows']:
temp_token_lookup = {
addr: batch_token_embeddings_unupd[i]
for i, addr in enumerate(batch['token_encoder_inputs']['_addresses_for_lookup']) # Use helper key
}
initial_wallet_embeddings = self.wallet_encoder(
**wallet_encoder_inputs,
token_vibe_lookup=temp_token_lookup,
embedding_pool=embedding_pool
)
else:
initial_wallet_embeddings = torch.empty(0, self.wallet_encoder.d_model, device=device, dtype=self.dtype)
# 1c. Encode OHLC
if ohlc_price_tensors.shape[0] > 0:
raw_chart_embeddings = self.ohlc_embedder(ohlc_price_tensors, ohlc_interval_ids)
else:
raw_chart_embeddings = torch.empty(0, self.ohlc_embedder.output_dim, device=device, dtype=self.dtype)
if quant_ohlc_feature_tensors.shape[0] > 0:
quant_chart_embeddings = self.quant_ohlc_embedder(
quant_ohlc_feature_tensors,
quant_ohlc_feature_mask,
quant_ohlc_feature_version_ids,
)
else:
quant_chart_embeddings = torch.empty(0, self.quant_ohlc_embedder.output_dim, device=device, dtype=self.dtype)
num_chart_segments = max(raw_chart_embeddings.shape[0], quant_chart_embeddings.shape[0])
if num_chart_segments > 0:
if raw_chart_embeddings.shape[0] == 0:
raw_chart_embeddings = torch.zeros(
num_chart_segments,
self.ohlc_embedder.output_dim,
device=device,
dtype=self.dtype,
)
if quant_chart_embeddings.shape[0] == 0:
quant_chart_embeddings = torch.zeros(
num_chart_segments,
self.quant_ohlc_embedder.output_dim,
device=device,
dtype=self.dtype,
)
interval_embeds = self.chart_interval_fusion_embedding(ohlc_interval_ids[:num_chart_segments]).to(self.dtype)
batch_ohlc_embeddings_raw = self.chart_fusion(
torch.cat([raw_chart_embeddings, quant_chart_embeddings, interval_embeds], dim=-1)
)
else:
batch_ohlc_embeddings_raw = torch.empty(0, self.quant_ohlc_embedder.output_dim, device=device, dtype=self.dtype)
# 1d. Run Graph Updater
pad_wallet_raw = self.pad_wallet_emb.to(self.dtype)
pad_token_raw = self.pad_token_emb.to(self.dtype)
padded_wallet_tensor = torch.cat([pad_wallet_raw, initial_wallet_embeddings], dim=0)
padded_token_tensor = torch.cat([pad_token_raw, batch_token_embeddings_unupd], dim=0)
x_dict_initial = {}
if padded_wallet_tensor.shape[0] > 1: x_dict_initial['wallet'] = padded_wallet_tensor
if padded_token_tensor.shape[0] > 1: x_dict_initial['token'] = padded_token_tensor
if x_dict_initial and graph_updater_links:
final_entity_embeddings_dict = self.graph_updater(x_dict_initial, graph_updater_links)
final_padded_wallet_embs = final_entity_embeddings_dict.get('wallet', padded_wallet_tensor)
final_padded_token_embs = final_entity_embeddings_dict.get('token', padded_token_tensor)
else:
final_padded_wallet_embs = padded_wallet_tensor
final_padded_token_embs = padded_token_tensor
# Strip padding before returning
final_wallet_embeddings_raw = final_padded_wallet_embs[1:]
final_token_embeddings_raw = final_padded_token_embs[1:]
return {
"wallet": final_wallet_embeddings_raw,
"token": final_token_embeddings_raw,
"ohlc": batch_ohlc_embeddings_raw
}
def _project_and_gather_embeddings(self, raw_embeds: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Projects raw embeddings to d_model and gathers them into sequence-aligned tensors.
"""
# Project raw embeddings to d_model
final_wallet_proj = self.wallet_proj(raw_embeds['wallet'])
final_token_proj = self.token_proj(raw_embeds['token'])
final_ohlc_proj = self.ohlc_proj(raw_embeds['ohlc'])
# Project padding embeddings to d_model
pad_wallet = self.wallet_proj(self.pad_wallet_emb.to(self.dtype))
pad_token = self.token_proj(self.pad_token_emb.to(self.dtype))
pad_ohlc = self.ohlc_proj(self.pad_ohlc_emb.to(self.dtype))
pad_holder_snapshot = self.pad_holder_snapshot_emb.to(self.dtype) # Already d_model
# --- NEW: Project pre-computed embeddings and create lookup ---
precomputed_pool = torch.nan_to_num(
batch['embedding_pool'].to(self.device, self.dtype),
nan=0.0,
posinf=0.0,
neginf=0.0
)
final_precomputed_proj = self.precomputed_proj(precomputed_pool)
pad_precomputed = self.precomputed_proj(self.pad_precomputed_emb.to(self.dtype))
final_precomputed_lookup = torch.cat([pad_precomputed, final_precomputed_proj], dim=0)
# Create final lookup tables with padding at index 0
final_wallet_lookup = torch.cat([pad_wallet, final_wallet_proj], dim=0)
final_token_lookup = torch.cat([pad_token, final_token_proj], dim=0)
final_ohlc_lookup = torch.cat([pad_ohlc, final_ohlc_proj], dim=0)
# --- NEW: Add Role Embeddings ---
main_role_emb = self.token_role_embedding(torch.tensor(self.main_token_role_id, device=self.device))
quote_role_emb = self.token_role_embedding(torch.tensor(self.quote_token_role_id, device=self.device))
trending_role_emb = self.token_role_embedding(torch.tensor(self.trending_token_role_id, device=self.device))
# Gather base embeddings
gathered_main_token_embs = F.embedding(batch['token_indices'], final_token_lookup)
gathered_quote_token_embs = F.embedding(batch['quote_token_indices'], final_token_lookup)
gathered_trending_token_embs = F.embedding(batch['trending_token_indices'], final_token_lookup)
gathered_boosted_token_embs = F.embedding(batch['boosted_token_indices'], final_token_lookup)
# --- NEW: Handle HolderSnapshot ---
final_holder_snapshot_lookup = torch.cat([pad_holder_snapshot, raw_embeds['holder_snapshot']], dim=0)
# Gather embeddings for each event in the sequence
return {
"wallet": F.embedding(batch['wallet_indices'], final_wallet_lookup),
"token": gathered_main_token_embs, # This is the baseline, no role needed
"ohlc": F.embedding(batch['ohlc_indices'], final_ohlc_lookup),
"original_author": F.embedding(batch['original_author_indices'], final_wallet_lookup), # NEW
"dest_wallet": F.embedding(batch['dest_wallet_indices'], final_wallet_lookup), # Also gather dest wallet
"quote_token": gathered_quote_token_embs + quote_role_emb,
"trending_token": gathered_trending_token_embs + trending_role_emb,
"boosted_token": gathered_boosted_token_embs + trending_role_emb, # Same role as trending
"holder_snapshot": F.embedding(batch['holder_snapshot_indices'], final_holder_snapshot_lookup), # NEW
"precomputed": final_precomputed_lookup # NEW: Pass the full lookup table
}
def _get_transfer_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculates the special embeddings for Transfer/LargeTransfer events.
"""
device = self.device
transfer_numerical_features = batch['transfer_numerical_features']
event_type_ids = batch['event_type_ids']
# --- FIXED: Selectively log-scale features ---
# Log scale: token_amount (idx 0), priority_fee (idx 3)
# Linear scale: transfer_pct_of_total_supply (idx 1), transfer_pct_of_holding (idx 2)
projected_transfer_features = self._normalize_and_project(
transfer_numerical_features, self.transfer_num_norm, self.transfer_num_proj, log_indices=[0, 3]
)
# Create a mask for Transfer/LargeTransfer events
transfer_event_ids = [self.event_type_to_id.get('Transfer', -1), self.event_type_to_id.get('LargeTransfer', -1)] # ADDED LargeTransfer
transfer_mask = torch.isin(event_type_ids, torch.tensor(transfer_event_ids, device=device)).unsqueeze(-1)
# Combine destination wallet and numerical features, then apply mask
return (gathered_embeds['dest_wallet'] + projected_transfer_features) * transfer_mask
def _get_trade_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculates the special embeddings for Trade events.
"""
device = self.device
trade_numerical_features = batch['trade_numerical_features']
trade_dex_ids = batch['trade_dex_ids'] # NEW
trade_direction_ids = batch['trade_direction_ids']
trade_mev_protection_ids = batch['trade_mev_protection_ids'] # NEW
trade_is_bundle_ids = batch['trade_is_bundle_ids'] # NEW
event_type_ids = batch['event_type_ids']
# --- FIXED: Selectively log-scale features ---
# Log scale: sol_amount (idx 0), priority_fee (idx 1), total_usd (idx 7)
# Linear scale: pcts, slippage, price_impact, success flags
projected_trade_features = self._normalize_and_project(
trade_numerical_features, self.trade_num_norm, self.trade_num_proj, log_indices=[0, 1, 7]
)
# --- CORRECTED: This layer now handles both generic and large trades ---
trade_event_names = ['Trade', 'LargeTrade']
trade_event_ids = [self.event_type_to_id.get(name, -1) for name in trade_event_names]
# Create mask where event_type_id is one of the trade event ids
trade_mask = torch.isin(event_type_ids, torch.tensor(trade_event_ids, device=device)).unsqueeze(-1)
# --- NEW: Get embedding for the categorical dex_id ---
dex_id_embeds = self.dex_platform_embedding(trade_dex_ids)
direction_embeds = self.trade_direction_embedding(trade_direction_ids)
mev_embeds = self.mev_protection_embedding(trade_mev_protection_ids) # NEW
bundle_embeds = self.is_bundle_embedding(trade_is_bundle_ids) # NEW
return (projected_trade_features + dex_id_embeds + direction_embeds + mev_embeds + bundle_embeds) * trade_mask
def _get_deployer_trade_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculates the special embeddings for Deployer_Trade events using its own layers.
"""
device = self.device
deployer_trade_numerical_features = batch['deployer_trade_numerical_features']
trade_dex_ids = batch['trade_dex_ids'] # NEW: Re-use the same ID tensor
trade_direction_ids = batch['trade_direction_ids']
trade_mev_protection_ids = batch['trade_mev_protection_ids'] # NEW
trade_is_bundle_ids = batch['trade_is_bundle_ids'] # NEW
event_type_ids = batch['event_type_ids']
# --- FIXED: Selectively log-scale features ---
# Log scale: sol_amount (idx 0), priority_fee (idx 1), total_usd (idx 7)
projected_deployer_trade_features = self._normalize_and_project(
deployer_trade_numerical_features, self.deployer_trade_num_norm, self.deployer_trade_num_proj, log_indices=[0, 1, 7]
)
dex_id_embeds = self.dex_platform_embedding(trade_dex_ids)
direction_embeds = self.trade_direction_embedding(trade_direction_ids)
mev_embeds = self.mev_protection_embedding(trade_mev_protection_ids) # NEW
bundle_embeds = self.is_bundle_embedding(trade_is_bundle_ids) # NEW
deployer_trade_mask = (event_type_ids == self.event_type_to_id.get('Deployer_Trade', -1)).unsqueeze(-1)
return (projected_deployer_trade_features + dex_id_embeds + direction_embeds + mev_embeds + bundle_embeds) * deployer_trade_mask
def _get_smart_wallet_trade_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculates the special embeddings for SmartWallet_Trade events using its own layers.
"""
device = self.device
smart_wallet_trade_numerical_features = batch['smart_wallet_trade_numerical_features']
trade_dex_ids = batch['trade_dex_ids'] # NEW: Re-use the same ID tensor
trade_direction_ids = batch['trade_direction_ids']
trade_mev_protection_ids = batch['trade_mev_protection_ids'] # NEW
trade_is_bundle_ids = batch['trade_is_bundle_ids'] # NEW
event_type_ids = batch['event_type_ids']
# --- FIXED: Selectively log-scale features ---
# Log scale: sol_amount (idx 0), priority_fee (idx 1), total_usd (idx 7)
projected_features = self._normalize_and_project(
smart_wallet_trade_numerical_features, self.smart_wallet_trade_num_norm, self.smart_wallet_trade_num_proj, log_indices=[0, 1, 7]
)
dex_id_embeds = self.dex_platform_embedding(trade_dex_ids)
direction_embeds = self.trade_direction_embedding(trade_direction_ids)
mev_embeds = self.mev_protection_embedding(trade_mev_protection_ids) # NEW
bundle_embeds = self.is_bundle_embedding(trade_is_bundle_ids) # NEW
mask = (event_type_ids == self.event_type_to_id.get('SmartWallet_Trade', -1)).unsqueeze(-1)
return (projected_features + dex_id_embeds + direction_embeds + mev_embeds + bundle_embeds) * mask
def _get_pool_created_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculates the special embeddings for PoolCreated events.
"""
device = self.device
pool_created_numerical_features = batch['pool_created_numerical_features']
pool_created_protocol_ids = batch['pool_created_protocol_ids'] # NEW
event_type_ids = batch['event_type_ids']
# --- FIXED: Selectively log-scale features ---
# Log scale: base_amount (idx 0), quote_amount (idx 1)
# Linear scale: pcts (idx 2, 3)
projected_features = self._normalize_and_project(
pool_created_numerical_features, self.pool_created_num_norm, self.pool_created_num_proj, log_indices=[0, 1]
)
# --- NEW: Get embedding for the categorical protocol_id ---
protocol_id_embeds = self.protocol_embedding(pool_created_protocol_ids)
# Create mask for the event
mask = (event_type_ids == self.event_type_to_id.get('PoolCreated', -1)).unsqueeze(-1)
# Combine Quote Token embedding with projected numericals
return (gathered_embeds['quote_token'] + projected_features + protocol_id_embeds) * mask
def _get_liquidity_change_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculates the special embeddings for LiquidityChange events.
"""
device = self.device
liquidity_change_numerical_features = batch['liquidity_change_numerical_features']
liquidity_change_type_ids = batch['liquidity_change_type_ids'] # NEW
event_type_ids = batch['event_type_ids']
# --- FIXED: Selectively log-scale features ---
# Log scale: quote_amount (idx 0)
projected_features = self._normalize_and_project(
liquidity_change_numerical_features, self.liquidity_change_num_norm, self.liquidity_change_num_proj, log_indices=[0]
)
# --- NEW: Get embedding for the categorical change_type_id ---
change_type_embeds = self.liquidity_change_type_embedding(liquidity_change_type_ids)
# Create mask for the event
mask = (event_type_ids == self.event_type_to_id.get('LiquidityChange', -1)).unsqueeze(-1)
# Combine Quote Token embedding with projected numericals
return (gathered_embeds['quote_token'] + projected_features + change_type_embeds) * mask
def _get_fee_collected_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculates the special embeddings for FeeCollected events.
"""
device = self.device
fee_collected_numerical_features = batch['fee_collected_numerical_features']
event_type_ids = batch['event_type_ids']
# --- FIXED: Single amount, log-scale ---
projected_features = self._normalize_and_project(
fee_collected_numerical_features, self.fee_collected_num_norm, self.fee_collected_num_proj, log_indices=[0]
)
# Create mask for the event
mask = (event_type_ids == self.event_type_to_id.get('FeeCollected', -1)).unsqueeze(-1)
return projected_features * mask
def _get_token_burn_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculates the special embeddings for TokenBurn events.
"""
device = self.device
token_burn_numerical_features = batch['token_burn_numerical_features']
event_type_ids = batch['event_type_ids']
# --- FIXED: Selectively log-scale features ---
# Log scale: amount_tokens_burned (idx 1)
# Linear scale: amount_pct_of_total_supply (idx 0)
projected_features = self._normalize_and_project(
token_burn_numerical_features, self.token_burn_num_norm, self.token_burn_num_proj, log_indices=[1]
)
# Create mask for the event
mask = (event_type_ids == self.event_type_to_id.get('TokenBurn', -1)).unsqueeze(-1)
return projected_features * mask
def _get_supply_lock_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculates the special embeddings for SupplyLock events.
"""
device = self.device
supply_lock_numerical_features = batch['supply_lock_numerical_features']
event_type_ids = batch['event_type_ids']
# --- FIXED: Selectively log-scale features ---
# Log scale: lock_duration (idx 1)
# Linear scale: amount_pct_of_total_supply (idx 0)
projected_features = self._normalize_and_project(
supply_lock_numerical_features, self.supply_lock_num_norm, self.supply_lock_num_proj, log_indices=[1]
)
# Create mask for the event
mask = (event_type_ids == self.event_type_to_id.get('SupplyLock', -1)).unsqueeze(-1)
return projected_features * mask
def _get_onchain_snapshot_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculates the special embeddings for OnChain_Snapshot events.
"""
device = self.device
onchain_snapshot_numerical_features = batch['onchain_snapshot_numerical_features']
event_type_ids = batch['event_type_ids']
# --- FIXED: Selectively log-scale features ---
# Log scale: counts, market_cap, liquidity, volume, fees (almost all)
# Linear scale: growth_rate, holder_pcts (indices 3, 4, 5, 6, 7)
projected_features = self._normalize_and_project(
onchain_snapshot_numerical_features, self.onchain_snapshot_num_norm, self.onchain_snapshot_num_proj, log_indices=[0, 1, 2, 8, 9, 10, 11, 12, 13]
)
# Create mask for the event
mask = (event_type_ids == self.event_type_to_id.get('OnChain_Snapshot', -1)).unsqueeze(-1)
return projected_features * mask
def _get_trending_token_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculates the special embeddings for TrendingToken events.
"""
device = self.device
trending_token_numerical_features = batch['trending_token_numerical_features']
trending_token_source_ids = batch['trending_token_source_ids'] # NEW
trending_token_timeframe_ids = batch['trending_token_timeframe_ids'] # NEW
event_type_ids = batch['event_type_ids']
# --- FIXED: Rank is already inverted (0-1), so treat as linear ---
projected_features = self._normalize_and_project(
trending_token_numerical_features, self.trending_token_num_norm, self.trending_token_num_proj, log_indices=None
)
# --- NEW: Get embeddings for categorical IDs ---
source_embeds = self.trending_list_source_embedding(trending_token_source_ids)
timeframe_embeds = self.trending_timeframe_embedding(trending_token_timeframe_ids)
# Create mask for the event
mask = (event_type_ids == self.event_type_to_id.get('TrendingToken', -1)).unsqueeze(-1)
# Combine Trending Token embedding with its projected numericals
return (gathered_embeds['trending_token'] + projected_features + source_embeds + timeframe_embeds) * mask
def _get_boosted_token_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculates the special embeddings for BoostedToken events.
"""
device = self.device
boosted_token_numerical_features = batch['boosted_token_numerical_features']
event_type_ids = batch['event_type_ids']
# --- FIXED: Selectively log-scale features ---
# Log scale: total_boost_amount (idx 0)
# Linear scale: inverted rank (idx 1)
projected_features = self._normalize_and_project(
boosted_token_numerical_features, self.boosted_token_num_norm, self.boosted_token_num_proj, log_indices=[0]
)
# Create mask for the event
mask = (event_type_ids == self.event_type_to_id.get('BoostedToken', -1)).unsqueeze(-1)
# Combine Boosted Token embedding with its projected numericals
return (gathered_embeds['boosted_token'] + projected_features) * mask
def _get_dexboost_paid_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Calculates the special embeddings for DexBoost_Paid events.
"""
device = self.device
dexboost_paid_numerical_features = batch['dexboost_paid_numerical_features']
event_type_ids = batch['event_type_ids']
# --- FIXED: All features are amounts, so log-scale all ---
projected_features = self._normalize_and_project(
dexboost_paid_numerical_features, self.dexboost_paid_num_norm, self.dexboost_paid_num_proj, log_indices=[0, 1]
)
# Create mask for the event
mask = (event_type_ids == self.event_type_to_id.get('DexBoost_Paid', -1)).unsqueeze(-1)
return projected_features * mask
def _get_alphagroup_call_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Handles AlphaGroup_Call events by looking up the group_id embedding.
"""
device = self.device
group_ids = batch['alpha_group_ids']
event_type_ids = batch['event_type_ids']
# Look up the embedding for the group ID
group_embeds = self.alpha_group_embedding(group_ids)
# Create mask for the event
mask = (event_type_ids == self.event_type_to_id.get('AlphaGroup_Call', -1)).unsqueeze(-1)
return group_embeds * mask
def _get_channel_call_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Handles Channel_Call events by looking up the channel_id embedding.
"""
device = self.device
channel_ids = batch['channel_ids']
event_type_ids = batch['event_type_ids']
channel_embeds = self.call_channel_embedding(channel_ids)
mask = (event_type_ids == self.event_type_to_id.get('Channel_Call', -1)).unsqueeze(-1)
return channel_embeds * mask
def _get_cexlisting_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Handles CexListing events by looking up the exchange_id embedding.
"""
device = self.device
exchange_ids = batch['exchange_ids']
event_type_ids = batch['event_type_ids']
exchange_embeds = self.cex_listing_embedding(exchange_ids)
mask = (event_type_ids == self.event_type_to_id.get('CexListing', -1)).unsqueeze(-1)
return exchange_embeds * mask
def _get_chainsnapshot_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Handles ChainSnapshot events.
"""
device = self.device
numerical_features = batch['chainsnapshot_numerical_features']
event_type_ids = batch['event_type_ids']
# --- FIXED: All features are amounts/prices, so log-scale all ---
projected_features = self._normalize_and_project(
numerical_features, self.chainsnapshot_num_norm, self.chainsnapshot_num_proj, log_indices=[0, 1]
)
mask = (event_type_ids == self.event_type_to_id.get('ChainSnapshot', -1)).unsqueeze(-1)
return projected_features * mask
def _get_lighthousesnapshot_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Handles Lighthouse_Snapshot events.
"""
device = self.device
numerical_features = batch['lighthousesnapshot_numerical_features']
protocol_ids = batch['lighthousesnapshot_protocol_ids'] # NEW
timeframe_ids = batch['lighthousesnapshot_timeframe_ids'] # NEW
event_type_ids = batch['event_type_ids']
# --- FIXED: All features are counts/volumes, so log-scale all ---
projected_features = self._normalize_and_project(
numerical_features, self.lighthousesnapshot_num_norm, self.lighthousesnapshot_num_proj, log_indices=[0, 1, 2, 3, 4]
)
# --- NEW: Get embeddings for categorical IDs ---
# Re-use the main protocol embedding layer
protocol_embeds = self.protocol_embedding(protocol_ids)
timeframe_embeds = self.lighthouse_timeframe_embedding(timeframe_ids)
mask = (event_type_ids == self.event_type_to_id.get('Lighthouse_Snapshot', -1)).unsqueeze(-1)
return (projected_features + protocol_embeds + timeframe_embeds) * mask
def _get_migrated_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Handles Migrated events by looking up the protocol_id embedding.
"""
device = self.device
protocol_ids = batch['migrated_protocol_ids']
event_type_ids = batch['event_type_ids']
# Look up the embedding for the protocol ID
protocol_embeds = self.protocol_embedding(protocol_ids)
# Create mask for the event
mask = (event_type_ids == self.event_type_to_id.get('Migrated', -1)).unsqueeze(-1)
return protocol_embeds * mask
def _get_special_context_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Handles special context tokens like 'MIDDLE' and 'RECENT' by adding their unique learnable embeddings.
"""
device = self.device
event_type_ids = batch['event_type_ids']
B, L = event_type_ids.shape
middle_id = self.event_type_to_id.get('MIDDLE', -1)
recent_id = self.event_type_to_id.get('RECENT', -1)
middle_mask = (event_type_ids == middle_id)
recent_mask = (event_type_ids == recent_id)
middle_emb = self.special_context_embedding(torch.tensor(self.special_context_tokens['MIDDLE'], device=device))
recent_emb = self.special_context_embedding(torch.tensor(self.special_context_tokens['RECENT'], device=device))
# Add the embeddings at the correct locations
return middle_mask.unsqueeze(-1) * middle_emb + recent_mask.unsqueeze(-1) * recent_emb
def _pool_hidden_states(self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor) -> torch.Tensor:
"""
Pools variable-length hidden states into a single embedding per sequence by
selecting the last non-masked token for each batch element.
"""
if hidden_states.size(0) == 0:
return torch.empty(0, self.d_model, device=hidden_states.device, dtype=hidden_states.dtype)
seq_lengths = attention_mask.long().sum(dim=1)
last_indices = torch.clamp(seq_lengths - 1, min=0)
batch_indices = torch.arange(hidden_states.size(0), device=hidden_states.device)
return hidden_states[batch_indices, last_indices]
def forward(self, batch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
device = self.device
# Unpack core sequence tensors
event_type_ids = batch['event_type_ids'].to(device)
timestamps_float = batch['timestamps_float'].to(device)
relative_ts = batch['relative_ts'].to(device, self.dtype)
attention_mask = batch['attention_mask'].to(device)
B, L = event_type_ids.shape
if B == 0 or L == 0:
print("Warning: Received empty batch in Oracle forward.")
empty_hidden = torch.empty(0, L, self.d_model, device=device, dtype=self.dtype)
empty_mask = torch.empty(0, L, device=device, dtype=torch.long)
empty_quantiles = torch.empty(0, self.num_outputs, device=device, dtype=self.dtype)
empty_quality = torch.empty(0, device=device, dtype=self.dtype)
empty_movement = torch.empty(0, len(self.horizons_seconds), self.num_movement_classes, device=device, dtype=self.dtype)
return {
'quantile_logits': empty_quantiles,
'quality_logits': empty_quality,
'movement_logits': empty_movement,
'pooled_states': torch.empty(0, self.d_model, device=device, dtype=self.dtype),
'hidden_states': empty_hidden,
'attention_mask': empty_mask
}
# === 1. Run Dynamic Encoders (produces graph-updated entity embeddings) ===
dynamic_raw_embeds = self._run_dynamic_encoders(batch)
# === 2. Run Snapshot Encoders (uses dynamic_raw_embeds) ===
wallet_addr_to_batch_idx = batch['wallet_addr_to_batch_idx']
snapshot_raw_embeds = self._run_snapshot_encoders(batch, dynamic_raw_embeds['wallet'], wallet_addr_to_batch_idx)
# === 3. Project Raw Embeddings and Gather for Sequence ===
raw_embeds = {**dynamic_raw_embeds, **snapshot_raw_embeds}
gathered_embeds = self._project_and_gather_embeddings(raw_embeds, batch)
# === 4. Assemble Final `inputs_embeds` ===
event_embeds = self.event_type_embedding(event_type_ids)
ts_embeds = self.time_proj(self.time_encoder(timestamps_float))
# Stabilize relative time: minutes scale + signed log1p + LayerNorm before projection
relative_ts_fp32 = batch['relative_ts'].to(device, torch.float32)
rel_ts_minutes = relative_ts_fp32 / 60.0
rel_ts_processed = torch.sign(rel_ts_minutes) * torch.log1p(torch.abs(rel_ts_minutes))
# Match LayerNorm parameter dtype, then match Linear parameter dtype
norm_dtype = self.rel_ts_norm.weight.dtype
proj_dtype = self.rel_ts_proj.weight.dtype
rel_ts_normed = self.rel_ts_norm(rel_ts_processed.to(norm_dtype))
rel_ts_embeds = self.rel_ts_proj(rel_ts_normed.to(proj_dtype))
# Get special embeddings for Transfer events
transfer_specific_embeds = self._get_transfer_specific_embeddings(batch, gathered_embeds)
# Get special embeddings for Trade events
trade_specific_embeds = self._get_trade_specific_embeddings(batch)
# Get special embeddings for Deployer Trade events
deployer_trade_specific_embeds = self._get_deployer_trade_specific_embeddings(batch)
# Get special embeddings for Smart Wallet Trade events
smart_wallet_trade_specific_embeds = self._get_smart_wallet_trade_specific_embeddings(batch)
# Get special embeddings for PoolCreated events
pool_created_specific_embeds = self._get_pool_created_specific_embeddings(batch, gathered_embeds)
# Get special embeddings for LiquidityChange events
liquidity_change_specific_embeds = self._get_liquidity_change_specific_embeddings(batch, gathered_embeds)
# Get special embeddings for FeeCollected events
fee_collected_specific_embeds = self._get_fee_collected_specific_embeddings(batch)
# Get special embeddings for TokenBurn events
token_burn_specific_embeds = self._get_token_burn_specific_embeddings(batch)
# Get special embeddings for SupplyLock events
supply_lock_specific_embeds = self._get_supply_lock_specific_embeddings(batch)
# Get special embeddings for OnChain_Snapshot events
onchain_snapshot_specific_embeds = self._get_onchain_snapshot_specific_embeddings(batch)
# Get special embeddings for TrendingToken events
trending_token_specific_embeds = self._get_trending_token_specific_embeddings(batch, gathered_embeds)
# Get special embeddings for BoostedToken events
boosted_token_specific_embeds = self._get_boosted_token_specific_embeddings(batch, gathered_embeds)
# Get special embeddings for DexBoost_Paid events
dexboost_paid_specific_embeds = self._get_dexboost_paid_specific_embeddings(batch)
# --- NEW: Get embeddings for Tracker events ---
alphagroup_call_specific_embeds = self._get_alphagroup_call_specific_embeddings(batch)
channel_call_specific_embeds = self._get_channel_call_specific_embeddings(batch)
cexlisting_specific_embeds = self._get_cexlisting_specific_embeddings(batch)
# --- NEW: Get embeddings for Chain and Lighthouse Snapshots ---
chainsnapshot_specific_embeds = self._get_chainsnapshot_specific_embeddings(batch)
lighthousesnapshot_specific_embeds = self._get_lighthousesnapshot_specific_embeddings(batch)
migrated_specific_embeds = self._get_migrated_specific_embeddings(batch)
# --- NEW: Handle DexProfile_Updated flags separately ---
dexprofile_updated_flags = batch['dexprofile_updated_flags']
dexprofile_flags_embeds = self.dexprofile_updated_flags_proj(dexprofile_updated_flags.to(self.dtype))
# --- REFACTORED: All text-based events are handled by the SocialEncoder ---
# This single call will replace the inefficient loops for social, dexprofile, and global trending events.
# The SocialEncoder's forward pass will need to be updated to handle this.
textual_event_embeds = self.social_encoder(
batch=batch,
gathered_embeds=gathered_embeds
)
# --- NEW: Get embeddings for special context injection tokens ---
special_context_embeds = self._get_special_context_embeddings(batch)
# --- Combine all features ---
# Sum in float32 for numerical stability, then cast back to model dtype
components = [
event_embeds, ts_embeds, rel_ts_embeds,
gathered_embeds['wallet'], gathered_embeds['token'], gathered_embeds['original_author'], gathered_embeds['ohlc'],
transfer_specific_embeds, trade_specific_embeds, deployer_trade_specific_embeds, smart_wallet_trade_specific_embeds,
pool_created_specific_embeds, liquidity_change_specific_embeds, fee_collected_specific_embeds,
token_burn_specific_embeds, supply_lock_specific_embeds, onchain_snapshot_specific_embeds,
trending_token_specific_embeds, boosted_token_specific_embeds, dexboost_paid_specific_embeds,
alphagroup_call_specific_embeds, channel_call_specific_embeds, cexlisting_specific_embeds,
migrated_specific_embeds, special_context_embeds, gathered_embeds['holder_snapshot'], textual_event_embeds,
dexprofile_flags_embeds, chainsnapshot_specific_embeds, lighthousesnapshot_specific_embeds
]
inputs_embeds = sum([t.float() for t in components]).to(self.dtype)
hf_attention_mask = attention_mask.to(device=device, dtype=torch.long)
outputs = self.model(
inputs_embeds=inputs_embeds,
attention_mask=hf_attention_mask,
return_dict=True
)
sequence_hidden = outputs.last_hidden_state
pooled_states = self._pool_hidden_states(sequence_hidden, hf_attention_mask)
quantile_logits = self.quantile_head(pooled_states)
quality_logits = self.quality_head(pooled_states).squeeze(-1)
movement_logits = self.movement_head(pooled_states).view(
pooled_states.shape[0],
len(self.horizons_seconds),
self.num_movement_classes,
)
return {
'quantile_logits': quantile_logits,
'quality_logits': quality_logits,
'movement_logits': movement_logits,
'pooled_states': pooled_states,
'hidden_states': sequence_hidden,
'attention_mask': hf_attention_mask
}
|