zirobtc commited on
Commit
18eb93c
·
1 Parent(s): d4195aa

Upload folder using huggingface_hub

Browse files
.ipynb_checkpoints/.gitignore-checkpoint ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore the __pycache__ directory anywhere in the repository
2
+ __pycache__/
3
+
4
+
5
+ # Ignore the 'runs' directory anywhere in the repository, regardless of nesting
6
+ runs/
7
+
8
+ data/pump_fun
9
+ data/cache
10
+ .env
11
+
12
+ data/cache
13
+ .tmp/
14
+ .cache/
15
+ checkpoints/
16
+ metadata/
17
+ store/
18
+ preprocessed_configs/
19
+ .early.coverage
data/.ipynb_checkpoints/data_fetcher-checkpoint.py ADDED
@@ -0,0 +1,1263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data_fetcher.py
2
+
3
+ from typing import List, Dict, Any, Tuple, Set, Optional
4
+ from collections import defaultdict
5
+ import datetime, time
6
+
7
+ # We need the vocabulary for mapping IDs
8
+ import models.vocabulary as vocab
9
+
10
+ class DataFetcher:
11
+ """
12
+ A dedicated class to handle all database queries for ClickHouse and Neo4j.
13
+ This keeps data fetching logic separate from the dataset and model.
14
+ """
15
+
16
+ # --- Explicit column definitions for wallet profile & social fetches ---
17
+ PROFILE_BASE_COLUMNS = [
18
+ 'wallet_address',
19
+ 'updated_at',
20
+ 'first_seen_ts',
21
+ 'last_seen_ts',
22
+ 'tags',
23
+ 'deployed_tokens',
24
+ 'funded_from',
25
+ 'funded_timestamp',
26
+ 'funded_signature',
27
+ 'funded_amount'
28
+ ]
29
+
30
+ PROFILE_METRIC_COLUMNS = [
31
+ 'balance',
32
+ 'transfers_in_count',
33
+ 'transfers_out_count',
34
+ 'spl_transfers_in_count',
35
+ 'spl_transfers_out_count',
36
+ 'total_buys_count',
37
+ 'total_sells_count',
38
+ 'total_winrate',
39
+ 'stats_1d_realized_profit_sol',
40
+ 'stats_1d_realized_profit_usd',
41
+ 'stats_1d_realized_profit_pnl',
42
+ 'stats_1d_buy_count',
43
+ 'stats_1d_sell_count',
44
+ 'stats_1d_transfer_in_count',
45
+ 'stats_1d_transfer_out_count',
46
+ 'stats_1d_avg_holding_period',
47
+ 'stats_1d_total_bought_cost_sol',
48
+ 'stats_1d_total_bought_cost_usd',
49
+ 'stats_1d_total_sold_income_sol',
50
+ 'stats_1d_total_sold_income_usd',
51
+ 'stats_1d_total_fee',
52
+ 'stats_1d_winrate',
53
+ 'stats_1d_tokens_traded',
54
+ 'stats_7d_realized_profit_sol',
55
+ 'stats_7d_realized_profit_usd',
56
+ 'stats_7d_realized_profit_pnl',
57
+ 'stats_7d_buy_count',
58
+ 'stats_7d_sell_count',
59
+ 'stats_7d_transfer_in_count',
60
+ 'stats_7d_transfer_out_count',
61
+ 'stats_7d_avg_holding_period',
62
+ 'stats_7d_total_bought_cost_sol',
63
+ 'stats_7d_total_bought_cost_usd',
64
+ 'stats_7d_total_sold_income_sol',
65
+ 'stats_7d_total_sold_income_usd',
66
+ 'stats_7d_total_fee',
67
+ 'stats_7d_winrate',
68
+ 'stats_7d_tokens_traded',
69
+ 'stats_30d_realized_profit_sol',
70
+ 'stats_30d_realized_profit_usd',
71
+ 'stats_30d_realized_profit_pnl',
72
+ 'stats_30d_buy_count',
73
+ 'stats_30d_sell_count',
74
+ 'stats_30d_transfer_in_count',
75
+ 'stats_30d_transfer_out_count',
76
+ 'stats_30d_avg_holding_period',
77
+ 'stats_30d_total_bought_cost_sol',
78
+ 'stats_30d_total_bought_cost_usd',
79
+ 'stats_30d_total_sold_income_sol',
80
+ 'stats_30d_total_sold_income_usd',
81
+ 'stats_30d_total_fee',
82
+ 'stats_30d_winrate',
83
+ 'stats_30d_tokens_traded'
84
+ ]
85
+
86
+ PROFILE_COLUMNS_FOR_QUERY = PROFILE_BASE_COLUMNS + PROFILE_METRIC_COLUMNS
87
+
88
+ SOCIAL_COLUMNS_FOR_QUERY = [
89
+ 'wallet_address',
90
+ 'pumpfun_username',
91
+ 'twitter_username',
92
+ 'telegram_channel',
93
+ 'kolscan_name',
94
+ 'cabalspy_name',
95
+ 'axiom_kol_name'
96
+ ]
97
+ def __init__(self, clickhouse_client: Any, neo4j_driver: Any):
98
+ self.db_client = clickhouse_client
99
+ self.graph_client = neo4j_driver
100
+ print("DataFetcher instantiated.")
101
+
102
+ def get_all_mints(self, start_date: Optional[datetime.datetime] = None) -> List[Dict[str, Any]]:
103
+ """
104
+ Fetches a list of all mint events to serve as dataset samples.
105
+ Can be filtered to only include mints on or after a given start_date.
106
+ """
107
+ query = "SELECT mint_address, timestamp, creator_address, protocol, token_name, token_symbol, token_uri, total_supply, token_decimals FROM mints"
108
+ params = {}
109
+ where_clauses = []
110
+
111
+ if start_date:
112
+ where_clauses.append("timestamp >= %(start_date)s")
113
+ params['start_date'] = start_date
114
+
115
+ if where_clauses:
116
+ query += " WHERE " + " AND ".join(where_clauses)
117
+
118
+ print(f"INFO: Executing query to get all mints: `{query}` with params: {params}")
119
+ try:
120
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
121
+ if not rows:
122
+ return []
123
+ columns = [col[0] for col in columns_info]
124
+ result = [dict(zip(columns, row)) for row in rows]
125
+ if not result:
126
+ return []
127
+ return result
128
+ except Exception as e:
129
+ print(f"ERROR: Failed to fetch token addresses from ClickHouse: {e}")
130
+ print("INFO: Falling back to mock token addresses for development.")
131
+ return [{'mint_address': 'tknA_real', 'timestamp': datetime.datetime.now(datetime.timezone.utc), 'creator_address': 'addr_Creator_Real', 'protocol': 0}]
132
+
133
+
134
+ def fetch_mint_record(self, token_address: str) -> Dict[str, Any]:
135
+ """
136
+ Fetches the raw mint record for a token from the 'mints' table.
137
+ """
138
+ query = f"SELECT timestamp, creator_address, mint_address, protocol FROM mints WHERE mint_address = '{token_address}' ORDER BY timestamp ASC LIMIT 1"
139
+ print(f"INFO: Executing query to fetch mint record: `{query}`")
140
+
141
+ # Assumes the client returns a list of dicts or can be converted
142
+ # Using column names from your schema
143
+ columns = ['timestamp', 'creator_address', 'mint_address', 'protocol']
144
+ try:
145
+ result = self.db_client.execute(query)
146
+
147
+ if not result or not result[0]:
148
+ raise ValueError(f"No mint event found for token {token_address}")
149
+
150
+ # Convert the tuple result into a dictionary
151
+ record = dict(zip(columns, result[0]))
152
+ return record
153
+ except Exception as e:
154
+ print(f"ERROR: Failed to fetch mint record for {token_address}: {e}")
155
+ print("INFO: Falling back to mock mint record for development.")
156
+ # Fallback for development if DB connection fails
157
+ return {
158
+ 'timestamp': datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(days=1),
159
+ 'creator_address': 'addr_Creator_Real',
160
+ 'mint_address': token_address,
161
+ 'protocol': vocab.PROTOCOL_TO_ID.get("Pump V1", 0)
162
+ }
163
+
164
+ def fetch_wallet_profiles(self, wallet_addresses: List[str], T_cutoff: datetime.datetime) -> Dict[str, Dict[str, Any]]:
165
+ """
166
+ Convenience wrapper around fetch_wallet_profiles_and_socials for profile-only data.
167
+ """
168
+ profiles, _ = self.fetch_wallet_profiles_and_socials(wallet_addresses, T_cutoff)
169
+ return profiles
170
+
171
+ def fetch_wallet_socials(self, wallet_addresses: List[str]) -> Dict[str, Dict[str, Any]]:
172
+ """
173
+ Fetches wallet social records for a list of wallet addresses.
174
+ Batches queries to avoid "Max query size exceeded" errors.
175
+ Returns a dictionary mapping wallet_address to its social data.
176
+ """
177
+ if not wallet_addresses:
178
+ return {}
179
+
180
+ BATCH_SIZE = 1000
181
+ socials = {}
182
+ total_wallets = len(wallet_addresses)
183
+ print(f"INFO: Executing query to fetch wallet socials for {total_wallets} wallets in batches of {BATCH_SIZE}.")
184
+
185
+ for i in range(0, total_wallets, BATCH_SIZE):
186
+ batch_addresses = wallet_addresses[i : i + BATCH_SIZE]
187
+
188
+ query = "SELECT * FROM wallet_socials WHERE wallet_address IN %(addresses)s"
189
+ params = {'addresses': batch_addresses}
190
+
191
+ try:
192
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
193
+ if not rows:
194
+ continue
195
+
196
+ columns = [col[0] for col in columns_info]
197
+ for row in rows:
198
+ social_dict = dict(zip(columns, row))
199
+ wallet_addr = social_dict.get('wallet_address')
200
+ if wallet_addr:
201
+ socials[wallet_addr] = social_dict
202
+
203
+ except Exception as e:
204
+ print(f"ERROR: Failed to fetch wallet socials for batch {i}: {e}")
205
+ # Continue to next batch
206
+
207
+ return socials
208
+
209
+ def fetch_wallet_profiles_and_socials(self,
210
+ wallet_addresses: List[str],
211
+ T_cutoff: datetime.datetime) -> Tuple[Dict[str, Dict[str, Any]], Dict[str, Dict[str, Any]]]:
212
+ """
213
+ Fetches wallet profiles (time-aware) and socials for all requested wallets.
214
+ Batches queries to avoid "Max query size exceeded" errors.
215
+ Returns two dictionaries: profiles, socials.
216
+ """
217
+ if not wallet_addresses:
218
+ return {}, {}
219
+
220
+ social_columns = self.SOCIAL_COLUMNS_FOR_QUERY
221
+ profile_base_cols = self.PROFILE_BASE_COLUMNS
222
+ profile_metric_cols = self.PROFILE_METRIC_COLUMNS
223
+
224
+ profile_base_str = ",\n ".join(profile_base_cols)
225
+ metric_projection_cols = ['wallet_address', 'updated_at'] + profile_metric_cols
226
+ profile_metric_str = ",\n ".join(metric_projection_cols)
227
+
228
+ profile_base_select_cols = [col for col in profile_base_cols if col != 'wallet_address']
229
+ profile_metric_select_cols = [
230
+ col for col in profile_metric_cols if col not in ('wallet_address',)
231
+ ]
232
+ social_select_cols = [col for col in social_columns if col != 'wallet_address']
233
+
234
+ select_expressions = []
235
+ for col in profile_base_select_cols:
236
+ select_expressions.append(f"lp.{col} AS profile__{col}")
237
+ for col in profile_metric_select_cols:
238
+ select_expressions.append(f"lm.{col} AS profile__{col}")
239
+ for col in social_select_cols:
240
+ select_expressions.append(f"ws.{col} AS social__{col}")
241
+ select_clause = ""
242
+ if select_expressions:
243
+ select_clause = ",\n " + ",\n ".join(select_expressions)
244
+
245
+ profile_keys = [f"profile__{col}" for col in (profile_base_select_cols + profile_metric_select_cols)]
246
+ social_keys = [f"social__{col}" for col in social_select_cols]
247
+
248
+ BATCH_SIZE = 1000
249
+ all_profiles = {}
250
+ all_socials = {}
251
+
252
+ total_wallets = len(wallet_addresses)
253
+ print(f"INFO: Fetching profiles+socials for {total_wallets} wallets in batches of {BATCH_SIZE}...")
254
+
255
+ for i in range(0, total_wallets, BATCH_SIZE):
256
+ batch_addresses = wallet_addresses[i : i + BATCH_SIZE]
257
+
258
+ query = f"""
259
+ WITH ranked_profiles AS (
260
+ SELECT
261
+ {profile_base_str},
262
+ ROW_NUMBER() OVER (PARTITION BY wallet_address ORDER BY updated_at DESC) AS rn
263
+ FROM wallet_profiles
264
+ WHERE wallet_address IN %(addresses)s
265
+ ),
266
+ latest_profiles AS (
267
+ SELECT
268
+ {profile_base_str}
269
+ FROM ranked_profiles
270
+ WHERE rn = 1
271
+ ),
272
+ ranked_metrics AS (
273
+ SELECT
274
+ {profile_metric_str},
275
+ ROW_NUMBER() OVER (PARTITION BY wallet_address ORDER BY updated_at DESC) AS rn
276
+ FROM wallet_profile_metrics
277
+ WHERE
278
+ wallet_address IN %(addresses)s
279
+ AND updated_at <= %(T_cutoff)s
280
+ ),
281
+ latest_metrics AS (
282
+ SELECT
283
+ {profile_metric_str}
284
+ FROM ranked_metrics
285
+ WHERE rn = 1
286
+ ),
287
+ requested_wallets AS (
288
+ SELECT DISTINCT wallet_address
289
+ FROM (SELECT arrayJoin(%(addresses)s) AS wallet_address)
290
+ )
291
+ SELECT
292
+ rw.wallet_address AS wallet_address
293
+ {select_clause}
294
+ FROM requested_wallets AS rw
295
+ LEFT JOIN latest_profiles AS lp ON rw.wallet_address = lp.wallet_address
296
+ LEFT JOIN latest_metrics AS lm ON rw.wallet_address = lm.wallet_address
297
+ LEFT JOIN wallet_socials AS ws ON rw.wallet_address = ws.wallet_address;
298
+ """
299
+
300
+ params = {'addresses': batch_addresses, 'T_cutoff': T_cutoff}
301
+
302
+ try:
303
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
304
+ if not rows:
305
+ continue
306
+
307
+ columns = [col[0] for col in columns_info]
308
+
309
+ for row in rows:
310
+ row_dict = dict(zip(columns, row))
311
+ wallet_addr = row_dict.get('wallet_address')
312
+ if not wallet_addr:
313
+ continue
314
+
315
+ profile_data = {}
316
+ if profile_keys:
317
+ for pref_key in profile_keys:
318
+ if pref_key in row_dict:
319
+ value = row_dict[pref_key]
320
+ profile_data[pref_key.replace('profile__', '')] = value
321
+
322
+ if profile_data and any(value is not None for value in profile_data.values()):
323
+ profile_data['wallet_address'] = wallet_addr
324
+ all_profiles[wallet_addr] = profile_data
325
+
326
+ social_data = {}
327
+ if social_keys:
328
+ for pref_key in social_keys:
329
+ if pref_key in row_dict:
330
+ value = row_dict[pref_key]
331
+ social_data[pref_key.replace('social__', '')] = value
332
+
333
+ if social_data and any(value is not None for value in social_data.values()):
334
+ social_data['wallet_address'] = wallet_addr
335
+ all_socials[wallet_addr] = social_data
336
+
337
+ except Exception as e:
338
+ print(f"ERROR: Combined profile/social query failed for batch {i}-{i+BATCH_SIZE}: {e}")
339
+ # We continue to the next batch
340
+
341
+ return all_profiles, all_socials
342
+
343
+ def fetch_wallet_holdings(self, wallet_addresses: List[str], T_cutoff: datetime.datetime) -> Dict[str, List[Dict[str, Any]]]:
344
+ """
345
+ Fetches top 2 wallet holding records for a list of wallet addresses that were active at T_cutoff.
346
+ Batches queries to avoid "Max query size exceeded" errors.
347
+ Returns a dictionary mapping wallet_address to a LIST of its holding data.
348
+ """
349
+ if not wallet_addresses:
350
+ return {}
351
+
352
+ BATCH_SIZE = 1000
353
+ holdings = defaultdict(list)
354
+ total_wallets = len(wallet_addresses)
355
+ print(f"INFO: Executing query to fetch wallet holdings for {total_wallets} wallets in batches of {BATCH_SIZE}.")
356
+
357
+ for i in range(0, total_wallets, BATCH_SIZE):
358
+ batch_addresses = wallet_addresses[i : i + BATCH_SIZE]
359
+
360
+ # --- Time-aware query ---
361
+ # 1. For each holding, find the latest state at or before T_cutoff.
362
+ # 2. Filter for holdings where the balance was greater than 0.
363
+ # 3. Rank these active holdings by USD volume and take the top 2 per wallet.
364
+ query = """
365
+ WITH point_in_time_holdings AS (
366
+ SELECT
367
+ *,
368
+ COALESCE(history_bought_cost_sol, 0) + COALESCE(history_sold_income_sol, 0) AS total_volume_usd,
369
+ ROW_NUMBER() OVER(PARTITION BY wallet_address, mint_address ORDER BY updated_at DESC) as rn_per_holding
370
+ FROM wallet_holdings
371
+ WHERE
372
+ wallet_address IN %(addresses)s
373
+ AND updated_at <= %(T_cutoff)s
374
+ ),
375
+ ranked_active_holdings AS (
376
+ SELECT *,
377
+ ROW_NUMBER() OVER(PARTITION BY wallet_address ORDER BY total_volume_usd DESC) as rn_per_wallet
378
+ FROM point_in_time_holdings
379
+ WHERE rn_per_holding = 1 AND current_balance > 0
380
+ )
381
+ SELECT *
382
+ FROM ranked_active_holdings
383
+ WHERE rn_per_wallet <= 2;
384
+ """
385
+ params = {'addresses': batch_addresses, 'T_cutoff': T_cutoff}
386
+
387
+ try:
388
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
389
+ if not rows:
390
+ continue
391
+
392
+ columns = [col[0] for col in columns_info]
393
+ for row in rows:
394
+ holding_dict = dict(zip(columns, row))
395
+ wallet_addr = holding_dict.get('wallet_address')
396
+ if wallet_addr:
397
+ holdings[wallet_addr].append(holding_dict)
398
+
399
+ except Exception as e:
400
+ print(f"ERROR: Failed to fetch wallet holdings for batch {i}: {e}")
401
+ # Continue to next batch
402
+
403
+ return dict(holdings)
404
+
405
+ def fetch_graph_links(self,
406
+ initial_addresses: List[str],
407
+ T_cutoff: datetime.datetime,
408
+ max_degrees: int = 1) -> Tuple[Dict[str, str], Dict[str, Dict[str, Any]]]:
409
+ """
410
+ Fetches graph links from Neo4j, traversing up to a max degree of separation.
411
+
412
+ Args:
413
+ initial_addresses: A list of starting wallet or token addresses.
414
+ max_degrees: The maximum number of hops to traverse in the graph.
415
+
416
+ Returns:
417
+ A tuple containing:
418
+ - A dictionary mapping entity addresses to their type ('Wallet' or 'Token').
419
+ - A dictionary of aggregated links, structured for the GraphUpdater.
420
+ """
421
+ if not initial_addresses:
422
+ return {}, {}
423
+
424
+ cutoff_ts = int(T_cutoff.timestamp())
425
+
426
+ print(f"INFO: Fetching graph links up to {max_degrees} degrees for {len(initial_addresses)} initial entities...")
427
+
428
+ max_retries = 3
429
+ backoff_sec = 2
430
+
431
+ for attempt in range(max_retries + 1):
432
+ try:
433
+ with self.graph_client.session() as session:
434
+ all_entities = {addr: 'Token' for addr in initial_addresses} # Assume initial are tokens
435
+ newly_found_entities = set(initial_addresses)
436
+ aggregated_links = defaultdict(lambda: {'links': [], 'edges': []})
437
+
438
+ for i in range(max_degrees):
439
+ if not newly_found_entities:
440
+ break
441
+
442
+ print(f" - Degree {i+1}: Traversing from {len(newly_found_entities)} new entities...")
443
+
444
+ # --- TIMING: Query execution ---
445
+ _t_query_start = time.perf_counter()
446
+
447
+ # Cypher query to find direct neighbors of the current frontier
448
+ # OPTIMIZED: Filter by timestamp IN Neo4j to avoid transferring 97%+ unused records
449
+ query = """
450
+ MATCH (a)-[r]-(b)
451
+ WHERE a.address IN $addresses AND r.timestamp <= $cutoff_ts
452
+ RETURN a.address AS source_address, type(r) AS link_type, properties(r) AS link_props, b.address AS dest_address, labels(b)[0] AS dest_type
453
+ LIMIT 10000
454
+ """
455
+ params = {'addresses': list(newly_found_entities), 'cutoff_ts': cutoff_ts}
456
+ result = session.run(query, params)
457
+
458
+ _t_query_done = time.perf_counter()
459
+
460
+ # --- TIMING: Result processing ---
461
+ _t_process_start = time.perf_counter()
462
+ records_total = 0
463
+
464
+ current_degree_new_entities = set()
465
+ for record in result:
466
+ records_total += 1
467
+ link_type = record['link_type']
468
+ link_props = dict(record['link_props'])
469
+ source_addr = record['source_address']
470
+ dest_addr = record['dest_address']
471
+ dest_type = record['dest_type']
472
+
473
+ # Add the link and edge data
474
+ aggregated_links[link_type]['links'].append(link_props)
475
+ aggregated_links[link_type]['edges'].append((source_addr, dest_addr))
476
+
477
+ # If we found a new entity, add it to the set for the next iteration
478
+ if dest_addr not in all_entities.keys():
479
+ current_degree_new_entities.add(dest_addr)
480
+ all_entities[dest_addr] = dest_type
481
+
482
+ _t_process_done = time.perf_counter()
483
+
484
+ # --- TIMING: Print detailed stats ---
485
+ print(f" [NEO4J TIMING] query_exec: {(_t_query_done - _t_query_start)*1000:.1f}ms, "
486
+ f"result_process: {(_t_process_done - _t_process_start)*1000:.1f}ms")
487
+ print(f" [NEO4J STATS] records_returned: {records_total}, "
488
+ f"new_entities: {len(current_degree_new_entities)}")
489
+
490
+ newly_found_entities = current_degree_new_entities
491
+
492
+ # --- Post-process: rename, map props, strip, cap ---
493
+ MAX_LINKS_PER_TYPE = 500
494
+
495
+ # Neo4j type -> collator type name
496
+ _NEO4J_TO_COLLATOR_NAME = {
497
+ 'TRANSFERRED_TO': 'TransferLink',
498
+ 'BUNDLE_TRADE': 'BundleTradeLink',
499
+ 'COPIED_TRADE': 'CopiedTradeLink',
500
+ 'COORDINATED_ACTIVITY': 'CoordinatedActivityLink',
501
+ 'SNIPED': 'SnipedLink',
502
+ 'MINTED': 'MintedLink',
503
+ 'LOCKED_SUPPLY': 'LockedSupplyLink',
504
+ 'BURNED': 'BurnedLink',
505
+ 'PROVIDED_LIQUIDITY': 'ProvidedLiquidityLink',
506
+ 'WHALE_OF': 'WhaleOfLink',
507
+ 'TOP_TRADER_OF': 'TopTraderOfLink',
508
+ }
509
+
510
+ # Neo4j prop name -> encoder prop name (for fields with mismatched names)
511
+ _PROP_REMAP = {
512
+ 'CopiedTradeLink': {
513
+ 'buy_gap': 'time_gap_on_buy_sec',
514
+ 'sell_gap': 'time_gap_on_sell_sec',
515
+ 'f_buy_total': 'follower_buy_total',
516
+ 'f_sell_total': 'follower_sell_total',
517
+ 'leader_pnl': 'leader_pnl',
518
+ 'follower_pnl': 'follower_pnl',
519
+ },
520
+ }
521
+
522
+ # Only keep fields each encoder actually reads
523
+ _NEEDED_FIELDS = {
524
+ 'TransferLink': ['amount', 'mint'],
525
+ 'BundleTradeLink': ['signatures'], # Neo4j has no total_amount; we derive it below
526
+ 'CopiedTradeLink': ['time_gap_on_buy_sec', 'time_gap_on_sell_sec', 'leader_pnl', 'follower_pnl', 'follower_buy_total', 'follower_sell_total'],
527
+ 'CoordinatedActivityLink': ['time_gap_on_first_sec', 'time_gap_on_second_sec'],
528
+ 'SnipedLink': ['rank', 'sniped_amount'],
529
+ 'MintedLink': ['buy_amount'],
530
+ 'LockedSupplyLink': ['amount'],
531
+ 'BurnedLink': ['amount'],
532
+ 'ProvidedLiquidityLink': ['amount_quote'],
533
+ 'WhaleOfLink': ['holding_pct_at_creation'],
534
+ 'TopTraderOfLink': ['pnl_at_creation'],
535
+ }
536
+
537
+ cleaned_links = {}
538
+ for neo4j_type, data in aggregated_links.items():
539
+ collator_name = _NEO4J_TO_COLLATOR_NAME.get(neo4j_type)
540
+ if not collator_name:
541
+ continue # Skip unknown link types
542
+
543
+ links = data['links']
544
+ edges = data['edges']
545
+
546
+ # Cap
547
+ links = links[:MAX_LINKS_PER_TYPE]
548
+ edges = edges[:MAX_LINKS_PER_TYPE]
549
+
550
+ # Remap property names if needed
551
+ remap = _PROP_REMAP.get(collator_name)
552
+ if remap:
553
+ links = [{remap.get(k, k): v for k, v in l.items()} for l in links]
554
+
555
+ # Strip to only needed fields
556
+ needed = _NEEDED_FIELDS.get(collator_name, [])
557
+ links = [{f: l.get(f, 0) for f in needed} for l in links]
558
+
559
+ # BundleTradeLink: Neo4j has no total_amount; derive from signatures count
560
+ if collator_name == 'BundleTradeLink':
561
+ links = [{'total_amount': len(l.get('signatures', []) if isinstance(l.get('signatures'), list) else [])} for l in links]
562
+
563
+ cleaned_links[collator_name] = {'links': links, 'edges': edges}
564
+
565
+ return all_entities, cleaned_links
566
+
567
+ except Exception as e:
568
+ msg = str(e)
569
+ is_rate_limit = "AuthenticationRateLimit" in msg or "RateLimit" in msg
570
+ is_transient = "ServiceUnavailable" in msg or "TransientError" in msg or "SessionExpired" in msg
571
+
572
+ if is_rate_limit or is_transient:
573
+ if attempt < max_retries:
574
+ sleep_time = backoff_sec * (2 ** attempt)
575
+ print(f"WARN: Neo4j error ({type(e).__name__}). Retrying in {sleep_time}s... (Attempt {attempt+1}/{max_retries})")
576
+ time.sleep(sleep_time)
577
+ continue
578
+
579
+ # If we're here, it's either not retryable or we ran out of retries
580
+ # Ensure we use "FATAL" prefix so the caller knows to stop if required
581
+ raise RuntimeError(f"FATAL: Failed to fetch graph links from Neo4j: {e}") from e
582
+
583
+ def fetch_token_data(self, token_addresses: List[str], T_cutoff: datetime.datetime) -> Dict[str, Dict[str, Any]]:
584
+ """
585
+ Fetches the latest token data for each address at or before T_cutoff.
586
+ Batches queries to avoid "Max query size exceeded" errors.
587
+ Returns a dictionary mapping token_address to its data.
588
+ """
589
+ if not token_addresses:
590
+ return {}
591
+
592
+ BATCH_SIZE = 1000
593
+ tokens = {}
594
+ total_tokens = len(token_addresses)
595
+ print(f"INFO: Executing query to fetch token data for {total_tokens} tokens in batches of {BATCH_SIZE}.")
596
+
597
+ for i in range(0, total_tokens, BATCH_SIZE):
598
+ batch_addresses = token_addresses[i : i + BATCH_SIZE]
599
+
600
+ # --- NEW: Time-aware query for historical token data ---
601
+ query = """
602
+ WITH ranked_tokens AS (
603
+ SELECT
604
+ *,
605
+ ROW_NUMBER() OVER (PARTITION BY token_address ORDER BY updated_at DESC) as rn
606
+ FROM tokens
607
+ WHERE
608
+ token_address IN %(addresses)s
609
+ AND updated_at <= %(T_cutoff)s
610
+ )
611
+ SELECT token_address, name, symbol, token_uri, protocol, total_supply, decimals
612
+ FROM ranked_tokens
613
+ WHERE rn = 1;
614
+ """
615
+ params = {'addresses': batch_addresses, 'T_cutoff': T_cutoff}
616
+
617
+ try:
618
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
619
+ if not rows:
620
+ continue
621
+
622
+ # Get column names from the query result description
623
+ columns = [col[0] for col in columns_info]
624
+
625
+ for row in rows:
626
+ token_dict = dict(zip(columns, row))
627
+ token_addr = token_dict.get('token_address')
628
+ if token_addr:
629
+ # The 'tokens' table in the schema has 'token_address' but the
630
+ # collator expects 'address'. We'll add it for compatibility.
631
+ token_dict['address'] = token_addr
632
+ tokens[token_addr] = token_dict
633
+
634
+ except Exception as e:
635
+ print(f"ERROR: Failed to fetch token data for batch {i}: {e}")
636
+ # Continue next batch
637
+
638
+ return tokens
639
+
640
+ def fetch_deployed_token_details(self, token_addresses: List[str], T_cutoff: datetime.datetime) -> Dict[str, Dict[str, Any]]:
641
+ """
642
+ Fetches historical details for deployed tokens at or before T_cutoff.
643
+ Batches queries to avoid "Max query size exceeded" errors.
644
+ """
645
+ if not token_addresses:
646
+ return {}
647
+
648
+ BATCH_SIZE = 1000
649
+ token_details = {}
650
+ total_tokens = len(token_addresses)
651
+ print(f"INFO: Executing query to fetch deployed token details for {total_tokens} tokens in batches of {BATCH_SIZE}.")
652
+
653
+ for i in range(0, total_tokens, BATCH_SIZE):
654
+ batch_addresses = token_addresses[i : i + BATCH_SIZE]
655
+
656
+ # --- NEW: Time-aware query for historical deployed token details ---
657
+ query = """
658
+ WITH ranked_tokens AS (
659
+ SELECT
660
+ *,
661
+ ROW_NUMBER() OVER (PARTITION BY token_address ORDER BY updated_at DESC) as rn
662
+ FROM tokens
663
+ WHERE
664
+ token_address IN %(addresses)s
665
+ AND updated_at <= %(T_cutoff)s
666
+ ),
667
+ ranked_token_metrics AS (
668
+ SELECT
669
+ token_address,
670
+ ath_price_usd,
671
+ ROW_NUMBER() OVER (PARTITION BY token_address ORDER BY updated_at DESC) as rn
672
+ FROM token_metrics
673
+ WHERE
674
+ token_address IN %(addresses)s
675
+ AND updated_at <= %(T_cutoff)s
676
+ ),
677
+ latest_tokens AS (
678
+ SELECT *
679
+ FROM ranked_tokens
680
+ WHERE rn = 1
681
+ ),
682
+ latest_token_metrics AS (
683
+ SELECT *
684
+ FROM ranked_token_metrics
685
+ WHERE rn = 1
686
+ )
687
+ SELECT
688
+ lt.token_address,
689
+ lt.created_at,
690
+ lt.updated_at,
691
+ ltm.ath_price_usd,
692
+ lt.total_supply,
693
+ lt.decimals,
694
+ (lt.launchpad != lt.protocol) AS has_migrated
695
+ FROM latest_tokens AS lt
696
+ LEFT JOIN latest_token_metrics AS ltm
697
+ ON lt.token_address = ltm.token_address;
698
+ """
699
+ params = {'addresses': batch_addresses, 'T_cutoff': T_cutoff}
700
+
701
+ try:
702
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
703
+ if not rows:
704
+ continue
705
+
706
+ columns = [col[0] for col in columns_info]
707
+ for row in rows:
708
+ token_details[row[0]] = dict(zip(columns, row))
709
+ except Exception as e:
710
+ print(f"ERROR: Failed to fetch deployed token details for batch {i}: {e}")
711
+ # Continue next batch
712
+
713
+ return token_details
714
+
715
+ def fetch_trades_for_token(self, token_address: str, T_cutoff: datetime.datetime, count_threshold: int, early_limit: int, recent_limit: int, full_history: bool = False) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
716
+ """
717
+ Fetches ALL trades for a token up to T_cutoff, ordered by time.
718
+
719
+ Notes:
720
+ - This intentionally does NOT apply the older fetch-time H/B/H (High-Def / Blurry / High-Def)
721
+ sampling logic. Sequence-length control is handled later in data_loader.py via event-level
722
+ head/tail sampling with MIDDLE/RECENT markers.
723
+ - The function signature still includes legacy H/B/H parameters for compatibility.
724
+ Returns: (all_trades, [], [])
725
+ """
726
+ if not token_address:
727
+ return [], [], []
728
+
729
+ params = {'token_address': token_address, 'T_cutoff': T_cutoff}
730
+ query = "SELECT * FROM trades WHERE base_address = %(token_address)s AND timestamp <= %(T_cutoff)s ORDER BY timestamp ASC"
731
+ try:
732
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
733
+ if not rows:
734
+ return [], [], []
735
+ columns = [col[0] for col in columns_info]
736
+ all_trades = [dict(zip(columns, row)) for row in rows]
737
+ return all_trades, [], []
738
+ except Exception as e:
739
+ print(f"ERROR: Failed to fetch trades for token {token_address}: {e}")
740
+ return [], [], []
741
+
742
+ def fetch_future_trades_for_token(self,
743
+ token_address: str,
744
+ start_ts: datetime.datetime,
745
+ end_ts: datetime.datetime) -> List[Dict[str, Any]]:
746
+ """
747
+ Fetches successful trades for a token in the window (start_ts, end_ts].
748
+ Used for constructing label targets beyond the cutoff.
749
+ """
750
+ if not token_address or start_ts is None or end_ts is None or start_ts >= end_ts:
751
+ return []
752
+
753
+ query = """
754
+ SELECT *
755
+ FROM trades
756
+ WHERE base_address = %(token_address)s
757
+ AND success = true
758
+ AND timestamp > %(start_ts)s
759
+ AND timestamp <= %(end_ts)s
760
+ ORDER BY timestamp ASC
761
+ """
762
+ params = {
763
+ 'token_address': token_address,
764
+ 'start_ts': start_ts,
765
+ 'end_ts': end_ts
766
+ }
767
+
768
+ try:
769
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
770
+ if not rows:
771
+ return []
772
+ columns = [col[0] for col in columns_info]
773
+ return [dict(zip(columns, row)) for row in rows]
774
+ except Exception as e:
775
+ print(f"ERROR: Failed to fetch future trades for token {token_address}: {e}")
776
+ return []
777
+
778
+ def fetch_transfers_for_token(self, token_address: str, T_cutoff: datetime.datetime, min_amount_threshold: float = 10_000_000) -> List[Dict[str, Any]]:
779
+ """
780
+ Fetches all transfers for a token before T_cutoff, filtering out small amounts.
781
+ """
782
+ if not token_address:
783
+ return []
784
+
785
+ query = """
786
+ SELECT * FROM transfers
787
+ WHERE mint_address = %(token_address)s
788
+ AND timestamp <= %(T_cutoff)s
789
+ AND amount_decimal >= %(min_amount)s
790
+ ORDER BY timestamp ASC
791
+ """
792
+ params = {'token_address': token_address, 'T_cutoff': T_cutoff, 'min_amount': min_amount_threshold}
793
+ print(f"INFO: Fetching significant transfers for {token_address} (amount >= {min_amount_threshold}).")
794
+
795
+ try:
796
+ # This query no longer uses H/B/H, it fetches all significant transfers
797
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
798
+ if not rows: return []
799
+ columns = [col[0] for col in columns_info]
800
+ return [dict(zip(columns, row)) for row in rows]
801
+ except Exception as e:
802
+ print(f"ERROR: Failed to fetch transfers for token {token_address}: {e}")
803
+ return []
804
+
805
+ def fetch_pool_creations_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
806
+ """
807
+ Fetches pool creation records where the token is the base asset.
808
+ """
809
+ if not token_address:
810
+ return []
811
+
812
+ query = """
813
+ SELECT
814
+ signature,
815
+ timestamp,
816
+ slot,
817
+ success,
818
+ error,
819
+ priority_fee,
820
+ protocol,
821
+ creator_address,
822
+ pool_address,
823
+ base_address,
824
+ quote_address,
825
+ lp_token_address,
826
+ initial_base_liquidity,
827
+ initial_quote_liquidity,
828
+ base_decimals,
829
+ quote_decimals
830
+ FROM pool_creations
831
+ WHERE base_address = %(token_address)s
832
+ AND timestamp <= %(T_cutoff)s
833
+ ORDER BY timestamp ASC
834
+ """
835
+ params = {'token_address': token_address, 'T_cutoff': T_cutoff}
836
+ # print(f"INFO: Fetching pool creation events for {token_address}.")
837
+
838
+ try:
839
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
840
+ if not rows:
841
+ return []
842
+ columns = [col[0] for col in columns_info]
843
+ return [dict(zip(columns, row)) for row in rows]
844
+ except Exception as e:
845
+ print(f"ERROR: Failed to fetch pool creations for token {token_address}: {e}")
846
+ return []
847
+
848
+ def fetch_liquidity_changes_for_pools(self, pool_addresses: List[str], T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
849
+ """
850
+ Fetches liquidity change records for the given pools up to T_cutoff.
851
+ """
852
+ if not pool_addresses:
853
+ return []
854
+
855
+ query = """
856
+ SELECT
857
+ signature,
858
+ timestamp,
859
+ slot,
860
+ success,
861
+ error,
862
+ priority_fee,
863
+ protocol,
864
+ change_type,
865
+ lp_provider,
866
+ pool_address,
867
+ base_amount,
868
+ quote_amount
869
+ FROM liquidity
870
+ WHERE pool_address IN %(pool_addresses)s
871
+ AND timestamp <= %(T_cutoff)s
872
+ ORDER BY timestamp ASC
873
+ """
874
+ params = {'pool_addresses': pool_addresses, 'T_cutoff': T_cutoff}
875
+ # print(f"INFO: Fetching liquidity change events for {len(pool_addresses)} pool(s).")
876
+
877
+ try:
878
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
879
+ if not rows:
880
+ return []
881
+ columns = [col[0] for col in columns_info]
882
+ return [dict(zip(columns, row)) for row in rows]
883
+ except Exception as e:
884
+ print(f"ERROR: Failed to fetch liquidity changes for pools {pool_addresses}: {e}")
885
+ return []
886
+
887
+ def fetch_fee_collections_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
888
+ """
889
+ Fetches fee collection events where the token appears as either token_0 or token_1.
890
+ """
891
+ if not token_address:
892
+ return []
893
+
894
+ query = """
895
+ SELECT
896
+ timestamp,
897
+ signature,
898
+ slot,
899
+ success,
900
+ error,
901
+ priority_fee,
902
+ protocol,
903
+ recipient_address,
904
+ token_0_mint_address,
905
+ token_0_amount,
906
+ token_1_mint_address,
907
+ token_1_amount
908
+ FROM fee_collections
909
+ WHERE (token_0_mint_address = %(token)s OR token_1_mint_address = %(token)s)
910
+ AND timestamp <= %(T_cutoff)s
911
+ ORDER BY timestamp ASC
912
+ """
913
+ params = {'token': token_address, 'T_cutoff': T_cutoff}
914
+ # print(f"INFO: Fetching fee collection events for {token_address}.")
915
+
916
+ try:
917
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
918
+ if not rows:
919
+ return []
920
+ columns = [col[0] for col in columns_info]
921
+ return [dict(zip(columns, row)) for row in rows]
922
+ except Exception as e:
923
+ print(f"ERROR: Failed to fetch fee collections for token {token_address}: {e}")
924
+ return []
925
+
926
+ def fetch_migrations_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
927
+ """
928
+ Fetches migration records for a given token up to T_cutoff.
929
+ """
930
+ if not token_address:
931
+ return []
932
+ query = """
933
+ SELECT
934
+ timestamp,
935
+ signature,
936
+ slot,
937
+ success,
938
+ error,
939
+ priority_fee,
940
+ protocol,
941
+ mint_address,
942
+ virtual_pool_address,
943
+ pool_address,
944
+ migrated_base_liquidity,
945
+ migrated_quote_liquidity
946
+ FROM migrations
947
+ WHERE mint_address = %(token)s
948
+ AND timestamp <= %(T_cutoff)s
949
+ ORDER BY timestamp ASC
950
+ """
951
+ params = {'token': token_address, 'T_cutoff': T_cutoff}
952
+ # print(f"INFO: Fetching migrations for {token_address}.")
953
+ try:
954
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
955
+ if not rows:
956
+ return []
957
+ columns = [col[0] for col in columns_info]
958
+ return [dict(zip(columns, row)) for row in rows]
959
+ except Exception as e:
960
+ print(f"ERROR: Failed to fetch migrations for token {token_address}: {e}")
961
+ return []
962
+
963
+ def fetch_burns_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
964
+ """
965
+ Fetches burn events for a given token up to T_cutoff.
966
+ Schema: burns(timestamp, signature, slot, success, error, priority_fee, mint_address, source, amount, amount_decimal, source_balance)
967
+ """
968
+ if not token_address:
969
+ return []
970
+
971
+ query = """
972
+ SELECT
973
+ timestamp,
974
+ signature,
975
+ slot,
976
+ success,
977
+ error,
978
+ priority_fee,
979
+ mint_address,
980
+ source,
981
+ amount,
982
+ amount_decimal,
983
+ source_balance
984
+ FROM burns
985
+ WHERE mint_address = %(token)s
986
+ AND timestamp <= %(T_cutoff)s
987
+ ORDER BY timestamp ASC
988
+ """
989
+ params = {'token': token_address, 'T_cutoff': T_cutoff}
990
+ # print(f"INFO: Fetching burn events for {token_address}.")
991
+
992
+ try:
993
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
994
+ if not rows:
995
+ return []
996
+ columns = [col[0] for col in columns_info]
997
+ return [dict(zip(columns, row)) for row in rows]
998
+ except Exception as e:
999
+ print(f"ERROR: Failed to fetch burns for token {token_address}: {e}")
1000
+ return []
1001
+
1002
+ def fetch_supply_locks_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
1003
+ """
1004
+ Fetches supply lock events for a given token up to T_cutoff.
1005
+ Schema: supply_locks(timestamp, signature, slot, success, error, priority_fee, protocol, contract_address, sender, recipient, mint_address, total_locked_amount, final_unlock_timestamp)
1006
+ """
1007
+ if not token_address:
1008
+ return []
1009
+
1010
+ query = """
1011
+ SELECT
1012
+ timestamp,
1013
+ signature,
1014
+ slot,
1015
+ success,
1016
+ error,
1017
+ priority_fee,
1018
+ protocol,
1019
+ contract_address,
1020
+ sender,
1021
+ recipient,
1022
+ mint_address,
1023
+ total_locked_amount,
1024
+ final_unlock_timestamp
1025
+ FROM supply_locks
1026
+ WHERE mint_address = %(token)s
1027
+ AND timestamp <= %(T_cutoff)s
1028
+ ORDER BY timestamp ASC
1029
+ """
1030
+ params = {'token': token_address, 'T_cutoff': T_cutoff}
1031
+ # print(f"INFO: Fetching supply lock events for {token_address}.")
1032
+
1033
+ try:
1034
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
1035
+ if not rows:
1036
+ return []
1037
+ columns = [col[0] for col in columns_info]
1038
+ return [dict(zip(columns, row)) for row in rows]
1039
+ except Exception as e:
1040
+ print(f"ERROR: Failed to fetch supply locks for token {token_address}: {e}")
1041
+ return []
1042
+
1043
+ def fetch_token_holders_for_snapshot(self, token_address: str, T_cutoff: datetime.datetime, limit: int = 200) -> List[Dict[str, Any]]:
1044
+ """
1045
+ Fetch top holders for a token at or before T_cutoff for snapshot purposes.
1046
+ Returns rows with wallet_address and current_balance (>0), ordered by balance desc.
1047
+ """
1048
+ if not token_address:
1049
+ return []
1050
+ query = """
1051
+ WITH point_in_time_holdings AS (
1052
+ SELECT *,
1053
+ ROW_NUMBER() OVER(PARTITION BY wallet_address, mint_address ORDER BY updated_at DESC) as rn_per_holding
1054
+ FROM wallet_holdings
1055
+ WHERE mint_address = %(token)s AND updated_at <= %(T_cutoff)s
1056
+ )
1057
+ SELECT wallet_address, current_balance
1058
+ FROM point_in_time_holdings
1059
+ WHERE rn_per_holding = 1 AND current_balance > 0
1060
+ ORDER BY current_balance DESC
1061
+ LIMIT %(limit)s;
1062
+ """
1063
+ params = {'token': token_address, 'T_cutoff': T_cutoff, 'limit': int(limit)}
1064
+ # print(f"INFO: Fetching top holders for snapshot for {token_address} (limit {limit}).")
1065
+ try:
1066
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
1067
+ if not rows:
1068
+ return []
1069
+ columns = [col[0] for col in columns_info]
1070
+ return [dict(zip(columns, row)) for row in rows]
1071
+ except Exception as e:
1072
+ print(f"ERROR: Failed to fetch token holders for {token_address}: {e}")
1073
+ return []
1074
+
1075
+ def fetch_total_holders_count_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> int:
1076
+ """
1077
+ Returns the total number of wallets holding the token (current_balance > 0)
1078
+ at or before T_cutoff.
1079
+ """
1080
+ if not token_address:
1081
+ return 0
1082
+ query = """
1083
+ WITH point_in_time_holdings AS (
1084
+ SELECT *,
1085
+ ROW_NUMBER() OVER(PARTITION BY wallet_address, mint_address ORDER BY updated_at DESC) as rn_per_holding
1086
+ FROM wallet_holdings
1087
+ WHERE mint_address = %(token)s AND updated_at <= %(T_cutoff)s
1088
+ )
1089
+ SELECT count()
1090
+ FROM point_in_time_holdings
1091
+ WHERE rn_per_holding = 1 AND current_balance > 0;
1092
+ """
1093
+ params = {'token': token_address, 'T_cutoff': T_cutoff}
1094
+ # print(f"INFO: Counting total holders for {token_address} at timestamp {T_cutoff}.")
1095
+ try:
1096
+ rows = self.db_client.execute(query, params)
1097
+ if not rows:
1098
+ return 0
1099
+ return int(rows[0][0])
1100
+ except Exception as e:
1101
+ print(f"ERROR: Failed to count total holders for token {token_address}: {e}")
1102
+ return 0
1103
+
1104
+ def fetch_holder_snapshot_stats_for_token(self, token_address: str, T_cutoff: datetime.datetime, limit: int = 200) -> Tuple[int, List[Dict[str, Any]]]:
1105
+ """
1106
+ Fetch total holder count at a point in time.
1107
+ Returns (count, top_holders_list).
1108
+ Uses the indexed wallet_holdings table directly - efficient due to mint_address filter.
1109
+ """
1110
+ if not token_address:
1111
+ return 0, []
1112
+
1113
+ holder_count = self.fetch_total_holders_count_for_token(token_address, T_cutoff)
1114
+ return holder_count, []
1115
+ def fetch_raw_token_data(
1116
+ self,
1117
+ token_address: str,
1118
+ creator_address: str,
1119
+ mint_timestamp: datetime.datetime,
1120
+ max_horizon_seconds: int = 3600,
1121
+ include_wallet_data: bool = True,
1122
+ include_graph: bool = True,
1123
+ min_trades: int = 0,
1124
+ full_history: bool = False,
1125
+ prune_failed: bool = False,
1126
+ prune_transfers: bool = False
1127
+ ) -> Optional[Dict[str, Any]]:
1128
+ """
1129
+ Fetches ALL available data for a token up to the maximum horizon.
1130
+ This data is agnostic of T_cutoff and will be masked/filtered dynamically during training.
1131
+ Wallet/graph data can be skipped to avoid caching T_cutoff-dependent features.
1132
+
1133
+ Args:
1134
+ full_history: If True, fetches ALL trades ignoring H/B/H limits.
1135
+ prune_failed: If True, filters out failed trades from the result.
1136
+ prune_transfers: If True, skips fetching transfers entirely.
1137
+ """
1138
+
1139
+ # 1. Calculate the absolute maximum timestamp we care about (mint + max_horizon)
1140
+ # We fetch everything up to this point.
1141
+ max_limit_time = mint_timestamp + datetime.timedelta(seconds=max_horizon_seconds)
1142
+
1143
+ # 2. Fetch all trades up to max_limit_time
1144
+ # Note: We pass None as T_cutoff to fetch_trades_for_token if we want *everything*,
1145
+ # but here we likely want to bound it by our max training horizon to avoid fetching months of data.
1146
+ # However, the existing method signature expects T_cutoff.
1147
+ # So we pass max_limit_time as the "cutoff" for the purpose of raw data collection.
1148
+
1149
+ # We use a large enough limit to get all relevant trades for the session
1150
+ # If full_history is True, these limits are ignored inside the method.
1151
+ early_trades, middle_trades, recent_trades = self.fetch_trades_for_token(
1152
+ token_address, max_limit_time, 30000, 10000, 15000, full_history=full_history
1153
+ )
1154
+
1155
+ # Combine and deduplicate trades
1156
+ all_trades = {}
1157
+ for t in early_trades + middle_trades + recent_trades:
1158
+ # key: (slot, tx_idx, instr_idx)
1159
+ key = (t.get('slot'), t.get('transaction_index'), t.get('instruction_index'), t.get('signature'))
1160
+ all_trades[key] = t
1161
+
1162
+ sorted_trades = sorted(list(all_trades.values()), key=lambda x: x['timestamp'])
1163
+
1164
+ # --- PRUNING FAILED TRADES ---
1165
+ if prune_failed:
1166
+ original_count = len(sorted_trades)
1167
+ sorted_trades = [t for t in sorted_trades if t.get('success', False)]
1168
+ if len(sorted_trades) < original_count:
1169
+ # print(f" INFO: Pruned {original_count - len(sorted_trades)} failed trades.")
1170
+ pass
1171
+
1172
+ if len(sorted_trades) < min_trades:
1173
+ print(f" SKIP: Token {token_address} has only {len(sorted_trades)} trades (min required: {min_trades}). skipping fetches.")
1174
+ return None
1175
+
1176
+ # 3. Fetch other events
1177
+ # --- PRUNING TRANSFERS ---
1178
+ if prune_transfers:
1179
+ transfers = []
1180
+ # print(" INFO: Pruning transfers (skipping fetch).")
1181
+ else:
1182
+ transfers = self.fetch_transfers_for_token(token_address, max_limit_time, 0.0) # 0.0 means fetch all
1183
+
1184
+ pool_creations = self.fetch_pool_creations_for_token(token_address, max_limit_time)
1185
+
1186
+ # Collect pool addresses to fetch liquidity changes
1187
+ pool_addresses = [p['pool_address'] for p in pool_creations if p.get('pool_address')]
1188
+ liquidity_changes = []
1189
+ if pool_addresses:
1190
+ liquidity_changes = self.fetch_liquidity_changes_for_pools(pool_addresses, max_limit_time)
1191
+
1192
+ fee_collections = self.fetch_fee_collections_for_token(token_address, max_limit_time)
1193
+ burns = self.fetch_burns_for_token(token_address, max_limit_time)
1194
+ supply_locks = self.fetch_supply_locks_for_token(token_address, max_limit_time)
1195
+ migrations = self.fetch_migrations_for_token(token_address, max_limit_time)
1196
+
1197
+ profile_data = {}
1198
+ social_data = {}
1199
+ holdings_data = {}
1200
+ deployed_token_details = {}
1201
+ fetched_graph_entities = {}
1202
+ graph_links = {}
1203
+
1204
+ unique_wallets = set()
1205
+ if include_wallet_data or include_graph:
1206
+ # Identify wallets that interacted with the token up to max_limit_time.
1207
+ unique_wallets.add(creator_address)
1208
+ for t in sorted_trades:
1209
+ if t.get('maker'):
1210
+ unique_wallets.add(t['maker'])
1211
+ for t in transfers:
1212
+ if t.get('source'):
1213
+ unique_wallets.add(t['source'])
1214
+ if t.get('destination'):
1215
+ unique_wallets.add(t['destination'])
1216
+ for p in pool_creations:
1217
+ if p.get('creator_address'):
1218
+ unique_wallets.add(p['creator_address'])
1219
+ for l in liquidity_changes:
1220
+ if l.get('lp_provider'):
1221
+ unique_wallets.add(l['lp_provider'])
1222
+
1223
+ if include_wallet_data and unique_wallets:
1224
+ # Profiles/holdings are time-dependent; only fetch if explicitly requested.
1225
+ profile_data, social_data = self.fetch_wallet_profiles_and_socials(list(unique_wallets), max_limit_time)
1226
+ holdings_data = self.fetch_wallet_holdings(list(unique_wallets), max_limit_time)
1227
+
1228
+ all_deployed_tokens = set()
1229
+ for profile in profile_data.values():
1230
+ all_deployed_tokens.update(profile.get('deployed_tokens', []))
1231
+ if all_deployed_tokens:
1232
+ deployed_token_details = self.fetch_deployed_token_details(list(all_deployed_tokens), max_limit_time)
1233
+
1234
+ if include_graph and unique_wallets:
1235
+ graph_seed_wallets = list(unique_wallets)
1236
+ if len(graph_seed_wallets) > 100:
1237
+ pass
1238
+ fetched_graph_entities, graph_links = self.fetch_graph_links(
1239
+ graph_seed_wallets,
1240
+ max_limit_time,
1241
+ max_degrees=1
1242
+ )
1243
+
1244
+ return {
1245
+ "token_address": token_address,
1246
+ "creator_address": creator_address,
1247
+ "mint_timestamp": mint_timestamp,
1248
+ "max_limit_time": max_limit_time,
1249
+ "trades": sorted_trades,
1250
+ "transfers": transfers,
1251
+ "pool_creations": pool_creations,
1252
+ "liquidity_changes": liquidity_changes,
1253
+ "fee_collections": fee_collections,
1254
+ "burns": burns,
1255
+ "supply_locks": supply_locks,
1256
+ "migrations": migrations,
1257
+ "profiles": profile_data,
1258
+ "socials": social_data,
1259
+ "holdings": holdings_data,
1260
+ "deployed_token_details": deployed_token_details,
1261
+ "graph_entities": fetched_graph_entities,
1262
+ "graph_links": graph_links
1263
+ }
data/data_collator.py CHANGED
@@ -144,23 +144,32 @@ class MemecoinCollator:
144
  item_wallet_addr_to_global_idx = {addr: wallet_addr_to_batch_idx.get(addr, self.entity_pad_idx) for addr in item_wallets.keys()}
145
  item_token_addr_to_global_idx = {addr: token_addr_to_batch_idx.get(addr, self.entity_pad_idx) for addr in item_tokens.keys()}
146
  for link_name, data in item.get('graph_links', {}).items():
147
- aggregated_links[link_name]['links_list'].extend(data.get('links', []))
148
  triplet = vocab.LINK_NAME_TO_TRIPLET.get(link_name)
149
  if not triplet: continue
150
  src_type, _, dst_type = triplet
151
  edges = data.get('edges')
152
- if not edges: continue
 
 
153
  src_map = item_wallet_addr_to_global_idx if src_type == 'wallet' else item_token_addr_to_global_idx
154
  dst_map = item_wallet_addr_to_global_idx if dst_type == 'wallet' else item_token_addr_to_global_idx
 
155
  remapped_edge_list = []
156
- for src_addr, dst_addr in edges:
 
 
157
  src_idx_global = src_map.get(src_addr, self.entity_pad_idx)
158
  dst_idx_global = dst_map.get(dst_addr, self.entity_pad_idx)
 
159
  if src_idx_global != self.entity_pad_idx and dst_idx_global != self.entity_pad_idx:
160
  remapped_edge_list.append([src_idx_global, dst_idx_global])
 
 
161
  if remapped_edge_list:
162
  remapped_edge_tensor = torch.tensor(remapped_edge_list, device=self.device, dtype=torch.long).t()
163
  aggregated_links[link_name]['edge_index_list'].append(remapped_edge_tensor)
 
164
  if link_name == "TransferLink":
165
  link_props = data.get('links', [])
166
  derived_edges = []
@@ -737,7 +746,7 @@ class MemecoinCollator:
737
  # Labels
738
  'labels': torch.stack([item['labels'] for item in batch]) if batch and 'labels' in batch[0] else None,
739
  'labels_mask': torch.stack([item['labels_mask'] for item in batch]) if batch and 'labels_mask' in batch[0] else None,
740
- 'quality_score': torch.stack([item['quality_score'] for item in batch]) if batch and 'quality_score' in batch[0] else None,
741
  'class_id': torch.tensor([item.get('class_id', 0) for item in batch], dtype=torch.long),
742
  # Debug info
743
  'token_addresses': [item.get('token_address', 'unknown') for item in batch],
 
144
  item_wallet_addr_to_global_idx = {addr: wallet_addr_to_batch_idx.get(addr, self.entity_pad_idx) for addr in item_wallets.keys()}
145
  item_token_addr_to_global_idx = {addr: token_addr_to_batch_idx.get(addr, self.entity_pad_idx) for addr in item_tokens.keys()}
146
  for link_name, data in item.get('graph_links', {}).items():
147
+ # aggregated_links[link_name]['links_list'].extend(data.get('links', [])) - REMOVED: Now handled inside the loop for sync
148
  triplet = vocab.LINK_NAME_TO_TRIPLET.get(link_name)
149
  if not triplet: continue
150
  src_type, _, dst_type = triplet
151
  edges = data.get('edges')
152
+ link_props_list = data.get('links', [])
153
+ if not edges or not link_props_list: continue
154
+
155
  src_map = item_wallet_addr_to_global_idx if src_type == 'wallet' else item_token_addr_to_global_idx
156
  dst_map = item_wallet_addr_to_global_idx if dst_type == 'wallet' else item_token_addr_to_global_idx
157
+
158
  remapped_edge_list = []
159
+ valid_link_props = []
160
+
161
+ for (src_addr, dst_addr), props in zip(edges, link_props_list):
162
  src_idx_global = src_map.get(src_addr, self.entity_pad_idx)
163
  dst_idx_global = dst_map.get(dst_addr, self.entity_pad_idx)
164
+
165
  if src_idx_global != self.entity_pad_idx and dst_idx_global != self.entity_pad_idx:
166
  remapped_edge_list.append([src_idx_global, dst_idx_global])
167
+ valid_link_props.append(props)
168
+
169
  if remapped_edge_list:
170
  remapped_edge_tensor = torch.tensor(remapped_edge_list, device=self.device, dtype=torch.long).t()
171
  aggregated_links[link_name]['edge_index_list'].append(remapped_edge_tensor)
172
+ aggregated_links[link_name]['links_list'].extend(valid_link_props)
173
  if link_name == "TransferLink":
174
  link_props = data.get('links', [])
175
  derived_edges = []
 
746
  # Labels
747
  'labels': torch.stack([item['labels'] for item in batch]) if batch and 'labels' in batch[0] else None,
748
  'labels_mask': torch.stack([item['labels_mask'] for item in batch]) if batch and 'labels_mask' in batch[0] else None,
749
+ 'quality_score': torch.stack([item['quality_score'] if isinstance(item['quality_score'], torch.Tensor) else torch.tensor(item['quality_score'], dtype=torch.float32) for item in batch]) if batch and 'quality_score' in batch[0] else None,
750
  'class_id': torch.tensor([item.get('class_id', 0) for item in batch], dtype=torch.long),
751
  # Debug info
752
  'token_addresses': [item.get('token_address', 'unknown') for item in batch],
data/ohlc_stats.npz CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:499481b1050c456cb48eddbfd2a4437c8b686715e8eec7c74e8edf2b43191591
3
  size 1660
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3af6751fb5666ccfd4c61d27c549e5fcd71d964090836f9d3646d6f1d63224c0
3
  size 1660
log.log CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9a102d8b3b6d3be1c81eac0be542ca3f91e17b4612ac00b50843669fa4e38ba5
3
- size 57319
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:41a901f956af52a553855651ff68f78a817ad4fa5b108efde1034e22a16724a0
3
+ size 4577
models/graph_updater.py CHANGED
@@ -400,10 +400,10 @@ class GraphUpdater(nn.Module):
400
 
401
  # Use vocabulary to get the triplet (src, rel, dst)
402
  # Make sure ID_TO_LINK_TYPE is correctly populated
403
- if link_name not in vocabulary.LINK_NAME_TO_TRIPLET:
404
  print(f"Warning: Link name '{link_name}' not found in vocabulary.LINK_NAME_TO_TRIPLET. Skipping.")
405
  continue
406
- src_type, rel_type, dst_type = vocabulary.LINK_NAME_TO_TRIPLET[link_name]
407
 
408
  # Check if encoder exists for this link name
409
  if link_name not in self.edge_encoders:
@@ -466,10 +466,9 @@ class GraphUpdater(nn.Module):
466
  print(f"Warning: Relation '{rel_type}' missing in block {block_key}. Skipping.")
467
  continue
468
 
469
- # *** THE FIX ***
470
- # Use scatter_add_ to accumulate messages for the destination node type.
471
- # This correctly handles multiple edge types pointing to the same node type.
472
- msg_aggregates[dst_type].scatter_add_(0, edge_index[1].unsqueeze(1).expand_as(messages), messages)
473
 
474
  # --- Aggregation & Update (Residual Connection) ---
475
  x_next = {}
 
400
 
401
  # Use vocabulary to get the triplet (src, rel, dst)
402
  # Make sure ID_TO_LINK_TYPE is correctly populated
403
+ if link_name not in models.vocabulary.LINK_NAME_TO_TRIPLET:
404
  print(f"Warning: Link name '{link_name}' not found in vocabulary.LINK_NAME_TO_TRIPLET. Skipping.")
405
  continue
406
+ src_type, rel_type, dst_type = models.vocabulary.LINK_NAME_TO_TRIPLET[link_name]
407
 
408
  # Check if encoder exists for this link name
409
  if link_name not in self.edge_encoders:
 
466
  print(f"Warning: Relation '{rel_type}' missing in block {block_key}. Skipping.")
467
  continue
468
 
469
+ # GATv2Conv output is already per-destination-node (shape [num_dst_nodes, node_dim])
470
+ # NOT per-edge. So we directly accumulate, no scatter needed.
471
+ msg_aggregates[dst_type] += messages
 
472
 
473
  # --- Aggregation & Update (Residual Connection) ---
474
  x_next = {}
sample_12LJX4a83B4tCuZ1_3.json ADDED
The diff for this file is too large to render. See raw diff
 
scripts/.ipynb_checkpoints/cache_dataset-checkpoint.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import sys
4
+ import argparse
5
+ import numpy as np
6
+ import datetime
7
+ import torch
8
+ import json
9
+ import math
10
+ from pathlib import Path
11
+ from tqdm import tqdm
12
+ from dotenv import load_dotenv
13
+ import huggingface_hub
14
+ import logging
15
+ from concurrent.futures import ProcessPoolExecutor, as_completed
16
+ import multiprocessing as mp
17
+
18
+ logging.getLogger("httpx").setLevel(logging.WARNING)
19
+ logging.getLogger("transformers").setLevel(logging.ERROR)
20
+ logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
21
+
22
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
23
+
24
+ from scripts.analyze_distribution import get_return_class_map
25
+ from scripts.compute_quality_score import get_token_quality_scores, fetch_token_metrics, _bucket_id, _midrank_percentiles, EPS
26
+
27
+ from clickhouse_driver import Client as ClickHouseClient
28
+ from neo4j import GraphDatabase
29
+
30
+ _worker_dataset = None
31
+ _worker_return_class_map = None
32
+ _worker_quality_scores_map = None
33
+
34
+
35
+ def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map):
36
+ global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
37
+ from data.data_loader import OracleDataset
38
+ from data.data_fetcher import DataFetcher
39
+
40
+ clickhouse_client = ClickHouseClient(host=db_config['clickhouse_host'], port=db_config['clickhouse_port'])
41
+ neo4j_driver = GraphDatabase.driver(db_config['neo4j_uri'], auth=(db_config['neo4j_user'], db_config['neo4j_password']))
42
+ data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
43
+
44
+ _worker_dataset = OracleDataset(
45
+ data_fetcher=data_fetcher,
46
+ max_samples=dataset_config['max_samples'],
47
+ start_date=dataset_config['start_date'],
48
+ ohlc_stats_path=dataset_config['ohlc_stats_path'],
49
+ horizons_seconds=dataset_config['horizons_seconds'],
50
+ quantiles=dataset_config['quantiles'],
51
+ min_trade_usd=dataset_config['min_trade_usd'],
52
+ max_seq_len=dataset_config['max_seq_len']
53
+ )
54
+ _worker_dataset.sampled_mints = dataset_config['sampled_mints']
55
+ _worker_return_class_map = return_class_map
56
+ _worker_quality_scores_map = quality_scores_map
57
+
58
+
59
+ def _process_single_token_context(args):
60
+ idx, mint_addr, samples_per_token, output_dir = args
61
+ global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
62
+ try:
63
+ class_id = _worker_return_class_map.get(mint_addr)
64
+ if class_id is None:
65
+ return {'status': 'skipped', 'reason': 'not in class map', 'mint': mint_addr}
66
+ contexts = _worker_dataset.__cacheitem_context__(idx, num_samples_per_token=samples_per_token)
67
+ if not contexts:
68
+ return {'status': 'skipped', 'reason': 'no valid contexts', 'mint': mint_addr}
69
+ q_score = _worker_quality_scores_map.get(mint_addr)
70
+ if q_score is None:
71
+ return {'status': 'skipped', 'reason': 'no quality score', 'mint': mint_addr}
72
+ saved_files = []
73
+ for ctx_idx, ctx in enumerate(contexts):
74
+ ctx["quality_score"] = q_score
75
+ ctx["class_id"] = class_id
76
+ ctx["source_token"] = mint_addr
77
+ ctx["cache_mode"] = "context"
78
+ filename = f"sample_{mint_addr[:16]}_{ctx_idx}.pt"
79
+ output_path = Path(output_dir) / filename
80
+ torch.save(ctx, output_path)
81
+ saved_files.append(filename)
82
+ return {'status': 'success', 'mint': mint_addr, 'class_id': class_id, 'q_score': q_score, 'n_contexts': len(contexts), 'n_events': len(contexts[0].get('event_sequence', [])) if contexts else 0, 'files': saved_files}
83
+ except Exception as e:
84
+ import traceback
85
+ return {'status': 'error', 'mint': mint_addr, 'error': str(e), 'traceback': traceback.format_exc()}
86
+
87
+
88
+ def _process_single_token_raw(args):
89
+ idx, mint_addr, output_dir = args
90
+ global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
91
+ try:
92
+ class_id = _worker_return_class_map.get(mint_addr)
93
+ if class_id is None:
94
+ return {'status': 'skipped', 'reason': 'not in class map', 'mint': mint_addr}
95
+ item = _worker_dataset.__cacheitem__(idx)
96
+ if item is None:
97
+ return {'status': 'skipped', 'reason': 'cacheitem returned None', 'mint': mint_addr}
98
+ q_score = _worker_quality_scores_map.get(mint_addr)
99
+ if q_score is None:
100
+ return {'status': 'skipped', 'reason': 'no quality score', 'mint': mint_addr}
101
+ item["quality_score"] = q_score
102
+ item["class_id"] = class_id
103
+ item["cache_mode"] = "raw"
104
+ filename = f"sample_{mint_addr[:16]}.pt"
105
+ output_path = Path(output_dir) / filename
106
+ torch.save(item, output_path)
107
+ return {'status': 'success', 'mint': mint_addr, 'class_id': class_id, 'q_score': q_score, 'n_trades': len(item.get('trades', [])), 'files': [filename]}
108
+ except Exception as e:
109
+ import traceback
110
+ return {'status': 'error', 'mint': mint_addr, 'error': str(e), 'traceback': traceback.format_exc()}
111
+
112
+
113
+ def compute_save_ohlc_stats(client, output_path):
114
+ print(f"INFO: Computing OHLC stats...")
115
+ query = """SELECT AVG(t.price_usd), stddevPop(t.price_usd), AVG(t.price), stddevPop(t.price), AVG(t.total_usd), stddevPop(t.total_usd) FROM trades AS t WHERE t.price_usd > 0 AND t.total_usd > 0"""
116
+ try:
117
+ result = client.execute(query)
118
+ if result and result[0]:
119
+ row = result[0]
120
+ stats = {"mean_price_usd": float(row[0] or 0), "std_price_usd": float(row[1] or 1), "mean_price_native": float(row[2] or 0), "std_price_native": float(row[3] or 1), "mean_trade_value_usd": float(row[4] or 0), "std_trade_value_usd": float(row[5] or 1)}
121
+ else:
122
+ stats = {"mean_price_usd": 0.0, "std_price_usd": 1.0, "mean_price_native": 0.0, "std_price_native": 1.0, "mean_trade_value_usd": 0.0, "std_trade_value_usd": 1.0}
123
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
124
+ np.savez(output_path, **stats)
125
+ print(f"INFO: Saved OHLC stats to {output_path}")
126
+ except Exception as e:
127
+ print(f"ERROR: Failed to compute OHLC stats: {e}")
128
+
129
+
130
+ def main():
131
+ load_dotenv()
132
+ mp.set_start_method('spawn', force=True)
133
+
134
+ hf_token = os.getenv("HF_TOKEN")
135
+ if hf_token:
136
+ print(f"INFO: Logging in to Hugging Face...")
137
+ huggingface_hub.login(token=hf_token)
138
+
139
+ parser = argparse.ArgumentParser()
140
+ parser.add_argument("--output_dir", type=str, default="data/cache")
141
+ parser.add_argument("--max_samples", type=int, default=None)
142
+ parser.add_argument("--start_date", type=str, default=None)
143
+ parser.add_argument("--ohlc_stats_path", type=str, default="data/ohlc_stats.npz")
144
+ parser.add_argument("--min_trade_usd", type=float, default=0.0)
145
+ parser.add_argument("--cache_mode", type=str, default="raw", choices=["raw", "context"])
146
+ parser.add_argument("--context_length", type=int, default=8192)
147
+ parser.add_argument("--min_trades", type=int, default=10)
148
+ parser.add_argument("--samples_per_token", type=int, default=1)
149
+ parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[30, 60, 120, 240, 420])
150
+ parser.add_argument("--quantiles", type=float, nargs="+", default=[0.1, 0.5, 0.9])
151
+ parser.add_argument("--num_workers", type=int, default=1)
152
+ parser.add_argument("--clickhouse_host", type=str, default=os.getenv("CLICKHOUSE_HOST", "localhost"))
153
+ parser.add_argument("--clickhouse_port", type=int, default=int(os.getenv("CLICKHOUSE_PORT", 9000)))
154
+ parser.add_argument("--neo4j_uri", type=str, default=os.getenv("NEO4J_URI", "bolt://localhost:7687"))
155
+ parser.add_argument("--neo4j_user", type=str, default=os.getenv("NEO4J_USER", "neo4j"))
156
+ parser.add_argument("--neo4j_password", type=str, default=os.getenv("NEO4J_PASSWORD", "password"))
157
+ args = parser.parse_args()
158
+
159
+ if args.num_workers == 0:
160
+ args.num_workers = max(1, mp.cpu_count() - 4)
161
+
162
+ output_dir = Path(args.output_dir)
163
+ output_dir.mkdir(parents=True, exist_ok=True)
164
+ start_date_dt = datetime.datetime.strptime(args.start_date, "%Y-%m-%d") if args.start_date else None
165
+
166
+ print(f"INFO: Initializing DB Connections...")
167
+ clickhouse_client = ClickHouseClient(host=args.clickhouse_host, port=args.clickhouse_port)
168
+ neo4j_driver = GraphDatabase.driver(args.neo4j_uri, auth=(args.neo4j_user, args.neo4j_password))
169
+
170
+ try:
171
+ compute_save_ohlc_stats(clickhouse_client, args.ohlc_stats_path)
172
+
173
+ from data.data_loader import OracleDataset
174
+ from data.data_fetcher import DataFetcher
175
+ data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
176
+
177
+ print("INFO: Fetching Return Classification Map...")
178
+ return_class_map, _ = get_return_class_map(clickhouse_client)
179
+ print(f"INFO: Loaded {len(return_class_map)} classified tokens.")
180
+
181
+ print("INFO: Fetching Quality Scores...")
182
+ quality_scores_map = get_token_quality_scores(clickhouse_client)
183
+ print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
184
+
185
+ dataset = OracleDataset(data_fetcher=data_fetcher, max_samples=args.max_samples, start_date=start_date_dt, ohlc_stats_path=args.ohlc_stats_path, horizons_seconds=args.horizons_seconds, quantiles=args.quantiles, min_trade_usd=args.min_trade_usd, max_seq_len=args.context_length)
186
+
187
+ if len(dataset) == 0:
188
+ print("WARNING: No samples. Exiting.")
189
+ return
190
+
191
+ # Filter mints by return_class_map
192
+ original_size = len(dataset.sampled_mints)
193
+ filtered_mints = [m for m in dataset.sampled_mints if m['mint_address'] in return_class_map]
194
+ print(f"INFO: Filtered by class map: {original_size} -> {len(filtered_mints)} tokens")
195
+
196
+ # Pre-filter: only keep tokens with >= min_trades trades (fast ClickHouse count query)
197
+ print(f"INFO: Pre-filtering tokens by trade count (>= {args.min_trades} trades)...")
198
+ trade_counts = clickhouse_client.execute("""
199
+ SELECT base_address, count() as cnt
200
+ FROM trades
201
+ GROUP BY base_address
202
+ HAVING cnt >= %(min_trades)s
203
+ """, {'min_trades': args.min_trades})
204
+ valid_tokens = {row[0] for row in trade_counts}
205
+ pre_filter_size = len(filtered_mints)
206
+ filtered_mints = [m for m in filtered_mints if m['mint_address'] in valid_tokens]
207
+ print(f"INFO: Pre-filtered by trade count: {pre_filter_size} -> {len(filtered_mints)} tokens (removed {pre_filter_size - len(filtered_mints)} with < {args.min_trades} trades)")
208
+
209
+ # Also filter by quality score availability
210
+ pre_quality_size = len(filtered_mints)
211
+ filtered_mints = [m for m in filtered_mints if m['mint_address'] in quality_scores_map]
212
+ print(f"INFO: Filtered by quality score: {pre_quality_size} -> {len(filtered_mints)} tokens")
213
+
214
+ if len(filtered_mints) == 0:
215
+ print("WARNING: No tokens after filtering.")
216
+ return
217
+
218
+ print(f"INFO: Cache mode: {args.cache_mode}, Workers: {args.num_workers}")
219
+
220
+ db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password}
221
+ dataset_config = {'max_samples': args.max_samples, 'start_date': start_date_dt, 'ohlc_stats_path': args.ohlc_stats_path, 'horizons_seconds': args.horizons_seconds, 'quantiles': args.quantiles, 'min_trade_usd': args.min_trade_usd, 'max_seq_len': args.context_length, 'sampled_mints': filtered_mints}
222
+
223
+ # Build tasks with class-aware multi-sampling for balanced cache
224
+ import random
225
+ from collections import Counter, defaultdict
226
+
227
+ # Count eligible tokens per class
228
+ eligible_class_counts = Counter()
229
+ mints_by_class = defaultdict(list)
230
+ for i, m in enumerate(filtered_mints):
231
+ cid = return_class_map.get(m['mint_address'])
232
+ if cid is not None:
233
+ eligible_class_counts[cid] += 1
234
+ mints_by_class[cid].append((i, m))
235
+
236
+ print(f"INFO: Eligible tokens per class: {dict(sorted(eligible_class_counts.items()))}")
237
+
238
+ # Compute balanced samples_per_token for each class
239
+ num_classes = len(eligible_class_counts)
240
+ if args.max_samples:
241
+ target_total = args.max_samples
242
+ else:
243
+ target_total = 15000 # Default target: 15k balanced files
244
+ target_per_class = target_total // max(num_classes, 1)
245
+
246
+ class_multipliers = {}
247
+ class_token_caps = {}
248
+ for cid, count in eligible_class_counts.items():
249
+ if count >= target_per_class:
250
+ # Enough tokens — 1 sample each, cap token count
251
+ class_multipliers[cid] = 1
252
+ class_token_caps[cid] = target_per_class
253
+ else:
254
+ # Not enough tokens — multi-sample, use all tokens
255
+ class_multipliers[cid] = min(10, max(1, math.ceil(target_per_class / max(count, 1))))
256
+ class_token_caps[cid] = count
257
+
258
+ print(f"INFO: Target total: {target_total}, Target per class: {target_per_class}")
259
+ print(f"INFO: Class multipliers: {dict(sorted(class_multipliers.items()))}")
260
+ print(f"INFO: Class token caps: {dict(sorted(class_token_caps.items()))}")
261
+
262
+ # Build balanced task list
263
+ tasks = []
264
+ for cid, mint_list in mints_by_class.items():
265
+ random.shuffle(mint_list)
266
+ cap = class_token_caps.get(cid, len(mint_list))
267
+ spt = class_multipliers.get(cid, 1)
268
+ # Override with CLI --samples_per_token if explicitly set > 1
269
+ if args.samples_per_token > 1:
270
+ spt = args.samples_per_token
271
+ for i, m in mint_list[:cap]:
272
+ mint_addr = m['mint_address']
273
+ if args.cache_mode == "context":
274
+ tasks.append((i, mint_addr, spt, str(output_dir)))
275
+ else:
276
+ tasks.append((i, mint_addr, str(output_dir)))
277
+
278
+ random.shuffle(tasks) # Shuffle tasks for even load distribution across workers
279
+ expected_files = sum(
280
+ class_multipliers.get(cid, 1) * min(class_token_caps.get(cid, len(ml)), len(ml))
281
+ for cid, ml in mints_by_class.items()
282
+ )
283
+ print(f"INFO: Total tasks: {len(tasks)} (expected ~{expected_files} output files, target ~{target_total})")
284
+
285
+ success_count, skipped_count, error_count = 0, 0, 0
286
+ class_distribution = {}
287
+
288
+ # --- Resume support: skip tokens that already have cached files ---
289
+ existing_files = set(f.name for f in output_dir.glob("sample_*.pt"))
290
+ if existing_files:
291
+ pre_resume = len(tasks)
292
+ filtered_tasks = []
293
+ already_cached = 0
294
+ for task in tasks:
295
+ mint_addr = task[1] # task = (idx, mint_addr, ...)
296
+ # Check if any file exists for this mint (context mode: sample_MINT_0.pt, raw mode: sample_MINT.pt)
297
+ mint_prefix = f"sample_{mint_addr[:16]}"
298
+ has_cached = any(ef.startswith(mint_prefix) for ef in existing_files)
299
+ if has_cached:
300
+ already_cached += 1
301
+ # Count existing files toward class distribution
302
+ cid = return_class_map.get(mint_addr)
303
+ if cid is not None:
304
+ class_distribution[cid] = class_distribution.get(cid, 0) + 1
305
+ success_count += 1
306
+ else:
307
+ filtered_tasks.append(task)
308
+ tasks = filtered_tasks
309
+ print(f"INFO: Resume: {already_cached} tokens already cached, {len(tasks)} remaining (was {pre_resume})")
310
+
311
+ print(f"INFO: Starting to cache {len(tasks)} tokens...")
312
+ process_fn = _process_single_token_context if args.cache_mode == "context" else _process_single_token_raw
313
+
314
+ import time as _time
315
+
316
+ def _log_progress(task_num, total, start_time, recent_times, success_count, skipped_count, error_count):
317
+ """Print progress with rolling ETA every 10 tokens."""
318
+ if (task_num + 1) % 10 == 0 and recent_times:
319
+ avg_time = sum(recent_times) / len(recent_times)
320
+ remaining = total - (task_num + 1)
321
+ eta_seconds = avg_time * remaining
322
+ eta_hours = eta_seconds / 3600
323
+ wall_elapsed = _time.perf_counter() - start_time
324
+ speed = (task_num + 1) / wall_elapsed
325
+ tqdm.write(
326
+ f" [PROGRESS] {task_num+1}/{total} | "
327
+ f"Speed: {speed:.1f} tok/s ({speed*60:.0f} tok/min) | "
328
+ f"Avg: {avg_time:.1f}s/tok | "
329
+ f"ETA: {eta_hours:.1f}h | "
330
+ f"OK: {success_count} Skip: {skipped_count} Err: {error_count}"
331
+ )
332
+
333
+ # Error log file for diagnosing failures
334
+ error_log_path = Path(args.output_dir) / "cache_errors.log"
335
+ error_samples = [] # First 20 unique error messages
336
+
337
+ if args.num_workers == 1:
338
+ print("INFO: Single-threaded mode...")
339
+ _init_worker(db_config, dataset_config, return_class_map, quality_scores_map)
340
+ start_time = _time.perf_counter()
341
+ recent_times = []
342
+ for task_num, task in enumerate(tqdm(tasks, desc="Caching", unit="tok")):
343
+ t0 = _time.perf_counter()
344
+ result = process_fn(task)
345
+ elapsed = _time.perf_counter() - t0
346
+ recent_times.append(elapsed)
347
+ if len(recent_times) > 50:
348
+ recent_times.pop(0)
349
+ if result['status'] == 'success':
350
+ success_count += 1
351
+ class_distribution[result['class_id']] = class_distribution.get(result['class_id'], 0) + 1
352
+ elif result['status'] == 'skipped':
353
+ skipped_count += 1
354
+ else:
355
+ error_count += 1
356
+ err_msg = result.get('error', 'unknown')
357
+ tqdm.write(f"ERROR: {result['mint'][:16]} - {err_msg}")
358
+ if len(error_samples) < 20:
359
+ error_samples.append({'mint': result.get('mint'), 'error': err_msg, 'traceback': result.get('traceback', '')})
360
+ _log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
361
+ else:
362
+ print(f"INFO: Running with {args.num_workers} workers...")
363
+ start_time = _time.perf_counter()
364
+ recent_times = []
365
+ with ProcessPoolExecutor(max_workers=args.num_workers, initializer=_init_worker, initargs=(db_config, dataset_config, return_class_map, quality_scores_map)) as executor:
366
+ futures = {executor.submit(process_fn, task): task for task in tasks}
367
+ for task_num, future in enumerate(tqdm(as_completed(futures), total=len(futures), desc="Caching", unit="tok")):
368
+ t0 = _time.perf_counter()
369
+ try:
370
+ result = future.result(timeout=300)
371
+ elapsed = _time.perf_counter() - t0
372
+ recent_times.append(elapsed)
373
+ if len(recent_times) > 50:
374
+ recent_times.pop(0)
375
+ if result['status'] == 'success':
376
+ success_count += 1
377
+ class_distribution[result['class_id']] = class_distribution.get(result['class_id'], 0) + 1
378
+ elif result['status'] == 'skipped':
379
+ skipped_count += 1
380
+ else:
381
+ error_count += 1
382
+ err_msg = result.get('error', 'unknown')
383
+ if len(error_samples) < 20:
384
+ error_samples.append({'mint': result.get('mint'), 'error': err_msg, 'traceback': result.get('traceback', '')})
385
+ if error_count <= 5:
386
+ tqdm.write(f"ERROR: {result.get('mint', '?')[:16]} - {err_msg}")
387
+ except Exception as e:
388
+ error_count += 1
389
+ tqdm.write(f"WORKER ERROR: {e}")
390
+ _log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
391
+
392
+ # Write error log
393
+ if error_samples:
394
+ with open(error_log_path, 'w') as ef:
395
+ for i, es in enumerate(error_samples):
396
+ ef.write(f"=== Error {i+1} === Token: {es['mint']}\n")
397
+ ef.write(f"Error: {es['error']}\n")
398
+ ef.write(f"Traceback:\n{es['traceback']}\n\n")
399
+ print(f"INFO: First {len(error_samples)} error tracebacks saved to {error_log_path}")
400
+
401
+ print("INFO: Building metadata...")
402
+ file_class_map = {}
403
+ for f in sorted(output_dir.glob("sample_*.pt")):
404
+ try:
405
+ file_class_map[f.name] = torch.load(f, map_location="cpu", weights_only=False).get("class_id", 0)
406
+ except:
407
+ pass
408
+
409
+ with open(output_dir / "class_metadata.json", 'w') as f:
410
+ json.dump({
411
+ 'file_class_map': file_class_map,
412
+ 'class_distribution': {str(k): v for k, v in class_distribution.items()},
413
+ 'cache_mode': args.cache_mode,
414
+ 'num_workers': args.num_workers,
415
+ 'horizons_seconds': args.horizons_seconds,
416
+ 'quantiles': args.quantiles,
417
+ 'class_multipliers': {str(k): v for k, v in class_multipliers.items()},
418
+ 'class_token_caps': {str(k): v for k, v in class_token_caps.items()},
419
+ 'target_total': target_total,
420
+ 'target_per_class': target_per_class,
421
+ }, f, indent=2)
422
+
423
+ print(f"\n--- Done ---\nSuccess: {success_count}, Skipped: {skipped_count}, Errors: {error_count}\nFiles: {len(file_class_map)}\nLocation: {output_dir.resolve()}")
424
+
425
+ finally:
426
+ clickhouse_client.disconnect()
427
+ neo4j_driver.close()
428
+
429
+
430
+ if __name__ == "__main__":
431
+ main()
scripts/analyze_distribution.py CHANGED
@@ -313,8 +313,108 @@ def print_stats(name, values):
313
 
314
  print(f" {name}: mean={mean:.4f} p50={p50:.4f} p90={p90:.4f} p99={p99:.4f} nonzero_rate={nonzero_rate:.3f} (n={len(vals)})")
315
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  def analyze():
317
  client = get_client()
 
 
 
 
318
  data = fetch_all_metrics(client)
319
  final_buckets, thresholds, count_manipulated = _classify_tokens(data)
320
 
 
313
 
314
  print(f" {name}: mean={mean:.4f} p50={p50:.4f} p90={p90:.4f} p99={p99:.4f} nonzero_rate={nonzero_rate:.3f} (n={len(vals)})")
315
 
316
+ def fetch_wallet_pnl_stats(client):
317
+ print(" -> Fetching Wallet PnL Quantiles (7d, 30d) - Unique per wallet...")
318
+ # Use argMax to get latest entry per wallet (table is a time-series dump)
319
+ query = """
320
+ WITH unique_wallets AS (
321
+ SELECT
322
+ wallet_address,
323
+ argMax(stats_30d_realized_profit_pnl, updated_at) as pnl_30d,
324
+ argMax(stats_7d_realized_profit_pnl, updated_at) as pnl_7d
325
+ FROM wallet_profile_metrics
326
+ GROUP BY wallet_address
327
+ )
328
+ SELECT
329
+ count() as n,
330
+ countIf(pnl_30d > 0.001) as pos_30d,
331
+ quantiles(0.5, 0.9, 0.95, 0.99, 0.999)(pnl_30d) as q_30d,
332
+ max(pnl_30d) as max_30d,
333
+
334
+ countIf(pnl_7d > 0.001) as pos_7d,
335
+ quantiles(0.5, 0.9, 0.95, 0.99, 0.999)(pnl_7d) as q_7d,
336
+ max(pnl_7d) as max_7d
337
+ FROM unique_wallets
338
+ WHERE pnl_30d > -999 OR pnl_7d > -999
339
+ """
340
+ rows = client.execute(query)
341
+ if not rows: return None
342
+ return rows[0]
343
+
344
+ def fetch_trade_stats(client):
345
+ print(" -> Fetching Trade Quantiles (USD & Supply %)...")
346
+ query = """
347
+ SELECT
348
+ count() as n,
349
+ quantiles(0.5, 0.9, 0.95, 0.99, 0.999)(t.total_usd) as q_usd,
350
+ quantiles(0.5, 0.9, 0.95, 0.99, 0.999)((t.base_amount / m.total_supply) * 100) as q_sup
351
+ FROM trades t
352
+ JOIN mints m ON t.base_address = m.mint_address
353
+ WHERE m.total_supply > 0
354
+ """
355
+ rows = client.execute(query)
356
+ if not rows: return None
357
+ return rows[0]
358
+
359
+ def fetch_kol_stats(client):
360
+ print(" -> Fetching KOL stats from wallet_socials...")
361
+ query = """
362
+ SELECT
363
+ uniq(wallet_address) as total_wallets,
364
+ uniqIf(wallet_address, kolscan_name != '' OR cabalspy_name != '' OR axiom_kol_name != '') as kols
365
+ FROM wallet_socials
366
+ """
367
+ rows = client.execute(query)
368
+ print(f" (DEBUG) KOL query result: {rows}")
369
+ if rows:
370
+ return rows[0]
371
+ return (0, 0)
372
+
373
+ def print_quantiles(name, n, pos_rate, q, max_val=None):
374
+ # q is list [p50, p90, p95, p99, p999]
375
+ print(f"\n[{name}] (n={n})")
376
+ if pos_rate is not None:
377
+ print(f" Positive Rate: {pos_rate*100:.1f}%")
378
+ print(f" p50={q[0]:.4f}")
379
+ print(f" p90={q[1]:.4f}")
380
+ print(f" p95={q[2]:.4f}")
381
+ print(f" p99={q[3]:.4f}")
382
+ print(f" p99.9={q[4]:.4f}")
383
+ if max_val is not None:
384
+ print(f" Max={max_val:.4f}")
385
+
386
+ def analyze_thresholds(client):
387
+ print("\n=== THRESHOLD DISTRIBUTION ANALYSIS (DB-Side) ===")
388
+
389
+ # 1. PnL
390
+ pnl_row = fetch_wallet_pnl_stats(client)
391
+ if pnl_row:
392
+ n, pos_30d, q_30d, max_30d, pos_7d, q_7d, max_7d = pnl_row
393
+ print_quantiles("Wallet PnL (30d)", n, pos_30d/n if n>0 else 0, q_30d, max_30d)
394
+ print_quantiles("Wallet PnL (7d)", n, pos_7d/n if n>0 else 0, q_7d, max_7d)
395
+
396
+ # 2. Trades
397
+ trade_row = fetch_trade_stats(client)
398
+ if trade_row:
399
+ n, q_usd, q_sup = trade_row
400
+ print_quantiles("Trade USD Size", n, None, q_usd)
401
+ print_quantiles("Trade Supply %", n, None, q_sup)
402
+
403
+ # 3. KOLs
404
+ total, kols = fetch_kol_stats(client)
405
+ if total > 0:
406
+ print("\n[KOL Statistics]")
407
+ print(f" Total Wallets with Socials: {total}")
408
+ print(f" Identified KOLs: {kols}")
409
+ print(f" KOL Ratio: {(kols/total)*100:.2f}%")
410
+
411
+
412
  def analyze():
413
  client = get_client()
414
+
415
+ # Run new analysis first
416
+ analyze_thresholds(client)
417
+
418
  data = fetch_all_metrics(client)
419
  final_buckets, thresholds, count_manipulated = _classify_tokens(data)
420
 
scripts/dump_cache_sample.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Dump a cached .pt sample to JSON for manual debugging.
4
+
5
+ Usage:
6
+ python scripts/dump_cache_sample.py # Dump first sample
7
+ python scripts/dump_cache_sample.py --index 5 # Dump sample at index 5
8
+ python scripts/dump_cache_sample.py --file data/cache/sample_ABC123.pt # Dump specific file
9
+ python scripts/dump_cache_sample.py --output debug.json # Custom output path
10
+ """
11
+
12
+ import argparse
13
+ import json
14
+ import sys
15
+ import os
16
+
17
+ # Add project root to path so torch.load can find project modules when unpickling
18
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
19
+
20
+ import torch
21
+ import numpy as np
22
+ from pathlib import Path
23
+ from datetime import datetime
24
+
25
+
26
+ def convert_to_serializable(obj):
27
+ """Recursively convert non-JSON-serializable objects."""
28
+ if obj is None:
29
+ return None
30
+ if isinstance(obj, (str, int, float, bool)):
31
+ return obj
32
+ if isinstance(obj, (np.integer,)):
33
+ return int(obj)
34
+ if isinstance(obj, (np.floating,)):
35
+ return float(obj)
36
+ if isinstance(obj, np.ndarray):
37
+ return {"__type__": "ndarray", "shape": list(obj.shape), "dtype": str(obj.dtype), "data": obj.tolist()}
38
+ if isinstance(obj, torch.Tensor):
39
+ return {"__type__": "tensor", "shape": list(obj.shape), "dtype": str(obj.dtype), "data": obj.tolist()}
40
+ if isinstance(obj, datetime):
41
+ return {"__type__": "datetime", "value": obj.isoformat()}
42
+ if isinstance(obj, bytes):
43
+ return {"__type__": "bytes", "length": len(obj), "preview": obj[:100].hex() if len(obj) > 0 else ""}
44
+ if isinstance(obj, dict):
45
+ return {str(k): convert_to_serializable(v) for k, v in obj.items()}
46
+ if isinstance(obj, (list, tuple)):
47
+ return [convert_to_serializable(item) for item in obj]
48
+ if isinstance(obj, set):
49
+ return {"__type__": "set", "data": list(obj)}
50
+ # Fallback: try str representation
51
+ try:
52
+ return {"__type__": type(obj).__name__, "repr": str(obj)[:500]}
53
+ except:
54
+ return {"__type__": "unknown", "repr": "<not serializable>"}
55
+
56
+
57
+ def main():
58
+ parser = argparse.ArgumentParser(description="Dump cached .pt sample to JSON")
59
+ parser.add_argument("--index", "-i", type=int, default=0, help="Index of sample to dump (default: 0)")
60
+ parser.add_argument("--file", "-f", type=str, default=None, help="Direct path to .pt file (overrides --index)")
61
+ parser.add_argument("--cache_dir", "-c", type=str, default="data/cache", help="Cache directory (default: data/cache)")
62
+ parser.add_argument("--output", "-o", type=str, default=None, help="Output JSON path (default: auto-generated)")
63
+ parser.add_argument("--compact", action="store_true", help="Compact JSON output (no indentation)")
64
+ args = parser.parse_args()
65
+
66
+ # Determine which file to load
67
+ if args.file:
68
+ filepath = Path(args.file)
69
+ if not filepath.exists():
70
+ print(f"ERROR: File not found: {filepath}")
71
+ return 1
72
+ else:
73
+ cache_dir = Path(args.cache_dir)
74
+ if not cache_dir.is_dir():
75
+ print(f"ERROR: Cache directory not found: {cache_dir}")
76
+ return 1
77
+
78
+ cached_files = sorted(cache_dir.glob("sample_*.pt"))
79
+ if not cached_files:
80
+ print(f"ERROR: No sample_*.pt files found in {cache_dir}")
81
+ return 1
82
+
83
+ if args.index >= len(cached_files):
84
+ print(f"ERROR: Index {args.index} out of range. Found {len(cached_files)} files.")
85
+ return 1
86
+
87
+ filepath = cached_files[args.index]
88
+
89
+ print(f"Loading: {filepath}")
90
+
91
+ # Load the .pt file
92
+ try:
93
+ data = torch.load(filepath, map_location="cpu", weights_only=False)
94
+ except Exception as e:
95
+ print(f"ERROR: Failed to load file: {e}")
96
+ return 1
97
+
98
+ # Convert to JSON-serializable format
99
+ print("Converting to JSON-serializable format...")
100
+ serializable_data = convert_to_serializable(data)
101
+
102
+ # Add metadata
103
+ output_data = {
104
+ "__metadata__": {
105
+ "source_file": str(filepath.absolute()),
106
+ "dumped_at": datetime.now().isoformat(),
107
+ "cache_mode": data.get("cache_mode", "unknown") if isinstance(data, dict) else "unknown"
108
+ },
109
+ "data": serializable_data
110
+ }
111
+
112
+ # Determine output path
113
+ if args.output:
114
+ output_path = Path(args.output)
115
+ else:
116
+ # Default: Save to current directory (root) instead of inside cache dir
117
+ output_path = Path.cwd() / filepath.with_suffix(".json").name
118
+
119
+ # Write JSON
120
+ print(f"Writing to: {output_path}")
121
+ indent = None if args.compact else 2
122
+ with open(output_path, "w") as f:
123
+ json.dump(output_data, f, indent=indent, ensure_ascii=False)
124
+
125
+ # Print summary
126
+ if isinstance(data, dict):
127
+ print(f"\n=== Summary ===")
128
+ print(f"Top-level keys: {list(data.keys())}")
129
+ print(f"Cache mode: {data.get('cache_mode', 'not specified')}")
130
+ if 'event_sequence' in data:
131
+ print(f"Event count: {len(data['event_sequence'])}")
132
+ if 'trades' in data:
133
+ print(f"Trade count: {len(data['trades'])}")
134
+ if 'source_token' in data:
135
+ print(f"Source token: {data['source_token']}")
136
+ if 'class_id' in data:
137
+ print(f"Class ID: {data['class_id']}")
138
+ if 'quality_score' in data:
139
+ print(f"Quality score: {data['quality_score']}")
140
+
141
+ print(f"\nDone! JSON saved to: {output_path}")
142
+ return 0
143
+
144
+
145
+ if __name__ == "__main__":
146
+ exit(main())
train.py CHANGED
@@ -406,7 +406,7 @@ def main() -> None:
406
  hf_token = os.getenv("HF_TOKEN")
407
  if hf_token:
408
  print(f"Logging in to Hugging Face with token starting with: {hf_token[:4]}...")
409
- huggingface_hub.login(token=hf_token)
410
  else:
411
  print("WARNING: HF_TOKEN not found in environment.")
412
 
@@ -437,7 +437,7 @@ def main() -> None:
437
  collator_encoder = CollatorEncoder(
438
  model_id=collator.model_id,
439
  dtype=init_dtype,
440
- device="cpu" # Collator runs on CPU to save VRAM
441
  )
442
  _set_worker_encoder(collator_encoder)
443
  logger.info("SigLIP encoder pre-loaded successfully.")
 
406
  hf_token = os.getenv("HF_TOKEN")
407
  if hf_token:
408
  print(f"Logging in to Hugging Face with token starting with: {hf_token[:4]}...")
409
+ pass # huggingface_hub.login(token=hf_token)
410
  else:
411
  print("WARNING: HF_TOKEN not found in environment.")
412
 
 
437
  collator_encoder = CollatorEncoder(
438
  model_id=collator.model_id,
439
  dtype=init_dtype,
440
+ device="cuda" # Use GPU for encoding (requires num_workers=0)
441
  )
442
  _set_worker_encoder(collator_encoder)
443
  logger.info("SigLIP encoder pre-loaded successfully.")
train.sh CHANGED
@@ -1,12 +1,12 @@
1
  accelerate launch train.py \
2
- --epochs 10 \
3
  --batch_size 8 \
4
  --learning_rate 1e-4 \
5
  --warmup_ratio 0.1 \
6
  --grad_accum_steps 2 \
7
  --max_grad_norm 1.0 \
8
  --seed 42 \
9
- --log_every 50 \
10
  --save_every 2000 \
11
  --tensorboard_dir runs/oracle \
12
  --checkpoint_dir checkpoints \
@@ -15,8 +15,8 @@ accelerate launch train.py \
15
  --horizons_seconds 30 60 120 240 420 \
16
  --quantiles 0.1 0.5 0.9 \
17
  --ohlc_stats_path ./data/ohlc_stats.npz \
18
- --num_workers 4 \
19
  --pin_memory \
20
  --val_split 0.1 \
21
- --val_every 2000 \
22
  "$@"
 
1
  accelerate launch train.py \
2
+ --epochs 1 \
3
  --batch_size 8 \
4
  --learning_rate 1e-4 \
5
  --warmup_ratio 0.1 \
6
  --grad_accum_steps 2 \
7
  --max_grad_norm 1.0 \
8
  --seed 42 \
9
+ --log_every 3 \
10
  --save_every 2000 \
11
  --tensorboard_dir runs/oracle \
12
  --checkpoint_dir checkpoints \
 
15
  --horizons_seconds 30 60 120 240 420 \
16
  --quantiles 0.1 0.5 0.9 \
17
  --ohlc_stats_path ./data/ohlc_stats.npz \
18
+ --num_workers 0 \
19
  --pin_memory \
20
  --val_split 0.1 \
21
+ --val_every 50 \
22
  "$@"