oracle / data /data_collator.py
zirobtc's picture
Upload folder using huggingface_hub
d195287 verified
# memecoin_collator.py (CORRECTED ORDER OF OPERATIONS)
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from typing import List, Dict, Any, Tuple, Optional, Union
from collections import defaultdict
from PIL import Image
# --- GLOBAL SINGLETON FOR WORKER PROCESSES REMOVED ---
import models.vocabulary as vocab
from data.data_loader import EmbeddingPooler
from data.quant_ohlc_feature_schema import FEATURE_VERSION, FEATURE_VERSION_ID, NUM_QUANT_OHLC_FEATURES, TOKENS_PER_SEGMENT
NATIVE_MINT = "So11111111111111111111111111111111111111112"
QUOTE_MINTS = {
NATIVE_MINT, # SOL
"EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v", # USDC
"Es9vMFrzaCERmJfrF4H2FYD4KCoNkY11McCe8BenwNYB", # USDT
"USD1ttGY1N17NEEHLmELoaybftRBUSErhqYiQzvEmuB", # USD1
}
class MemecoinCollator:
"""
Callable class for PyTorch DataLoader's collate_fn.
... (rest of docstring) ...
"""
def __init__(self,
event_type_to_id: Dict[str, int],
device: torch.device,
dtype: torch.dtype,
max_seq_len: Optional[int] = None,
model_id: str = "google/siglip-so400m-patch16-256-i18n"
):
self.event_type_to_id = event_type_to_id
self.pad_token_id = event_type_to_id.get('__PAD__', 0)
# self.multi_modal_encoder = multi_modal_encoder # DEPRECATED
self.model_id = model_id
self.entity_pad_idx = 0
self.device = device
self.dtype = dtype
self.ohlc_seq_len = 300 # HARDCODED
self.quant_ohlc_tokens = TOKENS_PER_SEGMENT
self.quant_ohlc_num_features = NUM_QUANT_OHLC_FEATURES
self.max_seq_len = max_seq_len
def _collate_features_for_encoder(self, entities: List[Dict], feature_keys: List[str], device: torch.device, entity_type: str) -> Dict[str, Any]:
""" (Unchanged) """
collated = defaultdict(list)
if not entities:
# --- FIXED: Return a default empty structure for BOTH tokens and wallets ---
if entity_type == "token":
return {
'name_embed_indices': torch.tensor([], device=device, dtype=torch.long),
'symbol_embed_indices': torch.tensor([], device=device, dtype=torch.long),
'image_embed_indices': torch.tensor([], device=device, dtype=torch.long),
'protocol_ids': torch.tensor([], device=device, dtype=torch.long),
'is_vanity_flags': torch.tensor([], device=device, dtype=torch.bool),
'_addresses_for_lookup': []
}
elif entity_type == "wallet":
return {
'username_embed_indices': torch.tensor([], device=device, dtype=torch.long),
'profile_rows': [], 'social_rows': [], 'holdings_batch': []
}
return {} # Should not happen
# NEW: We now gather indices to pre-computed embeddings
if entity_type == "token":
# This indicates a Token entity
# Helper key for WalletEncoder to find token vibes
collated['_addresses_for_lookup'] = [e.get('address', '') for e in entities]
collated['name_embed_indices'] = torch.tensor([e.get('name_emb_idx', 0) for e in entities], device=device, dtype=torch.long)
collated['symbol_embed_indices'] = torch.tensor([e.get('symbol_emb_idx', 0) for e in entities], device=device, dtype=torch.long)
collated['image_embed_indices'] = torch.tensor([e.get('image_emb_idx', 0) for e in entities], device=device, dtype=torch.long)
collated['protocol_ids'] = torch.tensor([e.get('protocol', 0) for e in entities], device=device, dtype=torch.long)
collated['is_vanity_flags'] = torch.tensor([e.get('is_vanity', False) for e in entities], device=device, dtype=torch.bool)
elif entity_type == "wallet":
# NEW: Gather username indices for WalletEncoder
collated['username_embed_indices'] = torch.tensor([e.get('socials', {}).get('username_emb_idx', 0) for e in entities], device=device, dtype=torch.long)
collated['profile_rows'] = [e.get('profile', {}) for e in entities]
collated['social_rows'] = [e.get('socials', {}) for e in entities]
collated['holdings_batch'] = [e.get('holdings', []) for e in entities]
return dict(collated)
def _collate_ohlc_inputs(self, chart_events: List[Dict]) -> Dict[str, torch.Tensor]:
""" (Unchanged from previous correct version) """
if not chart_events:
return {
'price_tensor': torch.empty(0, 2, self.ohlc_seq_len, device=self.device, dtype=self.dtype),
'interval_ids': torch.empty(0, device=self.device, dtype=torch.long),
'quant_feature_tensors': torch.empty(0, self.quant_ohlc_tokens, self.quant_ohlc_num_features, device=self.device, dtype=self.dtype),
'quant_feature_mask': torch.empty(0, self.quant_ohlc_tokens, device=self.device, dtype=self.dtype),
'quant_feature_version_ids': torch.empty(0, device=self.device, dtype=torch.long),
}
ohlc_tensors = []
interval_ids_list = []
quant_feature_tensors = []
quant_feature_masks = []
quant_feature_version_ids = []
seq_len = self.ohlc_seq_len
unknown_id = vocab.INTERVAL_TO_ID.get("Unknown", 0)
for segment_data in chart_events:
opens = segment_data.get('opens', [])
closes = segment_data.get('closes', [])
interval_str = segment_data.get('i', "Unknown")
pad_open = opens[-1] if opens else 0
pad_close = closes[-1] if closes else 0
o = torch.tensor(opens[:seq_len] + [pad_open]*(seq_len-len(opens)), dtype=self.dtype)
c = torch.tensor(closes[:seq_len] + [pad_close]*(seq_len-len(closes)), dtype=self.dtype)
o = torch.nan_to_num(o, nan=0.0, posinf=0.0, neginf=0.0)
c = torch.nan_to_num(c, nan=0.0, posinf=0.0, neginf=0.0)
ohlc_tensors.append(torch.stack([o, c]))
interval_id = vocab.INTERVAL_TO_ID.get(interval_str, unknown_id)
interval_ids_list.append(interval_id)
quant_payload = segment_data.get('quant_ohlc_features')
if quant_payload is None:
raise RuntimeError("Chart_Segment missing quant_ohlc_features. Rebuild cache with quantitative chart features.")
if not isinstance(quant_payload, list):
raise RuntimeError("Chart_Segment quant_ohlc_features must be a list.")
feature_rows = []
feature_mask = []
for token_idx in range(self.quant_ohlc_tokens):
if token_idx < len(quant_payload):
payload = quant_payload[token_idx]
vec = payload.get('feature_vector')
if not isinstance(vec, list) or len(vec) != self.quant_ohlc_num_features:
raise RuntimeError(
f"Chart_Segment quant feature vector must have length {self.quant_ohlc_num_features}."
)
feature_rows.append(vec)
feature_mask.append(1.0)
else:
feature_rows.append([0.0] * self.quant_ohlc_num_features)
feature_mask.append(0.0)
quant_feature_tensors.append(torch.tensor(feature_rows, device=self.device, dtype=self.dtype))
quant_feature_masks.append(torch.tensor(feature_mask, device=self.device, dtype=self.dtype))
version = segment_data.get('quant_feature_version', FEATURE_VERSION)
quant_feature_version_ids.append(FEATURE_VERSION_ID if version == FEATURE_VERSION else 0)
return {
'price_tensor': torch.stack(ohlc_tensors).to(self.device),
'interval_ids': torch.tensor(interval_ids_list, device=self.device, dtype=torch.long),
'quant_feature_tensors': torch.stack(quant_feature_tensors).to(self.device),
'quant_feature_mask': torch.stack(quant_feature_masks).to(self.device),
'quant_feature_version_ids': torch.tensor(quant_feature_version_ids, device=self.device, dtype=torch.long),
}
def _collate_graph_links(self,
batch_items: List[Dict],
wallet_addr_to_batch_idx: Dict[str, int],
token_addr_to_batch_idx: Dict[str, int]) -> Dict[str, Any]:
""" (Unchanged) """
aggregated_links = defaultdict(lambda: {'edge_index_list': [], 'links_list': []})
for item in batch_items:
item_wallets = item.get('wallets', {})
item_tokens = item.get('tokens', {})
item_wallet_addr_to_global_idx = {addr: wallet_addr_to_batch_idx.get(addr, self.entity_pad_idx) for addr in item_wallets.keys()}
item_token_addr_to_global_idx = {addr: token_addr_to_batch_idx.get(addr, self.entity_pad_idx) for addr in item_tokens.keys()}
for link_name, data in item.get('graph_links', {}).items():
# aggregated_links[link_name]['links_list'].extend(data.get('links', [])) - REMOVED: Now handled inside the loop for sync
triplet = vocab.LINK_NAME_TO_TRIPLET.get(link_name)
if not triplet: continue
src_type, _, dst_type = triplet
edges = data.get('edges')
link_props_list = data.get('links', [])
if not edges or not link_props_list: continue
src_map = item_wallet_addr_to_global_idx if src_type == 'wallet' else item_token_addr_to_global_idx
dst_map = item_wallet_addr_to_global_idx if dst_type == 'wallet' else item_token_addr_to_global_idx
remapped_edge_list = []
valid_link_props = []
for (src_addr, dst_addr), props in zip(edges, link_props_list):
src_idx_global = src_map.get(src_addr, self.entity_pad_idx)
dst_idx_global = dst_map.get(dst_addr, self.entity_pad_idx)
if src_idx_global != self.entity_pad_idx and dst_idx_global != self.entity_pad_idx:
remapped_edge_list.append([src_idx_global, dst_idx_global])
valid_link_props.append(props)
if remapped_edge_list:
remapped_edge_tensor = torch.tensor(remapped_edge_list, device=self.device, dtype=torch.long).t()
aggregated_links[link_name]['edge_index_list'].append(remapped_edge_tensor)
aggregated_links[link_name]['links_list'].extend(valid_link_props)
if link_name == "TransferLink":
link_props = data.get('links', [])
derived_edges = []
derived_props = []
for (src_addr, dst_addr), props in zip(edges, link_props):
mint_addr = props.get('mint')
if not mint_addr or mint_addr in QUOTE_MINTS:
continue
token_idx_global = item_token_addr_to_global_idx.get(mint_addr, self.entity_pad_idx)
if token_idx_global == self.entity_pad_idx:
continue
for wallet_addr in (src_addr, dst_addr):
wallet_idx_global = item_wallet_addr_to_global_idx.get(wallet_addr, self.entity_pad_idx)
if wallet_idx_global == self.entity_pad_idx:
continue
derived_edges.append([wallet_idx_global, token_idx_global])
derived_props.append(props)
if derived_edges:
derived_tensor = torch.tensor(derived_edges, device=self.device, dtype=torch.long).t()
aggregated_links["TransferLinkToken"]['edge_index_list'].append(derived_tensor)
aggregated_links["TransferLinkToken"]['links_list'].extend(derived_props)
final_links_dict = {}
for link_name, data in aggregated_links.items():
if data['edge_index_list']:
final_links_dict[link_name] = {
'links': data['links_list'],
'edge_index': torch.cat(data['edge_index_list'], dim=1)
}
return final_links_dict
def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Processes a batch of raw data items into tensors for the model.
"""
# --- NEW ARCHITECTURE ---
# 1. Aggregate all unique embeddable items from the entire batch.
# 2. Create a single embedding pool tensor for the whole batch.
# 3. Create a mapping from original (per-item) indices to the new batch-wide indices.
# 4. Remap all `_emb_idx` fields in the batch data using this new mapping.
batch_size = len(batch)
if batch_size == 0:
return {}
# --- 1. Aggregate all unique items and create index mappings ---
batch_wide_pooler = EmbeddingPooler()
# Map to translate from an item's original pooler to the new batch-wide indices
# Format: { batch_item_index: { original_idx: new_batch_idx } }
idx_remap = defaultdict(dict)
for i, item in enumerate(batch):
pooler = item.get('embedding_pooler')
if not pooler: continue
for pool_item_data in pooler.get_all_items():
original_idx = pool_item_data['idx']
raw_item = pool_item_data['item']
# get_idx will either return an existing index or create a new one
# --- FIX: Convert 1-based pooler index to 0-based tensor index ---
new_batch_idx_1_based = batch_wide_pooler.get_idx(raw_item)
new_batch_idx_0_based = new_batch_idx_1_based - 1
idx_remap[i][original_idx] = new_batch_idx_0_based
# --- 2. Create the single, batch-wide embedding pool tensor ---
all_items_sorted = batch_wide_pooler.get_all_items()
if not all_items_sorted:
# Handle edge case of absolutely no embeddings in batch
# Create a dummy empty tensor
batch_embedding_pool = torch.empty(0, 768, device=self.device, dtype=self.dtype) # Default SigLIP dim is 1152 actually, but standard is 768. Better to infer or default.
# Actually, if empty, it doesn't matter much as long as it's not accessed.
else:
first_item = all_items_sorted[0]['item']
if not isinstance(first_item, torch.Tensor):
raise RuntimeError(f"Collator expects pre-computed embeddings (torch.Tensor), found {type(first_item)}. Please rebuild cache.")
# Stack all embeddings
# They should already be CPU tensors from the loader
# Move to device and cast to dtype
batch_embedding_pool = torch.stack([d['item'] for d in all_items_sorted]).to(device=self.device, dtype=self.dtype)
batch_embedding_pool = torch.nan_to_num(batch_embedding_pool, nan=0.0, posinf=0.0, neginf=0.0)
# --- 3. Remap all indices in the batch data ---
for i, item in enumerate(batch):
remap_dict = idx_remap.get(i, {})
if not remap_dict: continue
# Remap tokens
for token_data in item.get('tokens', {}).values():
for key in ['name_emb_idx', 'symbol_emb_idx', 'image_emb_idx']:
if token_data.get(key, 0) > 0: # Check if it has a valid 1-based index
token_data[key] = remap_dict.get(token_data[key], -1) # Remap to 0-based, default to -1 if not found
# Remap wallets
for wallet_data in item.get('wallets', {}).values():
socials = wallet_data.get('socials', {})
if socials.get('username_emb_idx', 0) > 0:
socials['username_emb_idx'] = remap_dict.get(socials['username_emb_idx'], -1)
# Remap events
for event in item.get('event_sequence', []):
for key in event:
if key.endswith('_emb_idx') and event.get(key, 0) > 0:
event[key] = remap_dict.get(event[key], 0)
# --- 4. Standard Collation (Now that indices are correct) ---
unique_wallets_data = {}
unique_tokens_data = {}
all_event_sequences = []
max_len = 0
for item in batch:
seq = item.get('event_sequence', [])
if self.max_seq_len is not None and len(seq) > self.max_seq_len:
seq = seq[:self.max_seq_len]
all_event_sequences.append(seq)
max_len = max(max_len, len(seq))
unique_wallets_data.update(item.get('wallets', {}))
unique_tokens_data.update(item.get('tokens', {}))
# Create mappings needed for indexing (use dict keys as source of truth)
wallet_items = list(unique_wallets_data.items())
token_items = list(unique_tokens_data.items())
wallet_list_data = []
for addr, feat in wallet_items:
profile = feat.get('profile', {})
if not profile.get('wallet_address'):
profile['wallet_address'] = addr
wallet_list_data.append(feat)
token_list_data = []
for addr, feat in token_items:
if not feat.get('address'):
feat['address'] = addr
token_list_data.append(feat)
wallet_addr_to_batch_idx = {addr: i + 1 for i, (addr, _) in enumerate(wallet_items)}
token_addr_to_batch_idx = {addr: i + 1 for i, (addr, _) in enumerate(token_items)}
# Collate Static Raw Features (Tokens, Wallets, Graph)
token_encoder_inputs = self._collate_features_for_encoder(token_list_data, ['name'], self.device, "token")
# Collate Static Raw Features (Tokens, Wallets, Graph)
token_encoder_inputs = self._collate_features_for_encoder(token_list_data, ['name'], self.device, "token")
wallet_encoder_inputs = self._collate_features_for_encoder(wallet_list_data, ['profile'], self.device, "wallet")
graph_updater_links = self._collate_graph_links(batch, wallet_addr_to_batch_idx, token_addr_to_batch_idx)
# --- 5. Prepare Sequence Tensors & Collect Dynamic Data (OHLC) ---
B = batch_size
L = max_len
PAD_IDX_SEQ = self.pad_token_id
PAD_IDX_ENT = self.entity_pad_idx
# Initialize sequence tensors
event_type_ids = torch.full((B, L), PAD_IDX_SEQ, dtype=torch.long, device=self.device)
# Use float64 to preserve second-level precision for large Unix timestamps.
timestamps_float = torch.zeros((B, L), dtype=torch.float64, device=self.device)
# Store relative_ts in float32 for stability; model will scale/log/normalize
relative_ts = torch.zeros((B, L, 1), dtype=torch.float32, device=self.device)
attention_mask = torch.zeros((B, L), dtype=torch.long, device=self.device)
wallet_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
token_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
ohlc_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
quote_token_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device) # NEW
# --- NEW: Tensors for Transfer/LargeTransfer ---
dest_wallet_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
# --- NEW: Separate tensor for social media original authors ---
original_author_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
# 4 numerical features for transfers
transfer_numerical_features = torch.zeros((B, L, 4), dtype=self.dtype, device=self.device)
# --- NEW: Tensors for Trade ---
# --- FIXED: Size reduced from 10 to 8 ---
trade_numerical_features = torch.zeros((B, L, 8), dtype=self.dtype, device=self.device)
deployer_trade_numerical_features = torch.zeros((B, L, 8), dtype=self.dtype, device=self.device)
smart_wallet_trade_numerical_features = torch.zeros((B, L, 8), dtype=self.dtype, device=self.device)
# --- NEW: Dedicated tensor for categorical dex_platform_id ---
trade_dex_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
# --- NEW: Dedicated tensor for categorical trade_direction ---
trade_direction_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
# --- NEW: Dedicated tensor for categorical mev_protection ---
trade_mev_protection_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
# --- NEW: Dedicated tensor for categorical is_bundle ---
trade_is_bundle_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
# --- NEW: Tensors for PoolCreated ---
# --- UPDATED: Capture raw base/quote deposit amounts only ---
pool_created_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device)
# --- NEW: Dedicated tensor for categorical protocol_id ---
pool_created_protocol_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
# --- NEW: Tensors for LiquidityChange ---
# --- UPDATED: Keep only the raw quote amount deposit/withdraw ---
liquidity_change_numerical_features = torch.zeros((B, L, 1), dtype=self.dtype, device=self.device)
# --- NEW: Dedicated tensor for categorical change_type_id ---
liquidity_change_type_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
# --- NEW: Tensors for FeeCollected ---
fee_collected_numerical_features = torch.zeros((B, L, 1), dtype=self.dtype, device=self.device) # sol_amount only
# --- NEW: Tensors for TokenBurn ---
token_burn_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device) # amount_pct, amount_tokens
# --- NEW: Tensors for SupplyLock ---
supply_lock_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device) # amount_pct, lock_duration
# --- NEW: Tensors for OnChain_Snapshot ---
onchain_snapshot_numerical_features = torch.zeros((B, L, 14), dtype=self.dtype, device=self.device)
# --- NEW: Tensors for TrendingToken ---
trending_token_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
# --- FIXED: Size reduced from 3 to 1 after removing IDs ---
trending_token_numerical_features = torch.zeros((B, L, 1), dtype=self.dtype, device=self.device) # rank
trending_token_source_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
trending_token_timeframe_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
# --- NEW: Tensors for BoostedToken ---
boosted_token_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
boosted_token_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device) # total_boost_amount, rank
# --- NEW: Tensors for DexBoost_Paid ---
dexboost_paid_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device) # amount, total_amount_on_token
# --- NEW: Tensors for DexProfile_Updated ---
dexprofile_updated_flags = torch.zeros((B, L, 4), dtype=torch.float32, device=self.device) # Using float for easier projection
# --- NEW: Tensors for Tracker Events ---
alpha_group_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
channel_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
exchange_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
# --- NEW: Tensors for GlobalTrending Events ---
global_trending_numerical_features = torch.zeros((B, L, 1), dtype=self.dtype, device=self.device) # rank
# --- NEW: Tensors for ChainSnapshot ---
chainsnapshot_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device) # native_token_price_usd, gas_fee
# --- NEW: Tensors for Lighthouse_Snapshot ---
# --- FIXED: Size reduced from 7 to 5 after removing IDs ---
lighthousesnapshot_numerical_features = torch.zeros((B, L, 5), dtype=self.dtype, device=self.device)
lighthousesnapshot_protocol_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
lighthousesnapshot_timeframe_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
# --- NEW: Tensors for Migrated event ---
migrated_protocol_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
# --- NEW: Tensors for HolderSnapshot ---
# This will store the raw holder data for the Oracle to process
holder_snapshot_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
holder_snapshot_raw_data_list = [] # List of lists of dicts
# --- RENAMED: Generic tensors for any event with text/image features ---
textual_event_data_list = [] # List of dicts with text/media indices
textual_event_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
# --- NEW: Pointers for pre-encoded images ---
image_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
original_post_image_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
# --- CORRECTED: Initialize chart event collection here ---
batch_chart_events = []
chart_event_counter = 0
# Loop through sequences to populate tensors and collect chart events
for i, seq in enumerate(all_event_sequences):
seq_len = len(seq)
if seq_len == 0: continue
attention_mask[i, :seq_len] = 1
for j, event in enumerate(seq):
# Populate basic sequence info
event_type = event.get('event_type', '__PAD__')
type_id = self.event_type_to_id.get(event_type, PAD_IDX_SEQ)
event_type_ids[i, j] = type_id
timestamps_float[i, j] = event.get('timestamp', 0)
relative_ts[i, j, 0] = event.get('relative_ts', 0.0)
# Populate pointer indices
w_addr = event.get('wallet_address')
if w_addr:
wallet_indices[i, j] = wallet_addr_to_batch_idx.get(w_addr, PAD_IDX_ENT)
t_addr = event.get('token_address')
if t_addr:
token_indices[i, j] = token_addr_to_batch_idx.get(t_addr, PAD_IDX_ENT)
# If it's a chart event, collect it and record its index
if event_type == 'Chart_Segment':
batch_chart_events.append(event)
ohlc_indices[i, j] = chart_event_counter + 1 # Use 1-based index
chart_event_counter += 1
elif event_type in ['Transfer', 'LargeTransfer']: # ADDED LargeTransfer
# Get destination wallet index
dest_w_addr = event.get('destination_wallet_address') # Assuming this key exists
if dest_w_addr:
dest_wallet_indices[i, j] = wallet_addr_to_batch_idx.get(dest_w_addr, PAD_IDX_ENT)
# Get numerical features (use .get with default 0)
num_feats = [
event.get('token_amount', 0.0),
event.get('transfer_pct_of_total_supply', 0.0),
event.get('transfer_pct_of_holding', 0.0),
event.get('priority_fee', 0.0)
]
transfer_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype)
elif event_type in ['Trade', 'LargeTrade']:
# Get numerical and categorical features for the trade
trade_dex_ids[i, j] = event.get('dex_platform_id', 0)
trade_direction_ids[i, j] = event.get('trade_direction', 0) # 0=buy, 1=sell
trade_mev_protection_ids[i, j] = event.get('mev_protection', 0) # 0, 1, 2...
trade_is_bundle_ids[i, j] = 1 if event.get('is_bundle') else 0 # 0=false, 1=true
num_feats = [
event.get('sol_amount', 0.0),
event.get('priority_fee', 0.0),
event.get('token_amount_pct_of_holding', 0.0),
event.get('quote_amount_pct_of_holding', 0.0),
event.get('slippage', 0.0),
event.get('token_amount_pct_to_total_supply', 0.0), # REPLACED price_impact
1.0 if event.get('success') else 0.0,
event.get('total_usd', 0.0)
]
trade_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype)
elif event_type == 'Deployer_Trade':
# Use the dedicated tensor for deployer trades
trade_dex_ids[i, j] = event.get('dex_platform_id', 0)
trade_direction_ids[i, j] = event.get('trade_direction', 0) # 0=buy, 1=sell
trade_mev_protection_ids[i, j] = event.get('mev_protection', 0) # 0, 1, 2...
trade_is_bundle_ids[i, j] = 1 if event.get('is_bundle') else 0 # 0=false, 1=true
num_feats = [
event.get('sol_amount', 0.0),
event.get('priority_fee', 0.0),
event.get('token_amount_pct_of_holding', 0.0),
event.get('quote_amount_pct_of_holding', 0.0),
event.get('slippage', 0.0),
event.get('token_amount_pct_to_total_supply', 0.0), # REPLACED price_impact
1.0 if event.get('success') else 0.0,
event.get('total_usd', 0.0)
]
deployer_trade_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype)
elif event_type == 'SmartWallet_Trade':
# Use the dedicated tensor for smart wallet trades
trade_dex_ids[i, j] = event.get('dex_platform_id', 0)
trade_direction_ids[i, j] = event.get('trade_direction', 0) # 0=buy, 1=sell
trade_mev_protection_ids[i, j] = event.get('mev_protection', 0) # 0, 1, 2...
trade_is_bundle_ids[i, j] = 1 if event.get('is_bundle') else 0 # 0=false, 1=true
num_feats = [
event.get('sol_amount', 0.0),
event.get('priority_fee', 0.0),
event.get('token_amount_pct_of_holding', 0.0),
event.get('quote_amount_pct_of_holding', 0.0),
event.get('slippage', 0.0),
event.get('token_amount_pct_to_total_supply', 0.0), # REPLACED price_impact
1.0 if event.get('success') else 0.0,
event.get('total_usd', 0.0)
]
smart_wallet_trade_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype)
elif event_type == 'PoolCreated':
# Get the quote token index
quote_t_addr = event.get('quote_token_address')
if quote_t_addr:
quote_token_indices[i, j] = token_addr_to_batch_idx.get(quote_t_addr, PAD_IDX_ENT)
pool_created_protocol_ids[i, j] = event.get('protocol_id', 0)
# Get numerical features
num_feats = [
event.get('base_amount', 0.0),
event.get('quote_amount', 0.0)
]
pool_created_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype)
elif event_type == 'LiquidityChange':
# Get the quote token index
quote_t_addr = event.get('quote_token_address')
if quote_t_addr:
quote_token_indices[i, j] = token_addr_to_batch_idx.get(quote_t_addr, PAD_IDX_ENT)
liquidity_change_type_ids[i, j] = event.get('change_type_id', 0)
# Get numerical features
num_feats = [event.get('quote_amount', 0.0)]
liquidity_change_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype)
elif event_type == 'FeeCollected':
# This event has the recipient wallet plus a single numerical feature (SOL amount).
num_feats = [
event.get('sol_amount', 0.0)
]
fee_collected_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype)
elif event_type == 'TokenBurn':
# This event has a wallet (handled by wallet_indices) and two numerical features.
num_feats = [
event.get('amount_pct_of_total_supply', 0.0),
event.get('amount_tokens_burned', 0.0)
]
token_burn_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype)
elif event_type == 'SupplyLock':
# This event has a wallet and two numerical features.
num_feats = [
event.get('amount_pct_of_total_supply', 0.0),
event.get('lock_duration', 0.0)
]
supply_lock_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype)
elif event_type == 'OnChain_Snapshot':
# This event is a global snapshot with 14 numerical features.
num_feats = [
event.get('total_holders', 0.0),
event.get('smart_traders', 0.0),
event.get('kols', 0.0),
event.get('holder_growth_rate', 0.0),
event.get('top_10_holder_pct', 0.0),
event.get('sniper_holding_pct', 0.0),
event.get('rat_wallets_holding_pct', 0.0),
event.get('bundle_holding_pct', 0.0),
event.get('current_market_cap', 0.0),
event.get('volume', 0.0),
event.get('buy_count', 0.0),
event.get('sell_count', 0.0),
event.get('total_txns', 0.0),
event.get('global_fees_paid', 0.0)
]
onchain_snapshot_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype)
elif event_type == 'TrendingToken':
# Get the trending token index
trending_t_addr = event.get('token_address')
if trending_t_addr:
trending_token_indices[i, j] = token_addr_to_batch_idx.get(trending_t_addr, PAD_IDX_ENT)
trending_token_source_ids[i, j] = event.get('list_source_id', 0)
trending_token_timeframe_ids[i, j] = event.get('timeframe_id', 0)
# --- FIXED: Invert rank so that 1 is the highest value ---
# Get numerical/categorical features
num_feats = [
1.0 / event.get('rank', 1e9) # Use a large number for rank 0 or missing to make it small
]
trending_token_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype)
elif event_type == 'BoostedToken':
# Get the boosted token index
boosted_t_addr = event.get('token_address')
if boosted_t_addr:
boosted_token_indices[i, j] = token_addr_to_batch_idx.get(boosted_t_addr, PAD_IDX_ENT)
# --- FIXED: Invert rank so that 1 is the highest value ---
# Get numerical features
num_feats = [
event.get('total_boost_amount', 0.0),
1.0 / event.get('rank', 1e9)
]
boosted_token_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype)
elif event_type == 'Migrated':
migrated_protocol_ids[i, j] = event.get('protocol_id', 0)
elif event_type == 'HolderSnapshot':
# --- FIXED: Store raw holder data, not an index ---
raw_holders = event.get('holders', [])
holder_snapshot_raw_data_list.append(raw_holders)
holder_snapshot_indices[i, j] = len(holder_snapshot_raw_data_list) # 1-based index to the list
elif event_type == 'Lighthouse_Snapshot':
lighthousesnapshot_protocol_ids[i, j] = event.get('protocol_id', 0)
lighthousesnapshot_timeframe_ids[i, j] = event.get('timeframe_id', 0)
num_feats = [
event.get('total_volume', 0.0),
event.get('total_transactions', 0.0),
event.get('total_traders', 0.0),
event.get('total_tokens_created', 0.0),
event.get('total_migrations', 0.0)
]
lighthousesnapshot_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype)
# --- UPDATED: Group all events that contain pre-computed text/image indices ---
elif event_type in ['XPost', 'XReply', 'XRetweet', 'XQuoteTweet', 'PumpReply', 'DexProfile_Updated', 'TikTok_Trending_Hashtag', 'XTrending_Hashtag']:
# Store raw event data to look up text/image indices later
# 1. Store raw text/media data
textual_event_data_list.append(event)
textual_event_indices[i, j] = len(textual_event_data_list) # 1-based index
# --- FIXED: Handle rank for trending hashtags ---
if event_type in ['TikTok_Trending_Hashtag', 'XTrending_Hashtag']:
global_trending_numerical_features[i, j, 0] = 1.0 / event.get('rank', 1e9)
# 2. Populate wallet pointer tensors based on the event type
# The main 'wallet_address' is already handled above.
# Here we handle the *other* wallets involved.
if event_type == 'XRetweet' or event_type == 'XQuoteTweet':
orig_author_addr = event.get('original_author_wallet_address')
if orig_author_addr:
# --- FIXED: Use the dedicated tensor for original authors ---
original_author_indices[i, j] = wallet_addr_to_batch_idx.get(orig_author_addr, PAD_IDX_ENT)
# The pre-computed embedding indices are already in the event dict.
# No need to populate image_indices here anymore.
# For XReply, the main tweet is a text/media embedding, not a wallet.
# For XPost, there's only one wallet, already handled.
# --- 4. Collate Dynamic Features (OHLC) AFTER collecting them ---
ohlc_inputs_dict = self._collate_ohlc_inputs(batch_chart_events)
# --- 6. Prepare final output dictionary ---
collated_batch = {
# Sequence Tensors
'event_type_ids': event_type_ids,
'timestamps_float': timestamps_float,
'relative_ts': relative_ts,
'attention_mask': attention_mask,
# Pointer Tensors
'wallet_indices': wallet_indices,
'token_indices': token_indices,
'quote_token_indices': quote_token_indices, # NEW
'trending_token_indices': trending_token_indices, # NEW
'boosted_token_indices': boosted_token_indices, # NEW
'holder_snapshot_indices': holder_snapshot_indices, # This now points to the generated embeddings
'textual_event_indices': textual_event_indices, # RENAMED
'ohlc_indices': ohlc_indices,
# Raw Data for Encoders
'embedding_pool': batch_embedding_pool, # NEW
'token_encoder_inputs': token_encoder_inputs,
'wallet_encoder_inputs': wallet_encoder_inputs, # ADDED BACK
'ohlc_price_tensors': ohlc_inputs_dict['price_tensor'],
'ohlc_interval_ids': ohlc_inputs_dict['interval_ids'],
'quant_ohlc_feature_tensors': ohlc_inputs_dict['quant_feature_tensors'],
'quant_ohlc_feature_mask': ohlc_inputs_dict['quant_feature_mask'],
'quant_ohlc_feature_version_ids': ohlc_inputs_dict['quant_feature_version_ids'],
'graph_updater_links': graph_updater_links,
'wallet_addr_to_batch_idx': wallet_addr_to_batch_idx, # NEW: Pass the mapping
'dest_wallet_indices': dest_wallet_indices, # ADDED THIS LINE
'original_author_indices': original_author_indices, # NEW
# --- NEW: Numerical Features for Events ---
'transfer_numerical_features': transfer_numerical_features,
'trade_numerical_features': trade_numerical_features,
'trade_dex_ids': trade_dex_ids,
'deployer_trade_numerical_features': deployer_trade_numerical_features,
'trade_direction_ids': trade_direction_ids, # NEW
'trade_mev_protection_ids': trade_mev_protection_ids, # NEW
'smart_wallet_trade_numerical_features': smart_wallet_trade_numerical_features,
'trade_is_bundle_ids': trade_is_bundle_ids, # NEW
'pool_created_numerical_features': pool_created_numerical_features,
'pool_created_protocol_ids': pool_created_protocol_ids, # NEW
'liquidity_change_numerical_features': liquidity_change_numerical_features,
'liquidity_change_type_ids': liquidity_change_type_ids, # NEW
'fee_collected_numerical_features': fee_collected_numerical_features, # NEW
'token_burn_numerical_features': token_burn_numerical_features, # NEW
'supply_lock_numerical_features': supply_lock_numerical_features, # NEW
'onchain_snapshot_numerical_features': onchain_snapshot_numerical_features, # NEW
'boosted_token_numerical_features': boosted_token_numerical_features,
'trending_token_numerical_features': trending_token_numerical_features,
'trending_token_source_ids': trending_token_source_ids, # NEW
'trending_token_timeframe_ids': trending_token_timeframe_ids, # NEW
'dexboost_paid_numerical_features': dexboost_paid_numerical_features, # NEW
'dexprofile_updated_flags': dexprofile_updated_flags, # NEW,
'global_trending_numerical_features': global_trending_numerical_features, # NEW
'chainsnapshot_numerical_features': chainsnapshot_numerical_features, # NEW
'lighthousesnapshot_numerical_features': lighthousesnapshot_numerical_features,
'lighthousesnapshot_protocol_ids': lighthousesnapshot_protocol_ids, # NEW
'lighthousesnapshot_timeframe_ids': lighthousesnapshot_timeframe_ids, # NEW
'migrated_protocol_ids': migrated_protocol_ids, # NEW
'alpha_group_ids': alpha_group_ids, # NEW
'channel_ids': channel_ids, # NEW
'exchange_ids': exchange_ids, # NEW
'holder_snapshot_raw_data': holder_snapshot_raw_data_list, # NEW: Raw data for end-to-end processing
'textual_event_data': textual_event_data_list, # RENAMED
# Labels
'labels': torch.stack([item['labels'] for item in batch]) if batch and 'labels' in batch[0] else None,
'labels_mask': torch.stack([item['labels_mask'] for item in batch]) if batch and 'labels_mask' in batch[0] else None,
'movement_class_targets': torch.stack([item['movement_class_targets'] for item in batch]) if batch and 'movement_class_targets' in batch[0] else None,
'movement_class_mask': torch.stack([item['movement_class_mask'] for item in batch]) if batch and 'movement_class_mask' in batch[0] else None,
'quality_score': torch.stack([item['quality_score'] if isinstance(item['quality_score'], torch.Tensor) else torch.tensor(item['quality_score'], dtype=torch.float32) for item in batch]) if batch and 'quality_score' in batch[0] else None,
'class_id': torch.tensor([item.get('class_id', 0) for item in batch], dtype=torch.long),
# Debug info
'token_addresses': [item.get('token_address', 'unknown') for item in batch],
't_cutoffs': [item.get('t_cutoff', 'unknown') for item in batch],
'sample_indices': [item.get('sample_idx', -1) for item in batch]
}
if collated_batch['quality_score'] is None:
raise RuntimeError("FATAL: Missing quality_score in batch items. Rebuild cache with quality_score enabled.")
# Filter out None values (e.g., if no labels provided)
return {k: v for k, v in collated_batch.items() if v is not None}