# data_fetcher.py from typing import List, Dict, Any, Tuple, Set, Optional from collections import defaultdict import datetime, time # We need the vocabulary for mapping IDs import models.vocabulary as vocab class DataFetcher: """ A dedicated class to handle all database queries for ClickHouse and Neo4j. This keeps data fetching logic separate from the dataset and model. """ # --- Explicit column definitions for wallet profile & social fetches --- PROFILE_BASE_COLUMNS = [ 'wallet_address', 'updated_at', 'first_seen_ts', 'last_seen_ts', 'tags', 'deployed_tokens', 'funded_from', 'funded_timestamp', 'funded_signature', 'funded_amount' ] PROFILE_METRIC_COLUMNS = [ 'balance', 'transfers_in_count', 'transfers_out_count', 'spl_transfers_in_count', 'spl_transfers_out_count', 'total_buys_count', 'total_sells_count', 'total_winrate', 'stats_1d_realized_profit_sol', 'stats_1d_realized_profit_usd', 'stats_1d_realized_profit_pnl', 'stats_1d_buy_count', 'stats_1d_sell_count', 'stats_1d_transfer_in_count', 'stats_1d_transfer_out_count', 'stats_1d_avg_holding_period', 'stats_1d_total_bought_cost_sol', 'stats_1d_total_bought_cost_usd', 'stats_1d_total_sold_income_sol', 'stats_1d_total_sold_income_usd', 'stats_1d_total_fee', 'stats_1d_winrate', 'stats_1d_tokens_traded', 'stats_7d_realized_profit_sol', 'stats_7d_realized_profit_usd', 'stats_7d_realized_profit_pnl', 'stats_7d_buy_count', 'stats_7d_sell_count', 'stats_7d_transfer_in_count', 'stats_7d_transfer_out_count', 'stats_7d_avg_holding_period', 'stats_7d_total_bought_cost_sol', 'stats_7d_total_bought_cost_usd', 'stats_7d_total_sold_income_sol', 'stats_7d_total_sold_income_usd', 'stats_7d_total_fee', 'stats_7d_winrate', 'stats_7d_tokens_traded', 'stats_30d_realized_profit_sol', 'stats_30d_realized_profit_usd', 'stats_30d_realized_profit_pnl', 'stats_30d_buy_count', 'stats_30d_sell_count', 'stats_30d_transfer_in_count', 'stats_30d_transfer_out_count', 'stats_30d_avg_holding_period', 'stats_30d_total_bought_cost_sol', 'stats_30d_total_bought_cost_usd', 'stats_30d_total_sold_income_sol', 'stats_30d_total_sold_income_usd', 'stats_30d_total_fee', 'stats_30d_winrate', 'stats_30d_tokens_traded' ] PROFILE_COLUMNS_FOR_QUERY = PROFILE_BASE_COLUMNS + PROFILE_METRIC_COLUMNS SOCIAL_COLUMNS_FOR_QUERY = [ 'wallet_address', 'pumpfun_username', 'twitter_username', 'telegram_channel', 'kolscan_name', 'cabalspy_name', 'axiom_kol_name' ] DB_BATCH_SIZE = 5000 def __init__(self, clickhouse_client: Any, neo4j_driver: Any): self.db_client = clickhouse_client self.graph_client = neo4j_driver print("DataFetcher instantiated.") def get_all_mints(self, start_date: Optional[datetime.datetime] = None) -> List[Dict[str, Any]]: """ Fetches a list of all mint events to serve as dataset samples. Can be filtered to only include mints on or after a given start_date. """ query = "SELECT mint_address, timestamp, creator_address, protocol, token_name, token_symbol, token_uri, total_supply, token_decimals FROM mints" params = {} where_clauses = [] if start_date: where_clauses.append("timestamp >= %(start_date)s") params['start_date'] = start_date if where_clauses: query += " WHERE " + " AND ".join(where_clauses) try: rows, columns_info = self.db_client.execute(query, params, with_column_types=True) if not rows: return [] columns = [col[0] for col in columns_info] result = [dict(zip(columns, row)) for row in rows] if not result: return [] return result except Exception as e: print(f"ERROR: Failed to fetch token addresses from ClickHouse: {e}") print("INFO: Falling back to mock token addresses for development.") return [{'mint_address': 'tknA_real', 'timestamp': datetime.datetime.now(datetime.timezone.utc), 'creator_address': 'addr_Creator_Real', 'protocol': 0}] def fetch_mint_record(self, token_address: str) -> Dict[str, Any]: """ Fetches the raw mint record for a token from the 'mints' table. """ query = f"SELECT timestamp, creator_address, mint_address, protocol FROM mints WHERE mint_address = '{token_address}' ORDER BY timestamp ASC LIMIT 1" # Assumes the client returns a list of dicts or can be converted # Using column names from your schema columns = ['timestamp', 'creator_address', 'mint_address', 'protocol'] try: result = self.db_client.execute(query) if not result or not result[0]: raise ValueError(f"No mint event found for token {token_address}") # Convert the tuple result into a dictionary record = dict(zip(columns, result[0])) return record except Exception as e: print(f"ERROR: Failed to fetch mint record for {token_address}: {e}") print("INFO: Falling back to mock mint record for development.") # Fallback for development if DB connection fails return { 'timestamp': datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(days=1), 'creator_address': 'addr_Creator_Real', 'mint_address': token_address, 'protocol': vocab.PROTOCOL_TO_ID.get("Pump V1", 0) } def fetch_wallet_profiles(self, wallet_addresses: List[str], T_cutoff: datetime.datetime) -> Dict[str, Dict[str, Any]]: """ Convenience wrapper around fetch_wallet_profiles_and_socials for profile-only data. """ profiles, _ = self.fetch_wallet_profiles_and_socials(wallet_addresses, T_cutoff) return profiles def fetch_wallet_socials(self, wallet_addresses: List[str]) -> Dict[str, Dict[str, Any]]: """ Fetches wallet social records for a list of wallet addresses. Batches queries to avoid "Max query size exceeded" errors. Returns a dictionary mapping wallet_address to its social data. """ if not wallet_addresses: return {} BATCH_SIZE = self.DB_BATCH_SIZE socials = {} for i in range(0, len(wallet_addresses), BATCH_SIZE): batch_addresses = wallet_addresses[i : i + BATCH_SIZE] query = "SELECT * FROM wallet_socials WHERE wallet_address IN %(addresses)s" params = {'addresses': batch_addresses} try: rows, columns_info = self.db_client.execute(query, params, with_column_types=True) if not rows: continue columns = [col[0] for col in columns_info] for row in rows: social_dict = dict(zip(columns, row)) wallet_addr = social_dict.get('wallet_address') if wallet_addr: socials[wallet_addr] = social_dict except Exception as e: print(f"ERROR: Failed to fetch wallet socials for batch {i}: {e}") # Continue to next batch return socials def fetch_wallet_profiles_and_socials(self, wallet_addresses: List[str], T_cutoff: datetime.datetime) -> Tuple[Dict[str, Dict[str, Any]], Dict[str, Dict[str, Any]]]: """ Fetches wallet profiles (time-aware) and socials for all requested wallets. Batches queries to avoid "Max query size exceeded" errors. Returns two dictionaries: profiles, socials. """ if not wallet_addresses: return {}, {} social_columns = self.SOCIAL_COLUMNS_FOR_QUERY profile_base_cols = self.PROFILE_BASE_COLUMNS profile_metric_cols = self.PROFILE_METRIC_COLUMNS profile_base_str = ",\n ".join(profile_base_cols) metric_projection_cols = ['wallet_address', 'updated_at'] + profile_metric_cols profile_metric_str = ",\n ".join(metric_projection_cols) profile_base_select_cols = [col for col in profile_base_cols if col != 'wallet_address'] profile_metric_select_cols = [ col for col in profile_metric_cols if col not in ('wallet_address',) ] social_select_cols = [col for col in social_columns if col != 'wallet_address'] select_expressions = [] for col in profile_base_select_cols: select_expressions.append(f"lp.{col} AS profile__{col}") for col in profile_metric_select_cols: select_expressions.append(f"lm.{col} AS profile__{col}") for col in social_select_cols: select_expressions.append(f"ws.{col} AS social__{col}") select_clause = "" if select_expressions: select_clause = ",\n " + ",\n ".join(select_expressions) profile_keys = [f"profile__{col}" for col in (profile_base_select_cols + profile_metric_select_cols)] social_keys = [f"social__{col}" for col in social_select_cols] BATCH_SIZE = self.DB_BATCH_SIZE all_profiles = {} all_socials = {} for i in range(0, len(wallet_addresses), BATCH_SIZE): batch_addresses = wallet_addresses[i : i + BATCH_SIZE] query = f""" WITH ranked_profiles AS ( SELECT {profile_base_str}, ROW_NUMBER() OVER (PARTITION BY wallet_address ORDER BY updated_at DESC) AS rn FROM wallet_profiles WHERE wallet_address IN %(addresses)s ), latest_profiles AS ( SELECT {profile_base_str} FROM ranked_profiles WHERE rn = 1 ), ranked_metrics AS ( SELECT {profile_metric_str}, ROW_NUMBER() OVER (PARTITION BY wallet_address ORDER BY updated_at DESC) AS rn FROM wallet_profile_metrics WHERE wallet_address IN %(addresses)s AND updated_at <= %(T_cutoff)s ), latest_metrics AS ( SELECT {profile_metric_str} FROM ranked_metrics WHERE rn = 1 ), requested_wallets AS ( SELECT DISTINCT wallet_address FROM (SELECT arrayJoin(%(addresses)s) AS wallet_address) ) SELECT rw.wallet_address AS wallet_address {select_clause} FROM requested_wallets AS rw LEFT JOIN latest_profiles AS lp ON rw.wallet_address = lp.wallet_address LEFT JOIN latest_metrics AS lm ON rw.wallet_address = lm.wallet_address LEFT JOIN wallet_socials AS ws ON rw.wallet_address = ws.wallet_address; """ params = {'addresses': batch_addresses, 'T_cutoff': T_cutoff} try: rows, columns_info = self.db_client.execute(query, params, with_column_types=True) if not rows: continue columns = [col[0] for col in columns_info] for row in rows: row_dict = dict(zip(columns, row)) wallet_addr = row_dict.get('wallet_address') if not wallet_addr: continue profile_data = {} if profile_keys: for pref_key in profile_keys: if pref_key in row_dict: value = row_dict[pref_key] profile_data[pref_key.replace('profile__', '')] = value if profile_data and any(value is not None for value in profile_data.values()): profile_data['wallet_address'] = wallet_addr all_profiles[wallet_addr] = profile_data social_data = {} if social_keys: for pref_key in social_keys: if pref_key in row_dict: value = row_dict[pref_key] social_data[pref_key.replace('social__', '')] = value if social_data and any(value is not None for value in social_data.values()): social_data['wallet_address'] = wallet_addr all_socials[wallet_addr] = social_data except Exception as e: print(f"ERROR: Combined profile/social query failed for batch {i}-{i+BATCH_SIZE}: {e}") # We continue to the next batch return all_profiles, all_socials def fetch_wallet_holdings(self, wallet_addresses: List[str], T_cutoff: datetime.datetime) -> Dict[str, List[Dict[str, Any]]]: """ Fetches top 2 wallet holding records for a list of wallet addresses that were active at T_cutoff. Batches queries to avoid "Max query size exceeded" errors. Returns a dictionary mapping wallet_address to a LIST of its holding data. """ if not wallet_addresses: return {} BATCH_SIZE = self.DB_BATCH_SIZE holdings = defaultdict(list) for i in range(0, len(wallet_addresses), BATCH_SIZE): batch_addresses = wallet_addresses[i : i + BATCH_SIZE] # --- Time-aware query --- # 1. For each holding, find the latest state at or before T_cutoff. # 2. Filter for holdings where the balance was greater than 0. # 3. Rank these active holdings by USD volume and take the top 2 per wallet. query = """ WITH point_in_time_holdings AS ( SELECT *, COALESCE(history_bought_cost_sol, 0) + COALESCE(history_sold_income_sol, 0) AS total_volume_usd, ROW_NUMBER() OVER(PARTITION BY wallet_address, mint_address ORDER BY updated_at DESC) as rn_per_holding FROM wallet_holdings WHERE wallet_address IN %(addresses)s AND updated_at <= %(T_cutoff)s ), ranked_active_holdings AS ( SELECT *, ROW_NUMBER() OVER(PARTITION BY wallet_address ORDER BY total_volume_usd DESC) as rn_per_wallet FROM point_in_time_holdings WHERE rn_per_holding = 1 AND current_balance > 0 ) SELECT * FROM ranked_active_holdings WHERE rn_per_wallet <= 2; """ params = {'addresses': batch_addresses, 'T_cutoff': T_cutoff} try: rows, columns_info = self.db_client.execute(query, params, with_column_types=True) if not rows: continue columns = [col[0] for col in columns_info] for row in rows: holding_dict = dict(zip(columns, row)) wallet_addr = holding_dict.get('wallet_address') if wallet_addr: holdings[wallet_addr].append(holding_dict) except Exception as e: print(f"ERROR: Failed to fetch wallet holdings for batch {i}: {e}") # Continue to next batch return dict(holdings) def fetch_graph_links(self, initial_addresses: List[str], T_cutoff: datetime.datetime, max_degrees: int = 1) -> Tuple[Dict[str, str], Dict[str, Dict[str, Any]]]: """ Fetches graph links from Neo4j, traversing up to a max degree of separation. Args: initial_addresses: A list of starting wallet or token addresses. max_degrees: The maximum number of hops to traverse in the graph. Returns: A tuple containing: - A dictionary mapping entity addresses to their type ('Wallet' or 'Token'). - A dictionary of aggregated links, structured for the GraphUpdater. """ if not initial_addresses: return {}, {} cutoff_ts = int(T_cutoff.timestamp()) max_retries = 3 backoff_sec = 2 for attempt in range(max_retries + 1): try: with self.graph_client.session() as session: all_entities = {addr: 'Token' for addr in initial_addresses} # Assume initial are tokens newly_found_entities = set(initial_addresses) aggregated_links = defaultdict(lambda: {'links': [], 'edges': []}) for i in range(max_degrees): if not newly_found_entities: break # --- TIMING: Query execution --- _t_query_start = time.perf_counter() # Cypher query to find direct neighbors of the current frontier # OPTIMIZED: Filter by timestamp IN Neo4j to avoid transferring 97%+ unused records query = """ MATCH (a)-[r]-(b) WHERE a.address IN $addresses AND r.timestamp <= $cutoff_ts RETURN a.address AS source_address, type(r) AS link_type, properties(r) AS link_props, b.address AS dest_address, labels(b)[0] AS dest_type LIMIT 10000 """ params = {'addresses': list(newly_found_entities), 'cutoff_ts': cutoff_ts} result = session.run(query, params) _t_query_done = time.perf_counter() # --- TIMING: Result processing --- _t_process_start = time.perf_counter() records_total = 0 current_degree_new_entities = set() for record in result: records_total += 1 link_type = record['link_type'] link_props = dict(record['link_props']) source_addr = record['source_address'] dest_addr = record['dest_address'] dest_type = record['dest_type'] # Add the link and edge data aggregated_links[link_type]['links'].append(link_props) aggregated_links[link_type]['edges'].append((source_addr, dest_addr)) # If we found a new entity, add it to the set for the next iteration if dest_addr not in all_entities.keys(): current_degree_new_entities.add(dest_addr) all_entities[dest_addr] = dest_type _t_process_done = time.perf_counter() newly_found_entities = current_degree_new_entities # --- Post-process: rename, map props, strip, cap --- MAX_LINKS_PER_TYPE = 500 # Neo4j type -> collator type name _NEO4J_TO_COLLATOR_NAME = { 'TRANSFERRED_TO': 'TransferLink', 'BUNDLE_TRADE': 'BundleTradeLink', 'COPIED_TRADE': 'CopiedTradeLink', 'COORDINATED_ACTIVITY': 'CoordinatedActivityLink', 'SNIPED': 'SnipedLink', 'MINTED': 'MintedLink', 'LOCKED_SUPPLY': 'LockedSupplyLink', 'BURNED': 'BurnedLink', 'PROVIDED_LIQUIDITY': 'ProvidedLiquidityLink', 'WHALE_OF': 'WhaleOfLink', 'TOP_TRADER_OF': 'TopTraderOfLink', } # Neo4j prop name -> encoder prop name (for fields with mismatched names) _PROP_REMAP = { 'CopiedTradeLink': { 'buy_gap': 'time_gap_on_buy_sec', 'sell_gap': 'time_gap_on_sell_sec', 'f_buy_total': 'follower_buy_total', 'f_sell_total': 'follower_sell_total', 'leader_pnl': 'leader_pnl', 'follower_pnl': 'follower_pnl', }, } # Only keep fields each encoder actually reads _NEEDED_FIELDS = { 'TransferLink': ['amount', 'mint'], 'BundleTradeLink': ['signatures'], # Neo4j has no total_amount; we derive it below 'CopiedTradeLink': ['time_gap_on_buy_sec', 'time_gap_on_sell_sec', 'leader_pnl', 'follower_pnl', 'follower_buy_total', 'follower_sell_total'], 'CoordinatedActivityLink': ['time_gap_on_first_sec', 'time_gap_on_second_sec'], 'SnipedLink': ['rank', 'sniped_amount'], 'MintedLink': ['buy_amount'], 'LockedSupplyLink': ['amount'], 'BurnedLink': ['amount'], 'ProvidedLiquidityLink': ['amount_quote'], 'WhaleOfLink': ['holding_pct_at_creation'], 'TopTraderOfLink': ['pnl_at_creation'], } cleaned_links = {} for neo4j_type, data in aggregated_links.items(): collator_name = _NEO4J_TO_COLLATOR_NAME.get(neo4j_type) if not collator_name: continue # Skip unknown link types links = data['links'] edges = data['edges'] # Cap links = links[:MAX_LINKS_PER_TYPE] edges = edges[:MAX_LINKS_PER_TYPE] # Remap property names if needed remap = _PROP_REMAP.get(collator_name) if remap: links = [{remap.get(k, k): v for k, v in l.items()} for l in links] # Strip to only needed fields needed = _NEEDED_FIELDS.get(collator_name, []) links = [{f: l.get(f, 0) for f in needed} for l in links] # BundleTradeLink: Neo4j has no total_amount; derive from signatures count if collator_name == 'BundleTradeLink': links = [{'total_amount': len(l.get('signatures', []) if isinstance(l.get('signatures'), list) else [])} for l in links] cleaned_links[collator_name] = {'links': links, 'edges': edges} return all_entities, cleaned_links except Exception as e: msg = str(e) is_rate_limit = "AuthenticationRateLimit" in msg or "RateLimit" in msg is_transient = "ServiceUnavailable" in msg or "TransientError" in msg or "SessionExpired" in msg if is_rate_limit or is_transient: if attempt < max_retries: sleep_time = backoff_sec * (2 ** attempt) print(f"WARN: Neo4j error ({type(e).__name__}). Retrying in {sleep_time}s... (Attempt {attempt+1}/{max_retries})") time.sleep(sleep_time) continue # If we're here, it's either not retryable or we ran out of retries # Ensure we use "FATAL" prefix so the caller knows to stop if required raise RuntimeError(f"FATAL: Failed to fetch graph links from Neo4j: {e}") from e def fetch_token_data(self, token_addresses: List[str], T_cutoff: datetime.datetime) -> Dict[str, Dict[str, Any]]: """ Fetches the latest token data for each address at or before T_cutoff. Batches queries to avoid "Max query size exceeded" errors. Returns a dictionary mapping token_address to its data. """ if not token_addresses: return {} BATCH_SIZE = self.DB_BATCH_SIZE tokens = {} for i in range(0, len(token_addresses), BATCH_SIZE): batch_addresses = token_addresses[i : i + BATCH_SIZE] # --- NEW: Time-aware query for historical token data --- query = """ WITH ranked_tokens AS ( SELECT *, ROW_NUMBER() OVER (PARTITION BY token_address ORDER BY updated_at DESC) as rn FROM tokens WHERE token_address IN %(addresses)s AND updated_at <= %(T_cutoff)s ) SELECT token_address, name, symbol, token_uri, protocol, total_supply, decimals FROM ranked_tokens WHERE rn = 1; """ params = {'addresses': batch_addresses, 'T_cutoff': T_cutoff} try: rows, columns_info = self.db_client.execute(query, params, with_column_types=True) if not rows: continue # Get column names from the query result description columns = [col[0] for col in columns_info] for row in rows: token_dict = dict(zip(columns, row)) token_addr = token_dict.get('token_address') if token_addr: # The 'tokens' table in the schema has 'token_address' but the # collator expects 'address'. We'll add it for compatibility. token_dict['address'] = token_addr tokens[token_addr] = token_dict except Exception as e: print(f"ERROR: Failed to fetch token data for batch {i}: {e}") # Continue next batch return tokens def fetch_deployed_token_details(self, token_addresses: List[str], T_cutoff: datetime.datetime) -> Dict[str, Dict[str, Any]]: """ Fetches historical details for deployed tokens at or before T_cutoff. Batches queries to avoid "Max query size exceeded" errors. """ if not token_addresses: return {} BATCH_SIZE = self.DB_BATCH_SIZE token_details = {} total_tokens = len(token_addresses) for i in range(0, total_tokens, BATCH_SIZE): batch_addresses = token_addresses[i : i + BATCH_SIZE] # --- NEW: Time-aware query for historical deployed token details --- query = """ WITH ranked_tokens AS ( SELECT *, ROW_NUMBER() OVER (PARTITION BY token_address ORDER BY updated_at DESC) as rn FROM tokens WHERE token_address IN %(addresses)s AND updated_at <= %(T_cutoff)s ), ranked_token_metrics AS ( SELECT token_address, ath_price_usd, ROW_NUMBER() OVER (PARTITION BY token_address ORDER BY updated_at DESC) as rn FROM token_metrics WHERE token_address IN %(addresses)s AND updated_at <= %(T_cutoff)s ), latest_tokens AS ( SELECT * FROM ranked_tokens WHERE rn = 1 ), latest_token_metrics AS ( SELECT * FROM ranked_token_metrics WHERE rn = 1 ) SELECT lt.token_address, lt.created_at, lt.updated_at, ltm.ath_price_usd, lt.total_supply, lt.decimals, (lt.launchpad != lt.protocol) AS has_migrated FROM latest_tokens AS lt LEFT JOIN latest_token_metrics AS ltm ON lt.token_address = ltm.token_address; """ params = {'addresses': batch_addresses, 'T_cutoff': T_cutoff} try: rows, columns_info = self.db_client.execute(query, params, with_column_types=True) if not rows: continue columns = [col[0] for col in columns_info] for row in rows: token_details[row[0]] = dict(zip(columns, row)) except Exception as e: print(f"ERROR: Failed to fetch deployed token details for batch {i}: {e}") # Continue next batch return token_details def fetch_trades_for_token(self, token_address: str, T_cutoff: datetime.datetime, count_threshold: int, early_limit: int, recent_limit: int, full_history: bool = False) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]: """ Fetches ALL trades for a token up to T_cutoff, ordered by time. Notes: - This intentionally does NOT apply the older fetch-time H/B/H (High-Def / Blurry / High-Def) sampling logic. Sequence-length control is handled later in data_loader.py via event-level head/tail sampling with MIDDLE/RECENT markers. - The function signature still includes legacy H/B/H parameters for compatibility. Returns: (all_trades, [], []) """ if not token_address: return [], [], [] params = {'token_address': token_address, 'T_cutoff': T_cutoff} query = "SELECT * FROM trades WHERE base_address = %(token_address)s AND timestamp <= %(T_cutoff)s ORDER BY timestamp ASC" try: rows, columns_info = self.db_client.execute(query, params, with_column_types=True) if not rows: return [], [], [] columns = [col[0] for col in columns_info] all_trades = [dict(zip(columns, row)) for row in rows] return all_trades, [], [] except Exception as e: print(f"ERROR: Failed to fetch trades for token {token_address}: {e}") return [], [], [] def fetch_future_trades_for_token(self, token_address: str, start_ts: datetime.datetime, end_ts: datetime.datetime) -> List[Dict[str, Any]]: """ Fetches successful trades for a token in the window (start_ts, end_ts]. Used for constructing label targets beyond the cutoff. """ if not token_address or start_ts is None or end_ts is None or start_ts >= end_ts: return [] query = """ SELECT * FROM trades WHERE base_address = %(token_address)s AND success = true AND timestamp > %(start_ts)s AND timestamp <= %(end_ts)s ORDER BY timestamp ASC """ params = { 'token_address': token_address, 'start_ts': start_ts, 'end_ts': end_ts } try: rows, columns_info = self.db_client.execute(query, params, with_column_types=True) if not rows: return [] columns = [col[0] for col in columns_info] return [dict(zip(columns, row)) for row in rows] except Exception as e: print(f"ERROR: Failed to fetch future trades for token {token_address}: {e}") return [] def fetch_transfers_for_token(self, token_address: str, T_cutoff: datetime.datetime, min_amount_threshold: float = 10_000_000) -> List[Dict[str, Any]]: """ Fetches all transfers for a token before T_cutoff, filtering out small amounts. """ if not token_address: return [] query = """ SELECT * FROM transfers WHERE mint_address = %(token_address)s AND timestamp <= %(T_cutoff)s AND amount_decimal >= %(min_amount)s ORDER BY timestamp ASC """ params = {'token_address': token_address, 'T_cutoff': T_cutoff, 'min_amount': min_amount_threshold} try: # This query no longer uses H/B/H, it fetches all significant transfers rows, columns_info = self.db_client.execute(query, params, with_column_types=True) if not rows: return [] columns = [col[0] for col in columns_info] return [dict(zip(columns, row)) for row in rows] except Exception as e: print(f"ERROR: Failed to fetch transfers for token {token_address}: {e}") return [] def fetch_pool_creations_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]: """ Fetches pool creation records where the token is the base asset. """ if not token_address: return [] query = """ SELECT signature, timestamp, slot, success, error, priority_fee, protocol, creator_address, pool_address, base_address, quote_address, lp_token_address, initial_base_liquidity, initial_quote_liquidity, base_decimals, quote_decimals FROM pool_creations WHERE base_address = %(token_address)s AND timestamp <= %(T_cutoff)s ORDER BY timestamp ASC """ params = {'token_address': token_address, 'T_cutoff': T_cutoff} # print(f"INFO: Fetching pool creation events for {token_address}.") try: rows, columns_info = self.db_client.execute(query, params, with_column_types=True) if not rows: return [] columns = [col[0] for col in columns_info] return [dict(zip(columns, row)) for row in rows] except Exception as e: print(f"ERROR: Failed to fetch pool creations for token {token_address}: {e}") return [] def fetch_liquidity_changes_for_pools(self, pool_addresses: List[str], T_cutoff: datetime.datetime) -> List[Dict[str, Any]]: """ Fetches liquidity change records for the given pools up to T_cutoff. """ if not pool_addresses: return [] query = """ SELECT signature, timestamp, slot, success, error, priority_fee, protocol, change_type, lp_provider, pool_address, base_amount, quote_amount FROM liquidity WHERE pool_address IN %(pool_addresses)s AND timestamp <= %(T_cutoff)s ORDER BY timestamp ASC """ params = {'pool_addresses': pool_addresses, 'T_cutoff': T_cutoff} # print(f"INFO: Fetching liquidity change events for {len(pool_addresses)} pool(s).") try: rows, columns_info = self.db_client.execute(query, params, with_column_types=True) if not rows: return [] columns = [col[0] for col in columns_info] return [dict(zip(columns, row)) for row in rows] except Exception as e: print(f"ERROR: Failed to fetch liquidity changes for pools {pool_addresses}: {e}") return [] def fetch_fee_collections_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]: """ Fetches fee collection events where the token appears as either token_0 or token_1. """ if not token_address: return [] query = """ SELECT timestamp, signature, slot, success, error, priority_fee, protocol, recipient_address, token_0_mint_address, token_0_amount, token_1_mint_address, token_1_amount FROM fee_collections WHERE (token_0_mint_address = %(token)s OR token_1_mint_address = %(token)s) AND timestamp <= %(T_cutoff)s ORDER BY timestamp ASC """ params = {'token': token_address, 'T_cutoff': T_cutoff} # print(f"INFO: Fetching fee collection events for {token_address}.") try: rows, columns_info = self.db_client.execute(query, params, with_column_types=True) if not rows: return [] columns = [col[0] for col in columns_info] return [dict(zip(columns, row)) for row in rows] except Exception as e: print(f"ERROR: Failed to fetch fee collections for token {token_address}: {e}") return [] def fetch_migrations_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]: """ Fetches migration records for a given token up to T_cutoff. """ if not token_address: return [] query = """ SELECT timestamp, signature, slot, success, error, priority_fee, protocol, mint_address, virtual_pool_address, pool_address, migrated_base_liquidity, migrated_quote_liquidity FROM migrations WHERE mint_address = %(token)s AND timestamp <= %(T_cutoff)s ORDER BY timestamp ASC """ params = {'token': token_address, 'T_cutoff': T_cutoff} # print(f"INFO: Fetching migrations for {token_address}.") try: rows, columns_info = self.db_client.execute(query, params, with_column_types=True) if not rows: return [] columns = [col[0] for col in columns_info] return [dict(zip(columns, row)) for row in rows] except Exception as e: print(f"ERROR: Failed to fetch migrations for token {token_address}: {e}") return [] def fetch_burns_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]: """ Fetches burn events for a given token up to T_cutoff. Schema: burns(timestamp, signature, slot, success, error, priority_fee, mint_address, source, amount, amount_decimal, source_balance) """ if not token_address: return [] query = """ SELECT timestamp, signature, slot, success, error, priority_fee, mint_address, source, amount, amount_decimal, source_balance FROM burns WHERE mint_address = %(token)s AND timestamp <= %(T_cutoff)s ORDER BY timestamp ASC """ params = {'token': token_address, 'T_cutoff': T_cutoff} # print(f"INFO: Fetching burn events for {token_address}.") try: rows, columns_info = self.db_client.execute(query, params, with_column_types=True) if not rows: return [] columns = [col[0] for col in columns_info] return [dict(zip(columns, row)) for row in rows] except Exception as e: print(f"ERROR: Failed to fetch burns for token {token_address}: {e}") return [] def fetch_supply_locks_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]: """ Fetches supply lock events for a given token up to T_cutoff. Schema: supply_locks(timestamp, signature, slot, success, error, priority_fee, protocol, contract_address, sender, recipient, mint_address, total_locked_amount, final_unlock_timestamp) """ if not token_address: return [] query = """ SELECT timestamp, signature, slot, success, error, priority_fee, protocol, contract_address, sender, recipient, mint_address, total_locked_amount, final_unlock_timestamp FROM supply_locks WHERE mint_address = %(token)s AND timestamp <= %(T_cutoff)s ORDER BY timestamp ASC """ params = {'token': token_address, 'T_cutoff': T_cutoff} # print(f"INFO: Fetching supply lock events for {token_address}.") try: rows, columns_info = self.db_client.execute(query, params, with_column_types=True) if not rows: return [] columns = [col[0] for col in columns_info] return [dict(zip(columns, row)) for row in rows] except Exception as e: print(f"ERROR: Failed to fetch supply locks for token {token_address}: {e}") return [] def fetch_token_holders_for_snapshot(self, token_address: str, T_cutoff: datetime.datetime, limit: int = 200) -> List[Dict[str, Any]]: """ Fetch top holders for a token at or before T_cutoff for snapshot purposes. Reconstructs holdings from trades table (buys - sells) since wallet_holdings may not have full point-in-time history. Returns rows with wallet_address and current_balance (>0), ordered by balance desc. """ if not token_address: return [] query = """ SELECT maker as wallet_address, SUM(CASE WHEN trade_type = 0 THEN toInt64(base_amount) ELSE -toInt64(base_amount) END) / 1000000.0 as current_balance FROM trades WHERE base_address = %(token)s AND timestamp <= %(T_cutoff)s AND success = 1 GROUP BY maker HAVING current_balance > 0 ORDER BY current_balance DESC LIMIT %(limit)s; """ params = {'token': token_address, 'T_cutoff': T_cutoff, 'limit': int(limit)} try: rows, columns_info = self.db_client.execute(query, params, with_column_types=True) if not rows: return [] columns = [col[0] for col in columns_info] return [dict(zip(columns, row)) for row in rows] except Exception as e: print(f"ERROR: Failed to fetch token holders for {token_address}: {e}") return [] def fetch_total_holders_count_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> int: """ Returns the total number of wallets holding the token (balance > 0) at or before T_cutoff. Reconstructs from trades table. """ if not token_address: return 0 query = """ SELECT count() FROM ( SELECT maker FROM trades WHERE base_address = %(token)s AND timestamp <= %(T_cutoff)s AND success = 1 GROUP BY maker HAVING SUM(CASE WHEN trade_type = 0 THEN toInt64(base_amount) ELSE -toInt64(base_amount) END) > 0 ); """ params = {'token': token_address, 'T_cutoff': T_cutoff} try: rows = self.db_client.execute(query, params) if not rows: return 0 return int(rows[0][0]) except Exception as e: print(f"ERROR: Failed to count total holders for token {token_address}: {e}") return 0 def fetch_holder_snapshot_stats_for_token(self, token_address: str, T_cutoff: datetime.datetime, limit: int = 200) -> Tuple[int, List[Dict[str, Any]]]: """ Fetch total holder count and top holders at a point in time. Returns (count, top_holders_list). Uses the indexed wallet_holdings table directly - efficient due to mint_address filter. """ if not token_address: return 0, [] # Fetch actual holder data top_holders = self.fetch_token_holders_for_snapshot(token_address, T_cutoff, limit) holder_count = self.fetch_total_holders_count_for_token(token_address, T_cutoff) return holder_count, top_holders def fetch_raw_token_data( self, token_address: str, creator_address: str, mint_timestamp: datetime.datetime, max_horizon_seconds: int = 3600, include_wallet_data: bool = True, include_graph: bool = True, min_trades: int = 0, full_history: bool = False, prune_failed: bool = False, prune_transfers: bool = False ) -> Optional[Dict[str, Any]]: """ Fetches ALL available data for a token up to the maximum horizon. This data is agnostic of T_cutoff and will be masked/filtered dynamically during training. Wallet/graph data can be skipped to avoid caching T_cutoff-dependent features. Args: full_history: If True, fetches ALL trades ignoring H/B/H limits. prune_failed: If True, filters out failed trades from the result. prune_transfers: If True, skips fetching transfers entirely. """ # 1. Calculate the absolute maximum timestamp we care about (mint + max_horizon) # We fetch everything up to this point. max_limit_time = mint_timestamp + datetime.timedelta(seconds=max_horizon_seconds) # 2. Fetch all trades up to max_limit_time # Note: We pass None as T_cutoff to fetch_trades_for_token if we want *everything*, # but here we likely want to bound it by our max training horizon to avoid fetching months of data. # However, the existing method signature expects T_cutoff. # So we pass max_limit_time as the "cutoff" for the purpose of raw data collection. # We use a large enough limit to get all relevant trades for the session # If full_history is True, these limits are ignored inside the method. early_trades, middle_trades, recent_trades = self.fetch_trades_for_token( token_address, max_limit_time, 30000, 10000, 15000, full_history=full_history ) # Combine and deduplicate trades all_trades = {} for t in early_trades + middle_trades + recent_trades: # key: (slot, tx_idx, instr_idx) key = (t.get('slot'), t.get('transaction_index'), t.get('instruction_index'), t.get('signature')) all_trades[key] = t sorted_trades = sorted(list(all_trades.values()), key=lambda x: x['timestamp']) # --- PRUNING FAILED TRADES --- if prune_failed: original_count = len(sorted_trades) sorted_trades = [t for t in sorted_trades if t.get('success', False)] if len(sorted_trades) < original_count: # print(f" INFO: Pruned {original_count - len(sorted_trades)} failed trades.") pass if len(sorted_trades) < min_trades: print(f" SKIP: Token {token_address} has only {len(sorted_trades)} trades (min required: {min_trades}). skipping fetches.") return None # 3. Fetch other events # --- PRUNING TRANSFERS --- if prune_transfers: transfers = [] # print(" INFO: Pruning transfers (skipping fetch).") else: transfers = self.fetch_transfers_for_token(token_address, max_limit_time, 0.0) # 0.0 means fetch all pool_creations = self.fetch_pool_creations_for_token(token_address, max_limit_time) # Collect pool addresses to fetch liquidity changes pool_addresses = [p['pool_address'] for p in pool_creations if p.get('pool_address')] liquidity_changes = [] if pool_addresses: liquidity_changes = self.fetch_liquidity_changes_for_pools(pool_addresses, max_limit_time) fee_collections = self.fetch_fee_collections_for_token(token_address, max_limit_time) burns = self.fetch_burns_for_token(token_address, max_limit_time) supply_locks = self.fetch_supply_locks_for_token(token_address, max_limit_time) migrations = self.fetch_migrations_for_token(token_address, max_limit_time) profile_data = {} social_data = {} holdings_data = {} deployed_token_details = {} fetched_graph_entities = {} graph_links = {} unique_wallets = set() if include_wallet_data or include_graph: # Identify wallets that interacted with the token up to max_limit_time. unique_wallets.add(creator_address) for t in sorted_trades: if t.get('maker'): unique_wallets.add(t['maker']) for t in transfers: if t.get('source'): unique_wallets.add(t['source']) if t.get('destination'): unique_wallets.add(t['destination']) for p in pool_creations: if p.get('creator_address'): unique_wallets.add(p['creator_address']) for l in liquidity_changes: if l.get('lp_provider'): unique_wallets.add(l['lp_provider']) if include_wallet_data and unique_wallets: # Profiles/holdings are time-dependent; only fetch if explicitly requested. profile_data, social_data = self.fetch_wallet_profiles_and_socials(list(unique_wallets), max_limit_time) holdings_data = self.fetch_wallet_holdings(list(unique_wallets), max_limit_time) all_deployed_tokens = set() for profile in profile_data.values(): all_deployed_tokens.update(profile.get('deployed_tokens', [])) if all_deployed_tokens: deployed_token_details = self.fetch_deployed_token_details(list(all_deployed_tokens), max_limit_time) if include_graph and unique_wallets: graph_seed_wallets = list(unique_wallets) if len(graph_seed_wallets) > 100: pass fetched_graph_entities, graph_links = self.fetch_graph_links( graph_seed_wallets, max_limit_time, max_degrees=1 ) return { "token_address": token_address, "creator_address": creator_address, "mint_timestamp": mint_timestamp, "max_limit_time": max_limit_time, "trades": sorted_trades, "transfers": transfers, "pool_creations": pool_creations, "liquidity_changes": liquidity_changes, "fee_collections": fee_collections, "burns": burns, "supply_locks": supply_locks, "migrations": migrations, "profiles": profile_data, "socials": social_data, "holdings": holdings_data, "deployed_token_details": deployed_token_details, "graph_entities": fetched_graph_entities, "graph_links": graph_links }