# memecoin_collator.py (CORRECTED ORDER OF OPERATIONS) import torch import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence from typing import List, Dict, Any, Tuple, Optional, Union from collections import defaultdict from PIL import Image # --- GLOBAL SINGLETON FOR WORKER PROCESSES REMOVED --- import models.vocabulary as vocab from data.data_loader import EmbeddingPooler from data.quant_ohlc_feature_schema import FEATURE_VERSION, FEATURE_VERSION_ID, NUM_QUANT_OHLC_FEATURES, TOKENS_PER_SEGMENT NATIVE_MINT = "So11111111111111111111111111111111111111112" QUOTE_MINTS = { NATIVE_MINT, # SOL "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v", # USDC "Es9vMFrzaCERmJfrF4H2FYD4KCoNkY11McCe8BenwNYB", # USDT "USD1ttGY1N17NEEHLmELoaybftRBUSErhqYiQzvEmuB", # USD1 } class MemecoinCollator: """ Callable class for PyTorch DataLoader's collate_fn. ... (rest of docstring) ... """ def __init__(self, event_type_to_id: Dict[str, int], device: torch.device, dtype: torch.dtype, max_seq_len: Optional[int] = None, model_id: str = "google/siglip-so400m-patch16-256-i18n" ): self.event_type_to_id = event_type_to_id self.pad_token_id = event_type_to_id.get('__PAD__', 0) # self.multi_modal_encoder = multi_modal_encoder # DEPRECATED self.model_id = model_id self.entity_pad_idx = 0 self.device = device self.dtype = dtype self.ohlc_seq_len = 300 # HARDCODED self.quant_ohlc_tokens = TOKENS_PER_SEGMENT self.quant_ohlc_num_features = NUM_QUANT_OHLC_FEATURES self.max_seq_len = max_seq_len def _collate_features_for_encoder(self, entities: List[Dict], feature_keys: List[str], device: torch.device, entity_type: str) -> Dict[str, Any]: """ (Unchanged) """ collated = defaultdict(list) if not entities: # --- FIXED: Return a default empty structure for BOTH tokens and wallets --- if entity_type == "token": return { 'name_embed_indices': torch.tensor([], device=device, dtype=torch.long), 'symbol_embed_indices': torch.tensor([], device=device, dtype=torch.long), 'image_embed_indices': torch.tensor([], device=device, dtype=torch.long), 'protocol_ids': torch.tensor([], device=device, dtype=torch.long), 'is_vanity_flags': torch.tensor([], device=device, dtype=torch.bool), '_addresses_for_lookup': [] } elif entity_type == "wallet": return { 'username_embed_indices': torch.tensor([], device=device, dtype=torch.long), 'profile_rows': [], 'social_rows': [], 'holdings_batch': [] } return {} # Should not happen # NEW: We now gather indices to pre-computed embeddings if entity_type == "token": # This indicates a Token entity # Helper key for WalletEncoder to find token vibes collated['_addresses_for_lookup'] = [e.get('address', '') for e in entities] collated['name_embed_indices'] = torch.tensor([e.get('name_emb_idx', 0) for e in entities], device=device, dtype=torch.long) collated['symbol_embed_indices'] = torch.tensor([e.get('symbol_emb_idx', 0) for e in entities], device=device, dtype=torch.long) collated['image_embed_indices'] = torch.tensor([e.get('image_emb_idx', 0) for e in entities], device=device, dtype=torch.long) collated['protocol_ids'] = torch.tensor([e.get('protocol', 0) for e in entities], device=device, dtype=torch.long) collated['is_vanity_flags'] = torch.tensor([e.get('is_vanity', False) for e in entities], device=device, dtype=torch.bool) elif entity_type == "wallet": # NEW: Gather username indices for WalletEncoder collated['username_embed_indices'] = torch.tensor([e.get('socials', {}).get('username_emb_idx', 0) for e in entities], device=device, dtype=torch.long) collated['profile_rows'] = [e.get('profile', {}) for e in entities] collated['social_rows'] = [e.get('socials', {}) for e in entities] collated['holdings_batch'] = [e.get('holdings', []) for e in entities] return dict(collated) def _collate_ohlc_inputs(self, chart_events: List[Dict]) -> Dict[str, torch.Tensor]: """ (Unchanged from previous correct version) """ if not chart_events: return { 'price_tensor': torch.empty(0, 2, self.ohlc_seq_len, device=self.device, dtype=self.dtype), 'interval_ids': torch.empty(0, device=self.device, dtype=torch.long), 'quant_feature_tensors': torch.empty(0, self.quant_ohlc_tokens, self.quant_ohlc_num_features, device=self.device, dtype=self.dtype), 'quant_feature_mask': torch.empty(0, self.quant_ohlc_tokens, device=self.device, dtype=self.dtype), 'quant_feature_version_ids': torch.empty(0, device=self.device, dtype=torch.long), } ohlc_tensors = [] interval_ids_list = [] quant_feature_tensors = [] quant_feature_masks = [] quant_feature_version_ids = [] seq_len = self.ohlc_seq_len unknown_id = vocab.INTERVAL_TO_ID.get("Unknown", 0) for segment_data in chart_events: opens = segment_data.get('opens', []) closes = segment_data.get('closes', []) interval_str = segment_data.get('i', "Unknown") pad_open = opens[-1] if opens else 0 pad_close = closes[-1] if closes else 0 o = torch.tensor(opens[:seq_len] + [pad_open]*(seq_len-len(opens)), dtype=self.dtype) c = torch.tensor(closes[:seq_len] + [pad_close]*(seq_len-len(closes)), dtype=self.dtype) o = torch.nan_to_num(o, nan=0.0, posinf=0.0, neginf=0.0) c = torch.nan_to_num(c, nan=0.0, posinf=0.0, neginf=0.0) ohlc_tensors.append(torch.stack([o, c])) interval_id = vocab.INTERVAL_TO_ID.get(interval_str, unknown_id) interval_ids_list.append(interval_id) quant_payload = segment_data.get('quant_ohlc_features') if quant_payload is None: raise RuntimeError("Chart_Segment missing quant_ohlc_features. Rebuild cache with quantitative chart features.") if not isinstance(quant_payload, list): raise RuntimeError("Chart_Segment quant_ohlc_features must be a list.") feature_rows = [] feature_mask = [] for token_idx in range(self.quant_ohlc_tokens): if token_idx < len(quant_payload): payload = quant_payload[token_idx] vec = payload.get('feature_vector') if not isinstance(vec, list) or len(vec) != self.quant_ohlc_num_features: raise RuntimeError( f"Chart_Segment quant feature vector must have length {self.quant_ohlc_num_features}." ) feature_rows.append(vec) feature_mask.append(1.0) else: feature_rows.append([0.0] * self.quant_ohlc_num_features) feature_mask.append(0.0) quant_feature_tensors.append(torch.tensor(feature_rows, device=self.device, dtype=self.dtype)) quant_feature_masks.append(torch.tensor(feature_mask, device=self.device, dtype=self.dtype)) version = segment_data.get('quant_feature_version', FEATURE_VERSION) quant_feature_version_ids.append(FEATURE_VERSION_ID if version == FEATURE_VERSION else 0) return { 'price_tensor': torch.stack(ohlc_tensors).to(self.device), 'interval_ids': torch.tensor(interval_ids_list, device=self.device, dtype=torch.long), 'quant_feature_tensors': torch.stack(quant_feature_tensors).to(self.device), 'quant_feature_mask': torch.stack(quant_feature_masks).to(self.device), 'quant_feature_version_ids': torch.tensor(quant_feature_version_ids, device=self.device, dtype=torch.long), } def _collate_graph_links(self, batch_items: List[Dict], wallet_addr_to_batch_idx: Dict[str, int], token_addr_to_batch_idx: Dict[str, int]) -> Dict[str, Any]: """ (Unchanged) """ aggregated_links = defaultdict(lambda: {'edge_index_list': [], 'links_list': []}) for item in batch_items: item_wallets = item.get('wallets', {}) item_tokens = item.get('tokens', {}) item_wallet_addr_to_global_idx = {addr: wallet_addr_to_batch_idx.get(addr, self.entity_pad_idx) for addr in item_wallets.keys()} item_token_addr_to_global_idx = {addr: token_addr_to_batch_idx.get(addr, self.entity_pad_idx) for addr in item_tokens.keys()} for link_name, data in item.get('graph_links', {}).items(): # aggregated_links[link_name]['links_list'].extend(data.get('links', [])) - REMOVED: Now handled inside the loop for sync triplet = vocab.LINK_NAME_TO_TRIPLET.get(link_name) if not triplet: continue src_type, _, dst_type = triplet edges = data.get('edges') link_props_list = data.get('links', []) if not edges or not link_props_list: continue src_map = item_wallet_addr_to_global_idx if src_type == 'wallet' else item_token_addr_to_global_idx dst_map = item_wallet_addr_to_global_idx if dst_type == 'wallet' else item_token_addr_to_global_idx remapped_edge_list = [] valid_link_props = [] for (src_addr, dst_addr), props in zip(edges, link_props_list): src_idx_global = src_map.get(src_addr, self.entity_pad_idx) dst_idx_global = dst_map.get(dst_addr, self.entity_pad_idx) if src_idx_global != self.entity_pad_idx and dst_idx_global != self.entity_pad_idx: remapped_edge_list.append([src_idx_global, dst_idx_global]) valid_link_props.append(props) if remapped_edge_list: remapped_edge_tensor = torch.tensor(remapped_edge_list, device=self.device, dtype=torch.long).t() aggregated_links[link_name]['edge_index_list'].append(remapped_edge_tensor) aggregated_links[link_name]['links_list'].extend(valid_link_props) if link_name == "TransferLink": link_props = data.get('links', []) derived_edges = [] derived_props = [] for (src_addr, dst_addr), props in zip(edges, link_props): mint_addr = props.get('mint') if not mint_addr or mint_addr in QUOTE_MINTS: continue token_idx_global = item_token_addr_to_global_idx.get(mint_addr, self.entity_pad_idx) if token_idx_global == self.entity_pad_idx: continue for wallet_addr in (src_addr, dst_addr): wallet_idx_global = item_wallet_addr_to_global_idx.get(wallet_addr, self.entity_pad_idx) if wallet_idx_global == self.entity_pad_idx: continue derived_edges.append([wallet_idx_global, token_idx_global]) derived_props.append(props) if derived_edges: derived_tensor = torch.tensor(derived_edges, device=self.device, dtype=torch.long).t() aggregated_links["TransferLinkToken"]['edge_index_list'].append(derived_tensor) aggregated_links["TransferLinkToken"]['links_list'].extend(derived_props) final_links_dict = {} for link_name, data in aggregated_links.items(): if data['edge_index_list']: final_links_dict[link_name] = { 'links': data['links_list'], 'edge_index': torch.cat(data['edge_index_list'], dim=1) } return final_links_dict def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: """ Processes a batch of raw data items into tensors for the model. """ # --- NEW ARCHITECTURE --- # 1. Aggregate all unique embeddable items from the entire batch. # 2. Create a single embedding pool tensor for the whole batch. # 3. Create a mapping from original (per-item) indices to the new batch-wide indices. # 4. Remap all `_emb_idx` fields in the batch data using this new mapping. batch_size = len(batch) if batch_size == 0: return {} # --- 1. Aggregate all unique items and create index mappings --- batch_wide_pooler = EmbeddingPooler() # Map to translate from an item's original pooler to the new batch-wide indices # Format: { batch_item_index: { original_idx: new_batch_idx } } idx_remap = defaultdict(dict) for i, item in enumerate(batch): pooler = item.get('embedding_pooler') if not pooler: continue for pool_item_data in pooler.get_all_items(): original_idx = pool_item_data['idx'] raw_item = pool_item_data['item'] # get_idx will either return an existing index or create a new one # --- FIX: Convert 1-based pooler index to 0-based tensor index --- new_batch_idx_1_based = batch_wide_pooler.get_idx(raw_item) new_batch_idx_0_based = new_batch_idx_1_based - 1 idx_remap[i][original_idx] = new_batch_idx_0_based # --- 2. Create the single, batch-wide embedding pool tensor --- all_items_sorted = batch_wide_pooler.get_all_items() if not all_items_sorted: # Handle edge case of absolutely no embeddings in batch # Create a dummy empty tensor batch_embedding_pool = torch.empty(0, 768, device=self.device, dtype=self.dtype) # Default SigLIP dim is 1152 actually, but standard is 768. Better to infer or default. # Actually, if empty, it doesn't matter much as long as it's not accessed. else: first_item = all_items_sorted[0]['item'] if not isinstance(first_item, torch.Tensor): raise RuntimeError(f"Collator expects pre-computed embeddings (torch.Tensor), found {type(first_item)}. Please rebuild cache.") # Stack all embeddings # They should already be CPU tensors from the loader # Move to device and cast to dtype batch_embedding_pool = torch.stack([d['item'] for d in all_items_sorted]).to(device=self.device, dtype=self.dtype) batch_embedding_pool = torch.nan_to_num(batch_embedding_pool, nan=0.0, posinf=0.0, neginf=0.0) # --- 3. Remap all indices in the batch data --- for i, item in enumerate(batch): remap_dict = idx_remap.get(i, {}) if not remap_dict: continue # Remap tokens for token_data in item.get('tokens', {}).values(): for key in ['name_emb_idx', 'symbol_emb_idx', 'image_emb_idx']: if token_data.get(key, 0) > 0: # Check if it has a valid 1-based index token_data[key] = remap_dict.get(token_data[key], -1) # Remap to 0-based, default to -1 if not found # Remap wallets for wallet_data in item.get('wallets', {}).values(): socials = wallet_data.get('socials', {}) if socials.get('username_emb_idx', 0) > 0: socials['username_emb_idx'] = remap_dict.get(socials['username_emb_idx'], -1) # Remap events for event in item.get('event_sequence', []): for key in event: if key.endswith('_emb_idx') and event.get(key, 0) > 0: event[key] = remap_dict.get(event[key], 0) # --- 4. Standard Collation (Now that indices are correct) --- unique_wallets_data = {} unique_tokens_data = {} all_event_sequences = [] max_len = 0 for item in batch: seq = item.get('event_sequence', []) if self.max_seq_len is not None and len(seq) > self.max_seq_len: seq = seq[:self.max_seq_len] all_event_sequences.append(seq) max_len = max(max_len, len(seq)) unique_wallets_data.update(item.get('wallets', {})) unique_tokens_data.update(item.get('tokens', {})) # Create mappings needed for indexing (use dict keys as source of truth) wallet_items = list(unique_wallets_data.items()) token_items = list(unique_tokens_data.items()) wallet_list_data = [] for addr, feat in wallet_items: profile = feat.get('profile', {}) if not profile.get('wallet_address'): profile['wallet_address'] = addr wallet_list_data.append(feat) token_list_data = [] for addr, feat in token_items: if not feat.get('address'): feat['address'] = addr token_list_data.append(feat) wallet_addr_to_batch_idx = {addr: i + 1 for i, (addr, _) in enumerate(wallet_items)} token_addr_to_batch_idx = {addr: i + 1 for i, (addr, _) in enumerate(token_items)} # Collate Static Raw Features (Tokens, Wallets, Graph) token_encoder_inputs = self._collate_features_for_encoder(token_list_data, ['name'], self.device, "token") # Collate Static Raw Features (Tokens, Wallets, Graph) token_encoder_inputs = self._collate_features_for_encoder(token_list_data, ['name'], self.device, "token") wallet_encoder_inputs = self._collate_features_for_encoder(wallet_list_data, ['profile'], self.device, "wallet") graph_updater_links = self._collate_graph_links(batch, wallet_addr_to_batch_idx, token_addr_to_batch_idx) # --- 5. Prepare Sequence Tensors & Collect Dynamic Data (OHLC) --- B = batch_size L = max_len PAD_IDX_SEQ = self.pad_token_id PAD_IDX_ENT = self.entity_pad_idx # Initialize sequence tensors event_type_ids = torch.full((B, L), PAD_IDX_SEQ, dtype=torch.long, device=self.device) # Use float64 to preserve second-level precision for large Unix timestamps. timestamps_float = torch.zeros((B, L), dtype=torch.float64, device=self.device) # Store relative_ts in float32 for stability; model will scale/log/normalize relative_ts = torch.zeros((B, L, 1), dtype=torch.float32, device=self.device) attention_mask = torch.zeros((B, L), dtype=torch.long, device=self.device) wallet_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device) token_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device) ohlc_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device) quote_token_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device) # NEW # --- NEW: Tensors for Transfer/LargeTransfer --- dest_wallet_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device) # --- NEW: Separate tensor for social media original authors --- original_author_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device) # 4 numerical features for transfers transfer_numerical_features = torch.zeros((B, L, 4), dtype=self.dtype, device=self.device) # --- NEW: Tensors for Trade --- # --- FIXED: Size reduced from 10 to 8 --- trade_numerical_features = torch.zeros((B, L, 8), dtype=self.dtype, device=self.device) deployer_trade_numerical_features = torch.zeros((B, L, 8), dtype=self.dtype, device=self.device) smart_wallet_trade_numerical_features = torch.zeros((B, L, 8), dtype=self.dtype, device=self.device) # --- NEW: Dedicated tensor for categorical dex_platform_id --- trade_dex_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device) # --- NEW: Dedicated tensor for categorical trade_direction --- trade_direction_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device) # --- NEW: Dedicated tensor for categorical mev_protection --- trade_mev_protection_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device) # --- NEW: Dedicated tensor for categorical is_bundle --- trade_is_bundle_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device) # --- NEW: Tensors for PoolCreated --- # --- UPDATED: Capture raw base/quote deposit amounts only --- pool_created_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device) # --- NEW: Dedicated tensor for categorical protocol_id --- pool_created_protocol_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device) # --- NEW: Tensors for LiquidityChange --- # --- UPDATED: Keep only the raw quote amount deposit/withdraw --- liquidity_change_numerical_features = torch.zeros((B, L, 1), dtype=self.dtype, device=self.device) # --- NEW: Dedicated tensor for categorical change_type_id --- liquidity_change_type_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device) # --- NEW: Tensors for FeeCollected --- fee_collected_numerical_features = torch.zeros((B, L, 1), dtype=self.dtype, device=self.device) # sol_amount only # --- NEW: Tensors for TokenBurn --- token_burn_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device) # amount_pct, amount_tokens # --- NEW: Tensors for SupplyLock --- supply_lock_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device) # amount_pct, lock_duration # --- NEW: Tensors for OnChain_Snapshot --- onchain_snapshot_numerical_features = torch.zeros((B, L, 14), dtype=self.dtype, device=self.device) # --- NEW: Tensors for TrendingToken --- trending_token_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device) # --- FIXED: Size reduced from 3 to 1 after removing IDs --- trending_token_numerical_features = torch.zeros((B, L, 1), dtype=self.dtype, device=self.device) # rank trending_token_source_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device) trending_token_timeframe_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device) # --- NEW: Tensors for BoostedToken --- boosted_token_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device) boosted_token_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device) # total_boost_amount, rank # --- NEW: Tensors for DexBoost_Paid --- dexboost_paid_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device) # amount, total_amount_on_token # --- NEW: Tensors for DexProfile_Updated --- dexprofile_updated_flags = torch.zeros((B, L, 4), dtype=torch.float32, device=self.device) # Using float for easier projection # --- NEW: Tensors for Tracker Events --- alpha_group_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device) channel_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device) exchange_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device) # --- NEW: Tensors for GlobalTrending Events --- global_trending_numerical_features = torch.zeros((B, L, 1), dtype=self.dtype, device=self.device) # rank # --- NEW: Tensors for ChainSnapshot --- chainsnapshot_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device) # native_token_price_usd, gas_fee # --- NEW: Tensors for Lighthouse_Snapshot --- # --- FIXED: Size reduced from 7 to 5 after removing IDs --- lighthousesnapshot_numerical_features = torch.zeros((B, L, 5), dtype=self.dtype, device=self.device) lighthousesnapshot_protocol_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device) lighthousesnapshot_timeframe_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device) # --- NEW: Tensors for Migrated event --- migrated_protocol_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device) # --- NEW: Tensors for HolderSnapshot --- # This will store the raw holder data for the Oracle to process holder_snapshot_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device) holder_snapshot_raw_data_list = [] # List of lists of dicts # --- RENAMED: Generic tensors for any event with text/image features --- textual_event_data_list = [] # List of dicts with text/media indices textual_event_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device) # --- NEW: Pointers for pre-encoded images --- image_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device) original_post_image_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device) # --- CORRECTED: Initialize chart event collection here --- batch_chart_events = [] chart_event_counter = 0 # Loop through sequences to populate tensors and collect chart events for i, seq in enumerate(all_event_sequences): seq_len = len(seq) if seq_len == 0: continue attention_mask[i, :seq_len] = 1 for j, event in enumerate(seq): # Populate basic sequence info event_type = event.get('event_type', '__PAD__') type_id = self.event_type_to_id.get(event_type, PAD_IDX_SEQ) event_type_ids[i, j] = type_id timestamps_float[i, j] = event.get('timestamp', 0) relative_ts[i, j, 0] = event.get('relative_ts', 0.0) # Populate pointer indices w_addr = event.get('wallet_address') if w_addr: wallet_indices[i, j] = wallet_addr_to_batch_idx.get(w_addr, PAD_IDX_ENT) t_addr = event.get('token_address') if t_addr: token_indices[i, j] = token_addr_to_batch_idx.get(t_addr, PAD_IDX_ENT) # If it's a chart event, collect it and record its index if event_type == 'Chart_Segment': batch_chart_events.append(event) ohlc_indices[i, j] = chart_event_counter + 1 # Use 1-based index chart_event_counter += 1 elif event_type in ['Transfer', 'LargeTransfer']: # ADDED LargeTransfer # Get destination wallet index dest_w_addr = event.get('destination_wallet_address') # Assuming this key exists if dest_w_addr: dest_wallet_indices[i, j] = wallet_addr_to_batch_idx.get(dest_w_addr, PAD_IDX_ENT) # Get numerical features (use .get with default 0) num_feats = [ event.get('token_amount', 0.0), event.get('transfer_pct_of_total_supply', 0.0), event.get('transfer_pct_of_holding', 0.0), event.get('priority_fee', 0.0) ] transfer_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype) elif event_type in ['Trade', 'LargeTrade']: # Get numerical and categorical features for the trade trade_dex_ids[i, j] = event.get('dex_platform_id', 0) trade_direction_ids[i, j] = event.get('trade_direction', 0) # 0=buy, 1=sell trade_mev_protection_ids[i, j] = event.get('mev_protection', 0) # 0, 1, 2... trade_is_bundle_ids[i, j] = 1 if event.get('is_bundle') else 0 # 0=false, 1=true num_feats = [ event.get('sol_amount', 0.0), event.get('priority_fee', 0.0), event.get('token_amount_pct_of_holding', 0.0), event.get('quote_amount_pct_of_holding', 0.0), event.get('slippage', 0.0), event.get('token_amount_pct_to_total_supply', 0.0), # REPLACED price_impact 1.0 if event.get('success') else 0.0, event.get('total_usd', 0.0) ] trade_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype) elif event_type == 'Deployer_Trade': # Use the dedicated tensor for deployer trades trade_dex_ids[i, j] = event.get('dex_platform_id', 0) trade_direction_ids[i, j] = event.get('trade_direction', 0) # 0=buy, 1=sell trade_mev_protection_ids[i, j] = event.get('mev_protection', 0) # 0, 1, 2... trade_is_bundle_ids[i, j] = 1 if event.get('is_bundle') else 0 # 0=false, 1=true num_feats = [ event.get('sol_amount', 0.0), event.get('priority_fee', 0.0), event.get('token_amount_pct_of_holding', 0.0), event.get('quote_amount_pct_of_holding', 0.0), event.get('slippage', 0.0), event.get('token_amount_pct_to_total_supply', 0.0), # REPLACED price_impact 1.0 if event.get('success') else 0.0, event.get('total_usd', 0.0) ] deployer_trade_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype) elif event_type == 'SmartWallet_Trade': # Use the dedicated tensor for smart wallet trades trade_dex_ids[i, j] = event.get('dex_platform_id', 0) trade_direction_ids[i, j] = event.get('trade_direction', 0) # 0=buy, 1=sell trade_mev_protection_ids[i, j] = event.get('mev_protection', 0) # 0, 1, 2... trade_is_bundle_ids[i, j] = 1 if event.get('is_bundle') else 0 # 0=false, 1=true num_feats = [ event.get('sol_amount', 0.0), event.get('priority_fee', 0.0), event.get('token_amount_pct_of_holding', 0.0), event.get('quote_amount_pct_of_holding', 0.0), event.get('slippage', 0.0), event.get('token_amount_pct_to_total_supply', 0.0), # REPLACED price_impact 1.0 if event.get('success') else 0.0, event.get('total_usd', 0.0) ] smart_wallet_trade_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype) elif event_type == 'PoolCreated': # Get the quote token index quote_t_addr = event.get('quote_token_address') if quote_t_addr: quote_token_indices[i, j] = token_addr_to_batch_idx.get(quote_t_addr, PAD_IDX_ENT) pool_created_protocol_ids[i, j] = event.get('protocol_id', 0) # Get numerical features num_feats = [ event.get('base_amount', 0.0), event.get('quote_amount', 0.0) ] pool_created_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype) elif event_type == 'LiquidityChange': # Get the quote token index quote_t_addr = event.get('quote_token_address') if quote_t_addr: quote_token_indices[i, j] = token_addr_to_batch_idx.get(quote_t_addr, PAD_IDX_ENT) liquidity_change_type_ids[i, j] = event.get('change_type_id', 0) # Get numerical features num_feats = [event.get('quote_amount', 0.0)] liquidity_change_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype) elif event_type == 'FeeCollected': # This event has the recipient wallet plus a single numerical feature (SOL amount). num_feats = [ event.get('sol_amount', 0.0) ] fee_collected_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype) elif event_type == 'TokenBurn': # This event has a wallet (handled by wallet_indices) and two numerical features. num_feats = [ event.get('amount_pct_of_total_supply', 0.0), event.get('amount_tokens_burned', 0.0) ] token_burn_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype) elif event_type == 'SupplyLock': # This event has a wallet and two numerical features. num_feats = [ event.get('amount_pct_of_total_supply', 0.0), event.get('lock_duration', 0.0) ] supply_lock_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype) elif event_type == 'OnChain_Snapshot': # This event is a global snapshot with 14 numerical features. num_feats = [ event.get('total_holders', 0.0), event.get('smart_traders', 0.0), event.get('kols', 0.0), event.get('holder_growth_rate', 0.0), event.get('top_10_holder_pct', 0.0), event.get('sniper_holding_pct', 0.0), event.get('rat_wallets_holding_pct', 0.0), event.get('bundle_holding_pct', 0.0), event.get('current_market_cap', 0.0), event.get('volume', 0.0), event.get('buy_count', 0.0), event.get('sell_count', 0.0), event.get('total_txns', 0.0), event.get('global_fees_paid', 0.0) ] onchain_snapshot_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype) elif event_type == 'TrendingToken': # Get the trending token index trending_t_addr = event.get('token_address') if trending_t_addr: trending_token_indices[i, j] = token_addr_to_batch_idx.get(trending_t_addr, PAD_IDX_ENT) trending_token_source_ids[i, j] = event.get('list_source_id', 0) trending_token_timeframe_ids[i, j] = event.get('timeframe_id', 0) # --- FIXED: Invert rank so that 1 is the highest value --- # Get numerical/categorical features num_feats = [ 1.0 / event.get('rank', 1e9) # Use a large number for rank 0 or missing to make it small ] trending_token_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype) elif event_type == 'BoostedToken': # Get the boosted token index boosted_t_addr = event.get('token_address') if boosted_t_addr: boosted_token_indices[i, j] = token_addr_to_batch_idx.get(boosted_t_addr, PAD_IDX_ENT) # --- FIXED: Invert rank so that 1 is the highest value --- # Get numerical features num_feats = [ event.get('total_boost_amount', 0.0), 1.0 / event.get('rank', 1e9) ] boosted_token_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype) elif event_type == 'Migrated': migrated_protocol_ids[i, j] = event.get('protocol_id', 0) elif event_type == 'HolderSnapshot': # --- FIXED: Store raw holder data, not an index --- raw_holders = event.get('holders', []) holder_snapshot_raw_data_list.append(raw_holders) holder_snapshot_indices[i, j] = len(holder_snapshot_raw_data_list) # 1-based index to the list elif event_type == 'Lighthouse_Snapshot': lighthousesnapshot_protocol_ids[i, j] = event.get('protocol_id', 0) lighthousesnapshot_timeframe_ids[i, j] = event.get('timeframe_id', 0) num_feats = [ event.get('total_volume', 0.0), event.get('total_transactions', 0.0), event.get('total_traders', 0.0), event.get('total_tokens_created', 0.0), event.get('total_migrations', 0.0) ] lighthousesnapshot_numerical_features[i, j, :] = torch.as_tensor(num_feats, dtype=self.dtype) # --- UPDATED: Group all events that contain pre-computed text/image indices --- elif event_type in ['XPost', 'XReply', 'XRetweet', 'XQuoteTweet', 'PumpReply', 'DexProfile_Updated', 'TikTok_Trending_Hashtag', 'XTrending_Hashtag']: # Store raw event data to look up text/image indices later # 1. Store raw text/media data textual_event_data_list.append(event) textual_event_indices[i, j] = len(textual_event_data_list) # 1-based index # --- FIXED: Handle rank for trending hashtags --- if event_type in ['TikTok_Trending_Hashtag', 'XTrending_Hashtag']: global_trending_numerical_features[i, j, 0] = 1.0 / event.get('rank', 1e9) # 2. Populate wallet pointer tensors based on the event type # The main 'wallet_address' is already handled above. # Here we handle the *other* wallets involved. if event_type == 'XRetweet' or event_type == 'XQuoteTweet': orig_author_addr = event.get('original_author_wallet_address') if orig_author_addr: # --- FIXED: Use the dedicated tensor for original authors --- original_author_indices[i, j] = wallet_addr_to_batch_idx.get(orig_author_addr, PAD_IDX_ENT) # The pre-computed embedding indices are already in the event dict. # No need to populate image_indices here anymore. # For XReply, the main tweet is a text/media embedding, not a wallet. # For XPost, there's only one wallet, already handled. # --- 4. Collate Dynamic Features (OHLC) AFTER collecting them --- ohlc_inputs_dict = self._collate_ohlc_inputs(batch_chart_events) # --- 6. Prepare final output dictionary --- collated_batch = { # Sequence Tensors 'event_type_ids': event_type_ids, 'timestamps_float': timestamps_float, 'relative_ts': relative_ts, 'attention_mask': attention_mask, # Pointer Tensors 'wallet_indices': wallet_indices, 'token_indices': token_indices, 'quote_token_indices': quote_token_indices, # NEW 'trending_token_indices': trending_token_indices, # NEW 'boosted_token_indices': boosted_token_indices, # NEW 'holder_snapshot_indices': holder_snapshot_indices, # This now points to the generated embeddings 'textual_event_indices': textual_event_indices, # RENAMED 'ohlc_indices': ohlc_indices, # Raw Data for Encoders 'embedding_pool': batch_embedding_pool, # NEW 'token_encoder_inputs': token_encoder_inputs, 'wallet_encoder_inputs': wallet_encoder_inputs, # ADDED BACK 'ohlc_price_tensors': ohlc_inputs_dict['price_tensor'], 'ohlc_interval_ids': ohlc_inputs_dict['interval_ids'], 'quant_ohlc_feature_tensors': ohlc_inputs_dict['quant_feature_tensors'], 'quant_ohlc_feature_mask': ohlc_inputs_dict['quant_feature_mask'], 'quant_ohlc_feature_version_ids': ohlc_inputs_dict['quant_feature_version_ids'], 'graph_updater_links': graph_updater_links, 'wallet_addr_to_batch_idx': wallet_addr_to_batch_idx, # NEW: Pass the mapping 'dest_wallet_indices': dest_wallet_indices, # ADDED THIS LINE 'original_author_indices': original_author_indices, # NEW # --- NEW: Numerical Features for Events --- 'transfer_numerical_features': transfer_numerical_features, 'trade_numerical_features': trade_numerical_features, 'trade_dex_ids': trade_dex_ids, 'deployer_trade_numerical_features': deployer_trade_numerical_features, 'trade_direction_ids': trade_direction_ids, # NEW 'trade_mev_protection_ids': trade_mev_protection_ids, # NEW 'smart_wallet_trade_numerical_features': smart_wallet_trade_numerical_features, 'trade_is_bundle_ids': trade_is_bundle_ids, # NEW 'pool_created_numerical_features': pool_created_numerical_features, 'pool_created_protocol_ids': pool_created_protocol_ids, # NEW 'liquidity_change_numerical_features': liquidity_change_numerical_features, 'liquidity_change_type_ids': liquidity_change_type_ids, # NEW 'fee_collected_numerical_features': fee_collected_numerical_features, # NEW 'token_burn_numerical_features': token_burn_numerical_features, # NEW 'supply_lock_numerical_features': supply_lock_numerical_features, # NEW 'onchain_snapshot_numerical_features': onchain_snapshot_numerical_features, # NEW 'boosted_token_numerical_features': boosted_token_numerical_features, 'trending_token_numerical_features': trending_token_numerical_features, 'trending_token_source_ids': trending_token_source_ids, # NEW 'trending_token_timeframe_ids': trending_token_timeframe_ids, # NEW 'dexboost_paid_numerical_features': dexboost_paid_numerical_features, # NEW 'dexprofile_updated_flags': dexprofile_updated_flags, # NEW, 'global_trending_numerical_features': global_trending_numerical_features, # NEW 'chainsnapshot_numerical_features': chainsnapshot_numerical_features, # NEW 'lighthousesnapshot_numerical_features': lighthousesnapshot_numerical_features, 'lighthousesnapshot_protocol_ids': lighthousesnapshot_protocol_ids, # NEW 'lighthousesnapshot_timeframe_ids': lighthousesnapshot_timeframe_ids, # NEW 'migrated_protocol_ids': migrated_protocol_ids, # NEW 'alpha_group_ids': alpha_group_ids, # NEW 'channel_ids': channel_ids, # NEW 'exchange_ids': exchange_ids, # NEW 'holder_snapshot_raw_data': holder_snapshot_raw_data_list, # NEW: Raw data for end-to-end processing 'textual_event_data': textual_event_data_list, # RENAMED # Labels 'labels': torch.stack([item['labels'] for item in batch]) if batch and 'labels' in batch[0] else None, 'labels_mask': torch.stack([item['labels_mask'] for item in batch]) if batch and 'labels_mask' in batch[0] else None, 'movement_class_targets': torch.stack([item['movement_class_targets'] for item in batch]) if batch and 'movement_class_targets' in batch[0] else None, 'movement_class_mask': torch.stack([item['movement_class_mask'] for item in batch]) if batch and 'movement_class_mask' in batch[0] else None, 'quality_score': torch.stack([item['quality_score'] if isinstance(item['quality_score'], torch.Tensor) else torch.tensor(item['quality_score'], dtype=torch.float32) for item in batch]) if batch and 'quality_score' in batch[0] else None, 'class_id': torch.tensor([item.get('class_id', 0) for item in batch], dtype=torch.long), # Debug info 'token_addresses': [item.get('token_address', 'unknown') for item in batch], 't_cutoffs': [item.get('t_cutoff', 'unknown') for item in batch], 'sample_indices': [item.get('sample_idx', -1) for item in batch] } if collated_batch['quality_score'] is None: raise RuntimeError("FATAL: Missing quality_score in batch items. Rebuild cache with quality_score enabled.") # Filter out None values (e.g., if no labels provided) return {k: v for k, v in collated_batch.items() if v is not None}