| import torch |
| from collections import defaultdict |
| import datetime |
| import random |
| import requests |
| from io import BytesIO |
| from torch.utils.data import Dataset, IterableDataset |
| from PIL import Image |
| from typing import List, Dict, Any, Optional, Union, Tuple |
| from pathlib import Path |
| import numpy as np |
| from bisect import bisect_left, bisect_right |
| from concurrent.futures import ThreadPoolExecutor |
| import json |
|
|
| |
| import models.vocabulary as vocab |
| from models.multi_modal_processor import MultiModalEncoder |
| from data.data_fetcher import DataFetcher |
| from data.context_targets import derive_movement_targets |
| from data.quant_ohlc_feature_schema import ( |
| FEATURE_INDEX, |
| SEGMENT_SECONDS, |
| FEATURE_VERSION, |
| FEATURE_VERSION_ID, |
| LOOKBACK_SECONDS, |
| TOKENS_PER_SEGMENT, |
| WINDOW_SECONDS, |
| empty_feature_dict, |
| feature_dict_to_vector, |
| ) |
| from signals.rolling_quant import compute_rolling_quant_features |
| from signals.support_resistance import compute_support_resistance_features |
| from signals.trendlines import compute_trendline_features |
| from requests.adapters import HTTPAdapter |
| from urllib3.util.retry import Retry |
|
|
| |
| QUOTE_TOKEN_DECIMALS = { |
| 'So11111111111111111111111111111111111111112': 9, |
| 'EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v': 6, |
| 'Es9vMFrzaCERmJfrF4H2FYD4KCoNkY11McCe8BenwNYB': 6, |
| } |
|
|
| |
| LARGE_TRADE_USD_THRESHOLD = 100.0 |
| LARGE_TRADE_SUPPLY_PCT_THRESHOLD = 0.03 |
| LARGE_TRANSFER_SUPPLY_PCT_THRESHOLD = 0.03 |
| SMART_WALLET_PNL_THRESHOLD = 3.0 |
| SMART_WALLET_USD_THRESHOLD = 20000.0 |
|
|
| |
| |
| CRITICAL_EVENTS = { |
| 'Mint', 'Deployer_Trade', 'SmartWallet_Trade', 'LargeTrade', 'LargeTransfer', |
| 'TokenBurn', 'SupplyLock', 'PoolCreated', 'LiquidityChange', 'Migrated', |
| 'FeeCollected', 'TrendingToken', 'BoostedToken', 'XPost', 'XRetweet', |
| 'XReply', 'XQuoteTweet', 'PumpReply', 'DexBoost_Paid', 'DexProfile_Updated', |
| 'AlphaGroup_Call', 'Channel_Call', 'CexListing', 'TikTok_Trending_Hashtag', |
| 'XTrending_Hashtag' |
| } |
|
|
| |
| SNAPSHOT_EVENTS = { |
| 'Chart_Segment', 'OnChain_Snapshot', 'HolderSnapshot', |
| 'ChainSnapshot', 'Lighthouse_Snapshot' |
| } |
|
|
| |
| COMPRESSIBLE_EVENTS = {'Trade', 'Transfer'} |
|
|
| |
| OHLC_SEQ_LEN = 300 |
|
|
| MIN_AMOUNT_TRANSFER_SUPPLY = 0.0 |
|
|
| |
| HOLDER_SNAPSHOT_INTERVAL_SEC = 300 |
| HOLDER_SNAPSHOT_TOP_K = 200 |
| DEAD_URI_RETRY_LIMIT = 2 |
| DEFAULT_TOTAL_SUPPLY_RAW = 1_000_000_000_000_000 |
| DEFAULT_TOKEN_DECIMALS = 6 |
|
|
| CONTEXT_BUCKET_NEGATIVE = "bad" |
| CONTEXT_BUCKET_POSITIVE = "good" |
|
|
|
|
| def summarize_context_window( |
| labels: Any, |
| labels_mask: Any, |
| ) -> Dict[str, Any]: |
| """ |
| Summarize a realized context window using its valid future returns. |
| |
| Base rule: |
| - each horizon contributes signed terminal PnL from buying at cutoff |
| - magnitude matters, so we compress returns with signed log1p |
| - the context is `good` only if the net score is positive |
| """ |
| if labels is None or labels_mask is None: |
| raise RuntimeError("Context weighting requires both 'labels' and 'labels_mask'.") |
|
|
| if isinstance(labels, torch.Tensor): |
| label_vals = labels.tolist() |
| else: |
| label_vals = list(labels) |
|
|
| if isinstance(labels_mask, torch.Tensor): |
| mask_vals = labels_mask.tolist() |
| else: |
| mask_vals = list(labels_mask) |
|
|
| valid_returns = [ |
| float(ret) |
| for ret, keep in zip(label_vals, mask_vals) |
| if float(keep) > 0.0 |
| ] |
| signed_contributions = [] |
| for ret in valid_returns: |
| magnitude = np.log1p(abs(ret)) |
| signed_contributions.append(magnitude if ret > 0.0 else -magnitude) |
|
|
| positive_count = sum(1 for ret in valid_returns if ret > 0.0) |
| negative_count = len(valid_returns) - positive_count |
| context_score = float(sum(signed_contributions) / len(signed_contributions)) if signed_contributions else 0.0 |
| context_bucket = ( |
| CONTEXT_BUCKET_POSITIVE |
| if context_score > 0.0 |
| else CONTEXT_BUCKET_NEGATIVE |
| ) |
| return { |
| "context_bucket": context_bucket, |
| "context_score": context_score, |
| "positive_horizons": positive_count, |
| "negative_horizons": negative_count, |
| "valid_horizons": len(valid_returns), |
| } |
|
|
|
|
| class EmbeddingPooler: |
| """ |
| A helper class to manage the collection and encoding of unique text/image items |
| for a single data sample. |
| """ |
| def __init__(self): |
| self.pool_map = {} |
| self.next_idx = 1 |
|
|
| def get_idx(self, item: Any) -> int: |
| """ |
| Returns a unique index for a given item (string or image). |
| - Returns 0 for None or empty strings. |
| - Deduplicates identical text and image objects. |
| """ |
| if item is None: |
| return 0 |
|
|
| |
| if isinstance(item, str): |
| if not item.strip(): |
| return 0 |
| key = item.strip() |
| elif isinstance(item, Image.Image): |
| key = id(item) |
| elif isinstance(item, torch.Tensor): |
| key = id(item) |
| else: |
| key = item |
|
|
| if key not in self.pool_map: |
| self.pool_map[key] = {'item': item, 'idx': self.next_idx} |
| self.next_idx += 1 |
|
|
| return self.pool_map[key]['idx'] |
|
|
| def get_all_items(self) -> List[Dict[str, Any]]: |
| """ |
| Returns a list of all unique items, sorted by their assigned index. |
| """ |
| if not self.pool_map: |
| return [] |
| return sorted(self.pool_map.values(), key=lambda x: x['idx']) |
|
|
|
|
| class OracleDataset(Dataset): |
| """ |
| Dataset class for the Oracle model. It fetches, processes, and structures |
| all on-chain and off-chain data for a given token to create a comprehensive |
| input sequence for the model. |
| """ |
| def __init__(self, |
| data_fetcher: Optional[DataFetcher] = None, |
| fetcher_config: Optional[Dict[str, Any]] = None, |
| horizons_seconds: List[int] = [], |
| quantiles: List[float] = [], |
| max_samples: Optional[int] = None, |
| min_trades: int = 10, |
| |
| token_allowlist: Optional[List[str]] = None, |
| cache_dir: Optional[Union[str, Path]] = None, |
| start_date: Optional[datetime.datetime] = None, |
| min_trade_usd: float = 0.0, |
| max_seq_len: int = 8192, |
| p99_clamps: Optional[Dict[str, float]] = None, |
| movement_label_config: Optional[Dict[str, float]] = None): |
|
|
| self.max_seq_len = max_seq_len |
| self.min_trades = int(min_trades) |
| if self.min_trades < 1: |
| raise RuntimeError(f"min_trades must be >= 1, got {self.min_trades}") |
|
|
| |
| self.p99_clamps = { |
| 'slippage': 1.0, |
| 'total_usd': 100000.0, |
| 'history_bought_cost_sol': 30.0, |
| 'realized_profit_sol': 150.0, |
| } |
| if p99_clamps: |
| self.p99_clamps.update(p99_clamps) |
| print(f"INFO: Using P99 clamps: {self.p99_clamps}") |
| |
| |
| |
| self.http_session = None |
| self._init_http_session() |
|
|
| self.fetcher = data_fetcher |
| self.fetcher_config = fetcher_config |
| self.cache_dir = Path(cache_dir) if cache_dir else None |
| |
| |
| self.weights_list = [] |
| |
| |
| self._token_meta_cache = {} |
| self._chart_feature_log_count = 0 |
| |
|
|
|
|
|
|
| self.token_allowlist = set(token_allowlist) if token_allowlist else None |
| |
| if self.cache_dir: |
| if not self.cache_dir.is_dir(): |
| raise RuntimeError( |
| f"Cache directory '{self.cache_dir}' was provided but is not a directory. " |
| "Fix the path or disable cached mode." |
| ) |
| |
| print(f"INFO: Initializing dataset in offline (cached) mode from: {self.cache_dir}") |
| |
| def _sort_key(p): |
| |
| parts = p.stem.split('_') |
| if len(parts) >= 2: |
| try: |
| return (0, int(parts[1])) |
| except ValueError: |
| return (1, parts[1]) |
| return (2, p.stem) |
| self.cached_files = sorted(self.cache_dir.glob("sample_*.pt"), key=_sort_key) |
| if not self.cached_files: |
| raise RuntimeError(f"Cache directory '{self.cache_dir}' provided but contains no 'sample_*.pt' files.") |
|
|
| |
| file_class_map = {} |
| file_context_bucket_map = {} |
| file_context_summary_map = {} |
| class_counts = defaultdict(int) |
| class_context_counts = defaultdict(lambda: defaultdict(int)) |
| metadata_path = self.cache_dir / "class_metadata.json" |
|
|
| if metadata_path.exists(): |
| |
| print(f"INFO: Loading class metadata from cache: {metadata_path}") |
| try: |
| with open(metadata_path, 'r') as f: |
| cached_metadata = json.load(f) |
| file_class_map = cached_metadata.get('file_class_map', {}) |
| file_context_bucket_map = cached_metadata.get('file_context_bucket_map', {}) |
| file_context_summary_map = cached_metadata.get('file_context_summary_map', {}) |
| |
| cached_file_names = {p.name for p in self.cached_files} |
| metadata_file_names = set(file_class_map.keys()) |
| if cached_file_names != metadata_file_names: |
| print(f"WARN: Metadata cache mismatch ({len(cached_file_names)} files vs {len(metadata_file_names)} in metadata). Rebuilding...") |
| file_class_map = {} |
| file_context_bucket_map = {} |
| file_context_summary_map = {} |
| else: |
| |
| for fname, cid in file_class_map.items(): |
| class_counts[cid] += 1 |
| bucket = file_context_bucket_map.get(fname) |
| if bucket is not None: |
| class_context_counts[cid][bucket] += 1 |
| print(f"INFO: Loaded metadata for {len(file_class_map)} samples in <1s") |
| except Exception as e: |
| print(f"WARN: Failed to load metadata cache: {e}. Rebuilding...") |
| file_class_map = {} |
| file_context_bucket_map = {} |
| file_context_summary_map = {} |
|
|
| |
| if not file_class_map: |
| print(f"INFO: Building class metadata from {len(self.cached_files)} files (first run only)...") |
| for i, p in enumerate(self.cached_files): |
| if i > 0 and i % 1000 == 0: |
| print(f" Scanned {i}/{len(self.cached_files)} files...") |
| try: |
| try: |
| cached_item = torch.load(p, map_location="cpu", weights_only=False) |
| except TypeError: |
| cached_item = torch.load(p, map_location="cpu") |
| cid = cached_item.get("class_id") |
| if cid is None: |
| print(f"WARN: File {p.name} missing class_id. Skipping.") |
| continue |
| context_summary = summarize_context_window( |
| cached_item.get("labels"), |
| cached_item.get("labels_mask"), |
| ) |
| bucket = context_summary["context_bucket"] |
| file_class_map[p.name] = cid |
| file_context_bucket_map[p.name] = bucket |
| file_context_summary_map[p.name] = context_summary |
| class_counts[cid] += 1 |
| class_context_counts[cid][bucket] += 1 |
| except Exception as e: |
| print(f"WARN: Failed to read cached sample {p.name}: {e}") |
|
|
| |
| try: |
| with open(metadata_path, 'w') as f: |
| json.dump({ |
| 'file_class_map': file_class_map, |
| 'file_context_bucket_map': file_context_bucket_map, |
| 'file_context_summary_map': file_context_summary_map, |
| }, f) |
| print(f"INFO: Saved class metadata cache to {metadata_path}") |
| except Exception as e: |
| print(f"WARN: Failed to save metadata cache: {e}") |
|
|
| print(f"INFO: Class Distribution: {dict(class_counts)}") |
| print( |
| "INFO: Context Distribution by Class: " |
| f"{ {cid: dict(bucket_counts) for cid, bucket_counts in class_context_counts.items()} }" |
| ) |
|
|
| |
| self.file_class_map = {p: cid for p, cid in file_class_map.items()} |
| self.file_context_bucket_map = {p: bucket for p, bucket in file_context_bucket_map.items()} |
| self.file_context_summary_map = {p: summary for p, summary in file_context_summary_map.items()} |
|
|
| |
| self.weights_list = [] |
| valid_files = [] |
|
|
| |
| for p in self.cached_files: |
| fname = p.name |
| if fname not in file_class_map: |
| |
| print(f"WARN: File {fname} found in cache but missing class_id. Skipping.") |
| continue |
|
|
| cid = file_class_map[fname] |
| bucket = file_context_bucket_map.get(fname) |
| if bucket is None: |
| raise RuntimeError( |
| f"Cached sample '{fname}' is missing a context bucket. " |
| "Rebuild metadata or cache before training." |
| ) |
| class_bucket_counts = class_context_counts[cid] |
| present_buckets = [name for name, cnt in class_bucket_counts.items() if cnt > 0] |
| if not present_buckets: |
| raise RuntimeError( |
| f"Class {cid} has no valid context buckets recorded. Cannot compute sampler weights." |
| ) |
| bucket_count = class_bucket_counts[bucket] |
| if bucket_count <= 0: |
| raise RuntimeError( |
| f"Class {cid} bucket '{bucket}' has invalid count {bucket_count} for sample '{fname}'." |
| ) |
| weight = 1.0 / (len(present_buckets) * bucket_count) |
| self.weights_list.append(weight) |
| valid_files.append(p) |
| |
| self.cached_files = valid_files |
| self.num_samples = len(self.cached_files) |
| |
| if max_samples is not None: |
| self.num_samples = min(max_samples, self.num_samples) |
| self.cached_files = self.cached_files[:self.num_samples] |
| self.weights_list = self.weights_list[:self.num_samples] |
|
|
| |
| |
| active_class_context_counts = defaultdict(lambda: defaultdict(int)) |
| for p in self.cached_files: |
| fname = p.name |
| cid = file_class_map[fname] |
| bucket = file_context_bucket_map[fname] |
| active_class_context_counts[cid][bucket] += 1 |
|
|
| self.weights_list = [] |
| for p in self.cached_files: |
| fname = p.name |
| cid = file_class_map[fname] |
| bucket = file_context_bucket_map[fname] |
| class_bucket_counts = active_class_context_counts[cid] |
| present_buckets = [name for name, cnt in class_bucket_counts.items() if cnt > 0] |
| bucket_count = class_bucket_counts[bucket] |
| self.weights_list.append(1.0 / (len(present_buckets) * bucket_count)) |
|
|
| print(f"INFO: Weighted Dataset Ready. {self.num_samples} samples.") |
| self.sampled_mints = [] |
| self.available_mints = [] |
|
|
| elif self.fetcher: |
| print(f"INFO: Initializing dataset in online (generation) mode...") |
| self.available_mints = self.fetcher.get_all_mints(start_date=start_date) |
| if not self.available_mints: |
| raise RuntimeError("Dataset initialization failed: no mint records returned from data fetcher.") |
| if self.token_allowlist: |
| filtered_mints = [ |
| mint for mint in self.available_mints |
| if mint.get('mint_address') in self.token_allowlist |
| ] |
| if not filtered_mints: |
| raise RuntimeError(f"No mint records matched the provided token allowlist: {token_allowlist}") |
| self.available_mints = filtered_mints |
|
|
| total_mints = len(self.available_mints) |
| if max_samples is None: |
| self.num_samples = total_mints |
| self.sampled_mints = self.available_mints |
| else: |
| self.num_samples = min(max_samples, total_mints) |
| if self.num_samples < total_mints: |
| print(f"INFO: Limiting dataset to first {self.num_samples} of {total_mints} available mints.") |
| self.sampled_mints = self.available_mints[:self.num_samples] |
| else: |
| self.available_mints = [] |
| self.sampled_mints = [] |
| self.num_samples = 1 if max_samples is None else max_samples |
|
|
| self.horizons_seconds = sorted(set(horizons_seconds)) |
| self.quantiles = quantiles |
| self.num_outputs = len(self.horizons_seconds) * len(self.quantiles) |
| if self.horizons_seconds: |
| self.max_cache_horizon_seconds = max(self.horizons_seconds) |
| else: |
| self.max_cache_horizon_seconds = 3600 |
|
|
|
|
|
|
| self.min_trade_usd = min_trade_usd |
| self._uri_fail_counts: Dict[str, int] = {} |
| self.movement_label_config = movement_label_config |
|
|
| def _init_http_session(self) -> None: |
| |
| self.http_session = requests.Session() |
| retry_strategy = Retry( |
| total=3, |
| backoff_factor=1, |
| status_forcelist=[429, 500, 502, 503, 504], |
| allowed_methods=["HEAD", "GET", "OPTIONS"] |
| ) |
| adapter = HTTPAdapter(max_retries=retry_strategy) |
| self.http_session.mount("http://", adapter) |
| self.http_session.mount("https://", adapter) |
|
|
| def init_fetcher(self) -> None: |
| """ |
| Initialize DataFetcher from stored config (for DataLoader workers). |
| """ |
| if self.fetcher is not None or not self.fetcher_config: |
| return |
| from clickhouse_driver import Client as ClickHouseClient |
| from neo4j import GraphDatabase |
| cfg = self.fetcher_config |
| clickhouse_client = ClickHouseClient( |
| host=cfg.get("clickhouse_host", "localhost"), |
| port=int(cfg.get("clickhouse_port", 9000)), |
| ) |
| neo4j_driver = GraphDatabase.driver( |
| cfg.get("neo4j_uri", "bolt://localhost:7687"), |
| auth=(cfg.get("neo4j_user", "neo4j"), cfg.get("neo4j_password", "password")) |
| ) |
| self.fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver) |
|
|
| def __getstate__(self): |
| state = self.__dict__.copy() |
| |
| state["fetcher"] = None |
| state["http_session"] = None |
| return state |
|
|
| def __setstate__(self, state): |
| self.__dict__.update(state) |
| if self.http_session is None: |
| self._init_http_session() |
|
|
| def __len__(self) -> int: |
| return self.num_samples |
|
|
| def get_weights(self) -> torch.DoubleTensor: |
| """Returns the sampling weights for the dataset.""" |
| if hasattr(self, 'weights_list') and self.weights_list: |
| return torch.as_tensor(self.weights_list, dtype=torch.double) |
| return None |
|
|
| def _normalize_price_series(self, values: List[float]) -> List[float]: |
| if not values: |
| return values |
| import math |
| return [math.log(float(v)) if float(v) > 1e-9 else 0.0 for v in values] |
|
|
| def _is_dead_uri(self, uri: Optional[str]) -> bool: |
| if not uri: |
| return False |
| return self._uri_fail_counts.get(uri, 0) >= DEAD_URI_RETRY_LIMIT |
|
|
| def _mark_uri_failure(self, uri: Optional[str]) -> None: |
| if not uri: |
| return |
| self._uri_fail_counts[uri] = self._uri_fail_counts.get(uri, 0) + 1 |
|
|
| def _apply_dynamic_sampling(self, events: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| """ |
| Applies dynamic context sampling to fit events within max_seq_len. |
| |
| Priority: |
| 1. CRITICAL events (always kept) |
| 2. SNAPSHOT events (kept for continuity) |
| 3. COMPRESSIBLE events (Trade/Transfer) - split into Head/Tail with MIDDLE token |
| |
| Uses existing 'MIDDLE' and 'RECENT' tokens to mark transitions. |
| """ |
| if len(events) <= self.max_seq_len: |
| return events |
| |
| |
| critical_events = [] |
| snapshot_events = [] |
| compressible_events = [] |
| |
| for idx, event in enumerate(events): |
| event_type = event.get('event_type', '') |
| if event_type in CRITICAL_EVENTS: |
| critical_events.append((idx, event)) |
| elif event_type in SNAPSHOT_EVENTS: |
| snapshot_events.append((idx, event)) |
| elif event_type in COMPRESSIBLE_EVENTS: |
| compressible_events.append((idx, event)) |
| else: |
| |
| critical_events.append((idx, event)) |
| |
| |
| |
| reserved_tokens = 2 |
| fixed_count = len(critical_events) + len(snapshot_events) + reserved_tokens |
| budget_for_compressible = max(0, self.max_seq_len - fixed_count) |
| |
| |
| if budget_for_compressible == 0 or len(compressible_events) <= budget_for_compressible: |
| |
| all_events = critical_events + snapshot_events + compressible_events |
| all_events.sort(key=lambda x: x[0]) |
| return [e[1] for e in all_events] |
| |
| |
| head_size = budget_for_compressible // 2 |
| tail_size = budget_for_compressible - head_size |
| |
| head_events = compressible_events[:head_size] |
| tail_events = compressible_events[-tail_size:] if tail_size > 0 else [] |
| |
| |
| |
| middle_marker_idx = head_events[-1][0] if head_events else 0 |
| recent_marker_idx = tail_events[0][0] if tail_events else len(events) |
| |
| |
| middle_marker = { |
| 'event_type': 'MIDDLE', |
| 'relative_ts': events[middle_marker_idx].get('relative_ts', 0) if middle_marker_idx < len(events) else 0, |
| 'is_marker': True |
| } |
| recent_marker = { |
| 'event_type': 'RECENT', |
| 'relative_ts': events[recent_marker_idx - 1].get('relative_ts', 0) if recent_marker_idx > 0 and recent_marker_idx <= len(events) else 0, |
| 'is_marker': True |
| } |
| |
| |
| |
| all_indexed_events = critical_events + snapshot_events + head_events + tail_events |
| |
| |
| middle_idx = middle_marker_idx + 0.5 |
| recent_idx = recent_marker_idx - 0.5 |
| |
| all_indexed_events.append((middle_idx, middle_marker)) |
| all_indexed_events.append((recent_idx, recent_marker)) |
| |
| |
| all_indexed_events.sort(key=lambda x: x[0]) |
| |
| return [e[1] for e in all_indexed_events] |
|
|
| def _compute_future_return_labels(self, |
| anchor_price: Optional[float], |
| anchor_timestamp: int, |
| price_series: List[Tuple[int, float]]) -> Tuple[torch.Tensor, torch.Tensor, List[Dict[str, Any]]]: |
| if self.num_outputs == 0: |
| return torch.zeros(0), torch.zeros(0), [] |
|
|
| if anchor_price is None or abs(anchor_price) < 1e-9 or not price_series: |
| return torch.zeros(self.num_outputs), torch.zeros(self.num_outputs), [] |
|
|
| ts_list = [int(entry[0]) for entry in price_series] |
| price_list = [float(entry[1]) for entry in price_series] |
| if not ts_list: |
| return torch.zeros(self.num_outputs), torch.zeros(self.num_outputs), [] |
|
|
| last_ts = ts_list[-1] |
|
|
| label_values: List[float] = [] |
| mask_values: List[float] = [] |
| debug_entries: List[Dict[str, Any]] = [] |
|
|
| for horizon in self.horizons_seconds: |
| target_ts = anchor_timestamp + horizon |
| if target_ts > last_ts: |
| horizon_mask = 0.0 |
| horizon_return = 0.0 |
| future_price = None |
| else: |
| idx = bisect_right(ts_list, target_ts) - 1 |
| if idx < 0: |
| horizon_mask = 0.0 |
| horizon_return = 0.0 |
| future_price = None |
| else: |
| future_price = price_list[idx] |
| horizon_return = (future_price - anchor_price) / anchor_price |
| horizon_return = max(min(horizon_return, 10.0), -10.0) |
| horizon_mask = 1.0 |
|
|
| for _ in self.quantiles: |
| label_values.append(horizon_return) |
| mask_values.append(horizon_mask) |
| debug_entries.append({ |
| 'horizon': horizon, |
| 'target_ts': target_ts, |
| 'future_price': future_price, |
| 'return': horizon_return, |
| 'mask': horizon_mask |
| }) |
|
|
| return (torch.tensor(label_values, dtype=torch.float32), |
| torch.tensor(mask_values, dtype=torch.float32), |
| debug_entries) |
|
|
| def _generate_onchain_snapshots( |
| self, |
| token_address: str, |
| t0_timestamp: int, |
| T_cutoff: datetime.datetime, |
| interval_sec: int, |
| trade_events: List[Dict[str, Any]], |
| transfer_events: List[Dict[str, Any]], |
| aggregation_trades: List[Dict[str, Any]], |
| wallet_data: Dict[str, Any], |
| total_supply_dec: float, |
| _register_event_fn, |
| cached_holders_list: List[Dict[str, Any]] = None |
| ) -> None: |
| if cached_holders_list is None: |
| raise RuntimeError( |
| f"Missing holder_snapshots_list for token {token_address} in _generate_onchain_snapshots." |
| ) |
| |
| all_buy_trades = sorted([e for e in trade_events if e.get('trade_direction') == 0 and e.get('success', False)], key=lambda x: x['timestamp']) |
| sniper_wallets = [] |
| seen_buyers = set() |
| for e in all_buy_trades: |
| wa = e['wallet_address'] |
| if wa not in seen_buyers: |
| sniper_wallets.append(wa) |
| seen_buyers.add(wa) |
| if len(sniper_wallets) >= 70: |
| break |
| sniper_set = set(sniper_wallets) |
|
|
| KOL_NAME_KEYS = ['kolscan_name', 'cabalspy_name', 'axiom_kol_name'] |
|
|
| |
| agg_ts = [int(t['timestamp']) for t in aggregation_trades] if aggregation_trades else [] |
| agg_price = [float(t.get('price_usd', 0.0) or 0.0) for t in aggregation_trades] if aggregation_trades else [] |
|
|
| start_ts = t0_timestamp |
| end_ts = int(self._timestamp_to_order_value(T_cutoff)) if hasattr(self, '_timestamp_to_order_value') else int(T_cutoff.timestamp()) |
|
|
| buyers_seen_global = set() |
| prev_holders_count = 0 |
| |
| for i, snapshot_data in enumerate(cached_holders_list): |
| if not isinstance(snapshot_data, dict): |
| continue |
| ts_value = snapshot_data.get('timestamp') |
| if ts_value is None or ts_value > end_ts: |
| break |
| |
| window_start = ts_value - interval_sec |
| trades_win = [e for e in trade_events if e.get('success', False) and window_start < e['timestamp'] <= ts_value] |
| xfers_win = [e for e in transfer_events if window_start < e['timestamp'] <= ts_value] |
|
|
| |
| if not trades_win and not xfers_win: |
| continue |
|
|
| if 'holders' not in snapshot_data or not isinstance(snapshot_data['holders'], list): |
| continue |
|
|
| holder_records_ts = snapshot_data['holders'] |
| holders_end = 0 |
| holder_entries_ts = [] |
| for rec in holder_records_ts: |
| if not isinstance(rec, dict): |
| raise RuntimeError( |
| f"Malformed holder record for token {token_address} at index {i}: expected dict." |
| ) |
| if 'wallet_address' not in rec or 'current_balance' not in rec: |
| raise RuntimeError( |
| f"Malformed holder record for token {token_address} at index {i}: requires wallet_address/current_balance." |
| ) |
| addr = rec['wallet_address'] |
| bal = float(rec['current_balance']) |
| pct = (bal / total_supply_dec) if total_supply_dec and total_supply_dec > 0 else 0.0 |
| if addr and pct > 0.0: |
| holder_entries_ts.append({'wallet': addr, 'holding_pct': pct}) |
| holders_end += 1 |
| holder_entries_ts.sort(key=lambda d: d['holding_pct'], reverse=True) |
|
|
| |
| hs_event = { |
| 'event_type': 'HolderSnapshot', |
| 'timestamp': int(ts_value), |
| 'relative_ts': ts_value - t0_timestamp, |
| 'holders': holder_entries_ts |
| } |
| _register_event_fn( |
| hs_event, |
| self._event_execution_sort_key(ts_value, slot=10**12, transaction_index=10**9, signature='HolderSnapshot') |
| if hasattr(self, '_event_execution_sort_key') else (ts_value, 10**12, 10**9, 0, 'HolderSnapshot') |
| ) |
|
|
| holder_pct_map_ts = {d['wallet']: d['holding_pct'] for d in holder_entries_ts} |
| top10_holder_pct = sum(d['holding_pct'] for d in holder_entries_ts[:10]) if holder_entries_ts else 0.0 |
|
|
| |
| rat_set_ts = set(ev['destination_wallet_address'] for ev in transfer_events if ev['timestamp'] <= ts_value) |
| bundle_buyer_set_ts = set(e['wallet_address'] for e in trade_events if e.get('is_bundle') and e.get('trade_direction') == 0 and e.get('success', False) and e['timestamp'] <= ts_value) |
|
|
| buy_count = sum(1 for e in trades_win if e.get('trade_direction') == 0) |
| sell_count = sum(1 for e in trades_win if e.get('trade_direction') == 1) |
| volume = sum(float(e.get('total_usd', 0.0) or 0.0) for e in trades_win) |
| total_txns = len(trades_win) + len(xfers_win) |
| global_fees_paid = sum( |
| float(e.get('priority_fee', 0.0) or 0.0) + float(e.get('bribe_fee', 0.0) or 0.0) |
| for e in trades_win |
| ) |
|
|
| smart_trader_addrs = set( |
| e['wallet_address'] for e in trade_events |
| if e.get('event_type') == 'SmartWallet_Trade' |
| and e.get('success', False) |
| and e['timestamp'] <= ts_value |
| and holder_pct_map_ts.get(e['wallet_address'], 0.0) > 0.0 |
| ) |
| smart_traders = len(smart_trader_addrs) |
|
|
| kol_addrs = set() |
| for e in trades_win: |
| wa = e['wallet_address'] |
| soc = wallet_data.get(wa, {}).get('socials', {}) |
| if any(soc.get(k) for k in KOL_NAME_KEYS if soc): |
| kol_addrs.add(wa) |
| kols = len(kol_addrs) |
|
|
| new_buyers = [e['wallet_address'] for e in trades_win if e.get('trade_direction') == 0 and e['wallet_address'] not in buyers_seen_global] |
| for wa in new_buyers: |
| buyers_seen_global.add(wa) |
|
|
| |
| |
| total_holders = float(holders_end) |
| delta_holders = holders_end - prev_holders_count |
| holder_growth_rate = float(delta_holders) |
| prev_holders_count = holders_end |
|
|
| |
| last_price_usd = 0.0 |
| if agg_ts: |
| for i in range(len(agg_ts) - 1, -1, -1): |
| if agg_ts[i] <= ts_value: |
| last_price_usd = agg_price[i] |
| break |
| current_market_cap = float(last_price_usd) * float(total_supply_dec) |
|
|
| oc_event = { |
| 'event_type': 'OnChain_Snapshot', |
| 'timestamp': int(ts_value), |
| 'relative_ts': ts_value - t0_timestamp, |
| 'total_holders': total_holders, |
| 'smart_traders': float(smart_traders), |
| 'kols': float(kols), |
| 'holder_growth_rate': float(holder_growth_rate), |
| 'top_10_holder_pct': float(top10_holder_pct), |
| 'sniper_holding_pct': float(sum(holder_pct_map_ts.get(wa, 0.0) for wa in sniper_set)), |
| 'rat_wallets_holding_pct': float(sum(holder_pct_map_ts.get(wa, 0.0) for wa in rat_set_ts)), |
| 'bundle_holding_pct': float(sum(holder_pct_map_ts.get(wa, 0.0) for wa in bundle_buyer_set_ts)), |
| 'current_market_cap': float(current_market_cap), |
| 'volume': float(volume), |
| 'buy_count': float(buy_count), |
| 'sell_count': float(sell_count), |
| 'total_txns': float(total_txns), |
| 'global_fees_paid': float(global_fees_paid) |
| } |
| _register_event_fn( |
| oc_event, |
| self._event_execution_sort_key(ts_value, slot=10**12, transaction_index=10**9, signature='OnChain_Snapshot') |
| if hasattr(self, '_event_execution_sort_key') else (ts_value, 10**12, 10**9, 0, 'OnChain_Snapshot') |
| ) |
|
|
| def _calculate_deployed_token_stats(self, profiles: Dict[str, Dict[str, Any]], T_cutoff: datetime.datetime): |
| """ |
| Calculates aggregate statistics for wallets based on the tokens they've deployed. |
| This method modifies the `profiles` dictionary in-place. |
| """ |
| if not profiles: return |
|
|
| |
| all_deployed_tokens = set() |
| for addr, profile in profiles.items(): |
| deployed_tokens = profile.get('deployed_tokens', []) |
| all_deployed_tokens.update(deployed_tokens) |
| |
| |
| all_deployed_token_details = {} |
| if all_deployed_tokens and self.fetcher: |
| all_deployed_token_details = self.fetcher.fetch_deployed_token_details(list(all_deployed_tokens), T_cutoff) |
|
|
| for addr, profile in profiles.items(): |
| deployed_tokens = profile.get('deployed_tokens', []) |
| |
| |
| count = len(deployed_tokens) |
| profile['deployed_tokens_count'] = float(count) |
|
|
| if count == 0: |
| profile['deployed_tokens_migrated_pct'] = 0.0 |
| profile['deployed_tokens_avg_lifetime_sec'] = 0.0 |
| profile['deployed_tokens_avg_peak_mc_usd'] = 0.0 |
| profile['deployed_tokens_median_peak_mc_usd'] = 0.0 |
| continue |
|
|
| |
| lifetimes = [] |
| peak_mcs = [] |
| migrated_count = 0 |
| for token_addr in deployed_tokens: |
| details = all_deployed_token_details.get(token_addr) |
| if not details: continue |
|
|
| if details.get('has_migrated'): |
| migrated_count += 1 |
| |
| lifetimes.append((details['updated_at'] - details['created_at']).total_seconds()) |
| peak_mcs.append(details.get('ath_price_usd', 0.0) * details.get('total_supply', 0.0) / (10**details.get('decimals', 9))) |
|
|
| |
| profile['deployed_tokens_migrated_pct'] = (migrated_count / count) if count > 0 else 0.0 |
| |
| profile['deployed_tokens_avg_lifetime_sec'] = torch.mean(torch.tensor(lifetimes)).item() if lifetimes else 0.0 |
| |
| profile['deployed_tokens_avg_peak_mc_usd'] = torch.mean(torch.tensor(peak_mcs)).item() if peak_mcs else 0.0 |
| profile['deployed_tokens_median_peak_mc_usd'] = torch.median(torch.tensor(peak_mcs)).item() if peak_mcs else 0.0 |
| |
| def _process_wallet_data(self, wallet_addresses: List[str], token_data: Dict[str, Any], pooler: EmbeddingPooler, T_cutoff: datetime.datetime, |
| profiles_override: Optional[Dict] = None, socials_override: Optional[Dict] = None, holdings_override: Optional[Dict] = None) -> tuple[Dict[str, Dict[str, Any]], Dict[str, Dict[str, Any]]]: |
| """ |
| Fetches or uses cached profile, social, and holdings data. |
| """ |
| import time as _time |
| _wd_timings = {} |
|
|
| if not wallet_addresses: |
| return {}, token_data |
|
|
| _t0 = _time.perf_counter() |
| if profiles_override is not None and socials_override is not None: |
| profiles, socials = profiles_override, socials_override |
| holdings = holdings_override if holdings_override is not None else {} |
| else: |
| if self.fetcher: |
| profiles, socials = self.fetcher.fetch_wallet_profiles_and_socials(wallet_addresses, T_cutoff) |
| holdings = self.fetcher.fetch_wallet_holdings(wallet_addresses, T_cutoff) |
| else: |
| profiles, socials, holdings = {}, {}, {} |
| _wd_timings['db_fetch'] = _time.perf_counter() - _t0 |
|
|
| valid_wallets = [addr for addr in wallet_addresses if addr in profiles] |
| if not valid_wallets: |
| return {}, token_data |
| wallet_addresses = valid_wallets |
|
|
| |
| |
| |
| seed_token_addresses = set(token_data.keys()) |
| all_holding_mints = set() |
| top_holding_mints = set() |
| for wallet_addr in wallet_addresses: |
| wallet_holds = holdings.get(wallet_addr, []) |
| for holding_item in wallet_holds: |
| mint_addr = holding_item.get('mint_address') |
| if mint_addr and mint_addr not in seed_token_addresses: |
| all_holding_mints.add(mint_addr) |
| |
| sorted_holds = sorted(wallet_holds, key=lambda h: float(h.get('total_volume_usd', 0) or 0), reverse=True) |
| for h in sorted_holds[:2]: |
| mint_addr = h.get('mint_address') |
| if mint_addr and mint_addr not in seed_token_addresses: |
| top_holding_mints.add(mint_addr) |
|
|
| |
| top_holding_mints = set(list(top_holding_mints)[:10]) |
| rest_holding_mints = all_holding_mints - top_holding_mints |
| _wd_timings['num_holding_tokens'] = len(all_holding_mints) |
|
|
| |
| _t0 = _time.perf_counter() |
| top_tokens = self._process_token_data(list(top_holding_mints), pooler, T_cutoff) if top_holding_mints else {} |
| rest_tokens = self._process_token_data_lightweight(list(rest_holding_mints), pooler, T_cutoff) if rest_holding_mints else {} |
| processed_new_tokens = {**top_tokens, **rest_tokens} |
| _wd_timings['holding_token_processing'] = _time.perf_counter() - _t0 |
| |
| all_token_data = dict(token_data) |
| for addr, data in (processed_new_tokens or {}).items(): |
| if addr in all_token_data: |
| continue |
| all_token_data[addr] = data |
|
|
| |
| self._calculate_deployed_token_stats(profiles, T_cutoff) |
|
|
| |
| final_wallets = {} |
| for addr in wallet_addresses: |
|
|
| |
| expected_profile_keys = [ |
| 'deployed_tokens_count', 'deployed_tokens_migrated_pct', |
| 'deployed_tokens_avg_lifetime_sec', 'deployed_tokens_avg_peak_mc_usd', |
| 'deployed_tokens_median_peak_mc_usd', 'balance', 'transfers_in_count', |
| 'transfers_out_count', 'spl_transfers_in_count', 'spl_transfers_out_count', |
| 'total_buys_count', 'total_sells_count', 'total_winrate', |
| 'stats_1d_realized_profit_sol', 'stats_1d_realized_profit_pnl', 'stats_1d_buy_count', |
| 'stats_1d_sell_count', 'stats_1d_transfer_in_count', 'stats_1d_transfer_out_count', |
| 'stats_1d_avg_holding_period', 'stats_1d_total_bought_cost_sol', 'stats_1d_total_sold_income_sol', |
| 'stats_1d_total_fee', 'stats_1d_winrate', 'stats_1d_tokens_traded', |
| 'stats_7d_realized_profit_sol', 'stats_7d_realized_profit_pnl', 'stats_7d_buy_count', 'stats_7d_sell_count', 'stats_7d_transfer_in_count', 'stats_7d_transfer_out_count', 'stats_7d_avg_holding_period', 'stats_7d_total_bought_cost_sol', 'stats_7d_total_sold_income_sol', 'stats_7d_total_fee', 'stats_7d_winrate', 'stats_7d_tokens_traded' |
| ] |
|
|
| profile_data = profiles.get(addr, None) |
| if not profile_data: |
| continue |
|
|
| for key in expected_profile_keys: |
| profile_data.setdefault(key, 0.0) |
|
|
| social_data = socials.get(addr, {}) |
| |
| |
| social_data['has_pf_profile'] = bool(social_data.get('pumpfun_username')) |
| social_data['has_twitter'] = bool(social_data.get('twitter_username')) |
| social_data['has_telegram'] = bool(social_data.get('telegram_channel')) |
| social_data['is_exchange_wallet'] = 'exchange_wallet' in profile_data.get('tags', []) |
|
|
|
|
|
|
| username = social_data.get('pumpfun_username') or social_data.get('twitter_username') or social_data.get('kolscan_name') |
|
|
| if isinstance(username, str) and username.strip(): |
| social_data['username_emb_idx'] = pooler.get_idx(username.strip()) |
| else: |
| social_data['username_emb_idx'] = 0 |
|
|
| |
| original_holdings = holdings.get(addr, []) |
| valid_wallet_holdings = [] |
| now_ts = datetime.datetime.now(datetime.timezone.utc) |
| for holding_item in original_holdings: |
| |
| start_ts = holding_item.get('start_holding_at') |
| mint_addr = holding_item.get('mint_address') |
| token_info = all_token_data.get(mint_addr) |
|
|
| if not token_info: |
| continue |
|
|
| end_ts = holding_item.get('end_holding_at') |
| if not start_ts: |
| holding_item['holding_time'] = 0.0 |
| else: |
| end_ts = end_ts or now_ts |
| holding_item['holding_time'] = (end_ts - start_ts).total_seconds() |
| |
| |
| if token_info and token_info.get('total_supply', 0) > 0: |
| total_supply = token_info['total_supply'] / (10**token_info.get('decimals', 9)) |
| current_balance = holding_item.get('current_balance', 0.0) |
| holding_item['balance_pct_to_supply'] = (current_balance / total_supply) if total_supply > 0 else 0.0 |
| else: |
| holding_item['balance_pct_to_supply'] = 0.0 |
|
|
| |
| wallet_native_balance = profile_data.get('balance', 0.0) |
| bought_cost_sol = holding_item.get('history_bought_cost_sol', 0.0) |
| if wallet_native_balance > 1e-9: |
| holding_item['bought_amount_sol_pct_to_native_balance'] = bought_cost_sol / wallet_native_balance |
| else: |
| holding_item['bought_amount_sol_pct_to_native_balance'] = 0.0 |
|
|
| |
| compact_holding = { |
| 'mint_address': mint_addr, |
| 'holding_time': float(holding_item.get('holding_time', 0.0) or 0.0), |
| 'balance_pct_to_supply': min(1.0, float(holding_item.get('balance_pct_to_supply', 0.0) or 0.0)), |
| 'history_bought_cost_sol': min(self.p99_clamps['history_bought_cost_sol'], float(holding_item.get('history_bought_cost_sol', 0.0) or 0.0)), |
| 'bought_amount_sol_pct_to_native_balance': min(1.0, float(holding_item.get('bought_amount_sol_pct_to_native_balance', 0.0) or 0.0)), |
| 'history_total_buys': float(holding_item.get('history_total_buys', 0.0) or 0.0), |
| 'history_total_sells': float(holding_item.get('history_total_sells', 0.0) or 0.0), |
| 'realized_profit_pnl': float(holding_item.get('realized_profit_pnl', 0.0) or 0.0), |
| 'realized_profit_sol': max(-self.p99_clamps['realized_profit_sol'], min(self.p99_clamps['realized_profit_sol'], float(holding_item.get('realized_profit_sol', 0.0) or 0.0))), |
| 'history_transfer_in': float(holding_item.get('history_transfer_in', 0.0) or 0.0), |
| 'history_transfer_out': float(holding_item.get('history_transfer_out', 0.0) or 0.0), |
| 'avarage_trade_gap_seconds': float(holding_item.get('avarage_trade_gap_seconds', 0.0) or 0.0), |
| 'total_fees': float(holding_item.get('total_fees', 0.0) or 0.0), |
| } |
| valid_wallet_holdings.append(compact_holding) |
|
|
| |
| compact_profile = {'wallet_address': addr} |
| for key in expected_profile_keys: |
| compact_profile[key] = float(profile_data.get(key, 0.0) or 0.0) |
|
|
|
|
| compact_social = { |
| 'has_pf_profile': bool(social_data.get('has_pf_profile', False)), |
| 'has_twitter': bool(social_data.get('has_twitter', False)), |
| 'has_telegram': bool(social_data.get('has_telegram', False)), |
| 'is_exchange_wallet': bool(social_data.get('is_exchange_wallet', False)), |
| 'username_emb_idx': int(social_data.get('username_emb_idx', 0) or 0), |
| } |
|
|
| final_wallets[addr] = { |
| 'profile': compact_profile, |
| 'socials': compact_social, |
| 'holdings': valid_wallet_holdings |
| } |
|
|
| return final_wallets, all_token_data |
|
|
| def _process_token_data(self, token_addresses: List[str], pooler: EmbeddingPooler, T_cutoff: datetime.datetime, token_data: Optional[Dict] = None) -> Dict[str, Dict[str, Any]]: |
| """ |
| Fetches and processes static data for a list of tokens. |
| """ |
| if not token_addresses: |
| return {} |
| |
| if token_data is None: |
| |
| if not token_addresses: |
| token_data = {} |
| else: |
| valid_token_data = {} |
| missing_tokens = [] |
| |
| for addr in token_addresses: |
| if addr in self._token_meta_cache: |
| valid_token_data[addr] = self._token_meta_cache[addr].copy() |
| else: |
| missing_tokens.append(addr) |
| |
| |
| if missing_tokens and self.fetcher: |
| fetched = self.fetcher.fetch_token_data(missing_tokens, T_cutoff) |
| |
| for addr, data in fetched.items(): |
| if addr: |
| self._token_meta_cache[addr] = data |
| valid_token_data[addr] = data.copy() |
| |
| token_data = valid_token_data |
| |
| |
|
|
| |
| |
| valid_token_data = {} |
| for addr, data in token_data.items(): |
| |
| |
| image = None |
| try: |
| bullx_image_url = f"https://image.bullx.io/1399811149/{addr}?retry=0" |
| direct_resp = self.http_session.get(bullx_image_url, timeout=5) |
| if direct_resp.status_code == 200: |
| try: |
| image = Image.open(BytesIO(direct_resp.content)) |
| except Exception as e: |
| print(f"WARN: Failed to process image from Bullx for {addr}: {e}") |
| image = None |
| else: |
| print(f"WARN: Bullx image fetch failed for {addr}: status {direct_resp.status_code}") |
| except Exception as e: |
| print(f"WARN: Bullx image fetch exception for {addr}: {e}") |
| image = None |
|
|
| |
| |
| if image is None: |
| token_uri = data.get('token_uri') |
| if self._is_dead_uri(token_uri): |
| image = None |
| token_uri = None |
| |
| |
| if '_cached_image_pil' in data: |
| image = data['_cached_image_pil'] |
| |
| if image is None: |
| |
| pass |
|
|
| |
| token_name = data.get('name') if data.get('name') and data.get('name').strip() else None |
| token_symbol = data.get('symbol') if data.get('symbol') and data.get('symbol').strip() else None |
|
|
| |
| |
| |
| if not token_name: |
| token_name = None |
| if not token_symbol: |
| token_symbol = None |
| |
| |
| if not image: |
| image = None |
|
|
| |
| if not addr: |
| print(f"WARN: Token {addr} has no address?? Skipping.") |
| continue |
|
|
| |
| data['is_vanity'] = addr.lower().endswith("pump") |
|
|
| data['image_emb_idx'] = pooler.get_idx(image) |
| data['name_emb_idx'] = pooler.get_idx(token_name) |
| data['symbol_emb_idx'] = pooler.get_idx(token_symbol) |
| data.pop('_cached_image_pil', None) |
| |
| |
| |
| |
| raw_protocol_id = data.get('protocol') |
| if raw_protocol_id is not None and 0 <= raw_protocol_id < vocab.NUM_PROTOCOLS: |
| data['protocol'] = raw_protocol_id |
| else: |
| data['protocol'] = vocab.PROTOCOL_TO_ID.get('Unknown', 0) |
|
|
| valid_token_data[addr] = data |
|
|
| return valid_token_data |
|
|
| def _process_token_data_lightweight(self, token_addresses: List[str], pooler: EmbeddingPooler, T_cutoff: datetime.datetime) -> Dict[str, Dict[str, Any]]: |
| """ |
| Lightweight version of _process_token_data for non-top holding tokens. |
| Fetches metadata from ClickHouse only (cached). NO HTTP image fetches. |
| Sets image_emb_idx=0 (zero embedding). Still encodes name/symbol text. |
| """ |
| if not token_addresses: |
| return {} |
|
|
| |
| missing_tokens = [addr for addr in token_addresses if addr not in self._token_meta_cache] |
|
|
| |
| if missing_tokens and self.fetcher: |
| fetched_data = self.fetcher.fetch_token_data(missing_tokens, T_cutoff) |
| |
| for addr, data in fetched_data.items(): |
| if addr: |
| self._token_meta_cache[addr] = data |
|
|
| |
| valid_token_data = {} |
| for addr in token_addresses: |
| |
| raw_data = self._token_meta_cache.get(addr) |
| if not raw_data: |
| continue |
|
|
| |
| data = raw_data.copy() |
|
|
| token_name = data.get('name') if data.get('name') and data.get('name').strip() else None |
| token_symbol = data.get('symbol') if data.get('symbol') and data.get('symbol').strip() else None |
|
|
| data['is_vanity'] = addr.lower().endswith("pump") |
| data['image_emb_idx'] = 0 |
| data['name_emb_idx'] = pooler.get_idx(token_name) |
| data['symbol_emb_idx'] = pooler.get_idx(token_symbol) |
|
|
| raw_protocol_id = data.get('protocol') |
| if raw_protocol_id is not None and 0 <= raw_protocol_id < vocab.NUM_PROTOCOLS: |
| data['protocol'] = raw_protocol_id |
| else: |
| data['protocol'] = vocab.PROTOCOL_TO_ID.get('Unknown', 0) |
|
|
| valid_token_data[addr] = data |
|
|
| return valid_token_data |
|
|
| def _generate_ohlc(self, aggregation_trades: List[Dict[str, Any]], T_cutoff: datetime.datetime, interval_seconds: int, t0_timestamp: float = None) -> List[tuple]: |
| """ |
| Generates an OHLC series from a list of aggregated trades with a dynamic interval. |
| It forward-fills gaps and extends the series up to T_cutoff. |
| Returns a list of (timestamp, open, close) tuples. |
| |
| Args: |
| t0_timestamp: If provided, OHLC will start from max(first_trade, t0_timestamp) to ensure |
| chart data never precedes the mint event. |
| """ |
| if not aggregation_trades: |
| return [] |
|
|
| trades_by_interval = defaultdict(list) |
| for trade in aggregation_trades: |
| |
| interval_start_ts = (trade['timestamp'] // interval_seconds) * interval_seconds |
| trades_by_interval[interval_start_ts].append(trade['price_usd']) |
|
|
| sorted_intervals = sorted(trades_by_interval.keys()) |
|
|
| if not sorted_intervals: |
| return [] |
|
|
| full_ohlc = [] |
| |
| start_ts = sorted_intervals[0] |
| if t0_timestamp is not None: |
| |
| t0_aligned = (int(t0_timestamp) // interval_seconds) * interval_seconds |
| if t0_aligned < t0_timestamp: |
| t0_aligned += interval_seconds |
| start_ts = max(start_ts, t0_aligned) |
| |
| end_ts = int(T_cutoff.timestamp()) |
| |
| for interval_ts in sorted_intervals: |
| if start_ts <= interval_ts <= end_ts: |
| prices = trades_by_interval[interval_ts] |
| open_price = prices[0] |
| close_price = prices[-1] |
| full_ohlc.append((interval_ts, open_price, close_price)) |
| |
| return full_ohlc |
|
|
| def _compute_quant_rolling_features( |
| self, |
| closes: List[float], |
| end_idx: int, |
| ) -> Dict[str, float]: |
| return compute_rolling_quant_features(closes, end_idx) |
|
|
| def _compute_support_resistance_features( |
| self, |
| closes: List[float], |
| highs: List[float], |
| lows: List[float], |
| end_idx: int, |
| window_start: int, |
| window_end: int, |
| timestamps: List[int], |
| ) -> Dict[str, float]: |
| return compute_support_resistance_features( |
| closes=closes, |
| highs=highs, |
| lows=lows, |
| end_idx=end_idx, |
| window_start=window_start, |
| window_end=window_end, |
| timestamps=timestamps, |
| ) |
|
|
| def _compute_trendline_features( |
| self, |
| closes: List[float], |
| highs: List[float], |
| lows: List[float], |
| end_idx: int, |
| ) -> Dict[str, float]: |
| return compute_trendline_features(closes, highs, lows, end_idx) |
|
|
| def _extract_quant_ohlc_features_for_segment( |
| self, |
| segment: List[tuple], |
| interval_label: str, |
| token_address: Optional[str] = None, |
| ) -> List[Dict[str, Any]]: |
| if not segment: |
| print( |
| f"INFO: Chart quant skipped | token={token_address or 'unknown'} " |
| "reason=empty_segment" |
| ) |
| return [] |
|
|
| try: |
| interval_seconds = max(1, int(str(interval_label).rstrip("s"))) |
| except Exception: |
| interval_seconds = 1 |
| window_bar_count = max(1, WINDOW_SECONDS // interval_seconds) |
| effective_window_seconds = max(WINDOW_SECONDS, interval_seconds) |
| max_windows = max(1, SEGMENT_SECONDS // effective_window_seconds) |
|
|
| timestamps = [int(row[0]) for row in segment] |
| opens = [float(row[1]) for row in segment] |
| closes = [float(row[2]) for row in segment] |
| highs = [max(o, c) for o, c in zip(opens, closes)] |
| lows = [min(o, c) for o, c in zip(opens, closes)] |
| log_closes = np.log(np.clip(np.asarray(closes, dtype=np.float64), 1e-8, None)) |
| one_sec_returns = np.diff(log_closes) |
| feature_windows: List[Dict[str, Any]] = [] |
|
|
| for window_idx in range(max_windows): |
| window_start = window_idx * window_bar_count |
| if window_start >= len(segment): |
| break |
| window_end = min(len(segment), window_start + window_bar_count) |
| current_end_idx = window_end - 1 |
| window_returns = one_sec_returns[window_start:max(window_start, current_end_idx)] |
| window_closes = closes[window_start:window_end] |
| window_highs = highs[window_start:window_end] |
| window_lows = lows[window_start:window_end] |
| features = empty_feature_dict() |
|
|
| if window_closes: |
| window_close_arr = np.asarray(window_closes, dtype=np.float64) |
| window_return_sum = float(np.sum(window_returns)) if window_returns.size > 0 else 0.0 |
| range_width = max(max(window_highs) - min(window_lows), 0.0) |
| first_close = float(window_close_arr[0]) |
| last_close = float(window_close_arr[-1]) |
| accel_proxy = 0.0 |
| if window_returns.size >= 2: |
| accel_proxy = float(window_returns[-1] - window_returns[0]) |
| features.update({ |
| "cum_log_return": window_return_sum, |
| "mean_log_return_1s": float(np.mean(window_returns)) if window_returns.size > 0 else 0.0, |
| "std_log_return_1s": float(np.std(window_returns)) if window_returns.size > 0 else 0.0, |
| "max_up_1s": float(np.max(window_returns)) if window_returns.size > 0 else 0.0, |
| "max_down_1s": float(np.min(window_returns)) if window_returns.size > 0 else 0.0, |
| "realized_vol": float(np.sqrt(np.sum(np.square(window_returns)))) if window_returns.size > 0 else 0.0, |
| "window_range_frac": range_width / max(abs(last_close), 1e-8), |
| "close_to_close_slope": (last_close - first_close) / max(abs(first_close), 1e-8), |
| "accel_proxy": accel_proxy, |
| "frac_pos_1s": float(np.mean(window_returns > 0)) if window_returns.size > 0 else 0.0, |
| "frac_neg_1s": float(np.mean(window_returns < 0)) if window_returns.size > 0 else 0.0, |
| }) |
|
|
| current_price = closes[current_end_idx] |
| current_high = highs[current_end_idx] |
| current_low = lows[current_end_idx] |
| for lookback in LOOKBACK_SECONDS: |
| prefix = f"lb_{lookback}s" |
| lookback_start = max(0, current_end_idx - lookback + 1) |
| hist_closes = closes[lookback_start: current_end_idx + 1] |
| hist_highs = highs[lookback_start: current_end_idx + 1] |
| hist_lows = lows[lookback_start: current_end_idx + 1] |
| hist_range = max(max(hist_highs) - min(hist_lows), 1e-8) |
| rolling_high = max(hist_highs) |
| rolling_low = min(hist_lows) |
| hist_returns = np.diff(np.log(np.clip(np.asarray(hist_closes, dtype=np.float64), 1e-8, None))) |
| current_width = max(max(window_highs) - min(window_lows), 0.0) |
| prev_hist_width = max(max(hist_highs[:-len(window_highs)]) - min(hist_lows[:-len(window_lows)]), 0.0) if len(hist_highs) > len(window_highs) else current_width |
| prev_close = closes[current_end_idx - 1] if current_end_idx > 0 else current_price |
|
|
| features.update({ |
| f"{prefix}_dist_high": (rolling_high - current_price) / max(abs(current_price), 1e-8), |
| f"{prefix}_dist_low": (current_price - rolling_low) / max(abs(current_price), 1e-8), |
| f"{prefix}_drawdown_high": (current_price - rolling_high) / max(abs(rolling_high), 1e-8), |
| f"{prefix}_rebound_low": (current_price - rolling_low) / max(abs(rolling_low), 1e-8), |
| f"{prefix}_pos_in_range": (current_price - rolling_low) / hist_range, |
| f"{prefix}_range_width": hist_range / max(abs(current_price), 1e-8), |
| f"{prefix}_compression_ratio": current_width / max(prev_hist_width, 1e-8), |
| f"{prefix}_breakout_high": 1.0 if current_high > rolling_high and prev_close <= rolling_high else 0.0, |
| f"{prefix}_breakdown_low": 1.0 if current_low < rolling_low and prev_close >= rolling_low else 0.0, |
| f"{prefix}_reclaim_breakdown": 1.0 if current_low < rolling_low and current_price >= rolling_low else 0.0, |
| f"{prefix}_rejection_breakout": 1.0 if current_high > rolling_high and current_price <= rolling_high else 0.0, |
| }) |
|
|
| features.update(self._compute_support_resistance_features( |
| closes=closes, |
| highs=highs, |
| lows=lows, |
| end_idx=current_end_idx, |
| window_start=window_start, |
| window_end=window_end, |
| timestamps=timestamps, |
| )) |
| features.update(self._compute_trendline_features( |
| closes=closes, |
| highs=highs, |
| lows=lows, |
| end_idx=current_end_idx, |
| )) |
| features.update(self._compute_quant_rolling_features( |
| closes=closes, |
| end_idx=current_end_idx, |
| )) |
|
|
| feature_windows.append({ |
| "start_ts": timestamps[window_start], |
| "end_ts": timestamps[current_end_idx], |
| "window_seconds": effective_window_seconds, |
| "feature_vector": feature_dict_to_vector(features), |
| "feature_names_version": FEATURE_VERSION, |
| "feature_version_id": FEATURE_VERSION_ID, |
| "level_snapshot": { |
| "support_distance": features.get("nearest_support_dist", 0.0), |
| "resistance_distance": features.get("nearest_resistance_dist", 0.0), |
| "support_strength": features.get("support_strength", 0.0), |
| "resistance_strength": features.get("resistance_strength", 0.0), |
| "breakout_up": features.get("keylevel_breakout_up", 0.0), |
| "breakout_down": features.get("keylevel_breakout_down", 0.0), |
| "hold_above": features.get("keylevel_hold_above", 0.0), |
| "hold_below": features.get("keylevel_hold_below", 0.0), |
| "flip_to_support": features.get("keylevel_flip_to_support", 0.0), |
| "flip_to_resistance": features.get("keylevel_flip_to_resistance", 0.0), |
| }, |
| "keylevel_flags": { |
| "breakout_up": features.get("keylevel_breakout_up", 0.0), |
| "breakout_down": features.get("keylevel_breakout_down", 0.0), |
| "hold_above": features.get("keylevel_hold_above", 0.0), |
| "hold_below": features.get("keylevel_hold_below", 0.0), |
| "failed_breakout_up": features.get("keylevel_failed_breakout_up", 0.0), |
| "failed_breakout_down": features.get("keylevel_failed_breakout_down", 0.0), |
| "flip_to_support": features.get("keylevel_flip_to_support", 0.0), |
| "flip_to_resistance": features.get("keylevel_flip_to_resistance", 0.0), |
| }, |
| }) |
|
|
| sr_windows = sum( |
| 1 for window in feature_windows |
| if float(window["feature_vector"][FEATURE_INDEX["sr_available"]]) > 0.0 |
| ) |
| trendline_windows = sum( |
| 1 for window in feature_windows |
| if float(window["feature_vector"][FEATURE_INDEX["trendline_available"]]) > 0.0 |
| ) |
| breakout_windows = sum( |
| 1 for window in feature_windows |
| if ( |
| float(window["feature_vector"][FEATURE_INDEX["keylevel_breakout_up"]]) > 0.0 |
| or float(window["feature_vector"][FEATURE_INDEX["keylevel_breakout_down"]]) > 0.0 |
| or float(window["feature_vector"][FEATURE_INDEX["keylevel_flip_to_support"]]) > 0.0 |
| or float(window["feature_vector"][FEATURE_INDEX["keylevel_flip_to_resistance"]]) > 0.0 |
| ) |
| ) |
| keylevel_break_events = sum( |
| 1 for window in feature_windows |
| if ( |
| float(window["feature_vector"][FEATURE_INDEX["keylevel_breakout_up"]]) > 0.0 |
| or float(window["feature_vector"][FEATURE_INDEX["keylevel_breakout_down"]]) > 0.0 |
| ) |
| ) |
| self._chart_feature_log_count += 1 |
| print( |
| f"INFO: Chart quant built | token={token_address or 'unknown'} " |
| f"interval={interval_label} segment={self._chart_feature_log_count} " |
| f"windows={len(feature_windows)}/{max_windows} " |
| f"sr={sr_windows} trend={trendline_windows} breaks={breakout_windows} " |
| f"break_events={keylevel_break_events}" |
| ) |
| return feature_windows |
|
|
| def __getitem__(self, idx: int) -> Optional[Dict[str, Any]]: |
| """ |
| Loads data from cache. |
| """ |
| import time as _time |
| _timings = {} |
| _total_start = _time.perf_counter() |
|
|
| |
| _t0 = _time.perf_counter() |
| if not self.cache_dir: |
| raise RuntimeError("Offline mode required. No cache directory provided.") |
|
|
| if idx >= len(self.cached_files): |
| raise IndexError(f"Index {idx} out of range for {len(self.cached_files)} cached files.") |
|
|
| filepath = self.cached_files[idx] |
| try: |
| cached_data = torch.load(filepath, map_location='cpu', weights_only=False) |
| except Exception as e: |
| raise RuntimeError(f"ERROR: Could not load cached item {filepath}: {e}") |
| _timings['cache_load'] = _time.perf_counter() - _t0 |
|
|
| if not cached_data: |
| raise RuntimeError(f"No data loaded for index {idx}") |
|
|
| has_context_shape = ( |
| isinstance(cached_data, dict) and |
| 'event_sequence' in cached_data and |
| 'tokens' in cached_data and |
| 'wallets' in cached_data and |
| 'labels' in cached_data and |
| 'labels_mask' in cached_data |
| ) |
|
|
| if has_context_shape: |
| |
| _timings['total'] = _time.perf_counter() - _total_start |
|
|
| if 'movement_class_targets' not in cached_data and 'labels' in cached_data and 'labels_mask' in cached_data: |
| labels = cached_data['labels'] |
| labels_mask = cached_data['labels_mask'] |
| movement_targets = derive_movement_targets( |
| labels.tolist() if isinstance(labels, torch.Tensor) else labels, |
| labels_mask.tolist() if isinstance(labels_mask, torch.Tensor) else labels_mask, |
| movement_label_config=self.movement_label_config, |
| ) |
| cached_data['movement_class_targets'] = torch.tensor( |
| movement_targets['movement_class_targets'], |
| dtype=torch.long, |
| ) |
| cached_data['movement_class_mask'] = torch.tensor( |
| movement_targets['movement_class_mask'], |
| dtype=torch.long, |
| ) |
|
|
| if idx % 100 == 0: |
| print(f"[Sample {idx}] context cache | cache_load: {_timings['cache_load']*1000:.1f}ms | " |
| f"total: {_timings['total']*1000:.1f}ms | events: {len(cached_data.get('event_sequence', []))}") |
|
|
| return cached_data |
|
|
| raise RuntimeError( |
| f"Cached item at {filepath} is not a valid context cache. " |
| "Rebuild the cache with scripts/cache_dataset.py." |
| ) |
|
|
| def _process_token_data_offline(self, token_addresses: List[str], pooler: EmbeddingPooler, |
| T_cutoff: datetime.datetime, token_data: Optional[Dict] = None) -> Dict[str, Dict[str, Any]]: |
| """ |
| Processes token data in OFFLINE mode - no HTTP calls for images. |
| Uses pre-cached image bytes from the cache file. |
| """ |
| if not token_addresses: |
| return {} |
|
|
| if token_data is None: |
| token_data = {} |
|
|
| valid_token_data = {} |
| for addr, data in token_data.items(): |
| |
| image = data.get('_cached_image_pil', None) |
|
|
| |
| token_name = data.get('name') if data.get('name') and data.get('name').strip() else None |
| token_symbol = data.get('symbol') if data.get('symbol') and data.get('symbol').strip() else None |
|
|
| if not addr: |
| continue |
|
|
| |
| data['is_vanity'] = addr.lower().endswith("pump") |
|
|
| |
| data['image_emb_idx'] = pooler.get_idx(image) |
| data['name_emb_idx'] = pooler.get_idx(token_name) |
| data['symbol_emb_idx'] = pooler.get_idx(token_symbol) |
| |
| data.pop('_cached_image_pil', None) |
|
|
| |
| raw_protocol_id = data.get('protocol') |
| if raw_protocol_id is not None and 0 <= raw_protocol_id < vocab.NUM_PROTOCOLS: |
| data['protocol'] = raw_protocol_id |
| else: |
| data['protocol'] = vocab.PROTOCOL_TO_ID.get('Unknown', 0) |
|
|
| valid_token_data[addr] = data |
|
|
| return valid_token_data |
|
|
| def _build_main_token_seed(self, token_address: str, raw_data: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Build a minimal token metadata payload for the main token. |
| Prevents raw cache blobs (trades/snapshots/etc.) from leaking into |
| sample['tokens'][main_token]. |
| """ |
| return { |
| 'token_address': token_address, |
| 'address': token_address, |
| 'name': raw_data.get('name', ''), |
| 'symbol': raw_data.get('symbol', ''), |
| 'token_uri': raw_data.get('token_uri', ''), |
| 'protocol': raw_data.get('protocol', 1), |
| 'total_supply': raw_data.get('total_supply', 0), |
| 'decimals': raw_data.get('decimals', 6), |
| } |
|
|
| def __cacheitem__(self, idx: int) -> Optional[Dict[str, Any]]: |
| """ |
| Fetches cutoff-agnostic raw token data for caching/online sampling. |
| Generates dense time-series (1s OHLC, Snapshots) and prunes raw logs. |
| |
| NEW: Also pre-fetches and caches ALL wallet profiles, socials, holdings, |
| graph links, and token images to enable fully offline training. |
| """ |
|
|
| if not self.sampled_mints: |
| raise RuntimeError("Dataset has no mint records loaded; ensure fetcher returned data during initialization.") |
| if idx >= len(self.sampled_mints): |
| raise IndexError(f"Requested sample index {idx} exceeds loaded mint count {len(self.sampled_mints)}.") |
| initial_mint_record = self.sampled_mints[idx] |
| t0 = initial_mint_record["timestamp"] |
| if isinstance(t0, datetime.datetime) and t0.tzinfo is None: |
| t0 = t0.replace(tzinfo=datetime.timezone.utc) |
| creator_address = initial_mint_record['creator_address'] |
| token_address = initial_mint_record['mint_address'] |
| |
|
|
| if not self.fetcher: |
| raise RuntimeError("Dataset has no data fetcher; cannot load raw data.") |
|
|
| |
| raw_data = self.fetcher.fetch_raw_token_data( |
| token_address=token_address, |
| creator_address=creator_address, |
| mint_timestamp=t0, |
| max_horizon_seconds=self.max_cache_horizon_seconds, |
| include_wallet_data=False, |
| include_graph=False, |
| min_trades=self.min_trades, |
| full_history=True, |
| prune_failed=False, |
| prune_transfers=False |
| ) |
| if raw_data is None: |
| return None |
|
|
| |
| |
| print(f" DEBUG: initial_mint_record keys: {list(initial_mint_record.keys())}") |
| print(f" DEBUG: token_name='{initial_mint_record.get('token_name')}', token_symbol='{initial_mint_record.get('token_symbol')}'") |
| raw_data['name'] = initial_mint_record.get('token_name', '') |
| raw_data['symbol'] = initial_mint_record.get('token_symbol', '') |
| raw_data['token_uri'] = initial_mint_record.get('token_uri', '') |
| raw_total_supply = initial_mint_record.get('total_supply', DEFAULT_TOTAL_SUPPLY_RAW) |
| raw_token_decimals = initial_mint_record.get('token_decimals', DEFAULT_TOKEN_DECIMALS) |
| raw_data['total_supply'] = ( |
| int(raw_total_supply) if raw_total_supply and int(raw_total_supply) > 0 else DEFAULT_TOTAL_SUPPLY_RAW |
| ) |
| raw_data['decimals'] = ( |
| int(raw_token_decimals) if raw_token_decimals is not None and int(raw_token_decimals) >= 0 else DEFAULT_TOKEN_DECIMALS |
| ) |
| raw_data['protocol'] = initial_mint_record.get('protocol', 1) |
|
|
| def _timestamp_to_order_value(ts_value: Any) -> float: |
| if isinstance(ts_value, datetime.datetime): |
| if ts_value.tzinfo is None: |
| ts_value = ts_value.replace(tzinfo=datetime.timezone.utc) |
| return ts_value.timestamp() |
| try: |
| return float(ts_value) |
| except (TypeError, ValueError): |
| return 0.0 |
|
|
| trades = raw_data.get('trades', []) |
| trade_ts_values = [_timestamp_to_order_value(t.get('timestamp')) for t in trades] |
|
|
| if not trade_ts_values: |
| print(f" SKIP: No valid trades found for {token_address}.") |
| return None |
|
|
| t0_val = _timestamp_to_order_value(t0) |
| last_trade_ts_val = max(trade_ts_values) |
|
|
| |
| duration_seconds = int(last_trade_ts_val - t0_val) + 120 |
| ohlc_1s = torch.zeros((duration_seconds, 2), dtype=torch.float32) |
|
|
| |
| |
| trades.sort(key=lambda x: _timestamp_to_order_value(x['timestamp'])) |
|
|
| |
| |
| |
| |
|
|
| trades_by_sec = defaultdict(list) |
| for t in trades: |
| ts = _timestamp_to_order_value(t['timestamp']) |
| sec_idx = int(ts - t0_val) |
| if 0 <= sec_idx < duration_seconds: |
| trades_by_sec[sec_idx].append(t['price_usd']) |
|
|
| last_close = float(trades[0]['price_usd']) |
|
|
| for i in range(duration_seconds): |
| if i in trades_by_sec: |
| prices = trades_by_sec[i] |
| op = prices[0] |
| cl = prices[-1] |
| last_close = cl |
| else: |
| op = cl = last_close |
|
|
| ohlc_1s[i, 0] = float(op) |
| ohlc_1s[i, 1] = float(cl) |
|
|
| raw_data['ohlc_1s'] = ohlc_1s |
|
|
| |
| interval = 300 |
| num_intervals = (duration_seconds // interval) + 1 |
| |
| |
|
|
| snapshot_stats = torch.zeros((num_intervals, 6), dtype=torch.float32) |
|
|
| cum_volume = 0.0 |
| cum_tx = 0 |
| cum_buys = 0 |
| cum_sells = 0 |
|
|
| |
| buckets = defaultdict(list) |
| for t in trades: |
| ts = _timestamp_to_order_value(t['timestamp']) |
| bucket_idx = int(ts - t0_val) // interval |
| if bucket_idx >= 0: |
| buckets[bucket_idx].append(t) |
|
|
| |
| holder_counts_by_interval = {} |
| try: |
| all_holdings = self.fetcher.db_client.execute(""" |
| SELECT wallet_address, current_balance, updated_at |
| FROM wallet_holdings |
| WHERE mint_address = %(token)s |
| ORDER BY wallet_address, updated_at |
| """, {'token': token_address}) |
|
|
| if all_holdings: |
| wallet_latest = {} |
| all_holdings.sort(key=lambda x: x[2]) |
| holding_idx = 0 |
| for i in range(num_intervals): |
| snap_ts = t0 + datetime.timedelta(seconds=(i + 1) * interval) |
| while holding_idx < len(all_holdings) and all_holdings[holding_idx][2] <= snap_ts: |
| wallet, balance, _ = all_holdings[holding_idx] |
| wallet_latest[wallet] = balance |
| holding_idx += 1 |
| holder_counts_by_interval[i] = sum(1 for b in wallet_latest.values() if b and b > 0) |
| except Exception: |
| pass |
|
|
| holder_snapshots_list = [] |
|
|
| for i in range(num_intervals): |
| bucket_trades = buckets[i] |
|
|
| |
| vol = sum(t.get('total_usd', 0.0) for t in bucket_trades) |
| tx = len(bucket_trades) |
| buys = sum(1 for t in bucket_trades if t.get('trade_direction') == 0 or t.get('trade_type') == 0) |
| sells = tx - buys |
|
|
| count = holder_counts_by_interval.get(i, 0) |
|
|
| snapshot_stats[i, 0] = float(vol) |
| snapshot_stats[i, 1] = float(tx) |
| snapshot_stats[i, 2] = float(buys) |
| snapshot_stats[i, 3] = float(sells) |
| snapshot_stats[i, 4] = float(count) |
| snapshot_stats[i, 5] = 0.0 |
|
|
| snapshot_ts = t0 + datetime.timedelta(seconds=(i+1)*interval) |
| holder_snapshots_list.append({ |
| 'timestamp': int(snapshot_ts.timestamp()), |
| 'holders': [] |
| }) |
|
|
| raw_data['snapshots_5m'] = snapshot_stats |
| raw_data['holder_snapshots_list'] = holder_snapshots_list |
|
|
| raw_data['holder_snapshots_list'] = holder_snapshots_list |
| raw_data["protocol_id"] = initial_mint_record.get("protocol") |
|
|
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| all_wallets = set() |
| all_wallets.add(creator_address) |
|
|
| for trade in raw_data.get('trades', []): |
| if trade.get('maker'): |
| all_wallets.add(trade['maker']) |
|
|
| for transfer in raw_data.get('transfers', []): |
| if transfer.get('source'): |
| all_wallets.add(transfer['source']) |
| if transfer.get('destination'): |
| all_wallets.add(transfer['destination']) |
|
|
| for pool in raw_data.get('pool_creations', []): |
| if pool.get('creator_address'): |
| all_wallets.add(pool['creator_address']) |
|
|
| for liq in raw_data.get('liquidity_changes', []): |
| if liq.get('lp_provider'): |
| all_wallets.add(liq['lp_provider']) |
|
|
| |
| for snapshot in holder_snapshots_list: |
| for holder in snapshot.get('holders', []): |
| if holder.get('wallet_address'): |
| all_wallets.add(holder['wallet_address']) |
|
|
| all_wallets.discard(None) |
| all_wallets.discard('') |
| wallet_list = list(all_wallets) |
|
|
| print(f" INFO: Found {len(wallet_list)} unique wallets to cache") |
|
|
| |
| |
| max_T_cutoff = datetime.datetime.fromtimestamp(last_trade_ts_val, tz=datetime.timezone.utc) |
|
|
| try: |
| cached_profiles, cached_socials = self.fetcher.fetch_wallet_profiles_and_socials( |
| wallet_list, max_T_cutoff |
| ) |
| print(f" INFO: Cached {len(cached_profiles)} wallet profiles, {len(cached_socials)} socials") |
| except Exception as e: |
| print(f" WARN: Failed to fetch wallet profiles/socials: {e}") |
| cached_profiles, cached_socials = {}, {} |
|
|
| |
| try: |
| cached_holdings = self.fetcher.fetch_wallet_holdings(wallet_list, max_T_cutoff) |
| print(f" INFO: Cached holdings for {len(cached_holdings)} wallets") |
| except Exception as e: |
| print(f" WARN: Failed to fetch wallet holdings: {e}") |
| cached_holdings = {} |
|
|
| |
| try: |
| cached_graph_entities, cached_graph_links = self.fetcher.fetch_graph_links( |
| wallet_list, max_T_cutoff, max_degrees=1 |
| ) |
| print(f" INFO: Cached {len(cached_graph_links)} graph link types, {len(cached_graph_entities)} graph entities") |
| except Exception as e: |
| print(f" WARN: Failed to fetch graph links: {e}") |
| cached_graph_entities, cached_graph_links = {}, {} |
|
|
| |
| cached_image_bytes = None |
| try: |
| |
| bullx_image_url = f"https://image.bullx.io/1399811149/{token_address}?retry=0" |
| resp = self.http_session.get(bullx_image_url, timeout=2) |
| if resp.status_code == 200: |
| cached_image_bytes = resp.content |
| print(f" INFO: Cached token image from Bullx ({len(cached_image_bytes)} bytes)") |
| else: |
| |
| token_uri = raw_data.get('token_uri') |
| if token_uri and 'ipfs/' in token_uri: |
| ipfs_gateways = [ |
| "https://pump.mypinata.cloud/ipfs/", |
| "https://dweb.link/ipfs/", |
| "https://cloudflare-ipfs.com/ipfs/", |
| ] |
| metadata_hash = token_uri.split('ipfs/')[-1] |
| for gateway in ipfs_gateways: |
| try: |
| metadata_resp = self.http_session.get(f"{gateway}{metadata_hash}", timeout=5) |
| if metadata_resp.status_code == 200: |
| metadata = metadata_resp.json() |
| image_url = metadata.get('image', '') |
| if image_url and 'ipfs/' in image_url: |
| image_hash = image_url.split('ipfs/')[-1] |
| for img_gateway in ipfs_gateways: |
| try: |
| img_resp = self.http_session.get(f"{img_gateway}{image_hash}", timeout=5) |
| if img_resp.status_code == 200: |
| cached_image_bytes = img_resp.content |
| print(f" INFO: Cached token image from IPFS ({len(cached_image_bytes)} bytes)") |
| break |
| except: |
| continue |
| break |
| except: |
| continue |
| except Exception as e: |
| print(f" WARN: Failed to cache token image: {e}") |
|
|
| |
| raw_data['cached_wallet_data'] = { |
| 'profiles': cached_profiles, |
| 'socials': cached_socials, |
| 'holdings': cached_holdings, |
| } |
| raw_data['cached_graph_data'] = { |
| 'entities': cached_graph_entities, |
| 'links': cached_graph_links, |
| } |
| raw_data['cached_image_bytes'] = cached_image_bytes |
| raw_data['cached_max_T_cutoff'] = max_T_cutoff.timestamp() |
|
|
| print(f" INFO: Cache complete for {token_address}") |
|
|
| return raw_data |
|
|
| def _generate_dataset_item(self, |
| token_address: str, |
| t0: datetime.datetime, |
| T_cutoff: datetime.datetime, |
| mint_event: Dict[str, Any], |
| trade_records: List[Dict[str, Any]], |
| transfer_records: List[Dict[str, Any]], |
| pool_creation_records: List[Dict[str, Any]], |
| liquidity_change_records: List[Dict[str, Any]], |
| fee_collection_records: List[Dict[str, Any]], |
| burn_records: List[Dict[str, Any]], |
| supply_lock_records: List[Dict[str, Any]], |
| migration_records: List[Dict[str, Any]], |
| wallet_data: Dict[str, Dict[str, Any]], |
| all_token_data: Dict[str, Any], |
| graph_links: Dict[str, Any], |
| graph_seed_entities: set, |
| all_graph_entities: Dict[str, str], |
| future_trades_for_labels: List[Dict[str, Any]], |
| pooler: EmbeddingPooler, |
| sample_idx: Optional[int] = None, |
| cached_holders_list: List[Dict[str, Any]] = None, |
| cached_ohlc_1s: Optional[torch.Tensor] = None, |
| quality_score: Optional[float] = None |
| ) -> Optional[Dict[str, Any]]: |
| """ |
| Processes raw token data into a structured dataset item for a specific T_cutoff. |
| Filters events beyond T_cutoff, computes derived features, and builds the final sample. |
| """ |
| |
| |
| |
| |
| |
| |
| |
| def _safe_int(value: Any) -> int: |
| try: return int(value) |
| except: return 0 |
|
|
| def _timestamp_to_order_value(ts_value: Any) -> float: |
| if isinstance(ts_value, datetime.datetime): |
| if ts_value.tzinfo is None: ts_value = ts_value.replace(tzinfo=datetime.timezone.utc) |
| return ts_value.timestamp() |
| elif isinstance(ts_value, str): |
| try: return datetime.datetime.fromisoformat(ts_value.replace('Z', '+00:00')).timestamp() |
| except ValueError: pass |
| try: return float(ts_value) |
| except: return 0.0 |
|
|
| def _event_execution_sort_key(timestamp_value: Any, slot=0, transaction_index=0, instruction_index=0, signature='') -> tuple: |
| return (_timestamp_to_order_value(timestamp_value), _safe_int(slot), _safe_int(transaction_index), _safe_int(instruction_index), signature or '') |
|
|
| def _trade_execution_sort_key(trade: Dict[str, Any]) -> tuple: |
| return ( |
| _timestamp_to_order_value(trade.get('timestamp')), |
| _safe_int(trade.get('slot')), |
| _safe_int(trade.get('transaction_index')), |
| _safe_int(trade.get('instruction_index')), |
| trade.get('signature', '') |
| ) |
| |
| t0_timestamp = _timestamp_to_order_value(t0) |
| |
| |
| |
| |
| |
| def filter_by_time(records): |
| return [r for r in records if _timestamp_to_order_value(r.get('timestamp')) <= T_cutoff.timestamp()] |
| |
| trade_records = filter_by_time(trade_records) |
| transfer_records = filter_by_time(transfer_records) |
| pool_creation_records = filter_by_time(pool_creation_records) |
| liquidity_change_records = filter_by_time(liquidity_change_records) |
| fee_collection_records = filter_by_time(fee_collection_records) |
| burn_records = filter_by_time(burn_records) |
| supply_lock_records = filter_by_time(supply_lock_records) |
| migration_records = filter_by_time(migration_records) |
| |
| |
| event_sequence_entries: List[Tuple[tuple, Dict[str, Any]]] = [] |
| def _register_event(event: Dict[str, Any], sort_key: tuple): |
| event_sequence_entries.append((sort_key, event)) |
|
|
| |
| _register_event(mint_event, _event_execution_sort_key(mint_event['timestamp'], signature='Mint')) |
|
|
| |
| trade_events = [] |
| transfer_events = [] |
| aggregation_trades = [] |
| high_def_chart_trades = [] |
| middle_chart_trades = [] |
| |
| main_token_info = all_token_data.get(token_address, {}) |
| base_decimals = main_token_info.get('decimals', 6) |
| raw_total_supply = main_token_info.get('total_supply', 0) |
| total_supply_dec = (raw_total_supply / (10**base_decimals)) if base_decimals > 0 else raw_total_supply |
| |
| |
| if not total_supply_dec or total_supply_dec == 0: |
| total_supply_dec = 1_000_000_000.0 |
| |
| |
| QUOTE_TOKEN_DECIMALS = {'So11111111111111111111111111111111111111112': 9} |
| SMART_WALLET_PNL_THRESHOLD = 2.0 |
| SMART_WALLET_USD_THRESHOLD = 1000.0 |
| LARGE_TRADE_SUPPLY_PCT_THRESHOLD = 0.019 |
| LARGE_TRADE_USD_THRESHOLD = 330.0 |
| LARGE_TRANSFER_SUPPLY_PCT_THRESHOLD = 0.0028 |
| for trade in trade_records: |
| if trade.get('total_usd', 0.0) < self.min_trade_usd: continue |
|
|
| trade_sort_key = _trade_execution_sort_key(trade) |
| trade_ts_int = int(_timestamp_to_order_value(trade.get('timestamp'))) |
| |
| |
| trader_addr = trade['maker'] |
| |
| |
| |
| |
| trader_wallet = wallet_data.get(trader_addr, {}) |
| trader_profile = trader_wallet.get('profile', {}) |
| |
| KOL_NAME_KEYS = ['kolscan_name', 'cabalspy_name', 'axiom_kol_name'] |
| is_kol = any(trader_wallet.get('socials', {}).get(key) for key in KOL_NAME_KEYS) |
| is_profitable = (trader_profile.get('stats_30d_realized_profit_pnl', 0.0) > SMART_WALLET_PNL_THRESHOLD) |
| |
| base_amount_dec = trade.get('base_amount', 0) / (10**base_decimals) |
| is_large_amount = (total_supply_dec > 0 and (base_amount_dec / total_supply_dec) > LARGE_TRADE_SUPPLY_PCT_THRESHOLD) |
| |
| if trader_addr == mint_event['wallet_address']: event_type = 'Deployer_Trade' |
| elif is_kol or is_profitable: event_type = 'SmartWallet_Trade' |
| elif trade.get('total_usd', 0.0) > LARGE_TRADE_USD_THRESHOLD or is_large_amount: event_type = 'LargeTrade' |
| else: event_type = 'Trade' |
| |
| |
| quote_address = trade.get('quote_address') |
| quote_decimals = QUOTE_TOKEN_DECIMALS.get(quote_address, 9) |
| quote_amount_dec = trade.get('quote_amount', 0) / (10**quote_decimals) |
| |
| is_sell = trade.get('trade_type') == 1 |
| pre_trade_base = (trade.get('base_balance', 0) + base_amount_dec) if is_sell else trade.get('base_balance', 0) |
| pre_trade_quote = (trade.get('quote_balance', 0) + quote_amount_dec) if not is_sell else trade.get('quote_balance', 0) |
| |
| token_pct_hold = min(1.0, (base_amount_dec / pre_trade_base) if pre_trade_base > 1e-9 else 1.0) |
| quote_pct_hold = min(1.0, (quote_amount_dec / pre_trade_quote) if pre_trade_quote > 1e-9 else 1.0) |
| token_pct_supply = min(1.0, (base_amount_dec / total_supply_dec) if total_supply_dec > 0 else 0.0) |
| |
| is_success = trade.get('success', False) |
| price_valid = float(trade.get('price_usd', 0.0) or 0.0) > 0 |
| |
| if is_success and price_valid: |
| chart_entry = { |
| 'trade_direction': 1 if is_sell else 0, |
| 'price_usd': trade.get('price_usd', 0.0), |
| 'timestamp': trade_ts_int, |
| 'sort_key': trade_sort_key |
| } |
| aggregation_trades.append(chart_entry) |
| high_def_chart_trades.append(chart_entry.copy()) |
| |
| middle_chart_trades.append(chart_entry.copy()) |
| |
| trade_event = { |
| 'event_type': event_type, |
| 'timestamp': trade_ts_int, |
| 'relative_ts': _timestamp_to_order_value(trade.get('timestamp')) - t0_timestamp, |
| 'wallet_address': trader_addr, |
| 'token_address': token_address, |
| 'trade_direction': 1 if is_sell else 0, |
| 'sol_amount': trade.get('total', 0.0), |
| 'dex_platform_id': trade.get('platform', 0), |
| 'priority_fee': trade.get('priority_fee', 0.0), |
| 'mev_protection': 1 if trade.get('mev_protection', 0) > 0 else 0, |
| 'token_amount_pct_of_holding': token_pct_hold, |
| 'quote_amount_pct_of_holding': quote_pct_hold, |
| 'slippage': min(self.p99_clamps['slippage'], float(trade.get('slippage', 0.0) or 0.0)), |
| 'token_amount_pct_to_total_supply': token_pct_supply, |
| 'success': is_success, |
| 'is_bundle': trade.get('is_bundle', False), |
| 'total_usd': trade.get('total_usd', 0.0) |
| } |
| |
| _register_event(trade_event, trade_sort_key) |
| trade_events.append(trade_event) |
|
|
| |
| for transfer in transfer_records: |
| transfer_ts_val = _timestamp_to_order_value(transfer.get('timestamp')) |
| transfer_ts_int = int(transfer_ts_val) |
| amount_dec = float(transfer.get('amount_decimal', 0.0) or 0.0) |
| source_balance = float(transfer.get('source_balance', 0.0) or 0.0) |
| denom = source_balance + amount_dec if source_balance > 0 else 0.0 |
| transfer_pct_of_holding = min(1.0, (amount_dec / denom) if denom > 1e-9 else 0.0) |
| transfer_pct_of_supply = min(1.0, (amount_dec / total_supply_dec) if total_supply_dec > 0 else 0.0) |
| is_large_transfer = transfer_pct_of_supply >= LARGE_TRANSFER_SUPPLY_PCT_THRESHOLD |
|
|
| transfer_event = { |
| 'event_type': 'LargeTransfer' if is_large_transfer else 'Transfer', |
| 'timestamp': transfer_ts_int, |
| 'relative_ts': transfer_ts_val - t0_timestamp, |
| 'wallet_address': transfer.get('source'), |
| 'destination_wallet_address': transfer.get('destination'), |
| 'token_address': token_address, |
| 'token_amount': amount_dec, |
| 'transfer_pct_of_total_supply': transfer_pct_of_supply, |
| 'transfer_pct_of_holding': transfer_pct_of_holding, |
| 'priority_fee': transfer.get('priority_fee', 0.0), |
| 'success': transfer.get('success', False) |
| } |
| _register_event( |
| transfer_event, |
| _event_execution_sort_key( |
| transfer.get('timestamp'), |
| slot=transfer.get('slot', 0), |
| signature=transfer.get('signature', '') |
| ) |
| ) |
| transfer_events.append(transfer_event) |
|
|
| |
| def _finalize_chart(t_list): |
| t_list.sort(key=lambda x: x['sort_key']) |
| for e in t_list: e.pop('sort_key', None) |
| |
| _finalize_chart(aggregation_trades) |
| _finalize_chart(high_def_chart_trades) |
| _finalize_chart(middle_chart_trades) |
|
|
| HIGH_DEF_INTERVAL = ("1s", 1) |
| MIDDLE_INTERVAL = ("30s", 30) |
|
|
| def _emit_chart_segments(trades: List[Dict[str, Any]], interval: tuple, precomputed_ohlc: List[tuple] = None): |
| if not trades and precomputed_ohlc is None: |
| return [] |
| interval_label, interval_seconds = interval |
| |
| if precomputed_ohlc is not None: |
| ohlc_series = precomputed_ohlc |
| else: |
| |
| ohlc_series = self._generate_ohlc(trades, T_cutoff, interval_seconds, t0_timestamp=t0_timestamp) |
| emitted_events = [] |
| for idx in range(0, len(ohlc_series), OHLC_SEQ_LEN): |
| segment = ohlc_series[idx:idx + OHLC_SEQ_LEN] |
| if not segment: |
| continue |
| last_ts = segment[-1][0] |
| opens_raw = [s[1] for s in segment] |
| closes_raw = [s[2] for s in segment] |
| chart_event = { |
| 'event_type': 'Chart_Segment', |
| 'timestamp': int(last_ts), |
| 'relative_ts': int(last_ts) - int(t0_timestamp), |
| 'opens': self._normalize_price_series(opens_raw), |
| 'closes': self._normalize_price_series(closes_raw), |
| 'i': interval_label, |
| 'quant_ohlc_features': self._extract_quant_ohlc_features_for_segment(segment, interval_label, token_address=token_address), |
| 'quant_feature_version': FEATURE_VERSION, |
| } |
| emitted_events.append(chart_event) |
| return emitted_events |
| |
| |
| chart_events_1s = [] |
| chart_events_30s = [] |
|
|
| |
| |
| chart_events_1s = _emit_chart_segments(high_def_chart_trades, HIGH_DEF_INTERVAL) |
| chart_events_30s = _emit_chart_segments(middle_chart_trades, MIDDLE_INTERVAL) |
| |
| |
| pool_meta_by_address = {} |
| for pool_record in pool_creation_records: |
| pool_addr = pool_record.get('pool_address') |
| if pool_addr: |
| pool_meta_by_address[pool_addr] = pool_record |
|
|
| pool_ts_val = _timestamp_to_order_value(pool_record.get('timestamp')) |
| pool_ts = int(pool_ts_val) |
| base_decimals = pool_record.get('base_decimals') |
| quote_decimals = pool_record.get('quote_decimals') |
| base_decimals = int(base_decimals) if base_decimals is not None else 0 |
| quote_decimals = int(quote_decimals) if quote_decimals is not None else 0 |
|
|
| base_amount_raw = pool_record.get('initial_base_liquidity', 0) or 0 |
| quote_amount_raw = pool_record.get('initial_quote_liquidity', 0) or 0 |
| base_amount = float(base_amount_raw) / (10 ** base_decimals) if base_decimals > 0 else float(base_amount_raw) |
| quote_amount = float(quote_amount_raw) / (10 ** quote_decimals) if quote_decimals > 0 else float(quote_amount_raw) |
|
|
| pool_event = { |
| 'event_type': 'PoolCreated', |
| 'timestamp': pool_ts, |
| 'relative_ts': pool_ts_val - t0_timestamp, |
| 'wallet_address': pool_record.get('creator_address'), |
| 'token_address': token_address, |
| 'quote_token_address': pool_record.get('quote_address'), |
| 'protocol_id': pool_record.get('protocol', 0), |
| 'pool_address': pool_addr, |
| 'base_amount': base_amount, |
| 'quote_amount': quote_amount, |
| 'priority_fee': pool_record.get('priority_fee', 0.0), |
| 'success': pool_record.get('success', False) |
| } |
| _register_event( |
| pool_event, |
| _event_execution_sort_key( |
| pool_record.get('timestamp'), |
| slot=pool_record.get('slot', 0), |
| signature=pool_record.get('signature', '') |
| ) |
| ) |
|
|
| for liq_record in liquidity_change_records: |
| liq_ts_val = _timestamp_to_order_value(liq_record.get('timestamp')) |
| liq_ts = int(liq_ts_val) |
| pool_addr = liq_record.get('pool_address') |
| pool_meta = pool_meta_by_address.get(pool_addr, {}) |
| quote_decimals = pool_meta.get('quote_decimals') |
| quote_decimals = int(quote_decimals) if quote_decimals is not None else 0 |
|
|
| quote_amount_raw = liq_record.get('quote_amount', 0) or 0 |
| quote_amount = float(quote_amount_raw) / (10 ** quote_decimals) if quote_decimals > 0 else float(quote_amount_raw) |
|
|
| liq_event = { |
| 'event_type': 'LiquidityChange', |
| 'timestamp': liq_ts, |
| 'relative_ts': liq_ts_val - t0_timestamp, |
| 'wallet_address': liq_record.get('lp_provider'), |
| 'token_address': token_address, |
| 'quote_token_address': pool_meta.get('quote_address'), |
| 'protocol_id': liq_record.get('protocol', 0), |
| 'change_type_id': liq_record.get('change_type', 0), |
| 'quote_amount': quote_amount, |
| 'priority_fee': liq_record.get('priority_fee', 0.0), |
| 'success': liq_record.get('success', False) |
| } |
| _register_event( |
| liq_event, |
| _event_execution_sort_key( |
| liq_record.get('timestamp'), |
| slot=liq_record.get('slot', 0), |
| signature=liq_record.get('signature', '') |
| ) |
| ) |
|
|
| for fee_record in fee_collection_records: |
| fee_ts_val = _timestamp_to_order_value(fee_record.get('timestamp')) |
| fee_ts = int(fee_ts_val) |
| amount = 0.0 |
| if fee_record.get('token_0_mint_address') == token_address: |
| amount = float(fee_record.get('token_0_amount', 0.0) or 0.0) |
| elif fee_record.get('token_1_mint_address') == token_address: |
| amount = float(fee_record.get('token_1_amount', 0.0) or 0.0) |
|
|
| fee_event = { |
| 'event_type': 'FeeCollected', |
| 'timestamp': fee_ts, |
| 'relative_ts': fee_ts_val - t0_timestamp, |
| 'wallet_address': fee_record.get('recipient_address'), |
| 'token_address': token_address, |
| 'sol_amount': amount, |
| 'protocol_id': fee_record.get('protocol', 0), |
| 'priority_fee': fee_record.get('priority_fee', 0.0), |
| 'success': fee_record.get('success', False) |
| } |
| _register_event( |
| fee_event, |
| _event_execution_sort_key( |
| fee_record.get('timestamp'), |
| slot=fee_record.get('slot', 0), |
| signature=fee_record.get('signature', '') |
| ) |
| ) |
|
|
| for burn_record in burn_records: |
| burn_ts_val = _timestamp_to_order_value(burn_record.get('timestamp')) |
| burn_ts = int(burn_ts_val) |
| amount_dec = float(burn_record.get('amount_decimal', 0.0) or 0.0) |
| amount_pct = (amount_dec / total_supply_dec) if total_supply_dec > 0 else 0.0 |
|
|
| burn_event = { |
| 'event_type': 'TokenBurn', |
| 'timestamp': burn_ts, |
| 'relative_ts': burn_ts_val - t0_timestamp, |
| 'wallet_address': burn_record.get('source'), |
| 'token_address': token_address, |
| 'amount_pct_of_total_supply': amount_pct, |
| 'amount_tokens_burned': amount_dec, |
| 'priority_fee': burn_record.get('priority_fee', 0.0), |
| 'success': burn_record.get('success', False) |
| } |
| _register_event( |
| burn_event, |
| _event_execution_sort_key( |
| burn_record.get('timestamp'), |
| slot=burn_record.get('slot', 0), |
| signature=burn_record.get('signature', '') |
| ) |
| ) |
|
|
| for lock_record in supply_lock_records: |
| lock_ts_val = _timestamp_to_order_value(lock_record.get('timestamp')) |
| lock_ts = int(lock_ts_val) |
| total_locked_amount = float(lock_record.get('total_locked_amount', 0.0) or 0.0) |
| amount_pct = (total_locked_amount / total_supply_dec) if total_supply_dec > 0 else 0.0 |
| final_unlock_ts = lock_record.get('final_unlock_timestamp', 0) or 0 |
| lock_duration = float(final_unlock_ts) - float(lock_ts_val) |
| if lock_duration < 0: |
| lock_duration = 0.0 |
|
|
| lock_event = { |
| 'event_type': 'SupplyLock', |
| 'timestamp': lock_ts, |
| 'relative_ts': lock_ts_val - t0_timestamp, |
| 'wallet_address': lock_record.get('sender'), |
| 'token_address': token_address, |
| 'amount_pct_of_total_supply': amount_pct, |
| 'lock_duration': lock_duration, |
| 'protocol_id': lock_record.get('protocol', 0), |
| 'priority_fee': lock_record.get('priority_fee', 0.0), |
| 'success': lock_record.get('success', False) |
| } |
| _register_event( |
| lock_event, |
| _event_execution_sort_key( |
| lock_record.get('timestamp'), |
| slot=lock_record.get('slot', 0), |
| signature=lock_record.get('signature', '') |
| ) |
| ) |
|
|
| for migration_record in migration_records: |
| mig_ts_val = _timestamp_to_order_value(migration_record.get('timestamp')) |
| mig_ts = int(mig_ts_val) |
| mig_event = { |
| 'event_type': 'Migrated', |
| 'timestamp': mig_ts, |
| 'relative_ts': mig_ts_val - t0_timestamp, |
| 'wallet_address': None, |
| 'token_address': token_address, |
| 'protocol_id': migration_record.get('protocol', 0), |
| 'priority_fee': migration_record.get('priority_fee', 0.0), |
| 'success': migration_record.get('success', False) |
| } |
| _register_event( |
| mig_event, |
| _event_execution_sort_key( |
| migration_record.get('timestamp'), |
| slot=migration_record.get('slot', 0), |
| signature=migration_record.get('signature', '') |
| ) |
| ) |
| |
| |
| |
| wallet_balances_raw = {} |
| for trade in trade_records: |
| if not trade.get('success', False): |
| continue |
| maker = trade.get('maker') |
| if not maker: |
| continue |
| try: |
| trade_type = int(trade.get('trade_type', 0)) |
| base_amount_raw = int(trade.get('base_amount', 0)) |
| except: |
| continue |
| if trade_type not in (0, 1) or base_amount_raw < 0: |
| continue |
| signed_delta = base_amount_raw if trade_type == 0 else -base_amount_raw |
| wallet_balances_raw[maker] = wallet_balances_raw.get(maker, 0) + signed_delta |
|
|
| positive_holders_raw = [(w, b) for w, b in wallet_balances_raw.items() if b > 0] |
| positive_holders_raw.sort(key=lambda item: (-item[1], item[0])) |
| holders_topk_raw = positive_holders_raw[:HOLDER_SNAPSHOT_TOP_K] |
|
|
| cutoff_ts_epoch = int(T_cutoff.timestamp()) |
| token_scale = 10 ** base_decimals if base_decimals else 1 |
|
|
| cutoff_snapshot = { |
| 'timestamp': cutoff_ts_epoch, |
| 'holders': [ |
| { |
| 'wallet_address': w, |
| 'current_balance': float(b) / float(token_scale) |
| } |
| for w, b in holders_topk_raw |
| ] |
| } |
|
|
| |
| local_holders_list = [ |
| snap for snap in (cached_holders_list or []) |
| if snap.get('timestamp', 0) < cutoff_ts_epoch |
| ] |
| |
| |
| if not local_holders_list or local_holders_list[-1]['timestamp'] != cutoff_ts_epoch: |
| local_holders_list.append(cutoff_snapshot) |
|
|
| |
| self._generate_onchain_snapshots( |
| token_address, int(t0_timestamp), T_cutoff, |
| 300, |
| trade_events, transfer_events, |
| aggregation_trades, |
| wallet_data, |
| total_supply_dec, |
| _register_event, |
| cached_holders_list=local_holders_list |
| ) |
|
|
| |
| |
| |
| non_chart_event_count = len(event_sequence_entries) |
| would_exceed = (non_chart_event_count + len(chart_events_1s)) > self.max_seq_len |
| selected_chart_events = chart_events_30s if would_exceed else chart_events_1s |
| selected_chart_signature = "chart-mid" if would_exceed else "chart-hd" |
| for chart_idx, chart_event in enumerate(selected_chart_events): |
| |
| _register_event( |
| chart_event, |
| _event_execution_sort_key(chart_event['timestamp'], slot=10**12, transaction_index=10**9, signature=f"{selected_chart_signature}-{chart_idx}") |
| ) |
|
|
| |
| event_sequence_entries.sort(key=lambda x: x[0]) |
| raw_event_sequence = [entry[1] for entry in event_sequence_entries] |
| |
| |
| event_sequence = self._apply_dynamic_sampling(raw_event_sequence) |
| |
| |
| |
| horizons = sorted(self.horizons_seconds) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| all_trades = [ |
| t for t in future_trades_for_labels |
| if t.get('success', False) and float(t.get('price_usd', 0) or 0) > 0 |
| ] |
| |
| if not all_trades: |
| |
| quant_payload = [ |
| event.get('quant_ohlc_features', []) |
| for event in event_sequence |
| if event.get('event_type') == 'Chart_Segment' |
| ] |
| return { |
| 'event_sequence': event_sequence, |
| 'wallets': wallet_data, |
| 'tokens': all_token_data, |
| 'graph_links': graph_links, |
| 'embedding_pooler': pooler, |
| 'quant_ohlc_features': quant_payload, |
| 'quant_feature_version': FEATURE_VERSION, |
| 'labels': torch.zeros(len(self.horizons_seconds), dtype=torch.float32), |
| 'labels_mask': torch.zeros(len(self.horizons_seconds), dtype=torch.float32), |
| 'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32), |
| } |
| |
| |
| all_trades.sort(key=lambda x: _timestamp_to_order_value(x['timestamp'])) |
| |
| |
| |
| current_price = 0.0 |
| cutoff_ts_val = T_cutoff.timestamp() |
| last_trade_ts_val = _timestamp_to_order_value(all_trades[-1]['timestamp']) |
| |
| |
| valid_trades_for_labels = [t for t in all_trades if t.get('success', False) and float(t.get('price_usd', 0) or 0) > 0] |
| |
| current_price_idx = -1 |
| for i, t in enumerate(valid_trades_for_labels): |
| if _timestamp_to_order_value(t['timestamp']) <= cutoff_ts_val: |
| current_price = float(t['price_usd']) |
| current_price_idx = i |
| else: |
| break |
| |
| |
| |
| label_values = [] |
| mask_values = [] |
| |
| |
| if current_price_idx < 0 or current_price <= 0: |
| |
| for h in horizons: |
| label_values.append(0.0) |
| mask_values.append(0.0) |
| else: |
| |
| last_bucket_price = current_price |
| prev_target_ts = cutoff_ts_val |
| |
| |
| def _has_future_trades(after_ts: float) -> bool: |
| for j in range(current_price_idx + 1, len(valid_trades_for_labels)): |
| if _timestamp_to_order_value(valid_trades_for_labels[j]['timestamp']) > after_ts: |
| return True |
| return False |
|
|
| for h in horizons: |
| target_ts = cutoff_ts_val + h |
| |
| bucket_price = last_bucket_price |
| found_future_trade = False |
|
|
| for j in range(current_price_idx + 1, len(valid_trades_for_labels)): |
| t = valid_trades_for_labels[j] |
| t_ts = _timestamp_to_order_value(t['timestamp']) |
| |
| if prev_target_ts < t_ts <= target_ts: |
| bucket_price = float(t['price_usd']) |
| found_future_trade = True |
| elif t_ts > target_ts: |
| break |
|
|
| if found_future_trade: |
| |
| ret = (bucket_price - current_price) / current_price |
| label_values.append(ret) |
| mask_values.append(1.0) |
| last_bucket_price = bucket_price |
| else: |
| |
| |
| |
| ret = (last_bucket_price - current_price) / current_price |
| label_values.append(ret) |
| mask_values.append(1.0 if _has_future_trades(target_ts) else 0.0) |
| |
| prev_target_ts = target_ts |
|
|
| |
|
|
| quant_payload = [ |
| event.get('quant_ohlc_features', []) |
| for event in event_sequence |
| if event.get('event_type') == 'Chart_Segment' |
| ] |
| return { |
| 'sample_idx': sample_idx if sample_idx is not None else -1, |
| 'token_address': token_address, |
| 't_cutoff': T_cutoff.isoformat() if T_cutoff else None, |
| 'event_sequence': event_sequence, |
| 'wallets': wallet_data, |
| 'tokens': all_token_data, |
| 'graph_links': graph_links, |
| 'embedding_pooler': pooler, |
| 'quant_ohlc_features': quant_payload, |
| 'quant_feature_version': FEATURE_VERSION, |
| 'labels': torch.tensor(label_values, dtype=torch.float32), |
| 'labels_mask': torch.tensor(mask_values, dtype=torch.float32), |
| 'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32), |
| } |
|
|
| def _embed_context(self, context: Dict[str, Any], encoder: Any) -> None: |
| """ |
| Helper to replace raw items in the embedding pooler with pre-computed embeddings |
| using the provided encoder (on GPU). |
| """ |
| pooler = context.get('embedding_pooler') |
| if not pooler: |
| return |
|
|
| |
| keys_to_embed_img = [] |
| images_to_embed = [] |
| |
| keys_to_embed_text = [] |
| texts_to_embed = [] |
|
|
| for key, entry in pooler.pool_map.items(): |
| item = entry['item'] |
| if isinstance(item, str): |
| |
| keys_to_embed_text.append(key) |
| texts_to_embed.append(item) |
| elif hasattr(item, 'resize') and not isinstance(item, torch.Tensor): |
| keys_to_embed_img.append(key) |
| images_to_embed.append(item) |
|
|
| |
| if images_to_embed: |
| |
| with torch.no_grad(): |
| img_embeddings = encoder(images_to_embed) |
| |
| |
| for i, (key, emb) in enumerate(zip(keys_to_embed_img, img_embeddings)): |
| if key in pooler.pool_map: |
| old_entry = pooler.pool_map[key] |
| pooler.pool_map[key] = {'item': emb.cpu().clone(), 'idx': old_entry['idx']} |
|
|
| |
| if texts_to_embed: |
| |
| with torch.no_grad(): |
| text_embeddings = encoder(texts_to_embed) |
| |
| |
| for i, (key, emb) in enumerate(zip(keys_to_embed_text, text_embeddings)): |
| if key in pooler.pool_map: |
| old_entry = pooler.pool_map[key] |
| pooler.pool_map[key] = {'item': emb.cpu().clone(), 'idx': old_entry['idx']} |
|
|
|
|
| def __cacheitem_context__(self, idx: int, num_samples_per_token: int = 1, encoder: Optional[Any] = None, forced_cutoff_trade_idx: Optional[int] = None) -> List[Optional[Dict[str, Any]]]: |
| """ |
| Generates fully processed training contexts for caching. |
| |
| This method: |
| 1. Fetches raw token data (like __cacheitem__) |
| 2. Samples T_cutoff(s) using the weight sampling logic |
| 3. Applies H/B/H dynamic sampling based on max_seq_len |
| 4. Returns complete training-ready samples that can be loaded directly |
| |
| This moves ALL non-determinism into cache time, making training fully offline |
| and avoiding caching tokens that would never be seen (98% garbage filtered out |
| by weight sampling and T_cutoff eligibility). |
| |
| Args: |
| idx: Index into sampled_mints |
| num_samples_per_token: Number of different T_cutoff samples to generate per token |
| |
| Returns: |
| List of training-ready samples (may be fewer than num_samples_per_token if |
| some T_cutoffs are invalid) |
| """ |
| import time as _time |
|
|
| if not self.sampled_mints: |
| raise RuntimeError("Dataset has no mint records loaded.") |
| if idx >= len(self.sampled_mints): |
| raise IndexError(f"Index {idx} exceeds mint count {len(self.sampled_mints)}.") |
|
|
| initial_mint_record = self.sampled_mints[idx] |
| t0 = initial_mint_record["timestamp"] |
| if isinstance(t0, datetime.datetime) and t0.tzinfo is None: |
| t0 = t0.replace(tzinfo=datetime.timezone.utc) |
| creator_address = initial_mint_record['creator_address'] |
| token_address = initial_mint_record['mint_address'] |
|
|
| |
|
|
| if not self.fetcher: |
| raise RuntimeError("Dataset has no data fetcher.") |
|
|
| |
| raw_data = self.fetcher.fetch_raw_token_data( |
| token_address=token_address, |
| creator_address=creator_address, |
| mint_timestamp=t0, |
| max_horizon_seconds=self.max_cache_horizon_seconds, |
| include_wallet_data=False, |
| include_graph=False, |
| min_trades=self.min_trades, |
| full_history=True, |
| prune_failed=False, |
| prune_transfers=False |
| ) |
|
|
| if raw_data is None: |
| print(f" SKIP: No raw data for {token_address}") |
| return [] |
|
|
| |
| |
| if idx == 0: |
| print(f" DEBUG: initial_mint_record keys: {list(initial_mint_record.keys())}") |
| print(f" DEBUG: token_name='{initial_mint_record.get('token_name')}', token_symbol='{initial_mint_record.get('token_symbol')}'") |
| raw_data['name'] = initial_mint_record.get('token_name', '') |
| raw_data['symbol'] = initial_mint_record.get('token_symbol', '') |
| raw_data['token_uri'] = initial_mint_record.get('token_uri', '') |
| raw_total_supply = initial_mint_record.get('total_supply', DEFAULT_TOTAL_SUPPLY_RAW) |
| raw_token_decimals = initial_mint_record.get('token_decimals', DEFAULT_TOKEN_DECIMALS) |
| raw_data['total_supply'] = ( |
| int(raw_total_supply) if raw_total_supply and int(raw_total_supply) > 0 else DEFAULT_TOTAL_SUPPLY_RAW |
| ) |
| raw_data['decimals'] = ( |
| int(raw_token_decimals) if raw_token_decimals is not None and int(raw_token_decimals) >= 0 else DEFAULT_TOKEN_DECIMALS |
| ) |
| raw_data['protocol'] = initial_mint_record.get('protocol', 1) |
|
|
| def _timestamp_to_order_value(ts_value) -> float: |
| if isinstance(ts_value, datetime.datetime): |
| if ts_value.tzinfo is None: |
| ts_value = ts_value.replace(tzinfo=datetime.timezone.utc) |
| return ts_value.timestamp() |
| elif isinstance(ts_value, str): |
| try: |
| return datetime.datetime.fromisoformat(ts_value.replace('Z', '+00:00')).timestamp() |
| except ValueError: |
| pass |
| try: |
| return float(ts_value) |
| except: |
| return 0.0 |
|
|
| |
| all_trades_raw = raw_data.get('trades', []) |
| if not all_trades_raw: |
| print(f" SKIP: No trades for {token_address}") |
| return [] |
|
|
| all_trades_sorted = sorted( |
| [t for t in all_trades_raw if t.get('timestamp') is not None], |
| key=lambda t: _timestamp_to_order_value(t.get('timestamp')) |
| ) |
|
|
| min_context_trades = self.min_trades |
| if len(all_trades_sorted) < (min_context_trades + 1): |
| print(f" SKIP: Not enough trades ({len(all_trades_sorted)}) for {token_address}") |
| return [] |
|
|
| |
| successful_indices = [ |
| i for i, t in enumerate(all_trades_sorted) |
| if t.get('success', False) and float(t.get('price_usd', 0) or 0) > 0 |
| ] |
|
|
| if len(successful_indices) < 2: |
| print(f" SKIP: Not enough successful trades for {token_address}") |
| return [] |
|
|
| max_horizon_seconds = max(self.horizons_seconds) if self.horizons_seconds else 0 |
| min_idx = min_context_trades - 1 |
| max_idx = len(all_trades_sorted) - 2 |
|
|
| if max_idx < min_idx: |
| print(f" SKIP: Invalid index range for {token_address}") |
| return [] |
|
|
| |
| last_successful_before = [-1] * len(all_trades_sorted) |
| last_seen = -1 |
| succ_set = set(successful_indices) |
| for i in range(len(all_trades_sorted)): |
| if i in succ_set: |
| last_seen = i |
| last_successful_before[i] = last_seen |
|
|
| next_successful_after = [-1] * len(all_trades_sorted) |
| next_seen = -1 |
| for i in range(len(all_trades_sorted) - 1, -1, -1): |
| if i in succ_set: |
| next_seen = i |
| next_successful_after[i] = next_seen |
|
|
| |
| eligible_indices = [] |
| for i in range(min_idx, max_idx + 1): |
| anchor_idx = last_successful_before[i] |
| next_idx = next_successful_after[i + 1] if i + 1 < len(all_trades_sorted) else -1 |
| if anchor_idx < 0 or next_idx < 0: |
| continue |
| cutoff_ts = _timestamp_to_order_value(all_trades_sorted[i].get('timestamp')) |
| next_ts = _timestamp_to_order_value(all_trades_sorted[next_idx].get('timestamp')) |
| if next_ts <= cutoff_ts + max_horizon_seconds: |
| eligible_indices.append(i) |
|
|
| if not eligible_indices: |
| print(f" SKIP: No eligible T_cutoff indices for {token_address}") |
| return [] |
|
|
| |
|
|
| |
| trades = raw_data.get('trades', []) |
| trade_ts_values = [_timestamp_to_order_value(t.get('timestamp')) for t in trades] |
| t0_val = _timestamp_to_order_value(t0) |
| last_trade_ts_val = max(trade_ts_values) |
|
|
| |
| |
| duration_seconds = int(last_trade_ts_val - t0_val) + 120 |
| raw_data['ohlc_1s'] = None |
|
|
| |
| interval = 300 |
| num_intervals = (duration_seconds // interval) + 1 |
| snapshot_stats = torch.zeros((num_intervals, 6), dtype=torch.float32) |
|
|
| buckets = defaultdict(list) |
| for t in trades: |
| ts = _timestamp_to_order_value(t['timestamp']) |
| bucket_idx = int(ts - t0_val) // interval |
| if bucket_idx >= 0: |
| buckets[bucket_idx].append(t) |
|
|
| raw_total_supply = raw_data.get('total_supply') |
| raw_decimals = raw_data.get('decimals') |
| if raw_total_supply is None or raw_decimals is None: |
| raise RuntimeError("Missing token total_supply/decimals required for holder snapshot reconstruction.") |
| total_supply_raw = int(raw_total_supply) |
| token_decimals = int(raw_decimals) |
| if total_supply_raw <= 0: |
| total_supply_raw = DEFAULT_TOTAL_SUPPLY_RAW |
| if token_decimals < 0: |
| token_decimals = DEFAULT_TOKEN_DECIMALS |
| token_scale = 10 ** token_decimals |
|
|
| def _strict_int(v: Any, field_name: str) -> int: |
| if v is None: |
| raise RuntimeError(f"Missing {field_name} in trade record for {token_address}.") |
| try: |
| return int(v) |
| except Exception as e: |
| raise RuntimeError(f"Invalid {field_name} in trade record for {token_address}: {v}") from e |
|
|
| def _trade_sort_key_for_ledger(trade: Dict[str, Any]) -> tuple: |
| return ( |
| _timestamp_to_order_value(trade.get('timestamp')), |
| _strict_int(trade.get('slot', 0), 'slot'), |
| _strict_int(trade.get('transaction_index', 0), 'transaction_index'), |
| _strict_int(trade.get('instruction_index', 0), 'instruction_index'), |
| str(trade.get('signature') or '') |
| ) |
|
|
| ledger_trades = [] |
| for trade in trades: |
| if not trade.get('success', False): |
| continue |
| maker = trade.get('maker') |
| if not maker: |
| raise RuntimeError(f"Missing maker in successful trade for {token_address}.") |
| trade_type = _strict_int(trade.get('trade_type'), 'trade_type') |
| if trade_type not in (0, 1): |
| raise RuntimeError(f"Invalid trade_type={trade_type} for {token_address}; expected 0/1.") |
| base_amount_raw = _strict_int(trade.get('base_amount'), 'base_amount') |
| if base_amount_raw < 0: |
| raise RuntimeError(f"Invalid negative base_amount={base_amount_raw} for {token_address}.") |
| ledger_trades.append((trade, maker, trade_type, base_amount_raw)) |
|
|
| ledger_trades.sort(key=lambda x: _trade_sort_key_for_ledger(x[0])) |
| wallet_balances_raw: Dict[str, int] = {} |
| ledger_idx = 0 |
| holder_snapshots_list = [] |
|
|
| for i in range(num_intervals): |
| bucket_trades = buckets[i] |
| vol = sum(t.get('total_usd', 0.0) for t in bucket_trades) |
| tx = len(bucket_trades) |
| buys = sum(1 for t in bucket_trades if t.get('trade_direction') == 0 or t.get('trade_type') == 0) |
| sells = tx - buys |
| snapshot_ts_epoch = t0_val + ((i + 1) * interval) |
|
|
| while ledger_idx < len(ledger_trades): |
| trade, maker, trade_type, base_amount_raw = ledger_trades[ledger_idx] |
| trade_ts = _timestamp_to_order_value(trade.get('timestamp')) |
| if trade_ts > snapshot_ts_epoch: |
| break |
| signed_delta = base_amount_raw if trade_type == 0 else -base_amount_raw |
| wallet_balances_raw[maker] = wallet_balances_raw.get(maker, 0) + signed_delta |
| ledger_idx += 1 |
|
|
| positive_holders_raw = [(wallet, bal) for wallet, bal in wallet_balances_raw.items() if bal > 0] |
| positive_holders_raw.sort(key=lambda item: (-item[1], item[0])) |
| holders_topk_raw = positive_holders_raw[:HOLDER_SNAPSHOT_TOP_K] |
| count = len(positive_holders_raw) |
| top10_sum_raw = sum(bal for _, bal in positive_holders_raw[:10]) |
| top10_pct = float(top10_sum_raw) / float(total_supply_raw) |
|
|
| snapshot_stats[i, 0] = float(vol) |
| snapshot_stats[i, 1] = float(tx) |
| snapshot_stats[i, 2] = float(buys) |
| snapshot_stats[i, 3] = float(sells) |
| snapshot_stats[i, 4] = float(count) |
| snapshot_stats[i, 5] = float(top10_pct) |
|
|
| snapshot_ts = t0 + datetime.timedelta(seconds=(i+1)*interval) |
| holder_snapshots_list.append({ |
| 'timestamp': int(snapshot_ts.timestamp()), |
| 'holders': [ |
| { |
| 'wallet_address': wallet, |
| 'current_balance': float(balance_raw) / float(token_scale) |
| } |
| for wallet, balance_raw in holders_topk_raw |
| ] |
| }) |
|
|
| raw_data['snapshots_5m'] = snapshot_stats |
| raw_data['holder_snapshots_list'] = holder_snapshots_list |
| raw_data['protocol_id'] = initial_mint_record.get('protocol') |
|
|
| |
| all_wallets = set() |
| all_wallets.add(creator_address) |
|
|
| for trade in raw_data.get('trades', []): |
| if trade.get('maker'): |
| all_wallets.add(trade['maker']) |
| for transfer in raw_data.get('transfers', []): |
| if transfer.get('source'): |
| all_wallets.add(transfer['source']) |
| if transfer.get('destination'): |
| all_wallets.add(transfer['destination']) |
| for pool in raw_data.get('pool_creations', []): |
| if pool.get('creator_address'): |
| all_wallets.add(pool['creator_address']) |
| for liq in raw_data.get('liquidity_changes', []): |
| if liq.get('lp_provider'): |
| all_wallets.add(liq['lp_provider']) |
| for snapshot in holder_snapshots_list: |
| if not isinstance(snapshot, dict) or not isinstance(snapshot.get('holders'), list): |
| raise RuntimeError("Invalid holder_snapshots_list entry during wallet collection.") |
| for holder in snapshot['holders']: |
| if not isinstance(holder, dict) or 'wallet_address' not in holder or 'current_balance' not in holder: |
| raise RuntimeError("Invalid holder record during wallet collection.") |
| all_wallets.add(holder['wallet_address']) |
|
|
| all_wallets.discard(None) |
| all_wallets.discard('') |
| wallet_list = list(all_wallets) |
|
|
| max_T_cutoff = datetime.datetime.fromtimestamp(last_trade_ts_val, tz=datetime.timezone.utc) |
|
|
| |
| cached_profiles, cached_socials = {}, {} |
| cached_holdings = {} |
| cached_graph_entities, cached_graph_links = {}, {} |
| cached_image_bytes = None |
|
|
| def _fetch_clickhouse_data(): |
| """ClickHouse client is not thread-safe, so both queries run sequentially in one thread.""" |
| profiles, socials = self.fetcher.fetch_wallet_profiles_and_socials(wallet_list, max_T_cutoff) |
| holdings = self.fetcher.fetch_wallet_holdings(wallet_list, max_T_cutoff) |
| return profiles, socials, holdings |
|
|
| def _fetch_graph_data(): |
| return self.fetcher.fetch_graph_links(wallet_list, max_T_cutoff, max_degrees=1) |
|
|
| def _fetch_token_image(): |
| try: |
| bullx_image_url = f"https://image.bullx.io/1399811149/{token_address}?retry=0" |
| resp = self.http_session.get(bullx_image_url, timeout=2) |
| if resp.status_code == 200: |
| return resp.content |
| except: |
| pass |
| return None |
|
|
| with ThreadPoolExecutor(max_workers=3) as executor: |
| future_ch = executor.submit(_fetch_clickhouse_data) |
| future_graph = executor.submit(_fetch_graph_data) |
| future_image = executor.submit(_fetch_token_image) |
|
|
| try: |
| cached_profiles, cached_socials, cached_holdings = future_ch.result() |
| except Exception as e: |
| print(f" WARN: Failed to fetch ClickHouse data: {e}") |
|
|
| try: |
| cached_graph_entities, cached_graph_links = future_graph.result() |
| except Exception as e: |
| print(f" WARN: Failed to fetch graph links: {e}") |
|
|
| cached_image_bytes = future_image.result() |
|
|
| |
| results = [] |
|
|
| |
| if forced_cutoff_trade_idx is not None: |
| |
| if forced_cutoff_trade_idx >= len(all_trades_sorted): |
| print(f" WARN: forced_cutoff_trade_idx={forced_cutoff_trade_idx} >= total trades {len(all_trades_sorted)}, clamping.") |
| forced_cutoff_trade_idx = len(all_trades_sorted) - 2 |
| sampled_indices = [forced_cutoff_trade_idx] |
| print(f" Using forced T_cutoff at trade index {forced_cutoff_trade_idx}") |
| elif num_samples_per_token >= len(eligible_indices): |
| sampled_indices = eligible_indices.copy() |
| else: |
| sampled_indices = random.sample(eligible_indices, num_samples_per_token) |
|
|
| |
|
|
| for sample_num, sample_idx in enumerate(sampled_indices): |
| sample_trade = all_trades_sorted[sample_idx] |
| sample_offset_ts = _timestamp_to_order_value(sample_trade.get('timestamp')) |
| T_cutoff = datetime.datetime.fromtimestamp(sample_offset_ts, tz=datetime.timezone.utc) |
| cutoff_ts = sample_offset_ts |
|
|
| |
| wallets_to_fetch = set() |
| wallets_to_fetch.add(creator_address) |
|
|
| for trade in raw_data.get('trades', []): |
| if _timestamp_to_order_value(trade.get('timestamp')) <= cutoff_ts: |
| if trade.get('maker'): |
| wallets_to_fetch.add(trade['maker']) |
|
|
| for transfer in raw_data.get('transfers', []): |
| if _timestamp_to_order_value(transfer.get('timestamp')) <= cutoff_ts: |
| if transfer.get('source'): |
| wallets_to_fetch.add(transfer['source']) |
| if transfer.get('destination'): |
| wallets_to_fetch.add(transfer['destination']) |
|
|
| for pool in raw_data.get('pool_creations', []): |
| if _timestamp_to_order_value(pool.get('timestamp')) <= cutoff_ts: |
| if pool.get('creator_address'): |
| wallets_to_fetch.add(pool['creator_address']) |
|
|
| for liq in raw_data.get('liquidity_changes', []): |
| if _timestamp_to_order_value(liq.get('timestamp')) <= cutoff_ts: |
| if liq.get('lp_provider'): |
| wallets_to_fetch.add(liq['lp_provider']) |
|
|
| |
| elapsed = (T_cutoff - t0).total_seconds() |
| snap_idx = int(elapsed // 300) |
| if not (0 <= snap_idx < len(holder_snapshots_list)): |
| raise RuntimeError( |
| f"holder_snapshots_list index out of range in __cacheitem_context__ " |
| f"(snap_idx={snap_idx}, len={len(holder_snapshots_list)})." |
| ) |
| snapshot_data = holder_snapshots_list[snap_idx] |
| if not isinstance(snapshot_data, dict) or not isinstance(snapshot_data.get('holders'), list): |
| raise RuntimeError("Invalid holder snapshot entry in __cacheitem_context__.") |
| for holder in snapshot_data['holders']: |
| if not isinstance(holder, dict) or 'wallet_address' not in holder or 'current_balance' not in holder: |
| raise RuntimeError("Invalid holder record in __cacheitem_context__.") |
| wallets_to_fetch.add(holder['wallet_address']) |
|
|
| wallets_to_fetch.discard(None) |
| wallets_to_fetch.discard('') |
|
|
| |
| pooler = EmbeddingPooler() |
|
|
| |
| offline_token_data = {token_address: self._build_main_token_seed(token_address, raw_data)} |
| if cached_image_bytes: |
| try: |
| cached_image = Image.open(BytesIO(cached_image_bytes)) |
| offline_token_data[token_address]['_cached_image_pil'] = cached_image |
| except: |
| pass |
|
|
| main_token_data = self._process_token_data_offline( |
| [token_address], pooler, T_cutoff, token_data=offline_token_data |
| ) |
|
|
| if not main_token_data: |
| continue |
|
|
| |
| wallet_data, all_token_data = self._process_wallet_data( |
| list(wallets_to_fetch), |
| main_token_data.copy(), |
| pooler, |
| T_cutoff, |
| profiles_override=cached_profiles, |
| socials_override=cached_socials, |
| holdings_override=cached_holdings |
| ) |
|
|
| |
| mint_event = { |
| 'event_type': 'Mint', |
| 'timestamp': int(t0.timestamp()), |
| 'relative_ts': 0, |
| 'wallet_address': creator_address, |
| 'token_address': token_address, |
| 'protocol_id': raw_data.get('protocol_id', 0) |
| } |
|
|
| result = self._generate_dataset_item( |
| token_address=token_address, |
| t0=t0, |
| T_cutoff=T_cutoff, |
| mint_event=mint_event, |
| trade_records=raw_data['trades'], |
| transfer_records=raw_data['transfers'], |
| pool_creation_records=raw_data['pool_creations'], |
| liquidity_change_records=raw_data['liquidity_changes'], |
| fee_collection_records=raw_data['fee_collections'], |
| burn_records=raw_data['burns'], |
| supply_lock_records=raw_data['supply_locks'], |
| migration_records=raw_data['migrations'], |
| wallet_data=wallet_data, |
| all_token_data=all_token_data, |
| graph_links=cached_graph_links, |
| graph_seed_entities=wallets_to_fetch, |
| all_graph_entities=cached_graph_entities, |
| future_trades_for_labels=raw_data['trades'], |
| pooler=pooler, |
| sample_idx=idx, |
| cached_holders_list=holder_snapshots_list, |
| cached_ohlc_1s=None, |
| quality_score=None |
| ) |
|
|
| if result is not None: |
| results.append(result) |
| pass |
|
|
| |
| if encoder is not None: |
| |
| for ctx in results: |
| self._embed_context(ctx, encoder) |
| else: |
| if idx == 0: |
| print("DEBUG: No encoder provided to __cacheitem_context__", flush=True) |
|
|
| |
| return results |
|
|