oracle / data /data_fetcher.py
zirobtc's picture
Upload folder using huggingface_hub
a547253 verified
# data_fetcher.py
from typing import List, Dict, Any, Tuple, Set, Optional
from collections import defaultdict
import datetime, time
# We need the vocabulary for mapping IDs
import models.vocabulary as vocab
class DataFetcher:
"""
A dedicated class to handle all database queries for ClickHouse and Neo4j.
This keeps data fetching logic separate from the dataset and model.
"""
# --- Explicit column definitions for wallet profile & social fetches ---
PROFILE_BASE_COLUMNS = [
'wallet_address',
'updated_at',
'first_seen_ts',
'last_seen_ts',
'tags',
'deployed_tokens',
'funded_from',
'funded_timestamp',
'funded_signature',
'funded_amount'
]
PROFILE_METRIC_COLUMNS = [
'balance',
'transfers_in_count',
'transfers_out_count',
'spl_transfers_in_count',
'spl_transfers_out_count',
'total_buys_count',
'total_sells_count',
'total_winrate',
'stats_1d_realized_profit_sol',
'stats_1d_realized_profit_usd',
'stats_1d_realized_profit_pnl',
'stats_1d_buy_count',
'stats_1d_sell_count',
'stats_1d_transfer_in_count',
'stats_1d_transfer_out_count',
'stats_1d_avg_holding_period',
'stats_1d_total_bought_cost_sol',
'stats_1d_total_bought_cost_usd',
'stats_1d_total_sold_income_sol',
'stats_1d_total_sold_income_usd',
'stats_1d_total_fee',
'stats_1d_winrate',
'stats_1d_tokens_traded',
'stats_7d_realized_profit_sol',
'stats_7d_realized_profit_usd',
'stats_7d_realized_profit_pnl',
'stats_7d_buy_count',
'stats_7d_sell_count',
'stats_7d_transfer_in_count',
'stats_7d_transfer_out_count',
'stats_7d_avg_holding_period',
'stats_7d_total_bought_cost_sol',
'stats_7d_total_bought_cost_usd',
'stats_7d_total_sold_income_sol',
'stats_7d_total_sold_income_usd',
'stats_7d_total_fee',
'stats_7d_winrate',
'stats_7d_tokens_traded',
'stats_30d_realized_profit_sol',
'stats_30d_realized_profit_usd',
'stats_30d_realized_profit_pnl',
'stats_30d_buy_count',
'stats_30d_sell_count',
'stats_30d_transfer_in_count',
'stats_30d_transfer_out_count',
'stats_30d_avg_holding_period',
'stats_30d_total_bought_cost_sol',
'stats_30d_total_bought_cost_usd',
'stats_30d_total_sold_income_sol',
'stats_30d_total_sold_income_usd',
'stats_30d_total_fee',
'stats_30d_winrate',
'stats_30d_tokens_traded'
]
PROFILE_COLUMNS_FOR_QUERY = PROFILE_BASE_COLUMNS + PROFILE_METRIC_COLUMNS
SOCIAL_COLUMNS_FOR_QUERY = [
'wallet_address',
'pumpfun_username',
'twitter_username',
'telegram_channel',
'kolscan_name',
'cabalspy_name',
'axiom_kol_name'
]
DB_BATCH_SIZE = 5000
def __init__(self, clickhouse_client: Any, neo4j_driver: Any):
self.db_client = clickhouse_client
self.graph_client = neo4j_driver
print("DataFetcher instantiated.")
def get_all_mints(self, start_date: Optional[datetime.datetime] = None) -> List[Dict[str, Any]]:
"""
Fetches a list of all mint events to serve as dataset samples.
Can be filtered to only include mints on or after a given start_date.
"""
query = "SELECT mint_address, timestamp, creator_address, protocol, token_name, token_symbol, token_uri, total_supply, token_decimals FROM mints"
params = {}
where_clauses = []
if start_date:
where_clauses.append("timestamp >= %(start_date)s")
params['start_date'] = start_date
if where_clauses:
query += " WHERE " + " AND ".join(where_clauses)
try:
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
if not rows:
return []
columns = [col[0] for col in columns_info]
result = [dict(zip(columns, row)) for row in rows]
if not result:
return []
return result
except Exception as e:
print(f"ERROR: Failed to fetch token addresses from ClickHouse: {e}")
print("INFO: Falling back to mock token addresses for development.")
return [{'mint_address': 'tknA_real', 'timestamp': datetime.datetime.now(datetime.timezone.utc), 'creator_address': 'addr_Creator_Real', 'protocol': 0}]
def fetch_mint_record(self, token_address: str) -> Dict[str, Any]:
"""
Fetches the raw mint record for a token from the 'mints' table.
"""
query = f"SELECT timestamp, creator_address, mint_address, protocol FROM mints WHERE mint_address = '{token_address}' ORDER BY timestamp ASC LIMIT 1"
# Assumes the client returns a list of dicts or can be converted
# Using column names from your schema
columns = ['timestamp', 'creator_address', 'mint_address', 'protocol']
try:
result = self.db_client.execute(query)
if not result or not result[0]:
raise ValueError(f"No mint event found for token {token_address}")
# Convert the tuple result into a dictionary
record = dict(zip(columns, result[0]))
return record
except Exception as e:
print(f"ERROR: Failed to fetch mint record for {token_address}: {e}")
print("INFO: Falling back to mock mint record for development.")
# Fallback for development if DB connection fails
return {
'timestamp': datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(days=1),
'creator_address': 'addr_Creator_Real',
'mint_address': token_address,
'protocol': vocab.PROTOCOL_TO_ID.get("Pump V1", 0)
}
def fetch_wallet_profiles(self, wallet_addresses: List[str], T_cutoff: datetime.datetime) -> Dict[str, Dict[str, Any]]:
"""
Convenience wrapper around fetch_wallet_profiles_and_socials for profile-only data.
"""
profiles, _ = self.fetch_wallet_profiles_and_socials(wallet_addresses, T_cutoff)
return profiles
def fetch_wallet_socials(self, wallet_addresses: List[str]) -> Dict[str, Dict[str, Any]]:
"""
Fetches wallet social records for a list of wallet addresses.
Batches queries to avoid "Max query size exceeded" errors.
Returns a dictionary mapping wallet_address to its social data.
"""
if not wallet_addresses:
return {}
BATCH_SIZE = self.DB_BATCH_SIZE
socials = {}
for i in range(0, len(wallet_addresses), BATCH_SIZE):
batch_addresses = wallet_addresses[i : i + BATCH_SIZE]
query = "SELECT * FROM wallet_socials WHERE wallet_address IN %(addresses)s"
params = {'addresses': batch_addresses}
try:
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
if not rows:
continue
columns = [col[0] for col in columns_info]
for row in rows:
social_dict = dict(zip(columns, row))
wallet_addr = social_dict.get('wallet_address')
if wallet_addr:
socials[wallet_addr] = social_dict
except Exception as e:
print(f"ERROR: Failed to fetch wallet socials for batch {i}: {e}")
# Continue to next batch
return socials
def fetch_wallet_profiles_and_socials(self,
wallet_addresses: List[str],
T_cutoff: datetime.datetime) -> Tuple[Dict[str, Dict[str, Any]], Dict[str, Dict[str, Any]]]:
"""
Fetches wallet profiles (time-aware) and socials for all requested wallets.
Batches queries to avoid "Max query size exceeded" errors.
Returns two dictionaries: profiles, socials.
"""
if not wallet_addresses:
return {}, {}
social_columns = self.SOCIAL_COLUMNS_FOR_QUERY
profile_base_cols = self.PROFILE_BASE_COLUMNS
profile_metric_cols = self.PROFILE_METRIC_COLUMNS
profile_base_str = ",\n ".join(profile_base_cols)
metric_projection_cols = ['wallet_address', 'updated_at'] + profile_metric_cols
profile_metric_str = ",\n ".join(metric_projection_cols)
profile_base_select_cols = [col for col in profile_base_cols if col != 'wallet_address']
profile_metric_select_cols = [
col for col in profile_metric_cols if col not in ('wallet_address',)
]
social_select_cols = [col for col in social_columns if col != 'wallet_address']
select_expressions = []
for col in profile_base_select_cols:
select_expressions.append(f"lp.{col} AS profile__{col}")
for col in profile_metric_select_cols:
select_expressions.append(f"lm.{col} AS profile__{col}")
for col in social_select_cols:
select_expressions.append(f"ws.{col} AS social__{col}")
select_clause = ""
if select_expressions:
select_clause = ",\n " + ",\n ".join(select_expressions)
profile_keys = [f"profile__{col}" for col in (profile_base_select_cols + profile_metric_select_cols)]
social_keys = [f"social__{col}" for col in social_select_cols]
BATCH_SIZE = self.DB_BATCH_SIZE
all_profiles = {}
all_socials = {}
for i in range(0, len(wallet_addresses), BATCH_SIZE):
batch_addresses = wallet_addresses[i : i + BATCH_SIZE]
query = f"""
WITH ranked_profiles AS (
SELECT
{profile_base_str},
ROW_NUMBER() OVER (PARTITION BY wallet_address ORDER BY updated_at DESC) AS rn
FROM wallet_profiles
WHERE wallet_address IN %(addresses)s
),
latest_profiles AS (
SELECT
{profile_base_str}
FROM ranked_profiles
WHERE rn = 1
),
ranked_metrics AS (
SELECT
{profile_metric_str},
ROW_NUMBER() OVER (PARTITION BY wallet_address ORDER BY updated_at DESC) AS rn
FROM wallet_profile_metrics
WHERE
wallet_address IN %(addresses)s
AND updated_at <= %(T_cutoff)s
),
latest_metrics AS (
SELECT
{profile_metric_str}
FROM ranked_metrics
WHERE rn = 1
),
requested_wallets AS (
SELECT DISTINCT wallet_address
FROM (SELECT arrayJoin(%(addresses)s) AS wallet_address)
)
SELECT
rw.wallet_address AS wallet_address
{select_clause}
FROM requested_wallets AS rw
LEFT JOIN latest_profiles AS lp ON rw.wallet_address = lp.wallet_address
LEFT JOIN latest_metrics AS lm ON rw.wallet_address = lm.wallet_address
LEFT JOIN wallet_socials AS ws ON rw.wallet_address = ws.wallet_address;
"""
params = {'addresses': batch_addresses, 'T_cutoff': T_cutoff}
try:
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
if not rows:
continue
columns = [col[0] for col in columns_info]
for row in rows:
row_dict = dict(zip(columns, row))
wallet_addr = row_dict.get('wallet_address')
if not wallet_addr:
continue
profile_data = {}
if profile_keys:
for pref_key in profile_keys:
if pref_key in row_dict:
value = row_dict[pref_key]
profile_data[pref_key.replace('profile__', '')] = value
if profile_data and any(value is not None for value in profile_data.values()):
profile_data['wallet_address'] = wallet_addr
all_profiles[wallet_addr] = profile_data
social_data = {}
if social_keys:
for pref_key in social_keys:
if pref_key in row_dict:
value = row_dict[pref_key]
social_data[pref_key.replace('social__', '')] = value
if social_data and any(value is not None for value in social_data.values()):
social_data['wallet_address'] = wallet_addr
all_socials[wallet_addr] = social_data
except Exception as e:
print(f"ERROR: Combined profile/social query failed for batch {i}-{i+BATCH_SIZE}: {e}")
# We continue to the next batch
return all_profiles, all_socials
def fetch_wallet_holdings(self, wallet_addresses: List[str], T_cutoff: datetime.datetime) -> Dict[str, List[Dict[str, Any]]]:
"""
Fetches top 2 wallet holding records for a list of wallet addresses that were active at T_cutoff.
Batches queries to avoid "Max query size exceeded" errors.
Returns a dictionary mapping wallet_address to a LIST of its holding data.
"""
if not wallet_addresses:
return {}
BATCH_SIZE = self.DB_BATCH_SIZE
holdings = defaultdict(list)
for i in range(0, len(wallet_addresses), BATCH_SIZE):
batch_addresses = wallet_addresses[i : i + BATCH_SIZE]
# --- Time-aware query ---
# 1. For each holding, find the latest state at or before T_cutoff.
# 2. Filter for holdings where the balance was greater than 0.
# 3. Rank these active holdings by USD volume and take the top 2 per wallet.
query = """
WITH point_in_time_holdings AS (
SELECT
*,
COALESCE(history_bought_cost_sol, 0) + COALESCE(history_sold_income_sol, 0) AS total_volume_usd,
ROW_NUMBER() OVER(PARTITION BY wallet_address, mint_address ORDER BY updated_at DESC) as rn_per_holding
FROM wallet_holdings
WHERE
wallet_address IN %(addresses)s
AND updated_at <= %(T_cutoff)s
),
ranked_active_holdings AS (
SELECT *,
ROW_NUMBER() OVER(PARTITION BY wallet_address ORDER BY total_volume_usd DESC) as rn_per_wallet
FROM point_in_time_holdings
WHERE rn_per_holding = 1 AND current_balance > 0
)
SELECT *
FROM ranked_active_holdings
WHERE rn_per_wallet <= 2;
"""
params = {'addresses': batch_addresses, 'T_cutoff': T_cutoff}
try:
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
if not rows:
continue
columns = [col[0] for col in columns_info]
for row in rows:
holding_dict = dict(zip(columns, row))
wallet_addr = holding_dict.get('wallet_address')
if wallet_addr:
holdings[wallet_addr].append(holding_dict)
except Exception as e:
print(f"ERROR: Failed to fetch wallet holdings for batch {i}: {e}")
# Continue to next batch
return dict(holdings)
def fetch_graph_links(self,
initial_addresses: List[str],
T_cutoff: datetime.datetime,
max_degrees: int = 1) -> Tuple[Dict[str, str], Dict[str, Dict[str, Any]]]:
"""
Fetches graph links from Neo4j, traversing up to a max degree of separation.
Args:
initial_addresses: A list of starting wallet or token addresses.
max_degrees: The maximum number of hops to traverse in the graph.
Returns:
A tuple containing:
- A dictionary mapping entity addresses to their type ('Wallet' or 'Token').
- A dictionary of aggregated links, structured for the GraphUpdater.
"""
if not initial_addresses:
return {}, {}
cutoff_ts = int(T_cutoff.timestamp())
max_retries = 3
backoff_sec = 2
for attempt in range(max_retries + 1):
try:
with self.graph_client.session() as session:
all_entities = {addr: 'Token' for addr in initial_addresses} # Assume initial are tokens
newly_found_entities = set(initial_addresses)
aggregated_links = defaultdict(lambda: {'links': [], 'edges': []})
for i in range(max_degrees):
if not newly_found_entities:
break
# --- TIMING: Query execution ---
_t_query_start = time.perf_counter()
# Cypher query to find direct neighbors of the current frontier
# OPTIMIZED: Filter by timestamp IN Neo4j to avoid transferring 97%+ unused records
query = """
MATCH (a)-[r]-(b)
WHERE a.address IN $addresses AND r.timestamp <= $cutoff_ts
RETURN a.address AS source_address, type(r) AS link_type, properties(r) AS link_props, b.address AS dest_address, labels(b)[0] AS dest_type
LIMIT 10000
"""
params = {'addresses': list(newly_found_entities), 'cutoff_ts': cutoff_ts}
result = session.run(query, params)
_t_query_done = time.perf_counter()
# --- TIMING: Result processing ---
_t_process_start = time.perf_counter()
records_total = 0
current_degree_new_entities = set()
for record in result:
records_total += 1
link_type = record['link_type']
link_props = dict(record['link_props'])
source_addr = record['source_address']
dest_addr = record['dest_address']
dest_type = record['dest_type']
# Add the link and edge data
aggregated_links[link_type]['links'].append(link_props)
aggregated_links[link_type]['edges'].append((source_addr, dest_addr))
# If we found a new entity, add it to the set for the next iteration
if dest_addr not in all_entities.keys():
current_degree_new_entities.add(dest_addr)
all_entities[dest_addr] = dest_type
_t_process_done = time.perf_counter()
newly_found_entities = current_degree_new_entities
# --- Post-process: rename, map props, strip, cap ---
MAX_LINKS_PER_TYPE = 500
# Neo4j type -> collator type name
_NEO4J_TO_COLLATOR_NAME = {
'TRANSFERRED_TO': 'TransferLink',
'BUNDLE_TRADE': 'BundleTradeLink',
'COPIED_TRADE': 'CopiedTradeLink',
'COORDINATED_ACTIVITY': 'CoordinatedActivityLink',
'SNIPED': 'SnipedLink',
'MINTED': 'MintedLink',
'LOCKED_SUPPLY': 'LockedSupplyLink',
'BURNED': 'BurnedLink',
'PROVIDED_LIQUIDITY': 'ProvidedLiquidityLink',
'WHALE_OF': 'WhaleOfLink',
'TOP_TRADER_OF': 'TopTraderOfLink',
}
# Neo4j prop name -> encoder prop name (for fields with mismatched names)
_PROP_REMAP = {
'CopiedTradeLink': {
'buy_gap': 'time_gap_on_buy_sec',
'sell_gap': 'time_gap_on_sell_sec',
'f_buy_total': 'follower_buy_total',
'f_sell_total': 'follower_sell_total',
'leader_pnl': 'leader_pnl',
'follower_pnl': 'follower_pnl',
},
}
# Only keep fields each encoder actually reads
_NEEDED_FIELDS = {
'TransferLink': ['amount', 'mint'],
'BundleTradeLink': ['signatures'], # Neo4j has no total_amount; we derive it below
'CopiedTradeLink': ['time_gap_on_buy_sec', 'time_gap_on_sell_sec', 'leader_pnl', 'follower_pnl', 'follower_buy_total', 'follower_sell_total'],
'CoordinatedActivityLink': ['time_gap_on_first_sec', 'time_gap_on_second_sec'],
'SnipedLink': ['rank', 'sniped_amount'],
'MintedLink': ['buy_amount'],
'LockedSupplyLink': ['amount'],
'BurnedLink': ['amount'],
'ProvidedLiquidityLink': ['amount_quote'],
'WhaleOfLink': ['holding_pct_at_creation'],
'TopTraderOfLink': ['pnl_at_creation'],
}
cleaned_links = {}
for neo4j_type, data in aggregated_links.items():
collator_name = _NEO4J_TO_COLLATOR_NAME.get(neo4j_type)
if not collator_name:
continue # Skip unknown link types
links = data['links']
edges = data['edges']
# Cap
links = links[:MAX_LINKS_PER_TYPE]
edges = edges[:MAX_LINKS_PER_TYPE]
# Remap property names if needed
remap = _PROP_REMAP.get(collator_name)
if remap:
links = [{remap.get(k, k): v for k, v in l.items()} for l in links]
# Strip to only needed fields
needed = _NEEDED_FIELDS.get(collator_name, [])
links = [{f: l.get(f, 0) for f in needed} for l in links]
# BundleTradeLink: Neo4j has no total_amount; derive from signatures count
if collator_name == 'BundleTradeLink':
links = [{'total_amount': len(l.get('signatures', []) if isinstance(l.get('signatures'), list) else [])} for l in links]
cleaned_links[collator_name] = {'links': links, 'edges': edges}
return all_entities, cleaned_links
except Exception as e:
msg = str(e)
is_rate_limit = "AuthenticationRateLimit" in msg or "RateLimit" in msg
is_transient = "ServiceUnavailable" in msg or "TransientError" in msg or "SessionExpired" in msg
if is_rate_limit or is_transient:
if attempt < max_retries:
sleep_time = backoff_sec * (2 ** attempt)
print(f"WARN: Neo4j error ({type(e).__name__}). Retrying in {sleep_time}s... (Attempt {attempt+1}/{max_retries})")
time.sleep(sleep_time)
continue
# If we're here, it's either not retryable or we ran out of retries
# Ensure we use "FATAL" prefix so the caller knows to stop if required
raise RuntimeError(f"FATAL: Failed to fetch graph links from Neo4j: {e}") from e
def fetch_token_data(self, token_addresses: List[str], T_cutoff: datetime.datetime) -> Dict[str, Dict[str, Any]]:
"""
Fetches the latest token data for each address at or before T_cutoff.
Batches queries to avoid "Max query size exceeded" errors.
Returns a dictionary mapping token_address to its data.
"""
if not token_addresses:
return {}
BATCH_SIZE = self.DB_BATCH_SIZE
tokens = {}
for i in range(0, len(token_addresses), BATCH_SIZE):
batch_addresses = token_addresses[i : i + BATCH_SIZE]
# --- NEW: Time-aware query for historical token data ---
query = """
WITH ranked_tokens AS (
SELECT
*,
ROW_NUMBER() OVER (PARTITION BY token_address ORDER BY updated_at DESC) as rn
FROM tokens
WHERE
token_address IN %(addresses)s
AND updated_at <= %(T_cutoff)s
)
SELECT token_address, name, symbol, token_uri, protocol, total_supply, decimals
FROM ranked_tokens
WHERE rn = 1;
"""
params = {'addresses': batch_addresses, 'T_cutoff': T_cutoff}
try:
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
if not rows:
continue
# Get column names from the query result description
columns = [col[0] for col in columns_info]
for row in rows:
token_dict = dict(zip(columns, row))
token_addr = token_dict.get('token_address')
if token_addr:
# The 'tokens' table in the schema has 'token_address' but the
# collator expects 'address'. We'll add it for compatibility.
token_dict['address'] = token_addr
tokens[token_addr] = token_dict
except Exception as e:
print(f"ERROR: Failed to fetch token data for batch {i}: {e}")
# Continue next batch
return tokens
def fetch_deployed_token_details(self, token_addresses: List[str], T_cutoff: datetime.datetime) -> Dict[str, Dict[str, Any]]:
"""
Fetches historical details for deployed tokens at or before T_cutoff.
Batches queries to avoid "Max query size exceeded" errors.
"""
if not token_addresses:
return {}
BATCH_SIZE = self.DB_BATCH_SIZE
token_details = {}
total_tokens = len(token_addresses)
for i in range(0, total_tokens, BATCH_SIZE):
batch_addresses = token_addresses[i : i + BATCH_SIZE]
# --- NEW: Time-aware query for historical deployed token details ---
query = """
WITH ranked_tokens AS (
SELECT
*,
ROW_NUMBER() OVER (PARTITION BY token_address ORDER BY updated_at DESC) as rn
FROM tokens
WHERE
token_address IN %(addresses)s
AND updated_at <= %(T_cutoff)s
),
ranked_token_metrics AS (
SELECT
token_address,
ath_price_usd,
ROW_NUMBER() OVER (PARTITION BY token_address ORDER BY updated_at DESC) as rn
FROM token_metrics
WHERE
token_address IN %(addresses)s
AND updated_at <= %(T_cutoff)s
),
latest_tokens AS (
SELECT *
FROM ranked_tokens
WHERE rn = 1
),
latest_token_metrics AS (
SELECT *
FROM ranked_token_metrics
WHERE rn = 1
)
SELECT
lt.token_address,
lt.created_at,
lt.updated_at,
ltm.ath_price_usd,
lt.total_supply,
lt.decimals,
(lt.launchpad != lt.protocol) AS has_migrated
FROM latest_tokens AS lt
LEFT JOIN latest_token_metrics AS ltm
ON lt.token_address = ltm.token_address;
"""
params = {'addresses': batch_addresses, 'T_cutoff': T_cutoff}
try:
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
if not rows:
continue
columns = [col[0] for col in columns_info]
for row in rows:
token_details[row[0]] = dict(zip(columns, row))
except Exception as e:
print(f"ERROR: Failed to fetch deployed token details for batch {i}: {e}")
# Continue next batch
return token_details
def fetch_trades_for_token(self, token_address: str, T_cutoff: datetime.datetime, count_threshold: int, early_limit: int, recent_limit: int, full_history: bool = False) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
"""
Fetches ALL trades for a token up to T_cutoff, ordered by time.
Notes:
- This intentionally does NOT apply the older fetch-time H/B/H (High-Def / Blurry / High-Def)
sampling logic. Sequence-length control is handled later in data_loader.py via event-level
head/tail sampling with MIDDLE/RECENT markers.
- The function signature still includes legacy H/B/H parameters for compatibility.
Returns: (all_trades, [], [])
"""
if not token_address:
return [], [], []
params = {'token_address': token_address, 'T_cutoff': T_cutoff}
query = "SELECT * FROM trades WHERE base_address = %(token_address)s AND timestamp <= %(T_cutoff)s ORDER BY timestamp ASC"
try:
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
if not rows:
return [], [], []
columns = [col[0] for col in columns_info]
all_trades = [dict(zip(columns, row)) for row in rows]
return all_trades, [], []
except Exception as e:
print(f"ERROR: Failed to fetch trades for token {token_address}: {e}")
return [], [], []
def fetch_future_trades_for_token(self,
token_address: str,
start_ts: datetime.datetime,
end_ts: datetime.datetime) -> List[Dict[str, Any]]:
"""
Fetches successful trades for a token in the window (start_ts, end_ts].
Used for constructing label targets beyond the cutoff.
"""
if not token_address or start_ts is None or end_ts is None or start_ts >= end_ts:
return []
query = """
SELECT *
FROM trades
WHERE base_address = %(token_address)s
AND success = true
AND timestamp > %(start_ts)s
AND timestamp <= %(end_ts)s
ORDER BY timestamp ASC
"""
params = {
'token_address': token_address,
'start_ts': start_ts,
'end_ts': end_ts
}
try:
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
if not rows:
return []
columns = [col[0] for col in columns_info]
return [dict(zip(columns, row)) for row in rows]
except Exception as e:
print(f"ERROR: Failed to fetch future trades for token {token_address}: {e}")
return []
def fetch_transfers_for_token(self, token_address: str, T_cutoff: datetime.datetime, min_amount_threshold: float = 10_000_000) -> List[Dict[str, Any]]:
"""
Fetches all transfers for a token before T_cutoff, filtering out small amounts.
"""
if not token_address:
return []
query = """
SELECT * FROM transfers
WHERE mint_address = %(token_address)s
AND timestamp <= %(T_cutoff)s
AND amount_decimal >= %(min_amount)s
ORDER BY timestamp ASC
"""
params = {'token_address': token_address, 'T_cutoff': T_cutoff, 'min_amount': min_amount_threshold}
try:
# This query no longer uses H/B/H, it fetches all significant transfers
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
if not rows: return []
columns = [col[0] for col in columns_info]
return [dict(zip(columns, row)) for row in rows]
except Exception as e:
print(f"ERROR: Failed to fetch transfers for token {token_address}: {e}")
return []
def fetch_pool_creations_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
"""
Fetches pool creation records where the token is the base asset.
"""
if not token_address:
return []
query = """
SELECT
signature,
timestamp,
slot,
success,
error,
priority_fee,
protocol,
creator_address,
pool_address,
base_address,
quote_address,
lp_token_address,
initial_base_liquidity,
initial_quote_liquidity,
base_decimals,
quote_decimals
FROM pool_creations
WHERE base_address = %(token_address)s
AND timestamp <= %(T_cutoff)s
ORDER BY timestamp ASC
"""
params = {'token_address': token_address, 'T_cutoff': T_cutoff}
# print(f"INFO: Fetching pool creation events for {token_address}.")
try:
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
if not rows:
return []
columns = [col[0] for col in columns_info]
return [dict(zip(columns, row)) for row in rows]
except Exception as e:
print(f"ERROR: Failed to fetch pool creations for token {token_address}: {e}")
return []
def fetch_liquidity_changes_for_pools(self, pool_addresses: List[str], T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
"""
Fetches liquidity change records for the given pools up to T_cutoff.
"""
if not pool_addresses:
return []
query = """
SELECT
signature,
timestamp,
slot,
success,
error,
priority_fee,
protocol,
change_type,
lp_provider,
pool_address,
base_amount,
quote_amount
FROM liquidity
WHERE pool_address IN %(pool_addresses)s
AND timestamp <= %(T_cutoff)s
ORDER BY timestamp ASC
"""
params = {'pool_addresses': pool_addresses, 'T_cutoff': T_cutoff}
# print(f"INFO: Fetching liquidity change events for {len(pool_addresses)} pool(s).")
try:
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
if not rows:
return []
columns = [col[0] for col in columns_info]
return [dict(zip(columns, row)) for row in rows]
except Exception as e:
print(f"ERROR: Failed to fetch liquidity changes for pools {pool_addresses}: {e}")
return []
def fetch_fee_collections_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
"""
Fetches fee collection events where the token appears as either token_0 or token_1.
"""
if not token_address:
return []
query = """
SELECT
timestamp,
signature,
slot,
success,
error,
priority_fee,
protocol,
recipient_address,
token_0_mint_address,
token_0_amount,
token_1_mint_address,
token_1_amount
FROM fee_collections
WHERE (token_0_mint_address = %(token)s OR token_1_mint_address = %(token)s)
AND timestamp <= %(T_cutoff)s
ORDER BY timestamp ASC
"""
params = {'token': token_address, 'T_cutoff': T_cutoff}
# print(f"INFO: Fetching fee collection events for {token_address}.")
try:
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
if not rows:
return []
columns = [col[0] for col in columns_info]
return [dict(zip(columns, row)) for row in rows]
except Exception as e:
print(f"ERROR: Failed to fetch fee collections for token {token_address}: {e}")
return []
def fetch_migrations_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
"""
Fetches migration records for a given token up to T_cutoff.
"""
if not token_address:
return []
query = """
SELECT
timestamp,
signature,
slot,
success,
error,
priority_fee,
protocol,
mint_address,
virtual_pool_address,
pool_address,
migrated_base_liquidity,
migrated_quote_liquidity
FROM migrations
WHERE mint_address = %(token)s
AND timestamp <= %(T_cutoff)s
ORDER BY timestamp ASC
"""
params = {'token': token_address, 'T_cutoff': T_cutoff}
# print(f"INFO: Fetching migrations for {token_address}.")
try:
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
if not rows:
return []
columns = [col[0] for col in columns_info]
return [dict(zip(columns, row)) for row in rows]
except Exception as e:
print(f"ERROR: Failed to fetch migrations for token {token_address}: {e}")
return []
def fetch_burns_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
"""
Fetches burn events for a given token up to T_cutoff.
Schema: burns(timestamp, signature, slot, success, error, priority_fee, mint_address, source, amount, amount_decimal, source_balance)
"""
if not token_address:
return []
query = """
SELECT
timestamp,
signature,
slot,
success,
error,
priority_fee,
mint_address,
source,
amount,
amount_decimal,
source_balance
FROM burns
WHERE mint_address = %(token)s
AND timestamp <= %(T_cutoff)s
ORDER BY timestamp ASC
"""
params = {'token': token_address, 'T_cutoff': T_cutoff}
# print(f"INFO: Fetching burn events for {token_address}.")
try:
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
if not rows:
return []
columns = [col[0] for col in columns_info]
return [dict(zip(columns, row)) for row in rows]
except Exception as e:
print(f"ERROR: Failed to fetch burns for token {token_address}: {e}")
return []
def fetch_supply_locks_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
"""
Fetches supply lock events for a given token up to T_cutoff.
Schema: supply_locks(timestamp, signature, slot, success, error, priority_fee, protocol, contract_address, sender, recipient, mint_address, total_locked_amount, final_unlock_timestamp)
"""
if not token_address:
return []
query = """
SELECT
timestamp,
signature,
slot,
success,
error,
priority_fee,
protocol,
contract_address,
sender,
recipient,
mint_address,
total_locked_amount,
final_unlock_timestamp
FROM supply_locks
WHERE mint_address = %(token)s
AND timestamp <= %(T_cutoff)s
ORDER BY timestamp ASC
"""
params = {'token': token_address, 'T_cutoff': T_cutoff}
# print(f"INFO: Fetching supply lock events for {token_address}.")
try:
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
if not rows:
return []
columns = [col[0] for col in columns_info]
return [dict(zip(columns, row)) for row in rows]
except Exception as e:
print(f"ERROR: Failed to fetch supply locks for token {token_address}: {e}")
return []
def fetch_token_holders_for_snapshot(self, token_address: str, T_cutoff: datetime.datetime, limit: int = 200) -> List[Dict[str, Any]]:
"""
Fetch top holders for a token at or before T_cutoff for snapshot purposes.
Reconstructs holdings from trades table (buys - sells) since wallet_holdings
may not have full point-in-time history.
Returns rows with wallet_address and current_balance (>0), ordered by balance desc.
"""
if not token_address:
return []
query = """
SELECT
maker as wallet_address,
SUM(CASE WHEN trade_type = 0 THEN toInt64(base_amount) ELSE -toInt64(base_amount) END) / 1000000.0 as current_balance
FROM trades
WHERE base_address = %(token)s
AND timestamp <= %(T_cutoff)s
AND success = 1
GROUP BY maker
HAVING current_balance > 0
ORDER BY current_balance DESC
LIMIT %(limit)s;
"""
params = {'token': token_address, 'T_cutoff': T_cutoff, 'limit': int(limit)}
try:
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
if not rows:
return []
columns = [col[0] for col in columns_info]
return [dict(zip(columns, row)) for row in rows]
except Exception as e:
print(f"ERROR: Failed to fetch token holders for {token_address}: {e}")
return []
def fetch_total_holders_count_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> int:
"""
Returns the total number of wallets holding the token (balance > 0)
at or before T_cutoff. Reconstructs from trades table.
"""
if not token_address:
return 0
query = """
SELECT count() FROM (
SELECT maker
FROM trades
WHERE base_address = %(token)s
AND timestamp <= %(T_cutoff)s
AND success = 1
GROUP BY maker
HAVING SUM(CASE WHEN trade_type = 0 THEN toInt64(base_amount) ELSE -toInt64(base_amount) END) > 0
);
"""
params = {'token': token_address, 'T_cutoff': T_cutoff}
try:
rows = self.db_client.execute(query, params)
if not rows:
return 0
return int(rows[0][0])
except Exception as e:
print(f"ERROR: Failed to count total holders for token {token_address}: {e}")
return 0
def fetch_holder_snapshot_stats_for_token(self, token_address: str, T_cutoff: datetime.datetime, limit: int = 200) -> Tuple[int, List[Dict[str, Any]]]:
"""
Fetch total holder count and top holders at a point in time.
Returns (count, top_holders_list).
Uses the indexed wallet_holdings table directly - efficient due to mint_address filter.
"""
if not token_address:
return 0, []
# Fetch actual holder data
top_holders = self.fetch_token_holders_for_snapshot(token_address, T_cutoff, limit)
holder_count = self.fetch_total_holders_count_for_token(token_address, T_cutoff)
return holder_count, top_holders
def fetch_raw_token_data(
self,
token_address: str,
creator_address: str,
mint_timestamp: datetime.datetime,
max_horizon_seconds: int = 3600,
include_wallet_data: bool = True,
include_graph: bool = True,
min_trades: int = 0,
full_history: bool = False,
prune_failed: bool = False,
prune_transfers: bool = False
) -> Optional[Dict[str, Any]]:
"""
Fetches ALL available data for a token up to the maximum horizon.
This data is agnostic of T_cutoff and will be masked/filtered dynamically during training.
Wallet/graph data can be skipped to avoid caching T_cutoff-dependent features.
Args:
full_history: If True, fetches ALL trades ignoring H/B/H limits.
prune_failed: If True, filters out failed trades from the result.
prune_transfers: If True, skips fetching transfers entirely.
"""
# 1. Calculate the absolute maximum timestamp we care about (mint + max_horizon)
# We fetch everything up to this point.
max_limit_time = mint_timestamp + datetime.timedelta(seconds=max_horizon_seconds)
# 2. Fetch all trades up to max_limit_time
# Note: We pass None as T_cutoff to fetch_trades_for_token if we want *everything*,
# but here we likely want to bound it by our max training horizon to avoid fetching months of data.
# However, the existing method signature expects T_cutoff.
# So we pass max_limit_time as the "cutoff" for the purpose of raw data collection.
# We use a large enough limit to get all relevant trades for the session
# If full_history is True, these limits are ignored inside the method.
early_trades, middle_trades, recent_trades = self.fetch_trades_for_token(
token_address, max_limit_time, 30000, 10000, 15000, full_history=full_history
)
# Combine and deduplicate trades
all_trades = {}
for t in early_trades + middle_trades + recent_trades:
# key: (slot, tx_idx, instr_idx)
key = (t.get('slot'), t.get('transaction_index'), t.get('instruction_index'), t.get('signature'))
all_trades[key] = t
sorted_trades = sorted(list(all_trades.values()), key=lambda x: x['timestamp'])
# --- PRUNING FAILED TRADES ---
if prune_failed:
original_count = len(sorted_trades)
sorted_trades = [t for t in sorted_trades if t.get('success', False)]
if len(sorted_trades) < original_count:
# print(f" INFO: Pruned {original_count - len(sorted_trades)} failed trades.")
pass
if len(sorted_trades) < min_trades:
print(f" SKIP: Token {token_address} has only {len(sorted_trades)} trades (min required: {min_trades}). skipping fetches.")
return None
# 3. Fetch other events
# --- PRUNING TRANSFERS ---
if prune_transfers:
transfers = []
# print(" INFO: Pruning transfers (skipping fetch).")
else:
transfers = self.fetch_transfers_for_token(token_address, max_limit_time, 0.0) # 0.0 means fetch all
pool_creations = self.fetch_pool_creations_for_token(token_address, max_limit_time)
# Collect pool addresses to fetch liquidity changes
pool_addresses = [p['pool_address'] for p in pool_creations if p.get('pool_address')]
liquidity_changes = []
if pool_addresses:
liquidity_changes = self.fetch_liquidity_changes_for_pools(pool_addresses, max_limit_time)
fee_collections = self.fetch_fee_collections_for_token(token_address, max_limit_time)
burns = self.fetch_burns_for_token(token_address, max_limit_time)
supply_locks = self.fetch_supply_locks_for_token(token_address, max_limit_time)
migrations = self.fetch_migrations_for_token(token_address, max_limit_time)
profile_data = {}
social_data = {}
holdings_data = {}
deployed_token_details = {}
fetched_graph_entities = {}
graph_links = {}
unique_wallets = set()
if include_wallet_data or include_graph:
# Identify wallets that interacted with the token up to max_limit_time.
unique_wallets.add(creator_address)
for t in sorted_trades:
if t.get('maker'):
unique_wallets.add(t['maker'])
for t in transfers:
if t.get('source'):
unique_wallets.add(t['source'])
if t.get('destination'):
unique_wallets.add(t['destination'])
for p in pool_creations:
if p.get('creator_address'):
unique_wallets.add(p['creator_address'])
for l in liquidity_changes:
if l.get('lp_provider'):
unique_wallets.add(l['lp_provider'])
if include_wallet_data and unique_wallets:
# Profiles/holdings are time-dependent; only fetch if explicitly requested.
profile_data, social_data = self.fetch_wallet_profiles_and_socials(list(unique_wallets), max_limit_time)
holdings_data = self.fetch_wallet_holdings(list(unique_wallets), max_limit_time)
all_deployed_tokens = set()
for profile in profile_data.values():
all_deployed_tokens.update(profile.get('deployed_tokens', []))
if all_deployed_tokens:
deployed_token_details = self.fetch_deployed_token_details(list(all_deployed_tokens), max_limit_time)
if include_graph and unique_wallets:
graph_seed_wallets = list(unique_wallets)
if len(graph_seed_wallets) > 100:
pass
fetched_graph_entities, graph_links = self.fetch_graph_links(
graph_seed_wallets,
max_limit_time,
max_degrees=1
)
return {
"token_address": token_address,
"creator_address": creator_address,
"mint_timestamp": mint_timestamp,
"max_limit_time": max_limit_time,
"trades": sorted_trades,
"transfers": transfers,
"pool_creations": pool_creations,
"liquidity_changes": liquidity_changes,
"fee_collections": fee_collections,
"burns": burns,
"supply_locks": supply_locks,
"migrations": migrations,
"profiles": profile_data,
"socials": social_data,
"holdings": holdings_data,
"deployed_token_details": deployed_token_details,
"graph_entities": fetched_graph_entities,
"graph_links": graph_links
}