import torch from collections import defaultdict import datetime import random import requests from io import BytesIO from torch.utils.data import Dataset, IterableDataset from PIL import Image from typing import List, Dict, Any, Optional, Union, Tuple from pathlib import Path import numpy as np from bisect import bisect_left, bisect_right from concurrent.futures import ThreadPoolExecutor import json # We need the vocabulary for IDs and the processor for the pooler import models.vocabulary as vocab from models.multi_modal_processor import MultiModalEncoder from data.data_fetcher import DataFetcher # NEW: Import the DataFetcher from data.context_targets import derive_movement_targets from data.quant_ohlc_feature_schema import ( FEATURE_INDEX, SEGMENT_SECONDS, FEATURE_VERSION, FEATURE_VERSION_ID, LOOKBACK_SECONDS, TOKENS_PER_SEGMENT, WINDOW_SECONDS, empty_feature_dict, feature_dict_to_vector, ) from signals.rolling_quant import compute_rolling_quant_features from signals.support_resistance import compute_support_resistance_features from signals.trendlines import compute_trendline_features from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry # --- NEW: Hardcoded decimals for common quote tokens --- QUOTE_TOKEN_DECIMALS = { 'So11111111111111111111111111111111111111112': 9, # SOL 'EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v': 6, # USDC 'Es9vMFrzaCERmJfrF4H2FYD4KCoNkY11McCe8BenwNYB': 6, # USDT } # --- NEW: Hyperparameters for trade event classification --- LARGE_TRADE_USD_THRESHOLD = 100.0 LARGE_TRADE_SUPPLY_PCT_THRESHOLD = 0.03 # 3% of supply LARGE_TRANSFER_SUPPLY_PCT_THRESHOLD = 0.03 # 3% of supply SMART_WALLET_PNL_THRESHOLD = 3.0 # 300% PNL SMART_WALLET_USD_THRESHOLD = 20000.0 # --- Event Categorization for Dynamic Sampling --- # Events that are rare and should ALWAYS be kept CRITICAL_EVENTS = { 'Mint', 'Deployer_Trade', 'SmartWallet_Trade', 'LargeTrade', 'LargeTransfer', 'TokenBurn', 'SupplyLock', 'PoolCreated', 'LiquidityChange', 'Migrated', 'FeeCollected', 'TrendingToken', 'BoostedToken', 'XPost', 'XRetweet', 'XReply', 'XQuoteTweet', 'PumpReply', 'DexBoost_Paid', 'DexProfile_Updated', 'AlphaGroup_Call', 'Channel_Call', 'CexListing', 'TikTok_Trending_Hashtag', 'XTrending_Hashtag' } # Periodic snapshots - kept for context continuity SNAPSHOT_EVENTS = { 'Chart_Segment', 'OnChain_Snapshot', 'HolderSnapshot', 'ChainSnapshot', 'Lighthouse_Snapshot' } # High-volume events that can be compressed (Head/Tail) COMPRESSIBLE_EVENTS = {'Trade', 'Transfer'} # --- NEW: OHLC Sequence Length Constant --- OHLC_SEQ_LEN = 300 # 4 minutes of chart MIN_AMOUNT_TRANSFER_SUPPLY = 0.0 # 1.0% of total supply # Interval for HolderSnapshot events (seconds) HOLDER_SNAPSHOT_INTERVAL_SEC = 300 HOLDER_SNAPSHOT_TOP_K = 200 DEAD_URI_RETRY_LIMIT = 2 DEFAULT_TOTAL_SUPPLY_RAW = 1_000_000_000_000_000 DEFAULT_TOKEN_DECIMALS = 6 CONTEXT_BUCKET_NEGATIVE = "bad" CONTEXT_BUCKET_POSITIVE = "good" def summarize_context_window( labels: Any, labels_mask: Any, ) -> Dict[str, Any]: """ Summarize a realized context window using its valid future returns. Base rule: - each horizon contributes signed terminal PnL from buying at cutoff - magnitude matters, so we compress returns with signed log1p - the context is `good` only if the net score is positive """ if labels is None or labels_mask is None: raise RuntimeError("Context weighting requires both 'labels' and 'labels_mask'.") if isinstance(labels, torch.Tensor): label_vals = labels.tolist() else: label_vals = list(labels) if isinstance(labels_mask, torch.Tensor): mask_vals = labels_mask.tolist() else: mask_vals = list(labels_mask) valid_returns = [ float(ret) for ret, keep in zip(label_vals, mask_vals) if float(keep) > 0.0 ] signed_contributions = [] for ret in valid_returns: magnitude = np.log1p(abs(ret)) signed_contributions.append(magnitude if ret > 0.0 else -magnitude) positive_count = sum(1 for ret in valid_returns if ret > 0.0) negative_count = len(valid_returns) - positive_count context_score = float(sum(signed_contributions) / len(signed_contributions)) if signed_contributions else 0.0 context_bucket = ( CONTEXT_BUCKET_POSITIVE if context_score > 0.0 else CONTEXT_BUCKET_NEGATIVE ) return { "context_bucket": context_bucket, "context_score": context_score, "positive_horizons": positive_count, "negative_horizons": negative_count, "valid_horizons": len(valid_returns), } class EmbeddingPooler: """ A helper class to manage the collection and encoding of unique text/image items for a single data sample. """ def __init__(self): self.pool_map = {} self.next_idx = 1 # 0 is padding def get_idx(self, item: Any) -> int: """ Returns a unique index for a given item (string or image). - Returns 0 for None or empty strings. - Deduplicates identical text and image objects. """ if item is None: return 0 # Handle text case if isinstance(item, str): if not item.strip(): # skip empty or whitespace-only strings return 0 key = item.strip() # use normalized text key elif isinstance(item, Image.Image): key = id(item) # unique memory address for images elif isinstance(item, torch.Tensor): key = id(item) # unique memory address for tensors else: key = item # fallback: use object itself if hashable if key not in self.pool_map: self.pool_map[key] = {'item': item, 'idx': self.next_idx} self.next_idx += 1 return self.pool_map[key]['idx'] def get_all_items(self) -> List[Dict[str, Any]]: """ Returns a list of all unique items, sorted by their assigned index. """ if not self.pool_map: return [] return sorted(self.pool_map.values(), key=lambda x: x['idx']) class OracleDataset(Dataset): """ Dataset class for the Oracle model. It fetches, processes, and structures all on-chain and off-chain data for a given token to create a comprehensive input sequence for the model. """ def __init__(self, data_fetcher: Optional[DataFetcher] = None, # OPTIONAL: Only needed for caching (Writer) fetcher_config: Optional[Dict[str, Any]] = None, horizons_seconds: List[int] = [], quantiles: List[float] = [], max_samples: Optional[int] = None, min_trades: int = 10, token_allowlist: Optional[List[str]] = None, cache_dir: Optional[Union[str, Path]] = None, start_date: Optional[datetime.datetime] = None, min_trade_usd: float = 0.0, max_seq_len: int = 8192, p99_clamps: Optional[Dict[str, float]] = None, movement_label_config: Optional[Dict[str, float]] = None): self.max_seq_len = max_seq_len self.min_trades = int(min_trades) if self.min_trades < 1: raise RuntimeError(f"min_trades must be >= 1, got {self.min_trades}") # --- P99 data-driven clamp values (replace hardcoded min/max) --- self.p99_clamps = { 'slippage': 1.0, 'total_usd': 100000.0, 'history_bought_cost_sol': 30.0, 'realized_profit_sol': 150.0, } if p99_clamps: self.p99_clamps.update(p99_clamps) print(f"INFO: Using P99 clamps: {self.p99_clamps}") # --- NEW: Create a persistent requests session for efficiency --- # Configure robust HTTP session self.http_session = None self._init_http_session() self.fetcher = data_fetcher self.fetcher_config = fetcher_config self.cache_dir = Path(cache_dir) if cache_dir else None # Always define these so DataLoader workers don't crash with AttributeError if # initialization falls through an unexpected branch. self.weights_list = [] # Cache for lightweight token metadata to avoid redundant DB fetches self._token_meta_cache = {} self._chart_feature_log_count = 0 self.token_allowlist = set(token_allowlist) if token_allowlist else None if self.cache_dir: if not self.cache_dir.is_dir(): raise RuntimeError( f"Cache directory '{self.cache_dir}' was provided but is not a directory. " "Fix the path or disable cached mode." ) # Cached/offline mode print(f"INFO: Initializing dataset in offline (cached) mode from: {self.cache_dir}") # Scan for cached files to determine length def _sort_key(p): # Handle both formats: sample_0.pt (numeric) and sample_ABC123.pt (token address) parts = p.stem.split('_') if len(parts) >= 2: try: return (0, int(parts[1])) # Numeric: sort by number except ValueError: return (1, parts[1]) # String: sort alphabetically return (2, p.stem) self.cached_files = sorted(self.cache_dir.glob("sample_*.pt"), key=_sort_key) if not self.cached_files: raise RuntimeError(f"Cache directory '{self.cache_dir}' provided but contains no 'sample_*.pt' files.") # --- OPTIMIZED: Load cached metadata if available --- file_class_map = {} file_context_bucket_map = {} file_context_summary_map = {} class_counts = defaultdict(int) class_context_counts = defaultdict(lambda: defaultdict(int)) metadata_path = self.cache_dir / "class_metadata.json" if metadata_path.exists(): # Fast path: load from cached metadata print(f"INFO: Loading class metadata from cache: {metadata_path}") try: with open(metadata_path, 'r') as f: cached_metadata = json.load(f) file_class_map = cached_metadata.get('file_class_map', {}) file_context_bucket_map = cached_metadata.get('file_context_bucket_map', {}) file_context_summary_map = cached_metadata.get('file_context_summary_map', {}) # Validate that cached files match metadata cached_file_names = {p.name for p in self.cached_files} metadata_file_names = set(file_class_map.keys()) if cached_file_names != metadata_file_names: print(f"WARN: Metadata cache mismatch ({len(cached_file_names)} files vs {len(metadata_file_names)} in metadata). Rebuilding...") file_class_map = {} file_context_bucket_map = {} file_context_summary_map = {} else: # Rebuild class_counts from loaded map for fname, cid in file_class_map.items(): class_counts[cid] += 1 bucket = file_context_bucket_map.get(fname) if bucket is not None: class_context_counts[cid][bucket] += 1 print(f"INFO: Loaded metadata for {len(file_class_map)} samples in <1s") except Exception as e: print(f"WARN: Failed to load metadata cache: {e}. Rebuilding...") file_class_map = {} file_context_bucket_map = {} file_context_summary_map = {} # Slow path: scan all files and build metadata cache if not file_class_map: print(f"INFO: Building class metadata from {len(self.cached_files)} files (first run only)...") for i, p in enumerate(self.cached_files): if i > 0 and i % 1000 == 0: print(f" Scanned {i}/{len(self.cached_files)} files...") try: try: cached_item = torch.load(p, map_location="cpu", weights_only=False) except TypeError: cached_item = torch.load(p, map_location="cpu") cid = cached_item.get("class_id") if cid is None: print(f"WARN: File {p.name} missing class_id. Skipping.") continue context_summary = summarize_context_window( cached_item.get("labels"), cached_item.get("labels_mask"), ) bucket = context_summary["context_bucket"] file_class_map[p.name] = cid file_context_bucket_map[p.name] = bucket file_context_summary_map[p.name] = context_summary class_counts[cid] += 1 class_context_counts[cid][bucket] += 1 except Exception as e: print(f"WARN: Failed to read cached sample {p.name}: {e}") # Save metadata cache for future runs try: with open(metadata_path, 'w') as f: json.dump({ 'file_class_map': file_class_map, 'file_context_bucket_map': file_context_bucket_map, 'file_context_summary_map': file_context_summary_map, }, f) print(f"INFO: Saved class metadata cache to {metadata_path}") except Exception as e: print(f"WARN: Failed to save metadata cache: {e}") print(f"INFO: Class Distribution: {dict(class_counts)}") print( "INFO: Context Distribution by Class: " f"{ {cid: dict(bucket_counts) for cid, bucket_counts in class_context_counts.items()} }" ) # Store file_class_map for fast lookup by train.py's create_balanced_split self.file_class_map = {p: cid for p, cid in file_class_map.items()} self.file_context_bucket_map = {p: bucket for p, bucket in file_context_bucket_map.items()} self.file_context_summary_map = {p: summary for p, summary in file_context_summary_map.items()} # Compute Weights self.weights_list = [] valid_files = [] # We iterate properly sorted cached files to align with __getitem__ index for p in self.cached_files: fname = p.name if fname not in file_class_map: # If file exists but missing class_id, it might be stale or from an older cache. print(f"WARN: File {fname} found in cache but missing class_id. Skipping.") continue cid = file_class_map[fname] bucket = file_context_bucket_map.get(fname) if bucket is None: raise RuntimeError( f"Cached sample '{fname}' is missing a context bucket. " "Rebuild metadata or cache before training." ) class_bucket_counts = class_context_counts[cid] present_buckets = [name for name, cnt in class_bucket_counts.items() if cnt > 0] if not present_buckets: raise RuntimeError( f"Class {cid} has no valid context buckets recorded. Cannot compute sampler weights." ) bucket_count = class_bucket_counts[bucket] if bucket_count <= 0: raise RuntimeError( f"Class {cid} bucket '{bucket}' has invalid count {bucket_count} for sample '{fname}'." ) weight = 1.0 / (len(present_buckets) * bucket_count) self.weights_list.append(weight) valid_files.append(p) self.cached_files = valid_files self.num_samples = len(self.cached_files) if max_samples is not None: self.num_samples = min(max_samples, self.num_samples) self.cached_files = self.cached_files[:self.num_samples] self.weights_list = self.weights_list[:self.num_samples] # Recompute sampler weights against the active cached file subset so the # class/context balancing reflects the actual dataset seen by training. active_class_context_counts = defaultdict(lambda: defaultdict(int)) for p in self.cached_files: fname = p.name cid = file_class_map[fname] bucket = file_context_bucket_map[fname] active_class_context_counts[cid][bucket] += 1 self.weights_list = [] for p in self.cached_files: fname = p.name cid = file_class_map[fname] bucket = file_context_bucket_map[fname] class_bucket_counts = active_class_context_counts[cid] present_buckets = [name for name, cnt in class_bucket_counts.items() if cnt > 0] bucket_count = class_bucket_counts[bucket] self.weights_list.append(1.0 / (len(present_buckets) * bucket_count)) print(f"INFO: Weighted Dataset Ready. {self.num_samples} samples.") self.sampled_mints = [] # Not needed in cached mode self.available_mints = [] elif self.fetcher: print(f"INFO: Initializing dataset in online (generation) mode...") self.available_mints = self.fetcher.get_all_mints(start_date=start_date) if not self.available_mints: raise RuntimeError("Dataset initialization failed: no mint records returned from data fetcher.") if self.token_allowlist: filtered_mints = [ mint for mint in self.available_mints if mint.get('mint_address') in self.token_allowlist ] if not filtered_mints: raise RuntimeError(f"No mint records matched the provided token allowlist: {token_allowlist}") self.available_mints = filtered_mints total_mints = len(self.available_mints) if max_samples is None: self.num_samples = total_mints self.sampled_mints = self.available_mints else: self.num_samples = min(max_samples, total_mints) if self.num_samples < total_mints: print(f"INFO: Limiting dataset to first {self.num_samples} of {total_mints} available mints.") self.sampled_mints = self.available_mints[:self.num_samples] else: self.available_mints = [] self.sampled_mints = [] self.num_samples = 1 if max_samples is None else max_samples self.horizons_seconds = sorted(set(horizons_seconds)) self.quantiles = quantiles self.num_outputs = len(self.horizons_seconds) * len(self.quantiles) if self.horizons_seconds: self.max_cache_horizon_seconds = max(self.horizons_seconds) else: self.max_cache_horizon_seconds = 3600 self.min_trade_usd = min_trade_usd self._uri_fail_counts: Dict[str, int] = {} self.movement_label_config = movement_label_config def _init_http_session(self) -> None: # Configure robust HTTP session self.http_session = requests.Session() retry_strategy = Retry( total=3, backoff_factor=1, status_forcelist=[429, 500, 502, 503, 504], allowed_methods=["HEAD", "GET", "OPTIONS"] ) adapter = HTTPAdapter(max_retries=retry_strategy) self.http_session.mount("http://", adapter) self.http_session.mount("https://", adapter) def init_fetcher(self) -> None: """ Initialize DataFetcher from stored config (for DataLoader workers). """ if self.fetcher is not None or not self.fetcher_config: return from clickhouse_driver import Client as ClickHouseClient from neo4j import GraphDatabase cfg = self.fetcher_config clickhouse_client = ClickHouseClient( host=cfg.get("clickhouse_host", "localhost"), port=int(cfg.get("clickhouse_port", 9000)), ) neo4j_driver = GraphDatabase.driver( cfg.get("neo4j_uri", "bolt://localhost:7687"), auth=(cfg.get("neo4j_user", "neo4j"), cfg.get("neo4j_password", "password")) ) self.fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver) def __getstate__(self): state = self.__dict__.copy() # Drop non-pickleable objects state["fetcher"] = None state["http_session"] = None return state def __setstate__(self, state): self.__dict__.update(state) if self.http_session is None: self._init_http_session() def __len__(self) -> int: return self.num_samples def get_weights(self) -> torch.DoubleTensor: """Returns the sampling weights for the dataset.""" if hasattr(self, 'weights_list') and self.weights_list: return torch.as_tensor(self.weights_list, dtype=torch.double) return None def _normalize_price_series(self, values: List[float]) -> List[float]: if not values: return values import math return [math.log(float(v)) if float(v) > 1e-9 else 0.0 for v in values] def _is_dead_uri(self, uri: Optional[str]) -> bool: if not uri: return False return self._uri_fail_counts.get(uri, 0) >= DEAD_URI_RETRY_LIMIT def _mark_uri_failure(self, uri: Optional[str]) -> None: if not uri: return self._uri_fail_counts[uri] = self._uri_fail_counts.get(uri, 0) + 1 def _apply_dynamic_sampling(self, events: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Applies dynamic context sampling to fit events within max_seq_len. Priority: 1. CRITICAL events (always kept) 2. SNAPSHOT events (kept for continuity) 3. COMPRESSIBLE events (Trade/Transfer) - split into Head/Tail with MIDDLE token Uses existing 'MIDDLE' and 'RECENT' tokens to mark transitions. """ if len(events) <= self.max_seq_len: return events # Categorize events by type critical_events = [] # (original_idx, event) snapshot_events = [] compressible_events = [] for idx, event in enumerate(events): event_type = event.get('event_type', '') if event_type in CRITICAL_EVENTS: critical_events.append((idx, event)) elif event_type in SNAPSHOT_EVENTS: snapshot_events.append((idx, event)) elif event_type in COMPRESSIBLE_EVENTS: compressible_events.append((idx, event)) else: # Unknown event types go to critical (safe default) critical_events.append((idx, event)) # Calculate budget for compressible events # Reserve 2 tokens for MIDDLE and RECENT markers reserved_tokens = 2 fixed_count = len(critical_events) + len(snapshot_events) + reserved_tokens budget_for_compressible = max(0, self.max_seq_len - fixed_count) # If no budget for compressible, just return critical + snapshots if budget_for_compressible == 0 or len(compressible_events) <= budget_for_compressible: # All compressible fit, just return sorted all_events = critical_events + snapshot_events + compressible_events all_events.sort(key=lambda x: x[0]) return [e[1] for e in all_events] # Apply Head/Tail split for compressible events head_size = budget_for_compressible // 2 tail_size = budget_for_compressible - head_size head_events = compressible_events[:head_size] tail_events = compressible_events[-tail_size:] if tail_size > 0 else [] # Find the timestamp boundary for MIDDLE/RECENT markers # MIDDLE goes after head, RECENT goes before tail middle_marker_idx = head_events[-1][0] if head_events else 0 recent_marker_idx = tail_events[0][0] if tail_events else len(events) # Create marker events middle_marker = { 'event_type': 'MIDDLE', 'relative_ts': events[middle_marker_idx].get('relative_ts', 0) if middle_marker_idx < len(events) else 0, 'is_marker': True } recent_marker = { 'event_type': 'RECENT', 'relative_ts': events[recent_marker_idx - 1].get('relative_ts', 0) if recent_marker_idx > 0 and recent_marker_idx <= len(events) else 0, 'is_marker': True } # Combine all events with markers # We need to maintain chronological order all_indexed_events = critical_events + snapshot_events + head_events + tail_events # Add markers with synthetic indices middle_idx = middle_marker_idx + 0.5 # After last head event recent_idx = recent_marker_idx - 0.5 # Before first tail event all_indexed_events.append((middle_idx, middle_marker)) all_indexed_events.append((recent_idx, recent_marker)) # Sort by original index to maintain chronological order all_indexed_events.sort(key=lambda x: x[0]) return [e[1] for e in all_indexed_events] def _compute_future_return_labels(self, anchor_price: Optional[float], anchor_timestamp: int, price_series: List[Tuple[int, float]]) -> Tuple[torch.Tensor, torch.Tensor, List[Dict[str, Any]]]: if self.num_outputs == 0: return torch.zeros(0), torch.zeros(0), [] if anchor_price is None or abs(anchor_price) < 1e-9 or not price_series: return torch.zeros(self.num_outputs), torch.zeros(self.num_outputs), [] ts_list = [int(entry[0]) for entry in price_series] price_list = [float(entry[1]) for entry in price_series] if not ts_list: return torch.zeros(self.num_outputs), torch.zeros(self.num_outputs), [] last_ts = ts_list[-1] label_values: List[float] = [] mask_values: List[float] = [] debug_entries: List[Dict[str, Any]] = [] for horizon in self.horizons_seconds: target_ts = anchor_timestamp + horizon if target_ts > last_ts: horizon_mask = 0.0 horizon_return = 0.0 future_price = None else: idx = bisect_right(ts_list, target_ts) - 1 if idx < 0: horizon_mask = 0.0 horizon_return = 0.0 future_price = None else: future_price = price_list[idx] horizon_return = (future_price - anchor_price) / anchor_price horizon_return = max(min(horizon_return, 10.0), -10.0) horizon_mask = 1.0 for _ in self.quantiles: label_values.append(horizon_return) mask_values.append(horizon_mask) debug_entries.append({ 'horizon': horizon, 'target_ts': target_ts, 'future_price': future_price, 'return': horizon_return, 'mask': horizon_mask }) return (torch.tensor(label_values, dtype=torch.float32), torch.tensor(mask_values, dtype=torch.float32), debug_entries) def _generate_onchain_snapshots( self, token_address: str, t0_timestamp: int, T_cutoff: datetime.datetime, interval_sec: int, trade_events: List[Dict[str, Any]], transfer_events: List[Dict[str, Any]], aggregation_trades: List[Dict[str, Any]], wallet_data: Dict[str, Any], total_supply_dec: float, _register_event_fn, cached_holders_list: List[Dict[str, Any]] = None ) -> None: if cached_holders_list is None: raise RuntimeError( f"Missing holder_snapshots_list for token {token_address} in _generate_onchain_snapshots." ) # Prepare helper sets and maps (static sniper set based on earliest buyers) all_buy_trades = sorted([e for e in trade_events if e.get('trade_direction') == 0 and e.get('success', False)], key=lambda x: x['timestamp']) sniper_wallets = [] seen_buyers = set() for e in all_buy_trades: wa = e['wallet_address'] if wa not in seen_buyers: sniper_wallets.append(wa) seen_buyers.add(wa) if len(sniper_wallets) >= 70: break sniper_set = set(sniper_wallets) KOL_NAME_KEYS = ['kolscan_name', 'cabalspy_name', 'axiom_kol_name'] # Build time arrays for price lookup agg_ts = [int(t['timestamp']) for t in aggregation_trades] if aggregation_trades else [] agg_price = [float(t.get('price_usd', 0.0) or 0.0) for t in aggregation_trades] if aggregation_trades else [] start_ts = t0_timestamp end_ts = int(self._timestamp_to_order_value(T_cutoff)) if hasattr(self, '_timestamp_to_order_value') else int(T_cutoff.timestamp()) buyers_seen_global = set() prev_holders_count = 0 for i, snapshot_data in enumerate(cached_holders_list): if not isinstance(snapshot_data, dict): continue ts_value = snapshot_data.get('timestamp') if ts_value is None or ts_value > end_ts: break window_start = ts_value - interval_sec trades_win = [e for e in trade_events if e.get('success', False) and window_start < e['timestamp'] <= ts_value] xfers_win = [e for e in transfer_events if window_start < e['timestamp'] <= ts_value] # SPARSE SNAPSHOTS: Skip if absolutely nothing happened in this 5 minute window if not trades_win and not xfers_win: continue if 'holders' not in snapshot_data or not isinstance(snapshot_data['holders'], list): continue holder_records_ts = snapshot_data['holders'] holders_end = 0 holder_entries_ts = [] for rec in holder_records_ts: if not isinstance(rec, dict): raise RuntimeError( f"Malformed holder record for token {token_address} at index {i}: expected dict." ) if 'wallet_address' not in rec or 'current_balance' not in rec: raise RuntimeError( f"Malformed holder record for token {token_address} at index {i}: requires wallet_address/current_balance." ) addr = rec['wallet_address'] bal = float(rec['current_balance']) pct = (bal / total_supply_dec) if total_supply_dec and total_supply_dec > 0 else 0.0 if addr and pct > 0.0: holder_entries_ts.append({'wallet': addr, 'holding_pct': pct}) holders_end += 1 holder_entries_ts.sort(key=lambda d: d['holding_pct'], reverse=True) # Emit HolderSnapshot for this ts_value hs_event = { 'event_type': 'HolderSnapshot', 'timestamp': int(ts_value), 'relative_ts': ts_value - t0_timestamp, 'holders': holder_entries_ts } _register_event_fn( hs_event, self._event_execution_sort_key(ts_value, slot=10**12, transaction_index=10**9, signature='HolderSnapshot') if hasattr(self, '_event_execution_sort_key') else (ts_value, 10**12, 10**9, 0, 'HolderSnapshot') ) holder_pct_map_ts = {d['wallet']: d['holding_pct'] for d in holder_entries_ts} top10_holder_pct = sum(d['holding_pct'] for d in holder_entries_ts[:10]) if holder_entries_ts else 0.0 # Cumulative sets up to ts_value rat_set_ts = set(ev['destination_wallet_address'] for ev in transfer_events if ev['timestamp'] <= ts_value) bundle_buyer_set_ts = set(e['wallet_address'] for e in trade_events if e.get('is_bundle') and e.get('trade_direction') == 0 and e.get('success', False) and e['timestamp'] <= ts_value) buy_count = sum(1 for e in trades_win if e.get('trade_direction') == 0) sell_count = sum(1 for e in trades_win if e.get('trade_direction') == 1) volume = sum(float(e.get('total_usd', 0.0) or 0.0) for e in trades_win) total_txns = len(trades_win) + len(xfers_win) global_fees_paid = sum( float(e.get('priority_fee', 0.0) or 0.0) + float(e.get('bribe_fee', 0.0) or 0.0) for e in trades_win ) smart_trader_addrs = set( e['wallet_address'] for e in trade_events if e.get('event_type') == 'SmartWallet_Trade' and e.get('success', False) and e['timestamp'] <= ts_value and holder_pct_map_ts.get(e['wallet_address'], 0.0) > 0.0 ) smart_traders = len(smart_trader_addrs) kol_addrs = set() for e in trades_win: wa = e['wallet_address'] soc = wallet_data.get(wa, {}).get('socials', {}) if any(soc.get(k) for k in KOL_NAME_KEYS if soc): kol_addrs.add(wa) kols = len(kol_addrs) new_buyers = [e['wallet_address'] for e in trades_win if e.get('trade_direction') == 0 and e['wallet_address'] not in buyers_seen_global] for wa in new_buyers: buyers_seen_global.add(wa) # Compute growth against previous snapshot endpoint. # total_holders = float(holders_end) # already handled above total_holders = float(holders_end) delta_holders = holders_end - prev_holders_count holder_growth_rate = float(delta_holders) prev_holders_count = holders_end # Market cap from last price at or before ts last_price_usd = 0.0 if agg_ts: for i in range(len(agg_ts) - 1, -1, -1): if agg_ts[i] <= ts_value: last_price_usd = agg_price[i] break current_market_cap = float(last_price_usd) * float(total_supply_dec) oc_event = { 'event_type': 'OnChain_Snapshot', 'timestamp': int(ts_value), 'relative_ts': ts_value - t0_timestamp, 'total_holders': total_holders, 'smart_traders': float(smart_traders), 'kols': float(kols), 'holder_growth_rate': float(holder_growth_rate), 'top_10_holder_pct': float(top10_holder_pct), 'sniper_holding_pct': float(sum(holder_pct_map_ts.get(wa, 0.0) for wa in sniper_set)), 'rat_wallets_holding_pct': float(sum(holder_pct_map_ts.get(wa, 0.0) for wa in rat_set_ts)), 'bundle_holding_pct': float(sum(holder_pct_map_ts.get(wa, 0.0) for wa in bundle_buyer_set_ts)), 'current_market_cap': float(current_market_cap), 'volume': float(volume), 'buy_count': float(buy_count), 'sell_count': float(sell_count), 'total_txns': float(total_txns), 'global_fees_paid': float(global_fees_paid) } _register_event_fn( oc_event, self._event_execution_sort_key(ts_value, slot=10**12, transaction_index=10**9, signature='OnChain_Snapshot') if hasattr(self, '_event_execution_sort_key') else (ts_value, 10**12, 10**9, 0, 'OnChain_Snapshot') ) def _calculate_deployed_token_stats(self, profiles: Dict[str, Dict[str, Any]], T_cutoff: datetime.datetime): """ Calculates aggregate statistics for wallets based on the tokens they've deployed. This method modifies the `profiles` dictionary in-place. """ if not profiles: return # --- FIX: Batch all deployed tokens upfront to avoid N+1 query problem --- all_deployed_tokens = set() for addr, profile in profiles.items(): deployed_tokens = profile.get('deployed_tokens', []) all_deployed_tokens.update(deployed_tokens) # Fetch all token details in ONE batch query all_deployed_token_details = {} if all_deployed_tokens and self.fetcher: all_deployed_token_details = self.fetcher.fetch_deployed_token_details(list(all_deployed_tokens), T_cutoff) for addr, profile in profiles.items(): deployed_tokens = profile.get('deployed_tokens', []) # 1. Deployed Tokens Count count = len(deployed_tokens) profile['deployed_tokens_count'] = float(count) if count == 0: profile['deployed_tokens_migrated_pct'] = 0.0 profile['deployed_tokens_avg_lifetime_sec'] = 0.0 profile['deployed_tokens_avg_peak_mc_usd'] = 0.0 profile['deployed_tokens_median_peak_mc_usd'] = 0.0 continue # Collect stats for all deployed tokens of this wallet (using pre-fetched data) lifetimes = [] peak_mcs = [] migrated_count = 0 for token_addr in deployed_tokens: details = all_deployed_token_details.get(token_addr) if not details: continue if details.get('has_migrated'): migrated_count += 1 lifetimes.append((details['updated_at'] - details['created_at']).total_seconds()) peak_mcs.append(details.get('ath_price_usd', 0.0) * details.get('total_supply', 0.0) / (10**details.get('decimals', 9))) # Simplified MC # 2. Migrated Pct profile['deployed_tokens_migrated_pct'] = (migrated_count / count) if count > 0 else 0.0 # 3. Avg Lifetime profile['deployed_tokens_avg_lifetime_sec'] = torch.mean(torch.tensor(lifetimes)).item() if lifetimes else 0.0 # 4. Avg & Median Peak MC profile['deployed_tokens_avg_peak_mc_usd'] = torch.mean(torch.tensor(peak_mcs)).item() if peak_mcs else 0.0 profile['deployed_tokens_median_peak_mc_usd'] = torch.median(torch.tensor(peak_mcs)).item() if peak_mcs else 0.0 def _process_wallet_data(self, wallet_addresses: List[str], token_data: Dict[str, Any], pooler: EmbeddingPooler, T_cutoff: datetime.datetime, profiles_override: Optional[Dict] = None, socials_override: Optional[Dict] = None, holdings_override: Optional[Dict] = None) -> tuple[Dict[str, Dict[str, Any]], Dict[str, Dict[str, Any]]]: """ Fetches or uses cached profile, social, and holdings data. """ import time as _time _wd_timings = {} if not wallet_addresses: return {}, token_data _t0 = _time.perf_counter() if profiles_override is not None and socials_override is not None: profiles, socials = profiles_override, socials_override holdings = holdings_override if holdings_override is not None else {} else: if self.fetcher: profiles, socials = self.fetcher.fetch_wallet_profiles_and_socials(wallet_addresses, T_cutoff) holdings = self.fetcher.fetch_wallet_holdings(wallet_addresses, T_cutoff) else: profiles, socials, holdings = {}, {}, {} _wd_timings['db_fetch'] = _time.perf_counter() - _t0 valid_wallets = [addr for addr in wallet_addresses if addr in profiles] if not valid_wallets: return {}, token_data wallet_addresses = valid_wallets # --- Collect all unique mints from holdings, split into top 10 + rest --- # Preserve seed token metadata (main token from mint record) and avoid refetching it # from holdings/token snapshots, which may be sparse at early cutoffs. seed_token_addresses = set(token_data.keys()) all_holding_mints = set() top_holding_mints = set() for wallet_addr in wallet_addresses: wallet_holds = holdings.get(wallet_addr, []) for holding_item in wallet_holds: mint_addr = holding_item.get('mint_address') if mint_addr and mint_addr not in seed_token_addresses: all_holding_mints.add(mint_addr) # Pick top holdings by volume for full image processing sorted_holds = sorted(wallet_holds, key=lambda h: float(h.get('total_volume_usd', 0) or 0), reverse=True) for h in sorted_holds[:2]: mint_addr = h.get('mint_address') if mint_addr and mint_addr not in seed_token_addresses: top_holding_mints.add(mint_addr) # Cap top mints at 10 for full image processing top_holding_mints = set(list(top_holding_mints)[:10]) rest_holding_mints = all_holding_mints - top_holding_mints _wd_timings['num_holding_tokens'] = len(all_holding_mints) # --- Process holdings: top 10 with images, rest lightweight (no HTTP) --- _t0 = _time.perf_counter() top_tokens = self._process_token_data(list(top_holding_mints), pooler, T_cutoff) if top_holding_mints else {} rest_tokens = self._process_token_data_lightweight(list(rest_holding_mints), pooler, T_cutoff) if rest_holding_mints else {} processed_new_tokens = {**top_tokens, **rest_tokens} _wd_timings['holding_token_processing'] = _time.perf_counter() - _t0 # Defensive merge: never overwrite seed token metadata with holding-token fetches. all_token_data = dict(token_data) for addr, data in (processed_new_tokens or {}).items(): if addr in all_token_data: continue all_token_data[addr] = data # --- Calculate deployed token stats using point-in-time logic --- self._calculate_deployed_token_stats(profiles, T_cutoff) # --- Assemble the final wallet dictionary --- final_wallets = {} for addr in wallet_addresses: # --- Define all expected numerical keys for a profile --- expected_profile_keys = [ 'deployed_tokens_count', 'deployed_tokens_migrated_pct', 'deployed_tokens_avg_lifetime_sec', 'deployed_tokens_avg_peak_mc_usd', 'deployed_tokens_median_peak_mc_usd', 'balance', 'transfers_in_count', 'transfers_out_count', 'spl_transfers_in_count', 'spl_transfers_out_count', 'total_buys_count', 'total_sells_count', 'total_winrate', 'stats_1d_realized_profit_sol', 'stats_1d_realized_profit_pnl', 'stats_1d_buy_count', 'stats_1d_sell_count', 'stats_1d_transfer_in_count', 'stats_1d_transfer_out_count', 'stats_1d_avg_holding_period', 'stats_1d_total_bought_cost_sol', 'stats_1d_total_sold_income_sol', 'stats_1d_total_fee', 'stats_1d_winrate', 'stats_1d_tokens_traded', 'stats_7d_realized_profit_sol', 'stats_7d_realized_profit_pnl', 'stats_7d_buy_count', 'stats_7d_sell_count', 'stats_7d_transfer_in_count', 'stats_7d_transfer_out_count', 'stats_7d_avg_holding_period', 'stats_7d_total_bought_cost_sol', 'stats_7d_total_sold_income_sol', 'stats_7d_total_fee', 'stats_7d_winrate', 'stats_7d_tokens_traded' ] profile_data = profiles.get(addr, None) if not profile_data: continue for key in expected_profile_keys: profile_data.setdefault(key, 0.0) social_data = socials.get(addr, {}) # --- Derive boolean social flags based on schema --- social_data['has_pf_profile'] = bool(social_data.get('pumpfun_username')) social_data['has_twitter'] = bool(social_data.get('twitter_username')) social_data['has_telegram'] = bool(social_data.get('telegram_channel')) social_data['is_exchange_wallet'] = 'exchange_wallet' in profile_data.get('tags', []) username = social_data.get('pumpfun_username') or social_data.get('twitter_username') or social_data.get('kolscan_name') if isinstance(username, str) and username.strip(): social_data['username_emb_idx'] = pooler.get_idx(username.strip()) else: social_data['username_emb_idx'] = 0 # --- Filter holdings and calculate derived features --- original_holdings = holdings.get(addr, []) valid_wallet_holdings = [] now_ts = datetime.datetime.now(datetime.timezone.utc) for holding_item in original_holdings: # 1. Calculate holding_time start_ts = holding_item.get('start_holding_at') mint_addr = holding_item.get('mint_address') token_info = all_token_data.get(mint_addr) if not token_info: continue end_ts = holding_item.get('end_holding_at') if not start_ts: holding_item['holding_time'] = 0.0 else: end_ts = end_ts or now_ts holding_item['holding_time'] = (end_ts - start_ts).total_seconds() # 2. Calculate balance_pct_to_supply if token_info and token_info.get('total_supply', 0) > 0: total_supply = token_info['total_supply'] / (10**token_info.get('decimals', 9)) current_balance = holding_item.get('current_balance', 0.0) holding_item['balance_pct_to_supply'] = (current_balance / total_supply) if total_supply > 0 else 0.0 else: holding_item['balance_pct_to_supply'] = 0.0 # 3. --- NEW: Calculate bought_amount_sol_pct_to_native_balance --- wallet_native_balance = profile_data.get('balance', 0.0) bought_cost_sol = holding_item.get('history_bought_cost_sol', 0.0) if wallet_native_balance > 1e-9: holding_item['bought_amount_sol_pct_to_native_balance'] = bought_cost_sol / wallet_native_balance else: holding_item['bought_amount_sol_pct_to_native_balance'] = 0.0 # Keep only fields used by WalletEncoder to minimize cache size. compact_holding = { 'mint_address': mint_addr, 'holding_time': float(holding_item.get('holding_time', 0.0) or 0.0), 'balance_pct_to_supply': min(1.0, float(holding_item.get('balance_pct_to_supply', 0.0) or 0.0)), 'history_bought_cost_sol': min(self.p99_clamps['history_bought_cost_sol'], float(holding_item.get('history_bought_cost_sol', 0.0) or 0.0)), 'bought_amount_sol_pct_to_native_balance': min(1.0, float(holding_item.get('bought_amount_sol_pct_to_native_balance', 0.0) or 0.0)), 'history_total_buys': float(holding_item.get('history_total_buys', 0.0) or 0.0), 'history_total_sells': float(holding_item.get('history_total_sells', 0.0) or 0.0), 'realized_profit_pnl': float(holding_item.get('realized_profit_pnl', 0.0) or 0.0), 'realized_profit_sol': max(-self.p99_clamps['realized_profit_sol'], min(self.p99_clamps['realized_profit_sol'], float(holding_item.get('realized_profit_sol', 0.0) or 0.0))), 'history_transfer_in': float(holding_item.get('history_transfer_in', 0.0) or 0.0), 'history_transfer_out': float(holding_item.get('history_transfer_out', 0.0) or 0.0), 'avarage_trade_gap_seconds': float(holding_item.get('avarage_trade_gap_seconds', 0.0) or 0.0), 'total_fees': float(holding_item.get('total_fees', 0.0) or 0.0), } valid_wallet_holdings.append(compact_holding) # Keep only fields consumed by WalletEncoder. compact_profile = {'wallet_address': addr} for key in expected_profile_keys: compact_profile[key] = float(profile_data.get(key, 0.0) or 0.0) compact_social = { 'has_pf_profile': bool(social_data.get('has_pf_profile', False)), 'has_twitter': bool(social_data.get('has_twitter', False)), 'has_telegram': bool(social_data.get('has_telegram', False)), 'is_exchange_wallet': bool(social_data.get('is_exchange_wallet', False)), 'username_emb_idx': int(social_data.get('username_emb_idx', 0) or 0), } final_wallets[addr] = { 'profile': compact_profile, 'socials': compact_social, 'holdings': valid_wallet_holdings } return final_wallets, all_token_data def _process_token_data(self, token_addresses: List[str], pooler: EmbeddingPooler, T_cutoff: datetime.datetime, token_data: Optional[Dict] = None) -> Dict[str, Dict[str, Any]]: """ Fetches and processes static data for a list of tokens. """ if not token_addresses: return {} if token_data is None: # 1. Check metadata cache first if not token_addresses: token_data = {} else: valid_token_data = {} missing_tokens = [] # Use cached metadata if available for addr in token_addresses: if addr in self._token_meta_cache: valid_token_data[addr] = self._token_meta_cache[addr].copy() else: missing_tokens.append(addr) # Fetch missing tokens if missing_tokens and self.fetcher: fetched = self.fetcher.fetch_token_data(missing_tokens, T_cutoff) # Update cache for addr, data in fetched.items(): if addr: self._token_meta_cache[addr] = data valid_token_data[addr] = data.copy() token_data = valid_token_data # Add pre-computed embedding indices to the token data # --- CRITICAL FIX: This function now returns None if the main token is invalid --- valid_token_data = {} for addr, data in token_data.items(): # --- FIXED: Only add to pooler if data is valid --- # --- NEW: Primary Image Fetch (Direct from Bullx) --- image = None # Initialize image for this iteration try: bullx_image_url = f"https://image.bullx.io/1399811149/{addr}?retry=0" direct_resp = self.http_session.get(bullx_image_url, timeout=5) if direct_resp.status_code == 200: try: image = Image.open(BytesIO(direct_resp.content)) except Exception as e: print(f"WARN: Failed to process image from Bullx for {addr}: {e}") image = None else: print(f"WARN: Bullx image fetch failed for {addr}: status {direct_resp.status_code}") except Exception as e: print(f"WARN: Bullx image fetch exception for {addr}: {e}") image = None # --- Fallback: Existing Metadata Fetching --- # REMOVED: IPFS fallback logic to rely solely on BullX. if image is None: token_uri = data.get('token_uri') if self._is_dead_uri(token_uri): image = None token_uri = None # Check for cached image passed in token_data if '_cached_image_pil' in data: image = data['_cached_image_pil'] if image is None: # Log failure if significant pass # --- FIXED: Check for valid metadata before adding to pooler --- token_name = data.get('name') if data.get('name') and data.get('name').strip() else None token_symbol = data.get('symbol') if data.get('symbol') and data.get('symbol').strip() else None # --- RELAXED: Allow missing metadata (pass None -> Zero Embedding) --- # The collator's EmbeddingPooler and logic handles non-str/non-image items # by skipping encoding and leaving their embedding vector as zeros. if not token_name: token_name = None #(Zeroed) if not token_symbol: token_symbol = None #(Zeroed) # If image failed or missing, pass None if not image: image = None #(Zeroed) # Only skip if we somehow have NO address (should technically fail earlier) if not addr: print(f"WARN: Token {addr} has no address?? Skipping.") continue # --- NEW: Add is_vanity feature based on the token address --- data['is_vanity'] = addr.lower().endswith("pump") data['image_emb_idx'] = pooler.get_idx(image) data['name_emb_idx'] = pooler.get_idx(token_name) data['symbol_emb_idx'] = pooler.get_idx(token_symbol) data.pop('_cached_image_pil', None) # FIX: Validate the protocol ID --- # The DB might return an ID that is out of bounds for our nn.Embedding layer. # We must ensure the ID is valid or map it to a default 'Unknown' ID. raw_protocol_id = data.get('protocol') if raw_protocol_id is not None and 0 <= raw_protocol_id < vocab.NUM_PROTOCOLS: data['protocol'] = raw_protocol_id else: data['protocol'] = vocab.PROTOCOL_TO_ID.get('Unknown', 0) valid_token_data[addr] = data return valid_token_data def _process_token_data_lightweight(self, token_addresses: List[str], pooler: EmbeddingPooler, T_cutoff: datetime.datetime) -> Dict[str, Dict[str, Any]]: """ Lightweight version of _process_token_data for non-top holding tokens. Fetches metadata from ClickHouse only (cached). NO HTTP image fetches. Sets image_emb_idx=0 (zero embedding). Still encodes name/symbol text. """ if not token_addresses: return {} # 1. Identify missing tokens not in cache missing_tokens = [addr for addr in token_addresses if addr not in self._token_meta_cache] # 2. Fetch missing tokens if missing_tokens and self.fetcher: fetched_data = self.fetcher.fetch_token_data(missing_tokens, T_cutoff) # Update cache with RAW fetched data (before pooler modifications) for addr, data in fetched_data.items(): if addr: self._token_meta_cache[addr] = data # 3. Process all tokens using cached data + current pooler valid_token_data = {} for addr in token_addresses: # Get raw data from cache raw_data = self._token_meta_cache.get(addr) if not raw_data: continue # Create a copy to modify for this specific sample/pooler context data = raw_data.copy() token_name = data.get('name') if data.get('name') and data.get('name').strip() else None token_symbol = data.get('symbol') if data.get('symbol') and data.get('symbol').strip() else None data['is_vanity'] = addr.lower().endswith("pump") data['image_emb_idx'] = 0 # Zero embedding — skip HTTP image fetch data['name_emb_idx'] = pooler.get_idx(token_name) data['symbol_emb_idx'] = pooler.get_idx(token_symbol) raw_protocol_id = data.get('protocol') if raw_protocol_id is not None and 0 <= raw_protocol_id < vocab.NUM_PROTOCOLS: data['protocol'] = raw_protocol_id else: data['protocol'] = vocab.PROTOCOL_TO_ID.get('Unknown', 0) valid_token_data[addr] = data return valid_token_data def _generate_ohlc(self, aggregation_trades: List[Dict[str, Any]], T_cutoff: datetime.datetime, interval_seconds: int, t0_timestamp: float = None) -> List[tuple]: """ Generates an OHLC series from a list of aggregated trades with a dynamic interval. It forward-fills gaps and extends the series up to T_cutoff. Returns a list of (timestamp, open, close) tuples. Args: t0_timestamp: If provided, OHLC will start from max(first_trade, t0_timestamp) to ensure chart data never precedes the mint event. """ if not aggregation_trades: return [] trades_by_interval = defaultdict(list) for trade in aggregation_trades: # Group trades into interval buckets interval_start_ts = (trade['timestamp'] // interval_seconds) * interval_seconds trades_by_interval[interval_start_ts].append(trade['price_usd']) sorted_intervals = sorted(trades_by_interval.keys()) if not sorted_intervals: return [] full_ohlc = [] # Ensure chart starts AFTER mint (t0_timestamp) to prevent Chart_Segment before Mint in event ordering start_ts = sorted_intervals[0] if t0_timestamp is not None: # Align to interval boundary at or after t0 t0_aligned = (int(t0_timestamp) // interval_seconds) * interval_seconds if t0_aligned < t0_timestamp: t0_aligned += interval_seconds # Move to next interval to ensure it's after t0 start_ts = max(start_ts, t0_aligned) end_ts = int(T_cutoff.timestamp()) for interval_ts in sorted_intervals: if start_ts <= interval_ts <= end_ts: prices = trades_by_interval[interval_ts] open_price = prices[0] close_price = prices[-1] full_ohlc.append((interval_ts, open_price, close_price)) return full_ohlc def _compute_quant_rolling_features( self, closes: List[float], end_idx: int, ) -> Dict[str, float]: return compute_rolling_quant_features(closes, end_idx) def _compute_support_resistance_features( self, closes: List[float], highs: List[float], lows: List[float], end_idx: int, window_start: int, window_end: int, timestamps: List[int], ) -> Dict[str, float]: return compute_support_resistance_features( closes=closes, highs=highs, lows=lows, end_idx=end_idx, window_start=window_start, window_end=window_end, timestamps=timestamps, ) def _compute_trendline_features( self, closes: List[float], highs: List[float], lows: List[float], end_idx: int, ) -> Dict[str, float]: return compute_trendline_features(closes, highs, lows, end_idx) def _extract_quant_ohlc_features_for_segment( self, segment: List[tuple], interval_label: str, token_address: Optional[str] = None, ) -> List[Dict[str, Any]]: if not segment: print( f"INFO: Chart quant skipped | token={token_address or 'unknown'} " "reason=empty_segment" ) return [] try: interval_seconds = max(1, int(str(interval_label).rstrip("s"))) except Exception: interval_seconds = 1 window_bar_count = max(1, WINDOW_SECONDS // interval_seconds) effective_window_seconds = max(WINDOW_SECONDS, interval_seconds) max_windows = max(1, SEGMENT_SECONDS // effective_window_seconds) timestamps = [int(row[0]) for row in segment] opens = [float(row[1]) for row in segment] closes = [float(row[2]) for row in segment] highs = [max(o, c) for o, c in zip(opens, closes)] lows = [min(o, c) for o, c in zip(opens, closes)] log_closes = np.log(np.clip(np.asarray(closes, dtype=np.float64), 1e-8, None)) one_sec_returns = np.diff(log_closes) feature_windows: List[Dict[str, Any]] = [] for window_idx in range(max_windows): window_start = window_idx * window_bar_count if window_start >= len(segment): break window_end = min(len(segment), window_start + window_bar_count) current_end_idx = window_end - 1 window_returns = one_sec_returns[window_start:max(window_start, current_end_idx)] window_closes = closes[window_start:window_end] window_highs = highs[window_start:window_end] window_lows = lows[window_start:window_end] features = empty_feature_dict() if window_closes: window_close_arr = np.asarray(window_closes, dtype=np.float64) window_return_sum = float(np.sum(window_returns)) if window_returns.size > 0 else 0.0 range_width = max(max(window_highs) - min(window_lows), 0.0) first_close = float(window_close_arr[0]) last_close = float(window_close_arr[-1]) accel_proxy = 0.0 if window_returns.size >= 2: accel_proxy = float(window_returns[-1] - window_returns[0]) features.update({ "cum_log_return": window_return_sum, "mean_log_return_1s": float(np.mean(window_returns)) if window_returns.size > 0 else 0.0, "std_log_return_1s": float(np.std(window_returns)) if window_returns.size > 0 else 0.0, "max_up_1s": float(np.max(window_returns)) if window_returns.size > 0 else 0.0, "max_down_1s": float(np.min(window_returns)) if window_returns.size > 0 else 0.0, "realized_vol": float(np.sqrt(np.sum(np.square(window_returns)))) if window_returns.size > 0 else 0.0, "window_range_frac": range_width / max(abs(last_close), 1e-8), "close_to_close_slope": (last_close - first_close) / max(abs(first_close), 1e-8), "accel_proxy": accel_proxy, "frac_pos_1s": float(np.mean(window_returns > 0)) if window_returns.size > 0 else 0.0, "frac_neg_1s": float(np.mean(window_returns < 0)) if window_returns.size > 0 else 0.0, }) current_price = closes[current_end_idx] current_high = highs[current_end_idx] current_low = lows[current_end_idx] for lookback in LOOKBACK_SECONDS: prefix = f"lb_{lookback}s" lookback_start = max(0, current_end_idx - lookback + 1) hist_closes = closes[lookback_start: current_end_idx + 1] hist_highs = highs[lookback_start: current_end_idx + 1] hist_lows = lows[lookback_start: current_end_idx + 1] hist_range = max(max(hist_highs) - min(hist_lows), 1e-8) rolling_high = max(hist_highs) rolling_low = min(hist_lows) hist_returns = np.diff(np.log(np.clip(np.asarray(hist_closes, dtype=np.float64), 1e-8, None))) current_width = max(max(window_highs) - min(window_lows), 0.0) prev_hist_width = max(max(hist_highs[:-len(window_highs)]) - min(hist_lows[:-len(window_lows)]), 0.0) if len(hist_highs) > len(window_highs) else current_width prev_close = closes[current_end_idx - 1] if current_end_idx > 0 else current_price features.update({ f"{prefix}_dist_high": (rolling_high - current_price) / max(abs(current_price), 1e-8), f"{prefix}_dist_low": (current_price - rolling_low) / max(abs(current_price), 1e-8), f"{prefix}_drawdown_high": (current_price - rolling_high) / max(abs(rolling_high), 1e-8), f"{prefix}_rebound_low": (current_price - rolling_low) / max(abs(rolling_low), 1e-8), f"{prefix}_pos_in_range": (current_price - rolling_low) / hist_range, f"{prefix}_range_width": hist_range / max(abs(current_price), 1e-8), f"{prefix}_compression_ratio": current_width / max(prev_hist_width, 1e-8), f"{prefix}_breakout_high": 1.0 if current_high > rolling_high and prev_close <= rolling_high else 0.0, f"{prefix}_breakdown_low": 1.0 if current_low < rolling_low and prev_close >= rolling_low else 0.0, f"{prefix}_reclaim_breakdown": 1.0 if current_low < rolling_low and current_price >= rolling_low else 0.0, f"{prefix}_rejection_breakout": 1.0 if current_high > rolling_high and current_price <= rolling_high else 0.0, }) features.update(self._compute_support_resistance_features( closes=closes, highs=highs, lows=lows, end_idx=current_end_idx, window_start=window_start, window_end=window_end, timestamps=timestamps, )) features.update(self._compute_trendline_features( closes=closes, highs=highs, lows=lows, end_idx=current_end_idx, )) features.update(self._compute_quant_rolling_features( closes=closes, end_idx=current_end_idx, )) feature_windows.append({ "start_ts": timestamps[window_start], "end_ts": timestamps[current_end_idx], "window_seconds": effective_window_seconds, "feature_vector": feature_dict_to_vector(features), "feature_names_version": FEATURE_VERSION, "feature_version_id": FEATURE_VERSION_ID, "level_snapshot": { "support_distance": features.get("nearest_support_dist", 0.0), "resistance_distance": features.get("nearest_resistance_dist", 0.0), "support_strength": features.get("support_strength", 0.0), "resistance_strength": features.get("resistance_strength", 0.0), "breakout_up": features.get("keylevel_breakout_up", 0.0), "breakout_down": features.get("keylevel_breakout_down", 0.0), "hold_above": features.get("keylevel_hold_above", 0.0), "hold_below": features.get("keylevel_hold_below", 0.0), "flip_to_support": features.get("keylevel_flip_to_support", 0.0), "flip_to_resistance": features.get("keylevel_flip_to_resistance", 0.0), }, "keylevel_flags": { "breakout_up": features.get("keylevel_breakout_up", 0.0), "breakout_down": features.get("keylevel_breakout_down", 0.0), "hold_above": features.get("keylevel_hold_above", 0.0), "hold_below": features.get("keylevel_hold_below", 0.0), "failed_breakout_up": features.get("keylevel_failed_breakout_up", 0.0), "failed_breakout_down": features.get("keylevel_failed_breakout_down", 0.0), "flip_to_support": features.get("keylevel_flip_to_support", 0.0), "flip_to_resistance": features.get("keylevel_flip_to_resistance", 0.0), }, }) sr_windows = sum( 1 for window in feature_windows if float(window["feature_vector"][FEATURE_INDEX["sr_available"]]) > 0.0 ) trendline_windows = sum( 1 for window in feature_windows if float(window["feature_vector"][FEATURE_INDEX["trendline_available"]]) > 0.0 ) breakout_windows = sum( 1 for window in feature_windows if ( float(window["feature_vector"][FEATURE_INDEX["keylevel_breakout_up"]]) > 0.0 or float(window["feature_vector"][FEATURE_INDEX["keylevel_breakout_down"]]) > 0.0 or float(window["feature_vector"][FEATURE_INDEX["keylevel_flip_to_support"]]) > 0.0 or float(window["feature_vector"][FEATURE_INDEX["keylevel_flip_to_resistance"]]) > 0.0 ) ) keylevel_break_events = sum( 1 for window in feature_windows if ( float(window["feature_vector"][FEATURE_INDEX["keylevel_breakout_up"]]) > 0.0 or float(window["feature_vector"][FEATURE_INDEX["keylevel_breakout_down"]]) > 0.0 ) ) self._chart_feature_log_count += 1 print( f"INFO: Chart quant built | token={token_address or 'unknown'} " f"interval={interval_label} segment={self._chart_feature_log_count} " f"windows={len(feature_windows)}/{max_windows} " f"sr={sr_windows} trend={trendline_windows} breaks={breakout_windows} " f"break_events={keylevel_break_events}" ) return feature_windows def __getitem__(self, idx: int) -> Optional[Dict[str, Any]]: """ Loads data from cache. """ import time as _time _timings = {} _total_start = _time.perf_counter() # --- TIMING: Cache load --- _t0 = _time.perf_counter() if not self.cache_dir: raise RuntimeError("Offline mode required. No cache directory provided.") if idx >= len(self.cached_files): raise IndexError(f"Index {idx} out of range for {len(self.cached_files)} cached files.") filepath = self.cached_files[idx] try: cached_data = torch.load(filepath, map_location='cpu', weights_only=False) except Exception as e: raise RuntimeError(f"ERROR: Could not load cached item {filepath}: {e}") _timings['cache_load'] = _time.perf_counter() - _t0 if not cached_data: raise RuntimeError(f"No data loaded for index {idx}") has_context_shape = ( isinstance(cached_data, dict) and 'event_sequence' in cached_data and 'tokens' in cached_data and 'wallets' in cached_data and 'labels' in cached_data and 'labels_mask' in cached_data ) if has_context_shape: # Return pre-computed training context directly. _timings['total'] = _time.perf_counter() - _total_start if 'movement_class_targets' not in cached_data and 'labels' in cached_data and 'labels_mask' in cached_data: labels = cached_data['labels'] labels_mask = cached_data['labels_mask'] movement_targets = derive_movement_targets( labels.tolist() if isinstance(labels, torch.Tensor) else labels, labels_mask.tolist() if isinstance(labels_mask, torch.Tensor) else labels_mask, movement_label_config=self.movement_label_config, ) cached_data['movement_class_targets'] = torch.tensor( movement_targets['movement_class_targets'], dtype=torch.long, ) cached_data['movement_class_mask'] = torch.tensor( movement_targets['movement_class_mask'], dtype=torch.long, ) if idx % 100 == 0: print(f"[Sample {idx}] context cache | cache_load: {_timings['cache_load']*1000:.1f}ms | " f"total: {_timings['total']*1000:.1f}ms | events: {len(cached_data.get('event_sequence', []))}") return cached_data raise RuntimeError( f"Cached item at {filepath} is not a valid context cache. " "Rebuild the cache with scripts/cache_dataset.py." ) def _process_token_data_offline(self, token_addresses: List[str], pooler: EmbeddingPooler, T_cutoff: datetime.datetime, token_data: Optional[Dict] = None) -> Dict[str, Dict[str, Any]]: """ Processes token data in OFFLINE mode - no HTTP calls for images. Uses pre-cached image bytes from the cache file. """ if not token_addresses: return {} if token_data is None: token_data = {} valid_token_data = {} for addr, data in token_data.items(): # Use cached PIL image if available (set by __getitem__) image = data.get('_cached_image_pil', None) # Get token metadata token_name = data.get('name') if data.get('name') and data.get('name').strip() else None token_symbol = data.get('symbol') if data.get('symbol') and data.get('symbol').strip() else None if not addr: continue # Add is_vanity feature data['is_vanity'] = addr.lower().endswith("pump") # Add embedding indices data['image_emb_idx'] = pooler.get_idx(image) data['name_emb_idx'] = pooler.get_idx(token_name) data['symbol_emb_idx'] = pooler.get_idx(token_symbol) # Drop transient in-memory image object from cache payload. data.pop('_cached_image_pil', None) # Validate protocol ID raw_protocol_id = data.get('protocol') if raw_protocol_id is not None and 0 <= raw_protocol_id < vocab.NUM_PROTOCOLS: data['protocol'] = raw_protocol_id else: data['protocol'] = vocab.PROTOCOL_TO_ID.get('Unknown', 0) valid_token_data[addr] = data return valid_token_data def _build_main_token_seed(self, token_address: str, raw_data: Dict[str, Any]) -> Dict[str, Any]: """ Build a minimal token metadata payload for the main token. Prevents raw cache blobs (trades/snapshots/etc.) from leaking into sample['tokens'][main_token]. """ return { 'token_address': token_address, 'address': token_address, 'name': raw_data.get('name', ''), 'symbol': raw_data.get('symbol', ''), 'token_uri': raw_data.get('token_uri', ''), 'protocol': raw_data.get('protocol', 1), 'total_supply': raw_data.get('total_supply', 0), 'decimals': raw_data.get('decimals', 6), } def __cacheitem__(self, idx: int) -> Optional[Dict[str, Any]]: """ Fetches cutoff-agnostic raw token data for caching/online sampling. Generates dense time-series (1s OHLC, Snapshots) and prunes raw logs. NEW: Also pre-fetches and caches ALL wallet profiles, socials, holdings, graph links, and token images to enable fully offline training. """ if not self.sampled_mints: raise RuntimeError("Dataset has no mint records loaded; ensure fetcher returned data during initialization.") if idx >= len(self.sampled_mints): raise IndexError(f"Requested sample index {idx} exceeds loaded mint count {len(self.sampled_mints)}.") initial_mint_record = self.sampled_mints[idx] t0 = initial_mint_record["timestamp"] if isinstance(t0, datetime.datetime) and t0.tzinfo is None: t0 = t0.replace(tzinfo=datetime.timezone.utc) creator_address = initial_mint_record['creator_address'] token_address = initial_mint_record['mint_address'] # Per-token header logging removed for caching speed if not self.fetcher: raise RuntimeError("Dataset has no data fetcher; cannot load raw data.") # --- FETCH FULL HISTORY with PRUNING --- raw_data = self.fetcher.fetch_raw_token_data( token_address=token_address, creator_address=creator_address, mint_timestamp=t0, max_horizon_seconds=self.max_cache_horizon_seconds, include_wallet_data=False, include_graph=False, min_trades=self.min_trades, full_history=True, # Bypass H/B/H limits prune_failed=False, # Keep failed trades for realistic simulation prune_transfers=False # Keep transfers for snapshot reconstruction ) if raw_data is None: return None # --- FIX: Add token metadata from mint record to raw_data --- # DEBUG: Print what's in the mint record print(f" DEBUG: initial_mint_record keys: {list(initial_mint_record.keys())}") print(f" DEBUG: token_name='{initial_mint_record.get('token_name')}', token_symbol='{initial_mint_record.get('token_symbol')}'") raw_data['name'] = initial_mint_record.get('token_name', '') raw_data['symbol'] = initial_mint_record.get('token_symbol', '') raw_data['token_uri'] = initial_mint_record.get('token_uri', '') raw_total_supply = initial_mint_record.get('total_supply', DEFAULT_TOTAL_SUPPLY_RAW) raw_token_decimals = initial_mint_record.get('token_decimals', DEFAULT_TOKEN_DECIMALS) raw_data['total_supply'] = ( int(raw_total_supply) if raw_total_supply and int(raw_total_supply) > 0 else DEFAULT_TOTAL_SUPPLY_RAW ) raw_data['decimals'] = ( int(raw_token_decimals) if raw_token_decimals is not None and int(raw_token_decimals) >= 0 else DEFAULT_TOKEN_DECIMALS ) raw_data['protocol'] = initial_mint_record.get('protocol', 1) def _timestamp_to_order_value(ts_value: Any) -> float: if isinstance(ts_value, datetime.datetime): if ts_value.tzinfo is None: ts_value = ts_value.replace(tzinfo=datetime.timezone.utc) return ts_value.timestamp() try: return float(ts_value) except (TypeError, ValueError): return 0.0 trades = raw_data.get('trades', []) trade_ts_values = [_timestamp_to_order_value(t.get('timestamp')) for t in trades] if not trade_ts_values: print(f" SKIP: No valid trades found for {token_address}.") return None t0_val = _timestamp_to_order_value(t0) last_trade_ts_val = max(trade_ts_values) # --- GENERATE DENSE 1s OHLC --- duration_seconds = int(last_trade_ts_val - t0_val) + 120 # Add buffer ohlc_1s = torch.zeros((duration_seconds, 2), dtype=torch.float32) # Sort trades by time # raw_data trades are already sorted by fetcher, but let's be safe trades.sort(key=lambda x: _timestamp_to_order_value(x['timestamp'])) # Fill OHLC # A faster way: group by second # We can use a simple loop update or numpy accumulation. # Given standard density, simple loop is fine for caching. trades_by_sec = defaultdict(list) for t in trades: ts = _timestamp_to_order_value(t['timestamp']) sec_idx = int(ts - t0_val) if 0 <= sec_idx < duration_seconds: trades_by_sec[sec_idx].append(t['price_usd']) last_close = float(trades[0]['price_usd']) for i in range(duration_seconds): if i in trades_by_sec: prices = trades_by_sec[i] op = prices[0] cl = prices[-1] last_close = cl else: op = cl = last_close ohlc_1s[i, 0] = float(op) ohlc_1s[i, 1] = float(cl) raw_data['ohlc_1s'] = ohlc_1s # --- GENERATE ON-CHAIN SNAPSHOTS (5m Interval) --- interval = 300 # 5 minutes num_intervals = (duration_seconds // interval) + 1 # Feature columns: [volume, tx_count, buy_count, sell_count, total_holders, top_10_holder_pct] # We start with basic trade stats. Holder stats require DB queries. snapshot_stats = torch.zeros((num_intervals, 6), dtype=torch.float32) cum_volume = 0.0 cum_tx = 0 cum_buys = 0 cum_sells = 0 # Pre-group trades into 5m buckets for windowed volume buckets = defaultdict(list) for t in trades: ts = _timestamp_to_order_value(t['timestamp']) bucket_idx = int(ts - t0_val) // interval if bucket_idx >= 0: buckets[bucket_idx].append(t) # Batch-fetch ALL holder counts in ONE query (replaces N per-interval queries) holder_counts_by_interval = {} try: all_holdings = self.fetcher.db_client.execute(""" SELECT wallet_address, current_balance, updated_at FROM wallet_holdings WHERE mint_address = %(token)s ORDER BY wallet_address, updated_at """, {'token': token_address}) if all_holdings: wallet_latest = {} all_holdings.sort(key=lambda x: x[2]) holding_idx = 0 for i in range(num_intervals): snap_ts = t0 + datetime.timedelta(seconds=(i + 1) * interval) while holding_idx < len(all_holdings) and all_holdings[holding_idx][2] <= snap_ts: wallet, balance, _ = all_holdings[holding_idx] wallet_latest[wallet] = balance holding_idx += 1 holder_counts_by_interval[i] = sum(1 for b in wallet_latest.values() if b and b > 0) except Exception: pass holder_snapshots_list = [] for i in range(num_intervals): bucket_trades = buckets[i] # Windowed Stats vol = sum(t.get('total_usd', 0.0) for t in bucket_trades) tx = len(bucket_trades) buys = sum(1 for t in bucket_trades if t.get('trade_direction') == 0 or t.get('trade_type') == 0) # 0=Buy sells = tx - buys count = holder_counts_by_interval.get(i, 0) snapshot_stats[i, 0] = float(vol) snapshot_stats[i, 1] = float(tx) snapshot_stats[i, 2] = float(buys) snapshot_stats[i, 3] = float(sells) snapshot_stats[i, 4] = float(count) snapshot_stats[i, 5] = 0.0 # top10_pct not available in batch mode snapshot_ts = t0 + datetime.timedelta(seconds=(i+1)*interval) holder_snapshots_list.append({ 'timestamp': int(snapshot_ts.timestamp()), 'holders': [] }) raw_data['snapshots_5m'] = snapshot_stats raw_data['holder_snapshots_list'] = holder_snapshots_list # Save the list raw_data['holder_snapshots_list'] = holder_snapshots_list # Save the list raw_data["protocol_id"] = initial_mint_record.get("protocol") # ======================================================================= # NEW: PRE-FETCH AND CACHE ALL WALLET/GRAPH/IMAGE DATA FOR OFFLINE MODE # ======================================================================= # This enables fully offline training with zero DB calls during __getitem__ # ======================================================================= # Wallet/graph pre-fetch logging removed for caching speed # 1. Collect ALL unique wallet addresses from all events all_wallets = set() all_wallets.add(creator_address) for trade in raw_data.get('trades', []): if trade.get('maker'): all_wallets.add(trade['maker']) for transfer in raw_data.get('transfers', []): if transfer.get('source'): all_wallets.add(transfer['source']) if transfer.get('destination'): all_wallets.add(transfer['destination']) for pool in raw_data.get('pool_creations', []): if pool.get('creator_address'): all_wallets.add(pool['creator_address']) for liq in raw_data.get('liquidity_changes', []): if liq.get('lp_provider'): all_wallets.add(liq['lp_provider']) # Add wallets from holder snapshots for snapshot in holder_snapshots_list: for holder in snapshot.get('holders', []): if holder.get('wallet_address'): all_wallets.add(holder['wallet_address']) all_wallets.discard(None) all_wallets.discard('') wallet_list = list(all_wallets) print(f" INFO: Found {len(wallet_list)} unique wallets to cache") # 2. Fetch wallet profiles and socials (time-independent for caching) # Use the last trade timestamp as T_cutoff for the cache (max point in time) max_T_cutoff = datetime.datetime.fromtimestamp(last_trade_ts_val, tz=datetime.timezone.utc) try: cached_profiles, cached_socials = self.fetcher.fetch_wallet_profiles_and_socials( wallet_list, max_T_cutoff ) print(f" INFO: Cached {len(cached_profiles)} wallet profiles, {len(cached_socials)} socials") except Exception as e: print(f" WARN: Failed to fetch wallet profiles/socials: {e}") cached_profiles, cached_socials = {}, {} # 3. Fetch wallet holdings (at max T_cutoff) try: cached_holdings = self.fetcher.fetch_wallet_holdings(wallet_list, max_T_cutoff) print(f" INFO: Cached holdings for {len(cached_holdings)} wallets") except Exception as e: print(f" WARN: Failed to fetch wallet holdings: {e}") cached_holdings = {} # 4. Fetch graph links (at max T_cutoff) try: cached_graph_entities, cached_graph_links = self.fetcher.fetch_graph_links( wallet_list, max_T_cutoff, max_degrees=1 ) print(f" INFO: Cached {len(cached_graph_links)} graph link types, {len(cached_graph_entities)} graph entities") except Exception as e: print(f" WARN: Failed to fetch graph links: {e}") cached_graph_entities, cached_graph_links = {}, {} # 5. Fetch and cache token image as bytes (not PIL Image to avoid pickle issues) cached_image_bytes = None try: # Try Bullx first bullx_image_url = f"https://image.bullx.io/1399811149/{token_address}?retry=0" resp = self.http_session.get(bullx_image_url, timeout=2) if resp.status_code == 200: cached_image_bytes = resp.content print(f" INFO: Cached token image from Bullx ({len(cached_image_bytes)} bytes)") else: # Fallback to token_uri metadata token_uri = raw_data.get('token_uri') if token_uri and 'ipfs/' in token_uri: ipfs_gateways = [ "https://pump.mypinata.cloud/ipfs/", "https://dweb.link/ipfs/", "https://cloudflare-ipfs.com/ipfs/", ] metadata_hash = token_uri.split('ipfs/')[-1] for gateway in ipfs_gateways: try: metadata_resp = self.http_session.get(f"{gateway}{metadata_hash}", timeout=5) if metadata_resp.status_code == 200: metadata = metadata_resp.json() image_url = metadata.get('image', '') if image_url and 'ipfs/' in image_url: image_hash = image_url.split('ipfs/')[-1] for img_gateway in ipfs_gateways: try: img_resp = self.http_session.get(f"{img_gateway}{image_hash}", timeout=5) if img_resp.status_code == 200: cached_image_bytes = img_resp.content print(f" INFO: Cached token image from IPFS ({len(cached_image_bytes)} bytes)") break except: continue break except: continue except Exception as e: print(f" WARN: Failed to cache token image: {e}") # 6. Store all cached data in raw_data raw_data['cached_wallet_data'] = { 'profiles': cached_profiles, 'socials': cached_socials, 'holdings': cached_holdings, } raw_data['cached_graph_data'] = { 'entities': cached_graph_entities, 'links': cached_graph_links, } raw_data['cached_image_bytes'] = cached_image_bytes raw_data['cached_max_T_cutoff'] = max_T_cutoff.timestamp() print(f" INFO: Cache complete for {token_address}") return raw_data def _generate_dataset_item(self, token_address: str, t0: datetime.datetime, T_cutoff: datetime.datetime, mint_event: Dict[str, Any], trade_records: List[Dict[str, Any]], transfer_records: List[Dict[str, Any]], pool_creation_records: List[Dict[str, Any]], liquidity_change_records: List[Dict[str, Any]], fee_collection_records: List[Dict[str, Any]], burn_records: List[Dict[str, Any]], supply_lock_records: List[Dict[str, Any]], migration_records: List[Dict[str, Any]], wallet_data: Dict[str, Dict[str, Any]], all_token_data: Dict[str, Any], graph_links: Dict[str, Any], graph_seed_entities: set, all_graph_entities: Dict[str, str], future_trades_for_labels: List[Dict[str, Any]], pooler: EmbeddingPooler, sample_idx: Optional[int] = None, cached_holders_list: List[Dict[str, Any]] = None, cached_ohlc_1s: Optional[torch.Tensor] = None, quality_score: Optional[float] = None ) -> Optional[Dict[str, Any]]: """ Processes raw token data into a structured dataset item for a specific T_cutoff. Filters events beyond T_cutoff, computes derived features, and builds the final sample. """ # Helper functions (re-defined here to be accessible within this scope or passed as args if refactoring further) # For simplicity, assuming helper functions like _timestamp_to_order_value are available as self methods or inner functions # We will duplicate small helpers for self-containment or assume class methods if we moved them. # But wait, looking at the previous code, they were inner functions of __cacheitem__. # We'll make them class methods or redefining them. Redefining for safety. def _safe_int(value: Any) -> int: try: return int(value) except: return 0 def _timestamp_to_order_value(ts_value: Any) -> float: if isinstance(ts_value, datetime.datetime): if ts_value.tzinfo is None: ts_value = ts_value.replace(tzinfo=datetime.timezone.utc) return ts_value.timestamp() elif isinstance(ts_value, str): try: return datetime.datetime.fromisoformat(ts_value.replace('Z', '+00:00')).timestamp() except ValueError: pass try: return float(ts_value) except: return 0.0 def _event_execution_sort_key(timestamp_value: Any, slot=0, transaction_index=0, instruction_index=0, signature='') -> tuple: return (_timestamp_to_order_value(timestamp_value), _safe_int(slot), _safe_int(transaction_index), _safe_int(instruction_index), signature or '') def _trade_execution_sort_key(trade: Dict[str, Any]) -> tuple: return ( _timestamp_to_order_value(trade.get('timestamp')), _safe_int(trade.get('slot')), _safe_int(trade.get('transaction_index')), _safe_int(trade.get('instruction_index')), trade.get('signature', '') ) t0_timestamp = _timestamp_to_order_value(t0) # 1. Filter events by T_cutoff # We need to filter 'records' lists to only include items <= T_cutoff # AND we need to be careful about which features we compute based on this subset. def filter_by_time(records): return [r for r in records if _timestamp_to_order_value(r.get('timestamp')) <= T_cutoff.timestamp()] trade_records = filter_by_time(trade_records) transfer_records = filter_by_time(transfer_records) pool_creation_records = filter_by_time(pool_creation_records) liquidity_change_records = filter_by_time(liquidity_change_records) fee_collection_records = filter_by_time(fee_collection_records) burn_records = filter_by_time(burn_records) supply_lock_records = filter_by_time(supply_lock_records) migration_records = filter_by_time(migration_records) # 2. Main Event Registry event_sequence_entries: List[Tuple[tuple, Dict[str, Any]]] = [] def _register_event(event: Dict[str, Any], sort_key: tuple): event_sequence_entries.append((sort_key, event)) # Register Anchor Mint Event (always present) _register_event(mint_event, _event_execution_sort_key(mint_event['timestamp'], signature='Mint')) # 3. Process Trades (Events + Chart) trade_events = [] transfer_events = [] aggregation_trades = [] high_def_chart_trades = [] middle_chart_trades = [] main_token_info = all_token_data.get(token_address, {}) base_decimals = main_token_info.get('decimals', 6) raw_total_supply = main_token_info.get('total_supply', 0) total_supply_dec = (raw_total_supply / (10**base_decimals)) if base_decimals > 0 else raw_total_supply # Fallback to 1B supply if total_supply is missing (standard for Pump.fun tokens) if not total_supply_dec or total_supply_dec == 0: total_supply_dec = 1_000_000_000.0 # Constants from your code QUOTE_TOKEN_DECIMALS = {'So11111111111111111111111111111111111111112': 9} # Simplified SMART_WALLET_PNL_THRESHOLD = 2.0 SMART_WALLET_USD_THRESHOLD = 1000.0 LARGE_TRADE_SUPPLY_PCT_THRESHOLD = 0.019 LARGE_TRADE_USD_THRESHOLD = 330.0 LARGE_TRANSFER_SUPPLY_PCT_THRESHOLD = 0.0028 for trade in trade_records: if trade.get('total_usd', 0.0) < self.min_trade_usd: continue trade_sort_key = _trade_execution_sort_key(trade) trade_ts_int = int(_timestamp_to_order_value(trade.get('timestamp'))) # Identify Event Type trader_addr = trade['maker'] # NOTE: wallet_data might contain future info if we didn't mask it carefully in fetch_raw # But here we are processing relative to T_cutoff. # In a perfect world, we'd roll back wallet stats. # For now, we use the "static" wallet features we have. trader_wallet = wallet_data.get(trader_addr, {}) trader_profile = trader_wallet.get('profile', {}) KOL_NAME_KEYS = ['kolscan_name', 'cabalspy_name', 'axiom_kol_name'] is_kol = any(trader_wallet.get('socials', {}).get(key) for key in KOL_NAME_KEYS) is_profitable = (trader_profile.get('stats_30d_realized_profit_pnl', 0.0) > SMART_WALLET_PNL_THRESHOLD) base_amount_dec = trade.get('base_amount', 0) / (10**base_decimals) is_large_amount = (total_supply_dec > 0 and (base_amount_dec / total_supply_dec) > LARGE_TRADE_SUPPLY_PCT_THRESHOLD) if trader_addr == mint_event['wallet_address']: event_type = 'Deployer_Trade' elif is_kol or is_profitable: event_type = 'SmartWallet_Trade' elif trade.get('total_usd', 0.0) > LARGE_TRADE_USD_THRESHOLD or is_large_amount: event_type = 'LargeTrade' else: event_type = 'Trade' # Calcs quote_address = trade.get('quote_address') quote_decimals = QUOTE_TOKEN_DECIMALS.get(quote_address, 9) quote_amount_dec = trade.get('quote_amount', 0) / (10**quote_decimals) is_sell = trade.get('trade_type') == 1 pre_trade_base = (trade.get('base_balance', 0) + base_amount_dec) if is_sell else trade.get('base_balance', 0) pre_trade_quote = (trade.get('quote_balance', 0) + quote_amount_dec) if not is_sell else trade.get('quote_balance', 0) token_pct_hold = min(1.0, (base_amount_dec / pre_trade_base) if pre_trade_base > 1e-9 else 1.0) quote_pct_hold = min(1.0, (quote_amount_dec / pre_trade_quote) if pre_trade_quote > 1e-9 else 1.0) token_pct_supply = min(1.0, (base_amount_dec / total_supply_dec) if total_supply_dec > 0 else 0.0) is_success = trade.get('success', False) price_valid = float(trade.get('price_usd', 0.0) or 0.0) > 0 if is_success and price_valid: chart_entry = { 'trade_direction': 1 if is_sell else 0, 'price_usd': trade.get('price_usd', 0.0), 'timestamp': trade_ts_int, 'sort_key': trade_sort_key } aggregation_trades.append(chart_entry) high_def_chart_trades.append(chart_entry.copy()) # Simplified: Just use all trades for mid for now or split if needed middle_chart_trades.append(chart_entry.copy()) trade_event = { 'event_type': event_type, 'timestamp': trade_ts_int, 'relative_ts': _timestamp_to_order_value(trade.get('timestamp')) - t0_timestamp, 'wallet_address': trader_addr, 'token_address': token_address, 'trade_direction': 1 if is_sell else 0, 'sol_amount': trade.get('total', 0.0), 'dex_platform_id': trade.get('platform', 0), 'priority_fee': trade.get('priority_fee', 0.0), 'mev_protection': 1 if trade.get('mev_protection', 0) > 0 else 0, 'token_amount_pct_of_holding': token_pct_hold, 'quote_amount_pct_of_holding': quote_pct_hold, 'slippage': min(self.p99_clamps['slippage'], float(trade.get('slippage', 0.0) or 0.0)), 'token_amount_pct_to_total_supply': token_pct_supply, 'success': is_success, 'is_bundle': trade.get('is_bundle', False), 'total_usd': trade.get('total_usd', 0.0) } # Add to registry _register_event(trade_event, trade_sort_key) trade_events.append(trade_event) # 3b. Process Transfers for transfer in transfer_records: transfer_ts_val = _timestamp_to_order_value(transfer.get('timestamp')) transfer_ts_int = int(transfer_ts_val) amount_dec = float(transfer.get('amount_decimal', 0.0) or 0.0) source_balance = float(transfer.get('source_balance', 0.0) or 0.0) denom = source_balance + amount_dec if source_balance > 0 else 0.0 transfer_pct_of_holding = min(1.0, (amount_dec / denom) if denom > 1e-9 else 0.0) transfer_pct_of_supply = min(1.0, (amount_dec / total_supply_dec) if total_supply_dec > 0 else 0.0) is_large_transfer = transfer_pct_of_supply >= LARGE_TRANSFER_SUPPLY_PCT_THRESHOLD transfer_event = { 'event_type': 'LargeTransfer' if is_large_transfer else 'Transfer', 'timestamp': transfer_ts_int, 'relative_ts': transfer_ts_val - t0_timestamp, 'wallet_address': transfer.get('source'), 'destination_wallet_address': transfer.get('destination'), 'token_address': token_address, 'token_amount': amount_dec, 'transfer_pct_of_total_supply': transfer_pct_of_supply, 'transfer_pct_of_holding': transfer_pct_of_holding, 'priority_fee': transfer.get('priority_fee', 0.0), 'success': transfer.get('success', False) } _register_event( transfer_event, _event_execution_sort_key( transfer.get('timestamp'), slot=transfer.get('slot', 0), signature=transfer.get('signature', '') ) ) transfer_events.append(transfer_event) # 4. Generate Chart Events def _finalize_chart(t_list): t_list.sort(key=lambda x: x['sort_key']) for e in t_list: e.pop('sort_key', None) _finalize_chart(aggregation_trades) _finalize_chart(high_def_chart_trades) _finalize_chart(middle_chart_trades) HIGH_DEF_INTERVAL = ("1s", 1) MIDDLE_INTERVAL = ("30s", 30) def _emit_chart_segments(trades: List[Dict[str, Any]], interval: tuple, precomputed_ohlc: List[tuple] = None): if not trades and precomputed_ohlc is None: return [] interval_label, interval_seconds = interval if precomputed_ohlc is not None: ohlc_series = precomputed_ohlc else: # Pass t0_timestamp to ensure OHLC starts after mint, preventing Chart_Segment before Mint ohlc_series = self._generate_ohlc(trades, T_cutoff, interval_seconds, t0_timestamp=t0_timestamp) emitted_events = [] for idx in range(0, len(ohlc_series), OHLC_SEQ_LEN): segment = ohlc_series[idx:idx + OHLC_SEQ_LEN] if not segment: continue last_ts = segment[-1][0] opens_raw = [s[1] for s in segment] closes_raw = [s[2] for s in segment] chart_event = { 'event_type': 'Chart_Segment', 'timestamp': int(last_ts), 'relative_ts': int(last_ts) - int(t0_timestamp), 'opens': self._normalize_price_series(opens_raw), 'closes': self._normalize_price_series(closes_raw), 'i': interval_label, 'quant_ohlc_features': self._extract_quant_ohlc_features_for_segment(segment, interval_label, token_address=token_address), 'quant_feature_version': FEATURE_VERSION, } emitted_events.append(chart_event) return emitted_events # Build chart candidates (registration deferred until we choose exactly one interval mode) chart_events_1s = [] chart_events_30s = [] # Build chart candidates (registration deferred until we choose exactly one interval mode) # We process sparse native charts using _generate_ohlc for both 1s and 30s chart_events_1s = _emit_chart_segments(high_def_chart_trades, HIGH_DEF_INTERVAL) chart_events_30s = _emit_chart_segments(middle_chart_trades, MIDDLE_INTERVAL) # 5. Process Other Records (Pool, Liquidity, Fees, Burns, Locks, Migrations) pool_meta_by_address = {} for pool_record in pool_creation_records: pool_addr = pool_record.get('pool_address') if pool_addr: pool_meta_by_address[pool_addr] = pool_record pool_ts_val = _timestamp_to_order_value(pool_record.get('timestamp')) pool_ts = int(pool_ts_val) base_decimals = pool_record.get('base_decimals') quote_decimals = pool_record.get('quote_decimals') base_decimals = int(base_decimals) if base_decimals is not None else 0 quote_decimals = int(quote_decimals) if quote_decimals is not None else 0 base_amount_raw = pool_record.get('initial_base_liquidity', 0) or 0 quote_amount_raw = pool_record.get('initial_quote_liquidity', 0) or 0 base_amount = float(base_amount_raw) / (10 ** base_decimals) if base_decimals > 0 else float(base_amount_raw) quote_amount = float(quote_amount_raw) / (10 ** quote_decimals) if quote_decimals > 0 else float(quote_amount_raw) pool_event = { 'event_type': 'PoolCreated', 'timestamp': pool_ts, 'relative_ts': pool_ts_val - t0_timestamp, 'wallet_address': pool_record.get('creator_address'), 'token_address': token_address, 'quote_token_address': pool_record.get('quote_address'), 'protocol_id': pool_record.get('protocol', 0), 'pool_address': pool_addr, 'base_amount': base_amount, 'quote_amount': quote_amount, 'priority_fee': pool_record.get('priority_fee', 0.0), 'success': pool_record.get('success', False) } _register_event( pool_event, _event_execution_sort_key( pool_record.get('timestamp'), slot=pool_record.get('slot', 0), signature=pool_record.get('signature', '') ) ) for liq_record in liquidity_change_records: liq_ts_val = _timestamp_to_order_value(liq_record.get('timestamp')) liq_ts = int(liq_ts_val) pool_addr = liq_record.get('pool_address') pool_meta = pool_meta_by_address.get(pool_addr, {}) quote_decimals = pool_meta.get('quote_decimals') quote_decimals = int(quote_decimals) if quote_decimals is not None else 0 quote_amount_raw = liq_record.get('quote_amount', 0) or 0 quote_amount = float(quote_amount_raw) / (10 ** quote_decimals) if quote_decimals > 0 else float(quote_amount_raw) liq_event = { 'event_type': 'LiquidityChange', 'timestamp': liq_ts, 'relative_ts': liq_ts_val - t0_timestamp, 'wallet_address': liq_record.get('lp_provider'), 'token_address': token_address, 'quote_token_address': pool_meta.get('quote_address'), 'protocol_id': liq_record.get('protocol', 0), 'change_type_id': liq_record.get('change_type', 0), 'quote_amount': quote_amount, 'priority_fee': liq_record.get('priority_fee', 0.0), 'success': liq_record.get('success', False) } _register_event( liq_event, _event_execution_sort_key( liq_record.get('timestamp'), slot=liq_record.get('slot', 0), signature=liq_record.get('signature', '') ) ) for fee_record in fee_collection_records: fee_ts_val = _timestamp_to_order_value(fee_record.get('timestamp')) fee_ts = int(fee_ts_val) amount = 0.0 if fee_record.get('token_0_mint_address') == token_address: amount = float(fee_record.get('token_0_amount', 0.0) or 0.0) elif fee_record.get('token_1_mint_address') == token_address: amount = float(fee_record.get('token_1_amount', 0.0) or 0.0) fee_event = { 'event_type': 'FeeCollected', 'timestamp': fee_ts, 'relative_ts': fee_ts_val - t0_timestamp, 'wallet_address': fee_record.get('recipient_address'), 'token_address': token_address, 'sol_amount': amount, 'protocol_id': fee_record.get('protocol', 0), 'priority_fee': fee_record.get('priority_fee', 0.0), 'success': fee_record.get('success', False) } _register_event( fee_event, _event_execution_sort_key( fee_record.get('timestamp'), slot=fee_record.get('slot', 0), signature=fee_record.get('signature', '') ) ) for burn_record in burn_records: burn_ts_val = _timestamp_to_order_value(burn_record.get('timestamp')) burn_ts = int(burn_ts_val) amount_dec = float(burn_record.get('amount_decimal', 0.0) or 0.0) amount_pct = (amount_dec / total_supply_dec) if total_supply_dec > 0 else 0.0 burn_event = { 'event_type': 'TokenBurn', 'timestamp': burn_ts, 'relative_ts': burn_ts_val - t0_timestamp, 'wallet_address': burn_record.get('source'), 'token_address': token_address, 'amount_pct_of_total_supply': amount_pct, 'amount_tokens_burned': amount_dec, 'priority_fee': burn_record.get('priority_fee', 0.0), 'success': burn_record.get('success', False) } _register_event( burn_event, _event_execution_sort_key( burn_record.get('timestamp'), slot=burn_record.get('slot', 0), signature=burn_record.get('signature', '') ) ) for lock_record in supply_lock_records: lock_ts_val = _timestamp_to_order_value(lock_record.get('timestamp')) lock_ts = int(lock_ts_val) total_locked_amount = float(lock_record.get('total_locked_amount', 0.0) or 0.0) amount_pct = (total_locked_amount / total_supply_dec) if total_supply_dec > 0 else 0.0 final_unlock_ts = lock_record.get('final_unlock_timestamp', 0) or 0 lock_duration = float(final_unlock_ts) - float(lock_ts_val) if lock_duration < 0: lock_duration = 0.0 lock_event = { 'event_type': 'SupplyLock', 'timestamp': lock_ts, 'relative_ts': lock_ts_val - t0_timestamp, 'wallet_address': lock_record.get('sender'), 'token_address': token_address, 'amount_pct_of_total_supply': amount_pct, 'lock_duration': lock_duration, 'protocol_id': lock_record.get('protocol', 0), 'priority_fee': lock_record.get('priority_fee', 0.0), 'success': lock_record.get('success', False) } _register_event( lock_event, _event_execution_sort_key( lock_record.get('timestamp'), slot=lock_record.get('slot', 0), signature=lock_record.get('signature', '') ) ) for migration_record in migration_records: mig_ts_val = _timestamp_to_order_value(migration_record.get('timestamp')) mig_ts = int(mig_ts_val) mig_event = { 'event_type': 'Migrated', 'timestamp': mig_ts, 'relative_ts': mig_ts_val - t0_timestamp, 'wallet_address': None, 'token_address': token_address, 'protocol_id': migration_record.get('protocol', 0), 'priority_fee': migration_record.get('priority_fee', 0.0), 'success': migration_record.get('success', False) } _register_event( mig_event, _event_execution_sort_key( migration_record.get('timestamp'), slot=migration_record.get('slot', 0), signature=migration_record.get('signature', '') ) ) # --- ADD DYNAMIC T_CUTOFF SNAPSHOT --- # Evaluate balances exactly up to T_cutoff using the filtered trade_records wallet_balances_raw = {} for trade in trade_records: if not trade.get('success', False): continue maker = trade.get('maker') if not maker: continue try: trade_type = int(trade.get('trade_type', 0)) base_amount_raw = int(trade.get('base_amount', 0)) except: continue if trade_type not in (0, 1) or base_amount_raw < 0: continue signed_delta = base_amount_raw if trade_type == 0 else -base_amount_raw wallet_balances_raw[maker] = wallet_balances_raw.get(maker, 0) + signed_delta positive_holders_raw = [(w, b) for w, b in wallet_balances_raw.items() if b > 0] positive_holders_raw.sort(key=lambda item: (-item[1], item[0])) holders_topk_raw = positive_holders_raw[:HOLDER_SNAPSHOT_TOP_K] cutoff_ts_epoch = int(T_cutoff.timestamp()) token_scale = 10 ** base_decimals if base_decimals else 1 cutoff_snapshot = { 'timestamp': cutoff_ts_epoch, 'holders': [ { 'wallet_address': w, 'current_balance': float(b) / float(token_scale) } for w, b in holders_topk_raw ] } # Create a local copy of cached_holders_list up to T_cutoff local_holders_list = [ snap for snap in (cached_holders_list or []) if snap.get('timestamp', 0) < cutoff_ts_epoch ] # Append our precise T_cutoff snapshot at the end if not local_holders_list or local_holders_list[-1]['timestamp'] != cutoff_ts_epoch: local_holders_list.append(cutoff_snapshot) # 6. Generate Snapshots self._generate_onchain_snapshots( token_address, int(t0_timestamp), T_cutoff, 300, # Interval trade_events, transfer_events, aggregation_trades, wallet_data, total_supply_dec, _register_event, cached_holders_list=local_holders_list ) # Choose exactly one chart resolution per sample: # - no pressure -> 1s # - pressure -> 30s non_chart_event_count = len(event_sequence_entries) would_exceed = (non_chart_event_count + len(chart_events_1s)) > self.max_seq_len selected_chart_events = chart_events_30s if would_exceed else chart_events_1s selected_chart_signature = "chart-mid" if would_exceed else "chart-hd" for chart_idx, chart_event in enumerate(selected_chart_events): # Assign an artificially extremely high slot/tx index to ensure Chart_Segment always sorts AFTER all trades on the same timestamp _register_event( chart_event, _event_execution_sort_key(chart_event['timestamp'], slot=10**12, transaction_index=10**9, signature=f"{selected_chart_signature}-{chart_idx}") ) # 7. Finalize Sequence with Dynamic Sampling event_sequence_entries.sort(key=lambda x: x[0]) raw_event_sequence = [entry[1] for entry in event_sequence_entries] # Apply dynamic context sampling if needed event_sequence = self._apply_dynamic_sampling(raw_event_sequence) # 8. Compute Labels using future data # Define horizons (e.g., [60, 120, ...]) horizons = sorted(self.horizons_seconds) # Pre-sort future trades for efficient searching # Note: future_trades_for_labels contains ALL trades (past & future relative to T_cutoff) # We need to find the price at T_cutoff and at T_cutoff + h # ============================================================================ # CRITICAL: Filter for successful trades with valid prices ONLY! # ============================================================================ # Failed trades (success=False) often have price_usd=0 or invalid values. # Using these for label computation causes mathematically impossible returns # like -1.0 (price went to 0) or 0.0 (no price change despite trading). # ALWAYS filter by: success=True AND price_usd > 0 # ============================================================================ all_trades = [ t for t in future_trades_for_labels if t.get('success', False) and float(t.get('price_usd', 0) or 0) > 0 ] if not all_trades: # No valid trades for label computation quant_payload = [ event.get('quant_ohlc_features', []) for event in event_sequence if event.get('event_type') == 'Chart_Segment' ] return { 'event_sequence': event_sequence, 'wallets': wallet_data, 'tokens': all_token_data, 'graph_links': graph_links, 'embedding_pooler': pooler, 'quant_ohlc_features': quant_payload, 'quant_feature_version': FEATURE_VERSION, 'labels': torch.zeros(len(self.horizons_seconds), dtype=torch.float32), 'labels_mask': torch.zeros(len(self.horizons_seconds), dtype=torch.float32), 'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32), } # Ensure sorted all_trades.sort(key=lambda x: _timestamp_to_order_value(x['timestamp'])) # Find price at T_cutoff (Current Price) # It's the last trade before or at T_cutoff current_price = 0.0 cutoff_ts_val = T_cutoff.timestamp() last_trade_ts_val = _timestamp_to_order_value(all_trades[-1]['timestamp']) # Filter to only successful, positive priced trades for label generation valid_trades_for_labels = [t for t in all_trades if t.get('success', False) and float(t.get('price_usd', 0) or 0) > 0] current_price_idx = -1 for i, t in enumerate(valid_trades_for_labels): if _timestamp_to_order_value(t['timestamp']) <= cutoff_ts_val: current_price = float(t['price_usd']) current_price_idx = i else: break # DEBUG: Label computation details removed after validation label_values = [] mask_values = [] # Edge case: no trades before cutoff means we have no anchor price if current_price_idx < 0 or current_price <= 0: # No valid anchor price - mask all labels for h in horizons: label_values.append(0.0) mask_values.append(0.0) else: # The price from the previous valid bucket last_bucket_price = current_price prev_target_ts = cutoff_ts_val # Check if there are ANY trades left after target_ts to validate empty buckets in between def _has_future_trades(after_ts: float) -> bool: for j in range(current_price_idx + 1, len(valid_trades_for_labels)): if _timestamp_to_order_value(valid_trades_for_labels[j]['timestamp']) > after_ts: return True return False for h in horizons: target_ts = cutoff_ts_val + h bucket_price = last_bucket_price # Forward fill by default found_future_trade = False for j in range(current_price_idx + 1, len(valid_trades_for_labels)): t = valid_trades_for_labels[j] t_ts = _timestamp_to_order_value(t['timestamp']) if prev_target_ts < t_ts <= target_ts: bucket_price = float(t['price_usd']) found_future_trade = True elif t_ts > target_ts: break if found_future_trade: # New trade exists in bucket, use its price and valid mask ret = (bucket_price - current_price) / current_price label_values.append(ret) mask_values.append(1.0) last_bucket_price = bucket_price else: # Bucket is empty. Fill return with the last valid price. # Mask is 1.0 if there are STILL trades occurring later (price held steady). # Mask is 0.0 only if the token is completely dead and no trades ever occur again. ret = (last_bucket_price - current_price) / current_price label_values.append(ret) mask_values.append(1.0 if _has_future_trades(target_ts) else 0.0) prev_target_ts = target_ts # DEBUG: Mask summaries removed after validation quant_payload = [ event.get('quant_ohlc_features', []) for event in event_sequence if event.get('event_type') == 'Chart_Segment' ] return { 'sample_idx': sample_idx if sample_idx is not None else -1, # Debug trace 'token_address': token_address, # For debugging 't_cutoff': T_cutoff.isoformat() if T_cutoff else None, # For debugging 'event_sequence': event_sequence, 'wallets': wallet_data, 'tokens': all_token_data, 'graph_links': graph_links, 'embedding_pooler': pooler, 'quant_ohlc_features': quant_payload, 'quant_feature_version': FEATURE_VERSION, 'labels': torch.tensor(label_values, dtype=torch.float32), 'labels_mask': torch.tensor(mask_values, dtype=torch.float32), 'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32), } def _embed_context(self, context: Dict[str, Any], encoder: Any) -> None: """ Helper to replace raw items in the embedding pooler with pre-computed embeddings using the provided encoder (on GPU). """ pooler = context.get('embedding_pooler') if not pooler: return # Direct access to pool_map keys_to_embed_img = [] images_to_embed = [] keys_to_embed_text = [] texts_to_embed = [] for key, entry in pooler.pool_map.items(): item = entry['item'] if isinstance(item, str): # Strings (text) keys_to_embed_text.append(key) texts_to_embed.append(item) elif hasattr(item, 'resize') and not isinstance(item, torch.Tensor): # Duck typing to catch all PIL images keys_to_embed_img.append(key) images_to_embed.append(item) # Batch encode images if images_to_embed: # print(f"DEBUG: Found {len(images_to_embed)} images to embed", flush=True) with torch.no_grad(): img_embeddings = encoder(images_to_embed) # Update pool_map directly for images for i, (key, emb) in enumerate(zip(keys_to_embed_img, img_embeddings)): if key in pooler.pool_map: old_entry = pooler.pool_map[key] pooler.pool_map[key] = {'item': emb.cpu().clone(), 'idx': old_entry['idx']} # Batch encode text if texts_to_embed: # print(f"DEBUG: Found {len(texts_to_embed)} text items to embed", flush=True) with torch.no_grad(): text_embeddings = encoder(texts_to_embed) # Update pool_map directly for text for i, (key, emb) in enumerate(zip(keys_to_embed_text, text_embeddings)): if key in pooler.pool_map: old_entry = pooler.pool_map[key] pooler.pool_map[key] = {'item': emb.cpu().clone(), 'idx': old_entry['idx']} def __cacheitem_context__(self, idx: int, num_samples_per_token: int = 1, encoder: Optional[Any] = None, forced_cutoff_trade_idx: Optional[int] = None) -> List[Optional[Dict[str, Any]]]: """ Generates fully processed training contexts for caching. This method: 1. Fetches raw token data (like __cacheitem__) 2. Samples T_cutoff(s) using the weight sampling logic 3. Applies H/B/H dynamic sampling based on max_seq_len 4. Returns complete training-ready samples that can be loaded directly This moves ALL non-determinism into cache time, making training fully offline and avoiding caching tokens that would never be seen (98% garbage filtered out by weight sampling and T_cutoff eligibility). Args: idx: Index into sampled_mints num_samples_per_token: Number of different T_cutoff samples to generate per token Returns: List of training-ready samples (may be fewer than num_samples_per_token if some T_cutoffs are invalid) """ import time as _time if not self.sampled_mints: raise RuntimeError("Dataset has no mint records loaded.") if idx >= len(self.sampled_mints): raise IndexError(f"Index {idx} exceeds mint count {len(self.sampled_mints)}.") initial_mint_record = self.sampled_mints[idx] t0 = initial_mint_record["timestamp"] if isinstance(t0, datetime.datetime) and t0.tzinfo is None: t0 = t0.replace(tzinfo=datetime.timezone.utc) creator_address = initial_mint_record['creator_address'] token_address = initial_mint_record['mint_address'] # Verbose per-token logging removed for caching speed (was printing for every token) if not self.fetcher: raise RuntimeError("Dataset has no data fetcher.") # --- STEP 1: Fetch raw data (same as __cacheitem__) --- raw_data = self.fetcher.fetch_raw_token_data( token_address=token_address, creator_address=creator_address, mint_timestamp=t0, max_horizon_seconds=self.max_cache_horizon_seconds, include_wallet_data=False, include_graph=False, min_trades=self.min_trades, full_history=True, prune_failed=False, prune_transfers=False ) if raw_data is None: print(f" SKIP: No raw data for {token_address}") return [] # --- FIX: Add token metadata from mint record to raw_data --- # DEBUG: Print what's in the mint record (first token only) if idx == 0: print(f" DEBUG: initial_mint_record keys: {list(initial_mint_record.keys())}") print(f" DEBUG: token_name='{initial_mint_record.get('token_name')}', token_symbol='{initial_mint_record.get('token_symbol')}'") raw_data['name'] = initial_mint_record.get('token_name', '') raw_data['symbol'] = initial_mint_record.get('token_symbol', '') raw_data['token_uri'] = initial_mint_record.get('token_uri', '') raw_total_supply = initial_mint_record.get('total_supply', DEFAULT_TOTAL_SUPPLY_RAW) raw_token_decimals = initial_mint_record.get('token_decimals', DEFAULT_TOKEN_DECIMALS) raw_data['total_supply'] = ( int(raw_total_supply) if raw_total_supply and int(raw_total_supply) > 0 else DEFAULT_TOTAL_SUPPLY_RAW ) raw_data['decimals'] = ( int(raw_token_decimals) if raw_token_decimals is not None and int(raw_token_decimals) >= 0 else DEFAULT_TOKEN_DECIMALS ) raw_data['protocol'] = initial_mint_record.get('protocol', 1) def _timestamp_to_order_value(ts_value) -> float: if isinstance(ts_value, datetime.datetime): if ts_value.tzinfo is None: ts_value = ts_value.replace(tzinfo=datetime.timezone.utc) return ts_value.timestamp() elif isinstance(ts_value, str): try: return datetime.datetime.fromisoformat(ts_value.replace('Z', '+00:00')).timestamp() except ValueError: pass try: return float(ts_value) except: return 0.0 # --- STEP 2: Validate trades and find eligible T_cutoff indices --- all_trades_raw = raw_data.get('trades', []) if not all_trades_raw: print(f" SKIP: No trades for {token_address}") return [] all_trades_sorted = sorted( [t for t in all_trades_raw if t.get('timestamp') is not None], key=lambda t: _timestamp_to_order_value(t.get('timestamp')) ) min_context_trades = self.min_trades if len(all_trades_sorted) < (min_context_trades + 1): print(f" SKIP: Not enough trades ({len(all_trades_sorted)}) for {token_address}") return [] # Find successful trade indices successful_indices = [ i for i, t in enumerate(all_trades_sorted) if t.get('success', False) and float(t.get('price_usd', 0) or 0) > 0 ] if len(successful_indices) < 2: print(f" SKIP: Not enough successful trades for {token_address}") return [] max_horizon_seconds = max(self.horizons_seconds) if self.horizons_seconds else 0 min_idx = min_context_trades - 1 max_idx = len(all_trades_sorted) - 2 if max_idx < min_idx: print(f" SKIP: Invalid index range for {token_address}") return [] # Build lookup arrays last_successful_before = [-1] * len(all_trades_sorted) last_seen = -1 succ_set = set(successful_indices) for i in range(len(all_trades_sorted)): if i in succ_set: last_seen = i last_successful_before[i] = last_seen next_successful_after = [-1] * len(all_trades_sorted) next_seen = -1 for i in range(len(all_trades_sorted) - 1, -1, -1): if i in succ_set: next_seen = i next_successful_after[i] = next_seen # Find all eligible T_cutoff indices eligible_indices = [] for i in range(min_idx, max_idx + 1): anchor_idx = last_successful_before[i] next_idx = next_successful_after[i + 1] if i + 1 < len(all_trades_sorted) else -1 if anchor_idx < 0 or next_idx < 0: continue cutoff_ts = _timestamp_to_order_value(all_trades_sorted[i].get('timestamp')) next_ts = _timestamp_to_order_value(all_trades_sorted[next_idx].get('timestamp')) if next_ts <= cutoff_ts + max_horizon_seconds: eligible_indices.append(i) if not eligible_indices: print(f" SKIP: No eligible T_cutoff indices for {token_address}") return [] # Eligible positions count logged via tqdm in cache_dataset.py # --- STEP 3: Generate OHLC and holder snapshots (same as __cacheitem__) --- trades = raw_data.get('trades', []) trade_ts_values = [_timestamp_to_order_value(t.get('timestamp')) for t in trades] t0_val = _timestamp_to_order_value(t0) last_trade_ts_val = max(trade_ts_values) # Disable dense OHLC 1s precomputation. # Chart_Segment will now generate sparse OHLC at runtime. duration_seconds = int(last_trade_ts_val - t0_val) + 120 raw_data['ohlc_1s'] = None # Generate holder snapshots from deterministic trade-ledger reconstruction. interval = 300 num_intervals = (duration_seconds // interval) + 1 snapshot_stats = torch.zeros((num_intervals, 6), dtype=torch.float32) buckets = defaultdict(list) for t in trades: ts = _timestamp_to_order_value(t['timestamp']) bucket_idx = int(ts - t0_val) // interval if bucket_idx >= 0: buckets[bucket_idx].append(t) raw_total_supply = raw_data.get('total_supply') raw_decimals = raw_data.get('decimals') if raw_total_supply is None or raw_decimals is None: raise RuntimeError("Missing token total_supply/decimals required for holder snapshot reconstruction.") total_supply_raw = int(raw_total_supply) token_decimals = int(raw_decimals) if total_supply_raw <= 0: total_supply_raw = DEFAULT_TOTAL_SUPPLY_RAW if token_decimals < 0: token_decimals = DEFAULT_TOKEN_DECIMALS token_scale = 10 ** token_decimals def _strict_int(v: Any, field_name: str) -> int: if v is None: raise RuntimeError(f"Missing {field_name} in trade record for {token_address}.") try: return int(v) except Exception as e: raise RuntimeError(f"Invalid {field_name} in trade record for {token_address}: {v}") from e def _trade_sort_key_for_ledger(trade: Dict[str, Any]) -> tuple: return ( _timestamp_to_order_value(trade.get('timestamp')), _strict_int(trade.get('slot', 0), 'slot'), _strict_int(trade.get('transaction_index', 0), 'transaction_index'), _strict_int(trade.get('instruction_index', 0), 'instruction_index'), str(trade.get('signature') or '') ) ledger_trades = [] for trade in trades: if not trade.get('success', False): continue maker = trade.get('maker') if not maker: raise RuntimeError(f"Missing maker in successful trade for {token_address}.") trade_type = _strict_int(trade.get('trade_type'), 'trade_type') if trade_type not in (0, 1): raise RuntimeError(f"Invalid trade_type={trade_type} for {token_address}; expected 0/1.") base_amount_raw = _strict_int(trade.get('base_amount'), 'base_amount') if base_amount_raw < 0: raise RuntimeError(f"Invalid negative base_amount={base_amount_raw} for {token_address}.") ledger_trades.append((trade, maker, trade_type, base_amount_raw)) ledger_trades.sort(key=lambda x: _trade_sort_key_for_ledger(x[0])) wallet_balances_raw: Dict[str, int] = {} ledger_idx = 0 holder_snapshots_list = [] for i in range(num_intervals): bucket_trades = buckets[i] vol = sum(t.get('total_usd', 0.0) for t in bucket_trades) tx = len(bucket_trades) buys = sum(1 for t in bucket_trades if t.get('trade_direction') == 0 or t.get('trade_type') == 0) sells = tx - buys snapshot_ts_epoch = t0_val + ((i + 1) * interval) while ledger_idx < len(ledger_trades): trade, maker, trade_type, base_amount_raw = ledger_trades[ledger_idx] trade_ts = _timestamp_to_order_value(trade.get('timestamp')) if trade_ts > snapshot_ts_epoch: break signed_delta = base_amount_raw if trade_type == 0 else -base_amount_raw wallet_balances_raw[maker] = wallet_balances_raw.get(maker, 0) + signed_delta ledger_idx += 1 positive_holders_raw = [(wallet, bal) for wallet, bal in wallet_balances_raw.items() if bal > 0] positive_holders_raw.sort(key=lambda item: (-item[1], item[0])) holders_topk_raw = positive_holders_raw[:HOLDER_SNAPSHOT_TOP_K] count = len(positive_holders_raw) top10_sum_raw = sum(bal for _, bal in positive_holders_raw[:10]) top10_pct = float(top10_sum_raw) / float(total_supply_raw) snapshot_stats[i, 0] = float(vol) snapshot_stats[i, 1] = float(tx) snapshot_stats[i, 2] = float(buys) snapshot_stats[i, 3] = float(sells) snapshot_stats[i, 4] = float(count) snapshot_stats[i, 5] = float(top10_pct) snapshot_ts = t0 + datetime.timedelta(seconds=(i+1)*interval) holder_snapshots_list.append({ 'timestamp': int(snapshot_ts.timestamp()), 'holders': [ { 'wallet_address': wallet, 'current_balance': float(balance_raw) / float(token_scale) } for wallet, balance_raw in holders_topk_raw ] }) raw_data['snapshots_5m'] = snapshot_stats raw_data['holder_snapshots_list'] = holder_snapshots_list raw_data['protocol_id'] = initial_mint_record.get('protocol') # --- STEP 4: Collect ALL wallets and pre-fetch their data --- all_wallets = set() all_wallets.add(creator_address) for trade in raw_data.get('trades', []): if trade.get('maker'): all_wallets.add(trade['maker']) for transfer in raw_data.get('transfers', []): if transfer.get('source'): all_wallets.add(transfer['source']) if transfer.get('destination'): all_wallets.add(transfer['destination']) for pool in raw_data.get('pool_creations', []): if pool.get('creator_address'): all_wallets.add(pool['creator_address']) for liq in raw_data.get('liquidity_changes', []): if liq.get('lp_provider'): all_wallets.add(liq['lp_provider']) for snapshot in holder_snapshots_list: if not isinstance(snapshot, dict) or not isinstance(snapshot.get('holders'), list): raise RuntimeError("Invalid holder_snapshots_list entry during wallet collection.") for holder in snapshot['holders']: if not isinstance(holder, dict) or 'wallet_address' not in holder or 'current_balance' not in holder: raise RuntimeError("Invalid holder record during wallet collection.") all_wallets.add(holder['wallet_address']) all_wallets.discard(None) all_wallets.discard('') wallet_list = list(all_wallets) max_T_cutoff = datetime.datetime.fromtimestamp(last_trade_ts_val, tz=datetime.timezone.utc) # --- Run independent I/O queries concurrently --- cached_profiles, cached_socials = {}, {} cached_holdings = {} cached_graph_entities, cached_graph_links = {}, {} cached_image_bytes = None def _fetch_clickhouse_data(): """ClickHouse client is not thread-safe, so both queries run sequentially in one thread.""" profiles, socials = self.fetcher.fetch_wallet_profiles_and_socials(wallet_list, max_T_cutoff) holdings = self.fetcher.fetch_wallet_holdings(wallet_list, max_T_cutoff) return profiles, socials, holdings def _fetch_graph_data(): return self.fetcher.fetch_graph_links(wallet_list, max_T_cutoff, max_degrees=1) def _fetch_token_image(): try: bullx_image_url = f"https://image.bullx.io/1399811149/{token_address}?retry=0" resp = self.http_session.get(bullx_image_url, timeout=2) if resp.status_code == 200: return resp.content except: pass return None with ThreadPoolExecutor(max_workers=3) as executor: future_ch = executor.submit(_fetch_clickhouse_data) future_graph = executor.submit(_fetch_graph_data) future_image = executor.submit(_fetch_token_image) try: cached_profiles, cached_socials, cached_holdings = future_ch.result() except Exception as e: print(f" WARN: Failed to fetch ClickHouse data: {e}") try: cached_graph_entities, cached_graph_links = future_graph.result() except Exception as e: print(f" WARN: Failed to fetch graph links: {e}") cached_image_bytes = future_image.result() # --- STEP 5: Sample T_cutoffs and generate complete training contexts --- results = [] # Sample indices (with replacement if needed) if forced_cutoff_trade_idx is not None: # Forced mode: use the exact trade index provided (for evaluation) if forced_cutoff_trade_idx >= len(all_trades_sorted): print(f" WARN: forced_cutoff_trade_idx={forced_cutoff_trade_idx} >= total trades {len(all_trades_sorted)}, clamping.") forced_cutoff_trade_idx = len(all_trades_sorted) - 2 sampled_indices = [forced_cutoff_trade_idx] print(f" Using forced T_cutoff at trade index {forced_cutoff_trade_idx}") elif num_samples_per_token >= len(eligible_indices): sampled_indices = eligible_indices.copy() else: sampled_indices = random.sample(eligible_indices, num_samples_per_token) # Per-token sample count logged via tqdm in cache_dataset.py for sample_num, sample_idx in enumerate(sampled_indices): sample_trade = all_trades_sorted[sample_idx] sample_offset_ts = _timestamp_to_order_value(sample_trade.get('timestamp')) T_cutoff = datetime.datetime.fromtimestamp(sample_offset_ts, tz=datetime.timezone.utc) cutoff_ts = sample_offset_ts # Collect wallets visible at T_cutoff wallets_to_fetch = set() wallets_to_fetch.add(creator_address) for trade in raw_data.get('trades', []): if _timestamp_to_order_value(trade.get('timestamp')) <= cutoff_ts: if trade.get('maker'): wallets_to_fetch.add(trade['maker']) for transfer in raw_data.get('transfers', []): if _timestamp_to_order_value(transfer.get('timestamp')) <= cutoff_ts: if transfer.get('source'): wallets_to_fetch.add(transfer['source']) if transfer.get('destination'): wallets_to_fetch.add(transfer['destination']) for pool in raw_data.get('pool_creations', []): if _timestamp_to_order_value(pool.get('timestamp')) <= cutoff_ts: if pool.get('creator_address'): wallets_to_fetch.add(pool['creator_address']) for liq in raw_data.get('liquidity_changes', []): if _timestamp_to_order_value(liq.get('timestamp')) <= cutoff_ts: if liq.get('lp_provider'): wallets_to_fetch.add(liq['lp_provider']) # Get holder snapshot at T_cutoff elapsed = (T_cutoff - t0).total_seconds() snap_idx = int(elapsed // 300) if not (0 <= snap_idx < len(holder_snapshots_list)): raise RuntimeError( f"holder_snapshots_list index out of range in __cacheitem_context__ " f"(snap_idx={snap_idx}, len={len(holder_snapshots_list)})." ) snapshot_data = holder_snapshots_list[snap_idx] if not isinstance(snapshot_data, dict) or not isinstance(snapshot_data.get('holders'), list): raise RuntimeError("Invalid holder snapshot entry in __cacheitem_context__.") for holder in snapshot_data['holders']: if not isinstance(holder, dict) or 'wallet_address' not in holder or 'current_balance' not in holder: raise RuntimeError("Invalid holder record in __cacheitem_context__.") wallets_to_fetch.add(holder['wallet_address']) wallets_to_fetch.discard(None) wallets_to_fetch.discard('') # Build offline data for this context pooler = EmbeddingPooler() # Process token data offline (minimal main token metadata only) offline_token_data = {token_address: self._build_main_token_seed(token_address, raw_data)} if cached_image_bytes: try: cached_image = Image.open(BytesIO(cached_image_bytes)) offline_token_data[token_address]['_cached_image_pil'] = cached_image except: pass main_token_data = self._process_token_data_offline( [token_address], pooler, T_cutoff, token_data=offline_token_data ) if not main_token_data: continue # Process wallet data offline wallet_data, all_token_data = self._process_wallet_data( list(wallets_to_fetch), main_token_data.copy(), pooler, T_cutoff, profiles_override=cached_profiles, socials_override=cached_socials, holdings_override=cached_holdings ) # Generate the complete training item (with H/B/H applied via _generate_dataset_item) mint_event = { 'event_type': 'Mint', 'timestamp': int(t0.timestamp()), 'relative_ts': 0, 'wallet_address': creator_address, 'token_address': token_address, 'protocol_id': raw_data.get('protocol_id', 0) } result = self._generate_dataset_item( token_address=token_address, t0=t0, T_cutoff=T_cutoff, mint_event=mint_event, trade_records=raw_data['trades'], transfer_records=raw_data['transfers'], pool_creation_records=raw_data['pool_creations'], liquidity_change_records=raw_data['liquidity_changes'], fee_collection_records=raw_data['fee_collections'], burn_records=raw_data['burns'], supply_lock_records=raw_data['supply_locks'], migration_records=raw_data['migrations'], wallet_data=wallet_data, all_token_data=all_token_data, graph_links=cached_graph_links, graph_seed_entities=wallets_to_fetch, all_graph_entities=cached_graph_entities, future_trades_for_labels=raw_data['trades'], pooler=pooler, sample_idx=idx, cached_holders_list=holder_snapshots_list, cached_ohlc_1s=None, quality_score=None # Will be injected by cache_dataset.py ) if result is not None: results.append(result) pass # Per-context verbose logging removed for caching speed # --- OPTIONAL: Pre-compute Embeddings (if encoder provided) --- if encoder is not None: # print(f"DEBUG: Encoder provided to loader for {len(results)} contexts", flush=True) for ctx in results: self._embed_context(ctx, encoder) else: if idx == 0: print("DEBUG: No encoder provided to __cacheitem_context__", flush=True) # Final count logged via tqdm in cache_dataset.py return results