zirobtc commited on
Commit
858826c
·
1 Parent(s): 3596954

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ log.log filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore the __pycache__ directory anywhere in the repository
2
+ __pycache__/
3
+
4
+ # Ignore all .txt files anywhere in the repository
5
+ *.txt
6
+
7
+ # Ignore the 'runs' directory anywhere in the repository, regardless of nesting
8
+ runs/
9
+
10
+ data/pump_fun
11
+
12
+ .env
FullCryptoGuide.md ADDED
The diff for this file is too large to render. See raw diff
 
README.md ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =========================================
2
+ # Entity Encoders
3
+ # =========================================
4
+ # These are generated offline/streaming and are the "vocabulary" for the model.
5
+
6
+ <WalletEmbedding> # Embedding of a wallet's relationships, behavior, and history.
7
+ <WalletEmbedding> = [
8
+ // Data from the 'wallet_profiles' table (Wallet-level lifetime and daily/weekly stats)
9
+ wallet_profiles_row: [
10
+ // Core Info & Timestamps
11
+ age, // No Contextual
12
+ wallet_address, // Primary wallet identifier
13
+
14
+
15
+ // 7. NEW: Deployed Token Aggregates (8 Features)
16
+ deployed_tokens_count, // Total tokens created
17
+ deployed_tokens_migrated_pct, // % that migrated
18
+ deployed_tokens_avg_lifetime_sec, // Avg duration before dev selling
19
+ deployed_tokens_avg_peak_mc_usd, // Avg peak marketcap
20
+ deployed_tokens_median_peak_mc_usd,
21
+
22
+ // Metadata & Balances
23
+ balance, // Current SOL balance
24
+
25
+ // Lifetime Transaction Counts (Total history)
26
+ transfers_in_count, // Total native transfers received
27
+ transfers_out_count, // Total native transfers sent
28
+ spl_transfers_in_count, // Total SPL token transfers received
29
+ spl_transfers_out_count,// Total SPL token transfers sent
30
+
31
+ // Lifetime Trading Stats (Total history)
32
+ total_buys_count, // Total buys across all tokens
33
+ total_sells_count, // Total sells across all tokens
34
+ total_winrate, // Overall trading winrate
35
+
36
+ // 1-Day Stats (Realized P&L, Counts, Averages, Volume, Fees, Winrate)
37
+ stats_1d_realized_profit_sol,
38
+ stats_1d_realized_profit_pnl,
39
+ stats_1d_buy_count,
40
+ stats_1d_sell_count,
41
+ stats_1d_transfer_in_count,
42
+ stats_1d_transfer_out_count,
43
+ stats_1d_avg_holding_period,
44
+ stats_1d_total_bought_cost_sol,
45
+ stats_1d_total_sold_income_sol,
46
+ stats_1d_total_fee,
47
+ stats_1d_winrate,
48
+ stats_1d_tokens_traded,
49
+
50
+ // 7-Day Stats (Realized P&L, Counts, Averages, Volume, Fees, Winrate)
51
+ stats_7d_realized_profit_sol,
52
+ stats_7d_realized_profit_pnl,
53
+ stats_7d_buy_count,
54
+ stats_7d_sell_count,
55
+ stats_7d_transfer_in_count,
56
+ stats_7d_transfer_out_count,
57
+ stats_7d_avg_holding_period,
58
+ stats_7d_total_bought_cost_sol,
59
+ stats_7d_total_sold_income_sol,
60
+ stats_7d_total_fee,
61
+ stats_7d_winrate,
62
+ stats_7d_tokens_traded,
63
+
64
+ // 30 Days is to useless in the context
65
+ ],
66
+
67
+ // Data from the 'wallet_socials' table (Social media and profile info)
68
+ wallet_socials_row: [
69
+ has_pf_profile,
70
+ has_twitter,
71
+ has_telegram,
72
+ is_exchange_wallet,
73
+ username,
74
+ ],
75
+ // Data from the 'wallet_holdings' table (Token-level statistics for held tokens)
76
+ wallet_holdings_pool: [
77
+ <TokenVibeEmbedding>,
78
+ holding_time, // How much he held the token (We check only tokens that currently is holding, or recently traded)
79
+
80
+ balance_pct_to_supply, // Current quantity of the token held
81
+
82
+ // History (Amounts & Costs)
83
+ history_bought_amount_sol, // Total amount of token bought
84
+ bought_amount_sol_pct_to_native_balance // Is he traded a lot of his wallet size
85
+
86
+ // History (Counts)
87
+ history_total_buys, // Total number of buy transactions
88
+ history_total_sells, // Total number of sell transactions
89
+
90
+ // Profit and Loss
91
+ realized_profit_pnl, // Realized P&L as a percentage
92
+ realized_profit_sol,
93
+
94
+ // Transfers (Non-trade movements)
95
+ history_transfer_in,
96
+ history_transfer_out,
97
+
98
+ avarage_trade_gap_seconds,
99
+ total_priority_fees, // Total tips + Priority Fees
100
+ ]
101
+ ]
102
+
103
+ <TokenVibeEmbedding> # Multimodal embedding of a token's identity
104
+ <TokenVibeEmbedding> = [<TokenAddressEmbedding>, <NameEmbedding>, <SymbolEmbedding>, <ImageEmbedding>, protocol_id]
105
+
106
+ <TextEmbedding> # Text embedding MultiModal processor.
107
+ <MediaEmbedding> # Multimodal VIT encoder.
108
+
109
+ # -----------------------------------------
110
+ # 1. TradeEncoder
111
+ # -----------------------------------------
112
+
113
+ # Captures large-size trades from any wallet.
114
+ [timestamp, 'LargeTrade', relative_ts, <WalletEmbedding>, trade_direction, sol_amount, dex_platform_id, priority_fee, mev_protection, token_amount_pct_of_holding, quote_amount_pct_of_holding, slippage, price_impact, success, is_bundle, total_usd]
115
+
116
+ # Captures the high-signal "Dev Sold or Bought" event.
117
+ [timestamp, 'Deployer_Trade', relative_ts, <CreatorWalletEmbedding>, trade_direction, sol_amount, dex_platform_id, priority_fee, mev_protection, token_amount_pct_of_holding, quote_amount_pct_of_holding, slippage, price_impact, success, is_bundle, total_usd]
118
+
119
+ # Captures *all* trades from pre-defined high-P&L/win-rate, kol and known wallets.
120
+ [timestamp, 'SmartWallet_Trade', relative_ts, <TraderWalletEmbedding>, trade_direction, sol_amount, dex_platform_id, priority_fee, mev_protection, token_amount_pct_of_holding, quote_amount_pct_of_holding, slippage, price_impact, success, is_bundle, total_usd]
121
+
122
+ # Raw trades. Loaded in H/B/H Prefix (first ~10k) and Suffix (last ~5k).
123
+ [timestamp, 'Trade', relative_ts, <TraderWalletEmbedding>, trade_direction, sol_amount, dex_platform_id, priority_fee, mev_protection, token_amount_pct_of_holding, quote_amount_pct_of_holding, slippage, price_impact, success, is_bundle, total_usd]
124
+
125
+ # -----------------------------------------
126
+ # 2. TransferEncoder
127
+ # -----------------------------------------
128
+
129
+ # Raw transfers. Loaded in H/B/H Prefix (all in first ~10k trade window) and Suffix (all in last ~5k trade window).
130
+ [timestamp, 'Transfer', relative_ts, <SourceWalletEmbedding>, <DestinationWalletEmbedding>, token_amount, transfer_pct_of_total_supply, transfer_pct_of_holding, priority_fee]
131
+
132
+ # Captures scarce, large transfers *after* the initial launch window.
133
+ [timestamp, 'LargeTransfer', relative_ts, <FromWalletEmbedding>, <ToWalletEmbedding>, token_amount, transfer_pct_of_total_supply, transfer_pct_of_holding, priority_fee]
134
+
135
+ # -----------------------------------------
136
+ # 3. LifecycleEncoder
137
+ # -----------------------------------------
138
+
139
+ # The T0 event.
140
+ [timestamp, 'Mint', 0, <CreatorWalletEmbedding>, <TokenVibeEmbedding>]
141
+
142
+ # -----------------------------------------
143
+ # 3. PoolEncoder
144
+ # -----------------------------------------
145
+
146
+ # Signals migration from launchpad to a real pool.
147
+ [timestamp, 'PoolCreated', relative_ts, <ProviderWalletEmbedding>, protocol_id, <QuoteTokenVibeEmbedding>, base_amount, quote_amount, quote_pct_to_main_pool_balance, base_pct_to_main_pool_balance]
148
+
149
+ # Signals LP addition or removal.
150
+ [timestamp, 'LiquidityChange', relative_ts, <ProviderWalletEmbedding>, <QuoteTokenVibeEmbedding>, change_type_id, quote_amount, quote_pct_to_current_pool_balance]
151
+
152
+ # Signals creator/dev taking platform fees.
153
+ [timestamp, 'FeeCollected', relative_ts, <RecipientWalletEmbedding>, sol_amount, token_amount]
154
+
155
+
156
+ # -----------------------------------------
157
+ # SupplyEncoder
158
+ # -----------------------------------------
159
+
160
+ # Signals a supply reduction.
161
+ [timestamp, 'TokenBurn', relative_ts, <BurnerWalletEmbedding>, amount_pct_of_total_supply, amount_tokens_burned]
162
+
163
+ # Signals locked supply, e.g., for team/marketing.
164
+ [timestamp, 'SupplyLock', relative_ts, <LockerWalletEmbedding>, amount_pct_of_total_supply, lock_duration]
165
+
166
+ # -----------------------------------------
167
+ # ChartEncoder
168
+ # -----------------------------------------
169
+
170
+ # (The "Sliding Window") This is the new chart event.
171
+ [timestamp, 'Chart_Segment', relative_ts, OHLC_segment, chart_interval_id]
172
+
173
+ # -----------------------------------------
174
+ # PulseEncoder
175
+ # -----------------------------------------
176
+
177
+ # It is a low-frequency event (Dynamic Interval: 5min, 15min, or 1hr based on token age).
178
+ [timestamp, 'OnChain_Snapshot', relative_ts, total_holders, smart_traders, kols, holder_growth_rate, top_10_holder_pct, sniper_holding_pct, rat_wallets_holding_pct, bundle_holding_pct, current_market_cap, liquidity, volume, buy_count, sell_count, total_txns, global_fees_paid]
179
+
180
+ # -----------------------------------------
181
+ # HoldersListEncoder
182
+ # -----------------------------------------
183
+
184
+ <HolderDistributionEmbedding> # Transformer-based embedding of the top holders (WalletEmbeddings + Pct).
185
+
186
+ # Token-specific holder analysis.
187
+ [timestamp, 'HolderSnapshot', relative_ts, <HolderDistributionEmbedding>]
188
+
189
+
190
+ # -----------------------------------------
191
+ # ChainSnapshotEncoder
192
+ # -----------------------------------------
193
+
194
+ # Broad chain-level market conditions.
195
+ [timestamp, 'ChainSnapshot', relative_ts, native_token_price_usd, gas_fee]
196
+
197
+ # Launchpad market regime (using absolute, log-normalized values).
198
+ [timestamp, 'Lighthouse_Snapshot', relative_ts, protocol_id, timeframe_id, total_volume, total_transactions, total_traders, total_tokens_created, total_migrations]
199
+
200
+ # -----------------------------------------
201
+ # TokenTrendingListEncoder
202
+ # -----------------------------------------
203
+
204
+ # Fires *per token* on a trending list. The high-attention "meta" signal.
205
+ [timestamp, 'TrendingToken', relative_ts, <TokenVibeEmbedding_of_trending_token>, list_source_id, timeframe_id, rank]
206
+
207
+ # Fires *per token* on the boosted list.
208
+ [timestamp, 'BoostedToken', relative_ts, <TokenVibeEmbedding_of_boosted_token>, total_boost_amount, rank]
209
+
210
+ # -----------------------------------------
211
+ # LaunchpadTheadEncoder
212
+ # -----------------------------------------
213
+
214
+ # On-platform social signal (Pump.fun comments).
215
+ [timestamp, 'PumpReply', relative_ts, <UserWalletEmbedding>, <ReplyTextEmbedding>]
216
+
217
+ # -----------------------------------------
218
+ # CTEncoder
219
+ # -----------------------------------------
220
+
221
+ # Off-platform social signal (Twitter).
222
+ [timestamp, 'XPost', relative_ts, <AuthorWalletEmbedding>, <PostTextEmbedding>, <MediaEmbedding>]
223
+ [timestamp, 'XRetweet', relative_ts, <RetweeterWalletEmbedding>, <OriginalAuthorWalletEmbedding>, <OriginalPostTextEmbedding>, <OriginalPostMediaEmbedding>]
224
+ [timestamp, 'XReply', relative_ts, <AuthorWalletEmbedding>, <PostTextEmbedding>, <MediaEmbedding>, <MainTweetEmbedding>]
225
+ [timestamp, 'XQuoteTweet', relative_ts, <QuoterWalletEmbedding>, <QuoterTextEmbedding>, <OriginalAuthorWalletEmbedding>, <OriginalPostTextEmbedding>, <OriginalPostMediaEmbedding>]
226
+
227
+ # -----------------------------------------
228
+ # GlobalTrendingEncoder
229
+ # -----------------------------------------
230
+
231
+ # Broader cultural trend signal (TikTok).
232
+ [timestamp, 'TikTok_Trending_Hashtag', relative_ts, <HashtagNameEmbedding>, rank]
233
+
234
+ # Broader cultural trend signal (Twitter).
235
+ [timestamp, 'XTrending_Hashtag', relative_ts, <HashtagNameEmbedding>, rank]
236
+
237
+ # -----------------------------------------
238
+ # TrackerEncoder
239
+ # -----------------------------------------
240
+
241
+ # Retail marketing signal (Paid groups).
242
+ [timestamp, 'AlphaGroup_Call', relative_ts, group_id]
243
+
244
+ [timestamp, 'Call_Channel', relative_ts, channel_id]
245
+
246
+ # High-impact catalyst event.
247
+ [timestamp, 'CexListing', relative_ts, exchange_id]
248
+
249
+ # High-impact catalyst event.
250
+ [timestamp, 'Migrated', relative_ts, protocol_id]
251
+
252
+ # -----------------------------------------
253
+ # Dex Encoder
254
+ # -----------------------------------------
255
+
256
+ [timestamp, 'DexBoost_Paid', relative_ts, amount, total_amount_on_token]
257
+
258
+ [timestamp, 'DexProfile_Updated', relative_ts, has_changed_website_flag, has_changed_twitter_flag, has_changed_telegram_flag, has_changed_description_flag, <WebsiteEmbedding>, <TwitterLinkEmbedding>, <NewDescriptionEmbeeded>]
259
+
260
+ ### **Global Context Injection**
261
+
262
+ <PRELAUNCH> <LAUNCH> <Middle> <RECENT>
263
+
264
+ ### **Token Role Embedding**
265
+
266
+ <TokenVibeEmbedding_of_Token_A> + Subject_Token_Role
267
+
268
+ <TokenVibeEmbedding_of_Token_B> + Trending_Token_Role
269
+
270
+ <QuoteTokenVibeEmbedding_of_USDC> + Quote_Token_Role
271
+
272
+
273
+ # **Links**
274
+
275
+ ### `TransferLink`
276
+
277
+ ```
278
+ ['signature', 'source', 'destination', 'mint', 'timestamp']
279
+ ```
280
+
281
+ -----
282
+
283
+ ### `BundleTradeLink`
284
+
285
+ ```
286
+ ['signatures', 'wallet_a', 'wallet_b', 'mint', 'slot', 'timestamp']
287
+ ```
288
+
289
+ -----
290
+
291
+ ### `CopiedTradeLink`
292
+
293
+ ```
294
+ ['leader_buy_sig', 'leader_sell_sig', 'follower_buy_sig', 'follower_sell_sig', 'follower', 'leader', 'mint', 'time_gap_on_buy_sec', 'time_gap_on_sell_sec', 'leader_pnl', 'follower_pnl', 'leader_buy_total', 'leader_sell_total', 'follower_buy_total', 'follower_sell_total', 'follower_buy_slippage', 'follower_sell_slippage']
295
+ ```
296
+
297
+ -----
298
+
299
+ ### `CoordinatedActivityLink`
300
+
301
+ ```
302
+ ['leader_first_sig', 'leader_second_sig', 'follower_first_sig', 'follower_second_sig', 'follower', 'leader', 'mint', 'time_gap_on_first_sec', 'time_gap_on_second_sec']
303
+ ```
304
+
305
+ -----
306
+
307
+ ### `MintedLink`
308
+
309
+ ```
310
+ ['signature', 'timestamp', 'buy_amount']
311
+ ```
312
+
313
+ -----
314
+
315
+ ### `SnipedLink`
316
+
317
+ ```
318
+ ['signature', 'rank', 'sniped_amount']
319
+ ```
320
+
321
+ -----
322
+
323
+ ### `LockedSupplyLink`
324
+
325
+ ```
326
+ ['signature', 'amount', 'unlock_timestamp']
327
+ ```
328
+
329
+ -----
330
+
331
+ ### `BurnedLink`
332
+
333
+ ```
334
+ ['signature', 'amount', 'timestamp']
335
+ ```
336
+
337
+ -----
338
+
339
+ ### `ProvidedLiquidityLink`
340
+
341
+ ```
342
+ ['signature', 'wallet', 'token', 'pool_address', 'amount_base', 'amount_quote', 'timestamp']
343
+ ```
344
+
345
+ -----
346
+
347
+ ### `WhaleOfLink`
348
+
349
+ ```
350
+ ['wallet', 'token', 'holding_pct_at_creation', 'ath_usd_at_creation']
351
+ ```
352
+
353
+ -----
354
+
355
+ ### `TopTraderOfLink`
356
+
357
+ ```
358
+ ['wallet', 'token', 'pnl_at_creation', 'ath_usd_at_creation']
359
+ ```
360
+
361
+
362
+
363
+
364
+ /////
365
+
366
+ def __gettestitem__(self, idx: int) -> Dict[str, Any]:
367
+ """
368
+ Generates a single complex data item, structured for the MemecoinCollator.
369
+ NOTE: This currently returns the same mock data regardless of `idx`.
370
+ """
371
+ # --- 1. Setup Pooler and Define Raw Data ---
372
+ pooler = EmbeddingPooler()
373
+
374
+ # --- 5. Create Mock Raw Batch Data (FIXED) ---
375
+ print("Creating mock raw batch...")
376
+
377
+ # (Wallet profiles, socials, holdings definitions are unchanged)
378
+ profile1 = {
379
+ 'wallet_address': 'addrW1', 'age': 1.5e7, 'balance': 10.5,
380
+ 'deployed_tokens_count': 2, 'deployed_tokens_migrated_pct': 0.5, 'deployed_tokens_avg_lifetime_sec': 36000.0, 'deployed_tokens_avg_peak_mc_usd': 100000.0, 'deployed_tokens_median_peak_mc_usd': 50000.0,
381
+ 'transfers_in_count': 10, 'transfers_out_count': 5, 'spl_transfers_in_count': 20, 'spl_transfers_out_count': 15,
382
+ 'total_buys_count': 50, 'total_sells_count': 40, 'total_winrate': 0.6,
383
+ 'stats_1d_realized_profit_sol': 1.2, 'stats_1d_realized_profit_pnl': 0.1, 'stats_1d_buy_count': 5, 'stats_1d_sell_count': 3, 'stats_1d_transfer_in_count': 2, 'stats_1d_transfer_out_count': 1, 'stats_1d_avg_holding_period': 3600, 'stats_1d_total_bought_cost_sol': 10.0, 'stats_1d_total_sold_income_sol': 11.2, 'stats_1d_total_fee': 0.1, 'stats_1d_winrate': 0.7, 'stats_1d_tokens_traded': 4,
384
+ 'stats_7d_realized_profit_sol': 5.0, 'stats_7d_realized_profit_pnl': 0.2, 'stats_7d_buy_count': 20, 'stats_7d_sell_count': 15, 'stats_7d_transfer_in_count': 8, 'stats_7d_transfer_out_count': 4, 'stats_7d_avg_holding_period': 7200, 'stats_7d_total_bought_cost_sol': 40.0, 'stats_7d_total_sold_income_sol': 45.0, 'stats_7d_total_fee': 0.5, 'stats_7d_winrate': 0.65, 'stats_7d_tokens_traded': 10,
385
+ }
386
+ social1 = {'has_pf_profile': True, 'has_twitter': True, 'has_telegram': False, 'is_exchange_wallet': False, 'username': 'trader_one'}
387
+ holdings1 = [
388
+ {'mint_address': 'tknA', 'holding_time': 3600.0, 'realized_profit_sol': 5.2, 'total_priority_fees': 0.05, 'balance_pct_to_supply': 0.01, 'history_bought_amount_sol': 10, 'bought_amount_sol_pct_to_native_balance': 0.5, 'history_total_buys': 5, 'history_total_sells': 2, 'realized_profit_pnl': 0.52, 'history_transfer_in': 1, 'history_transfer_out': 0, 'avarage_trade_gap_seconds': 300},
389
+ ]
390
+ profile2 = {
391
+ 'wallet_address': 'addrW2', 'age': 1e6, 'balance': 1.0,
392
+ 'deployed_tokens_count': 0, 'deployed_tokens_migrated_pct': 0.0, 'deployed_tokens_avg_lifetime_sec': 0.0, 'deployed_tokens_avg_peak_mc_usd': 0.0, 'deployed_tokens_median_peak_mc_usd': 0.0,
393
+ 'transfers_in_count': 1, 'transfers_out_count': 0, 'spl_transfers_in_count': 0, 'spl_transfers_out_count': 0,
394
+ 'total_buys_count': 0, 'total_sells_count': 0, 'total_winrate': 0.0,
395
+ 'stats_1d_realized_profit_sol': 0.0, 'stats_1d_realized_profit_pnl': 0.0, 'stats_1d_buy_count': 0, 'stats_1d_sell_count': 0, 'stats_1d_transfer_in_count': 0, 'stats_1d_transfer_out_count': 0, 'stats_1d_avg_holding_period': 0, 'stats_1d_total_bought_cost_sol': 0.0, 'stats_1d_total_sold_income_sol': 0.0, 'stats_1d_total_fee': 0.0, 'stats_1d_winrate': 0.0, 'stats_1d_tokens_traded': 0,
396
+ 'stats_7d_realized_profit_sol': 0.0, 'stats_7d_realized_profit_pnl': 0.0, 'stats_7d_buy_count': 0, 'stats_7d_sell_count': 0, 'stats_7d_transfer_in_count': 0, 'stats_7d_transfer_out_count': 0, 'stats_7d_avg_holding_period': 0, 'stats_7d_total_bought_cost_sol': 0.0, 'stats_7d_total_sold_income_sol': 0.0, 'stats_7d_total_fee': 0.0, 'stats_7d_winrate': 0.0, 'stats_7d_tokens_traded': 0,
397
+ }
398
+ social2 = {'has_pf_profile': False, 'has_twitter': False, 'has_telegram': False, 'is_exchange_wallet': True, 'username': 'cex_wallet'}
399
+ holdings2 = []
400
+
401
+
402
+ # Define raw data and get their indices
403
+ tokenA_data = {
404
+ 'address_emb_idx': pooler.get_idx('tknA'),
405
+ 'name_emb_idx': pooler.get_idx('Token A'),
406
+ 'symbol_emb_idx': pooler.get_idx('TKA'),
407
+ 'image_emb_idx': pooler.get_idx(Image.new('RGB',(256,256), color='blue')),
408
+ 'protocol': 1
409
+ }
410
+ # Add wallet usernames to the pool
411
+ wallet1_user_idx = pooler.get_idx(social1['username'])
412
+ wallet2_user_idx = pooler.get_idx(social2['username'])
413
+ social1['username_emb_idx'] = wallet1_user_idx
414
+ social2['username_emb_idx'] = wallet2_user_idx
415
+ # --- NEW: Add a third wallet for social tests ---
416
+ social3 = {'has_pf_profile': False, 'has_twitter': True, 'has_telegram': True, 'is_exchange_wallet': False, 'username': 'social_butterfly'}
417
+ wallet3_user_idx = pooler.get_idx(social3['username'])
418
+ social3['username_emb_idx'] = wallet3_user_idx
419
+
420
+ # Create the final pre-computed data structures
421
+ tokenB_data = {
422
+ 'address_emb_idx': pooler.get_idx('tknA'),
423
+ 'name_emb_idx': pooler.get_idx('Token A'),
424
+ 'symbol_emb_idx': pooler.get_idx('TKA'),
425
+ 'image_emb_idx': pooler.get_idx(Image.new('RGB',(256,256), color='blue')),
426
+ 'protocol': 1
427
+ }
428
+
429
+ tokenC_data = {
430
+ 'address_emb_idx': pooler.get_idx('tknA'),
431
+ 'name_emb_idx': pooler.get_idx('Token A'),
432
+ 'symbol_emb_idx': pooler.get_idx('TKA'),
433
+ 'image_emb_idx': pooler.get_idx(Image.new('RGB',(256,256), color='blue')),
434
+ 'protocol': 1
435
+ }
436
+
437
+ tokenD_data = {
438
+ 'address_emb_idx': pooler.get_idx('tknA'),
439
+ 'name_emb_idx': pooler.get_idx('Token A'),
440
+ 'symbol_emb_idx': pooler.get_idx('TKA'),
441
+ 'image_emb_idx': pooler.get_idx(Image.new('RGB',(256,256), color='blue')),
442
+ 'protocol': 1
443
+ }
444
+
445
+ item = {
446
+ 'event_sequence': [
447
+ {'event_type': 'XPost', # NEW
448
+ 'timestamp': 1729711350,
449
+ 'relative_ts': -25,
450
+ 'wallet_address': 'addrW1', # Author
451
+ 'text_emb_idx': pooler.get_idx('This is the main tweet about $TKA'),
452
+ 'media_emb_idx': pooler.get_idx(Image.new('RGB', (100,100), color='cyan'))
453
+ },
454
+ {'event_type': 'XReply', # NEW
455
+ 'timestamp': 1729711360,
456
+ 'relative_ts': -35,
457
+ 'wallet_address': 'addrW2', # Replier
458
+ 'text_emb_idx': pooler.get_idx('This is a reply to the main tweet'),
459
+ 'media_emb_idx': pooler.get_idx(None), # No media in reply
460
+ 'main_tweet_text_emb_idx': pooler.get_idx('This is the main tweet about $TKA')
461
+ },
462
+ {'event_type': 'XRetweet', # NEW
463
+ 'timestamp': 1729711370,
464
+ 'relative_ts': -40,
465
+ 'wallet_address': 'addrW3', # The retweeter
466
+ 'original_author_wallet_address': 'addrW1', # The original author
467
+ 'original_post_text_emb_idx': pooler.get_idx('This is the main tweet about $TKA'),
468
+ 'original_post_media_emb_idx': pooler.get_idx(Image.new('RGB', (100,100), color='cyan'))
469
+ },
470
+ # --- CORRECTED: Test a pre-launch event with negative relative_ts ---
471
+ {'event_type': 'Transfer',
472
+ 'timestamp': 1729711180,
473
+ 'relative_ts': -10, # Negative relative_ts indicates pre-launch
474
+ 'wallet_address': 'addrW2',
475
+ 'destination_wallet_address': 'addrW1',
476
+ 'token_address': 'tknA',
477
+ 'token_amount': 1000.0, 'transfer_pct_of_total_supply': 0.0, 'transfer_pct_of_holding': 0.0, 'priority_fee': 0.0
478
+ },
479
+ {'event_type': 'Mint', 'timestamp': 1729711190, 'relative_ts': 0, 'wallet_address': 'addrW1', 'token_address': 'tknA'},
480
+ {'event_type': 'Chart_Segment', 'timestamp': 1729711200, 'relative_ts': 60, 'opens': [1.0]*OHLC_SEQ_LEN, 'closes': [1.1]*OHLC_SEQ_LEN, 'i': '1s'}, # This is high-def (segment 0) by default
481
+ {'event_type': 'Chart_Segment', 'timestamp': 1729711260, 'relative_ts': 120, 'opens': [1.0]*OHLC_SEQ_LEN, 'closes': [1.1]*OHLC_SEQ_LEN, 'i': '1s'}, # You can mark this as blurry
482
+ {'event_type': 'Transfer',
483
+ 'timestamp': 1729711210,
484
+ 'relative_ts': 20,
485
+ 'wallet_address': 'addrW1', # Source
486
+ 'destination_wallet_address': 'addrW2', # Destination
487
+ 'token_address': 'tknA', # Need token for context? (Optional, depends on design)
488
+ 'token_amount': 500.0,
489
+ 'transfer_pct_of_total_supply': 0.005,
490
+ 'transfer_pct_of_holding': 0.1,
491
+ 'priority_fee': 0.0001
492
+ },
493
+ {'event_type': 'Trade',
494
+ 'timestamp': 1729711220,
495
+ 'relative_ts': 30,
496
+ 'wallet_address': 'addrW1',
497
+ 'token_address': 'tknA',
498
+ 'trade_direction': 0,
499
+ 'sol_amount': 0.5,
500
+ # --- FIXED: Pass the integer ID directly ---
501
+ 'dex_platform_id': vocab.DEX_TO_ID['Axiom'],
502
+ 'priority_fee': 0.0002,
503
+ 'mev_protection': False,
504
+ 'token_amount_pct_of_holding': 0.05, 'quote_amount_pct_of_holding': 0.02,
505
+ 'slippage': 0.01, 'price_impact': 0.005, 'success': True, 'is_bundle': False, 'total_usd': 75.0
506
+ },
507
+ {'event_type': 'Deployer_Trade', # NEW: Testing a trade variant
508
+ 'timestamp': 1729711230,
509
+ 'relative_ts': 40,
510
+ 'wallet_address': 'addrW1', # The creator wallet
511
+ 'token_address': 'tknA',
512
+ 'trade_direction': 1, 'sol_amount': 0.2,
513
+ # --- FIXED: Pass the integer ID directly ---
514
+ 'dex_platform_id': vocab.DEX_TO_ID['Trojan'],
515
+ 'priority_fee': 0.0005,
516
+ 'mev_protection': True,
517
+ 'token_amount_pct_of_holding': 0.1, 'quote_amount_pct_of_holding': 0.0,
518
+ 'slippage': 0.02, 'price_impact': 0.01, 'success': True, 'is_bundle': False, 'total_usd': 30.0
519
+ },
520
+ {'event_type': 'SmartWallet_Trade', # NEW
521
+ 'timestamp': 1729711240,
522
+ 'relative_ts': 50,
523
+ 'wallet_address': 'addrW1', # A known smart wallet
524
+ 'token_address': 'tknA',
525
+ 'trade_direction': 0, 'sol_amount': 1.5,
526
+ # --- FIXED: Pass the integer ID directly ---
527
+ 'dex_platform_id': vocab.DEX_TO_ID['Axiom'],
528
+ 'priority_fee': 0.001,
529
+ 'mev_protection': True,
530
+ 'token_amount_pct_of_holding': 0.2, 'quote_amount_pct_of_holding': 0.1,
531
+ 'slippage': 0.01, 'price_impact': 0.008, 'success': True, 'is_bundle': False, 'total_usd': 225.0
532
+ },
533
+ {'event_type': 'LargeTrade', # NEW
534
+ 'timestamp': 1729711250,
535
+ 'relative_ts': 60,
536
+ 'wallet_address': 'addrW2', # Some other wallet
537
+ 'token_address': 'tknA',
538
+ 'trade_direction': 0, 'sol_amount': 10.0,
539
+ # --- FIXED: Pass the integer ID directly ---
540
+ 'dex_platform_id': vocab.DEX_TO_ID['OXK'],
541
+ 'priority_fee': 0.002,
542
+ 'mev_protection': False,
543
+ 'token_amount_pct_of_holding': 0.8, 'quote_amount_pct_of_holding': 0.5,
544
+ 'slippage': 0.03, 'price_impact': 0.05, 'success': True, 'is_bundle': False, 'total_usd': 1500.0
545
+ },
546
+ {'event_type': 'Chart_Segment', 'timestamp': 1729711260, 'relative_ts': 70, 'opens': [1.0]*OHLC_SEQ_LEN, 'closes': [1.1]*OHLC_SEQ_LEN, 'i': '1s'},
547
+ {'event_type': 'PoolCreated', # NEW
548
+ 'timestamp': 1729711270,
549
+ 'relative_ts': 80,
550
+ 'wallet_address': 'addrW1',
551
+ 'protocol_id': vocab.PROTOCOL_TO_ID['Raydium CPMM'],
552
+ 'quote_token_address': 'tknB',
553
+ 'base_amount': 1000000.0,
554
+ 'quote_amount': 10.0
555
+ },
556
+ {'event_type': 'LiquidityChange', # NEW
557
+ 'timestamp': 1729711280,
558
+ 'relative_ts': 90,
559
+ 'wallet_address': 'addrW2',
560
+ 'quote_token_address': 'tknB',
561
+ 'change_type_id': 0, # 0 for 'add'
562
+ 'quote_amount': 2.0
563
+ },
564
+ {'event_type': 'FeeCollected', # NEW
565
+ 'timestamp': 1729711290,
566
+ 'relative_ts': 100,
567
+ 'wallet_address': 'addrW1', # The recipient (e.g., dev wallet)
568
+ 'sol_amount': 0.1
569
+ },
570
+ {'event_type': 'TokenBurn', # NEW
571
+ 'timestamp': 1729711300,
572
+ 'relative_ts': 110,
573
+ 'wallet_address': 'addrW2', # The burner wallet
574
+ 'amount_pct_of_total_supply': 0.01, # 1% of supply
575
+ 'amount_tokens_burned': 10000000.0
576
+ },
577
+ {'event_type': 'SupplyLock', # NEW
578
+ 'timestamp': 1729711310,
579
+ 'relative_ts': 120,
580
+ 'wallet_address': 'addrW1', # The locker wallet
581
+ 'amount_pct_of_total_supply': 0.10, # 10% of supply
582
+ 'lock_duration': 2592000 # 30 days in seconds
583
+ },
584
+ {'event_type': 'HolderSnapshot', # NEW
585
+ 'timestamp': 1729711320,
586
+ 'relative_ts': 130,
587
+ # This is a pointer to the pre-computed embedding
588
+ # In a real system, this would be the index of the embedding
589
+ 'holders': [ # Raw holder data
590
+ {'wallet': 'addrW1', 'holding_pct': 0.15},
591
+ {'wallet': 'addrW2', 'holding_pct': 0.05},
592
+ # Add more mock holders if needed
593
+ ]
594
+ },
595
+ {'event_type': 'OnChain_Snapshot', # NEW
596
+ 'timestamp': 1729711320,
597
+ 'relative_ts': 130,
598
+ 'total_holders': 500,
599
+ 'smart_traders': 25,
600
+ 'kols': 3,
601
+ 'holder_growth_rate': 0.15,
602
+ 'top_10_holder_pct': 0.22,
603
+ 'sniper_holding_pct': 0.05,
604
+ 'rat_wallets_holding_pct': 0.02,
605
+ 'bundle_holding_pct': 0.01,
606
+ 'current_market_cap': 150000.0,
607
+ 'volume': 50000.0,
608
+ 'buy_count': 120,
609
+ 'sell_count': 80,
610
+ 'total_txns': 200,
611
+ 'global_fees_paid': 1.5
612
+ },
613
+ {'event_type': 'TrendingToken', # NEW
614
+ 'timestamp': 1729711330,
615
+ 'relative_ts': 140,
616
+ 'token_address': 'tknC', # The token that is trending
617
+ 'list_source_id': vocab.TRENDING_LIST_SOURCE_TO_ID['Phantom'],
618
+ 'timeframe_id': vocab.TRENDING_LIST_TIMEFRAME_TO_ID['1h'],
619
+ 'rank': 3
620
+ },
621
+ {'event_type': 'BoostedToken', # NEW
622
+ 'timestamp': 1729711340,
623
+ 'relative_ts': 150,
624
+ 'token_address': 'tknD', # The token that is boosted
625
+ 'total_boost_amount': 5000.0,
626
+ 'rank': 1
627
+ },
628
+ {'event_type': 'XQuoteTweet', # NEW
629
+ 'timestamp': 1729711380,
630
+ 'relative_ts': 190,
631
+ 'wallet_address': 'addrW3', # The quoter
632
+ 'quoter_text_emb_idx': pooler.get_idx('Wow, look at this! $TKA'),
633
+ 'original_author_wallet_address': 'addrW1', # The original author
634
+ 'original_post_text_emb_idx': pooler.get_idx('This is the main tweet about $TKA'),
635
+ 'original_post_media_emb_idx': pooler.get_idx(Image.new('RGB', (100,100), color='cyan'))
636
+ },
637
+ # --- NEW: Add special context tokens ---
638
+ {'event_type': 'MIDDLE', 'timestamp': 1729711500, 'relative_ts': 195},
639
+ {'event_type': 'PumpReply', # NEW
640
+ 'timestamp': 1729711390,
641
+ 'relative_ts': 200,
642
+ 'wallet_address': 'addrW2', # The user who replied
643
+ 'reply_text_emb_idx': pooler.get_idx('to the moon!')
644
+ },
645
+ {'event_type': 'DexBoost_Paid', # NEW
646
+ 'timestamp': 1729711400,
647
+ 'relative_ts': 210,
648
+ 'amount': 5.0, # e.g., 5 Boost
649
+ 'total_amount_on_token': 25.0 # 25 Boost Points
650
+ },
651
+ {'event_type': 'DexProfile_Updated', # NEW
652
+ 'timestamp': 1729711410,
653
+ 'relative_ts': 220,
654
+ 'has_changed_website_flag': True,
655
+ 'has_changed_twitter_flag': False,
656
+ 'has_changed_telegram_flag': True,
657
+ 'has_changed_description_flag': True,
658
+ # Pre-computed text embeddings
659
+ 'website_emb_idx': pooler.get_idx('new-token-website.com'),
660
+ 'twitter_link_emb_idx': pooler.get_idx('old_handle'), # No change, so old link
661
+ 'telegram_link_emb_idx': pooler.get_idx('new_tg_group'),
662
+ 'description_emb_idx': pooler.get_idx('This is the new and improved token description.')
663
+ },
664
+ {'event_type': 'AlphaGroup_Call', # NEW
665
+ 'timestamp': 1729711420,
666
+ 'relative_ts': 230,
667
+ 'group_id': vocab.ALPHA_GROUPS_TO_ID['Potion']
668
+ },
669
+ {'event_type': 'Channel_Call', # NEW
670
+ 'timestamp': 1729711430,
671
+ 'relative_ts': 240,
672
+ 'channel_id': vocab.CALL_CHANNELS_TO_ID['MarcosCalls']
673
+ },
674
+ {'event_type': 'RECENT', 'timestamp': 1729711510, 'relative_ts': 245},
675
+ {'event_type': 'CexListing', # NEW
676
+ 'timestamp': 1729711440,
677
+ 'relative_ts': 250,
678
+ 'exchange_id': vocab.EXCHANGES_TO_ID['mexc']
679
+ },
680
+ {'event_type': 'TikTok_Trending_Hashtag', # NEW
681
+ 'timestamp': 1729711450,
682
+ 'relative_ts': 260,
683
+ 'hashtag_name_emb_idx': pooler.get_idx('CryptoTok'),
684
+ 'rank': 5
685
+ },
686
+ {'event_type': 'XTrending_Hashtag', # NEW
687
+ 'timestamp': 1729711460,
688
+ 'relative_ts': 270,
689
+ 'hashtag_name_emb_idx': pooler.get_idx('SolanaMemes'),
690
+ 'rank': 2
691
+ },
692
+ {'event_type': 'ChainSnapshot', # NEW
693
+ 'timestamp': 1729711470,
694
+ 'relative_ts': 280,
695
+ 'native_token_price_usd': 150.75,
696
+ 'gas_fee': 0.00015 # Example gas fee
697
+ },
698
+ {'event_type': 'Lighthouse_Snapshot', # NEW
699
+ 'timestamp': 1729711480,
700
+ 'relative_ts': 290,
701
+ 'protocol_id': vocab.PROTOCOL_TO_ID['Pump V1'],
702
+ 'timeframe_id': vocab.LIGHTHOUSE_TIMEFRAME_TO_ID['1h'],
703
+ 'total_volume': 1.2e6,
704
+ 'total_transactions': 5000,
705
+ 'total_traders': 1200,
706
+ 'total_tokens_created': 85,
707
+ 'total_migrations': 70
708
+ },
709
+ {'event_type': 'Migrated', # NEW
710
+ 'timestamp': 1729711490,
711
+ 'relative_ts': 300,
712
+ 'protocol_id': vocab.PROTOCOL_TO_ID['Raydium CPMM']
713
+ },
714
+
715
+ ],
716
+ 'wallets': {
717
+ 'addrW1': {'profile': profile1, 'socials': social1, 'holdings': holdings1},
718
+ 'addrW2': {'profile': profile2, 'socials': social2, 'holdings': holdings2},
719
+ # --- NEW: Add wallet 3 data ---
720
+ 'addrW3': {
721
+ 'profile': {**profile2, 'wallet_address': 'addrW3'}, # Reuse profile2 but change address
722
+ 'socials': social3,
723
+ 'holdings': []
724
+ }
725
+ },
726
+ 'tokens': {
727
+ 'tknA': tokenA_data, # Main token
728
+ 'tknB': tokenB_data, # Quote token
729
+ 'tknC': tokenC_data, # Trending token
730
+ 'tknD': tokenD_data # Boosted token
731
+ },
732
+ # --- NEW: The pre-computed embedding pool is generated after collecting all items
733
+ 'embedding_pooler': pooler, # Pass the pooler to generate the tensor later
734
+
735
+ # --- NEW: Expanded graph_links to test all encoders ---
736
+ # --- FIXED: Removed useless logging fields as per user request ---
737
+ 'graph_links': {
738
+ 'TransferLink': {'links': [{'timestamp': 1729711205}], 'edges': [('addrW1', 'addrW2')]}, # Keep timestamp
739
+ 'BundleTradeLink': {'links': [{'timestamp': 1729711215}], 'edges': [('addrW1', 'addrW2')]}, # Keep timestamp
740
+ 'CopiedTradeLink': {'links': [
741
+ {'time_gap_on_buy_sec': 10, 'time_gap_on_sell_sec': 120, 'leader_pnl': 5.0, 'follower_pnl': 4.0, 'follower_buy_total': 100, 'follower_sell_total': 120}
742
+ ], 'edges': [('addrW1', 'addrW2')]},
743
+ 'CoordinatedActivityLink': {'links': [
744
+ {'time_gap_on_first_sec': 5, 'time_gap_on_second_sec': 8}
745
+ ], 'edges': [('addrW1', 'addrW2')]},
746
+ 'MintedLink': {'links': [
747
+ {'timestamp': 1729711200, 'buy_amount': 1e9}
748
+ ], 'edges': [('addrW1', 'tknA')]},
749
+ 'SnipedLink': {'links': [
750
+ {'rank': 1, 'sniped_amount': 5e8}
751
+ ], 'edges': [('addrW1', 'tknA')]},
752
+ 'LockedSupplyLink': {'links': [
753
+ {'amount': 1e10} # Only amount is needed
754
+ ], 'edges': [('addrW1', 'tknA')]},
755
+ 'BurnedLink': {'links': [
756
+ {'timestamp': 1729711300} # Only timestamp is needed
757
+ ], 'edges': [('addrW2', 'tknA')]},
758
+ 'ProvidedLiquidityLink': {'links': [
759
+ {'timestamp': 1729711250} # Only timestamp is needed
760
+ ], 'edges': [('addrW1', 'tknA')]},
761
+ 'WhaleOfLink': {'links': [
762
+ {} # Just the existence of the link is the feature
763
+ ], 'edges': [('addrW1', 'tknA')]},
764
+ 'TopTraderOfLink': {'links': [
765
+ {'pnl_at_creation': 50000.0} # Only PnL is needed
766
+ ], 'edges': [('addrW2', 'tknA')]}
767
+ },
768
+
769
+ # --- FIXED: Removed chart_segments dictionary ---
770
+ 'labels': torch.randn(self.num_outputs) if self.num_outputs > 0 else torch.zeros(0),
771
+ 'labels_mask': torch.ones(self.num_outputs) if self.num_outputs > 0 else torch.zeros(0)
772
+ }
773
+
774
+ print("Mock raw batch created.")
775
+
776
+ return item
data/data_collator.py ADDED
@@ -0,0 +1,708 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # memecoin_collator.py (CORRECTED ORDER OF OPERATIONS)
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.nn.utils.rnn import pad_sequence
6
+ from typing import List, Dict, Any, Tuple, Optional, Union
7
+ from collections import defaultdict
8
+ from PIL import Image
9
+ from models.multi_modal_processor import MultiModalEncoder
10
+
11
+ # Encoders are NO LONGER imported here
12
+ import models.vocabulary as vocab # For IDs, config sizes
13
+ from data.data_loader import EmbeddingPooler # Import for type hinting and instantiation
14
+
15
+ NATIVE_MINT = "So11111111111111111111111111111111111111112"
16
+ QUOTE_MINTS = {
17
+ NATIVE_MINT, # SOL
18
+ "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v", # USDC
19
+ "Es9vMFrzaCERmJfrF4H2FYD4KCoNkY11McCe8BenwNYB", # USDT
20
+ "USD1ttGY1N17NEEHLmELoaybftRBUSErhqYiQzvEmuB", # USD1
21
+ }
22
+
23
+ class MemecoinCollator:
24
+ """
25
+ Callable class for PyTorch DataLoader's collate_fn.
26
+ ... (rest of docstring) ...
27
+ """
28
+ def __init__(self,
29
+ event_type_to_id: Dict[str, int],
30
+ device: torch.device,
31
+ multi_modal_encoder: MultiModalEncoder,
32
+ dtype: torch.dtype,
33
+ ohlc_seq_len: int = 300,
34
+ max_seq_len: Optional[int] = None
35
+ ):
36
+ self.event_type_to_id = event_type_to_id
37
+ self.pad_token_id = event_type_to_id.get('__PAD__', 0)
38
+ self.multi_modal_encoder = multi_modal_encoder
39
+ self.entity_pad_idx = 0
40
+
41
+ self.device = device
42
+ self.dtype = dtype
43
+ self.ohlc_seq_len = ohlc_seq_len
44
+ self.max_seq_len = max_seq_len
45
+
46
+ def _collate_features_for_encoder(self, entities: List[Dict], feature_keys: List[str], device: torch.device, entity_type: str) -> Dict[str, Any]:
47
+ """ (Unchanged) """
48
+ collated = defaultdict(list)
49
+ if not entities:
50
+ # --- FIXED: Return a default empty structure for BOTH tokens and wallets ---
51
+ if entity_type == "token":
52
+ return {
53
+ 'name_embed_indices': torch.tensor([], device=device, dtype=torch.long),
54
+ 'symbol_embed_indices': torch.tensor([], device=device, dtype=torch.long),
55
+ 'image_embed_indices': torch.tensor([], device=device, dtype=torch.long),
56
+ 'protocol_ids': torch.tensor([], device=device, dtype=torch.long),
57
+ 'is_vanity_flags': torch.tensor([], device=device, dtype=torch.bool),
58
+ '_addresses_for_lookup': []
59
+ }
60
+ elif entity_type == "wallet":
61
+ return {
62
+ 'username_embed_indices': torch.tensor([], device=device, dtype=torch.long),
63
+ 'profile_rows': [], 'social_rows': [], 'holdings_batch': []
64
+ }
65
+ return {} # Should not happen
66
+
67
+ # NEW: We now gather indices to pre-computed embeddings
68
+ if entity_type == "token":
69
+ # This indicates a Token entity
70
+ # Helper key for WalletEncoder to find token vibes
71
+ collated['_addresses_for_lookup'] = [e.get('address', '') for e in entities]
72
+ collated['name_embed_indices'] = torch.tensor([e.get('name_emb_idx', 0) for e in entities], device=device, dtype=torch.long)
73
+ collated['symbol_embed_indices'] = torch.tensor([e.get('symbol_emb_idx', 0) for e in entities], device=device, dtype=torch.long)
74
+ collated['image_embed_indices'] = torch.tensor([e.get('image_emb_idx', 0) for e in entities], device=device, dtype=torch.long)
75
+ collated['protocol_ids'] = torch.tensor([e.get('protocol', 0) for e in entities], device=device, dtype=torch.long)
76
+ collated['is_vanity_flags'] = torch.tensor([e.get('is_vanity', False) for e in entities], device=device, dtype=torch.bool)
77
+ elif entity_type == "wallet":
78
+ # NEW: Gather username indices for WalletEncoder
79
+ collated['username_embed_indices'] = torch.tensor([e.get('socials', {}).get('username_emb_idx', 0) for e in entities], device=device, dtype=torch.long)
80
+ collated['profile_rows'] = [e.get('profile', {}) for e in entities]
81
+ collated['social_rows'] = [e.get('socials', {}) for e in entities]
82
+ collated['holdings_batch'] = [e.get('holdings', []) for e in entities]
83
+ return dict(collated)
84
+
85
+ def _collate_ohlc_inputs(self, chart_events: List[Dict]) -> Dict[str, torch.Tensor]:
86
+ """ (Unchanged from previous correct version) """
87
+ if not chart_events:
88
+ return {
89
+ 'price_tensor': torch.empty(0, 2, self.ohlc_seq_len, device=self.device, dtype=self.dtype),
90
+ 'interval_ids': torch.empty(0, device=self.device, dtype=torch.long)
91
+ }
92
+ ohlc_tensors = []
93
+ interval_ids_list = []
94
+ seq_len = self.ohlc_seq_len
95
+ unknown_id = vocab.INTERVAL_TO_ID.get("Unknown", 0)
96
+ for segment_data in chart_events:
97
+ opens = segment_data.get('opens', [])
98
+ closes = segment_data.get('closes', [])
99
+ interval_str = segment_data.get('i', "Unknown")
100
+ pad_open = opens[-1] if opens else 0
101
+ pad_close = closes[-1] if closes else 0
102
+ o = torch.tensor(opens[:seq_len] + [pad_open]*(seq_len-len(opens)), dtype=self.dtype)
103
+ c = torch.tensor(closes[:seq_len] + [pad_close]*(seq_len-len(closes)), dtype=self.dtype)
104
+ ohlc_tensors.append(torch.stack([o, c]))
105
+ interval_id = vocab.INTERVAL_TO_ID.get(interval_str, unknown_id)
106
+ interval_ids_list.append(interval_id)
107
+ return {
108
+ 'price_tensor': torch.stack(ohlc_tensors).to(self.device),
109
+ 'interval_ids': torch.tensor(interval_ids_list, device=self.device, dtype=torch.long)
110
+ }
111
+
112
+ def _collate_graph_links(self,
113
+ batch_items: List[Dict],
114
+ wallet_addr_to_batch_idx: Dict[str, int],
115
+ token_addr_to_batch_idx: Dict[str, int]) -> Dict[str, Any]:
116
+ """ (Unchanged) """
117
+ aggregated_links = defaultdict(lambda: {'edge_index_list': [], 'links_list': []})
118
+ for item in batch_items:
119
+ item_wallets = item.get('wallets', {})
120
+ item_tokens = item.get('tokens', {})
121
+ item_wallet_addr_to_global_idx = {addr: wallet_addr_to_batch_idx.get(addr, self.entity_pad_idx) for addr in item_wallets.keys()}
122
+ item_token_addr_to_global_idx = {addr: token_addr_to_batch_idx.get(addr, self.entity_pad_idx) for addr in item_tokens.keys()}
123
+ for link_name, data in item.get('graph_links', {}).items():
124
+ aggregated_links[link_name]['links_list'].extend(data.get('links', []))
125
+ triplet = vocab.LINK_NAME_TO_TRIPLET.get(link_name)
126
+ if not triplet: continue
127
+ src_type, _, dst_type = triplet
128
+ edges = data.get('edges')
129
+ if not edges: continue
130
+ src_map = item_wallet_addr_to_global_idx if src_type == 'wallet' else item_token_addr_to_global_idx
131
+ dst_map = item_wallet_addr_to_global_idx if dst_type == 'wallet' else item_token_addr_to_global_idx
132
+ remapped_edge_list = []
133
+ for src_addr, dst_addr in edges:
134
+ src_idx_global = src_map.get(src_addr, self.entity_pad_idx)
135
+ dst_idx_global = dst_map.get(dst_addr, self.entity_pad_idx)
136
+ if src_idx_global != self.entity_pad_idx and dst_idx_global != self.entity_pad_idx:
137
+ remapped_edge_list.append([src_idx_global, dst_idx_global])
138
+ if remapped_edge_list:
139
+ remapped_edge_tensor = torch.tensor(remapped_edge_list, device=self.device, dtype=torch.long).t()
140
+ aggregated_links[link_name]['edge_index_list'].append(remapped_edge_tensor)
141
+ if link_name == "TransferLink":
142
+ link_props = data.get('links', [])
143
+ derived_edges = []
144
+ derived_props = []
145
+ for (src_addr, dst_addr), props in zip(edges, link_props):
146
+ mint_addr = props.get('mint')
147
+ if not mint_addr or mint_addr in QUOTE_MINTS:
148
+ continue
149
+ token_idx_global = item_token_addr_to_global_idx.get(mint_addr, self.entity_pad_idx)
150
+ if token_idx_global == self.entity_pad_idx:
151
+ continue
152
+ for wallet_addr in (src_addr, dst_addr):
153
+ wallet_idx_global = item_wallet_addr_to_global_idx.get(wallet_addr, self.entity_pad_idx)
154
+ if wallet_idx_global == self.entity_pad_idx:
155
+ continue
156
+ derived_edges.append([wallet_idx_global, token_idx_global])
157
+ derived_props.append(props)
158
+ if derived_edges:
159
+ derived_tensor = torch.tensor(derived_edges, device=self.device, dtype=torch.long).t()
160
+ aggregated_links["TransferLinkToken"]['edge_index_list'].append(derived_tensor)
161
+ aggregated_links["TransferLinkToken"]['links_list'].extend(derived_props)
162
+ final_links_dict = {}
163
+ for link_name, data in aggregated_links.items():
164
+ if data['edge_index_list']:
165
+ final_links_dict[link_name] = {
166
+ 'links': data['links_list'],
167
+ 'edge_index': torch.cat(data['edge_index_list'], dim=1)
168
+ }
169
+ return final_links_dict
170
+
171
+ def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
172
+ """
173
+ Processes a batch of raw data items into tensors for the model.
174
+ """
175
+ # --- NEW ARCHITECTURE ---
176
+ # 1. Aggregate all unique embeddable items from the entire batch.
177
+ # 2. Create a single embedding pool tensor for the whole batch.
178
+ # 3. Create a mapping from original (per-item) indices to the new batch-wide indices.
179
+ # 4. Remap all `_emb_idx` fields in the batch data using this new mapping.
180
+
181
+ batch_size = len(batch)
182
+ if batch_size == 0:
183
+ return {}
184
+
185
+ # --- 1. Aggregate all unique items and create index mappings ---
186
+ batch_wide_pooler = EmbeddingPooler()
187
+ # Map to translate from an item's original pooler to the new batch-wide indices
188
+ # Format: { batch_item_index: { original_idx: new_batch_idx } }
189
+ idx_remap = defaultdict(dict)
190
+
191
+ for i, item in enumerate(batch):
192
+ pooler = item.get('embedding_pooler')
193
+ if not pooler: continue
194
+
195
+ for pool_item_data in pooler.get_all_items():
196
+ original_idx = pool_item_data['idx']
197
+ raw_item = pool_item_data['item']
198
+ # get_idx will either return an existing index or create a new one
199
+ # --- FIX: Convert 1-based pooler index to 0-based tensor index ---
200
+ new_batch_idx_1_based = batch_wide_pooler.get_idx(raw_item)
201
+ new_batch_idx_0_based = new_batch_idx_1_based - 1
202
+ idx_remap[i][original_idx] = new_batch_idx_0_based
203
+
204
+ # --- 2. Create the single, batch-wide embedding pool tensor ---
205
+ all_items_sorted = batch_wide_pooler.get_all_items()
206
+ texts_to_encode = [d['item'] for d in all_items_sorted if isinstance(d['item'], str)]
207
+ images_to_encode = [d['item'] for d in all_items_sorted if isinstance(d['item'], Image.Image)]
208
+
209
+ text_embeds = self.multi_modal_encoder(texts_to_encode) if texts_to_encode else torch.empty(0)
210
+ image_embeds = self.multi_modal_encoder(images_to_encode) if images_to_encode else torch.empty(0)
211
+
212
+ # Create the final lookup tensor and fill it based on original item type
213
+ batch_embedding_pool = torch.zeros(len(all_items_sorted), self.multi_modal_encoder.embedding_dim, device=self.device, dtype=self.dtype)
214
+ text_cursor, image_cursor = 0, 0
215
+ for i, item_data in enumerate(all_items_sorted):
216
+ if isinstance(item_data['item'], str):
217
+ if text_embeds.numel() > 0:
218
+ batch_embedding_pool[i] = text_embeds[text_cursor]
219
+ text_cursor += 1
220
+ elif isinstance(item_data['item'], Image.Image):
221
+ if image_embeds.numel() > 0:
222
+ batch_embedding_pool[i] = image_embeds[image_cursor]
223
+ image_cursor += 1
224
+
225
+ # --- 3. Remap all indices in the batch data ---
226
+ for i, item in enumerate(batch):
227
+ remap_dict = idx_remap.get(i, {})
228
+ if not remap_dict: continue
229
+
230
+ # Remap tokens
231
+ for token_data in item.get('tokens', {}).values():
232
+ for key in ['name_emb_idx', 'symbol_emb_idx', 'image_emb_idx']:
233
+ if token_data.get(key, 0) > 0: # Check if it has a valid 1-based index
234
+ token_data[key] = remap_dict.get(token_data[key], -1) # Remap to 0-based, default to -1 if not found
235
+ # Remap wallets
236
+ for wallet_data in item.get('wallets', {}).values():
237
+ socials = wallet_data.get('socials', {})
238
+ if socials.get('username_emb_idx', 0) > 0:
239
+ socials['username_emb_idx'] = remap_dict.get(socials['username_emb_idx'], -1)
240
+ # Remap events
241
+ for event in item.get('event_sequence', []):
242
+ for key in event:
243
+ if key.endswith('_emb_idx') and event.get(key, 0) > 0:
244
+ event[key] = remap_dict.get(event[key], 0)
245
+
246
+ # --- 4. Standard Collation (Now that indices are correct) ---
247
+ unique_wallets_data = {}
248
+ unique_tokens_data = {}
249
+ all_event_sequences = []
250
+ max_len = 0
251
+
252
+ for item in batch:
253
+ seq = item.get('event_sequence', [])
254
+ if self.max_seq_len is not None and len(seq) > self.max_seq_len:
255
+ seq = seq[:self.max_seq_len]
256
+ all_event_sequences.append(seq)
257
+ max_len = max(max_len, len(seq))
258
+ unique_wallets_data.update(item.get('wallets', {}))
259
+ unique_tokens_data.update(item.get('tokens', {}))
260
+
261
+ # Create mappings needed for indexing
262
+ wallet_list_data = list(unique_wallets_data.values())
263
+ token_list_data = list(unique_tokens_data.values())
264
+ wallet_addr_to_batch_idx = {feat.get('profile', {}).get('wallet_address', f'__error_{i}'): i+1 for i, feat in enumerate(wallet_list_data)}
265
+ token_addr_to_batch_idx = {feat.get('address', f'__error_{i}'): i+1 for i, feat in enumerate(token_list_data)}
266
+
267
+ # Collate Static Raw Features (Tokens, Wallets, Graph)
268
+ token_encoder_inputs = self._collate_features_for_encoder(token_list_data, ['name'], self.device, "token")
269
+ wallet_encoder_inputs = self._collate_features_for_encoder(wallet_list_data, ['profile'], self.device, "wallet")
270
+ graph_updater_links = self._collate_graph_links(batch, wallet_addr_to_batch_idx, token_addr_to_batch_idx)
271
+
272
+ # --- Logging ---
273
+ pool_contents = batch_wide_pooler.get_all_items()
274
+ print(f"\n[DataCollator: Final Embedding Pool] ({len(pool_contents)} items):")
275
+ if pool_contents:
276
+ for item_data in pool_contents:
277
+ sample_item = item_data['item']
278
+ sample_type = "Image" if isinstance(sample_item, Image.Image) else "Text"
279
+ content_preview = str(sample_item)
280
+ if sample_type == "Text" and len(content_preview) > 100:
281
+ content_preview = content_preview[:97] + "..."
282
+ print(f" - Item (Original Idx {item_data['idx']}): Type='{sample_type}', Content='{content_preview}'")
283
+
284
+ # --- 5. Prepare Sequence Tensors & Collect Dynamic Data (OHLC) ---
285
+ B = batch_size
286
+ L = max_len
287
+ PAD_IDX_SEQ = self.pad_token_id
288
+ PAD_IDX_ENT = self.entity_pad_idx
289
+
290
+ # Initialize sequence tensors
291
+ event_type_ids = torch.full((B, L), PAD_IDX_SEQ, dtype=torch.long, device=self.device)
292
+ timestamps_float = torch.zeros((B, L), dtype=torch.float32, device=self.device)
293
+ # Store relative_ts in float32 for stability; model will scale/log/normalize
294
+ relative_ts = torch.zeros((B, L, 1), dtype=torch.float32, device=self.device)
295
+ attention_mask = torch.zeros((B, L), dtype=torch.long, device=self.device)
296
+ wallet_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
297
+ token_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
298
+ ohlc_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
299
+ quote_token_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device) # NEW
300
+
301
+ # --- NEW: Tensors for Transfer/LargeTransfer ---
302
+ dest_wallet_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
303
+ # --- NEW: Separate tensor for social media original authors ---
304
+ original_author_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
305
+ # 4 numerical features for transfers
306
+ transfer_numerical_features = torch.zeros((B, L, 4), dtype=self.dtype, device=self.device)
307
+
308
+ # --- NEW: Tensors for Trade ---
309
+ # --- FIXED: Size reduced from 10 to 8 ---
310
+ trade_numerical_features = torch.zeros((B, L, 8), dtype=self.dtype, device=self.device)
311
+ deployer_trade_numerical_features = torch.zeros((B, L, 8), dtype=self.dtype, device=self.device)
312
+ smart_wallet_trade_numerical_features = torch.zeros((B, L, 8), dtype=self.dtype, device=self.device)
313
+ # --- NEW: Dedicated tensor for categorical dex_platform_id ---
314
+ trade_dex_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
315
+ # --- NEW: Dedicated tensor for categorical trade_direction ---
316
+ trade_direction_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
317
+ # --- NEW: Dedicated tensor for categorical mev_protection ---
318
+ trade_mev_protection_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
319
+ # --- NEW: Dedicated tensor for categorical is_bundle ---
320
+ trade_is_bundle_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
321
+
322
+ # --- NEW: Tensors for PoolCreated ---
323
+ # --- UPDATED: Capture raw base/quote deposit amounts only ---
324
+ pool_created_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device)
325
+ # --- NEW: Dedicated tensor for categorical protocol_id ---
326
+ pool_created_protocol_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
327
+
328
+ # --- NEW: Tensors for LiquidityChange ---
329
+ # --- UPDATED: Keep only the raw quote amount deposit/withdraw ---
330
+ liquidity_change_numerical_features = torch.zeros((B, L, 1), dtype=self.dtype, device=self.device)
331
+ # --- NEW: Dedicated tensor for categorical change_type_id ---
332
+ liquidity_change_type_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
333
+
334
+ # --- NEW: Tensors for FeeCollected ---
335
+ fee_collected_numerical_features = torch.zeros((B, L, 1), dtype=self.dtype, device=self.device) # sol_amount only
336
+ # --- NEW: Tensors for TokenBurn ---
337
+ token_burn_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device) # amount_pct, amount_tokens
338
+
339
+ # --- NEW: Tensors for SupplyLock ---
340
+ supply_lock_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device) # amount_pct, lock_duration
341
+
342
+ # --- NEW: Tensors for OnChain_Snapshot ---
343
+ onchain_snapshot_numerical_features = torch.zeros((B, L, 14), dtype=self.dtype, device=self.device)
344
+
345
+ # --- NEW: Tensors for TrendingToken ---
346
+ trending_token_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
347
+ # --- FIXED: Size reduced from 3 to 1 after removing IDs ---
348
+ trending_token_numerical_features = torch.zeros((B, L, 1), dtype=self.dtype, device=self.device) # rank
349
+ trending_token_source_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
350
+ trending_token_timeframe_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
351
+
352
+ # --- NEW: Tensors for BoostedToken ---
353
+ boosted_token_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
354
+ boosted_token_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device) # total_boost_amount, rank
355
+
356
+ # --- NEW: Tensors for DexBoost_Paid ---
357
+ dexboost_paid_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device) # amount, total_amount_on_token
358
+
359
+ # --- NEW: Tensors for DexProfile_Updated ---
360
+ dexprofile_updated_flags = torch.zeros((B, L, 4), dtype=torch.float32, device=self.device) # Using float for easier projection
361
+
362
+ # --- NEW: Tensors for Tracker Events ---
363
+ alpha_group_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
364
+ channel_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
365
+ exchange_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
366
+
367
+ # --- NEW: Tensors for GlobalTrending Events ---
368
+ global_trending_numerical_features = torch.zeros((B, L, 1), dtype=self.dtype, device=self.device) # rank
369
+
370
+ # --- NEW: Tensors for ChainSnapshot ---
371
+ chainsnapshot_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device) # native_token_price_usd, gas_fee
372
+
373
+ # --- NEW: Tensors for Lighthouse_Snapshot ---
374
+ # --- FIXED: Size reduced from 7 to 5 after removing IDs ---
375
+ lighthousesnapshot_numerical_features = torch.zeros((B, L, 5), dtype=self.dtype, device=self.device)
376
+ lighthousesnapshot_protocol_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
377
+ lighthousesnapshot_timeframe_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
378
+
379
+ # --- NEW: Tensors for Migrated event ---
380
+ migrated_protocol_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
381
+
382
+ # --- NEW: Tensors for HolderSnapshot ---
383
+ # This will store the raw holder data for the Oracle to process
384
+ holder_snapshot_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
385
+ holder_snapshot_raw_data_list = [] # List of lists of dicts
386
+
387
+ # --- RENAMED: Generic tensors for any event with text/image features ---
388
+ textual_event_data_list = [] # List of dicts with text/media indices
389
+ textual_event_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
390
+ # --- NEW: Pointers for pre-encoded images ---
391
+ image_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
392
+ original_post_image_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
393
+
394
+
395
+
396
+ # --- CORRECTED: Initialize chart event collection here ---
397
+ batch_chart_events = []
398
+ chart_event_counter = 0
399
+
400
+ # Loop through sequences to populate tensors and collect chart events
401
+ for i, seq in enumerate(all_event_sequences):
402
+ seq_len = len(seq)
403
+ if seq_len == 0: continue
404
+ attention_mask[i, :seq_len] = 1
405
+
406
+ for j, event in enumerate(seq):
407
+ # Populate basic sequence info
408
+ event_type = event.get('event_type', '__PAD__')
409
+ type_id = self.event_type_to_id.get(event_type, PAD_IDX_SEQ)
410
+ event_type_ids[i, j] = type_id
411
+ timestamps_float[i, j] = event.get('timestamp', 0)
412
+ relative_ts[i, j, 0] = event.get('relative_ts', 0.0)
413
+
414
+ # Populate pointer indices
415
+ w_addr = event.get('wallet_address')
416
+ if w_addr:
417
+ wallet_indices[i, j] = wallet_addr_to_batch_idx.get(w_addr, PAD_IDX_ENT)
418
+ t_addr = event.get('token_address')
419
+ if t_addr:
420
+ token_indices[i, j] = token_addr_to_batch_idx.get(t_addr, PAD_IDX_ENT)
421
+
422
+ # If it's a chart event, collect it and record its index
423
+ if event_type == 'Chart_Segment':
424
+ batch_chart_events.append(event)
425
+ ohlc_indices[i, j] = chart_event_counter + 1 # Use 1-based index
426
+ chart_event_counter += 1
427
+
428
+ elif event_type in ['Transfer', 'LargeTransfer']: # ADDED LargeTransfer
429
+ # Get destination wallet index
430
+ dest_w_addr = event.get('destination_wallet_address') # Assuming this key exists
431
+ if dest_w_addr:
432
+ dest_wallet_indices[i, j] = wallet_addr_to_batch_idx.get(dest_w_addr, PAD_IDX_ENT)
433
+
434
+ # Get numerical features (use .get with default 0)
435
+ num_feats = [
436
+ event.get('token_amount', 0.0),
437
+ event.get('transfer_pct_of_total_supply', 0.0),
438
+ event.get('transfer_pct_of_holding', 0.0),
439
+ event.get('priority_fee', 0.0)
440
+ ]
441
+ transfer_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
442
+
443
+ elif event_type in ['Trade', 'LargeTrade']:
444
+ # Get numerical and categorical features for the trade
445
+ trade_dex_ids[i, j] = event.get('dex_platform_id', 0)
446
+ trade_direction_ids[i, j] = event.get('trade_direction', 0) # 0=buy, 1=sell
447
+ trade_mev_protection_ids[i, j] = event.get('mev_protection', 0) # 0, 1, 2...
448
+ trade_is_bundle_ids[i, j] = 1 if event.get('is_bundle') else 0 # 0=false, 1=true
449
+
450
+ num_feats = [
451
+ event.get('sol_amount', 0.0),
452
+ event.get('priority_fee', 0.0),
453
+ event.get('token_amount_pct_of_holding', 0.0),
454
+ event.get('quote_amount_pct_of_holding', 0.0),
455
+ event.get('slippage', 0.0),
456
+ event.get('token_amount_pct_to_total_supply', 0.0), # REPLACED price_impact
457
+ 1.0 if event.get('success') else 0.0,
458
+ event.get('total_usd', 0.0)
459
+ ]
460
+ trade_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
461
+
462
+ elif event_type == 'Deployer_Trade':
463
+ # Use the dedicated tensor for deployer trades
464
+ trade_dex_ids[i, j] = event.get('dex_platform_id', 0)
465
+ trade_direction_ids[i, j] = event.get('trade_direction', 0) # 0=buy, 1=sell
466
+ trade_mev_protection_ids[i, j] = event.get('mev_protection', 0) # 0, 1, 2...
467
+ trade_is_bundle_ids[i, j] = 1 if event.get('is_bundle') else 0 # 0=false, 1=true
468
+ num_feats = [
469
+ event.get('sol_amount', 0.0),
470
+ event.get('priority_fee', 0.0),
471
+ event.get('token_amount_pct_of_holding', 0.0),
472
+ event.get('quote_amount_pct_of_holding', 0.0),
473
+ event.get('slippage', 0.0),
474
+ event.get('token_amount_pct_to_total_supply', 0.0), # REPLACED price_impact
475
+ 1.0 if event.get('success') else 0.0,
476
+ event.get('total_usd', 0.0)
477
+ ]
478
+ deployer_trade_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
479
+
480
+ elif event_type == 'SmartWallet_Trade':
481
+ # Use the dedicated tensor for smart wallet trades
482
+ trade_dex_ids[i, j] = event.get('dex_platform_id', 0)
483
+ trade_direction_ids[i, j] = event.get('trade_direction', 0) # 0=buy, 1=sell
484
+ trade_mev_protection_ids[i, j] = event.get('mev_protection', 0) # 0, 1, 2...
485
+ trade_is_bundle_ids[i, j] = 1 if event.get('is_bundle') else 0 # 0=false, 1=true
486
+ num_feats = [
487
+ event.get('sol_amount', 0.0),
488
+ event.get('priority_fee', 0.0),
489
+ event.get('token_amount_pct_of_holding', 0.0),
490
+ event.get('quote_amount_pct_of_holding', 0.0),
491
+ event.get('slippage', 0.0),
492
+ event.get('token_amount_pct_to_total_supply', 0.0), # REPLACED price_impact
493
+ 1.0 if event.get('success') else 0.0,
494
+ event.get('total_usd', 0.0)
495
+ ]
496
+ smart_wallet_trade_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
497
+
498
+ elif event_type == 'PoolCreated':
499
+ # Get the quote token index
500
+ quote_t_addr = event.get('quote_token_address')
501
+ if quote_t_addr:
502
+ quote_token_indices[i, j] = token_addr_to_batch_idx.get(quote_t_addr, PAD_IDX_ENT)
503
+
504
+ pool_created_protocol_ids[i, j] = event.get('protocol_id', 0)
505
+ # Get numerical features
506
+ num_feats = [
507
+ event.get('base_amount', 0.0),
508
+ event.get('quote_amount', 0.0)
509
+ ]
510
+ pool_created_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
511
+
512
+ elif event_type == 'LiquidityChange':
513
+ # Get the quote token index
514
+ quote_t_addr = event.get('quote_token_address')
515
+ if quote_t_addr:
516
+ quote_token_indices[i, j] = token_addr_to_batch_idx.get(quote_t_addr, PAD_IDX_ENT)
517
+
518
+ liquidity_change_type_ids[i, j] = event.get('change_type_id', 0)
519
+ # Get numerical features
520
+ num_feats = [event.get('quote_amount', 0.0)]
521
+ liquidity_change_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
522
+
523
+ elif event_type == 'FeeCollected':
524
+ # This event has the recipient wallet plus a single numerical feature (SOL amount).
525
+ num_feats = [
526
+ event.get('sol_amount', 0.0)
527
+ ]
528
+ fee_collected_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
529
+
530
+ elif event_type == 'TokenBurn':
531
+ # This event has a wallet (handled by wallet_indices) and two numerical features.
532
+ num_feats = [
533
+ event.get('amount_pct_of_total_supply', 0.0),
534
+ event.get('amount_tokens_burned', 0.0)
535
+ ]
536
+ token_burn_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
537
+
538
+ elif event_type == 'SupplyLock':
539
+ # This event has a wallet and two numerical features.
540
+ num_feats = [
541
+ event.get('amount_pct_of_total_supply', 0.0),
542
+ event.get('lock_duration', 0.0)
543
+ ]
544
+ supply_lock_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
545
+
546
+ elif event_type == 'OnChain_Snapshot':
547
+ # This event is a global snapshot with 14 numerical features.
548
+ num_feats = [
549
+ event.get('total_holders', 0.0),
550
+ event.get('smart_traders', 0.0),
551
+ event.get('kols', 0.0),
552
+ event.get('holder_growth_rate', 0.0),
553
+ event.get('top_10_holder_pct', 0.0),
554
+ event.get('sniper_holding_pct', 0.0),
555
+ event.get('rat_wallets_holding_pct', 0.0),
556
+ event.get('bundle_holding_pct', 0.0),
557
+ event.get('current_market_cap', 0.0),
558
+ event.get('volume', 0.0),
559
+ event.get('buy_count', 0.0),
560
+ event.get('sell_count', 0.0),
561
+ event.get('total_txns', 0.0),
562
+ event.get('global_fees_paid', 0.0)
563
+ ]
564
+ onchain_snapshot_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
565
+
566
+ elif event_type == 'TrendingToken':
567
+ # Get the trending token index
568
+ trending_t_addr = event.get('token_address')
569
+ if trending_t_addr:
570
+ trending_token_indices[i, j] = token_addr_to_batch_idx.get(trending_t_addr, PAD_IDX_ENT)
571
+
572
+ trending_token_source_ids[i, j] = event.get('list_source_id', 0)
573
+ trending_token_timeframe_ids[i, j] = event.get('timeframe_id', 0)
574
+ # --- FIXED: Invert rank so that 1 is the highest value ---
575
+ # Get numerical/categorical features
576
+ num_feats = [
577
+ 1.0 / event.get('rank', 1e9) # Use a large number for rank 0 or missing to make it small
578
+ ]
579
+ trending_token_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
580
+
581
+ elif event_type == 'BoostedToken':
582
+ # Get the boosted token index
583
+ boosted_t_addr = event.get('token_address')
584
+ if boosted_t_addr:
585
+ boosted_token_indices[i, j] = token_addr_to_batch_idx.get(boosted_t_addr, PAD_IDX_ENT)
586
+
587
+ # --- FIXED: Invert rank so that 1 is the highest value ---
588
+ # Get numerical features
589
+ num_feats = [
590
+ event.get('total_boost_amount', 0.0),
591
+ 1.0 / event.get('rank', 1e9)
592
+ ]
593
+ boosted_token_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
594
+
595
+ elif event_type == 'HolderSnapshot':
596
+ # --- FIXED: Store raw holder data, not an index ---
597
+ raw_holders = event.get('holders', [])
598
+ holder_snapshot_raw_data_list.append(raw_holders)
599
+ holder_snapshot_indices[i, j] = len(holder_snapshot_raw_data_list) # 1-based index to the list
600
+
601
+ elif event_type == 'Lighthouse_Snapshot':
602
+ lighthousesnapshot_protocol_ids[i, j] = event.get('protocol_id', 0)
603
+ lighthousesnapshot_timeframe_ids[i, j] = event.get('timeframe_id', 0)
604
+ num_feats = [
605
+ event.get('total_volume', 0.0),
606
+ event.get('total_transactions', 0.0),
607
+ event.get('total_traders', 0.0),
608
+ event.get('total_tokens_created', 0.0),
609
+ event.get('total_migrations', 0.0)
610
+ ]
611
+ lighthousesnapshot_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
612
+
613
+
614
+ # --- UPDATED: Group all events that contain pre-computed text/image indices ---
615
+ elif event_type in ['XPost', 'XReply', 'XRetweet', 'XQuoteTweet', 'PumpReply', 'DexProfile_Updated', 'TikTok_Trending_Hashtag', 'XTrending_Hashtag']:
616
+ # Store raw event data to look up text/image indices later
617
+ # 1. Store raw text/media data
618
+ textual_event_data_list.append(event)
619
+ textual_event_indices[i, j] = len(textual_event_data_list) # 1-based index
620
+ # --- FIXED: Handle rank for trending hashtags ---
621
+ if event_type in ['TikTok_Trending_Hashtag', 'XTrending_Hashtag']:
622
+ global_trending_numerical_features[i, j, 0] = 1.0 / event.get('rank', 1e9)
623
+
624
+ # 2. Populate wallet pointer tensors based on the event type
625
+ # The main 'wallet_address' is already handled above.
626
+ # Here we handle the *other* wallets involved.
627
+ if event_type == 'XRetweet' or event_type == 'XQuoteTweet':
628
+ orig_author_addr = event.get('original_author_wallet_address')
629
+ if orig_author_addr:
630
+ # --- FIXED: Use the dedicated tensor for original authors ---
631
+ original_author_indices[i, j] = wallet_addr_to_batch_idx.get(orig_author_addr, PAD_IDX_ENT)
632
+
633
+ # The pre-computed embedding indices are already in the event dict.
634
+ # No need to populate image_indices here anymore.
635
+ # For XReply, the main tweet is a text/media embedding, not a wallet.
636
+ # For XPost, there's only one wallet, already handled.
637
+
638
+ # --- 4. Collate Dynamic Features (OHLC) AFTER collecting them ---
639
+ ohlc_inputs_dict = self._collate_ohlc_inputs(batch_chart_events)
640
+
641
+ # --- 6. Prepare final output dictionary ---
642
+ collated_batch = {
643
+ # Sequence Tensors
644
+ 'event_type_ids': event_type_ids,
645
+ 'timestamps_float': timestamps_float,
646
+ 'relative_ts': relative_ts,
647
+ 'attention_mask': attention_mask,
648
+ # Pointer Tensors
649
+ 'wallet_indices': wallet_indices,
650
+ 'token_indices': token_indices,
651
+ 'quote_token_indices': quote_token_indices, # NEW
652
+ 'trending_token_indices': trending_token_indices, # NEW
653
+ 'boosted_token_indices': boosted_token_indices, # NEW
654
+ 'holder_snapshot_indices': holder_snapshot_indices, # This now points to the generated embeddings
655
+ 'textual_event_indices': textual_event_indices, # RENAMED
656
+ 'ohlc_indices': ohlc_indices,
657
+ # Raw Data for Encoders
658
+ 'embedding_pool': batch_embedding_pool, # NEW
659
+ 'token_encoder_inputs': token_encoder_inputs,
660
+ 'wallet_encoder_inputs': wallet_encoder_inputs, # ADDED BACK
661
+ 'ohlc_price_tensors': ohlc_inputs_dict['price_tensor'],
662
+ 'ohlc_interval_ids': ohlc_inputs_dict['interval_ids'],
663
+ 'graph_updater_links': graph_updater_links,
664
+ 'wallet_addr_to_batch_idx': wallet_addr_to_batch_idx, # NEW: Pass the mapping
665
+
666
+ 'dest_wallet_indices': dest_wallet_indices, # ADDED THIS LINE
667
+ 'original_author_indices': original_author_indices, # NEW
668
+ # --- NEW: Numerical Features for Events ---
669
+ 'transfer_numerical_features': transfer_numerical_features,
670
+ 'trade_numerical_features': trade_numerical_features,
671
+ 'trade_dex_ids': trade_dex_ids,
672
+ 'deployer_trade_numerical_features': deployer_trade_numerical_features,
673
+ 'trade_direction_ids': trade_direction_ids, # NEW
674
+ 'trade_mev_protection_ids': trade_mev_protection_ids, # NEW
675
+ 'smart_wallet_trade_numerical_features': smart_wallet_trade_numerical_features,
676
+ 'trade_is_bundle_ids': trade_is_bundle_ids, # NEW
677
+ 'pool_created_numerical_features': pool_created_numerical_features,
678
+ 'pool_created_protocol_ids': pool_created_protocol_ids, # NEW
679
+ 'liquidity_change_numerical_features': liquidity_change_numerical_features,
680
+ 'liquidity_change_type_ids': liquidity_change_type_ids, # NEW
681
+ 'fee_collected_numerical_features': fee_collected_numerical_features, # NEW
682
+ 'token_burn_numerical_features': token_burn_numerical_features, # NEW
683
+ 'supply_lock_numerical_features': supply_lock_numerical_features, # NEW
684
+ 'onchain_snapshot_numerical_features': onchain_snapshot_numerical_features, # NEW
685
+ 'boosted_token_numerical_features': boosted_token_numerical_features,
686
+ 'trending_token_numerical_features': trending_token_numerical_features,
687
+ 'trending_token_source_ids': trending_token_source_ids, # NEW
688
+ 'trending_token_timeframe_ids': trending_token_timeframe_ids, # NEW
689
+ 'dexboost_paid_numerical_features': dexboost_paid_numerical_features, # NEW
690
+ 'dexprofile_updated_flags': dexprofile_updated_flags, # NEW,
691
+ 'global_trending_numerical_features': global_trending_numerical_features, # NEW
692
+ 'chainsnapshot_numerical_features': chainsnapshot_numerical_features, # NEW
693
+ 'lighthousesnapshot_numerical_features': lighthousesnapshot_numerical_features,
694
+ 'lighthousesnapshot_protocol_ids': lighthousesnapshot_protocol_ids, # NEW
695
+ 'lighthousesnapshot_timeframe_ids': lighthousesnapshot_timeframe_ids, # NEW
696
+ 'migrated_protocol_ids': migrated_protocol_ids, # NEW
697
+ 'alpha_group_ids': alpha_group_ids, # NEW
698
+ 'channel_ids': channel_ids, # NEW
699
+ 'exchange_ids': exchange_ids, # NEW
700
+ 'holder_snapshot_raw_data': holder_snapshot_raw_data_list, # NEW: Raw data for end-to-end processing
701
+ 'textual_event_data': textual_event_data_list, # RENAMED
702
+ # Labels
703
+ 'labels': torch.stack([item['labels'] for item in batch]) if batch and 'labels' in batch[0] else None,
704
+ 'labels_mask': torch.stack([item['labels_mask'] for item in batch]) if batch and 'labels_mask' in batch[0] else None
705
+ }
706
+
707
+ # Filter out None values (e.g., if no labels provided)
708
+ return {k: v for k, v in collated_batch.items() if v is not None}
data/data_fetcher.py ADDED
@@ -0,0 +1,1009 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data_fetcher.py
2
+
3
+ from typing import List, Dict, Any, Tuple, Set
4
+ from collections import defaultdict
5
+ import datetime, time
6
+
7
+ # We need the vocabulary for mapping IDs
8
+ import models.vocabulary as vocab
9
+
10
+ class DataFetcher:
11
+ """
12
+ A dedicated class to handle all database queries for ClickHouse and Neo4j.
13
+ This keeps data fetching logic separate from the dataset and model.
14
+ """
15
+
16
+ # --- Explicit column definitions for wallet profile & social fetches ---
17
+ PROFILE_BASE_COLUMNS = [
18
+ 'wallet_address',
19
+ 'updated_at',
20
+ 'first_seen_ts',
21
+ 'last_seen_ts',
22
+ 'tags',
23
+ 'deployed_tokens',
24
+ 'funded_from',
25
+ 'funded_timestamp',
26
+ 'funded_signature',
27
+ 'funded_amount'
28
+ ]
29
+
30
+ PROFILE_METRIC_COLUMNS = [
31
+ 'balance',
32
+ 'transfers_in_count',
33
+ 'transfers_out_count',
34
+ 'spl_transfers_in_count',
35
+ 'spl_transfers_out_count',
36
+ 'total_buys_count',
37
+ 'total_sells_count',
38
+ 'total_winrate',
39
+ 'stats_1d_realized_profit_sol',
40
+ 'stats_1d_realized_profit_usd',
41
+ 'stats_1d_realized_profit_pnl',
42
+ 'stats_1d_buy_count',
43
+ 'stats_1d_sell_count',
44
+ 'stats_1d_transfer_in_count',
45
+ 'stats_1d_transfer_out_count',
46
+ 'stats_1d_avg_holding_period',
47
+ 'stats_1d_total_bought_cost_sol',
48
+ 'stats_1d_total_bought_cost_usd',
49
+ 'stats_1d_total_sold_income_sol',
50
+ 'stats_1d_total_sold_income_usd',
51
+ 'stats_1d_total_fee',
52
+ 'stats_1d_winrate',
53
+ 'stats_1d_tokens_traded',
54
+ 'stats_7d_realized_profit_sol',
55
+ 'stats_7d_realized_profit_usd',
56
+ 'stats_7d_realized_profit_pnl',
57
+ 'stats_7d_buy_count',
58
+ 'stats_7d_sell_count',
59
+ 'stats_7d_transfer_in_count',
60
+ 'stats_7d_transfer_out_count',
61
+ 'stats_7d_avg_holding_period',
62
+ 'stats_7d_total_bought_cost_sol',
63
+ 'stats_7d_total_bought_cost_usd',
64
+ 'stats_7d_total_sold_income_sol',
65
+ 'stats_7d_total_sold_income_usd',
66
+ 'stats_7d_total_fee',
67
+ 'stats_7d_winrate',
68
+ 'stats_7d_tokens_traded',
69
+ 'stats_30d_realized_profit_sol',
70
+ 'stats_30d_realized_profit_usd',
71
+ 'stats_30d_realized_profit_pnl',
72
+ 'stats_30d_buy_count',
73
+ 'stats_30d_sell_count',
74
+ 'stats_30d_transfer_in_count',
75
+ 'stats_30d_transfer_out_count',
76
+ 'stats_30d_avg_holding_period',
77
+ 'stats_30d_total_bought_cost_sol',
78
+ 'stats_30d_total_bought_cost_usd',
79
+ 'stats_30d_total_sold_income_sol',
80
+ 'stats_30d_total_sold_income_usd',
81
+ 'stats_30d_total_fee',
82
+ 'stats_30d_winrate',
83
+ 'stats_30d_tokens_traded'
84
+ ]
85
+
86
+ PROFILE_COLUMNS_FOR_QUERY = PROFILE_BASE_COLUMNS + PROFILE_METRIC_COLUMNS
87
+
88
+ SOCIAL_COLUMNS_FOR_QUERY = [
89
+ 'wallet_address',
90
+ 'pumpfun_username',
91
+ 'twitter_username',
92
+ 'telegram_channel',
93
+ 'kolscan_name',
94
+ 'cabalspy_name',
95
+ 'axiom_kol_name'
96
+ ]
97
+ def __init__(self, clickhouse_client: Any, neo4j_driver: Any):
98
+ self.db_client = clickhouse_client
99
+ self.graph_client = neo4j_driver
100
+ print("DataFetcher instantiated.")
101
+
102
+ def get_all_mints(self, start_date: Optional[datetime.datetime] = None) -> List[Dict[str, Any]]:
103
+ """
104
+ Fetches a list of all mint events to serve as dataset samples.
105
+ Can be filtered to only include mints on or after a given start_date.
106
+ """
107
+ query = "SELECT mint_address, timestamp, creator_address, protocol, token_name, token_symbol, token_uri, total_supply, token_decimals FROM mints"
108
+ params = {}
109
+ where_clauses = []
110
+
111
+ if start_date:
112
+ where_clauses.append("timestamp >= %(start_date)s")
113
+ params['start_date'] = start_date
114
+
115
+ if where_clauses:
116
+ query += " WHERE " + " AND ".join(where_clauses)
117
+
118
+ print(f"INFO: Executing query to get all mints: `{query}` with params: {params}")
119
+ try:
120
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
121
+ if not rows:
122
+ return []
123
+ columns = [col[0] for col in columns_info]
124
+ result = [dict(zip(columns, row)) for row in rows]
125
+ if not result:
126
+ return []
127
+ return result
128
+ except Exception as e:
129
+ print(f"ERROR: Failed to fetch token addresses from ClickHouse: {e}")
130
+ print("INFO: Falling back to mock token addresses for development.")
131
+ return [{'mint_address': 'tknA_real', 'timestamp': datetime.datetime.now(datetime.timezone.utc), 'creator_address': 'addr_Creator_Real', 'protocol': 0}]
132
+
133
+
134
+ def fetch_mint_record(self, token_address: str) -> Dict[str, Any]:
135
+ """
136
+ Fetches the raw mint record for a token from the 'mints' table.
137
+ """
138
+ query = f"SELECT timestamp, creator_address, mint_address, protocol FROM mints WHERE mint_address = '{token_address}' ORDER BY timestamp ASC LIMIT 1"
139
+ print(f"INFO: Executing query to fetch mint record: `{query}`")
140
+
141
+ # Assumes the client returns a list of dicts or can be converted
142
+ # Using column names from your schema
143
+ columns = ['timestamp', 'creator_address', 'mint_address', 'protocol']
144
+ try:
145
+ result = self.db_client.execute(query)
146
+
147
+ if not result or not result[0]:
148
+ raise ValueError(f"No mint event found for token {token_address}")
149
+
150
+ # Convert the tuple result into a dictionary
151
+ record = dict(zip(columns, result[0]))
152
+ return record
153
+ except Exception as e:
154
+ print(f"ERROR: Failed to fetch mint record for {token_address}: {e}")
155
+ print("INFO: Falling back to mock mint record for development.")
156
+ # Fallback for development if DB connection fails
157
+ return {
158
+ 'timestamp': datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(days=1),
159
+ 'creator_address': 'addr_Creator_Real',
160
+ 'mint_address': token_address,
161
+ 'protocol': vocab.PROTOCOL_TO_ID.get("Pump V1", 0)
162
+ }
163
+
164
+ def fetch_wallet_profiles(self, wallet_addresses: List[str], T_cutoff: datetime.datetime) -> Dict[str, Dict[str, Any]]:
165
+ """
166
+ Convenience wrapper around fetch_wallet_profiles_and_socials for profile-only data.
167
+ """
168
+ profiles, _ = self.fetch_wallet_profiles_and_socials(wallet_addresses, T_cutoff)
169
+ return profiles
170
+
171
+ def fetch_wallet_socials(self, wallet_addresses: List[str]) -> Dict[str, Dict[str, Any]]:
172
+ """
173
+ Fetches wallet social records for a list of wallet addresses.
174
+ Returns a dictionary mapping wallet_address to its social data.
175
+ """
176
+ if not wallet_addresses:
177
+ return {}
178
+
179
+ query = "SELECT * FROM wallet_socials WHERE wallet_address IN %(addresses)s"
180
+ params = {'addresses': wallet_addresses}
181
+ print(f"INFO: Executing query to fetch wallet socials for {len(wallet_addresses)} wallets.")
182
+
183
+ try:
184
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
185
+ if not rows:
186
+ return {}
187
+
188
+ columns = [col[0] for col in columns_info]
189
+ socials = {}
190
+ for row in rows:
191
+ social_dict = dict(zip(columns, row))
192
+ wallet_addr = social_dict.get('wallet_address')
193
+ if wallet_addr:
194
+ socials[wallet_addr] = social_dict
195
+ return socials
196
+
197
+ except Exception as e:
198
+ print(f"ERROR: Failed to fetch wallet socials: {e}")
199
+ print("INFO: Returning empty dictionary for wallet socials.")
200
+ return {}
201
+
202
+ def fetch_wallet_profiles_and_socials(self,
203
+ wallet_addresses: List[str],
204
+ T_cutoff: datetime.datetime) -> Tuple[Dict[str, Dict[str, Any]], Dict[str, Dict[str, Any]]]:
205
+ """
206
+ Fetches wallet profiles (time-aware) and socials for all requested wallets in a single query.
207
+ Returns two dictionaries: profiles, socials.
208
+ """
209
+ if not wallet_addresses:
210
+ return {}, {}
211
+
212
+ social_columns = self.SOCIAL_COLUMNS_FOR_QUERY
213
+
214
+ profile_base_cols = self.PROFILE_BASE_COLUMNS
215
+ profile_metric_cols = self.PROFILE_METRIC_COLUMNS
216
+
217
+ profile_base_str = ",\n ".join(profile_base_cols)
218
+ metric_projection_cols = ['wallet_address', 'updated_at'] + profile_metric_cols
219
+ profile_metric_str = ",\n ".join(metric_projection_cols)
220
+
221
+ profile_base_select_cols = [col for col in profile_base_cols if col != 'wallet_address']
222
+ profile_metric_select_cols = [
223
+ col for col in profile_metric_cols if col not in ('wallet_address',)
224
+ ]
225
+ social_select_cols = [col for col in social_columns if col != 'wallet_address']
226
+
227
+ select_expressions = []
228
+ for col in profile_base_select_cols:
229
+ select_expressions.append(f"lp.{col} AS profile__{col}")
230
+ for col in profile_metric_select_cols:
231
+ select_expressions.append(f"lm.{col} AS profile__{col}")
232
+ for col in social_select_cols:
233
+ select_expressions.append(f"ws.{col} AS social__{col}")
234
+ select_clause = ""
235
+ if select_expressions:
236
+ select_clause = ",\n " + ",\n ".join(select_expressions)
237
+
238
+ query = f"""
239
+ WITH ranked_profiles AS (
240
+ SELECT
241
+ {profile_base_str},
242
+ ROW_NUMBER() OVER (PARTITION BY wallet_address ORDER BY updated_at DESC) AS rn
243
+ FROM wallet_profiles
244
+ WHERE wallet_address IN %(addresses)s
245
+ ),
246
+ latest_profiles AS (
247
+ SELECT
248
+ {profile_base_str}
249
+ FROM ranked_profiles
250
+ WHERE rn = 1
251
+ ),
252
+ ranked_metrics AS (
253
+ SELECT
254
+ {profile_metric_str},
255
+ ROW_NUMBER() OVER (PARTITION BY wallet_address ORDER BY updated_at DESC) AS rn
256
+ FROM wallet_profile_metrics
257
+ WHERE
258
+ wallet_address IN %(addresses)s
259
+ AND updated_at <= %(T_cutoff)s
260
+ ),
261
+ latest_metrics AS (
262
+ SELECT
263
+ {profile_metric_str}
264
+ FROM ranked_metrics
265
+ WHERE rn = 1
266
+ ),
267
+ requested_wallets AS (
268
+ SELECT DISTINCT wallet_address
269
+ FROM (SELECT arrayJoin(%(addresses)s) AS wallet_address)
270
+ )
271
+ SELECT
272
+ rw.wallet_address AS wallet_address
273
+ {select_clause}
274
+ FROM requested_wallets AS rw
275
+ LEFT JOIN latest_profiles AS lp ON rw.wallet_address = lp.wallet_address
276
+ LEFT JOIN latest_metrics AS lm ON rw.wallet_address = lm.wallet_address
277
+ LEFT JOIN wallet_socials AS ws ON rw.wallet_address = ws.wallet_address;
278
+ """
279
+
280
+ params = {'addresses': wallet_addresses, 'T_cutoff': T_cutoff}
281
+ print(f"INFO: Executing combined query for profiles+socials on {len(wallet_addresses)} wallets.")
282
+
283
+ try:
284
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
285
+ if not rows:
286
+ return {}, {}
287
+
288
+ columns = [col[0] for col in columns_info]
289
+ profiles: Dict[str, Dict[str, Any]] = {}
290
+ socials: Dict[str, Dict[str, Any]] = {}
291
+
292
+ profile_keys = [f"profile__{col}" for col in (profile_base_select_cols + profile_metric_select_cols)]
293
+ social_keys = [f"social__{col}" for col in social_select_cols]
294
+
295
+ for row in rows:
296
+ row_dict = dict(zip(columns, row))
297
+ wallet_addr = row_dict.get('wallet_address')
298
+ if not wallet_addr:
299
+ continue
300
+
301
+ profile_data = {}
302
+ if profile_keys:
303
+ for pref_key in profile_keys:
304
+ if pref_key in row_dict:
305
+ value = row_dict[pref_key]
306
+ profile_data[pref_key.replace('profile__', '')] = value
307
+
308
+ if profile_data and any(value is not None for value in profile_data.values()):
309
+ profile_data['wallet_address'] = wallet_addr
310
+ profiles[wallet_addr] = profile_data
311
+
312
+ social_data = {}
313
+ if social_keys:
314
+ for pref_key in social_keys:
315
+ if pref_key in row_dict:
316
+ value = row_dict[pref_key]
317
+ social_data[pref_key.replace('social__', '')] = value
318
+
319
+ if social_data and any(value is not None for value in social_data.values()):
320
+ social_data['wallet_address'] = wallet_addr
321
+ socials[wallet_addr] = social_data
322
+
323
+ return profiles, socials
324
+
325
+ except Exception as e:
326
+ print(f"ERROR: Combined profile/social query failed: {e}")
327
+ print("INFO: Falling back to separate queries.")
328
+ profiles = self.fetch_wallet_profiles(wallet_addresses, T_cutoff)
329
+ socials = self.fetch_wallet_socials(wallet_addresses)
330
+ return profiles, socials
331
+
332
+ def fetch_wallet_holdings(self, wallet_addresses: List[str], T_cutoff: datetime.datetime) -> Dict[str, List[Dict[str, Any]]]:
333
+ """
334
+ Fetches top 3 wallet holding records for a list of wallet addresses that were active at T_cutoff.
335
+ Returns a dictionary mapping wallet_address to a LIST of its holding data.
336
+ """
337
+ if not wallet_addresses:
338
+ return {}
339
+
340
+ # --- NEW: Time-aware query based on user's superior logic ---
341
+ # 1. For each holding, find the latest state at or before T_cutoff.
342
+ # 2. Filter for holdings where the balance was greater than 0.
343
+ # 3. Rank these active holdings by USD volume and take the top 3 per wallet.
344
+ query = """
345
+ WITH point_in_time_holdings AS (
346
+ SELECT
347
+ *,
348
+ COALESCE(history_bought_cost_sol, 0) + COALESCE(history_sold_income_sol, 0) AS total_volume_usd,
349
+ ROW_NUMBER() OVER(PARTITION BY wallet_address, mint_address ORDER BY updated_at DESC) as rn_per_holding
350
+ FROM wallet_holdings
351
+ WHERE
352
+ wallet_address IN %(addresses)s
353
+ AND updated_at <= %(T_cutoff)s
354
+ ),
355
+ ranked_active_holdings AS (
356
+ SELECT *,
357
+ ROW_NUMBER() OVER(PARTITION BY wallet_address ORDER BY total_volume_usd DESC) as rn_per_wallet
358
+ FROM point_in_time_holdings
359
+ WHERE rn_per_holding = 1 AND current_balance > 0
360
+ )
361
+ SELECT *
362
+ FROM ranked_active_holdings
363
+ WHERE rn_per_wallet <= 3;
364
+ """
365
+ params = {'addresses': wallet_addresses, 'T_cutoff': T_cutoff}
366
+ print(f"INFO: Executing query to fetch wallet holdings for {len(wallet_addresses)} wallets.")
367
+
368
+ try:
369
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
370
+ if not rows:
371
+ return {}
372
+
373
+ columns = [col[0] for col in columns_info]
374
+ holdings = defaultdict(list)
375
+ for row in rows:
376
+ holding_dict = dict(zip(columns, row))
377
+ wallet_addr = holding_dict.get('wallet_address')
378
+ if wallet_addr:
379
+ holdings[wallet_addr].append(holding_dict)
380
+ return dict(holdings)
381
+
382
+ except Exception as e:
383
+ print(f"ERROR: Failed to fetch wallet holdings: {e}")
384
+ print("INFO: Returning empty dictionary for wallet holdings.")
385
+ return {}
386
+
387
+ def fetch_graph_links(self,
388
+ initial_addresses: List[str],
389
+ T_cutoff: datetime.datetime,
390
+ max_degrees: int = 2) -> Tuple[Dict[str, str], Dict[str, Dict[str, Any]]]:
391
+ """
392
+ Fetches graph links from Neo4j, traversing up to a max degree of separation.
393
+
394
+ Args:
395
+ initial_addresses: A list of starting wallet or token addresses.
396
+ max_degrees: The maximum number of hops to traverse in the graph.
397
+
398
+ Returns:
399
+ A tuple containing:
400
+ - A dictionary mapping entity addresses to their type ('Wallet' or 'Token').
401
+ - A dictionary of aggregated links, structured for the GraphUpdater.
402
+ """
403
+ if not initial_addresses:
404
+ return set(), {}
405
+
406
+ cutoff_ts = int(T_cutoff.timestamp())
407
+
408
+ print(f"INFO: Fetching graph links up to {max_degrees} degrees for {len(initial_addresses)} initial entities...")
409
+ try:
410
+ with self.graph_client.session() as session:
411
+ all_entities = {addr: 'Token' for addr in initial_addresses} # Assume initial are tokens
412
+ newly_found_entities = set(initial_addresses)
413
+ aggregated_links = defaultdict(lambda: {'links': [], 'edges': []})
414
+
415
+ for i in range(max_degrees):
416
+ if not newly_found_entities:
417
+ break
418
+
419
+ print(f" - Degree {i+1}: Traversing from {len(newly_found_entities)} new entities...")
420
+
421
+ # Cypher query to find direct neighbors of the current frontier
422
+ query = """
423
+ MATCH (a)-[r]-(b)
424
+ WHERE a.address IN $addresses
425
+ RETURN a.address AS source_address, type(r) AS link_type, properties(r) AS link_props, b.address AS dest_address, labels(b)[0] AS dest_type
426
+ """
427
+ params = {'addresses': list(newly_found_entities)}
428
+ result = session.run(query, params)
429
+
430
+ current_degree_new_entities = set()
431
+ for record in result:
432
+ link_type = record['link_type']
433
+ link_props = dict(record['link_props'])
434
+ link_ts_raw = link_props.get('timestamp')
435
+ try:
436
+ link_ts = int(link_ts_raw)
437
+ except (TypeError, ValueError):
438
+ continue
439
+ if link_ts > cutoff_ts:
440
+ continue
441
+ source_addr = record['source_address']
442
+ dest_addr = record['dest_address']
443
+ dest_type = record['dest_type']
444
+
445
+ # Add the link and edge data
446
+ aggregated_links[link_type]['links'].append(link_props)
447
+ aggregated_links[link_type]['edges'].append((source_addr, dest_addr))
448
+
449
+ # If we found a new entity, add it to the set for the next iteration
450
+ if dest_addr not in all_entities.keys():
451
+ current_degree_new_entities.add(dest_addr)
452
+ all_entities[dest_addr] = dest_type
453
+
454
+ newly_found_entities = current_degree_new_entities
455
+
456
+ return all_entities, dict(aggregated_links)
457
+ except Exception as e:
458
+ print(f"ERROR: Failed to fetch graph links from Neo4j: {e}")
459
+ return {addr: 'Token' for addr in initial_addresses}, {}
460
+
461
+ def fetch_token_data(self, token_addresses: List[str], T_cutoff: datetime.datetime) -> Dict[str, Dict[str, Any]]:
462
+ """
463
+ Fetches the latest token data for each address at or before T_cutoff.
464
+ Returns a dictionary mapping token_address to its data.
465
+ """
466
+ if not token_addresses:
467
+ return {}
468
+
469
+ # --- NEW: Time-aware query for historical token data ---
470
+ query = """
471
+ WITH ranked_tokens AS (
472
+ SELECT
473
+ *,
474
+ ROW_NUMBER() OVER (PARTITION BY token_address ORDER BY updated_at DESC) as rn
475
+ FROM tokens
476
+ WHERE
477
+ token_address IN %(addresses)s
478
+ AND updated_at <= %(T_cutoff)s
479
+ )
480
+ SELECT token_address, name, symbol, token_uri, protocol, total_supply, decimals
481
+ FROM ranked_tokens
482
+ WHERE rn = 1;
483
+ """
484
+ params = {'addresses': token_addresses, 'T_cutoff': T_cutoff}
485
+ print(f"INFO: Executing query to fetch token data for {len(token_addresses)} tokens.")
486
+
487
+ try:
488
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
489
+
490
+ if not rows:
491
+ return {}
492
+
493
+ # Get column names from the query result description
494
+ columns = [col[0] for col in columns_info]
495
+
496
+ tokens = {}
497
+ for row in rows:
498
+ token_dict = dict(zip(columns, row))
499
+ token_addr = token_dict.get('token_address')
500
+ if token_addr:
501
+ # The 'tokens' table in the schema has 'token_address' but the
502
+ # collator expects 'address'. We'll add it for compatibility.
503
+ token_dict['address'] = token_addr
504
+ tokens[token_addr] = token_dict
505
+ return tokens
506
+
507
+ except Exception as e:
508
+ print(f"ERROR: Failed to fetch token data: {e}")
509
+ print("INFO: Returning empty dictionary for token data.")
510
+ return {}
511
+
512
+ def fetch_deployed_token_details(self, token_addresses: List[str], T_cutoff: datetime.datetime) -> Dict[str, Dict[str, Any]]:
513
+ """
514
+ Fetches historical details for deployed tokens at or before T_cutoff.
515
+ """
516
+ if not token_addresses:
517
+ return {}
518
+
519
+ # --- NEW: Time-aware query for historical deployed token details ---
520
+ query = """
521
+ WITH ranked_tokens AS (
522
+ SELECT
523
+ *,
524
+ ROW_NUMBER() OVER (PARTITION BY token_address ORDER BY updated_at DESC) as rn
525
+ FROM tokens
526
+ WHERE
527
+ token_address IN %(addresses)s
528
+ AND updated_at <= %(T_cutoff)s
529
+ ),
530
+ ranked_token_metrics AS (
531
+ SELECT
532
+ token_address,
533
+ ath_price_usd,
534
+ ROW_NUMBER() OVER (PARTITION BY token_address ORDER BY updated_at DESC) as rn
535
+ FROM token_metrics
536
+ WHERE
537
+ token_address IN %(addresses)s
538
+ AND updated_at <= %(T_cutoff)s
539
+ ),
540
+ latest_tokens AS (
541
+ SELECT *
542
+ FROM ranked_tokens
543
+ WHERE rn = 1
544
+ ),
545
+ latest_token_metrics AS (
546
+ SELECT *
547
+ FROM ranked_token_metrics
548
+ WHERE rn = 1
549
+ )
550
+ SELECT
551
+ lt.token_address,
552
+ lt.created_at,
553
+ lt.updated_at,
554
+ ltm.ath_price_usd,
555
+ lt.total_supply,
556
+ lt.decimals,
557
+ (lt.launchpad != lt.protocol) AS has_migrated
558
+ FROM latest_tokens AS lt
559
+ LEFT JOIN latest_token_metrics AS ltm
560
+ ON lt.token_address = ltm.token_address;
561
+ """
562
+ params = {'addresses': token_addresses, 'T_cutoff': T_cutoff}
563
+ print(f"INFO: Executing query to fetch deployed token details for {len(token_addresses)} tokens.")
564
+
565
+ try:
566
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
567
+ if not rows:
568
+ return {}
569
+
570
+ columns = [col[0] for col in columns_info]
571
+ token_details = {row[0]: dict(zip(columns, row)) for row in rows}
572
+ return token_details
573
+ except Exception as e:
574
+ print(f"ERROR: Failed to fetch deployed token details: {e}")
575
+ return {}
576
+
577
+ def fetch_trades_for_token(self, token_address: str, T_cutoff: datetime.datetime, count_threshold: int, early_limit: int, recent_limit: int) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
578
+ """
579
+ Fetches trades for a token, using a 3-part H/B/H strategy if the total count exceeds a threshold.
580
+ Returns three lists: early_trades, middle_trades, recent_trades.
581
+ """
582
+ if not token_address:
583
+ return [], [], []
584
+
585
+ params = {'token_address': token_address, 'T_cutoff': T_cutoff}
586
+
587
+ # 1. Get the total count of trades for the token before the cutoff
588
+ count_query = "SELECT count() FROM trades WHERE base_address = %(token_address)s AND timestamp <= %(T_cutoff)s"
589
+ try:
590
+ total_trades = self.db_client.execute(count_query, params)[0][0]
591
+ print(f"INFO: Found {total_trades} total trades for token {token_address} before {T_cutoff}.")
592
+ except Exception as e:
593
+ print(f"ERROR: Could not count trades for token {token_address}: {e}")
594
+ return [], [], []
595
+
596
+ # 2. Decide which query to use based on the count
597
+ if total_trades < count_threshold:
598
+ print("INFO: Fetching all trades (count is below H/B/H threshold).")
599
+ query = "SELECT * FROM trades WHERE base_address = %(token_address)s AND timestamp <= %(T_cutoff)s ORDER BY timestamp ASC"
600
+ try:
601
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
602
+ if not rows: return [], [], []
603
+ columns = [col[0] for col in columns_info]
604
+ all_trades = [dict(zip(columns, row)) for row in rows]
605
+ # When not using HBH, all trades are considered "early"
606
+ return all_trades, [], []
607
+ except Exception as e:
608
+ print(f"ERROR: Failed to fetch all trades for token {token_address}: {e}")
609
+ return [], [], []
610
+
611
+ # 3. Use the H/B/H strategy if the count is high
612
+ print("INFO: Fetching trades using 3-part High-Def/Blurry/High-Def strategy.")
613
+ try:
614
+ # Fetch Early (High-Def)
615
+ early_query = "SELECT * FROM trades WHERE base_address = %(token_address)s AND timestamp <= %(T_cutoff)s ORDER BY timestamp ASC LIMIT %(limit)s"
616
+ early_rows, early_cols_info = self.db_client.execute(early_query, {'token_address': token_address, 'T_cutoff': T_cutoff, 'limit': early_limit}, with_column_types=True)
617
+ early_trades = [dict(zip([c[0] for c in early_cols_info], r)) for r in early_rows] if early_rows else []
618
+
619
+ # Fetch Recent (High-Def)
620
+ recent_query = "SELECT * FROM trades WHERE base_address = %(token_address)s AND timestamp <= %(T_cutoff)s ORDER BY timestamp DESC LIMIT %(limit)s"
621
+ recent_rows, recent_cols_info = self.db_client.execute(recent_query, {'token_address': token_address, 'T_cutoff': T_cutoff, 'limit': recent_limit}, with_column_types=True)
622
+ recent_trades = [dict(zip([c[0] for c in recent_cols_info], r)) for r in recent_rows] if recent_rows else []
623
+ recent_trades.reverse() # Order ASC
624
+
625
+ # Fetch Middle (Blurry - successful trades only)
626
+ middle_trades = []
627
+ if early_trades and recent_trades:
628
+ start_middle_ts = early_trades[-1]['timestamp']
629
+ end_middle_ts = recent_trades[0]['timestamp']
630
+ if start_middle_ts < end_middle_ts:
631
+ middle_query = """
632
+ SELECT * FROM trades
633
+ WHERE base_address = %(token_address)s
634
+ AND success = true
635
+ AND timestamp > %(start_ts)s
636
+ AND timestamp < %(end_ts)s
637
+ ORDER BY timestamp ASC
638
+ """
639
+ middle_params = {'token_address': token_address, 'start_ts': start_middle_ts, 'end_ts': end_middle_ts}
640
+ middle_rows, middle_cols_info = self.db_client.execute(middle_query, middle_params, with_column_types=True)
641
+ middle_trades = [dict(zip([c[0] for c in middle_cols_info], r)) for r in middle_rows] if middle_rows else []
642
+
643
+ return early_trades, middle_trades, recent_trades
644
+
645
+ except Exception as e:
646
+ print(f"ERROR: Failed to fetch H/B/H trades for token {token_address}: {e}")
647
+ return [], [], []
648
+
649
+ def fetch_future_trades_for_token(self,
650
+ token_address: str,
651
+ start_ts: datetime.datetime,
652
+ end_ts: datetime.datetime) -> List[Dict[str, Any]]:
653
+ """
654
+ Fetches successful trades for a token in the window (start_ts, end_ts].
655
+ Used for constructing label targets beyond the cutoff.
656
+ """
657
+ if not token_address or start_ts is None or end_ts is None or start_ts >= end_ts:
658
+ return []
659
+
660
+ query = """
661
+ SELECT *
662
+ FROM trades
663
+ WHERE base_address = %(token_address)s
664
+ AND success = true
665
+ AND timestamp > %(start_ts)s
666
+ AND timestamp <= %(end_ts)s
667
+ ORDER BY timestamp ASC
668
+ """
669
+ params = {
670
+ 'token_address': token_address,
671
+ 'start_ts': start_ts,
672
+ 'end_ts': end_ts
673
+ }
674
+
675
+ try:
676
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
677
+ if not rows:
678
+ return []
679
+ columns = [col[0] for col in columns_info]
680
+ return [dict(zip(columns, row)) for row in rows]
681
+ except Exception as e:
682
+ print(f"ERROR: Failed to fetch future trades for token {token_address}: {e}")
683
+ return []
684
+
685
+ def fetch_transfers_for_token(self, token_address: str, T_cutoff: datetime.datetime, min_amount_threshold: float = 10_000_000) -> List[Dict[str, Any]]:
686
+ """
687
+ Fetches all transfers for a token before T_cutoff, filtering out small amounts.
688
+ """
689
+ if not token_address:
690
+ return []
691
+
692
+ query = """
693
+ SELECT * FROM transfers
694
+ WHERE mint_address = %(token_address)s
695
+ AND timestamp <= %(T_cutoff)s
696
+ AND amount_decimal >= %(min_amount)s
697
+ ORDER BY timestamp ASC
698
+ """
699
+ params = {'token_address': token_address, 'T_cutoff': T_cutoff, 'min_amount': min_amount_threshold}
700
+ print(f"INFO: Fetching significant transfers for {token_address} (amount >= {min_amount_threshold}).")
701
+
702
+ try:
703
+ # This query no longer uses H/B/H, it fetches all significant transfers
704
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
705
+ if not rows: return []
706
+ columns = [col[0] for col in columns_info]
707
+ return [dict(zip(columns, row)) for row in rows]
708
+ except Exception as e:
709
+ print(f"ERROR: Failed to fetch transfers for token {token_address}: {e}")
710
+ return []
711
+
712
+ def fetch_pool_creations_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
713
+ """
714
+ Fetches pool creation records where the token is the base asset.
715
+ """
716
+ if not token_address:
717
+ return []
718
+
719
+ query = """
720
+ SELECT
721
+ signature,
722
+ timestamp,
723
+ slot,
724
+ success,
725
+ error,
726
+ priority_fee,
727
+ protocol,
728
+ creator_address,
729
+ pool_address,
730
+ base_address,
731
+ quote_address,
732
+ lp_token_address,
733
+ initial_base_liquidity,
734
+ initial_quote_liquidity,
735
+ base_decimals,
736
+ quote_decimals
737
+ FROM pool_creations
738
+ WHERE base_address = %(token_address)s
739
+ AND timestamp <= %(T_cutoff)s
740
+ ORDER BY timestamp ASC
741
+ """
742
+ params = {'token_address': token_address, 'T_cutoff': T_cutoff}
743
+ print(f"INFO: Fetching pool creation events for {token_address}.")
744
+
745
+ try:
746
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
747
+ if not rows:
748
+ return []
749
+ columns = [col[0] for col in columns_info]
750
+ return [dict(zip(columns, row)) for row in rows]
751
+ except Exception as e:
752
+ print(f"ERROR: Failed to fetch pool creations for token {token_address}: {e}")
753
+ return []
754
+
755
+ def fetch_liquidity_changes_for_pools(self, pool_addresses: List[str], T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
756
+ """
757
+ Fetches liquidity change records for the given pools up to T_cutoff.
758
+ """
759
+ if not pool_addresses:
760
+ return []
761
+
762
+ query = """
763
+ SELECT
764
+ signature,
765
+ timestamp,
766
+ slot,
767
+ success,
768
+ error,
769
+ priority_fee,
770
+ protocol,
771
+ change_type,
772
+ lp_provider,
773
+ pool_address,
774
+ base_amount,
775
+ quote_amount
776
+ FROM liquidity
777
+ WHERE pool_address IN %(pool_addresses)s
778
+ AND timestamp <= %(T_cutoff)s
779
+ ORDER BY timestamp ASC
780
+ """
781
+ params = {'pool_addresses': pool_addresses, 'T_cutoff': T_cutoff}
782
+ print(f"INFO: Fetching liquidity change events for {len(pool_addresses)} pool(s).")
783
+
784
+ try:
785
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
786
+ if not rows:
787
+ return []
788
+ columns = [col[0] for col in columns_info]
789
+ return [dict(zip(columns, row)) for row in rows]
790
+ except Exception as e:
791
+ print(f"ERROR: Failed to fetch liquidity changes for pools {pool_addresses}: {e}")
792
+ return []
793
+
794
+ def fetch_fee_collections_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
795
+ """
796
+ Fetches fee collection events where the token appears as either token_0 or token_1.
797
+ """
798
+ if not token_address:
799
+ return []
800
+
801
+ query = """
802
+ SELECT
803
+ timestamp,
804
+ signature,
805
+ slot,
806
+ success,
807
+ error,
808
+ priority_fee,
809
+ protocol,
810
+ recipient_address,
811
+ token_0_mint_address,
812
+ token_0_amount,
813
+ token_1_mint_address,
814
+ token_1_amount
815
+ FROM fee_collections
816
+ WHERE (token_0_mint_address = %(token)s OR token_1_mint_address = %(token)s)
817
+ AND timestamp <= %(T_cutoff)s
818
+ ORDER BY timestamp ASC
819
+ """
820
+ params = {'token': token_address, 'T_cutoff': T_cutoff}
821
+ print(f"INFO: Fetching fee collection events for {token_address}.")
822
+
823
+ try:
824
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
825
+ if not rows:
826
+ return []
827
+ columns = [col[0] for col in columns_info]
828
+ return [dict(zip(columns, row)) for row in rows]
829
+ except Exception as e:
830
+ print(f"ERROR: Failed to fetch fee collections for token {token_address}: {e}")
831
+ return []
832
+
833
+ def fetch_migrations_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
834
+ """
835
+ Fetches migration records for a given token up to T_cutoff.
836
+ """
837
+ if not token_address:
838
+ return []
839
+ query = """
840
+ SELECT
841
+ timestamp,
842
+ signature,
843
+ slot,
844
+ success,
845
+ error,
846
+ priority_fee,
847
+ protocol,
848
+ mint_address,
849
+ virtual_pool_address,
850
+ pool_address,
851
+ migrated_base_liquidity,
852
+ migrated_quote_liquidity
853
+ FROM migrations
854
+ WHERE mint_address = %(token)s
855
+ AND timestamp <= %(T_cutoff)s
856
+ ORDER BY timestamp ASC
857
+ """
858
+ params = {'token': token_address, 'T_cutoff': T_cutoff}
859
+ print(f"INFO: Fetching migrations for {token_address}.")
860
+ try:
861
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
862
+ if not rows:
863
+ return []
864
+ columns = [col[0] for col in columns_info]
865
+ return [dict(zip(columns, row)) for row in rows]
866
+ except Exception as e:
867
+ print(f"ERROR: Failed to fetch migrations for token {token_address}: {e}")
868
+ return []
869
+
870
+ def fetch_burns_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
871
+ """
872
+ Fetches burn events for a given token up to T_cutoff.
873
+ Schema: burns(timestamp, signature, slot, success, error, priority_fee, mint_address, source, amount, amount_decimal, source_balance)
874
+ """
875
+ if not token_address:
876
+ return []
877
+
878
+ query = """
879
+ SELECT
880
+ timestamp,
881
+ signature,
882
+ slot,
883
+ success,
884
+ error,
885
+ priority_fee,
886
+ mint_address,
887
+ source,
888
+ amount,
889
+ amount_decimal,
890
+ source_balance
891
+ FROM burns
892
+ WHERE mint_address = %(token)s
893
+ AND timestamp <= %(T_cutoff)s
894
+ ORDER BY timestamp ASC
895
+ """
896
+ params = {'token': token_address, 'T_cutoff': T_cutoff}
897
+ print(f"INFO: Fetching burn events for {token_address}.")
898
+
899
+ try:
900
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
901
+ if not rows:
902
+ return []
903
+ columns = [col[0] for col in columns_info]
904
+ return [dict(zip(columns, row)) for row in rows]
905
+ except Exception as e:
906
+ print(f"ERROR: Failed to fetch burns for token {token_address}: {e}")
907
+ return []
908
+
909
+ def fetch_supply_locks_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
910
+ """
911
+ Fetches supply lock events for a given token up to T_cutoff.
912
+ Schema: supply_locks(timestamp, signature, slot, success, error, priority_fee, protocol, contract_address, sender, recipient, mint_address, total_locked_amount, final_unlock_timestamp)
913
+ """
914
+ if not token_address:
915
+ return []
916
+
917
+ query = """
918
+ SELECT
919
+ timestamp,
920
+ signature,
921
+ slot,
922
+ success,
923
+ error,
924
+ priority_fee,
925
+ protocol,
926
+ contract_address,
927
+ sender,
928
+ recipient,
929
+ mint_address,
930
+ total_locked_amount,
931
+ final_unlock_timestamp
932
+ FROM supply_locks
933
+ WHERE mint_address = %(token)s
934
+ AND timestamp <= %(T_cutoff)s
935
+ ORDER BY timestamp ASC
936
+ """
937
+ params = {'token': token_address, 'T_cutoff': T_cutoff}
938
+ print(f"INFO: Fetching supply lock events for {token_address}.")
939
+
940
+ try:
941
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
942
+ if not rows:
943
+ return []
944
+ columns = [col[0] for col in columns_info]
945
+ return [dict(zip(columns, row)) for row in rows]
946
+ except Exception as e:
947
+ print(f"ERROR: Failed to fetch supply locks for token {token_address}: {e}")
948
+ return []
949
+
950
+ def fetch_token_holders_for_snapshot(self, token_address: str, T_cutoff: datetime.datetime, limit: int = 200) -> List[Dict[str, Any]]:
951
+ """
952
+ Fetch top holders for a token at or before T_cutoff for snapshot purposes.
953
+ Returns rows with wallet_address and current_balance (>0), ordered by balance desc.
954
+ """
955
+ if not token_address:
956
+ return []
957
+ query = """
958
+ WITH point_in_time_holdings AS (
959
+ SELECT *,
960
+ ROW_NUMBER() OVER(PARTITION BY wallet_address, mint_address ORDER BY updated_at DESC) as rn_per_holding
961
+ FROM wallet_holdings
962
+ WHERE mint_address = %(token)s AND updated_at <= %(T_cutoff)s
963
+ )
964
+ SELECT wallet_address, current_balance
965
+ FROM point_in_time_holdings
966
+ WHERE rn_per_holding = 1 AND current_balance > 0
967
+ ORDER BY current_balance DESC
968
+ LIMIT %(limit)s;
969
+ """
970
+ params = {'token': token_address, 'T_cutoff': T_cutoff, 'limit': int(limit)}
971
+ print(f"INFO: Fetching top holders for snapshot for {token_address} (limit {limit}).")
972
+ try:
973
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
974
+ if not rows:
975
+ return []
976
+ columns = [col[0] for col in columns_info]
977
+ return [dict(zip(columns, row)) for row in rows]
978
+ except Exception as e:
979
+ print(f"ERROR: Failed to fetch token holders for {token_address}: {e}")
980
+ return []
981
+
982
+ def fetch_total_holders_count_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> int:
983
+ """
984
+ Returns the total number of wallets holding the token (current_balance > 0)
985
+ at or before T_cutoff.
986
+ """
987
+ if not token_address:
988
+ return 0
989
+ query = """
990
+ WITH point_in_time_holdings AS (
991
+ SELECT *,
992
+ ROW_NUMBER() OVER(PARTITION BY wallet_address, mint_address ORDER BY updated_at DESC) as rn_per_holding
993
+ FROM wallet_holdings
994
+ WHERE mint_address = %(token)s AND updated_at <= %(T_cutoff)s
995
+ )
996
+ SELECT count()
997
+ FROM point_in_time_holdings
998
+ WHERE rn_per_holding = 1 AND current_balance > 0;
999
+ """
1000
+ params = {'token': token_address, 'T_cutoff': T_cutoff}
1001
+ print(f"INFO: Counting total holders for {token_address} at cutoff.")
1002
+ try:
1003
+ rows = self.db_client.execute(query, params)
1004
+ if not rows:
1005
+ return 0
1006
+ return int(rows[0][0])
1007
+ except Exception as e:
1008
+ print(f"ERROR: Failed to count total holders for token {token_address}: {e}")
1009
+ return 0
data/data_loader.py ADDED
@@ -0,0 +1,1657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from collections import defaultdict
3
+ import datetime
4
+ import requests
5
+ from io import BytesIO
6
+ from torch.utils.data import Dataset, IterableDataset
7
+ from PIL import Image
8
+ from typing import List, Dict, Any, Optional, Union, Tuple
9
+ from pathlib import Path
10
+ import numpy as np
11
+ from bisect import bisect_left, bisect_right
12
+
13
+ # We need the vocabulary for IDs and the processor for the pooler
14
+ import models.vocabulary as vocab
15
+ from models.multi_modal_processor import MultiModalEncoder
16
+ from data.data_fetcher import DataFetcher # NEW: Import the DataFetcher
17
+
18
+ # --- NEW: Hardcoded decimals for common quote tokens ---
19
+ QUOTE_TOKEN_DECIMALS = {
20
+ 'So11111111111111111111111111111111111111112': 9, # SOL
21
+ 'EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v': 6, # USDC
22
+ 'Es9vMFrzaCERmJfrF4H2FYD4KCoNkY11McCe8BenwNYB': 6, # USDT
23
+ }
24
+
25
+ # --- NEW: Hyperparameters for trade event classification ---
26
+ LARGE_TRADE_USD_THRESHOLD = 100.0
27
+ LARGE_TRADE_SUPPLY_PCT_THRESHOLD = 0.03 # 3% of supply
28
+ LARGE_TRANSFER_SUPPLY_PCT_THRESHOLD = 0.03 # 3% of supply
29
+ SMART_WALLET_PNL_THRESHOLD = 3.0 # 300% PNL
30
+ SMART_WALLET_USD_THRESHOLD = 20000.0
31
+
32
+ # --- NEW: Hyperparameters for H/B/H Event Fetching ---
33
+ EVENT_COUNT_THRESHOLD_FOR_HBH = 30000 # If total events > this, use H/B/H
34
+ HBH_EARLY_EVENT_LIMIT = 10000
35
+ HBH_RECENT_EVENT_LIMIT = 15000
36
+
37
+ # --- NEW: OHLC Sequence Length Constant ---
38
+ OHLC_SEQ_LEN = 300 # 4 minutes of chart
39
+
40
+ MIN_AMOUNT_TRANSFER_SUPPLY = 0.0 # 1.0% of total supply
41
+
42
+ # Interval for HolderSnapshot events (seconds)
43
+ HOLDER_SNAPSHOT_INTERVAL_SEC = 300
44
+ HOLDER_SNAPSHOT_TOP_K = 200
45
+
46
+
47
+ class EmbeddingPooler:
48
+ """
49
+ A helper class to manage the collection and encoding of unique text/image items
50
+ for a single data sample.
51
+ """
52
+ def __init__(self):
53
+ self.pool_map = {}
54
+ self.next_idx = 1 # 0 is padding
55
+
56
+ def get_idx(self, item: Any) -> int:
57
+ """
58
+ Returns a unique index for a given item (string or image).
59
+ - Returns 0 for None or empty strings.
60
+ - Deduplicates identical text and image objects.
61
+ """
62
+ if item is None:
63
+ return 0
64
+
65
+ # Handle text case
66
+ if isinstance(item, str):
67
+ if not item.strip(): # skip empty or whitespace-only strings
68
+ return 0
69
+ key = item.strip() # use normalized text key
70
+ elif isinstance(item, Image.Image):
71
+ key = id(item) # unique memory address for images
72
+ else:
73
+ key = item # fallback: use object itself if hashable
74
+
75
+ if key not in self.pool_map:
76
+ self.pool_map[key] = {'item': item, 'idx': self.next_idx}
77
+ self.next_idx += 1
78
+
79
+ return self.pool_map[key]['idx']
80
+
81
+ def get_all_items(self) -> List[Dict[str, Any]]:
82
+ """
83
+ Returns a list of all unique items, sorted by their assigned index.
84
+ """
85
+ if not self.pool_map:
86
+ return []
87
+ return sorted(self.pool_map.values(), key=lambda x: x['idx'])
88
+
89
+
90
+ class OracleDataset(Dataset):
91
+ """
92
+ Dataset class for the Oracle model. It fetches, processes, and structures
93
+ all on-chain and off-chain data for a given token to create a comprehensive
94
+ input sequence for the model.
95
+ """
96
+ def __init__(self,
97
+ data_fetcher: DataFetcher, # NEW: Pass the fetcher instance
98
+ horizons_seconds: List[int] = [],
99
+ quantiles: List[float] = [],
100
+ max_samples: Optional[int] = None,
101
+ ohlc_stats_path: Union[str, Path] = "./data/ohlc_stats.npz", # NEW: Add stats path parameter
102
+ token_allowlist: Optional[List[str]] = None,
103
+ t_cutoff_seconds: int = 60,
104
+ cache_dir: Optional[Union[str, Path]] = None,
105
+ start_date: Optional[datetime.datetime] = None,
106
+ min_trade_usd: float = 0.0):
107
+
108
+ # --- NEW: Create a persistent requests session for efficiency ---
109
+ self.http_session = requests.Session()
110
+
111
+ self.fetcher = data_fetcher
112
+ self.cache_dir = Path(cache_dir) if cache_dir else None
113
+
114
+ # If a fetcher is provided, we can determine the number of samples.
115
+ # Otherwise, we are likely in a test mode where __len__ might not be called
116
+ # or is used with a mock length.
117
+ self.t_cutoff_seconds = max(0, int(t_cutoff_seconds or 0))
118
+ self.token_allowlist = set(token_allowlist) if token_allowlist else None
119
+
120
+ if self.cache_dir and self.cache_dir.is_dir():
121
+ print(f"INFO: Initializing dataset in offline (cached) mode from: {self.cache_dir}")
122
+ # Scan for cached files to determine length
123
+ self.cached_files = sorted(self.cache_dir.glob("sample_*.pt"), key=lambda p: int(p.stem.split('_')[1]))
124
+ if not self.cached_files:
125
+ raise RuntimeError(f"Cache directory '{self.cache_dir}' provided but contains no 'sample_*.pt' files.")
126
+
127
+ self.num_samples = len(self.cached_files)
128
+ if max_samples is not None:
129
+ self.num_samples = min(max_samples, self.num_samples)
130
+ self.cached_files = self.cached_files[:self.num_samples]
131
+ print(f"INFO: Found {self.num_samples} cached samples to use.")
132
+ self.sampled_mints = [] # Not needed in cached mode
133
+ self.available_mints = []
134
+
135
+ elif self.fetcher:
136
+ print(f"INFO: Initializing dataset in online (generation) mode...")
137
+ self.available_mints = self.fetcher.get_all_mints(start_date=start_date)
138
+ if not self.available_mints:
139
+ raise RuntimeError("Dataset initialization failed: no mint records returned from data fetcher.")
140
+ if self.token_allowlist:
141
+ filtered_mints = [
142
+ mint for mint in self.available_mints
143
+ if mint.get('mint_address') in self.token_allowlist
144
+ ]
145
+ if not filtered_mints:
146
+ raise RuntimeError(f"No mint records matched the provided token allowlist: {token_allowlist}")
147
+ self.available_mints = filtered_mints
148
+
149
+ total_mints = len(self.available_mints)
150
+ if max_samples is None:
151
+ self.num_samples = total_mints
152
+ self.sampled_mints = self.available_mints
153
+ else:
154
+ self.num_samples = min(max_samples, total_mints)
155
+ if self.num_samples < total_mints:
156
+ print(f"INFO: Limiting dataset to first {self.num_samples} of {total_mints} available mints.")
157
+ self.sampled_mints = self.available_mints[:self.num_samples]
158
+ else:
159
+ self.available_mints = []
160
+ self.sampled_mints = []
161
+ self.num_samples = 1 if max_samples is None else max_samples
162
+
163
+ self.horizons_seconds = sorted(set(horizons_seconds))
164
+ self.quantiles = quantiles
165
+ self.num_outputs = len(self.horizons_seconds) * len(self.quantiles)
166
+
167
+ # --- NEW: Load global OHLC normalization stats ---
168
+ stats_path = Path(ohlc_stats_path)
169
+ if not stats_path.exists():
170
+ raise FileNotFoundError(f"Required OHLC stats file not found: {stats_path}")
171
+ stats = np.load(stats_path)
172
+ self.ohlc_price_mean = float(stats.get('mean_price_usd', 0.0))
173
+ self.ohlc_price_std = float(stats.get('std_price_usd', 1.0)) or 1.0
174
+
175
+ self.min_trade_usd = min_trade_usd
176
+
177
+ def __len__(self) -> int:
178
+ return self.num_samples
179
+
180
+ def _normalize_price_series(self, values: List[float]) -> List[float]:
181
+ if not values:
182
+ return values
183
+ denom = self.ohlc_price_std if abs(self.ohlc_price_std) > 1e-9 else 1.0
184
+ return [(float(v) - self.ohlc_price_mean) / denom for v in values]
185
+
186
+ def _compute_future_return_labels(self,
187
+ anchor_price: Optional[float],
188
+ anchor_timestamp: int,
189
+ price_series: List[Tuple[int, float]]) -> Tuple[torch.Tensor, torch.Tensor, List[Dict[str, Any]]]:
190
+ if self.num_outputs == 0:
191
+ return torch.zeros(0), torch.zeros(0), []
192
+
193
+ if anchor_price is None or abs(anchor_price) < 1e-9 or not price_series:
194
+ return torch.zeros(self.num_outputs), torch.zeros(self.num_outputs), []
195
+
196
+ ts_list = [int(entry[0]) for entry in price_series]
197
+ price_list = [float(entry[1]) for entry in price_series]
198
+ if not ts_list:
199
+ return torch.zeros(self.num_outputs), torch.zeros(self.num_outputs), []
200
+
201
+ last_ts = ts_list[-1]
202
+
203
+ label_values: List[float] = []
204
+ mask_values: List[float] = []
205
+ debug_entries: List[Dict[str, Any]] = []
206
+
207
+ for horizon in self.horizons_seconds:
208
+ target_ts = anchor_timestamp + horizon
209
+ if target_ts > last_ts:
210
+ horizon_mask = 0.0
211
+ horizon_return = 0.0
212
+ future_price = None
213
+ else:
214
+ idx = bisect_right(ts_list, target_ts) - 1
215
+ if idx < 0:
216
+ horizon_mask = 0.0
217
+ horizon_return = 0.0
218
+ future_price = None
219
+ else:
220
+ future_price = price_list[idx]
221
+ horizon_return = (future_price - anchor_price) / anchor_price
222
+ horizon_return = max(min(horizon_return, 10.0), -10.0)
223
+ horizon_mask = 1.0
224
+
225
+ for _ in self.quantiles:
226
+ label_values.append(horizon_return)
227
+ mask_values.append(horizon_mask)
228
+ debug_entries.append({
229
+ 'horizon': horizon,
230
+ 'target_ts': target_ts,
231
+ 'future_price': future_price,
232
+ 'return': horizon_return,
233
+ 'mask': horizon_mask
234
+ })
235
+
236
+ return (torch.tensor(label_values, dtype=torch.float32),
237
+ torch.tensor(mask_values, dtype=torch.float32),
238
+ debug_entries)
239
+
240
+ def _generate_onchain_snapshots(
241
+ self,
242
+ token_address: str,
243
+ t0_timestamp: int,
244
+ T_cutoff: datetime.datetime,
245
+ interval_sec: int,
246
+ trade_events: List[Dict[str, Any]],
247
+ transfer_events: List[Dict[str, Any]],
248
+ aggregation_trades: List[Dict[str, Any]],
249
+ wallet_data: Dict[str, Any],
250
+ total_supply_dec: float,
251
+ _register_event_fn
252
+ ) -> None:
253
+ # Prepare helper sets and maps (static sniper set based on earliest buyers)
254
+ all_buy_trades = sorted([e for e in trade_events if e.get('trade_direction') == 0 and e.get('success', False)], key=lambda x: x['timestamp'])
255
+ sniper_wallets = []
256
+ seen_buyers = set()
257
+ for e in all_buy_trades:
258
+ wa = e['wallet_address']
259
+ if wa not in seen_buyers:
260
+ sniper_wallets.append(wa)
261
+ seen_buyers.add(wa)
262
+ if len(sniper_wallets) >= 70:
263
+ break
264
+ sniper_set = set(sniper_wallets)
265
+
266
+ KOL_NAME_KEYS = ['kolscan_name', 'cabalspy_name', 'axiom_kol_name']
267
+
268
+ # Build time arrays for price lookup
269
+ agg_ts = [int(t['timestamp']) for t in aggregation_trades] if aggregation_trades else []
270
+ agg_price = [float(t.get('price_usd', 0.0) or 0.0) for t in aggregation_trades] if aggregation_trades else []
271
+
272
+ start_ts = t0_timestamp
273
+ end_ts = int(self._timestamp_to_order_value(T_cutoff)) if hasattr(self, '_timestamp_to_order_value') else int(T_cutoff.timestamp())
274
+ if end_ts - start_ts < interval_sec:
275
+ oc_snapshot_times = [end_ts]
276
+ else:
277
+ steps = (end_ts - start_ts) // interval_sec
278
+ oc_snapshot_times = [start_ts + i * interval_sec for i in range(1, steps + 1)]
279
+
280
+ buyers_seen_global = set()
281
+ prev_holders_count = 0
282
+ for ts_value in oc_snapshot_times:
283
+ window_start = ts_value - interval_sec
284
+ trades_win = [e for e in trade_events if e.get('success', False) and window_start < e['timestamp'] <= ts_value]
285
+ xfers_win = [e for e in transfer_events if window_start < e['timestamp'] <= ts_value]
286
+
287
+ # Per-snapshot holder distribution at ts_value
288
+ cutoff_dt_ts = datetime.datetime.fromtimestamp(ts_value, tz=datetime.timezone.utc)
289
+ holder_records_ts = self.fetcher.fetch_token_holders_for_snapshot(token_address, cutoff_dt_ts, limit=HOLDER_SNAPSHOT_TOP_K)
290
+ holder_entries_ts = []
291
+ for rec in holder_records_ts:
292
+ addr = rec.get('wallet_address')
293
+ try:
294
+ bal = float(rec.get('current_balance', 0.0) or 0.0)
295
+ except (TypeError, ValueError):
296
+ bal = 0.0
297
+ pct = (bal / total_supply_dec) if total_supply_dec and total_supply_dec > 0 else 0.0
298
+ if addr and pct > 0.0:
299
+ holder_entries_ts.append({'wallet': addr, 'holding_pct': pct})
300
+ holder_entries_ts.sort(key=lambda d: d['holding_pct'], reverse=True)
301
+
302
+ # Emit HolderSnapshot for this ts_value
303
+ hs_event = {
304
+ 'event_type': 'HolderSnapshot',
305
+ 'timestamp': int(ts_value),
306
+ 'relative_ts': ts_value - t0_timestamp,
307
+ 'holders': holder_entries_ts
308
+ }
309
+ _register_event_fn(hs_event, self._event_execution_sort_key(ts_value, signature='HolderSnapshot') if hasattr(self, '_event_execution_sort_key') else (ts_value, 0, 0, 0, 'HolderSnapshot'))
310
+
311
+ holder_pct_map_ts = {d['wallet']: d['holding_pct'] for d in holder_entries_ts}
312
+ top10_holder_pct = sum(d['holding_pct'] for d in holder_entries_ts[:10]) if holder_entries_ts else 0.0
313
+
314
+ # Cumulative sets up to ts_value
315
+ rat_set_ts = set(ev['destination_wallet_address'] for ev in transfer_events if ev['timestamp'] <= ts_value)
316
+ bundle_buyer_set_ts = set(e['wallet_address'] for e in trade_events if e.get('is_bundle') and e.get('trade_direction') == 0 and e.get('success', False) and e['timestamp'] <= ts_value)
317
+
318
+ buy_count = sum(1 for e in trades_win if e.get('trade_direction') == 0)
319
+ sell_count = sum(1 for e in trades_win if e.get('trade_direction') == 1)
320
+ volume = sum(float(e.get('total_usd', 0.0) or 0.0) for e in trades_win)
321
+ total_txns = len(trades_win) + len(xfers_win)
322
+ global_fees_paid = sum(float(e.get('priority_fee', 0.0) or 0.0) for e in trades_win) + \
323
+ sum(float(e.get('priority_fee', 0.0) or 0.0) for e in xfers_win)
324
+
325
+ smart_trader_addrs = set(e['wallet_address'] for e in trades_win if e.get('event_type') == 'SmartWallet_Trade')
326
+ smart_traders = len(smart_trader_addrs)
327
+
328
+ kol_addrs = set()
329
+ for e in trades_win:
330
+ wa = e['wallet_address']
331
+ soc = wallet_data.get(wa, {}).get('socials', {})
332
+ if any(soc.get(k) for k in KOL_NAME_KEYS if soc):
333
+ kol_addrs.add(wa)
334
+ kols = len(kol_addrs)
335
+
336
+ new_buyers = [e['wallet_address'] for e in trades_win if e.get('trade_direction') == 0 and e['wallet_address'] not in buyers_seen_global]
337
+ for wa in new_buyers:
338
+ buyers_seen_global.add(wa)
339
+
340
+ # Compute growth against previous snapshot endpoint.
341
+ end_dt = datetime.datetime.fromtimestamp(ts_value, tz=datetime.timezone.utc)
342
+ holders_end = self.fetcher.fetch_total_holders_count_for_token(token_address, end_dt)
343
+ total_holders = float(holders_end)
344
+ delta_holders = holders_end - prev_holders_count
345
+ holder_growth_rate = float(delta_holders)
346
+ prev_holders_count = holders_end
347
+
348
+ # Market cap from last price at or before ts
349
+ last_price_usd = 0.0
350
+ if agg_ts:
351
+ for i in range(len(agg_ts) - 1, -1, -1):
352
+ if agg_ts[i] <= ts_value:
353
+ last_price_usd = agg_price[i]
354
+ break
355
+ current_market_cap = float(last_price_usd) * float(total_supply_dec)
356
+
357
+ oc_event = {
358
+ 'event_type': 'OnChain_Snapshot',
359
+ 'timestamp': int(ts_value),
360
+ 'relative_ts': ts_value - t0_timestamp,
361
+ 'total_holders': total_holders,
362
+ 'smart_traders': float(smart_traders),
363
+ 'kols': float(kols),
364
+ 'holder_growth_rate': float(holder_growth_rate),
365
+ 'top_10_holder_pct': float(top10_holder_pct),
366
+ 'sniper_holding_pct': float(sum(holder_pct_map_ts.get(wa, 0.0) for wa in sniper_set)),
367
+ 'rat_wallets_holding_pct': float(sum(holder_pct_map_ts.get(wa, 0.0) for wa in rat_set_ts)),
368
+ 'bundle_holding_pct': float(sum(holder_pct_map_ts.get(wa, 0.0) for wa in bundle_buyer_set_ts)),
369
+ 'current_market_cap': float(current_market_cap),
370
+ 'volume': float(volume),
371
+ 'buy_count': float(buy_count),
372
+ 'sell_count': float(sell_count),
373
+ 'total_txns': float(total_txns),
374
+ 'global_fees_paid': float(global_fees_paid)
375
+ }
376
+ _register_event_fn(oc_event, self._event_execution_sort_key(ts_value, signature='OnChain_Snapshot') if hasattr(self, '_event_execution_sort_key') else (ts_value, 0, 0, 0, 'OnChain_Snapshot'))
377
+
378
+ def _calculate_deployed_token_stats(self, profiles: Dict[str, Dict[str, Any]], T_cutoff: datetime.datetime):
379
+ """
380
+ Calculates aggregate statistics for wallets based on the tokens they've deployed.
381
+ This method modifies the `profiles` dictionary in-place.
382
+ """
383
+ if not profiles: return
384
+
385
+ for addr, profile in profiles.items():
386
+ deployed_tokens = profile.get('deployed_tokens', [])
387
+
388
+ # 1. Deployed Tokens Count
389
+ count = len(deployed_tokens)
390
+ profile['deployed_tokens_count'] = float(count)
391
+
392
+ if count == 0:
393
+ profile['deployed_tokens_migrated_pct'] = 0.0
394
+ profile['deployed_tokens_avg_lifetime_sec'] = 0.0
395
+ profile['deployed_tokens_avg_peak_mc_usd'] = 0.0
396
+ profile['deployed_tokens_median_peak_mc_usd'] = 0.0
397
+ continue
398
+
399
+ # --- NEW: Fetch deployed token details with point-in-time logic ---
400
+ deployed_token_details = self.fetcher.fetch_deployed_token_details(deployed_tokens, T_cutoff)
401
+
402
+ # Collect stats for all deployed tokens of this wallet
403
+ lifetimes = []
404
+ peak_mcs = []
405
+ migrated_count = 0
406
+ for token_addr in deployed_tokens:
407
+ details = deployed_token_details.get(token_addr)
408
+ if not details: continue
409
+
410
+ if details.get('has_migrated'):
411
+ migrated_count += 1
412
+
413
+ lifetimes.append((details['updated_at'] - details['created_at']).total_seconds())
414
+ peak_mcs.append(details.get('ath_price_usd', 0.0) * details.get('total_supply', 0.0) / (10**details.get('decimals', 9))) # Simplified MC
415
+
416
+ # 2. Migrated Pct
417
+ profile['deployed_tokens_migrated_pct'] = (migrated_count / count) if count > 0 else 0.0
418
+ # 3. Avg Lifetime
419
+ profile['deployed_tokens_avg_lifetime_sec'] = torch.mean(torch.tensor(lifetimes)).item() if lifetimes else 0.0
420
+ # 4. Avg & Median Peak MC
421
+ profile['deployed_tokens_avg_peak_mc_usd'] = torch.mean(torch.tensor(peak_mcs)).item() if peak_mcs else 0.0
422
+ profile['deployed_tokens_median_peak_mc_usd'] = torch.median(torch.tensor(peak_mcs)).item() if peak_mcs else 0.0
423
+
424
+ def _process_wallet_data(self, wallet_addresses: List[str], token_data: Dict[str, Any], pooler: EmbeddingPooler, T_cutoff: datetime.datetime) -> tuple[Dict[str, Dict[str, Any]], Dict[str, Dict[str, Any]]]:
425
+ """
426
+ Fetches and processes profile, social, and holdings data for a list of wallets.
427
+ Uses a T_cutoff to ensure data is point-in-time accurate.
428
+ """
429
+ if not wallet_addresses:
430
+ return {}, token_data
431
+
432
+ print(f"INFO: Processing wallet data for {len(wallet_addresses)} unique wallets...")
433
+ # Bulk fetch all data
434
+ profiles, socials = self.fetcher.fetch_wallet_profiles_and_socials(wallet_addresses, T_cutoff)
435
+ holdings = self.fetcher.fetch_wallet_holdings(wallet_addresses, T_cutoff)
436
+
437
+ valid_wallets = [addr for addr in wallet_addresses if addr in profiles]
438
+ dropped_wallets = set(wallet_addresses) - set(valid_wallets)
439
+ if dropped_wallets:
440
+ print(f"INFO: Skipping {len(dropped_wallets)} wallets with no profile before cutoff.")
441
+ if not valid_wallets:
442
+ print("INFO: All wallets were graph-only or appeared after cutoff; skipping wallet processing for this token.")
443
+ return {}, token_data
444
+ wallet_addresses = valid_wallets
445
+
446
+ # --- NEW: Collect all unique mints from holdings to fetch their data ---
447
+ all_holding_mints = set()
448
+ for wallet_addr in wallet_addresses:
449
+ for holding_item in holdings.get(wallet_addr, []):
450
+ if 'mint_address' in holding_item:
451
+ all_holding_mints.add(holding_item['mint_address'])
452
+
453
+ # --- NEW: Process all discovered tokens with point-in-time logic ---
454
+ # 1. Fetch raw data for all newly found tokens from holdings.
455
+ # 2. Process this raw data to get embedding indices and add to the pooler.
456
+ # Note: _process_token_data is designed to take a list and return a dict.
457
+ # We pass the addresses and let it handle the fetching and processing internally.
458
+ processed_new_tokens = self._process_token_data(list(all_holding_mints), pooler, T_cutoff)
459
+ # 3. Merge the fully processed new tokens with the existing main token data.
460
+ all_token_data = {**token_data, **(processed_new_tokens or {})}
461
+
462
+ # --- NEW: Calculate deployed token stats using point-in-time logic ---
463
+ self._calculate_deployed_token_stats(profiles, T_cutoff)
464
+
465
+ # --- Assemble the final wallet dictionary ---
466
+ # This structure is exactly what the WalletEncoder expects.
467
+ final_wallets = {}
468
+ for addr in wallet_addresses:
469
+
470
+ # --- Define all expected numerical keys for a profile ---
471
+ # This prevents KeyErrors if the DB returns a partial profile.
472
+ expected_profile_keys = [
473
+ 'age', 'deployed_tokens_count', 'deployed_tokens_migrated_pct',
474
+ 'deployed_tokens_avg_lifetime_sec', 'deployed_tokens_avg_peak_mc_usd',
475
+ 'deployed_tokens_median_peak_mc_usd', 'balance', 'transfers_in_count',
476
+ 'transfers_out_count', 'spl_transfers_in_count', 'spl_transfers_out_count',
477
+ 'total_buys_count', 'total_sells_count', 'total_winrate',
478
+ 'stats_1d_realized_profit_sol', 'stats_1d_realized_profit_pnl', 'stats_1d_buy_count',
479
+ 'stats_1d_sell_count', 'stats_1d_transfer_in_count', 'stats_1d_transfer_out_count',
480
+ 'stats_1d_avg_holding_period', 'stats_1d_total_bought_cost_sol', 'stats_1d_total_sold_income_sol',
481
+ 'stats_1d_total_fee', 'stats_1d_winrate', 'stats_1d_tokens_traded',
482
+ 'stats_7d_realized_profit_sol', 'stats_7d_realized_profit_pnl', 'stats_7d_buy_count', 'stats_7d_sell_count', 'stats_7d_transfer_in_count', 'stats_7d_transfer_out_count', 'stats_7d_avg_holding_period', 'stats_7d_total_bought_cost_sol', 'stats_7d_total_sold_income_sol', 'stats_7d_total_fee', 'stats_7d_winrate', 'stats_7d_tokens_traded'
483
+ ]
484
+ # --- FIXED: Use .get() and provide a default empty dict if not found ---
485
+ # --- NEW: If a wallet profile doesn't exist in the DB, skip it entirely. ---
486
+ # This removes the old logic that created a placeholder profile with zeroed-out features.
487
+ # "If it doesn't exist, it doesn't exist."
488
+ profile_data = profiles.get(addr, None)
489
+ if not profile_data:
490
+ print(f"INFO: Wallet {addr} found in graph but has no profile in DB. Skipping this wallet.")
491
+ continue
492
+
493
+ # --- NEW: Ensure all expected keys exist in the fetched profile ---
494
+ for key in expected_profile_keys:
495
+ profile_data.setdefault(key, 0.0) # Use 0.0 as a safe default for any missing numerical key
496
+
497
+ social_data = socials.get(addr, {})
498
+
499
+ # --- NEW: Derive boolean social flags based on schema ---
500
+ social_data['has_pf_profile'] = bool(social_data.get('pumpfun_username'))
501
+ social_data['has_twitter'] = bool(social_data.get('twitter_username'))
502
+ social_data['has_telegram'] = bool(social_data.get('telegram_channel'))
503
+ # 'is_exchange_wallet' is not in the schema, so we'll default to False for now.
504
+ # This is a feature that would likely come from a 'tags' column or a separate service.
505
+ social_data['is_exchange_wallet'] = 'exchange_wallet' in profile_data.get('tags', [])
506
+
507
+ # --- NEW: Calculate 'age' based on user's logic ---
508
+ funded_ts = profile_data.get('funded_timestamp', 0)
509
+ if funded_ts and funded_ts > 0:
510
+ # Calculate age in seconds from the funding timestamp
511
+ age_seconds = int(datetime.datetime.now(datetime.timezone.utc).timestamp()) - funded_ts
512
+ else:
513
+ # Fallback for wallets older than our DB window, as requested
514
+ # 5 months * 30 days/month * 24 hours/day * 3600 seconds/hour
515
+ age_seconds = 12_960_000
516
+
517
+ # Add the calculated age to the profile data that the WalletEncoder will receive
518
+ profile_data['age'] = float(age_seconds)
519
+
520
+ # Get the username and add it to the embedding pooler
521
+ username = social_data.get('pumpfun_username') or social_data.get('twitter_username') or social_data.get('kolscan_name')
522
+
523
+ if isinstance(username, str) and username.strip():
524
+ social_data['username_emb_idx'] = pooler.get_idx(username.strip())
525
+ else:
526
+ social_data['username_emb_idx'] = 0 # means "no embedding"
527
+
528
+ # --- NEW: Filter holdings and calculate derived features ---
529
+ # We create a new list `valid_wallet_holdings` to ensure that if a holding's
530
+ # token is invalid (filtered out by _process_token_data), the entire holding
531
+ # row is removed and not passed to the WalletEncoder.
532
+ original_holdings = holdings.get(addr, [])
533
+ valid_wallet_holdings = []
534
+ now_ts = datetime.datetime.now(datetime.timezone.utc)
535
+ for holding_item in original_holdings:
536
+ # 1. Calculate holding_time
537
+ start_ts = holding_item.get('start_holding_at')
538
+ mint_addr = holding_item.get('mint_address')
539
+ token_info = all_token_data.get(mint_addr)
540
+
541
+ if not token_info:
542
+ print(f"INFO: Skipping holding for token {mint_addr} in wallet {addr} because token data is invalid/missing.")
543
+ continue
544
+
545
+ end_ts = holding_item.get('end_holding_at')
546
+ if not start_ts:
547
+ holding_item['holding_time'] = 0.0
548
+ else:
549
+ end_ts = end_ts or now_ts
550
+ holding_item['holding_time'] = (end_ts - start_ts).total_seconds()
551
+
552
+ # 2. Calculate balance_pct_to_supply
553
+ if token_info and token_info.get('total_supply', 0) > 0:
554
+ total_supply = token_info['total_supply'] / (10**token_info.get('decimals', 9))
555
+ current_balance = holding_item.get('current_balance', 0.0)
556
+ holding_item['balance_pct_to_supply'] = (current_balance / total_supply) if total_supply > 0 else 0.0
557
+ else:
558
+ holding_item['balance_pct_to_supply'] = 0.0
559
+
560
+ # 3. --- NEW: Calculate bought_amount_sol_pct_to_native_balance ---
561
+ # This uses the historically accurate native balance from the profile.
562
+ wallet_native_balance = profile_data.get('balance', 0.0)
563
+ bought_cost_sol = holding_item.get('history_bought_cost_sol', 0.0)
564
+ if wallet_native_balance > 1e-9: # Use a small epsilon to avoid division by zero
565
+ holding_item['bought_amount_sol_pct_to_native_balance'] = bought_cost_sol / wallet_native_balance
566
+ else:
567
+ holding_item['bought_amount_sol_pct_to_native_balance'] = 0.0
568
+
569
+ valid_wallet_holdings.append(holding_item)
570
+
571
+
572
+ final_wallets[addr] = {
573
+ 'profile': profile_data,
574
+ 'socials': social_data,
575
+ 'holdings': valid_wallet_holdings
576
+ }
577
+
578
+ return final_wallets, all_token_data
579
+
580
+ def _process_token_data(self, token_addresses: List[str], pooler: EmbeddingPooler, T_cutoff: datetime.datetime, token_data: Optional[Dict] = None) -> Dict[str, Dict[str, Any]]:
581
+ """
582
+ Fetches and processes static data for a list of tokens.
583
+ """
584
+ if not token_addresses:
585
+ return {}
586
+
587
+ if token_data is None:
588
+ print(f"INFO: Processing token data for {len(token_addresses)} unique tokens...")
589
+ token_data = self.fetcher.fetch_token_data(token_addresses, T_cutoff)
590
+
591
+ # --- NEW: Print the raw fetched token data as requested ---
592
+ print("\n--- RAW TOKEN DATA FROM DATABASE ---")
593
+ print(token_data)
594
+
595
+ # Add pre-computed embedding indices to the token data
596
+ # --- CRITICAL FIX: This function now returns None if the main token is invalid ---
597
+ valid_token_data = {}
598
+ for addr, data in token_data.items():
599
+ # --- FIXED: Only add to pooler if data is valid ---
600
+ image = None
601
+ token_uri = data.get('token_uri')
602
+
603
+ # --- NEW: Use multiple IPFS gateways for reliability ---
604
+ if token_uri and isinstance(token_uri, str) and token_uri.strip():
605
+
606
+ ipfs_gateways = [
607
+ "https://pump.mypinata.cloud/ipfs/",
608
+ "https://dweb.link/ipfs/",
609
+ "https://cloudflare-ipfs.com/ipfs/",
610
+ ]
611
+
612
+ try:
613
+ # Handle IPFS URIs for metadata
614
+ if 'ipfs/' in token_uri:
615
+ metadata_hash = token_uri.split('ipfs/')[-1]
616
+ # Try fetching from multiple gateways
617
+ for gateway in ipfs_gateways:
618
+ try:
619
+ metadata_resp = self.http_session.get(f"{gateway}{metadata_hash}", timeout=5)
620
+ metadata_resp.raise_for_status()
621
+ metadata = metadata_resp.json()
622
+ break # Success, exit loop
623
+ except requests.RequestException:
624
+ continue # Try next gateway
625
+ else: # If all gateways fail
626
+ raise requests.RequestException("All IPFS gateways failed for metadata.")
627
+ else: # Handle regular HTTP URIs
628
+ metadata_resp = self.http_session.get(token_uri, timeout=5)
629
+ metadata_resp.raise_for_status()
630
+ metadata = metadata_resp.json()
631
+
632
+ # 1. Fetch metadata JSON from token_uri
633
+ image_url = metadata.get('image', '')
634
+
635
+ # --- FIXED: Apply the same multi-gateway logic to image fetching ---
636
+ if image_url:
637
+ # Handle IPFS URIs for the image
638
+ if 'ipfs/' in image_url:
639
+ image_hash = image_url.split('ipfs/')[-1]
640
+ # Try fetching image from multiple gateways
641
+ for gateway in ipfs_gateways:
642
+ try:
643
+ image_resp = self.http_session.get(f"{gateway}{image_hash}", timeout=10)
644
+ image_resp.raise_for_status()
645
+ image = Image.open(BytesIO(image_resp.content))
646
+ break # Success, exit loop
647
+ except requests.RequestException:
648
+ continue # Try next gateway
649
+ else: # If all gateways fail for the image
650
+ raise requests.RequestException("All IPFS gateways failed for image.")
651
+ else: # Handle regular HTTP image URLs
652
+ image_resp = self.http_session.get(image_url, timeout=10)
653
+ image_resp.raise_for_status()
654
+ image = Image.open(BytesIO(image_resp.content))
655
+ except (requests.RequestException, ValueError, IOError) as e:
656
+ print(f"WARN: Could not fetch or process image for token {addr} from URI {token_uri}. Reason: {e}")
657
+ image = None # Ensure image is None on failure
658
+
659
+ # --- FIXED: Check for valid metadata before adding to pooler ---
660
+ token_name = data.get('name') if data.get('name') and data.get('name').strip() else None
661
+ token_symbol = data.get('symbol') if data.get('symbol') and data.get('symbol').strip() else None
662
+
663
+ # --- IMAGE IS A FUCKING MUST
664
+ # --- FIXED: Correctly handle invalid secondary tokens without aborting the whole process ---
665
+ if not token_name or not token_symbol or not image:
666
+ if not token_name: reason = "name"
667
+ elif not token_symbol: reason = "symbol"
668
+ else: reason = "image (fetch failed)"
669
+
670
+ print(f"WARN: Token {addr} is missing essential metadata ('{reason}'). This token will be skipped.")
671
+
672
+ # If this function was called with only one token, it's the main token.
673
+ # If the main token is invalid, the whole sample is invalid, so return None.
674
+ if len(token_addresses) == 1:
675
+ return None
676
+ # Otherwise, it's a secondary token. Skip it and continue with the others.
677
+ continue
678
+
679
+ # --- NEW: Add is_vanity feature based on the token address ---
680
+ data['is_vanity'] = addr.lower().endswith("pump")
681
+
682
+ data['image_emb_idx'] = pooler.get_idx(image)
683
+ data['name_emb_idx'] = pooler.get_idx(token_name)
684
+ data['symbol_emb_idx'] = pooler.get_idx(token_symbol)
685
+
686
+ # FIX: Validate the protocol ID ---
687
+ # The DB might return an ID that is out of bounds for our nn.Embedding layer.
688
+ # We must ensure the ID is valid or map it to a default 'Unknown' ID.
689
+ raw_protocol_id = data.get('protocol')
690
+ if raw_protocol_id is not None and 0 <= raw_protocol_id < vocab.NUM_PROTOCOLS:
691
+ data['protocol'] = raw_protocol_id
692
+ else:
693
+ data['protocol'] = vocab.PROTOCOL_TO_ID.get('Unknown', 0)
694
+
695
+ valid_token_data[addr] = data
696
+
697
+ return valid_token_data
698
+
699
+ def _generate_ohlc(self, aggregation_trades: List[Dict[str, Any]], T_cutoff: datetime.datetime, interval_seconds: int) -> List[tuple]:
700
+ """
701
+ Generates an OHLC series from a list of aggregated trades with a dynamic interval.
702
+ It forward-fills gaps and extends the series up to T_cutoff.
703
+ Returns a list of (timestamp, open, close) tuples.
704
+ """
705
+ if not aggregation_trades:
706
+ return []
707
+
708
+ trades_by_interval = defaultdict(list)
709
+ for trade in aggregation_trades:
710
+ # Group trades into interval buckets
711
+ interval_start_ts = (trade['timestamp'] // interval_seconds) * interval_seconds
712
+ trades_by_interval[interval_start_ts].append(trade['price_usd'])
713
+
714
+ sorted_intervals = sorted(trades_by_interval.keys())
715
+
716
+ if not sorted_intervals:
717
+ return []
718
+
719
+ full_ohlc = []
720
+ start_ts = sorted_intervals[0]
721
+ end_ts = int(T_cutoff.timestamp())
722
+ # Align end_ts to the interval grid
723
+ end_ts = (end_ts // interval_seconds) * interval_seconds
724
+ last_price = aggregation_trades[0]['price_usd']
725
+
726
+ # --- NEW: Debugging log for trades grouped by interval ---
727
+ print(f"\n[DEBUG] OHLC Generation: Trades grouped by interval bucket:")
728
+ print(dict(trades_by_interval))
729
+
730
+ for ts in range(start_ts, end_ts + 1, interval_seconds):
731
+ if ts in trades_by_interval:
732
+ prices = trades_by_interval[ts]
733
+ open_price = prices[0]
734
+ close_price = prices[-1]
735
+ full_ohlc.append((ts, open_price, close_price))
736
+ last_price = close_price
737
+ else:
738
+ full_ohlc.append((ts, last_price, last_price))
739
+ return full_ohlc
740
+
741
+ def __getitem__(self, idx: int) -> Optional[Dict[str, Any]]:
742
+ """
743
+ Loads a pre-processed data item from the cache, or generates it on-the-fly
744
+ if the dataset is in online mode.
745
+ """
746
+ if self.cache_dir:
747
+ if idx >= len(self.cached_files):
748
+ raise IndexError(f"Index {idx} out of range for {len(self.cached_files)} cached files.")
749
+ filepath = self.cached_files[idx]
750
+ try:
751
+ # Use map_location to avoid issues if cached on GPU and loading on CPU
752
+ return torch.load(filepath, map_location='cpu')
753
+ except Exception as e:
754
+ print(f"ERROR: Could not load or process cached item {filepath}: {e}")
755
+ return None # DataLoader can be configured to skip None items
756
+
757
+ # Fallback to online generation if no cache_dir is set
758
+ return self.__cacheitem__(idx)
759
+
760
+ def __cacheitem__(self, idx: int) -> Optional[Dict[str, Any]]:
761
+ """
762
+ The main data loading method. For a given token, it fetches all
763
+ relevant on-chain and off-chain data, processes it, and returns
764
+ a structured dictionary for the collator.
765
+ """
766
+
767
+ if not self.sampled_mints:
768
+ raise RuntimeError("Dataset has no mint records loaded; ensure fetcher returned data during initialization.")
769
+ if idx >= len(self.sampled_mints):
770
+ raise IndexError(f"Requested sample index {idx} exceeds loaded mint count {len(self.sampled_mints)}.")
771
+ initial_mint_record = self.sampled_mints[idx]
772
+ t0 = initial_mint_record["timestamp"]
773
+ creator_address = initial_mint_record['creator_address']
774
+ token_address = initial_mint_record['mint_address']
775
+ print(f"\n--- Building dataset for token: {token_address} ---")
776
+
777
+ # The EmbeddingPooler is crucial for collecting unique text/images per sample
778
+ pooler = EmbeddingPooler()
779
+
780
+ def _safe_int(value: Any) -> int:
781
+ try:
782
+ return int(value)
783
+ except (TypeError, ValueError):
784
+ return 0
785
+
786
+ def _timestamp_to_order_value(ts_value: Any) -> float:
787
+ if isinstance(ts_value, datetime.datetime):
788
+ if ts_value.tzinfo is None:
789
+ ts_value = ts_value.replace(tzinfo=datetime.timezone.utc)
790
+ return ts_value.timestamp()
791
+ try:
792
+ return float(ts_value)
793
+ except (TypeError, ValueError):
794
+ return 0.0
795
+
796
+ def _event_execution_sort_key(timestamp_value: Any,
797
+ slot: Any = 0,
798
+ transaction_index: Any = 0,
799
+ instruction_index: Any = 0,
800
+ signature: str = '') -> tuple:
801
+ return (
802
+ _timestamp_to_order_value(timestamp_value),
803
+ _safe_int(slot),
804
+ _safe_int(transaction_index),
805
+ _safe_int(instruction_index),
806
+ signature or ''
807
+ )
808
+
809
+
810
+
811
+ # 1. Fetch anchor Mint event to establish the timeline & initial entities
812
+ # --- SIMPLIFIED: Use the mint record we already have ---
813
+ mint_event = {
814
+ 'event_type': 'Mint',
815
+ 'timestamp': int(initial_mint_record['timestamp'].timestamp()),
816
+ 'relative_ts': 0,
817
+ 'wallet_address': initial_mint_record['creator_address'],
818
+ 'token_address': token_address,
819
+ 'protocol_id': initial_mint_record.get('protocol')
820
+ }
821
+
822
+ initial_entities = {mint_event['wallet_address']}
823
+ event_sequence_entries: List[Tuple[tuple, Dict[str, Any]]] = []
824
+
825
+ def _register_event(event: Dict[str, Any], sort_key: tuple):
826
+ event_sequence_entries.append((sort_key, event))
827
+
828
+ _register_event(mint_event, _event_execution_sort_key(mint_event['timestamp'], signature='Mint'))
829
+
830
+ # Determine the cutoff time for all historical data fetching
831
+ # T_cutoff = datetime.datetime.fromtimestamp(event_sequence[-1]['timestamp'], tz=datetime.timezone.utc)
832
+ # --- MODIFIED: Set T_cutoff to mint timestamp + 1 day ---
833
+ T_cutoff = initial_mint_record['timestamp'] + datetime.timedelta(seconds=self.t_cutoff_seconds)
834
+ max_horizon_seconds = max(self.horizons_seconds) if self.horizons_seconds else 0
835
+ future_trades_for_labels: List[Dict[str, Any]] = []
836
+ if self.num_outputs > 0 and max_horizon_seconds > 0:
837
+ future_window_end = T_cutoff + datetime.timedelta(seconds=max_horizon_seconds)
838
+ future_trades_for_labels = self.fetcher.fetch_future_trades_for_token(
839
+ token_address, T_cutoff, future_window_end
840
+ )
841
+ if not future_trades_for_labels:
842
+ print(f"INFO: Skipping token {token_address} (no future trades beyond cutoff).")
843
+ return None
844
+
845
+ # --- NEW: Accumulate all wallets before hitting Neo4j to avoid duplicate queries ---
846
+ graph_seed_entities = set(initial_entities)
847
+ all_graph_entities: Dict[str, str] = {mint_event['wallet_address']: 'Wallet'}
848
+ all_graph_entity_addrs = set(all_graph_entities.keys())
849
+ graph_links: Dict[str, Any] = {}
850
+
851
+ # 3. Fetch trades and add traders to the entity set
852
+ # --- REFACTORED: Fetch trades using the new 3-part HBH system ---
853
+ early_trades, middle_trades, recent_trades = self.fetcher.fetch_trades_for_token(
854
+ token_address, T_cutoff, EVENT_COUNT_THRESHOLD_FOR_HBH, HBH_EARLY_EVENT_LIMIT, HBH_RECENT_EVENT_LIMIT
855
+ )
856
+ def _trade_execution_sort_key(trade: Dict[str, Any]) -> tuple:
857
+ return (
858
+ _timestamp_to_order_value(trade.get('timestamp')),
859
+ _safe_int(trade.get('slot')),
860
+ _safe_int(trade.get('transaction_index')),
861
+ _safe_int(trade.get('instruction_index')),
862
+ trade.get('signature', '')
863
+ )
864
+
865
+ early_trades = sorted(early_trades, key=_trade_execution_sort_key)
866
+ middle_trades = sorted(middle_trades, key=_trade_execution_sort_key)
867
+ recent_trades = sorted(recent_trades, key=_trade_execution_sort_key)
868
+
869
+ # --- NEW: Inject special context tokens to mark HBH boundaries ---
870
+ # 'Middle' marks the start of the blurry middle window
871
+ if middle_trades:
872
+ mid_ts_val = _timestamp_to_order_value(middle_trades[0].get('timestamp'))
873
+ middle_event = {
874
+ 'event_type': 'Middle',
875
+ 'timestamp': int(mid_ts_val),
876
+ 'relative_ts': mid_ts_val - _timestamp_to_order_value(t0)
877
+ }
878
+ _register_event(middle_event, _event_execution_sort_key(mid_ts_val, signature='Middle'))
879
+
880
+ # 'RECENT' marks the start of the high-definition recent window
881
+ if recent_trades:
882
+ rec_ts_val = _timestamp_to_order_value(recent_trades[0].get('timestamp'))
883
+ recent_event = {
884
+ 'event_type': 'RECENT',
885
+ 'timestamp': int(rec_ts_val),
886
+ 'relative_ts': rec_ts_val - _timestamp_to_order_value(t0)
887
+ }
888
+ _register_event(recent_event, _event_execution_sort_key(rec_ts_val, signature='RECENT'))
889
+
890
+ # For now, we only process the high-definition segments for event creation,
891
+ # deduplicated in case of overlap between early/recent slices.
892
+ trade_records = []
893
+ seen_trade_keys = set()
894
+ for trade in early_trades + recent_trades:
895
+ dedupe_key = (
896
+ _safe_int(trade.get('slot')),
897
+ _safe_int(trade.get('transaction_index')),
898
+ _safe_int(trade.get('instruction_index')),
899
+ trade.get('signature', '')
900
+ )
901
+ if dedupe_key in seen_trade_keys:
902
+ continue
903
+ seen_trade_keys.add(dedupe_key)
904
+ trade_records.append(trade)
905
+
906
+ for trade in trade_records:
907
+ trader_addr = trade['maker']
908
+ if trader_addr not in all_graph_entity_addrs:
909
+ all_graph_entity_addrs.add(trader_addr)
910
+ all_graph_entities[trader_addr] = 'Wallet' # Trades are always made by wallets
911
+ graph_seed_entities.add(trader_addr)
912
+
913
+ # --- REFACTORED: Fetch significant transfers, passing total supply for filtering ---
914
+ raw_total_supply = initial_mint_record.get('total_supply', 0)
915
+ base_decimals = initial_mint_record.get('token_decimals', 9)
916
+ total_supply_dec = (raw_total_supply / (10**base_decimals)) if base_decimals > 0 else raw_total_supply
917
+
918
+ # Calculate the minimum amount to be considered a significant transfer
919
+ total_supply_dec = total_supply_dec * MIN_AMOUNT_TRANSFER_SUPPLY # 0.01% of total supply
920
+
921
+ transfer_records = self.fetcher.fetch_transfers_for_token(token_address, T_cutoff, total_supply_dec)
922
+ for transfer in transfer_records:
923
+ src = transfer.get('source')
924
+ dst = transfer.get('destination')
925
+ if src:
926
+ all_graph_entities[src] = 'Wallet'
927
+ graph_seed_entities.add(src)
928
+ if dst:
929
+ all_graph_entities[dst] = 'Wallet'
930
+ graph_seed_entities.add(dst)
931
+
932
+ # --- NEW: Fetch pool creation events to enrich entity set and token list ---
933
+ pool_creation_records = self.fetcher.fetch_pool_creations_for_token(token_address, T_cutoff)
934
+ pool_quote_addresses = set()
935
+ pool_metadata_by_address: Dict[str, Dict[str, Any]] = {}
936
+ for pool_record in pool_creation_records:
937
+ creator_addr = pool_record.get('creator_address')
938
+ if creator_addr:
939
+ all_graph_entities[creator_addr] = 'Wallet'
940
+ graph_seed_entities.add(creator_addr)
941
+ quote_addr = pool_record.get('quote_address')
942
+ if quote_addr:
943
+ pool_quote_addresses.add(quote_addr)
944
+ # Mark discovered quote tokens so they can be fetched later if needed
945
+ all_graph_entities.setdefault(quote_addr, 'Token')
946
+ pool_addr = pool_record.get('pool_address')
947
+ if pool_addr:
948
+ pool_metadata_by_address[pool_addr] = {
949
+ 'quote_token_address': quote_addr,
950
+ 'quote_decimals': pool_record.get('quote_decimals'),
951
+ 'base_decimals': pool_record.get('base_decimals')
952
+ }
953
+
954
+ liquidity_change_records = self.fetcher.fetch_liquidity_changes_for_pools(list(pool_metadata_by_address.keys()), T_cutoff)
955
+ for liquidity_record in liquidity_change_records:
956
+ lp_provider = liquidity_record.get('lp_provider')
957
+ if lp_provider:
958
+ all_graph_entities[lp_provider] = 'Wallet'
959
+ graph_seed_entities.add(lp_provider)
960
+
961
+ fee_collection_records = self.fetcher.fetch_fee_collections_for_token(token_address, T_cutoff)
962
+ burn_records = self.fetcher.fetch_burns_for_token(token_address, T_cutoff)
963
+ supply_lock_records = self.fetcher.fetch_supply_locks_for_token(token_address, T_cutoff)
964
+ migration_records = self.fetcher.fetch_migrations_for_token(token_address, T_cutoff)
965
+ # NEW: Fetch top holders to include their wallets so we can embed them
966
+ holder_records = self.fetcher.fetch_token_holders_for_snapshot(token_address, T_cutoff, limit=HOLDER_SNAPSHOT_TOP_K)
967
+ fee_related_mints = set()
968
+ for fee_record in fee_collection_records:
969
+ recipient = fee_record.get('recipient_address')
970
+ if recipient:
971
+ all_graph_entities[recipient] = 'Wallet'
972
+ graph_seed_entities.add(recipient)
973
+ mint_addr = fee_record.get('token_0_mint_address')
974
+ if mint_addr and mint_addr not in (token_address, ''):
975
+ fee_related_mints.add(mint_addr)
976
+ # Include migration pool addresses as tokens/entities if present
977
+ for mig in migration_records:
978
+ vpool = mig.get('virtual_pool_address')
979
+ paddr = mig.get('pool_address')
980
+ if vpool:
981
+ all_graph_entities.setdefault(vpool, 'Token')
982
+ if paddr:
983
+ all_graph_entities.setdefault(paddr, 'Token')
984
+
985
+ # Include burner wallets in entity set
986
+ for burn in burn_records:
987
+ src = burn.get('source')
988
+ if src:
989
+ all_graph_entities[src] = 'Wallet'
990
+ graph_seed_entities.add(src)
991
+ # Include holder wallets in entity set for embedding availability
992
+ for rec in holder_records:
993
+ wa = rec.get('wallet_address')
994
+ if wa:
995
+ all_graph_entities[wa] = 'Wallet'
996
+ graph_seed_entities.add(wa)
997
+ # Include lockers in entity set
998
+ for lock in supply_lock_records:
999
+ sender = lock.get('sender')
1000
+ recipient = lock.get('recipient')
1001
+ if sender:
1002
+ all_graph_entities[sender] = 'Wallet'
1003
+ graph_seed_entities.add(sender)
1004
+ if recipient:
1005
+ all_graph_entities[recipient] = 'Wallet'
1006
+ graph_seed_entities.add(recipient)
1007
+
1008
+ # --- NEW: Now that all wallets are known, fetch graph links once ---
1009
+ if graph_seed_entities:
1010
+ fetched_graph_entities, graph_links = self.fetcher.fetch_graph_links(
1011
+ list(graph_seed_entities),
1012
+ T_cutoff=T_cutoff,
1013
+ max_degrees=2
1014
+ )
1015
+ for addr, entity_type in fetched_graph_entities.items():
1016
+ all_graph_entities[addr] = entity_type
1017
+ all_graph_entity_addrs = set(all_graph_entities.keys())
1018
+
1019
+ # 4. Fetch and process static data for the main token
1020
+ tokens_to_fetch = [token_address]
1021
+ for quote_addr in pool_quote_addresses:
1022
+ if quote_addr and quote_addr not in tokens_to_fetch:
1023
+ tokens_to_fetch.append(quote_addr)
1024
+ for mint_addr in fee_related_mints:
1025
+ if mint_addr and mint_addr not in tokens_to_fetch:
1026
+ tokens_to_fetch.append(mint_addr)
1027
+ main_metadata = {}
1028
+ main_metadata[token_address] = {
1029
+ 'name': initial_mint_record["token_name"],
1030
+ 'symbol': initial_mint_record["token_symbol"],
1031
+ 'token_uri': initial_mint_record["token_uri"],
1032
+ 'protocol': initial_mint_record["protocol"],
1033
+ 'total_supply': initial_mint_record["total_supply"],
1034
+ 'decimals': initial_mint_record["token_decimals"],
1035
+ 'address': token_address
1036
+ }
1037
+
1038
+ main_token_data = self._process_token_data(tokens_to_fetch, pooler, T_cutoff, main_metadata)
1039
+
1040
+ # --- CRITICAL FIX: If the main token is invalid, skip this entire sample ---
1041
+ if not main_token_data:
1042
+ return None # The specific reason is already logged in _process_token_data
1043
+
1044
+ # 5. Fetch and process data for ALL wallets discovered (from mint, graph, trades, etc.)
1045
+ # --- FIXED: Correctly identify wallets using their entity type from the graph ---
1046
+ wallets_to_fetch = [addr for addr, type in all_graph_entities.items() if type == 'Wallet']
1047
+ # Also include traders from trades, even if they weren't in the graph
1048
+ wallets_to_fetch.extend([trade['maker'] for trade in trade_records if trade['maker'] not in wallets_to_fetch])
1049
+ wallet_data, all_token_data = self._process_wallet_data(list(set(wallets_to_fetch)), main_token_data.copy(), pooler, T_cutoff)
1050
+
1051
+ # 6. Process trades into event format using the now-available wallet_data
1052
+ trade_events = []
1053
+
1054
+ aggregation_trades = []
1055
+ high_def_chart_trades = [] # Early + recent windows use 1s candles
1056
+ middle_chart_trades = [] # Middle window uses 30s candles
1057
+ # --- FIXED: Get main token decimals once before the loop ---
1058
+ main_token_info = main_token_data[token_address]
1059
+ base_decimals = main_token_info.get('decimals', 6)
1060
+ # --- FIXED: Get total_supply directly from the initial mint record ---
1061
+ raw_total_supply = initial_mint_record.get('total_supply', 0)
1062
+ total_supply_dec = (raw_total_supply / (10**base_decimals)) if base_decimals > 0 else raw_total_supply
1063
+ print("SUPPLY", total_supply_dec)
1064
+
1065
+ t0_timestamp = _timestamp_to_order_value(t0)
1066
+
1067
+ for trade in trade_records:
1068
+ # --- NEW: Filter out trades with low USD value ---
1069
+ # This applies to both event creation and chart aggregation.
1070
+ if trade.get('total_usd', 0.0) < self.min_trade_usd:
1071
+ continue
1072
+
1073
+ trade_sort_key = _trade_execution_sort_key(trade)
1074
+ trade_timestamp = trade.get('timestamp')
1075
+ trade_timestamp_value = _timestamp_to_order_value(trade_timestamp)
1076
+ trade_timestamp_int = int(trade_timestamp_value)
1077
+ # --- NEW: Determine event type with priority ---
1078
+ trader_addr = trade['maker']
1079
+ trader_wallet_data = wallet_data.get(trader_addr, {})
1080
+ trader_profile = trader_wallet_data.get('profile', {})
1081
+ trader_socials = trader_wallet_data.get('socials', {})
1082
+
1083
+ KOL_NAME_KEYS = ['kolscan_name', 'cabalspy_name', 'axiom_kol_name']
1084
+ is_kol = any(trader_socials.get(key) for key in KOL_NAME_KEYS if trader_socials)
1085
+ is_profitable = (trader_profile.get('stats_30d_realized_profit_pnl', 0.0) > SMART_WALLET_PNL_THRESHOLD and
1086
+ trader_profile.get('stats_30d_realized_profit_usd', 0.0) > SMART_WALLET_USD_THRESHOLD)
1087
+
1088
+ base_amount_dec = trade.get('base_amount', 0) / (10**base_decimals)
1089
+ is_large_amount = (total_supply_dec > 0 and (base_amount_dec / total_supply_dec) > LARGE_TRADE_SUPPLY_PCT_THRESHOLD)
1090
+
1091
+ if trader_addr == creator_address:
1092
+ event_type = 'Deployer_Trade'
1093
+ elif is_kol or is_profitable:
1094
+ event_type = 'SmartWallet_Trade'
1095
+ elif trade.get('total_usd', 0.0) > LARGE_TRADE_USD_THRESHOLD or is_large_amount:
1096
+ event_type = 'LargeTrade'
1097
+ else:
1098
+ event_type = 'Trade'
1099
+
1100
+ # --- NEW: Get token decimals for accurate calculations ---
1101
+ quote_address = trade.get('quote_address')
1102
+ quote_decimals = QUOTE_TOKEN_DECIMALS.get(quote_address, 9) # Default to 9 for SOL
1103
+
1104
+ quote_amount_dec = trade.get('quote_amount', 0) / (10**quote_decimals)
1105
+
1106
+ # --- NEW: Correctly calculate pre-trade balances ---
1107
+ is_sell = trade.get('trade_type') == 1
1108
+
1109
+ # If it's a sell, the pre-trade base balance was higher.
1110
+ pre_trade_base_balance = (trade.get('base_balance', 0.0) + base_amount_dec) if is_sell else trade.get('base_balance', 0.0)
1111
+ # If it's a buy, the pre-trade quote balance was higher.
1112
+ pre_trade_quote_balance = (trade.get('quote_balance', 0.0) + quote_amount_dec) if not is_sell else trade.get('quote_balance', 0.0)
1113
+
1114
+ # --- NEW: Calculate percentage features with the corrected values ---
1115
+ token_amount_pct = (base_amount_dec / pre_trade_base_balance) if pre_trade_base_balance > 1e-9 else 1.0
1116
+ quote_amount_pct = (quote_amount_dec / pre_trade_quote_balance) if pre_trade_quote_balance > 1e-9 else 1.0
1117
+ is_success = trade.get('success', False)
1118
+ if is_success:
1119
+ chart_entry = {
1120
+ 'trade_direction': 1 if is_sell else 0, # 1 for sell, 0 for buy,
1121
+ 'price_usd': trade.get('price_usd', 0.0),
1122
+ 'timestamp': trade_timestamp_int,
1123
+ 'sort_key': trade_sort_key,
1124
+ }
1125
+ aggregation_trades.append(chart_entry)
1126
+ high_def_chart_trades.append(chart_entry.copy())
1127
+ # --- NEW: Calculate token amount as a percentage of total supply ---
1128
+ token_amount_pct_of_supply = (base_amount_dec / total_supply_dec) if total_supply_dec > 0 else 0.0
1129
+ trade_event = {
1130
+ 'event_type': event_type,
1131
+ 'timestamp': trade_timestamp_int,
1132
+ 'relative_ts': trade_timestamp_value - t0_timestamp,
1133
+ 'wallet_address': trade['maker'],
1134
+ 'token_address': token_address,
1135
+ 'trade_direction': 1 if is_sell else 0, # 1 for sell, 0 for buy
1136
+ 'sol_amount': trade.get('total', 0.0), # Assuming 'total' is the SOL amount
1137
+ 'dex_platform_id': trade.get('platform', 0),
1138
+ 'priority_fee': trade.get('priority_fee', 0.0),
1139
+ 'mev_protection': 1 if trade.get('mev_protection', 0) > 0 else 0, # Convert to binary: 0 for False, 1 for True
1140
+ # --- FIXED: Use the new, correct percentage calculations ---
1141
+ 'token_amount_pct_of_holding': token_amount_pct,
1142
+ 'quote_amount_pct_of_holding': quote_amount_pct,
1143
+ 'slippage': trade.get('slippage', 0.0),
1144
+ 'token_amount_pct_to_total_supply': token_amount_pct_of_supply, # FIXED: Replaced price_impact
1145
+ 'success': is_success,
1146
+ 'is_bundle': False, # Default to False, will be updated below
1147
+ 'total_usd': trade.get('total_usd', 0.0)
1148
+ }
1149
+ trade_events.append(trade_event)
1150
+ _register_event(trade_event, trade_sort_key)
1151
+
1152
+ for trade in middle_trades:
1153
+ # --- NEW: Filter out trades with low USD value from chart aggregation ---
1154
+ if trade.get('total_usd', 0.0) < self.min_trade_usd:
1155
+ continue
1156
+
1157
+ # --- NEW: Correctly calculate pre-trade balances ---
1158
+ is_sell = trade.get('trade_type') == 1
1159
+
1160
+ chart_entry = {
1161
+ 'trade_direction': 1 if is_sell else 0, # 1 for sell, 0 for buy,
1162
+ 'price_usd': trade.get('price_usd', 0.0),
1163
+ 'timestamp': int(_timestamp_to_order_value(trade.get('timestamp'))),
1164
+ 'sort_key': _trade_execution_sort_key(trade),
1165
+ }
1166
+ aggregation_trades.append(chart_entry)
1167
+ middle_chart_trades.append(chart_entry.copy())
1168
+
1169
+ def _finalize_chart_trade_list(trade_list: List[Dict[str, Any]]):
1170
+ trade_list.sort(key=lambda x: x['sort_key'])
1171
+ for entry in trade_list:
1172
+ entry.pop('sort_key', None)
1173
+
1174
+ _finalize_chart_trade_list(aggregation_trades)
1175
+ _finalize_chart_trade_list(high_def_chart_trades)
1176
+ _finalize_chart_trade_list(middle_chart_trades)
1177
+
1178
+ # --- NEW: Debugging log for all trades used in chart generation ---
1179
+ print(f"\n[DEBUG] Total aggregated trades for OHLC: {len(aggregation_trades)}")
1180
+ if aggregation_trades:
1181
+ print("[DEBUG] First 5 aggregated trades:", aggregation_trades[:5])
1182
+
1183
+ HIGH_DEF_INTERVAL = ("1s", 1)
1184
+ MIDDLE_INTERVAL = ("30s", 30)
1185
+
1186
+ def _emit_chart_segments(trades: List[Dict[str, Any]], interval: tuple, signature_prefix: str):
1187
+ if not trades:
1188
+ return []
1189
+ interval_label, interval_seconds = interval
1190
+ ohlc_series = self._generate_ohlc(trades, T_cutoff, interval_seconds)
1191
+ print(f"[DEBUG] Generated OHLC series ({interval_label}) with {len(ohlc_series)} candles. First 5: {ohlc_series[:5]}")
1192
+ emitted_events = []
1193
+ for idx in range(0, len(ohlc_series), OHLC_SEQ_LEN):
1194
+ segment = ohlc_series[idx:idx + OHLC_SEQ_LEN]
1195
+ if not segment:
1196
+ continue
1197
+ last_ts = segment[-1][0]
1198
+ opens_raw = [s[1] for s in segment]
1199
+ closes_raw = [s[2] for s in segment]
1200
+ chart_event = {
1201
+ 'event_type': 'Chart_Segment',
1202
+ 'timestamp': last_ts,
1203
+ 'relative_ts': last_ts - t0_timestamp,
1204
+ 'opens': self._normalize_price_series(opens_raw),
1205
+ 'closes': self._normalize_price_series(closes_raw),
1206
+ 'i': interval_label
1207
+ }
1208
+ emitted_events.append(chart_event)
1209
+ _register_event(chart_event, _event_execution_sort_key(last_ts, signature=f"{signature_prefix}-{idx}"))
1210
+ return emitted_events
1211
+
1212
+ # --- NEW: Generate Chart_Segment events from aggregated trades ---
1213
+ chart_events = []
1214
+ chart_events.extend(_emit_chart_segments(high_def_chart_trades, HIGH_DEF_INTERVAL, "chart-hd"))
1215
+ chart_events.extend(_emit_chart_segments(middle_chart_trades, MIDDLE_INTERVAL, "chart-mid"))
1216
+
1217
+ # --- NEW: Convert pool creation records into structured events ---
1218
+ SOL_MINT_ADDRESS = 'So11111111111111111111111111111111111111112'
1219
+
1220
+ def _convert_amount_with_decimals(raw_amount: Any, mint_addr: Optional[str]) -> float:
1221
+ if raw_amount is None:
1222
+ return 0.0
1223
+ try:
1224
+ amount_float = float(raw_amount)
1225
+ except (TypeError, ValueError):
1226
+ return 0.0
1227
+ decimals_value = None
1228
+ if mint_addr == SOL_MINT_ADDRESS:
1229
+ decimals_value = QUOTE_TOKEN_DECIMALS.get(SOL_MINT_ADDRESS, 9)
1230
+ elif mint_addr:
1231
+ token_info = all_token_data.get(mint_addr) or main_token_data.get(mint_addr)
1232
+ if token_info:
1233
+ decimals_value = token_info.get('decimals')
1234
+ if decimals_value is None:
1235
+ return amount_float
1236
+ try:
1237
+ decimals_int = max(int(decimals_value), 0)
1238
+ except (TypeError, ValueError):
1239
+ decimals_int = 0
1240
+ if decimals_int <= 0:
1241
+ return amount_float
1242
+ if mint_addr == SOL_MINT_ADDRESS:
1243
+ should_scale = abs(amount_float) >= 1e5
1244
+ else:
1245
+ should_scale = abs(amount_float) >= (10 ** decimals_int)
1246
+ return amount_float / (10 ** decimals_int) if should_scale else amount_float
1247
+
1248
+ pool_created_events = []
1249
+ for pool_record in pool_creation_records:
1250
+ pool_ts_value = _timestamp_to_order_value(pool_record.get('timestamp'))
1251
+ pool_timestamp_int = int(pool_ts_value)
1252
+
1253
+ quote_token_address = pool_record.get('quote_address')
1254
+
1255
+ base_liquidity_raw = pool_record.get('initial_base_liquidity')
1256
+ base_decimals_override = pool_record.get('base_decimals')
1257
+ if base_decimals_override is None:
1258
+ base_decimals_override = main_token_info.get('decimals', base_decimals)
1259
+ base_decimals_value = int(base_decimals_override) if base_decimals_override is not None else int(base_decimals)
1260
+ base_amount_dec = _convert_amount_with_decimals(base_liquidity_raw, token_address)
1261
+
1262
+ quote_liquidity_raw = pool_record.get('initial_quote_liquidity')
1263
+ quote_decimals_override = pool_record.get('quote_decimals')
1264
+ if quote_decimals_override is None:
1265
+ quote_token_info = main_token_data.get(quote_token_address, {})
1266
+ quote_decimals_override = quote_token_info.get('decimals', QUOTE_TOKEN_DECIMALS.get(quote_token_address, 9))
1267
+ if quote_decimals_override is None:
1268
+ quote_decimals_override = 9
1269
+ quote_decimals_value = int(quote_decimals_override)
1270
+ quote_amount_dec = _convert_amount_with_decimals(quote_liquidity_raw, quote_token_address)
1271
+
1272
+ protocol_raw = pool_record.get('protocol')
1273
+ protocol_id = protocol_raw if isinstance(protocol_raw, int) and 0 <= protocol_raw < vocab.NUM_PROTOCOLS else vocab.PROTOCOL_TO_ID.get('Unknown', 0)
1274
+
1275
+ pool_event = {
1276
+ 'event_type': 'PoolCreated',
1277
+ 'timestamp': pool_timestamp_int,
1278
+ 'relative_ts': pool_ts_value - t0_timestamp,
1279
+ 'wallet_address': pool_record.get('creator_address'),
1280
+ 'token_address': token_address,
1281
+ 'protocol_id': protocol_id,
1282
+ 'quote_token_address': quote_token_address,
1283
+ 'base_amount': base_amount_dec,
1284
+ 'quote_amount': quote_amount_dec,
1285
+ 'priority_fee': pool_record.get('priority_fee', 0.0),
1286
+ }
1287
+ pool_created_events.append(pool_event)
1288
+ pool_sort_key = _event_execution_sort_key(
1289
+ pool_ts_value,
1290
+ slot=pool_record.get('slot'),
1291
+ transaction_index=0,
1292
+ instruction_index=0,
1293
+ signature=pool_record.get('signature', '')
1294
+ )
1295
+ _register_event(pool_event, pool_sort_key)
1296
+
1297
+ # --- NEW: Convert liquidity change records into structured events ---
1298
+ liquidity_change_events = []
1299
+ for liquidity_record in liquidity_change_records:
1300
+ pool_address = liquidity_record.get('pool_address')
1301
+ pool_meta = pool_metadata_by_address.get(pool_address, {})
1302
+ quote_token_address = pool_meta.get('quote_token_address')
1303
+
1304
+ quote_decimals_override = pool_meta.get('quote_decimals')
1305
+ if quote_decimals_override is None:
1306
+ quote_token_info = main_token_data.get(quote_token_address, {})
1307
+ quote_decimals_override = quote_token_info.get('decimals', QUOTE_TOKEN_DECIMALS.get(quote_token_address, 9))
1308
+ if quote_decimals_override is None:
1309
+ quote_decimals_override = 9
1310
+
1311
+ quote_amount_raw = liquidity_record.get('quote_amount', 0)
1312
+ quote_decimals_value = int(quote_decimals_override)
1313
+ quote_amount_dec = _convert_amount_with_decimals(quote_amount_raw, quote_token_address)
1314
+
1315
+ liquidity_ts_value = _timestamp_to_order_value(liquidity_record.get('timestamp'))
1316
+ liquidity_timestamp_int = int(liquidity_ts_value)
1317
+
1318
+ protocol_raw = liquidity_record.get('protocol')
1319
+ protocol_id = protocol_raw if isinstance(protocol_raw, int) and 0 <= protocol_raw < vocab.NUM_PROTOCOLS else vocab.PROTOCOL_TO_ID.get('Unknown', 0)
1320
+ change_type_id = int(liquidity_record.get('change_type', 0) or 0)
1321
+
1322
+ liquidity_event = {
1323
+ 'event_type': 'LiquidityChange',
1324
+ 'timestamp': liquidity_timestamp_int,
1325
+ 'relative_ts': liquidity_ts_value - t0_timestamp,
1326
+ 'wallet_address': liquidity_record.get('lp_provider'),
1327
+ 'token_address': token_address,
1328
+ 'protocol_id': protocol_id,
1329
+ 'quote_token_address': quote_token_address,
1330
+ 'change_type_id': change_type_id,
1331
+ 'quote_amount': quote_amount_dec,
1332
+ 'priority_fee': liquidity_record.get('priority_fee', 0.0),
1333
+ 'success': liquidity_record.get('success', False)
1334
+ }
1335
+
1336
+ if quote_token_address:
1337
+ liquidity_change_events.append(liquidity_event)
1338
+ liquidity_sort_key = _event_execution_sort_key(
1339
+ liquidity_ts_value,
1340
+ slot=liquidity_record.get('slot'),
1341
+ transaction_index=0,
1342
+ instruction_index=0,
1343
+ signature=liquidity_record.get('signature', '')
1344
+ )
1345
+ _register_event(liquidity_event, liquidity_sort_key)
1346
+
1347
+ # --- NEW: Convert fee collection records into structured events ---
1348
+ fee_collected_events = []
1349
+ for fee_record in fee_collection_records:
1350
+ fee_ts_value = _timestamp_to_order_value(fee_record.get('timestamp'))
1351
+ fee_timestamp_int = int(fee_ts_value)
1352
+
1353
+ token0_mint = fee_record.get('token_0_mint_address')
1354
+ token1_mint = fee_record.get('token_1_mint_address')
1355
+ token0_amount_raw = fee_record.get('token_0_amount')
1356
+ token1_amount_raw = fee_record.get('token_1_amount')
1357
+
1358
+ sol_amount = 0.0
1359
+ if token0_mint == SOL_MINT_ADDRESS:
1360
+ sol_amount = _convert_amount_with_decimals(token0_amount_raw, SOL_MINT_ADDRESS)
1361
+ elif token1_mint == SOL_MINT_ADDRESS:
1362
+ sol_amount = _convert_amount_with_decimals(token1_amount_raw, SOL_MINT_ADDRESS)
1363
+
1364
+ # Skip if both amounts are zero and no meaningful wallet
1365
+ recipient_addr = fee_record.get('recipient_address')
1366
+ if not recipient_addr:
1367
+ continue
1368
+
1369
+ fee_event = {
1370
+ 'event_type': 'FeeCollected',
1371
+ 'timestamp': fee_timestamp_int,
1372
+ 'relative_ts': fee_ts_value - t0_timestamp,
1373
+ 'wallet_address': recipient_addr,
1374
+ 'token_address': token_address,
1375
+ 'sol_amount': sol_amount,
1376
+ 'priority_fee': fee_record.get('priority_fee', 0.0),
1377
+ 'protocol_id': fee_record.get('protocol', 0),
1378
+ 'success': fee_record.get('success', False),
1379
+ }
1380
+
1381
+ fee_collected_events.append(fee_event)
1382
+ fee_sort_key = _event_execution_sort_key(
1383
+ fee_ts_value,
1384
+ slot=fee_record.get('slot'),
1385
+ transaction_index=0,
1386
+ instruction_index=0,
1387
+ signature=fee_record.get('signature', '')
1388
+ )
1389
+ _register_event(fee_event, fee_sort_key)
1390
+
1391
+ # --- NEW: Convert burn records into structured TokenBurn events ---
1392
+ token_burn_events = []
1393
+ for burn in burn_records:
1394
+ burn_ts_value = _timestamp_to_order_value(burn.get('timestamp'))
1395
+ burn_timestamp_int = int(burn_ts_value)
1396
+
1397
+ amount_dec = burn.get('amount_decimal')
1398
+ if amount_dec is None:
1399
+ raw_amount = burn.get('amount', 0)
1400
+ try:
1401
+ raw_amount = float(raw_amount)
1402
+ except (TypeError, ValueError):
1403
+ raw_amount = 0.0
1404
+ amount_dec = raw_amount / (10**base_decimals) if base_decimals and base_decimals > 0 else raw_amount
1405
+
1406
+ pct_of_supply = (amount_dec / total_supply_dec) if total_supply_dec and total_supply_dec > 0 else 0.0
1407
+
1408
+ burn_event = {
1409
+ 'event_type': 'TokenBurn',
1410
+ 'timestamp': burn_timestamp_int,
1411
+ 'relative_ts': burn_ts_value - t0_timestamp,
1412
+ 'wallet_address': burn.get('source'),
1413
+ 'token_address': token_address,
1414
+ 'amount_pct_of_total_supply': pct_of_supply,
1415
+ 'amount_tokens_burned': amount_dec,
1416
+ 'priority_fee': burn.get('priority_fee', 0.0),
1417
+ 'success': burn.get('success', False),
1418
+ }
1419
+ token_burn_events.append(burn_event)
1420
+ burn_sort_key = _event_execution_sort_key(
1421
+ burn_ts_value,
1422
+ slot=burn.get('slot'),
1423
+ transaction_index=0,
1424
+ instruction_index=0,
1425
+ signature=burn.get('signature', '')
1426
+ )
1427
+ _register_event(burn_event, burn_sort_key)
1428
+
1429
+ # --- NEW: Convert migrations into Migrated events ---
1430
+ for mig in migration_records:
1431
+ mig_ts_value = _timestamp_to_order_value(mig.get('timestamp'))
1432
+ mig_timestamp_int = int(mig_ts_value)
1433
+ prot_raw = mig.get('protocol', 0)
1434
+ protocol_id = prot_raw if isinstance(prot_raw, int) and 0 <= prot_raw < vocab.NUM_PROTOCOLS else vocab.PROTOCOL_TO_ID.get('Unknown', 0)
1435
+ mig_event = {
1436
+ 'event_type': 'Migrated',
1437
+ 'timestamp': mig_timestamp_int,
1438
+ 'relative_ts': mig_ts_value - t0_timestamp,
1439
+ 'protocol_id': protocol_id,
1440
+ }
1441
+ mig_sort_key = _event_execution_sort_key(
1442
+ mig_ts_value,
1443
+ slot=mig.get('slot'),
1444
+ transaction_index=0,
1445
+ instruction_index=0,
1446
+ signature=mig.get('signature', '')
1447
+ )
1448
+ _register_event(mig_event, mig_sort_key)
1449
+
1450
+ # NOTE: HolderSnapshot events are generated per-snapshot time inside _generate_onchain_snapshots
1451
+
1452
+ # --- NEW: Convert supply lock records into structured SupplyLock events ---
1453
+ supply_lock_events = []
1454
+ for lock in supply_lock_records:
1455
+ lock_ts_value = _timestamp_to_order_value(lock.get('timestamp'))
1456
+ lock_timestamp_int = int(lock_ts_value)
1457
+
1458
+ # total_locked_amount is Float64, typically already decimal-scaled
1459
+ raw_locked = lock.get('total_locked_amount', 0.0)
1460
+ try:
1461
+ locked_amount = float(raw_locked)
1462
+ except (TypeError, ValueError):
1463
+ locked_amount = 0.0
1464
+
1465
+ pct_of_supply = (locked_amount / total_supply_dec) if total_supply_dec and total_supply_dec > 0 else 0.0
1466
+
1467
+ final_unlock_ts = lock.get('final_unlock_timestamp') or 0
1468
+ try:
1469
+ final_unlock_ts = int(final_unlock_ts)
1470
+ except (TypeError, ValueError):
1471
+ final_unlock_ts = 0
1472
+ lock_duration = max(0, final_unlock_ts - lock_timestamp_int)
1473
+
1474
+ lock_event = {
1475
+ 'event_type': 'SupplyLock',
1476
+ 'timestamp': lock_timestamp_int,
1477
+ 'relative_ts': lock_ts_value - t0_timestamp,
1478
+ 'wallet_address': lock.get('sender'),
1479
+ 'token_address': token_address,
1480
+ 'amount_pct_of_total_supply': pct_of_supply,
1481
+ 'lock_duration': float(lock_duration),
1482
+ 'priority_fee': lock.get('priority_fee', 0.0),
1483
+ 'success': lock.get('success', False),
1484
+ }
1485
+ supply_lock_events.append(lock_event)
1486
+ lock_sort_key = _event_execution_sort_key(
1487
+ lock_ts_value,
1488
+ slot=lock.get('slot'),
1489
+ transaction_index=0,
1490
+ instruction_index=0,
1491
+ signature=lock.get('signature', '')
1492
+ )
1493
+ _register_event(lock_event, lock_sort_key)
1494
+
1495
+ # --- NEW: Process transfer events with strict validation ---
1496
+ transfer_events = []
1497
+ for transfer in transfer_records:
1498
+ print("BOMBOCLAT TRANSFER", transfer)
1499
+ # --- VALIDATION: Ensure the destination wallet has a valid profile ---
1500
+ if transfer['destination'] not in wallet_data:
1501
+ print(f"INFO: Skipping transfer event {transfer['signature']} because destination wallet {transfer['destination']} has no profile.")
1502
+ continue
1503
+
1504
+ # Calculate features
1505
+ token_amount = transfer.get('amount_decimal', 0.0)
1506
+ pct_of_supply = (token_amount / total_supply_dec) if total_supply_dec > 0 else 0.0
1507
+
1508
+ # Reconstruct pre-transfer balance of the source wallet
1509
+ pre_transfer_source_balance = transfer.get('source_balance', 0.0) + token_amount
1510
+ pct_of_holding = (token_amount / pre_transfer_source_balance) if pre_transfer_source_balance > 1e-9 else 1.0
1511
+
1512
+ # --- NEW: Classify LargeTransfer based on supply percentage ---
1513
+ if pct_of_supply > LARGE_TRANSFER_SUPPLY_PCT_THRESHOLD:
1514
+ event_type = 'LargeTransfer'
1515
+ else:
1516
+ event_type = 'Transfer'
1517
+
1518
+ transfer_ts_value = _timestamp_to_order_value(transfer.get('timestamp'))
1519
+ transfer_event = {
1520
+ 'event_type': event_type,
1521
+ 'timestamp': int(transfer_ts_value),
1522
+ 'relative_ts': transfer_ts_value - t0_timestamp,
1523
+ 'wallet_address': transfer['source'],
1524
+ 'destination_wallet_address': transfer['destination'],
1525
+ 'token_address': token_address,
1526
+ 'token_amount': token_amount,
1527
+ 'transfer_pct_of_total_supply': pct_of_supply,
1528
+ 'transfer_pct_of_holding': pct_of_holding,
1529
+ 'priority_fee': transfer.get('priority_fee', 0.0)
1530
+ }
1531
+ transfer_events.append(transfer_event)
1532
+ transfer_sort_key = _event_execution_sort_key(
1533
+ transfer_ts_value,
1534
+ slot=transfer.get('slot'),
1535
+ transaction_index=transfer.get('transaction_index'),
1536
+ instruction_index=transfer.get('instruction_index'),
1537
+ signature=transfer.get('signature', '')
1538
+ )
1539
+ _register_event(transfer_event, transfer_sort_key)
1540
+
1541
+ # --- NEW: Correctly detect bundles with a single pass after event creation ---
1542
+ # trade_records are ordered by (timestamp, slot, transaction_index, instruction_index),
1543
+ # so adjacent entries that share a slot belong to the same bundle.
1544
+ if len(trade_records) > 1:
1545
+ for i in range(1, len(trade_records)):
1546
+ if trade_records[i]['slot'] == trade_records[i-1]['slot']:
1547
+ # The corresponding events are at the same indices in trade_events
1548
+ trade_events[i]['is_bundle'] = True
1549
+ trade_events[i-1]['is_bundle'] = True
1550
+
1551
+ # Generate OnChain_Snapshot events using helper
1552
+ self._generate_onchain_snapshots(
1553
+ token_address=token_address,
1554
+ t0_timestamp=t0_timestamp,
1555
+ T_cutoff=T_cutoff,
1556
+ interval_sec=HOLDER_SNAPSHOT_INTERVAL_SEC,
1557
+ trade_events=trade_events,
1558
+ transfer_events=transfer_events,
1559
+ aggregation_trades=aggregation_trades,
1560
+ wallet_data=wallet_data,
1561
+ total_supply_dec=total_supply_dec,
1562
+ _register_event_fn=_register_event
1563
+ )
1564
+
1565
+ # 7. TODO: Fetch social events (tweets, replies, etc.) for all discovered wallets
1566
+ # - Query tables like 'x_posts', 'pump_replies'.
1567
+ # - Use the pooler to get indices for text and media.
1568
+
1569
+ # Sort the combined event sequence by precise execution order
1570
+ event_sequence_entries.sort(key=lambda entry: entry[0])
1571
+ event_sequence = [event for _, event in event_sequence_entries]
1572
+
1573
+ anchor_timestamp_int = int(_timestamp_to_order_value(T_cutoff))
1574
+ anchor_price = None
1575
+ if aggregation_trades:
1576
+ for trade in reversed(aggregation_trades):
1577
+ price_val = trade.get('price_usd')
1578
+ if price_val is not None:
1579
+ anchor_price = float(price_val)
1580
+ break
1581
+ if self.num_outputs > 0 and anchor_price is None:
1582
+ print(f"INFO: Skipping token {token_address} (no pre-cutoff price for labeling).")
1583
+ return None
1584
+
1585
+ future_price_series: List[Tuple[int, float]] = []
1586
+ if (self.num_outputs > 0 and max_horizon_seconds > 0 and
1587
+ anchor_price is not None):
1588
+ timeline = [(anchor_timestamp_int, anchor_price)]
1589
+ for trade in future_trades_for_labels:
1590
+ price_val = trade.get('price_usd')
1591
+ if price_val is None:
1592
+ continue
1593
+ ts_int = int(_timestamp_to_order_value(trade.get('timestamp')))
1594
+ if ts_int <= timeline[-1][0]:
1595
+ continue
1596
+ timeline.append((ts_int, float(price_val)))
1597
+ if len(timeline) > 1:
1598
+ future_price_series = timeline
1599
+
1600
+ debug_label_entries: List[Dict[str, Any]] = []
1601
+ if self.num_outputs > 0:
1602
+ labels_tensor, labels_mask_tensor, debug_label_entries = self._compute_future_return_labels(
1603
+ anchor_price, anchor_timestamp_int, future_price_series
1604
+ )
1605
+ if labels_mask_tensor.sum() == 0:
1606
+ print(f"INFO: Skipping token {token_address} (no valid horizons in future).")
1607
+ return None
1608
+ print("\n[Label Debug]")
1609
+ for entry in debug_label_entries:
1610
+ print(f" Horizon {entry['horizon']}s -> target_ts={entry['target_ts']}, "
1611
+ f"future_price={entry['future_price']}, return={entry['return']:.6f}, "
1612
+ f"mask={int(entry['mask'])}")
1613
+ else:
1614
+ labels_tensor = torch.zeros(0)
1615
+ labels_mask_tensor = torch.zeros(0)
1616
+
1617
+ # For now, we'll return the item with mint and trade events
1618
+ item = {
1619
+ 'event_sequence': event_sequence,
1620
+ 'wallets': wallet_data,
1621
+ 'tokens': all_token_data, # FIXED: Use the comprehensive token data
1622
+ 'graph_links': graph_links, # NEW: Add the fetched graph links
1623
+ 'embedding_pooler': pooler,
1624
+ 'labels': labels_tensor,
1625
+ 'labels_mask': labels_mask_tensor}
1626
+
1627
+ # --- NEW: Comprehensive logging before returning the item ---
1628
+ print("\n--- Dataset Item Generation Summary ---")
1629
+ print(f"Token Address: {token_address}"
1630
+ )
1631
+ print(f"\n[Event Sequence] ({len(item['event_sequence'])} events):")
1632
+ for i, event in enumerate(item['event_sequence']):
1633
+ print(f" - Event {i}: {event}")
1634
+
1635
+ print(f"\n[Wallets] ({len(item['wallets'])} wallets):")
1636
+ for i, (addr, data) in enumerate(item['wallets'].items()):
1637
+ print(f" - Wallet {addr}:")
1638
+ print(f" - Profile: {data.get('profile', {})}")
1639
+ print(f" - Socials: {data.get('socials', {})}")
1640
+
1641
+ print(f"\n[Tokens] ({len(item['tokens'])} tokens):")
1642
+ for addr, data in item['tokens'].items():
1643
+ print(f" - Token {addr}: {data}")
1644
+
1645
+ if self.num_outputs > 0:
1646
+ print(f"\n[Labels]")
1647
+ for h_idx, horizon in enumerate(self.horizons_seconds):
1648
+ offset = h_idx * len(self.quantiles)
1649
+ values = item['labels'][offset:offset + len(self.quantiles)]
1650
+ masks = item['labels_mask'][offset:offset + len(self.quantiles)]
1651
+ print(f" Horizon {horizon}s:")
1652
+ for q_idx, quantile in enumerate(self.quantiles):
1653
+ print(f" q={quantile:.2f}: value={values[q_idx]:.6f}, mask={masks[q_idx]:.0f}")
1654
+
1655
+ print("--- End Summary ---\n")
1656
+
1657
+ return item
data/ohlc_stats.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f39f15281440244b927a46d14a85537afd891163556d46ee3a79c80c25b6f36b
3
+ size 1660
data/preprocess_distribution.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Preprocess distribution statistics for OHLC normalization and token history coverage.
4
+
5
+ This script:
6
+ 1. Computes global mean/std figures for price/volume so downstream code can normalize.
7
+ 2. Prints descriptive stats about how much price history (in seconds) each token has,
8
+ helping decide which horizons are realistic.
9
+
10
+ All configuration is done via environment variables (see below).
11
+ """
12
+
13
+ import os
14
+ import pathlib
15
+ import sys
16
+ from typing import List
17
+
18
+ import numpy as np
19
+ import clickhouse_connect
20
+
21
+
22
+ # --- Configuration (override via env vars if needed) ---
23
+ CLICKHOUSE_HOST = os.getenv("CLICKHOUSE_HOST", "localhost")
24
+ CLICKHOUSE_PORT = int(os.getenv("CLICKHOUSE_PORT", "8123"))
25
+ CLICKHOUSE_USERNAME = os.getenv("CLICKHOUSE_USERNAME", "default")
26
+ CLICKHOUSE_PASSWORD = os.getenv("CLICKHOUSE_PASSWORD", "")
27
+ CLICKHOUSE_DATABASE = os.getenv("CLICKHOUSE_DATABASE", "default")
28
+
29
+ OUTPUT_PATH = pathlib.Path(os.getenv("OHLC_STATS_PATH", "ohlc_stats.npz"))
30
+ MIN_PRICE_USD = float(os.getenv("OHLC_MIN_PRICE_USD", "0.0"))
31
+ MIN_VOLUME_USD = float(os.getenv("OHLC_MIN_VOLUME_USD", "0.0"))
32
+
33
+ TOKEN_ADDRESSES_ENV = os.getenv("OHLC_TOKEN_ADDRESSES", "")
34
+ TOKEN_ADDRESSES = tuple(addr.strip() for addr in TOKEN_ADDRESSES_ENV.split(",") if addr.strip()) or None
35
+
36
+
37
+ def build_where_clause() -> List[str]:
38
+ clauses = ["t.price_usd > %(min_price)s", "t.total_usd > %(min_vol)s"]
39
+ if TOKEN_ADDRESSES:
40
+ clauses.append("t.base_address IN %(token_addresses)s")
41
+ return clauses
42
+
43
+
44
+ def build_stats_query(where_sql: str) -> str:
45
+ return f"""
46
+ SELECT
47
+ AVG(t.price_usd) AS mean_price_usd,
48
+ stddevPop(t.price_usd) AS std_price_usd,
49
+ AVG(t.price) AS mean_price_native,
50
+ stddevPop(t.price) AS std_price_native,
51
+ AVG(t.total_usd) AS mean_trade_value_usd,
52
+ stddevPop(t.total_usd) AS std_trade_value_usd
53
+ FROM trades AS t
54
+ INNER JOIN mints AS m
55
+ ON m.mint_address = t.base_address
56
+ WHERE {where_sql}
57
+ """
58
+
59
+
60
+ def build_history_query(where_sql: str) -> str:
61
+ return f"""
62
+ SELECT
63
+ t.base_address AS token_address,
64
+ toUnixTimestamp(min(t.timestamp)) AS first_ts,
65
+ toUnixTimestamp(max(t.timestamp)) AS last_ts,
66
+ toUnixTimestamp(max(t.timestamp)) - toUnixTimestamp(min(t.timestamp)) AS history_seconds
67
+ FROM trades AS t
68
+ INNER JOIN mints AS m
69
+ ON m.mint_address = t.base_address
70
+ WHERE {where_sql}
71
+ GROUP BY token_address
72
+ """
73
+
74
+
75
+ def summarize_histories(histories: np.ndarray) -> None:
76
+ if histories.size == 0:
77
+ print("No token history stats available (no qualifying trades).")
78
+ return
79
+
80
+ stats = {
81
+ "count": histories.size,
82
+ "min": histories.min(),
83
+ "median": float(np.median(histories)),
84
+ "mean": histories.mean(),
85
+ "p90": float(np.percentile(histories, 90)),
86
+ "max": histories.max(),
87
+ }
88
+
89
+ def format_seconds(sec: float) -> str:
90
+ hours = sec / 3600.0
91
+ days = hours / 24.0
92
+ return f"{sec:.0f}s ({hours:.2f}h / {days:.2f}d)"
93
+
94
+ print("\nToken history coverage (seconds):")
95
+ print(f" Tokens analyzed: {int(stats['count'])}")
96
+ print(f" Min history: {format_seconds(stats['min'])}")
97
+ print(f" Median history: {format_seconds(stats['median'])}")
98
+ print(f" Mean history: {format_seconds(stats['mean'])}")
99
+ print(f" 90th percentile: {format_seconds(stats['p90'])}")
100
+ print(f" Max history: {format_seconds(stats['max'])}")
101
+
102
+
103
+ def main() -> int:
104
+ where_clauses = build_where_clause()
105
+ where_sql = " AND ".join(where_clauses) if where_clauses else "1"
106
+ params: dict[str, object] = {
107
+ "min_price": max(MIN_PRICE_USD, 0.0),
108
+ "min_vol": max(MIN_VOLUME_USD, 0.0),
109
+ }
110
+ if TOKEN_ADDRESSES:
111
+ params["token_addresses"] = TOKEN_ADDRESSES
112
+
113
+ client = clickhouse_connect.get_client(
114
+ host=CLICKHOUSE_HOST,
115
+ port=CLICKHOUSE_PORT,
116
+ username=CLICKHOUSE_USERNAME,
117
+ password=CLICKHOUSE_PASSWORD,
118
+ database=CLICKHOUSE_DATABASE,
119
+ )
120
+
121
+ # --- Price/volume stats ---
122
+ stats_query = build_stats_query(where_sql)
123
+ stats_result = client.query(stats_query, parameters=params)
124
+ if not stats_result.result_rows:
125
+ print("ERROR: Stats query returned no rows. Check filters / connectivity.", file=sys.stderr)
126
+ return 1
127
+ (
128
+ mean_price_usd,
129
+ std_price_usd,
130
+ mean_price_native,
131
+ std_price_native,
132
+ mean_trade_value_usd,
133
+ std_trade_value_usd,
134
+ ) = stats_result.result_rows[0]
135
+
136
+ stats = {
137
+ "mean_price_usd": float(mean_price_usd or 0.0),
138
+ "std_price_usd": float(std_price_usd or 1.0),
139
+ "mean_price_native": float(mean_price_native or 0.0),
140
+ "std_price_native": float(std_price_native or 1.0),
141
+ "mean_trade_value_usd": float(mean_trade_value_usd or 0.0),
142
+ "std_trade_value_usd": float(std_trade_value_usd or 1.0),
143
+ }
144
+
145
+ OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True)
146
+ np.savez(OUTPUT_PATH, **stats)
147
+
148
+ print(f"Saved stats to {OUTPUT_PATH.resolve()}:")
149
+ for key, value in stats.items():
150
+ print(f" {key}: {value:.6f}")
151
+
152
+ # --- Token history coverage ---
153
+ history_query = build_history_query(where_sql)
154
+ history_result = client.query(history_query, parameters=params)
155
+ history_seconds = np.array(
156
+ [float(row[3]) for row in history_result.result_rows if row[3] is not None],
157
+ dtype=np.float64
158
+ )
159
+ summarize_histories(history_seconds)
160
+ return 0
161
+
162
+
163
+ if __name__ == "__main__":
164
+ raise SystemExit(main())
graph_schema.rs ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /// Tracks direct capital flow and identifies funding chains.
2
+ pub struct TransferLink {
3
+ pub signature: String,
4
+ pub source: String,
5
+ pub destination: String,
6
+ pub mint: String,
7
+ pub timestamp: i64,
8
+ }
9
+
10
+ /// Identifies wallets trading the same token in the same slot.
11
+ pub struct BundleTradeLink {
12
+ pub signatures: Vec<String>,
13
+ pub wallet_a: String,
14
+ pub wallet_b: String,
15
+ pub mint: String,
16
+ pub slot: i64,
17
+ pub timestamp: i64,
18
+ }
19
+
20
+ /// Reveals a behavioral pattern of one wallet mirroring another's successful trade.
21
+ pub struct CopiedTradeLink {
22
+ pub timestamp: i64,
23
+ pub leader_buy_sig: String,
24
+ pub leader_sell_sig: String,
25
+ pub follower_buy_sig: String,
26
+ pub follower_sell_sig: String,
27
+ pub follower: String,
28
+ pub leader: String,
29
+ pub mint: String,
30
+ pub time_gap_on_buy_sec: i64,
31
+ pub time_gap_on_sell_sec: i64,
32
+ pub leader_pnl: f64,
33
+ pub follower_pnl: f64,
34
+
35
+ pub leader_buy_total: f64,
36
+ pub leader_sell_total: f64,
37
+
38
+ pub follower_buy_total: f64,
39
+ pub follower_sell_total: f64,
40
+ pub follower_buy_slippage: f32,
41
+ pub follower_sell_slippage: f32,
42
+ }
43
+
44
+ /// Represents a link where a group of wallets re-engage with a token in a coordinated manner.
45
+ pub struct CoordinatedActivityLink {
46
+ pub timestamp: i64,
47
+ pub leader_first_sig: String,
48
+ pub leader_second_sig: String,
49
+ pub follower_first_sig: String,
50
+ pub follower_second_sig: String,
51
+ pub follower: String,
52
+ pub leader: String,
53
+ pub mint: String,
54
+ pub time_gap_on_first_sec: i64,
55
+ pub time_gap_on_second_sec: i64,
56
+ }
57
+
58
+ /// Links a token to its original creator.
59
+ pub struct MintedLink {
60
+ pub signature: String,
61
+ pub timestamp: i64,
62
+ pub buy_amount: f64,
63
+ }
64
+
65
+ /// Connects a token to its successful first-movers.
66
+ pub struct SnipedLink {
67
+ pub timestamp: i64,
68
+ pub signature: String,
69
+ pub rank: i64,
70
+ pub sniped_amount: f64,
71
+ }
72
+
73
+ /// Represents connection between wallet that locked supply.
74
+ pub struct LockedSupplyLink {
75
+ pub timestamp: i64,
76
+ pub signature: String,
77
+ pub amount: f64,
78
+ pub unlock_timestamp: u64,
79
+ }
80
+
81
+ /// link of the wallet that burned tokens.
82
+ pub struct BurnedLink {
83
+ pub signature: String,
84
+ pub amount: f64,
85
+ pub timestamp: i64,
86
+ }
87
+
88
+ /// Identifies wallets that provided liquidity, signaling high conviction.
89
+ pub struct ProvidedLiquidityLink {
90
+ pub signature: String,
91
+ pub wallet: String,
92
+ pub token: String,
93
+ pub pool_address: String,
94
+ pub amount_base: f64,
95
+ pub amount_quote: f64,
96
+ pub timestamp: i64,
97
+ }
98
+
99
+ /// A derived link connecting a token to its largest holders.
100
+ pub struct WhaleOfLink {
101
+ pub timestamp: i64,
102
+ pub wallet: String,
103
+ pub token: String,
104
+ pub holding_pct_at_creation: f32, // Holding % when the link was made
105
+ pub ath_usd_at_creation: f64, // Token's ATH when the link was made
106
+ }
107
+
108
+ /// A derived link connecting a token to its most profitable traders.
109
+ pub struct TopTraderOfLink {
110
+ pub timestamp: i64,
111
+ pub wallet: String,
112
+ pub token: String,
113
+ pub pnl_at_creation: f64, // The PNL that first triggered the link
114
+ pub ath_usd_at_creation: f64, // Token's ATH when the link was made
115
+ }
inference.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference.py
2
+
3
+ import torch
4
+ import traceback
5
+ import time
6
+
7
+ # Import all the necessary components from our project
8
+ from models.model import Oracle
9
+ from data.data_collator import MemecoinCollator
10
+ from models.multi_modal_processor import MultiModalEncoder
11
+ from data.data_loader import OracleDataset
12
+ from data.data_fetcher import DataFetcher
13
+ from models.helper_encoders import ContextualTimeEncoder
14
+ from models.token_encoder import TokenEncoder
15
+ from models.wallet_encoder import WalletEncoder
16
+ from models.graph_updater import GraphUpdater
17
+ from models.ohlc_embedder import OHLCEmbedder
18
+ import models.vocabulary as vocab
19
+
20
+ # --- NEW: Import database clients ---
21
+ from clickhouse_driver import Client as ClickHouseClient
22
+ from neo4j import GraphDatabase
23
+
24
+ # =============================================================================
25
+ # Inference/Test Script for the Oracle Model
26
+ # This script replicates the test logic previously in model.py
27
+ # =============================================================================
28
+ if __name__ == "__main__":
29
+ print("--- Oracle Inference Script (Full Pipeline Test) ---")
30
+
31
+ # --- 1. Define Configs ---
32
+ OHLC_SEQ_LEN = 60
33
+ print(f"Using {vocab.NUM_EVENT_TYPES} event types from vocabulary.")
34
+
35
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
37
+ if device.type == 'cpu': dtype = torch.float32
38
+ print(f"Using device: {device}, dtype: {dtype}")
39
+
40
+ _test_quantiles = [0.1, 0.5, 0.9]
41
+ _test_horizons = [30, 60, 120, 240, 420]
42
+ _test_num_outputs = len(_test_quantiles) * len(_test_horizons)
43
+
44
+ # --- 2. Instantiate ALL Encoders ---
45
+ print("Instantiating encoders (using defaults)...")
46
+ try:
47
+ multi_modal_encoder = MultiModalEncoder(dtype=dtype)
48
+ real_time_enc = ContextualTimeEncoder(dtype=dtype)
49
+
50
+ real_token_enc = TokenEncoder(
51
+ multi_dim=multi_modal_encoder.embedding_dim,
52
+ dtype=dtype
53
+ )
54
+ real_wallet_enc = WalletEncoder(encoder=multi_modal_encoder, dtype=dtype)
55
+ real_graph_upd = GraphUpdater(time_encoder=real_time_enc, dtype=dtype)
56
+
57
+ real_ohlc_emb = OHLCEmbedder(
58
+ num_intervals=vocab.NUM_OHLC_INTERVALS,
59
+ sequence_length=OHLC_SEQ_LEN,
60
+ dtype=dtype
61
+ )
62
+
63
+ print(f"TokenEncoder default output_dim: {real_token_enc.output_dim}")
64
+ print(f"WalletEncoder default d_model: {real_wallet_enc.d_model}")
65
+ print(f"OHLCEmbedder default output_dim: {real_ohlc_emb.output_dim}")
66
+
67
+ print("Encoders instantiated.")
68
+ except Exception as e:
69
+ print(f"Failed to instantiate encoders: {e}")
70
+ traceback.print_exc()
71
+ exit()
72
+
73
+ # --- 3. Instantiate the Collator ---
74
+ collator = MemecoinCollator(
75
+ event_type_to_id=vocab.EVENT_TO_ID,
76
+ device=device,
77
+ multi_modal_encoder=multi_modal_encoder,
78
+ dtype=dtype,
79
+ ohlc_seq_len=OHLC_SEQ_LEN,
80
+ max_seq_len=50
81
+ )
82
+ print("MemecoinCollator (fast batcher) instantiated.")
83
+
84
+ # --- 4. Instantiate the Oracle Model ---
85
+ print("Instantiating Oracle (full pipeline)...")
86
+ model = Oracle(
87
+ token_encoder=real_token_enc,
88
+ wallet_encoder=real_wallet_enc,
89
+ graph_updater=real_graph_upd,
90
+ time_encoder=real_time_enc,
91
+ multi_modal_dim=multi_modal_encoder.embedding_dim,
92
+ num_event_types=vocab.NUM_EVENT_TYPES,
93
+ event_pad_id=vocab.EVENT_TO_ID['__PAD__'],
94
+ event_type_to_id=vocab.EVENT_TO_ID,
95
+ model_config_name="Qwen/Qwen3-0.6B",
96
+ quantiles=_test_quantiles,
97
+ horizons_seconds=_test_horizons,
98
+ dtype=dtype,
99
+ ohlc_embedder=real_ohlc_emb
100
+ ).to(device)
101
+ model.eval()
102
+ print(f"Oracle d_model: {model.d_model}")
103
+
104
+ # --- 5. Create Dataset and run pre-collation step ---
105
+ print("Creating Dataset...")
106
+
107
+ # --- NEW: Initialize real database clients and DataFetcher ---
108
+ try:
109
+ print("Connecting to databases...")
110
+ # ClickHouse running locally on port 8123 with no auth
111
+ clickhouse_client = ClickHouseClient(host='localhost', port=9000)
112
+ # Neo4j running locally on port 7687 with no auth
113
+ neo4j_driver = GraphDatabase.driver("bolt://localhost:7687", auth=None)
114
+
115
+ data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
116
+ print("Database clients and DataFetcher initialized.")
117
+
118
+ # --- Fetch mints to get the first token for processing ---
119
+ all_mints = data_fetcher.get_all_mints()
120
+ if not all_mints:
121
+ print("\n❌ No mints found in the database. Exiting test.")
122
+ exit()
123
+
124
+ # --- FIXED: Instantiate the dataset in REAL mode, removing is_test flag ---
125
+ dataset = OracleDataset(
126
+ data_fetcher=data_fetcher,
127
+ horizons_seconds=_test_horizons,
128
+ quantiles=_test_quantiles,
129
+ max_samples=57)
130
+
131
+ except Exception as e:
132
+ print(f"FATAL: Could not initialize database connections or dataset: {e}")
133
+ traceback.print_exc()
134
+ exit()
135
+
136
+ # --- PRODUCTION-READY: Process a full batch of items from the dataset ---
137
+ print(f"\n--- Processing a batch of up to {len(dataset)} items from the dataset ---")
138
+ batch_items = []
139
+ for i in range(len(dataset)):
140
+ token_addr = dataset.sampled_mints[i].get('mint_address', 'unknown')
141
+ print(f" - Attempting to process sample {i+1}/{len(dataset)} ({token_addr})...")
142
+ fetch_start = time.time()
143
+ sample = dataset[i]
144
+ fetch_elapsed = time.time() - fetch_start
145
+ print(f" ... fetch completed in {fetch_elapsed:.2f}s")
146
+ if sample is not None:
147
+ batch_items.append(sample)
148
+ print(f" ... Success! Sample added to batch.")
149
+
150
+ if not batch_items:
151
+ print("\n❌ No valid samples could be generated from the dataset. Exiting.")
152
+ exit()
153
+
154
+ # --- 6. Run Collator AND Model ---
155
+ print("\n--- Testing Pipeline (Collator + Model.forward) ---")
156
+ try:
157
+ # 1. Collator
158
+ collate_start = time.time()
159
+ collated_batch = collator(batch_items)
160
+ collate_elapsed = time.time() - collate_start
161
+ print("Collation successful!")
162
+ print(f"Collation time for batch of {len(batch_items)} tokens: {collate_elapsed:.2f}s")
163
+
164
+ # --- Check collator output ---
165
+ B = len(batch_items)
166
+ L = collated_batch['attention_mask'].shape[1]
167
+ assert 'ohlc_price_tensors' in collated_batch
168
+ ohlc_price_tensors = collated_batch['ohlc_price_tensors']
169
+ assert ohlc_price_tensors.dim() == 3, f"Expected 3D OHLC tensor, got shape {tuple(ohlc_price_tensors.shape)}"
170
+ assert ohlc_price_tensors.shape[1] == 2, f"Expected OHLC tensor with 2 rows (open/close), got {ohlc_price_tensors.shape[1]}"
171
+ assert ohlc_price_tensors.shape[2] == OHLC_SEQ_LEN, f"Expected OHLC seq len {OHLC_SEQ_LEN}, got {ohlc_price_tensors.shape[2]}"
172
+ assert collated_batch['ohlc_interval_ids'].shape[0] == ohlc_price_tensors.shape[0], "Interval ids must align with OHLC segments"
173
+ assert ohlc_price_tensors.dtype == dtype, f"OHLC tensor dtype {ohlc_price_tensors.dtype} != expected {dtype}"
174
+ print(f"Collator produced {ohlc_price_tensors.shape[0]} OHLC segment(s).")
175
+
176
+ # --- FIXED: Update assertions for event-specific data which is mostly empty for now ---
177
+ assert collated_batch['dest_wallet_indices'].shape == (B, L)
178
+ assert collated_batch['transfer_numerical_features'].shape == (B, L, 4)
179
+ assert collated_batch['trade_numerical_features'].shape == (B, L, 8) # Corrected from 10
180
+ assert collated_batch['deployer_trade_numerical_features'].shape == (B, L, 8) # Corrected from 10
181
+ assert collated_batch['smart_wallet_trade_numerical_features'].shape == (B, L, 8) # Corrected from 10
182
+ assert collated_batch['pool_created_numerical_features'].shape == (B, L, 2)
183
+ assert collated_batch['liquidity_change_numerical_features'].shape == (B, L, 1)
184
+ assert collated_batch['fee_collected_numerical_features'].shape == (B, L, 1)
185
+ assert collated_batch['token_burn_numerical_features'].shape == (B, L, 2)
186
+ assert collated_batch['supply_lock_numerical_features'].shape == (B, L, 2)
187
+ assert collated_batch['onchain_snapshot_numerical_features'].shape == (B, L, 14)
188
+ assert collated_batch['trending_token_numerical_features'].shape == (B, L, 1)
189
+ assert collated_batch['boosted_token_numerical_features'].shape == (B, L, 2)
190
+ # assert len(collated_batch['holder_snapshot_raw_data']) == 1 # No holder snapshots yet
191
+ # assert len(collated_batch['textual_event_data']) == 8 # No textual events yet
192
+ assert collated_batch['dexboost_paid_numerical_features'].shape == (B, L, 2)
193
+ print("Collator correctly processed all event-specific numerical data into their respective tensors.")
194
+
195
+ # --- NEW: Comprehensive Debugging Output ---
196
+ print("\n--- Collated Batch Debug Output ---")
197
+ print(f"Batch Size: {B}, Max Sequence Length: {L}")
198
+
199
+ # Print shapes of key tensors
200
+ print("\n[Core Tensors]")
201
+ print(f" event_type_ids: {collated_batch['event_type_ids'].shape}")
202
+ print(f" attention_mask: {collated_batch['attention_mask'].shape}")
203
+ print(f" timestamps_float: {collated_batch['timestamps_float'].shape}")
204
+
205
+ print("\n[Pointer Tensors]")
206
+ print(f" wallet_indices: {collated_batch['wallet_indices'].shape}")
207
+ print(f" token_indices: {collated_batch['token_indices'].shape}")
208
+
209
+ print("\n[Encoder Inputs]")
210
+ print(f" embedding_pool: {collated_batch['embedding_pool'].shape}")
211
+ # --- FIXED: Check for a key that still exists after removing address embeddings ---
212
+ if collated_batch['token_encoder_inputs']['name_embed_indices'].numel() > 0:
213
+ print(f" token_encoder_inputs contains {collated_batch['token_encoder_inputs']['name_embed_indices'].shape[0]} tokens.")
214
+ else:
215
+ print(" token_encoder_inputs is empty.")
216
+ if collated_batch['wallet_encoder_inputs']['profile_rows']:
217
+ print(f" wallet_encoder_inputs contains {len(collated_batch['wallet_encoder_inputs']['profile_rows'])} wallets.")
218
+ else:
219
+ print(" wallet_encoder_inputs is empty.")
220
+
221
+ print("\n[Graph Links]")
222
+ if collated_batch['graph_updater_links']:
223
+ for link_name, data in collated_batch['graph_updater_links'].items():
224
+ print(f" - {link_name}: {data['edge_index'].shape[1]} edges")
225
+ else:
226
+ print(" No graph links in this batch.")
227
+ print("--- End Debug Output ---\n")
228
+
229
+ print("Embedding pool size:", collated_batch["embedding_pool"].shape[0])
230
+ print("Max name_emb_idx:", collated_batch["token_encoder_inputs"]["name_embed_indices"].max().item())
231
+
232
+ # 2. Model Forward Pass
233
+ with torch.no_grad():
234
+ model_outputs = model(collated_batch)
235
+ quantile_logits = model_outputs["quantile_logits"]
236
+ hidden_states = model_outputs["hidden_states"]
237
+ attention_mask = model_outputs["attention_mask"]
238
+ pooled_states = model_outputs["pooled_states"]
239
+ print("Model forward pass successful!")
240
+
241
+ # --- 7. Verify Output ---
242
+ print("\n--- Test Results ---")
243
+ D_MODEL = model.d_model
244
+
245
+ print(f"Final hidden_states shape: {hidden_states.shape}")
246
+ print(f"Final attention_mask shape: {attention_mask.shape}")
247
+
248
+ assert hidden_states.shape == (B, L, D_MODEL)
249
+ assert attention_mask.shape == (B, L)
250
+ assert hidden_states.dtype == dtype
251
+
252
+ print(f"Output mean (sanity check): {hidden_states.mean().item()}")
253
+ print(f"Pooled state shape: {pooled_states.shape}")
254
+ print(f"Quantile logits shape: {quantile_logits.shape}")
255
+
256
+ quantile_grid = quantile_logits.view(B, len(_test_horizons), len(_test_quantiles))
257
+ print("\n[Quantile Predictions]")
258
+ for b_idx in range(B):
259
+ print(f" Sample {b_idx}:")
260
+ for h_idx, horizon in enumerate(_test_horizons):
261
+ row = quantile_grid[b_idx, h_idx]
262
+ print(f" Horizon {horizon}s -> " + ", ".join(
263
+ f"q={q:.2f}: {row[q_idx].item():.6f}"
264
+ for q_idx, q in enumerate(_test_quantiles)
265
+ ))
266
+
267
+ print("\n✅ **Test Passed!** Full ENCODING pipeline is working.")
268
+
269
+ except Exception as e:
270
+ print(f"\n❌ Error during pipeline test: {e}")
271
+ traceback.print_exc()
link_graph.rs ADDED
@@ -0,0 +1,2275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use crate::aggregator::graph_schema::{
2
+ BundleTradeLink, BurnedLink, CoordinatedActivityLink, CopiedTradeLink, LockedSupplyLink,
3
+ MintedLink, ProvidedLiquidityLink, SnipedLink, TopTraderOfLink, TransferLink, WhaleOfLink,
4
+ };
5
+ use crate::handlers::constants::{
6
+ NATIVE_MINT, PROTOCOL_PUMPFUN_LAUNCHPAD, USD1_MINT, USDC_MINT, USDT_MINT,
7
+ };
8
+ use crate::types::{
9
+ BurnRow, EventPayload, EventType, LiquidityRow, MintRow, SupplyLockRow, TradeRow, TransferRow,
10
+ };
11
+ use anyhow::{Result, anyhow};
12
+ use chrono::Utc;
13
+ use clickhouse::{Client, Row};
14
+ use futures::stream::{self, StreamExt};
15
+ use itertools::Itertools;
16
+ use neo4rs::{BoltType, Graph, query};
17
+ use once_cell::sync::Lazy;
18
+ use serde::Deserialize;
19
+ use solana_sdk::native_token::LAMPORTS_PER_SOL;
20
+ use std::collections::{HashMap, HashSet, VecDeque};
21
+ use std::future::Future;
22
+ use std::str::FromStr;
23
+ use std::sync::Arc;
24
+ use std::sync::atomic::{AtomicUsize, Ordering};
25
+ use std::time::Duration;
26
+ use tokio::sync::{Mutex, mpsc};
27
+ use tokio::time::sleep;
28
+ use tokio::time::{Instant, MissedTickBehavior, interval};
29
+ use tokio::try_join;
30
+
31
+ fn decimals_for_quote(mint: &str) -> u8 {
32
+ if mint == NATIVE_MINT {
33
+ 9
34
+ } else if mint == USDC_MINT || mint == USDT_MINT || mint == USD1_MINT {
35
+ 6
36
+ } else {
37
+ 9 // default assumption if unknown
38
+ }
39
+ }
40
+
41
+ #[derive(Debug)]
42
+ struct LinkGraphConfig {
43
+ time_window_seconds: u32,
44
+ copied_trade_window_seconds: i64,
45
+ sniper_rank_threshold: u64,
46
+ whale_rank_threshold: u64,
47
+ min_top_trader_pnl: f32,
48
+ min_trade_total_usd: f64,
49
+ ath_price_threshold_usd: f64,
50
+ window_max_wait_ms: u64,
51
+ late_slack_ms: u64,
52
+ chunk_size_large: usize,
53
+ chunk_size_historical: usize,
54
+ chunk_size_mint_small: usize,
55
+ chunk_size_mint_large: usize,
56
+ chunk_size_token: usize,
57
+ trade_cache_max_entries: usize,
58
+ trade_cache_ttl_secs: u32,
59
+ trade_cache_max_recent: usize,
60
+ writer_channel_capacity: usize,
61
+ writer_max_batch_rows: usize,
62
+ writer_retry_attempts: u32,
63
+ writer_retry_backoff_ms: u64,
64
+ ath_fetch_chunk_size: usize,
65
+ ch_retry_attempts: u32,
66
+ ch_retry_backoff_ms: u64,
67
+ ch_fail_fast: bool,
68
+ }
69
+
70
+ static LINK_GRAPH_CONFIG: Lazy<LinkGraphConfig> = Lazy::new(|| LinkGraphConfig {
71
+ time_window_seconds: env_parse("LINK_GRAPH_TIME_WINDOW_SECONDS", 120_u32),
72
+ copied_trade_window_seconds: env_parse("LINK_GRAPH_COPIED_TRADE_WINDOW_SECONDS", 60_i64),
73
+ sniper_rank_threshold: env_parse("LINK_GRAPH_SNIPER_RANK_THRESHOLD", 45_u64),
74
+ whale_rank_threshold: env_parse("LINK_GRAPH_WHALE_RANK_THRESHOLD", 5_u64),
75
+ min_top_trader_pnl: env_parse("LINK_GRAPH_MIN_TOP_TRADER_PNL", 1.0_f32),
76
+ min_trade_total_usd: env_parse("LINK_GRAPH_MIN_TRADE_TOTAL_USD", 20.0_f64),
77
+ ath_price_threshold_usd: env_parse("LINK_GRAPH_ATH_PRICE_THRESHOLD_USD", 0.0002000_f64),
78
+ window_max_wait_ms: env_parse("LINK_GRAPH_WINDOW_MAX_WAIT_MS", 250_u64),
79
+ late_slack_ms: env_parse("LINK_GRAPH_LATE_SLACK_MS", 2000_u64),
80
+ chunk_size_large: env_parse("LINK_GRAPH_CHUNK_SIZE_LARGE", 3000_usize),
81
+ chunk_size_historical: env_parse("LINK_GRAPH_CHUNK_SIZE_HISTORICAL", 1000_usize),
82
+ chunk_size_mint_small: env_parse("LINK_GRAPH_CHUNK_SIZE_MINT_SMALL", 1500_usize),
83
+ chunk_size_mint_large: env_parse("LINK_GRAPH_CHUNK_SIZE_MINT_LARGE", 3000_usize),
84
+ chunk_size_token: env_parse("LINK_GRAPH_CHUNK_SIZE_TOKEN", 3000_usize),
85
+ trade_cache_max_entries: env_parse("LINK_GRAPH_TRADE_CACHE_MAX_ENTRIES", 1_000_000_usize),
86
+ trade_cache_ttl_secs: env_parse("LINK_GRAPH_TRADE_CACHE_TTL_SECS", 600_u32),
87
+ trade_cache_max_recent: env_parse("LINK_GRAPH_TRADE_CACHE_MAX_RECENT", 16_usize),
88
+ writer_channel_capacity: env_parse("LINK_GRAPH_WRITER_CHANNEL_CAPACITY", 5000_usize),
89
+ writer_max_batch_rows: env_parse("LINK_GRAPH_WRITER_MAX_BATCH_ROWS", 1000_usize),
90
+ writer_retry_attempts: env_parse("LINK_GRAPH_WRITER_RETRY_ATTEMPTS", 3_u32),
91
+ writer_retry_backoff_ms: env_parse("LINK_GRAPH_WRITER_RETRY_BACKOFF_MS", 250_u64),
92
+ ath_fetch_chunk_size: env_parse("LINK_GRAPH_ATH_FETCH_CHUNK_SIZE", 500_usize),
93
+ ch_retry_attempts: env_parse("LINK_GRAPH_CH_RETRY_ATTEMPTS", 3_u32),
94
+ ch_retry_backoff_ms: env_parse("LINK_GRAPH_CH_RETRY_BACKOFF_MS", 500_u64),
95
+ ch_fail_fast: env_parse("LINK_GRAPH_CH_FAIL_FAST", true),
96
+ });
97
+
98
+ fn env_parse<T: FromStr>(key: &str, default: T) -> T {
99
+ std::env::var(key)
100
+ .ok()
101
+ .and_then(|v| v.parse::<T>().ok())
102
+ .unwrap_or(default)
103
+ }
104
+
105
+ #[derive(Row, Deserialize, Clone)]
106
+ struct FullHistTrade {
107
+ maker: String,
108
+ base_address: String,
109
+ timestamp: u32,
110
+ signature: String,
111
+ trade_type: u8,
112
+ total_usd: f64,
113
+ slippage: f32,
114
+ }
115
+
116
+ enum FollowerLink {
117
+ Copied(CopiedTradeLink),
118
+ Coordinated(CoordinatedActivityLink),
119
+ }
120
+
121
+ pub struct LinkGraph {
122
+ db_client: Client,
123
+ neo4j_client: Arc<Graph>,
124
+ rx: mpsc::Receiver<EventPayload>,
125
+ link_graph_depth: Arc<AtomicUsize>,
126
+ write_lock: Mutex<()>,
127
+ trade_cache: Arc<Mutex<HashMap<(String, String), CachedPairState>>>,
128
+ write_sender: mpsc::Sender<WriteJob>,
129
+ writer_depth: Arc<AtomicUsize>,
130
+ }
131
+
132
+ // Global Neo4j write lock to serialize batches across workers and avoid deadlocks.
133
+ static NEO4J_WRITE_LOCK: Lazy<Mutex<()>> = Lazy::new(|| Mutex::new(()));
134
+
135
+ #[derive(Row, Deserialize, Debug)]
136
+ struct Ping {
137
+ alive: u8,
138
+ }
139
+ #[derive(Row, Deserialize, Debug)]
140
+ struct CountResult {
141
+ count: u64,
142
+ }
143
+
144
+ #[derive(Clone, Debug)]
145
+ struct CachedTrade {
146
+ maker: String,
147
+ base_address: String,
148
+ timestamp: u32,
149
+ signature: String,
150
+ trade_type: u8,
151
+ total_usd: f64,
152
+ slippage: f32,
153
+ }
154
+
155
+ #[derive(Debug)]
156
+ struct CachedPairState {
157
+ first_buy: Option<CachedTrade>,
158
+ first_sell: Option<CachedTrade>,
159
+ recent: VecDeque<CachedTrade>,
160
+ last_seen: u32,
161
+ }
162
+
163
+ #[derive(Debug)]
164
+ pub struct WriteJob {
165
+ query: String,
166
+ params: Vec<HashMap<String, BoltType>>,
167
+ }
168
+
169
+ impl LinkGraph {
170
+ pub async fn new(
171
+ db_client: Client,
172
+ neo4j_client: Arc<Graph>,
173
+ rx: mpsc::Receiver<EventPayload>,
174
+ link_graph_depth: Arc<AtomicUsize>,
175
+ write_sender: mpsc::Sender<WriteJob>,
176
+ writer_depth: Arc<AtomicUsize>,
177
+ ) -> Result<Self> {
178
+ let _: Ping = db_client.query("SELECT 1 as alive").fetch_one().await?;
179
+ neo4j_client.run(query("MATCH (n) RETURN count(n)")).await?;
180
+ println!("[WalletGraph] ✔️ Connected to ClickHouse, Neo4j. Listening on channel.");
181
+ Ok(Self {
182
+ db_client,
183
+ neo4j_client,
184
+ rx,
185
+ link_graph_depth,
186
+ write_lock: Mutex::new(()),
187
+ trade_cache: Arc::new(Mutex::new(HashMap::new())),
188
+ write_sender,
189
+ writer_depth,
190
+ })
191
+ }
192
+
193
+ async fn with_ch_retry<T, F, Fut>(&self, mut op: F, label: &str) -> Result<T>
194
+ where
195
+ F: FnMut() -> Fut,
196
+ Fut: Future<Output = Result<T>>,
197
+ {
198
+ let cfg = &*LINK_GRAPH_CONFIG;
199
+ let mut attempts = 0;
200
+ loop {
201
+ attempts += 1;
202
+ match op().await {
203
+ Ok(res) => return Ok(res),
204
+ Err(e) => {
205
+ if attempts >= cfg.ch_retry_attempts {
206
+ return Err(anyhow!(
207
+ "[LinkGraph] {} failed after {} attempts: {}",
208
+ label,
209
+ attempts,
210
+ e
211
+ ));
212
+ }
213
+ let backoff = cfg.ch_retry_backoff_ms * attempts as u64;
214
+ eprintln!(
215
+ "[LinkGraph] ⚠️ {} retry {}/{} after {}ms: {}",
216
+ label, attempts, cfg.ch_retry_attempts, backoff, e
217
+ );
218
+ sleep(Duration::from_millis(backoff)).await;
219
+ }
220
+ }
221
+ }
222
+ }
223
+
224
+ pub async fn run(&mut self) -> Result<()> {
225
+ let cfg = &*LINK_GRAPH_CONFIG;
226
+ let mut message_buffer: Vec<EventPayload> = Vec::new();
227
+ let mut current_window_start: Option<u32> = None;
228
+ let mut window_opened_at: Option<Instant> = None;
229
+ let mut flush_check = interval(Duration::from_millis(cfg.window_max_wait_ms.max(50)));
230
+ flush_check.set_missed_tick_behavior(MissedTickBehavior::Delay);
231
+ let late_slack_secs: u32 = (cfg.late_slack_ms / 1000) as u32;
232
+
233
+ loop {
234
+ tokio::select! {
235
+ maybe_payload = self.rx.recv() => {
236
+ match maybe_payload {
237
+ Some(payload) => {
238
+ // one item left the channel
239
+ self.link_graph_depth.fetch_sub(1, Ordering::Relaxed);
240
+ if current_window_start.is_none() {
241
+ current_window_start = Some(payload.timestamp);
242
+ window_opened_at = Some(Instant::now());
243
+ }
244
+
245
+ let window_end = current_window_start.unwrap() + cfg.time_window_seconds;
246
+ if payload.timestamp <= window_end + late_slack_secs {
247
+ message_buffer.push(payload);
248
+ } else {
249
+ if !message_buffer.is_empty() {
250
+ message_buffer.sort_by_key(|p| p.timestamp);
251
+ let batch = std::mem::take(&mut message_buffer);
252
+ if let Err(e) = self.process_batch_with_retry(batch).await {
253
+ eprintln!("[LinkGraph] 🔴 Fatal processing window: {}", e);
254
+ std::process::exit(1);
255
+ }
256
+ }
257
+ current_window_start = Some(payload.timestamp);
258
+ window_opened_at = Some(Instant::now());
259
+ message_buffer.push(payload);
260
+ }
261
+ }
262
+ None => {
263
+ eprintln!("[LinkGraph] 🔴 Input channel closed. Exiting.");
264
+ if !message_buffer.is_empty() {
265
+ message_buffer.sort_by_key(|p| p.timestamp);
266
+ let batch = std::mem::take(&mut message_buffer);
267
+ if let Err(e) = self.process_batch_with_retry(batch).await {
268
+ eprintln!("[LinkGraph] 🔴 Fatal processing final window: {}", e);
269
+ }
270
+ }
271
+ // Fatal: the producer is gone. Exit so it's obvious.
272
+ std::process::exit(1);
273
+ }
274
+ }
275
+ }
276
+ _ = flush_check.tick() => {
277
+ if !message_buffer.is_empty() {
278
+ if let Some(opened) = window_opened_at {
279
+ if opened.elapsed() >= Duration::from_millis(cfg.window_max_wait_ms) {
280
+ message_buffer.sort_by_key(|p| p.timestamp);
281
+ let batch = std::mem::take(&mut message_buffer);
282
+ if let Err(e) = self.process_batch_with_retry(batch).await {
283
+ eprintln!("[LinkGraph] 🔴 Fatal processing timed window: {}", e);
284
+ std::process::exit(1);
285
+ }
286
+ current_window_start = None;
287
+ window_opened_at = None;
288
+ }
289
+ }
290
+ }
291
+ }
292
+ }
293
+ }
294
+
295
+ Ok(())
296
+ }
297
+
298
+ async fn process_time_window(&self, payloads: &[EventPayload]) -> Result<()> {
299
+ let cfg = &*LINK_GRAPH_CONFIG;
300
+ let mut unique_wallets = HashSet::new();
301
+ let mut unique_tokens = HashSet::new();
302
+ let mut trades = Vec::new();
303
+ let mut transfers = Vec::new();
304
+ let mut mints = Vec::new();
305
+ let mut supply_locks = Vec::new();
306
+ let mut burns = Vec::new();
307
+ let mut liquidity_events = Vec::new();
308
+
309
+ for payload in payloads {
310
+ match &payload.event {
311
+ EventType::Trade(trade) => {
312
+ // Skip dust trades to reduce noise in downstream links/datasets
313
+ if trade.total_usd >= cfg.min_trade_total_usd {
314
+ unique_wallets.insert(trade.maker.clone());
315
+ unique_tokens.insert(trade.base_address.clone());
316
+ trades.push(trade.clone());
317
+ }
318
+ }
319
+ EventType::Transfer(transfer) => {
320
+ unique_wallets.insert(transfer.source.clone());
321
+ unique_wallets.insert(transfer.destination.clone());
322
+ transfers.push(transfer.clone());
323
+ }
324
+ EventType::Mint(mint) => {
325
+ unique_wallets.insert(mint.creator_address.clone());
326
+ unique_tokens.insert(mint.mint_address.clone());
327
+ mints.push(mint.clone());
328
+ }
329
+ EventType::SupplyLock(lock) => {
330
+ unique_wallets.insert(lock.sender.clone());
331
+ unique_wallets.insert(lock.recipient.clone());
332
+ unique_tokens.insert(lock.mint_address.clone());
333
+ supply_locks.push(lock.clone());
334
+ }
335
+ EventType::Burn(burn) => {
336
+ unique_wallets.insert(burn.source.clone());
337
+ unique_tokens.insert(burn.mint_address.clone());
338
+ burns.push(burn.clone());
339
+ }
340
+ EventType::Liquidity(liquidity) => {
341
+ if liquidity.change_type == 0 {
342
+ // 0 = Add Liquidity
343
+ unique_wallets.insert(liquidity.lp_provider.clone());
344
+ liquidity_events.push(liquidity.clone());
345
+ }
346
+ }
347
+ _ => {}
348
+ }
349
+ }
350
+
351
+ // Run link detection in parallel; writes remain serialized by the global Neo4j lock.
352
+ let parallel_start = Instant::now();
353
+ try_join!(
354
+ self.process_mints(&mints, &trades),
355
+ self.process_transfers_and_funding(&transfers),
356
+ self.process_supply_locks(&supply_locks),
357
+ self.process_burns(&burns),
358
+ self.process_liquidity_events(&liquidity_events),
359
+ self.process_trade_patterns(&trades, &mints),
360
+ )?;
361
+ println!(
362
+ "[LinkGraph] [TimeWindow] Parallel link processing finished in: {:?}",
363
+ parallel_start.elapsed()
364
+ );
365
+ Ok(())
366
+ }
367
+
368
+ async fn process_batch(&self, mut payloads: Vec<EventPayload>) -> Result<()> {
369
+ if payloads.is_empty() {
370
+ return Ok(());
371
+ }
372
+
373
+ // Payloads are already a complete time-window. We just need to sort them.
374
+ payloads.sort_by_key(|p| p.timestamp);
375
+
376
+ // Process the entire batch as a single logical unit with a per-worker write lock.
377
+ let _guard = self.write_lock.lock().await;
378
+ self.process_time_window(&payloads).await?;
379
+
380
+ println!(
381
+ "[LinkGraph] Finished processing batch of {} events.",
382
+ payloads.len()
383
+ );
384
+ Ok(())
385
+ }
386
+
387
+ async fn process_batch_with_retry(&self, payloads: Vec<EventPayload>) -> Result<()> {
388
+ // Serialize across all workers to avoid Neo4j deadlocks.
389
+ let _global_lock = NEO4J_WRITE_LOCK.lock().await;
390
+ let mut attempts = 0;
391
+ let max_retries = 3;
392
+ loop {
393
+ match self.process_batch(payloads.clone()).await {
394
+ Ok(_) => return Ok(()),
395
+ Err(e) => {
396
+ let err_str = e.to_string();
397
+ if err_str.contains("DeadlockDetected") && attempts < max_retries {
398
+ attempts += 1;
399
+ let backoff_ms = 200 * attempts;
400
+ eprintln!(
401
+ "[LinkGraph] ⚠️ Deadlock detected, retrying {}/{} after {}ms",
402
+ attempts, max_retries, backoff_ms
403
+ );
404
+ sleep(Duration::from_millis(backoff_ms as u64)).await;
405
+ continue;
406
+ } else {
407
+ return Err(e);
408
+ }
409
+ }
410
+ }
411
+ }
412
+ }
413
+
414
+ // --- Main Logic for Pattern Detection ---
415
+ fn cached_trade_from_trade(trade: &TradeRow) -> CachedTrade {
416
+ CachedTrade {
417
+ maker: trade.maker.clone(),
418
+ base_address: trade.base_address.clone(),
419
+ timestamp: trade.timestamp,
420
+ signature: trade.signature.clone(),
421
+ trade_type: trade.trade_type,
422
+ total_usd: trade.total_usd,
423
+ slippage: trade.slippage,
424
+ }
425
+ }
426
+
427
+ async fn update_trade_cache(&self, trades: &[&TradeRow]) -> Result<()> {
428
+ if trades.is_empty() {
429
+ return Ok(());
430
+ }
431
+ let cfg = &*LINK_GRAPH_CONFIG;
432
+ let now_ts = trades.iter().map(|t| t.timestamp).max().unwrap_or(0);
433
+ let cutoff = now_ts.saturating_sub(cfg.trade_cache_ttl_secs);
434
+
435
+ let mut cache = self.trade_cache.lock().await;
436
+ cache.retain(|_, state| state.last_seen >= cutoff);
437
+
438
+ for trade in trades {
439
+ let key = (trade.maker.clone(), trade.base_address.clone());
440
+ let entry = cache.entry(key).or_insert_with(|| CachedPairState {
441
+ first_buy: None,
442
+ first_sell: None,
443
+ recent: VecDeque::new(),
444
+ last_seen: 0,
445
+ });
446
+
447
+ entry.last_seen = entry.last_seen.max(trade.timestamp);
448
+
449
+ let ct = Self::cached_trade_from_trade(trade);
450
+ if trade.trade_type == 0 {
451
+ if entry
452
+ .first_buy
453
+ .as_ref()
454
+ .map_or(true, |b| ct.timestamp < b.timestamp)
455
+ {
456
+ entry.first_buy = Some(ct.clone());
457
+ }
458
+ } else if trade.trade_type == 1 {
459
+ if entry
460
+ .first_sell
461
+ .as_ref()
462
+ .map_or(true, |s| ct.timestamp < s.timestamp)
463
+ {
464
+ entry.first_sell = Some(ct.clone());
465
+ }
466
+ }
467
+
468
+ entry.recent.push_back(ct);
469
+ while entry.recent.len() > cfg.trade_cache_max_recent {
470
+ entry.recent.pop_front();
471
+ }
472
+ while let Some(front) = entry.recent.front() {
473
+ if front.timestamp + cfg.trade_cache_ttl_secs < now_ts {
474
+ entry.recent.pop_front();
475
+ } else {
476
+ break;
477
+ }
478
+ }
479
+ }
480
+
481
+ if cache.len() > cfg.trade_cache_max_entries {
482
+ let mut entries: Vec<_> = cache
483
+ .iter()
484
+ .map(|(k, v)| (k.clone(), v.last_seen))
485
+ .collect();
486
+ entries.sort_by_key(|(_, ts)| *ts);
487
+ let to_drop = entries.len().saturating_sub(cfg.trade_cache_max_entries);
488
+ for (key, _) in entries.into_iter().take(to_drop) {
489
+ cache.remove(&key);
490
+ }
491
+ }
492
+ Ok(())
493
+ }
494
+
495
+ async fn build_histories_from_cache(
496
+ &self,
497
+ pairs: &[(String, String)],
498
+ ) -> Result<HashMap<(String, String), Vec<FullHistTrade>>> {
499
+ let mut map = HashMap::new();
500
+ let cache = self.trade_cache.lock().await;
501
+ for pair in pairs {
502
+ if let Some(state) = cache.get(pair) {
503
+ let mut collected = Vec::new();
504
+ if let Some(b) = &state.first_buy {
505
+ collected.push(Self::cached_to_full(b));
506
+ }
507
+ if let Some(s) = &state.first_sell {
508
+ collected.push(Self::cached_to_full(s));
509
+ }
510
+ for t in state.recent.iter() {
511
+ collected.push(Self::cached_to_full(t));
512
+ }
513
+
514
+ if !collected.is_empty() {
515
+ collected.sort_by_key(|t| t.timestamp);
516
+ collected.dedup_by(|a, b| a.signature == b.signature);
517
+ map.insert(pair.clone(), collected);
518
+ }
519
+ }
520
+ }
521
+ Ok(map)
522
+ }
523
+
524
+ fn cached_to_full(ct: &CachedTrade) -> FullHistTrade {
525
+ FullHistTrade {
526
+ maker: ct.maker.clone(),
527
+ base_address: ct.base_address.clone(),
528
+ timestamp: ct.timestamp,
529
+ signature: ct.signature.clone(),
530
+ trade_type: ct.trade_type,
531
+ total_usd: ct.total_usd,
532
+ slippage: ct.slippage,
533
+ }
534
+ }
535
+
536
+ pub async fn writer_task(
537
+ mut rx: mpsc::Receiver<WriteJob>,
538
+ neo4j_client: Arc<Graph>,
539
+ writer_depth: Arc<AtomicUsize>,
540
+ ) {
541
+ let cfg = &*LINK_GRAPH_CONFIG;
542
+ while let Some(job) = rx.recv().await {
543
+ writer_depth.fetch_sub(1, Ordering::Relaxed);
544
+ let batches = job
545
+ .params
546
+ .chunks(cfg.writer_max_batch_rows.max(1))
547
+ .map(|chunk| chunk.to_vec())
548
+ .collect::<Vec<_>>();
549
+
550
+ for (idx, params) in batches.iter().enumerate() {
551
+ let q = query(&job.query).param("x", params.clone());
552
+ let mut attempts = 0;
553
+ loop {
554
+ let start = Instant::now();
555
+ match neo4j_client.run(q.clone()).await {
556
+ Ok(_) => {
557
+ println!(
558
+ "[LinkGraph] [Writer] ✅ wrote {} rows (chunk {}/{}) in {:?}",
559
+ params.len(),
560
+ idx + 1,
561
+ batches.len(),
562
+ start.elapsed()
563
+ );
564
+ break;
565
+ }
566
+ Err(e) => {
567
+ let msg = e.to_string();
568
+ attempts += 1;
569
+ if msg.contains("DeadlockDetected")
570
+ && attempts <= cfg.writer_retry_attempts
571
+ {
572
+ let backoff = cfg.writer_retry_backoff_ms * attempts as u64;
573
+ eprintln!(
574
+ "[LinkGraph] [Writer] ⚠️ deadlock, retry {}/{} after {}ms: {}",
575
+ attempts, cfg.writer_retry_attempts, backoff, msg
576
+ );
577
+ sleep(Duration::from_millis(backoff)).await;
578
+ continue;
579
+ } else {
580
+ eprintln!(
581
+ "[LinkGraph] 🔴 Writer fatal after {} attempts: {}",
582
+ attempts, msg
583
+ );
584
+ std::process::exit(1);
585
+ }
586
+ }
587
+ }
588
+ }
589
+ }
590
+ }
591
+ eprintln!("[LinkGraph] 🔴 Writer channel closed.");
592
+ std::process::exit(1);
593
+ }
594
+
595
+ async fn enqueue_write(
596
+ &self,
597
+ cypher: &str,
598
+ params: Vec<HashMap<String, BoltType>>,
599
+ ) -> Result<()> {
600
+ let job = WriteJob {
601
+ query: cypher.to_string(),
602
+ params,
603
+ };
604
+ self.write_sender
605
+ .send(job)
606
+ .await
607
+ .map_err(|e| anyhow!("[LinkGraph] Failed to enqueue write: {}", e))?;
608
+ self.writer_depth.fetch_add(1, Ordering::Relaxed);
609
+ Ok(())
610
+ }
611
+
612
+ async fn process_mints(
613
+ &self,
614
+ mints: &[MintRow],
615
+ all_trades_in_batch: &[TradeRow],
616
+ ) -> Result<()> {
617
+ let start = Instant::now();
618
+ if mints.is_empty() {
619
+ return Ok(());
620
+ }
621
+ let mut links = Vec::new();
622
+
623
+ for mint in mints {
624
+ let dev_buy = all_trades_in_batch.iter().find(
625
+ |t| {
626
+ t.maker == mint.creator_address
627
+ && t.base_address == mint.mint_address
628
+ && t.trade_type == 0
629
+ }, // 0 = Buy
630
+ );
631
+ let buy_amount_decimals = dev_buy.map_or(0.0, |t| {
632
+ let quote_decimals = decimals_for_quote(&t.quote_address);
633
+ t.quote_amount as f64 / 10f64.powi(quote_decimals as i32)
634
+ });
635
+ links.push(MintedLink {
636
+ signature: mint.signature.clone(),
637
+ timestamp: mint.timestamp as i64,
638
+ buy_amount: buy_amount_decimals,
639
+ });
640
+ }
641
+ self.write_minted_links(&links, mints).await?;
642
+ println!(
643
+ "[LinkGraph] [Profile] process_mints: {} mints in {:?}",
644
+ mints.len(),
645
+ start.elapsed()
646
+ );
647
+ Ok(())
648
+ }
649
+
650
+ async fn process_supply_locks(&self, locks: &[SupplyLockRow]) -> Result<()> {
651
+ let start = Instant::now();
652
+ if locks.is_empty() {
653
+ return Ok(());
654
+ }
655
+ let links: Vec<_> = locks
656
+ .iter()
657
+ .map(|l| LockedSupplyLink {
658
+ signature: l.signature.clone(),
659
+ amount: l.total_locked_amount as f64,
660
+ timestamp: l.timestamp as i64,
661
+ unlock_timestamp: l.final_unlock_timestamp,
662
+ })
663
+ .collect();
664
+ self.write_locked_supply_links(&links, locks).await?;
665
+ println!(
666
+ "[LinkGraph] [Profile] process_supply_locks: {} locks in {:?}",
667
+ locks.len(),
668
+ start.elapsed()
669
+ );
670
+ Ok(())
671
+ }
672
+
673
+ async fn process_burns(&self, burns: &[BurnRow]) -> Result<()> {
674
+ let start = Instant::now();
675
+ if burns.is_empty() {
676
+ return Ok(());
677
+ }
678
+ let links: Vec<_> = burns
679
+ .iter()
680
+ .map(|b| BurnedLink {
681
+ signature: b.signature.clone(),
682
+ amount: b.amount_decimal,
683
+ timestamp: b.timestamp as i64,
684
+ })
685
+ .collect();
686
+ self.write_burned_links(&links, burns).await?;
687
+ println!(
688
+ "[LinkGraph] [Profile] process_burns: {} burns in {:?}",
689
+ burns.len(),
690
+ start.elapsed()
691
+ );
692
+ Ok(())
693
+ }
694
+
695
+ async fn process_transfers_and_funding(&self, transfers: &[TransferRow]) -> Result<()> {
696
+ let start = Instant::now();
697
+ if transfers.is_empty() {
698
+ return Ok(());
699
+ }
700
+
701
+ // Directly map every TransferRow to a TransferLink without any extra logic.
702
+ let transfer_links: Vec<TransferLink> = transfers
703
+ .iter()
704
+ .map(|transfer| TransferLink {
705
+ source: transfer.source.clone(),
706
+ destination: transfer.destination.clone(),
707
+ signature: transfer.signature.clone(),
708
+ mint: transfer.mint_address.clone(),
709
+ timestamp: transfer.timestamp as i64,
710
+ amount: transfer.amount_decimal,
711
+ })
712
+ .collect();
713
+
714
+ self.write_transfer_links(&transfer_links).await?;
715
+ println!(
716
+ "[LinkGraph] [Profile] process_transfers: {} transfers in {:?}",
717
+ transfers.len(),
718
+ start.elapsed()
719
+ );
720
+ Ok(())
721
+ }
722
+
723
+ async fn process_trade_patterns(
724
+ &self,
725
+ trades: &[TradeRow],
726
+ mints_in_batch: &[MintRow],
727
+ ) -> Result<()> {
728
+ let start = Instant::now();
729
+ if trades.is_empty() {
730
+ return Ok(());
731
+ }
732
+
733
+ let creator_map: HashMap<String, String> = mints_in_batch
734
+ .iter()
735
+ .map(|m| (m.mint_address.clone(), m.creator_address.clone()))
736
+ .collect();
737
+
738
+ let mut processed_pairs = HashSet::new();
739
+
740
+ let bundle_links = self.detect_bundle_trades(trades, &mut processed_pairs);
741
+ if !bundle_links.is_empty() {
742
+ self.write_bundle_trade_links(&bundle_links).await?;
743
+ }
744
+
745
+ let follower_links = self
746
+ .detect_follower_activity(trades, &mut processed_pairs)
747
+ .await?;
748
+ if !follower_links.is_empty() {
749
+ let mut copied_links = Vec::new();
750
+ let mut coordinated_links = Vec::new();
751
+ for link in follower_links {
752
+ match link {
753
+ FollowerLink::Copied(l) => copied_links.push(l),
754
+ FollowerLink::Coordinated(l) => coordinated_links.push(l),
755
+ }
756
+ }
757
+ if !copied_links.is_empty() {
758
+ self.write_copied_trade_links(&copied_links).await?;
759
+ }
760
+ if !coordinated_links.is_empty() {
761
+ self.write_coordinated_activity_links(&coordinated_links)
762
+ .await?;
763
+ }
764
+ }
765
+
766
+ self.detect_and_write_snipes(trades, creator_map).await?;
767
+ self.detect_and_write_whale_links(trades).await?;
768
+ self.detect_and_write_top_trader_links(trades).await?;
769
+
770
+ println!(
771
+ "[LinkGraph] [Profile] process_trade_patterns: {} trades in {:?}",
772
+ trades.len(),
773
+ start.elapsed()
774
+ );
775
+ Ok(())
776
+ }
777
+
778
+ async fn detect_and_write_snipes(
779
+ &self,
780
+ _trades: &[TradeRow],
781
+ creator_map: HashMap<String, String>,
782
+ ) -> Result<()> {
783
+ let start = Instant::now();
784
+ let cfg = &*LINK_GRAPH_CONFIG;
785
+ let mut links: Vec<SnipedLink> = Vec::new();
786
+ let mut snipers_map: HashMap<String, (String, String)> = HashMap::new();
787
+ // Limit sniper detection to Pump.fun launchpad trades only.
788
+ let pump_trades: Vec<&TradeRow> = _trades
789
+ .iter()
790
+ .filter(|t| t.protocol == PROTOCOL_PUMPFUN_LAUNCHPAD)
791
+ .collect();
792
+ if pump_trades.is_empty() {
793
+ return Ok(());
794
+ }
795
+
796
+ let unique_mints: HashSet<String> =
797
+ pump_trades.iter().map(|t| t.base_address.clone()).collect();
798
+ if unique_mints.is_empty() {
799
+ return Ok(());
800
+ }
801
+
802
+ // This pre-flight check remains the same
803
+ #[derive(Row, Deserialize, Debug)]
804
+ struct TokenHolderInfo {
805
+ token_address: String,
806
+ unique_holders: u32,
807
+ }
808
+
809
+ let holder_check_query = "
810
+ SELECT token_address, unique_holders
811
+ FROM token_metrics_latest
812
+ WHERE token_address IN ?
813
+ ORDER BY token_address, updated_at DESC
814
+ LIMIT 1 BY token_address
815
+ ";
816
+ let mut holder_infos: Vec<TokenHolderInfo> = Vec::new();
817
+ let unique_mints_vec: Vec<_> = unique_mints.iter().cloned().collect();
818
+
819
+ for chunk in unique_mints_vec.chunks(cfg.chunk_size_large) {
820
+ let mut chunk_results = self
821
+ .with_ch_retry(
822
+ || async {
823
+ self.db_client
824
+ .query(holder_check_query)
825
+ .bind(chunk)
826
+ .fetch_all()
827
+ .await
828
+ .map_err(anyhow::Error::from)
829
+ },
830
+ "Snipes-HolderCheck chunk",
831
+ )
832
+ .await?;
833
+ holder_infos.append(&mut chunk_results);
834
+ }
835
+
836
+ let token_holder_map: HashMap<String, u32> = holder_infos
837
+ .into_iter()
838
+ .map(|t| (t.token_address, t.unique_holders))
839
+ .collect();
840
+
841
+ #[derive(Row, Deserialize, Clone, Debug)]
842
+ struct SniperInfo {
843
+ maker: String,
844
+ first_sig: String,
845
+ first_total: f64,
846
+ first_ts: u32,
847
+ }
848
+
849
+ #[derive(Row, Deserialize, Debug)]
850
+ struct TokenCreator {
851
+ creator_address: String,
852
+ }
853
+
854
+ // OPTIMIZATION: Parallelize the database queries for each mint.
855
+ let query_futures = unique_mints
856
+ .into_iter()
857
+ .filter(|mint| {
858
+ // Pre-filter mints that are too established
859
+ let holder_count = token_holder_map.get(mint).cloned().unwrap_or(0);
860
+ holder_count <= cfg.sniper_rank_threshold as u32
861
+ })
862
+ .map(|mint| {
863
+ let db_client = self.db_client.clone();
864
+ let creator_map_clone = creator_map.clone();
865
+ // Create an async block (a future) for each query
866
+ async move {
867
+ let snipers_query = "
868
+ SELECT maker,
869
+ argMin(signature, timestamp) as first_sig,
870
+ argMin(total, timestamp) as first_total,
871
+ min(toUInt32(timestamp)) as first_ts
872
+ FROM trades WHERE base_address = ? AND trade_type = 0
873
+ GROUP BY maker ORDER BY min(timestamp) ASC LIMIT ?
874
+ ";
875
+
876
+ let result = db_client
877
+ .query(snipers_query)
878
+ .bind(mint.clone()) // Keep this bind
879
+ .bind(cfg.sniper_rank_threshold) // And this one
880
+ .fetch_all::<SniperInfo>()
881
+ .await
882
+ .map_err(|e| {
883
+ anyhow!(
884
+ "[SNIPER_FAIL]: Sniper fetch for mint '{}' failed. Error: {}",
885
+ mint,
886
+ e
887
+ )
888
+ });
889
+
890
+ (mint, result)
891
+ }
892
+ });
893
+
894
+ // Execute the futures concurrently with a limit of 20 at a time.
895
+ let results = stream::iter(query_futures)
896
+ .buffer_unordered(20) // CONCURRENCY LIMIT
897
+ .collect::<Vec<_>>()
898
+ .await;
899
+
900
+ // Process the results after they have all completed
901
+ for (mint, result) in results {
902
+ match result {
903
+ Ok(sniper_candidates) => {
904
+ for (i, sniper) in sniper_candidates.iter().enumerate() {
905
+ links.push(SnipedLink {
906
+ timestamp: sniper.first_ts as i64,
907
+ signature: sniper.first_sig.clone(),
908
+ rank: (i + 1) as i64,
909
+ sniped_amount: sniper.first_total,
910
+ });
911
+ snipers_map.insert(
912
+ sniper.first_sig.clone(),
913
+ (sniper.maker.clone(), mint.clone()),
914
+ );
915
+ }
916
+ }
917
+ Err(e) => eprintln!("[Snipers] Error processing mint {}: {}", mint, e),
918
+ }
919
+ }
920
+
921
+ if !links.is_empty() {
922
+ self.write_sniped_links(&links, &snipers_map).await?;
923
+ }
924
+ println!(
925
+ "[LinkGraph] [Profile] detect_and_write_snipes: {} links in {:?}",
926
+ links.len(),
927
+ start.elapsed()
928
+ );
929
+ Ok(())
930
+ }
931
+
932
+ fn detect_bundle_trades(
933
+ &self,
934
+ trades: &[TradeRow],
935
+ processed_pairs: &mut HashSet<(String, String)>,
936
+ ) -> Vec<BundleTradeLink> {
937
+ let mut links = Vec::new();
938
+ let trades_by_slot_mint = trades
939
+ .iter()
940
+ .into_group_map_by(|t| (t.slot, t.base_address.clone()));
941
+
942
+ for ((slot, mint), trades_in_bundle) in trades_by_slot_mint {
943
+ let unique_makers: Vec<_> =
944
+ trades_in_bundle.iter().map(|t| &t.maker).unique().collect();
945
+ if unique_makers.len() <= 1 {
946
+ continue;
947
+ }
948
+
949
+ // Leader Election: Find the trade with the max `quote_amount`.
950
+ // Includes a deterministic tie-breaker using the wallet address.
951
+ let leader_trade = match trades_in_bundle.iter().max_by(|a, b| {
952
+ match a.quote_amount.cmp(&b.quote_amount) {
953
+ std::cmp::Ordering::Equal => b.maker.cmp(&a.maker),
954
+ other => other,
955
+ }
956
+ }) {
957
+ Some(trade) => trade,
958
+ None => continue,
959
+ };
960
+ let leader_wallet = &leader_trade.maker;
961
+
962
+ let all_bundle_signatures: Vec<String> = trades_in_bundle
963
+ .iter()
964
+ .map(|t| t.signature.clone())
965
+ .collect();
966
+
967
+ for follower_trade in trades_in_bundle
968
+ .iter()
969
+ .filter(|t| &t.maker != leader_wallet)
970
+ {
971
+ let follower_wallet = &follower_trade.maker;
972
+
973
+ let mut combo_sorted = vec![leader_wallet.clone(), follower_wallet.clone()];
974
+ combo_sorted.sort();
975
+ let pair_key = (combo_sorted[0].clone(), combo_sorted[1].clone());
976
+
977
+ // Populate the processed_pairs set and create the link.
978
+ if processed_pairs.insert(pair_key) {
979
+ links.push(BundleTradeLink {
980
+ signatures: all_bundle_signatures.clone(),
981
+ wallet_a: leader_wallet.clone(),
982
+ wallet_b: follower_wallet.clone(),
983
+ mint: mint.clone(),
984
+ slot: slot as i64,
985
+ timestamp: leader_trade.timestamp as i64,
986
+ });
987
+ }
988
+ }
989
+ }
990
+ links
991
+ }
992
+
993
+ async fn detect_follower_activity(
994
+ &self,
995
+ trades: &[TradeRow],
996
+ processed_pairs: &mut HashSet<(String, String)>,
997
+ ) -> Result<Vec<FollowerLink>> {
998
+ let cfg = &*LINK_GRAPH_CONFIG;
999
+ let mut links = Vec::new();
1000
+ let min_usd_value = cfg.min_trade_total_usd;
1001
+
1002
+ let significant_trades: Vec<&TradeRow> = trades
1003
+ .iter()
1004
+ .filter(|t| t.total_usd >= min_usd_value)
1005
+ .collect();
1006
+
1007
+ if significant_trades.len() < 2 {
1008
+ return Ok(links);
1009
+ }
1010
+
1011
+ let unique_pairs: Vec<(String, String)> = significant_trades
1012
+ .iter()
1013
+ .map(|t| (t.maker.clone(), t.base_address.clone()))
1014
+ .unique()
1015
+ .collect();
1016
+ // Update and read from the bounded in-memory cache; fallback to CH only on misses.
1017
+ self.update_trade_cache(&significant_trades).await?;
1018
+ let mut historical_trades_map = self.build_histories_from_cache(&unique_pairs).await?;
1019
+
1020
+ let missing_pairs: Vec<(String, String)> = unique_pairs
1021
+ .iter()
1022
+ .filter(|k| !historical_trades_map.contains_key(*k))
1023
+ .cloned()
1024
+ .collect();
1025
+ if !missing_pairs.is_empty() {
1026
+ let historical_query = "
1027
+ SELECT maker, base_address, toUnixTimestamp(timestamp) as timestamp, signature, trade_type, total_usd, slippage
1028
+ FROM trades
1029
+ WHERE (maker, base_address) IN ?
1030
+ ";
1031
+ for chunk in missing_pairs.chunks(cfg.chunk_size_historical) {
1032
+ let chunk_results: Vec<FullHistTrade> = self
1033
+ .db_client
1034
+ .query(historical_query)
1035
+ .bind(chunk)
1036
+ .fetch_all()
1037
+ .await
1038
+ .map_err(|e| {
1039
+ anyhow!(
1040
+ "[FOLLOWER_FAIL]: Historical trade fetch failed. Error: {}",
1041
+ e
1042
+ )
1043
+ })?;
1044
+
1045
+ for trade in chunk_results {
1046
+ historical_trades_map
1047
+ .entry((trade.maker.clone(), trade.base_address.clone()))
1048
+ .or_default()
1049
+ .push(trade);
1050
+ }
1051
+ }
1052
+ }
1053
+
1054
+ let trades_by_mint = significant_trades
1055
+ .into_iter()
1056
+ .into_group_map_by(|t| t.base_address.clone());
1057
+
1058
+ for (mint, trades_in_batch) in trades_by_mint {
1059
+ if trades_in_batch.len() < 2 {
1060
+ continue;
1061
+ }
1062
+
1063
+ let Some(leader_trade) = trades_in_batch.iter().min_by_key(|t| t.timestamp) else {
1064
+ continue;
1065
+ };
1066
+ let leader_wallet = &leader_trade.maker;
1067
+
1068
+ for follower_trade in trades_in_batch.iter().filter(|t| &t.maker != leader_wallet) {
1069
+ let follower_wallet = &follower_trade.maker;
1070
+
1071
+ let mut pair_key_vec = vec![leader_wallet.to_string(), follower_wallet.to_string()];
1072
+ pair_key_vec.sort();
1073
+ let pair_key = (pair_key_vec[0].clone(), pair_key_vec[1].clone());
1074
+ if processed_pairs.contains(&pair_key) {
1075
+ continue;
1076
+ }
1077
+
1078
+ if let (Some(leader_hist_ref), Some(follower_hist_ref)) = (
1079
+ historical_trades_map.get(&(leader_wallet.clone(), mint.clone())),
1080
+ historical_trades_map.get(&(follower_wallet.clone(), mint.clone())),
1081
+ ) {
1082
+ let mut leader_hist = leader_hist_ref.clone();
1083
+ let mut follower_hist = follower_hist_ref.clone();
1084
+ leader_hist.sort_by_key(|t| t.timestamp);
1085
+ follower_hist.sort_by_key(|t| t.timestamp);
1086
+
1087
+ let leader_first_trade = leader_hist.get(0);
1088
+ let follower_first_trade = follower_hist.get(0);
1089
+
1090
+ // --- THE CRITICAL FIX ---
1091
+ // Base the decision on the very first interaction.
1092
+ if let (Some(l1), Some(f1)) = (leader_first_trade, follower_first_trade) {
1093
+ let first_gap = (f1.timestamp as i64 - l1.timestamp as i64).abs();
1094
+
1095
+ if first_gap > 0 && first_gap <= cfg.copied_trade_window_seconds {
1096
+ processed_pairs.insert(pair_key); // Process this pair only once
1097
+
1098
+ // A) If the FIRST trades are BOTH BUYS, it's a COPIED_TRADE.
1099
+ if l1.trade_type == 0 && f1.trade_type == 0 {
1100
+ let l_buy = l1; // Already have the first buy
1101
+ let f_buy = f1; // Already have the first buy
1102
+
1103
+ let leader_sells: Vec<_> =
1104
+ leader_hist.iter().filter(|t| t.trade_type == 1).collect();
1105
+ let follower_sells: Vec<_> =
1106
+ follower_hist.iter().filter(|t| t.trade_type == 1).collect();
1107
+ let leader_sell_total: f64 =
1108
+ leader_sells.iter().map(|t| t.total_usd).sum();
1109
+ let follower_sell_total: f64 =
1110
+ follower_sells.iter().map(|t| t.total_usd).sum();
1111
+ let leader_pnl = if l_buy.total_usd > 0.0 {
1112
+ (leader_sell_total - l_buy.total_usd) / l_buy.total_usd
1113
+ } else {
1114
+ 0.0
1115
+ };
1116
+ let follower_pnl = if f_buy.total_usd > 0.0 {
1117
+ (follower_sell_total - f_buy.total_usd) / f_buy.total_usd
1118
+ } else {
1119
+ 0.0
1120
+ };
1121
+ let leader_first_sell =
1122
+ leader_sells.iter().min_by_key(|t| t.timestamp);
1123
+ let follower_first_sell =
1124
+ follower_sells.iter().min_by_key(|t| t.timestamp);
1125
+
1126
+ let (sell_gap, l_sell_sig, f_sell_sig, f_sell_slip) =
1127
+ if let (Some(l_sell), Some(f_sell)) =
1128
+ (leader_first_sell, follower_first_sell)
1129
+ {
1130
+ (
1131
+ (f_sell.timestamp as i64 - l_sell.timestamp as i64)
1132
+ .abs(),
1133
+ l_sell.signature.clone(),
1134
+ f_sell.signature.clone(),
1135
+ f_sell.slippage,
1136
+ )
1137
+ } else {
1138
+ (0, "".to_string(), "".to_string(), 0.0)
1139
+ };
1140
+
1141
+ links.push(FollowerLink::Copied(CopiedTradeLink {
1142
+ timestamp: f_buy.timestamp as i64,
1143
+ follower: follower_wallet.clone(),
1144
+ leader: leader_wallet.clone(),
1145
+ mint: mint.clone(),
1146
+ time_gap_on_buy_sec: first_gap, // Use the already calculated gap
1147
+ time_gap_on_sell_sec: sell_gap,
1148
+ leader_pnl,
1149
+ follower_pnl,
1150
+ leader_buy_sig: l_buy.signature.clone(),
1151
+ leader_sell_sig: l_sell_sig,
1152
+ follower_buy_sig: f_buy.signature.clone(),
1153
+ follower_sell_sig: f_sell_sig,
1154
+ leader_buy_total: l_buy.total_usd,
1155
+ leader_sell_total,
1156
+ follower_buy_total: f_buy.total_usd,
1157
+ follower_sell_total,
1158
+ follower_buy_slippage: f_buy.slippage,
1159
+ follower_sell_slippage: f_sell_slip,
1160
+ }));
1161
+ }
1162
+ // B) ELSE, if the first trades are not both buys, it's a COORDINATED_ACTIVITY.
1163
+ else {
1164
+ let leader_second_trade = leader_hist.get(1);
1165
+ let follower_second_trade = follower_hist.get(1);
1166
+
1167
+ let (l2_sig, f2_sig, second_gap) = if let (Some(l2), Some(f2)) =
1168
+ (leader_second_trade, follower_second_trade)
1169
+ {
1170
+ (
1171
+ l2.signature.clone(),
1172
+ f2.signature.clone(),
1173
+ (f2.timestamp as i64 - l2.timestamp as i64).abs(),
1174
+ )
1175
+ } else {
1176
+ ("".to_string(), "".to_string(), 0)
1177
+ };
1178
+
1179
+ links.push(FollowerLink::Coordinated(CoordinatedActivityLink {
1180
+ timestamp: l1.timestamp as i64,
1181
+ leader: leader_wallet.clone(),
1182
+ follower: follower_wallet.clone(),
1183
+ mint: mint.clone(),
1184
+ leader_first_sig: l1.signature.clone(),
1185
+ follower_first_sig: f1.signature.clone(),
1186
+ time_gap_on_first_sec: first_gap,
1187
+ leader_second_sig: l2_sig,
1188
+ follower_second_sig: f2_sig,
1189
+ time_gap_on_second_sec: second_gap,
1190
+ }));
1191
+ }
1192
+ }
1193
+ }
1194
+ }
1195
+ }
1196
+ }
1197
+ Ok(links)
1198
+ }
1199
+
1200
+ async fn detect_and_write_top_trader_links(&self, trades: &[TradeRow]) -> Result<()> {
1201
+ let start = Instant::now();
1202
+ let cfg = &*LINK_GRAPH_CONFIG;
1203
+ let active_trader_pairs: Vec<(String, String)> = trades
1204
+ .iter()
1205
+ .map(|t| (t.maker.clone(), t.base_address.clone()))
1206
+ .unique()
1207
+ .collect();
1208
+
1209
+ if active_trader_pairs.is_empty() {
1210
+ return Ok(());
1211
+ }
1212
+
1213
+ // --- NEW: CONFIDENCE FILTER ---
1214
+ // 1. Get all unique mints from the active pairs.
1215
+ let unique_mints: Vec<String> = active_trader_pairs
1216
+ .iter()
1217
+ .map(|(_, mint)| mint.clone())
1218
+ .unique()
1219
+ .collect();
1220
+
1221
+ #[derive(Row, Deserialize, Debug)]
1222
+ struct MintCheck {
1223
+ mint_address: String,
1224
+ }
1225
+ let mint_query = "SELECT DISTINCT mint_address FROM mints WHERE mint_address IN ?";
1226
+
1227
+ let mut fully_tracked_mints = HashSet::new();
1228
+ let mint_chunk_small = cfg.chunk_size_mint_small;
1229
+
1230
+ for chunk in unique_mints.chunks(mint_chunk_small) {
1231
+ let chunk_rows: Vec<MintCheck> = self
1232
+ .with_ch_retry(
1233
+ || async {
1234
+ self.db_client
1235
+ .query(mint_query)
1236
+ .bind(chunk)
1237
+ .fetch_all()
1238
+ .await
1239
+ .map_err(anyhow::Error::from)
1240
+ },
1241
+ "TopTrader mint check chunk",
1242
+ )
1243
+ .await?;
1244
+ for mint_row in chunk_rows {
1245
+ fully_tracked_mints.insert(mint_row.mint_address);
1246
+ }
1247
+ }
1248
+
1249
+ // 2. Filter the active pairs to only include those for fully tracked tokens.
1250
+ let confident_trader_pairs: Vec<(String, String)> = active_trader_pairs
1251
+ .into_iter()
1252
+ .filter(|(_, mint)| fully_tracked_mints.contains(mint))
1253
+ .collect();
1254
+
1255
+ if confident_trader_pairs.is_empty() {
1256
+ return Ok(());
1257
+ }
1258
+ // --- END CONFIDENCE FILTER ---
1259
+
1260
+ let mints_to_query: Vec<String> = fully_tracked_mints.iter().cloned().collect();
1261
+ if mints_to_query.is_empty() {
1262
+ return Ok(());
1263
+ }
1264
+
1265
+ let ath_map = self
1266
+ .fetch_latest_ath_map_with_retry(&mints_to_query)
1267
+ .await?;
1268
+ if ath_map.is_empty() {
1269
+ return Ok(());
1270
+ }
1271
+
1272
+ #[derive(Row, Deserialize, Debug)]
1273
+ struct TraderContextInfo {
1274
+ wallet_address: String,
1275
+ mint_address: String,
1276
+ realized_profit_pnl: f32,
1277
+ }
1278
+
1279
+ let pnl_query = "
1280
+ SELECT
1281
+ wh.wallet_address, wh.mint_address, wh.realized_profit_pnl
1282
+ FROM wallet_holdings_latest AS wh
1283
+ WHERE wh.mint_address IN ?
1284
+ AND wh.realized_profit_pnl > ?
1285
+ QUALIFY ROW_NUMBER() OVER (PARTITION BY wh.mint_address ORDER BY wh.realized_profit_pnl DESC) = 1
1286
+ ";
1287
+
1288
+ let mut top_traders: Vec<TraderContextInfo> = Vec::new();
1289
+
1290
+ for chunk in mints_to_query.chunks(cfg.chunk_size_mint_large) {
1291
+ let chunk_results = self
1292
+ .db_client
1293
+ .query(pnl_query)
1294
+ .bind(chunk)
1295
+ .bind(cfg.min_top_trader_pnl)
1296
+ .fetch_all()
1297
+ .await
1298
+ .map_err(|e| anyhow!("[TOPTRADER_FAIL]: Top-1 PNL fetch failed. Error: {}", e))?;
1299
+ top_traders.extend(chunk_results);
1300
+ }
1301
+
1302
+ let links: Vec<TopTraderOfLink> = top_traders
1303
+ .into_iter()
1304
+ .filter_map(|trader| {
1305
+ ath_map
1306
+ .get(&trader.mint_address)
1307
+ .filter(|ath| **ath >= cfg.ath_price_threshold_usd)
1308
+ .map(|ath| TopTraderOfLink {
1309
+ timestamp: Utc::now().timestamp(),
1310
+ wallet: trader.wallet_address,
1311
+ token: trader.mint_address,
1312
+ pnl_at_creation: trader.realized_profit_pnl as f64,
1313
+ ath_usd_at_creation: *ath,
1314
+ })
1315
+ })
1316
+ .collect();
1317
+
1318
+ if !links.is_empty() {
1319
+ self.write_top_trader_of_links(&links).await?;
1320
+ }
1321
+
1322
+ println!(
1323
+ "[LinkGraph] [Profile] detect_and_write_top_trader_links: {} links in {:?}",
1324
+ links.len(),
1325
+ start.elapsed()
1326
+ );
1327
+ Ok(())
1328
+ }
1329
+
1330
+ async fn process_liquidity_events(&self, liquidity_adds: &[LiquidityRow]) -> Result<()> {
1331
+ let cfg = &*LINK_GRAPH_CONFIG;
1332
+ if liquidity_adds.is_empty() {
1333
+ return Ok(());
1334
+ }
1335
+ let unique_pools: HashSet<String> = liquidity_adds
1336
+ .iter()
1337
+ .map(|l| l.pool_address.clone())
1338
+ .collect();
1339
+ if unique_pools.is_empty() {
1340
+ return Ok(());
1341
+ }
1342
+
1343
+ #[derive(Row, Deserialize, Debug)]
1344
+ struct PoolInfo {
1345
+ pool_address: String,
1346
+ base_address: String,
1347
+ base_decimals: Option<u8>,
1348
+ quote_decimals: Option<u8>,
1349
+ }
1350
+
1351
+ let pool_query = "SELECT pool_address, base_address, base_decimals, quote_decimals FROM pool_creations WHERE pool_address IN ?";
1352
+ let mut pools_info: Vec<PoolInfo> = Vec::new();
1353
+ let unique_pools_vec: Vec<_> = unique_pools.iter().cloned().collect();
1354
+
1355
+ for chunk in unique_pools_vec.chunks(cfg.chunk_size_large) {
1356
+ let mut chunk_results = self
1357
+ .db_client
1358
+ .query(pool_query)
1359
+ .bind(chunk)
1360
+ .fetch_all()
1361
+ .await
1362
+ .map_err(|e| anyhow!("[LIQUIDITY_FAIL]: PoolQuery chunk failed. Error: {}", e))?;
1363
+ pools_info.append(&mut chunk_results);
1364
+ }
1365
+
1366
+ let pool_to_token_map: HashMap<String, (String, Option<u8>, Option<u8>)> = pools_info
1367
+ .into_iter()
1368
+ .map(|p| {
1369
+ (
1370
+ p.pool_address,
1371
+ (p.base_address, p.base_decimals, p.quote_decimals),
1372
+ )
1373
+ })
1374
+ .collect();
1375
+
1376
+ let links: Vec<_> = liquidity_adds
1377
+ .iter()
1378
+ .filter_map(|l| {
1379
+ pool_to_token_map.get(&l.pool_address).map(
1380
+ |(token_address, base_decimals, quote_decimals)| {
1381
+ let base_scale = 10f64.powi(base_decimals.unwrap_or(0) as i32);
1382
+ let quote_scale = 10f64.powi(quote_decimals.unwrap_or(0) as i32);
1383
+ ProvidedLiquidityLink {
1384
+ signature: l.signature.clone(),
1385
+ wallet: l.lp_provider.clone(),
1386
+ token: token_address.clone(),
1387
+ pool_address: l.pool_address.clone(),
1388
+ amount_base: l.base_amount as f64 / base_scale,
1389
+ amount_quote: l.quote_amount as f64 / quote_scale,
1390
+ timestamp: l.timestamp as i64,
1391
+ }
1392
+ },
1393
+ )
1394
+ })
1395
+ .collect();
1396
+
1397
+ if !links.is_empty() {
1398
+ self.write_provided_liquidity_links(&links).await?;
1399
+ }
1400
+ Ok(())
1401
+ }
1402
+
1403
+ async fn detect_and_write_whale_links(&self, trades: &[TradeRow]) -> Result<()> {
1404
+ let start = Instant::now();
1405
+ let cfg = &*LINK_GRAPH_CONFIG;
1406
+ let unique_mints_in_batch: Vec<String> = trades
1407
+ .iter()
1408
+ .map(|t| t.base_address.clone())
1409
+ .unique()
1410
+ .collect();
1411
+ if unique_mints_in_batch.is_empty() {
1412
+ return Ok(());
1413
+ }
1414
+
1415
+ // --- NEW: CONFIDENCE FILTER ---
1416
+ // 1. Check which of the mints in the batch have a creation event in our DB.
1417
+ #[derive(Row, Deserialize, Debug)]
1418
+ struct MintCheck {
1419
+ mint_address: String,
1420
+ }
1421
+ let mint_query = "SELECT DISTINCT mint_address FROM mints WHERE mint_address IN ?";
1422
+
1423
+ let mut fully_tracked_mints = HashSet::new();
1424
+ for chunk in unique_mints_in_batch.chunks(cfg.chunk_size_mint_large) {
1425
+ let chunk_rows: Vec<MintCheck> = self
1426
+ .with_ch_retry(
1427
+ || async {
1428
+ self.db_client
1429
+ .query(mint_query)
1430
+ .bind(chunk)
1431
+ .fetch_all()
1432
+ .await
1433
+ .map_err(anyhow::Error::from)
1434
+ },
1435
+ "Whale mint check chunk",
1436
+ )
1437
+ .await?;
1438
+ for mint_row in chunk_rows {
1439
+ fully_tracked_mints.insert(mint_row.mint_address);
1440
+ }
1441
+ }
1442
+
1443
+ if fully_tracked_mints.is_empty() {
1444
+ return Ok(());
1445
+ }
1446
+ let confident_mints: Vec<String> = fully_tracked_mints.iter().cloned().collect();
1447
+ let ath_map = self
1448
+ .fetch_latest_ath_map_with_retry(&confident_mints)
1449
+ .await?;
1450
+ if ath_map.is_empty() {
1451
+ return Ok(());
1452
+ }
1453
+ // --- END CONFIDENCE FILTER ---
1454
+
1455
+ #[derive(Row, Deserialize, Debug)]
1456
+ struct TokenInfo {
1457
+ token_address: String,
1458
+ total_supply: u64,
1459
+ decimals: u8,
1460
+ }
1461
+
1462
+ let token_query = "SELECT token_address, total_supply, decimals FROM tokens FINAL WHERE token_address IN ?";
1463
+
1464
+ // --- RE-INTRODUCED CHUNKING for the token pre-filter ---
1465
+ let mut context_map: HashMap<String, (u64, f64, u8)> = HashMap::new();
1466
+
1467
+ for chunk in confident_mints.chunks(cfg.chunk_size_token) {
1468
+ let mut attempts = 0;
1469
+ loop {
1470
+ attempts += 1;
1471
+ let result: Result<Vec<TokenInfo>> = self
1472
+ .db_client
1473
+ .query(token_query)
1474
+ .bind(chunk)
1475
+ .fetch_all()
1476
+ .await
1477
+ .map_err(anyhow::Error::from);
1478
+
1479
+ match result {
1480
+ Ok(chunk_results) => {
1481
+ for token in chunk_results {
1482
+ if let Some(ath) = ath_map.get(&token.token_address) {
1483
+ if *ath >= cfg.ath_price_threshold_usd {
1484
+ context_map.insert(
1485
+ token.token_address,
1486
+ (token.total_supply, *ath, token.decimals),
1487
+ );
1488
+ }
1489
+ }
1490
+ }
1491
+ break;
1492
+ }
1493
+ Err(e) => {
1494
+ if attempts >= cfg.ch_retry_attempts {
1495
+ return Err(anyhow!(
1496
+ "[WHALE_FAIL]: Token pre-filter chunk failed after {} attempts: {}",
1497
+ attempts,
1498
+ e
1499
+ ));
1500
+ }
1501
+ let backoff = cfg.ch_retry_backoff_ms * attempts as u64;
1502
+ eprintln!(
1503
+ "[LinkGraph] ⚠️ Whale token pre-filter retry {}/{} after {}ms: {}",
1504
+ attempts, cfg.ch_retry_attempts, backoff, e
1505
+ );
1506
+ sleep(Duration::from_millis(backoff)).await;
1507
+ }
1508
+ }
1509
+ }
1510
+ }
1511
+ // --- END CHUNKING ---
1512
+
1513
+ if context_map.is_empty() {
1514
+ return Ok(());
1515
+ }
1516
+
1517
+ let tokens_to_query: Vec<String> = context_map.keys().cloned().collect();
1518
+
1519
+ #[derive(Row, Deserialize, Debug)]
1520
+ struct WhaleInfo {
1521
+ wallet_address: String,
1522
+ mint_address: String,
1523
+ current_balance: f64,
1524
+ }
1525
+
1526
+ let whales_query = "
1527
+ SELECT wallet_address, mint_address, current_balance
1528
+ FROM wallet_holdings_latest
1529
+ WHERE mint_address IN ? AND current_balance > 0
1530
+ QUALIFY ROW_NUMBER() OVER (PARTITION BY mint_address ORDER BY current_balance DESC) <= ?
1531
+ ";
1532
+
1533
+ // --- RE-INTRODUCED CHUNKING for the main whale query ---
1534
+ let mut top_holders: Vec<WhaleInfo> = Vec::new();
1535
+ for chunk in tokens_to_query.chunks(cfg.chunk_size_token) {
1536
+ let mut attempts = 0;
1537
+ loop {
1538
+ attempts += 1;
1539
+ let result: Result<Vec<WhaleInfo>> = self
1540
+ .db_client
1541
+ .query(whales_query)
1542
+ .bind(chunk)
1543
+ .bind(cfg.whale_rank_threshold)
1544
+ .fetch_all()
1545
+ .await
1546
+ .map_err(anyhow::Error::from);
1547
+
1548
+ match result {
1549
+ Ok(chunk_results) => {
1550
+ top_holders.extend(chunk_results);
1551
+ break;
1552
+ }
1553
+ Err(e) => {
1554
+ if attempts >= cfg.ch_retry_attempts {
1555
+ return Err(anyhow!(
1556
+ "[WHALE_FAIL]: Holder query chunk failed after {} attempts: {}",
1557
+ attempts,
1558
+ e
1559
+ ));
1560
+ }
1561
+ let backoff = cfg.ch_retry_backoff_ms * attempts as u64;
1562
+ eprintln!(
1563
+ "[LinkGraph] ⚠️ Whale holder chunk retry {}/{} after {}ms: {}",
1564
+ attempts, cfg.ch_retry_attempts, backoff, e
1565
+ );
1566
+ sleep(Duration::from_millis(backoff)).await;
1567
+ }
1568
+ }
1569
+ }
1570
+ }
1571
+ // --- END CHUNKING ---
1572
+
1573
+ let mut links = Vec::new();
1574
+ for holder in top_holders {
1575
+ if let Some((raw_total_supply, ath_usd, decimals)) =
1576
+ context_map.get(&holder.mint_address)
1577
+ {
1578
+ if *raw_total_supply == 0 {
1579
+ continue;
1580
+ }
1581
+
1582
+ // --- THE FIX ---
1583
+ // Adjust the total supply to be human-readable before dividing.
1584
+ let human_total_supply = *raw_total_supply as f64 / 10f64.powi(*decimals as i32);
1585
+ if human_total_supply == 0.0 {
1586
+ continue;
1587
+ }
1588
+ // --- END FIX ---
1589
+
1590
+ let holding_pct = (holder.current_balance / human_total_supply) as f32;
1591
+
1592
+ links.push(WhaleOfLink {
1593
+ timestamp: Utc::now().timestamp(),
1594
+ wallet: holder.wallet_address.clone(),
1595
+ token: holder.mint_address.clone(),
1596
+ holding_pct_at_creation: holding_pct,
1597
+ ath_usd_at_creation: *ath_usd,
1598
+ });
1599
+ }
1600
+ }
1601
+
1602
+ if !links.is_empty() {
1603
+ self.write_whale_of_links(&links).await?;
1604
+ }
1605
+
1606
+ println!(
1607
+ "[LinkGraph] [Profile] detect_and_write_whale_links: {} links in {:?}",
1608
+ links.len(),
1609
+ start.elapsed()
1610
+ );
1611
+ Ok(())
1612
+ }
1613
+
1614
+ async fn create_wallet_nodes(&self, wallets: &HashSet<String>) -> Result<()> {
1615
+ if wallets.is_empty() {
1616
+ return Ok(());
1617
+ }
1618
+ let cfg = &*LINK_GRAPH_CONFIG;
1619
+
1620
+ // Convert the HashSet to a Vec to be able to create chunks
1621
+ let wallet_vec: Vec<_> = wallets.iter().cloned().collect();
1622
+
1623
+ // Process the wallets in smaller, manageable chunks
1624
+ for chunk in wallet_vec.chunks(cfg.chunk_size_large) {
1625
+ let params: Vec<_> = chunk
1626
+ .iter()
1627
+ .map(|addr| HashMap::from([("address".to_string(), BoltType::from(addr.clone()))]))
1628
+ .collect();
1629
+
1630
+ let cypher = "
1631
+ UNWIND $wallets as wallet
1632
+ MERGE (w:Wallet {address: wallet.address})
1633
+ ";
1634
+
1635
+ self.enqueue_write(cypher, params).await?;
1636
+ }
1637
+ Ok(())
1638
+ }
1639
+
1640
+ async fn create_token_nodes(&self, tokens: &HashSet<String>) -> Result<()> {
1641
+ if tokens.is_empty() {
1642
+ return Ok(());
1643
+ }
1644
+ let cfg = &*LINK_GRAPH_CONFIG;
1645
+
1646
+ // Convert the HashSet to a Vec to be able to create chunks
1647
+ let token_vec: Vec<_> = tokens.iter().cloned().collect();
1648
+
1649
+ // Process the tokens in smaller, manageable chunks
1650
+ for chunk in token_vec.chunks(cfg.chunk_size_large) {
1651
+ let params: Vec<_> = chunk
1652
+ .iter()
1653
+ .map(|addr| HashMap::from([("address".to_string(), BoltType::from(addr.clone()))]))
1654
+ .collect();
1655
+
1656
+ let cypher = "
1657
+ UNWIND $tokens as token
1658
+ MERGE (t:Token {address: token.address})
1659
+ ON CREATE SET t.created_ts = token.created_ts
1660
+ ";
1661
+
1662
+ self.enqueue_write(cypher, params).await?;
1663
+ }
1664
+ Ok(())
1665
+ }
1666
+
1667
+ async fn write_bundle_trade_links(&self, links: &[BundleTradeLink]) -> Result<()> {
1668
+ if links.is_empty() {
1669
+ return Ok(());
1670
+ }
1671
+ let params: Vec<_> = links
1672
+ .iter()
1673
+ .map(|l| {
1674
+ HashMap::from([
1675
+ ("wa".to_string(), BoltType::from(l.wallet_a.clone())),
1676
+ ("wb".to_string(), BoltType::from(l.wallet_b.clone())),
1677
+ ("mint".to_string(), BoltType::from(l.mint.clone())),
1678
+ ("slot".to_string(), BoltType::from(l.slot)),
1679
+ ("timestamp".to_string(), BoltType::from(l.timestamp)),
1680
+ (
1681
+ "signatures".to_string(),
1682
+ BoltType::from(l.signatures.clone()),
1683
+ ),
1684
+ ])
1685
+ })
1686
+ .collect();
1687
+ // Corrected relationship name to BUNDLE_TRADE for consistency
1688
+ let cypher = "
1689
+ UNWIND $x as t
1690
+ MERGE (a:Wallet {address: t.wa})
1691
+ MERGE (b:Wallet {address: t.wb})
1692
+ MERGE (a)-[r:BUNDLE_TRADE {mint: t.mint, slot: t.slot}]->(b)
1693
+ ON CREATE SET r.timestamp = t.timestamp, r.signatures = t.signatures
1694
+ ON MATCH SET r.timestamp = CASE WHEN t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END
1695
+ ";
1696
+ self.enqueue_write(cypher, params).await
1697
+ }
1698
+
1699
+ async fn write_transfer_links(&self, links: &[TransferLink]) -> Result<()> {
1700
+ if links.is_empty() {
1701
+ return Ok(());
1702
+ }
1703
+
1704
+ // --- THE FIX ---
1705
+ // Use `unique_by` to get the *entire first link object* for each unique path.
1706
+ // This preserves the signature and timestamp from the first event we see.
1707
+ let unique_links = links
1708
+ .iter()
1709
+ .unique_by(|l| (&l.source, &l.destination, &l.mint))
1710
+ .collect::<Vec<_>>();
1711
+
1712
+ // Now build the parameters with the full data from the unique links.
1713
+ let params: Vec<_> = unique_links
1714
+ .iter()
1715
+ .map(|l| {
1716
+ HashMap::from([
1717
+ ("source".to_string(), BoltType::from(l.source.clone())),
1718
+ (
1719
+ "destination".to_string(),
1720
+ BoltType::from(l.destination.clone()),
1721
+ ),
1722
+ ("mint".to_string(), BoltType::from(l.mint.clone())),
1723
+ ("signature".to_string(), BoltType::from(l.signature.clone())), // Include the signature
1724
+ ("timestamp".to_string(), BoltType::from(l.timestamp)), // Include the on-chain timestamp
1725
+ ("amount".to_string(), BoltType::from(l.amount)),
1726
+ ])
1727
+ })
1728
+ .collect();
1729
+
1730
+ // --- UPDATED CYPHER QUERY ---
1731
+ // The query now sets the signature and on-chain timestamp on the link when it's first created.
1732
+ let cypher = "
1733
+ UNWIND $x as t
1734
+ MERGE (s:Wallet {address: t.source})
1735
+ MERGE (d:Wallet {address: t.destination})
1736
+ MERGE (s)-[r:TRANSFERRED_TO {mint: t.mint}]->(d)
1737
+ ON CREATE SET
1738
+ r.signature = t.signature,
1739
+ r.timestamp = t.timestamp,
1740
+ r.amount = t.amount
1741
+ ON MATCH SET r.timestamp = CASE WHEN t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END
1742
+ ";
1743
+
1744
+ self.enqueue_write(cypher, params).await
1745
+ }
1746
+
1747
+ async fn write_coordinated_activity_links(
1748
+ &self,
1749
+ links: &[CoordinatedActivityLink],
1750
+ ) -> Result<()> {
1751
+ if links.is_empty() {
1752
+ return Ok(());
1753
+ }
1754
+
1755
+ let params: Vec<_> = links
1756
+ .iter()
1757
+ .map(|l| {
1758
+ HashMap::from([
1759
+ ("leader".to_string(), BoltType::from(l.leader.clone())),
1760
+ ("follower".to_string(), BoltType::from(l.follower.clone())),
1761
+ ("mint".to_string(), BoltType::from(l.mint.clone())),
1762
+ ("timestamp".to_string(), BoltType::from(l.timestamp)),
1763
+ // Use the new, correct field names
1764
+ (
1765
+ "l_sig_1".to_string(),
1766
+ BoltType::from(l.leader_first_sig.clone()),
1767
+ ),
1768
+ (
1769
+ "l_sig_2".to_string(),
1770
+ BoltType::from(l.leader_second_sig.clone()),
1771
+ ),
1772
+ (
1773
+ "f_sig_1".to_string(),
1774
+ BoltType::from(l.follower_first_sig.clone()),
1775
+ ),
1776
+ (
1777
+ "f_sig_2".to_string(),
1778
+ BoltType::from(l.follower_second_sig.clone()),
1779
+ ),
1780
+ ("gap_1".to_string(), BoltType::from(l.time_gap_on_first_sec)),
1781
+ (
1782
+ "gap_2".to_string(),
1783
+ BoltType::from(l.time_gap_on_second_sec),
1784
+ ),
1785
+ ])
1786
+ })
1787
+ .collect();
1788
+
1789
+ // This query now creates a single, comprehensive link per pair/mint
1790
+ let cypher = "
1791
+ UNWIND $x as t
1792
+ MERGE (l:Wallet {address: t.leader})
1793
+ MERGE (f:Wallet {address: t.follower})
1794
+ MERGE (f)-[r:COORDINATED_ACTIVITY {mint: t.mint}]->(l)
1795
+ ON CREATE SET
1796
+ r.timestamp = t.timestamp,
1797
+ r.leader_first_sig = t.l_sig_1,
1798
+ r.leader_second_sig = t.l_sig_2,
1799
+ r.follower_first_sig = t.f_sig_1,
1800
+ r.follower_second_sig = t.f_sig_2,
1801
+ r.time_gap_on_first_sec = t.gap_1,
1802
+ r.time_gap_on_second_sec = t.gap_2
1803
+ ON MATCH SET r.timestamp = CASE WHEN t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END
1804
+ ";
1805
+
1806
+ self.enqueue_write(cypher, params).await
1807
+ }
1808
+
1809
+ async fn write_copied_trade_links(&self, links: &[CopiedTradeLink]) -> Result<()> {
1810
+ if links.is_empty() {
1811
+ return Ok(());
1812
+ }
1813
+ // This uses the latest struct definition provided in the prompt.
1814
+ let params: Vec<_> = links
1815
+ .iter()
1816
+ .map(|l| {
1817
+ HashMap::from([
1818
+ ("follower".to_string(), BoltType::from(l.follower.clone())),
1819
+ ("leader".to_string(), BoltType::from(l.leader.clone())),
1820
+ ("mint".to_string(), BoltType::from(l.mint.clone())),
1821
+ ("buy_gap".to_string(), BoltType::from(l.time_gap_on_buy_sec)),
1822
+ (
1823
+ "sell_gap".to_string(),
1824
+ BoltType::from(l.time_gap_on_sell_sec),
1825
+ ),
1826
+ ("leader_pnl".to_string(), BoltType::from(l.leader_pnl)),
1827
+ ("follower_pnl".to_string(), BoltType::from(l.follower_pnl)),
1828
+ (
1829
+ "l_buy_sig".to_string(),
1830
+ BoltType::from(l.leader_buy_sig.clone()),
1831
+ ),
1832
+ (
1833
+ "l_sell_sig".to_string(),
1834
+ BoltType::from(l.leader_sell_sig.clone()),
1835
+ ),
1836
+ (
1837
+ "f_buy_sig".to_string(),
1838
+ BoltType::from(l.follower_buy_sig.clone()),
1839
+ ),
1840
+ (
1841
+ "f_sell_sig".to_string(),
1842
+ BoltType::from(l.follower_sell_sig.clone()),
1843
+ ),
1844
+ (
1845
+ "l_buy_total".to_string(),
1846
+ BoltType::from(l.leader_buy_total),
1847
+ ),
1848
+ (
1849
+ "l_sell_total".to_string(),
1850
+ BoltType::from(l.leader_sell_total),
1851
+ ),
1852
+ (
1853
+ "f_buy_total".to_string(),
1854
+ BoltType::from(l.follower_buy_total),
1855
+ ),
1856
+ (
1857
+ "f_sell_total".to_string(),
1858
+ BoltType::from(l.follower_sell_total),
1859
+ ),
1860
+ (
1861
+ "f_buy_slip".to_string(),
1862
+ BoltType::from(l.follower_buy_slippage),
1863
+ ),
1864
+ (
1865
+ "f_sell_slip".to_string(),
1866
+ BoltType::from(l.follower_sell_slippage),
1867
+ ),
1868
+ ("timestamp".to_string(), BoltType::from(l.timestamp)),
1869
+ ])
1870
+ })
1871
+ .collect();
1872
+ let cypher = "
1873
+ UNWIND $x as t
1874
+ MERGE (f:Wallet {address: t.follower})
1875
+ MERGE (l:Wallet {address: t.leader})
1876
+ MERGE (f)-[r:COPIED_TRADE {mint: t.mint}]->(l)
1877
+ ON CREATE SET
1878
+ r.timestamp = t.timestamp,
1879
+ r.follower = t.follower,
1880
+ r.leader = t.leader,
1881
+ r.mint = t.mint,
1882
+ r.buy_gap = t.buy_gap,
1883
+ r.sell_gap = t.sell_gap,
1884
+ r.leader_pnl = t.leader_pnl,
1885
+ r.follower_pnl = t.follower_pnl,
1886
+ r.l_buy_sig = t.l_buy_sig,
1887
+ r.l_sell_sig = t.l_sell_sig,
1888
+ r.f_buy_sig = t.f_buy_sig,
1889
+ r.f_sell_sig = t.f_sell_sig,
1890
+ r.l_buy_total = t.l_buy_total,
1891
+ r.l_sell_total = t.l_sell_total,
1892
+ r.f_buy_total = t.f_buy_total,
1893
+ r.f_sell_total = t.f_sell_total,
1894
+ r.f_buy_slip = t.f_buy_slip,
1895
+ r.f_sell_slip = t.f_sell_slip
1896
+ ON MATCH SET r.timestamp = CASE WHEN t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END
1897
+ ";
1898
+ self.enqueue_write(cypher, params).await
1899
+ }
1900
+
1901
+ async fn write_minted_links(&self, links: &[MintedLink], mints: &[MintRow]) -> Result<()> {
1902
+ if links.is_empty() {
1903
+ return Ok(());
1904
+ }
1905
+ let mint_map: HashMap<_, _> = mints.iter().map(|m| (m.signature.clone(), m)).collect();
1906
+
1907
+ let params: Vec<_> = links
1908
+ .iter()
1909
+ .filter_map(|l| {
1910
+ mint_map.get(&l.signature).map(|m| {
1911
+ HashMap::from([
1912
+ (
1913
+ "creator".to_string(),
1914
+ BoltType::from(m.creator_address.clone()),
1915
+ ),
1916
+ ("token".to_string(), BoltType::from(m.mint_address.clone())),
1917
+ ("signature".to_string(), BoltType::from(l.signature.clone())),
1918
+ ("timestamp".to_string(), BoltType::from(l.timestamp)),
1919
+ ("buy_amount".to_string(), BoltType::from(l.buy_amount)),
1920
+ ])
1921
+ })
1922
+ })
1923
+ .collect();
1924
+
1925
+ if params.is_empty() {
1926
+ return Ok(());
1927
+ }
1928
+ // --- MODIFIED: MERGE on the signature for idempotency ---
1929
+ let cypher = "
1930
+ UNWIND $x as t
1931
+ MERGE (c:Wallet {address: t.creator})
1932
+ MERGE (k:Token {address: t.token})
1933
+ MERGE (c)-[r:MINTED {signature: t.signature}]->(k)
1934
+ ON CREATE SET r.timestamp = t.timestamp, r.buy_amount = t.buy_amount
1935
+ ON MATCH SET r.timestamp = CASE WHEN t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END
1936
+ ";
1937
+ self.enqueue_write(cypher, params).await
1938
+ }
1939
+
1940
+ async fn write_sniped_links(
1941
+ &self,
1942
+ links: &[SnipedLink],
1943
+ snipers: &HashMap<String, (String, String)>,
1944
+ ) -> Result<()> {
1945
+ if links.is_empty() {
1946
+ return Ok(());
1947
+ }
1948
+
1949
+ let params: Vec<_> = links
1950
+ .iter()
1951
+ .filter_map(|l| {
1952
+ snipers.get(&l.signature).map(|(wallet, token)| {
1953
+ HashMap::from([
1954
+ ("wallet".to_string(), BoltType::from(wallet.clone())),
1955
+ ("token".to_string(), BoltType::from(token.clone())),
1956
+ ("signature".to_string(), BoltType::from(l.signature.clone())),
1957
+ ("rank".to_string(), BoltType::from(l.rank)),
1958
+ ("sniped_amount".to_string(), BoltType::from(l.sniped_amount)),
1959
+ ("timestamp".to_string(), BoltType::from(l.timestamp)),
1960
+ ])
1961
+ })
1962
+ })
1963
+ .collect();
1964
+
1965
+ if params.is_empty() {
1966
+ return Ok(());
1967
+ }
1968
+
1969
+ // --- MODIFIED: MERGE on signature ---
1970
+ let cypher = "
1971
+ UNWIND $x as t
1972
+ MERGE (w:Wallet {address: t.wallet})
1973
+ MERGE (k:Token {address: t.token})
1974
+ MERGE (w)-[r:SNIPED {signature: t.signature}]->(k)
1975
+ ON CREATE SET r.rank = t.rank, r.sniped_amount = t.sniped_amount, r.timestamp = t.timestamp
1976
+ ON MATCH SET r.timestamp = CASE WHEN t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END
1977
+ ";
1978
+ self.enqueue_write(cypher, params).await
1979
+ }
1980
+
1981
+ async fn write_locked_supply_links(
1982
+ &self,
1983
+ links: &[LockedSupplyLink],
1984
+ locks: &[SupplyLockRow],
1985
+ ) -> Result<()> {
1986
+ if links.is_empty() {
1987
+ return Ok(());
1988
+ }
1989
+ let lock_map: HashMap<_, _> = locks.iter().map(|l| (l.signature.clone(), l)).collect();
1990
+
1991
+ let params: Vec<_> = links
1992
+ .iter()
1993
+ .filter_map(|l| {
1994
+ lock_map.get(&l.signature).map(|lock_row| {
1995
+ HashMap::from([
1996
+ (
1997
+ "sender".to_string(),
1998
+ BoltType::from(lock_row.sender.clone()),
1999
+ ),
2000
+ (
2001
+ "recipient".to_string(),
2002
+ BoltType::from(lock_row.recipient.clone()),
2003
+ ),
2004
+ (
2005
+ "mint".to_string(),
2006
+ BoltType::from(lock_row.mint_address.clone()),
2007
+ ),
2008
+ ("signature".to_string(), BoltType::from(l.signature.clone())),
2009
+ ("amount".to_string(), BoltType::from(l.amount)),
2010
+ (
2011
+ "unlock_ts".to_string(),
2012
+ BoltType::from(l.unlock_timestamp as i64),
2013
+ ),
2014
+ ("timestamp".to_string(), BoltType::from(l.timestamp)),
2015
+ ])
2016
+ })
2017
+ })
2018
+ .collect();
2019
+
2020
+ if params.is_empty() {
2021
+ return Ok(());
2022
+ }
2023
+
2024
+ // --- THE CRITICAL FIX ---
2025
+ let cypher = "
2026
+ UNWIND $x as t
2027
+ MERGE (s:Wallet {address: t.sender})
2028
+ MERGE (k:Token {address: t.mint})
2029
+ MERGE (s)-[r:LOCKED_SUPPLY {signature: t.signature}]->(k)
2030
+ ON CREATE SET r.amount = t.amount, r.unlock_timestamp = t.unlock_ts, r.recipient = t.recipient, r.timestamp = t.timestamp
2031
+ ON MATCH SET r.timestamp = CASE WHEN t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END
2032
+ ";
2033
+ self.enqueue_write(cypher, params).await
2034
+ }
2035
+
2036
+ async fn write_burned_links(&self, links: &[BurnedLink], burns: &[BurnRow]) -> Result<()> {
2037
+ if links.is_empty() {
2038
+ return Ok(());
2039
+ }
2040
+ let burn_map: HashMap<_, _> = burns.iter().map(|b| (b.signature.clone(), b)).collect();
2041
+
2042
+ let params: Vec<_> = links
2043
+ .iter()
2044
+ .filter_map(|l| {
2045
+ burn_map.get(&l.signature).map(|burn_row| {
2046
+ HashMap::from([
2047
+ (
2048
+ "wallet".to_string(),
2049
+ BoltType::from(burn_row.source.clone()),
2050
+ ),
2051
+ (
2052
+ "token".to_string(),
2053
+ BoltType::from(burn_row.mint_address.clone()),
2054
+ ),
2055
+ ("signature".to_string(), BoltType::from(l.signature.clone())),
2056
+ ("amount".to_string(), BoltType::from(l.amount)),
2057
+ ("timestamp".to_string(), BoltType::from(l.timestamp)),
2058
+ ])
2059
+ })
2060
+ })
2061
+ .collect();
2062
+
2063
+ if params.is_empty() {
2064
+ return Ok(());
2065
+ }
2066
+ // --- MODIFIED: MERGE on signature ---
2067
+ let cypher = "
2068
+ UNWIND $x as t
2069
+ MATCH (w:Wallet {address: t.wallet}), (k:Token {address: t.token})
2070
+ MERGE (w)-[r:BURNED {signature: t.signature}]->(k)
2071
+ ON CREATE SET r.amount = t.amount, r.timestamp = t.timestamp
2072
+ ON MATCH SET r.timestamp = CASE WHEN t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END
2073
+ ";
2074
+ self.enqueue_write(cypher, params).await
2075
+ }
2076
+
2077
+ async fn write_provided_liquidity_links(&self, links: &[ProvidedLiquidityLink]) -> Result<()> {
2078
+ if links.is_empty() {
2079
+ return Ok(());
2080
+ }
2081
+ let params: Vec<_> = links
2082
+ .iter()
2083
+ .map(|l| {
2084
+ HashMap::from([
2085
+ ("wallet".to_string(), BoltType::from(l.wallet.clone())),
2086
+ ("token".to_string(), BoltType::from(l.token.clone())),
2087
+ ("signature".to_string(), BoltType::from(l.signature.clone())),
2088
+ (
2089
+ "pool_address".to_string(),
2090
+ BoltType::from(l.pool_address.clone()),
2091
+ ),
2092
+ ("amount_base".to_string(), BoltType::from(l.amount_base)),
2093
+ ("amount_quote".to_string(), BoltType::from(l.amount_quote)),
2094
+ ("timestamp".to_string(), BoltType::from(l.timestamp)),
2095
+ ])
2096
+ })
2097
+ .collect();
2098
+
2099
+ // --- MODIFIED: MERGE on signature ---
2100
+ let cypher = "
2101
+ UNWIND $x as t
2102
+ MERGE (w:Wallet {address: t.wallet})
2103
+ MERGE (k:Token {address: t.token})
2104
+ MERGE (w)-[r:PROVIDED_LIQUIDITY {signature: t.signature}]->(k)
2105
+ ON CREATE SET r.pool_address = t.pool_address, r.amount_base = t.amount_base, r.amount_quote = t.amount_quote, r.timestamp = t.timestamp
2106
+ ON MATCH SET r.timestamp = CASE WHEN t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END
2107
+ ";
2108
+ self.enqueue_write(cypher, params).await
2109
+ }
2110
+
2111
+ async fn write_top_trader_of_links(&self, links: &[TopTraderOfLink]) -> Result<()> {
2112
+ if links.is_empty() {
2113
+ return Ok(());
2114
+ }
2115
+ let params: Vec<_> = links
2116
+ .iter()
2117
+ .map(|l| {
2118
+ HashMap::from([
2119
+ ("wallet".to_string(), BoltType::from(l.wallet.clone())),
2120
+ ("token".to_string(), BoltType::from(l.token.clone())),
2121
+ // Add new params
2122
+ (
2123
+ "pnl_at_creation".to_string(),
2124
+ BoltType::from(l.pnl_at_creation),
2125
+ ),
2126
+ (
2127
+ "ath_at_creation".to_string(),
2128
+ BoltType::from(l.ath_usd_at_creation),
2129
+ ),
2130
+ ("timestamp".to_string(), BoltType::from(l.timestamp)),
2131
+ ])
2132
+ })
2133
+ .collect();
2134
+
2135
+ // --- MODIFIED: The definitive Cypher query ---
2136
+ let cypher = "
2137
+ UNWIND $x as t
2138
+ MERGE (w:Wallet {address: t.wallet})
2139
+ MERGE (k:Token {address: t.token})
2140
+ MERGE (w)-[r:TOP_TRADER_OF]->(k)
2141
+ ON CREATE SET
2142
+ r.pnl_at_creation = t.pnl_at_creation,
2143
+ r.ath_usd_at_creation = t.ath_at_creation,
2144
+ r.timestamp = t.timestamp
2145
+ ON MATCH SET r.timestamp = CASE WHEN t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END
2146
+ ";
2147
+ self.enqueue_write(cypher, params).await
2148
+ }
2149
+
2150
+ async fn write_whale_of_links(&self, links: &[WhaleOfLink]) -> Result<()> {
2151
+ if links.is_empty() {
2152
+ return Ok(());
2153
+ }
2154
+ let params: Vec<_> = links
2155
+ .iter()
2156
+ .map(|l| {
2157
+ HashMap::from([
2158
+ ("wallet".to_string(), BoltType::from(l.wallet.clone())),
2159
+ ("token".to_string(), BoltType::from(l.token.clone())),
2160
+ // Add new params
2161
+ (
2162
+ "pct_at_creation".to_string(),
2163
+ BoltType::from(l.holding_pct_at_creation),
2164
+ ),
2165
+ (
2166
+ "ath_at_creation".to_string(),
2167
+ BoltType::from(l.ath_usd_at_creation),
2168
+ ),
2169
+ ("timestamp".to_string(), BoltType::from(l.timestamp)),
2170
+ ])
2171
+ })
2172
+ .collect();
2173
+
2174
+ // --- MODIFIED: The definitive Cypher query ---
2175
+ let cypher = "
2176
+ UNWIND $x as t
2177
+ MERGE (w:Wallet {address: t.wallet})
2178
+ MERGE (k:Token {address: t.token})
2179
+ MERGE (w)-[r:WHALE_OF]->(k)
2180
+ ON CREATE SET
2181
+ r.holding_pct_at_creation = t.pct_at_creation,
2182
+ r.ath_usd_at_creation = t.ath_at_creation,
2183
+ r.timestamp = t.timestamp
2184
+ ON MATCH SET r.timestamp = CASE WHEN t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END
2185
+ ";
2186
+ self.enqueue_write(cypher, params).await
2187
+ }
2188
+
2189
+ async fn fetch_latest_ath_map_with_retry(
2190
+ &self,
2191
+ token_addresses: &[String],
2192
+ ) -> Result<HashMap<String, f64>> {
2193
+ let mut ath_map = HashMap::new();
2194
+ if token_addresses.is_empty() {
2195
+ return Ok(ath_map);
2196
+ }
2197
+ let cfg = &*LINK_GRAPH_CONFIG;
2198
+
2199
+ #[derive(Row, Deserialize, Debug)]
2200
+ struct AthInfo {
2201
+ token_address: String,
2202
+ ath_price_usd: f64,
2203
+ }
2204
+
2205
+ let query = "
2206
+ SELECT token_address, ath_price_usd
2207
+ FROM token_metrics_latest
2208
+ WHERE token_address IN ?
2209
+ ORDER BY token_address, updated_at DESC
2210
+ LIMIT 1 BY token_address
2211
+ ";
2212
+
2213
+ for chunk in token_addresses.chunks(cfg.ath_fetch_chunk_size.max(1)) {
2214
+ let mut attempts = 0;
2215
+ loop {
2216
+ attempts += 1;
2217
+ let result: Result<Vec<AthInfo>> = self
2218
+ .db_client
2219
+ .query(query)
2220
+ .bind(chunk)
2221
+ .fetch_all()
2222
+ .await
2223
+ .map_err(|e| anyhow!("[LinkGraph] ATH fetch failed: {}", e));
2224
+
2225
+ match result {
2226
+ Ok(mut chunk_rows) => {
2227
+ for row in chunk_rows.drain(..) {
2228
+ ath_map.insert(row.token_address, row.ath_price_usd);
2229
+ }
2230
+ break;
2231
+ }
2232
+ Err(e) => {
2233
+ if attempts >= cfg.ch_retry_attempts {
2234
+ eprintln!(
2235
+ "[LinkGraph] 🔴 ATH fetch failed after {} attempts: {}",
2236
+ attempts, e
2237
+ );
2238
+ std::process::exit(1);
2239
+ }
2240
+ let backoff = cfg.ch_retry_backoff_ms * attempts as u64;
2241
+ eprintln!(
2242
+ "[LinkGraph] ⚠️ ATH fetch retry {}/{} after {}ms: {}",
2243
+ attempts, cfg.ch_retry_attempts, backoff, e
2244
+ );
2245
+ sleep(Duration::from_millis(backoff)).await;
2246
+ }
2247
+ }
2248
+ }
2249
+ }
2250
+
2251
+ Ok(ath_map)
2252
+ }
2253
+
2254
+ async fn fetch_pnl(&self, wallet_address: &str, mint_address: &str) -> Result<f64> {
2255
+ let q_str = format!(
2256
+ "SELECT realized_profit_pnl FROM wallet_holdings_latest WHERE wallet_address = '{}' AND mint_address = '{}'",
2257
+ wallet_address, mint_address
2258
+ );
2259
+ // Fetch the pre-calculated f32 value
2260
+ let pnl_f32 = self
2261
+ .with_ch_retry(
2262
+ || async {
2263
+ self.db_client
2264
+ .query(&q_str)
2265
+ .fetch_one::<f32>()
2266
+ .await
2267
+ .map_err(anyhow::Error::from)
2268
+ },
2269
+ "Fetch PNL",
2270
+ )
2271
+ .await?;
2272
+ // Cast to f64 for the return type
2273
+ Ok(pnl_f32 as f64)
2274
+ }
2275
+ }
log.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:47b5b03f090da19eba850d54ea4cab1a97ebfdb7712ef4842cfc43804ec411b8
3
+ size 10517118
models/HoldersEncoder.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import List, Dict, Any
4
+
5
+ class HolderDistributionEncoder(nn.Module):
6
+ """
7
+ Encodes a list of top holders (wallet embeddings + holding percentages)
8
+ into a single fixed-size embedding representing the holder distribution.
9
+ It uses a Transformer Encoder to capture patterns and relationships.
10
+ """
11
+ def __init__(self,
12
+ wallet_embedding_dim: int,
13
+ output_dim: int,
14
+ nhead: int = 4,
15
+ num_layers: int = 2,
16
+ dtype: torch.dtype = torch.float16):
17
+ super().__init__()
18
+ self.wallet_embedding_dim = wallet_embedding_dim
19
+ self.output_dim = output_dim
20
+ self.dtype = dtype
21
+
22
+ # 1. MLP to project holding percentage to the wallet embedding dimension
23
+ self.pct_proj = nn.Sequential(
24
+ nn.Linear(1, wallet_embedding_dim // 4),
25
+ nn.GELU(),
26
+ nn.Linear(wallet_embedding_dim // 4, wallet_embedding_dim)
27
+ ).to(dtype)
28
+
29
+ # 2. Transformer Encoder to process the sequence of holders
30
+ encoder_layer = nn.TransformerEncoderLayer(
31
+ d_model=wallet_embedding_dim,
32
+ nhead=nhead,
33
+ dim_feedforward=wallet_embedding_dim * 4,
34
+ dropout=0.1,
35
+ activation='gelu',
36
+ batch_first=True,
37
+ dtype=dtype
38
+ )
39
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
40
+
41
+ # 3. A learnable [CLS] token to aggregate the sequence information
42
+ self.cls_token = nn.Parameter(torch.randn(1, 1, wallet_embedding_dim, dtype=dtype))
43
+
44
+ # 4. Final projection layer to get the desired output dimension
45
+ self.final_proj = nn.Linear(wallet_embedding_dim, output_dim).to(dtype)
46
+
47
+ def forward(self, holder_data: List[Dict[str, Any]]) -> torch.Tensor:
48
+ """
49
+ Args:
50
+ holder_data: A list of dictionaries, where each dict contains:
51
+ 'wallet_embedding': A tensor of shape [wallet_embedding_dim]
52
+ 'pct': The holding percentage as a float.
53
+
54
+ Returns:
55
+ A tensor of shape [1, output_dim] representing the entire distribution.
56
+ """
57
+ if not holder_data:
58
+ # Return a zero tensor if there are no holders
59
+ return torch.zeros(1, self.output_dim, device=self.cls_token.device, dtype=self.dtype)
60
+
61
+ # Prepare inputs for the transformer
62
+ wallet_embeds = torch.stack([d['wallet_embedding'] for d in holder_data])
63
+ holder_pcts = torch.tensor([[d['pct']] for d in holder_data], device=wallet_embeds.device, dtype=self.dtype)
64
+
65
+ # Project percentages and add to wallet embeddings to create holder features
66
+ pct_embeds = self.pct_proj(holder_pcts)
67
+ holder_inputs = (wallet_embeds + pct_embeds).unsqueeze(0) # Add batch dimension
68
+
69
+ # Prepend the [CLS] token
70
+ batch_size = holder_inputs.size(0)
71
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
72
+ transformer_input = torch.cat((cls_tokens, holder_inputs), dim=1)
73
+
74
+ # Pass through the transformer
75
+ transformer_output = self.transformer_encoder(transformer_input)
76
+
77
+ # Get the embedding of the [CLS] token (the first token)
78
+ cls_embedding = transformer_output[:, 0, :]
79
+
80
+ # Project to the final output dimension
81
+ return self.final_proj(cls_embedding)
models/SocialEncoders.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Optional, Dict, Any
5
+ import models.vocabulary as vocab # For event type IDs
6
+
7
+ class XPostEncoder(nn.Module):
8
+ """ Encodes: <AuthorWalletEmbedding>, <PostTextEmbedding>, <MediaEmbedding> """
9
+ def __init__(self, d_model: int, dtype: torch.dtype):
10
+ super().__init__()
11
+ # Input: Wallet (d_model) + Text (d_model) + Media (d_model)
12
+ self.mlp = nn.Sequential(
13
+ nn.Linear(d_model * 3, d_model * 2),
14
+ nn.GELU(),
15
+ nn.LayerNorm(d_model * 2),
16
+ nn.Linear(d_model * 2, d_model)
17
+ ).to(dtype)
18
+
19
+ def forward(self, author_emb: torch.Tensor, text_emb: torch.Tensor, media_emb: torch.Tensor) -> torch.Tensor:
20
+ combined = torch.cat([author_emb, text_emb, media_emb], dim=-1)
21
+ return self.mlp(combined)
22
+
23
+ class XRetweetEncoder(nn.Module):
24
+ """ Encodes: <RetweeterWalletEmbedding>, <OriginalAuthorWalletEmbedding>, <OriginalPostTextEmbedding>, <OriginalPostMediaEmbedding> """
25
+ def __init__(self, d_model: int, dtype: torch.dtype):
26
+ super().__init__()
27
+ # Input: Retweeter (d_model) + Original Author (d_model) + Original Text (d_model) + Original Media (d_model)
28
+ self.mlp = nn.Sequential(
29
+ nn.Linear(d_model * 4, d_model * 2),
30
+ nn.GELU(),
31
+ nn.LayerNorm(d_model * 2),
32
+ nn.Linear(d_model * 2, d_model)
33
+ ).to(dtype)
34
+
35
+ def forward(self,
36
+ retweeter_emb: torch.Tensor,
37
+ orig_author_emb: torch.Tensor,
38
+ orig_text_emb: torch.Tensor,
39
+ orig_media_emb: torch.Tensor) -> torch.Tensor:
40
+ combined = torch.cat([retweeter_emb, orig_author_emb, orig_text_emb, orig_media_emb], dim=-1)
41
+ return self.mlp(combined)
42
+
43
+ class XReplyEncoder(nn.Module):
44
+ """ Encodes: <AuthorWalletEmbedding>, <PostTextEmbedding>, <MediaEmbedding>, <MainTweetEmbedding> """
45
+ def __init__(self, d_model: int, dtype: torch.dtype):
46
+ super().__init__()
47
+ # Input: Author (d_model) + Reply Text (d_model) + Reply Media (d_model) + Main Tweet Text (d_model)
48
+ self.mlp = nn.Sequential(
49
+ nn.Linear(d_model * 4, d_model * 2),
50
+ nn.GELU(),
51
+ nn.LayerNorm(d_model * 2),
52
+ nn.Linear(d_model * 2, d_model)
53
+ ).to(dtype)
54
+
55
+ def forward(self,
56
+ author_emb: torch.Tensor,
57
+ text_emb: torch.Tensor,
58
+ media_emb: torch.Tensor,
59
+ main_tweet_emb: torch.Tensor) -> torch.Tensor:
60
+ combined = torch.cat([author_emb, text_emb, media_emb, main_tweet_emb], dim=-1)
61
+ return self.mlp(combined)
62
+
63
+ class XQuoteTweetEncoder(nn.Module):
64
+ """ Encodes: <QuoterWalletEmbedding>, <QuoterTextEmbedding>, <OriginalAuthorWalletEmbedding>, <OriginalPostTextEmbedding>, <OriginalPostMediaEmbedding> """
65
+ def __init__(self, d_model: int, dtype: torch.dtype):
66
+ super().__init__()
67
+ # Input: Quoter Wallet (d_model) + Quoter Text (d_model) + Orig Author (d_model) + Orig Text (d_model) + Orig Media (d_model)
68
+ self.mlp = nn.Sequential(
69
+ nn.Linear(d_model * 5, d_model * 2),
70
+ nn.GELU(),
71
+ nn.LayerNorm(d_model * 2),
72
+ nn.Linear(d_model * 2, d_model)
73
+ ).to(dtype)
74
+
75
+ def forward(self,
76
+ quoter_wallet_emb: torch.Tensor,
77
+ quoter_text_emb: torch.Tensor,
78
+ orig_author_emb: torch.Tensor,
79
+ orig_text_emb: torch.Tensor,
80
+ orig_media_emb: torch.Tensor) -> torch.Tensor:
81
+ combined = torch.cat([quoter_wallet_emb, quoter_text_emb, orig_author_emb, orig_text_emb, orig_media_emb], dim=-1)
82
+ return self.mlp(combined)
83
+
84
+ class PumpReplyEncoder(nn.Module):
85
+ """ Encodes: <UserWalletEmbedding>, <ReplyTextEmbedding> """
86
+ def __init__(self, d_model: int, dtype: torch.dtype):
87
+ super().__init__()
88
+ # Input: User Wallet (d_model) + Reply Text (d_model)
89
+ self.mlp = nn.Sequential(
90
+ nn.Linear(d_model * 2, d_model * 2),
91
+ nn.GELU(),
92
+ nn.LayerNorm(d_model * 2),
93
+ nn.Linear(d_model * 2, d_model)
94
+ ).to(dtype)
95
+
96
+ def forward(self, user_emb: torch.Tensor, text_emb: torch.Tensor) -> torch.Tensor:
97
+ combined = torch.cat([user_emb, text_emb], dim=-1)
98
+ return self.mlp(combined)
99
+
100
+ # --- NEW: Encoders for other text-based events ---
101
+ class DexProfileUpdatedEncoder(nn.Module):
102
+ """ Encodes: <4_flags_projection>, <website_emb>, <twitter_emb>, <telegram_emb>, <description_emb> """
103
+ def __init__(self, d_model: int, dtype: torch.dtype):
104
+ super().__init__()
105
+ # Input: flags_proj (d_model) + 4x text_embeds (d_model)
106
+ self.mlp = nn.Sequential(
107
+ nn.Linear(d_model * 4, d_model * 2), # Corrected from 5 to 4, flags are separate
108
+ nn.GELU(),
109
+ nn.LayerNorm(d_model * 2),
110
+ nn.Linear(d_model * 2, d_model)
111
+ ).to(dtype)
112
+
113
+ def forward(self, website_emb: torch.Tensor, twitter_emb: torch.Tensor, telegram_emb: torch.Tensor, description_emb: torch.Tensor) -> torch.Tensor:
114
+ combined = torch.cat([website_emb, twitter_emb, telegram_emb, description_emb], dim=-1)
115
+ return self.mlp(combined)
116
+
117
+ class GlobalTrendingEncoder(nn.Module):
118
+ """ Encodes: <hashtag_emb> """
119
+ def __init__(self, d_model: int, dtype: torch.dtype):
120
+ super().__init__()
121
+ # Input: hashtag_emb (d_model)
122
+ self.mlp = nn.Sequential(
123
+ nn.Linear(d_model, d_model),
124
+ nn.GELU(),
125
+ nn.Linear(d_model, d_model)
126
+ ).to(dtype)
127
+
128
+ def forward(self, hashtag_emb: torch.Tensor) -> torch.Tensor:
129
+ return self.mlp(hashtag_emb)
130
+ class SocialEncoder(nn.Module):
131
+ """
132
+ A single module to house all social event encoders.
133
+ This simplifies instantiation in the main Oracle model.
134
+ """
135
+ def __init__(self, d_model: int, dtype: torch.dtype):
136
+ super().__init__()
137
+ self.x_post_encoder = XPostEncoder(d_model, dtype)
138
+ self.x_retweet_encoder = XRetweetEncoder(d_model, dtype)
139
+ self.x_reply_encoder = XReplyEncoder(d_model, dtype)
140
+ self.x_quote_tweet_encoder = XQuoteTweetEncoder(d_model, dtype)
141
+ self.pump_reply_encoder = PumpReplyEncoder(d_model, dtype)
142
+ # --- NEW: Add the other text-based encoders ---
143
+ self.dex_profile_encoder = DexProfileUpdatedEncoder(d_model, dtype)
144
+ self.global_trending_encoder = GlobalTrendingEncoder(d_model, dtype)
145
+
146
+ # Store for convenience
147
+ self.d_model = d_model
148
+ self.dtype = dtype
149
+
150
+ def forward(self, batch: Dict[str, Any], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor:
151
+ """
152
+ REFACTORED: Processes all text-based events for the entire batch in a vectorized way.
153
+ This replaces the inefficient loops in the main Oracle model.
154
+ """
155
+ device = gathered_embeds['wallet'].device
156
+ B, L, D = gathered_embeds['wallet'].shape
157
+ final_embeds = torch.zeros(B, L, D, device=device, dtype=self.dtype)
158
+
159
+ textual_event_indices = batch['textual_event_indices']
160
+ textual_event_data = batch.get('textual_event_data', [])
161
+ precomputed_lookup = gathered_embeds['precomputed']
162
+ zero_emb = torch.zeros(self.d_model, device=device, dtype=self.dtype)
163
+
164
+ # --- Create masks for each event type ---
165
+ event_type_ids = batch['event_type_ids']
166
+ event_masks = {
167
+ 'XPost': (event_type_ids == vocab.EVENT_TO_ID.get('XPost', -1)),
168
+ 'XReply': (event_type_ids == vocab.EVENT_TO_ID.get('XReply', -1)),
169
+ 'XRetweet': (event_type_ids == vocab.EVENT_TO_ID.get('XRetweet', -1)),
170
+ 'XQuoteTweet': (event_type_ids == vocab.EVENT_TO_ID.get('XQuoteTweet', -1)),
171
+ 'PumpReply': (event_type_ids == vocab.EVENT_TO_ID.get('PumpReply', -1)),
172
+ 'DexProfile_Updated': (event_type_ids == vocab.EVENT_TO_ID.get('DexProfile_Updated', -1)),
173
+ 'TikTok_Trending_Hashtag': (event_type_ids == vocab.EVENT_TO_ID.get('TikTok_Trending_Hashtag', -1)),
174
+ 'XTrending_Hashtag': (event_type_ids == vocab.EVENT_TO_ID.get('XTrending_Hashtag', -1)),
175
+ }
176
+
177
+ # --- Gather all necessary pre-computed embeddings in one go ---
178
+ # Flatten indices for efficient lookup, then reshape
179
+ flat_indices = textual_event_indices.flatten()
180
+ # Create a default event structure for padding indices (idx=0)
181
+ default_event = {'event_type': 'PAD'}
182
+ # Use 1-based index from collator, so textual_event_data[idx-1]
183
+ raw_events_flat = [textual_event_data[idx-1] if idx > 0 else default_event for idx in flat_indices.tolist()]
184
+
185
+ # Helper to gather embeddings for a specific key
186
+ def gather_precomputed(key: str) -> torch.Tensor:
187
+ indices = torch.tensor([e.get(key, 0) for e in raw_events_flat], device=device, dtype=torch.long)
188
+ return F.embedding(indices, precomputed_lookup).view(B, L, -1)
189
+
190
+ # --- Process each event type ---
191
+
192
+ # XPost
193
+ if event_masks['XPost'].any():
194
+ text_emb = gather_precomputed('text_emb_idx')
195
+ media_emb = gather_precomputed('media_emb_idx')
196
+ post_embeds = self.x_post_encoder(gathered_embeds['wallet'], text_emb, media_emb)
197
+ final_embeds += post_embeds * event_masks['XPost'].unsqueeze(-1)
198
+
199
+ # XReply
200
+ if event_masks['XReply'].any():
201
+ text_emb = gather_precomputed('text_emb_idx')
202
+ media_emb = gather_precomputed('media_emb_idx')
203
+ main_tweet_emb = gather_precomputed('main_tweet_text_emb_idx')
204
+ reply_embeds = self.x_reply_encoder(gathered_embeds['wallet'], text_emb, media_emb, main_tweet_emb)
205
+ final_embeds += reply_embeds * event_masks['XReply'].unsqueeze(-1)
206
+
207
+ # XRetweet
208
+ if event_masks['XRetweet'].any():
209
+ orig_text_emb = gather_precomputed('original_post_text_emb_idx')
210
+ orig_media_emb = gather_precomputed('original_post_media_emb_idx')
211
+ retweet_embeds = self.x_retweet_encoder(gathered_embeds['wallet'], gathered_embeds['original_author'], orig_text_emb, orig_media_emb)
212
+ final_embeds += retweet_embeds * event_masks['XRetweet'].unsqueeze(-1)
213
+
214
+ # XQuoteTweet
215
+ if event_masks['XQuoteTweet'].any():
216
+ quoter_text_emb = gather_precomputed('quoter_text_emb_idx')
217
+ orig_text_emb = gather_precomputed('original_post_text_emb_idx')
218
+ orig_media_emb = gather_precomputed('original_post_media_emb_idx')
219
+ quote_embeds = self.x_quote_tweet_encoder(gathered_embeds['wallet'], quoter_text_emb, gathered_embeds['original_author'], orig_text_emb, orig_media_emb)
220
+ final_embeds += quote_embeds * event_masks['XQuoteTweet'].unsqueeze(-1)
221
+
222
+ # PumpReply
223
+ if event_masks['PumpReply'].any():
224
+ text_emb = gather_precomputed('reply_text_emb_idx')
225
+ pump_reply_embeds = self.pump_reply_encoder(gathered_embeds['wallet'], text_emb)
226
+ final_embeds += pump_reply_embeds * event_masks['PumpReply'].unsqueeze(-1)
227
+
228
+ # DexProfile_Updated
229
+ if event_masks['DexProfile_Updated'].any():
230
+ website_emb = gather_precomputed('website_emb_idx')
231
+ twitter_emb = gather_precomputed('twitter_link_emb_idx')
232
+ telegram_emb = gather_precomputed('telegram_link_emb_idx')
233
+ description_emb = gather_precomputed('description_emb_idx')
234
+ profile_embeds = self.dex_profile_encoder(website_emb, twitter_emb, telegram_emb, description_emb)
235
+ # Note: The flags are handled separately in the main model now, so we just add the text embeds
236
+ final_embeds += profile_embeds * event_masks['DexProfile_Updated'].unsqueeze(-1)
237
+
238
+ # Global Trending Hashtags
239
+ trending_mask = event_masks['TikTok_Trending_Hashtag'] | event_masks['XTrending_Hashtag']
240
+ if trending_mask.any():
241
+ hashtag_emb = gather_precomputed('hashtag_name_emb_idx')
242
+ trending_embeds = self.global_trending_encoder(hashtag_emb)
243
+ final_embeds += trending_embeds * trending_mask.unsqueeze(-1)
244
+
245
+ return final_embeds
models/__init__.py ADDED
File without changes
models/graph_updater.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ # We still use GATv2Conv, just not the to_hetero wrapper
4
+ from torch_geometric.nn import GATv2Conv
5
+ from torch_geometric.data import HeteroData
6
+ from typing import Dict, List, Any
7
+ from collections import defaultdict # For easy aggregation
8
+ from PIL import Image
9
+
10
+ from models.helper_encoders import ContextualTimeEncoder # Type hint for constructor compatibility
11
+ # Import the actual ID_TO_LINK_TYPE mapping
12
+ from models.vocabulary import ID_TO_LINK_TYPE
13
+ # Import other modules needed for the test block
14
+ import models.vocabulary
15
+ from models.wallet_encoder import WalletEncoder
16
+ from models.token_encoder import TokenEncoder
17
+ from models.multi_modal_processor import MultiModalEncoder
18
+
19
+
20
+ class _TransferLinkEncoder(nn.Module):
21
+ """Encodes: transfer amount only (timestamps removed)."""
22
+ def __init__(self, out_dim: int, dtype: torch.dtype):
23
+ super().__init__()
24
+ self.proj = nn.Sequential(
25
+ nn.Linear(1, out_dim),
26
+ nn.GELU(),
27
+ nn.Linear(out_dim, out_dim)
28
+ )
29
+ self.dtype = dtype
30
+
31
+ def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor:
32
+ return torch.sign(x) * torch.log1p(torch.abs(x))
33
+
34
+ def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
35
+ amounts = torch.tensor([[l.get('amount', 0.0)] for l in links], device=device, dtype=self.dtype)
36
+ features = self._safe_signed_log(amounts)
37
+
38
+ return self.proj(features)
39
+
40
+ class _BundleTradeLinkEncoder(nn.Module):
41
+ """Encodes: total_amount across bundle (timestamps removed)."""
42
+ def __init__(self, out_dim: int, dtype: torch.dtype):
43
+ super().__init__()
44
+ self.proj = nn.Sequential(
45
+ nn.Linear(1, out_dim),
46
+ nn.GELU(),
47
+ nn.Linear(out_dim, out_dim)
48
+ )
49
+ self.dtype = dtype
50
+
51
+ def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor:
52
+ return torch.sign(x) * torch.log1p(torch.abs(x))
53
+
54
+ def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
55
+ totals = torch.tensor([[l.get('total_amount', 0.0)] for l in links], device=device, dtype=self.dtype)
56
+ total_embeds = self._safe_signed_log(totals)
57
+
58
+ return self.proj(total_embeds)
59
+
60
+ class _CopiedTradeLinkEncoder(nn.Module):
61
+ """ Encodes: 10 numerical features """
62
+ def __init__(self, in_features: int, out_dim: int, dtype: torch.dtype): # Added dtype
63
+ super().__init__()
64
+ self.in_features = in_features
65
+ self.norm = nn.LayerNorm(in_features)
66
+ self.mlp = nn.Sequential(
67
+ nn.Linear(in_features, out_dim * 2), nn.GELU(),
68
+ nn.Linear(out_dim * 2, out_dim)
69
+ )
70
+ self.dtype = dtype # Store dtype
71
+
72
+ def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor:
73
+ return torch.sign(x) * torch.log1p(torch.abs(x))
74
+
75
+ def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
76
+ num_data = []
77
+ for l in links:
78
+ # --- FIXED: Only use the 6 essential features ---
79
+ num_data.append([
80
+ l.get('time_gap_on_buy_sec', 0), l.get('time_gap_on_sell_sec', 0),
81
+ l.get('leader_pnl', 0), l.get('follower_pnl', 0),
82
+ l.get('follower_buy_total', 0), l.get('follower_sell_total', 0)
83
+ ])
84
+ # Create tensor with correct dtype
85
+ x = torch.tensor(num_data, device=device, dtype=self.dtype)
86
+ # Input to norm must match norm's dtype
87
+ x_norm = self.norm(self._safe_signed_log(x))
88
+ return self.mlp(x_norm)
89
+
90
+ class _CoordinatedActivityLinkEncoder(nn.Module):
91
+ """ Encodes: 2 numerical features """
92
+ def __init__(self, in_features: int, out_dim: int, dtype: torch.dtype): # Added dtype
93
+ super().__init__()
94
+ self.in_features = in_features
95
+ self.norm = nn.LayerNorm(in_features)
96
+ self.mlp = nn.Sequential(
97
+ nn.Linear(in_features, out_dim), nn.GELU(),
98
+ nn.Linear(out_dim, out_dim)
99
+ )
100
+ self.dtype = dtype # Store dtype
101
+
102
+ def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor:
103
+ return torch.sign(x) * torch.log1p(torch.abs(x))
104
+
105
+ def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
106
+ num_data = []
107
+ for l in links:
108
+ num_data.append([
109
+ l.get('time_gap_on_first_sec', 0), l.get('time_gap_on_second_sec', 0)
110
+ ])
111
+ # Create tensor with correct dtype
112
+ x = torch.tensor(num_data, device=device, dtype=self.dtype)
113
+ x_norm = self.norm(self._safe_signed_log(x))
114
+ return self.mlp(x_norm)
115
+
116
+ class _MintedLinkEncoder(nn.Module):
117
+ """Encodes: buy_amount only (timestamps removed)."""
118
+ def __init__(self, out_dim: int, dtype: torch.dtype):
119
+ super().__init__()
120
+ self.proj = nn.Sequential(
121
+ nn.Linear(1, out_dim),
122
+ nn.GELU(),
123
+ nn.Linear(out_dim, out_dim)
124
+ )
125
+ self.dtype = dtype # Store dtype
126
+
127
+ def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor:
128
+ return torch.sign(x) * torch.log1p(torch.abs(x))
129
+
130
+ def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
131
+ nums = torch.tensor([[l['buy_amount']] for l in links], device=device, dtype=self.dtype)
132
+
133
+ num_embeds = self._safe_signed_log(nums)
134
+
135
+ return self.proj(num_embeds)
136
+
137
+ class _SnipedLinkEncoder(nn.Module):
138
+ """ Encodes: rank, sniped_amount """
139
+ def __init__(self, in_features: int, out_dim: int, dtype: torch.dtype): # Added dtype
140
+ super().__init__()
141
+ self.norm = nn.LayerNorm(in_features)
142
+ self.mlp = nn.Sequential(nn.Linear(in_features, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim))
143
+ self.dtype = dtype # Store dtype
144
+
145
+ def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
146
+ num_data = [[l.get('rank', 0), l.get('sniped_amount', 0)] for l in links]
147
+ # Create tensor with correct dtype
148
+ x = torch.tensor(num_data, device=device, dtype=self.dtype)
149
+
150
+ # --- FIXED: Selectively log-scale features ---
151
+ # Invert rank so 1 is highest, treat as linear. Log-scale sniped_amount.
152
+ x[:, 0] = 1.0 / torch.clamp(x[:, 0], min=1.0) # Invert rank, clamp to avoid division by zero
153
+ x[:, 1] = torch.sign(x[:, 1]) * torch.log1p(torch.abs(x[:, 1])) # Log-scale amount
154
+
155
+ x_norm = self.norm(x)
156
+ return self.mlp(x_norm)
157
+
158
+ class _LockedSupplyLinkEncoder(nn.Module):
159
+ """ Encodes: amount """
160
+ def __init__(self, out_dim: int, dtype: torch.dtype): # Removed time_encoder
161
+ super().__init__()
162
+ self.proj = nn.Sequential(
163
+ nn.Linear(1, out_dim),
164
+ nn.GELU(),
165
+ nn.Linear(out_dim, out_dim)
166
+ )
167
+ self.dtype = dtype # Store dtype
168
+
169
+ def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor:
170
+ return torch.sign(x) * torch.log1p(torch.abs(x))
171
+
172
+ def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
173
+ nums = torch.tensor([[l['amount']] for l in links], device=device, dtype=self.dtype)
174
+ num_embeds = self._safe_signed_log(nums)
175
+ return self.proj(num_embeds)
176
+
177
+ class _BurnedLinkEncoder(nn.Module):
178
+ """Encodes: burned amount (timestamps removed)."""
179
+ def __init__(self, out_dim: int, dtype: torch.dtype):
180
+ super().__init__()
181
+ self.proj = nn.Sequential(
182
+ nn.Linear(1, out_dim),
183
+ nn.GELU(),
184
+ nn.Linear(out_dim, out_dim)
185
+ )
186
+ self.dtype = dtype
187
+
188
+ def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor:
189
+ return torch.sign(x) * torch.log1p(torch.abs(x))
190
+
191
+ def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
192
+ amounts = torch.tensor([[l.get('amount', 0.0)] for l in links], device=device, dtype=self.dtype)
193
+ amount_embeds = self._safe_signed_log(amounts)
194
+
195
+ return self.proj(amount_embeds)
196
+
197
+ class _ProvidedLiquidityLinkEncoder(nn.Module):
198
+ """Encodes: quote amount (timestamps removed)."""
199
+ def __init__(self, out_dim: int, dtype: torch.dtype):
200
+ super().__init__()
201
+ self.proj = nn.Sequential(
202
+ nn.Linear(1, out_dim),
203
+ nn.GELU(),
204
+ nn.Linear(out_dim, out_dim)
205
+ )
206
+ self.dtype = dtype
207
+
208
+ def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor:
209
+ return torch.sign(x) * torch.log1p(torch.abs(x))
210
+
211
+ def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
212
+ quote_amounts = torch.tensor([[l.get('amount_quote', 0.0)] for l in links], device=device, dtype=self.dtype)
213
+ quote_embeds = self._safe_signed_log(quote_amounts)
214
+
215
+ return self.proj(quote_embeds)
216
+
217
+ class _WhaleOfLinkEncoder(nn.Module):
218
+ """ Encodes: holding_pct_at_creation """
219
+ def __init__(self, out_dim: int, dtype: torch.dtype):
220
+ super().__init__()
221
+ self.mlp = nn.Sequential(
222
+ nn.Linear(1, out_dim),
223
+ nn.GELU(),
224
+ nn.Linear(out_dim, out_dim)
225
+ )
226
+ self.dtype = dtype
227
+
228
+ def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
229
+ vals = torch.tensor([[l.get('holding_pct_at_creation', 0.0)] for l in links], device=device, dtype=self.dtype)
230
+ vals_log = torch.sign(vals) * torch.log1p(torch.abs(vals))
231
+ return self.mlp(vals_log)
232
+
233
+ class _TopTraderOfLinkEncoder(nn.Module):
234
+ """ Encodes: pnl_at_creation """
235
+ def __init__(self, out_dim: int, dtype: torch.dtype): # Removed in_features
236
+ super().__init__()
237
+ self.mlp = nn.Sequential(nn.Linear(1, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim))
238
+ self.dtype = dtype
239
+
240
+ def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
241
+ num_data = [[l.get('pnl_at_creation', 0)] for l in links]
242
+ x = torch.tensor(num_data, device=device, dtype=self.dtype)
243
+ log_scaled_x = torch.sign(x) * torch.log1p(torch.abs(x))
244
+ return self.mlp(log_scaled_x)
245
+
246
+
247
+ class RelationalGATBlock(nn.Module):
248
+ """
249
+ Shared GATv2Conv that remains relation-aware by concatenating a learned
250
+ relation embedding to every edge attribute before message passing.
251
+ """
252
+
253
+ def __init__(
254
+ self,
255
+ node_dim: int,
256
+ edge_attr_dim: int,
257
+ n_heads: int,
258
+ relations: List[str],
259
+ dtype: torch.dtype,
260
+ ):
261
+ super().__init__()
262
+ self.rel_to_id = {name: idx for idx, name in enumerate(relations)}
263
+ self.edge_attr_dim = edge_attr_dim
264
+ self.rel_emb = nn.Embedding(len(relations), edge_attr_dim)
265
+ self.conv = GATv2Conv(
266
+ in_channels=node_dim,
267
+ out_channels=node_dim,
268
+ heads=n_heads,
269
+ concat=False,
270
+ dropout=0.1,
271
+ add_self_loops=False,
272
+ edge_dim=edge_attr_dim * 2, # concat of edge attr + relation emb
273
+ ).to(dtype)
274
+
275
+ def forward(
276
+ self,
277
+ x_src: torch.Tensor,
278
+ x_dst: torch.Tensor,
279
+ edge_index: torch.Tensor,
280
+ edge_attr: torch.Tensor,
281
+ rel_type: str,
282
+ ) -> torch.Tensor:
283
+ num_edges = edge_index.size(1)
284
+ device = edge_index.device
285
+
286
+ if edge_attr is None:
287
+ edge_attr = torch.zeros(
288
+ num_edges,
289
+ self.edge_attr_dim,
290
+ device=device,
291
+ dtype=x_src.dtype,
292
+ )
293
+
294
+ rel_id = self.rel_to_id.get(rel_type)
295
+ if rel_id is None:
296
+ raise KeyError(f"Relation '{rel_type}' not registered in RelationalGATBlock.")
297
+
298
+ rel_feat = self.rel_emb.weight[rel_id].to(edge_attr.dtype)
299
+ rel_feat = rel_feat.expand(num_edges, -1)
300
+ augmented_attr = torch.cat([edge_attr, rel_feat], dim=-1)
301
+
302
+ return self.conv((x_src, x_dst), edge_index, edge_attr=augmented_attr)
303
+ # =============================================================================
304
+ # 2. The Main GraphUpdater (GNN) - MANUAL HETEROGENEOUS IMPLEMENTATION
305
+ # =============================================================================
306
+
307
+ class GraphUpdater(nn.Module):
308
+ """
309
+ FIXED: Manually implements Heterogeneous GNN logic using separate GATv2Conv
310
+ layers for each edge type, bypassing the problematic `to_hetero` wrapper.
311
+ """
312
+
313
+ def __init__(self,time_encoder: ContextualTimeEncoder, edge_attr_dim: int = 64,
314
+ n_heads: int = 4, num_layers: int = 2, node_dim: int = 2048, dtype: torch.dtype = torch.float16):
315
+ super().__init__()
316
+ self.node_dim = node_dim
317
+ self.edge_attr_dim = edge_attr_dim
318
+ self.num_layers = num_layers
319
+ self.dtype = dtype
320
+
321
+ # --- Instantiate all 11 Link Feature Encoders --- (Unchanged)
322
+ self.edge_encoders = nn.ModuleDict({
323
+ 'TransferLink': _TransferLinkEncoder(edge_attr_dim, dtype=dtype),
324
+ 'TransferLinkToken': _TransferLinkEncoder(edge_attr_dim, dtype=dtype),
325
+ 'BundleTradeLink': _BundleTradeLinkEncoder(edge_attr_dim, dtype=dtype),
326
+ 'CopiedTradeLink': _CopiedTradeLinkEncoder(6, edge_attr_dim, dtype=dtype), # FIXED: in_features=6
327
+ 'CoordinatedActivityLink': _CoordinatedActivityLinkEncoder(2, edge_attr_dim, dtype=dtype),
328
+ 'MintedLink': _MintedLinkEncoder(edge_attr_dim, dtype=dtype),
329
+ 'SnipedLink': _SnipedLinkEncoder(2, edge_attr_dim, dtype=dtype),
330
+ 'LockedSupplyLink': _LockedSupplyLinkEncoder(edge_attr_dim, dtype=dtype), # FIXED: No time_encoder
331
+ 'BurnedLink': _BurnedLinkEncoder(edge_attr_dim, dtype=dtype),
332
+ 'ProvidedLiquidityLink': _ProvidedLiquidityLinkEncoder(edge_attr_dim, dtype=dtype),
333
+ 'WhaleOfLink': _WhaleOfLinkEncoder(edge_attr_dim, dtype=dtype), # FIXED: No in_features
334
+ 'TopTraderOfLink': _TopTraderOfLinkEncoder(edge_attr_dim, dtype=dtype), # FIXED: No in_features
335
+ }).to(dtype)
336
+
337
+ # --- Define shared relational GNN blocks per meta edge direction ---
338
+ self.edge_groups = self._build_edge_groups()
339
+ self.conv_layers = nn.ModuleList()
340
+ for _ in range(num_layers):
341
+ conv_dict = nn.ModuleDict()
342
+ for (src_type, dst_type), relations in self.edge_groups.items():
343
+ conv_dict[f"{src_type}__{dst_type}"] = RelationalGATBlock(
344
+ node_dim=node_dim,
345
+ edge_attr_dim=edge_attr_dim,
346
+ n_heads=n_heads,
347
+ relations=relations,
348
+ dtype=dtype,
349
+ )
350
+ self.conv_layers.append(conv_dict)
351
+
352
+ self.norm = nn.LayerNorm(node_dim)
353
+ self.to(dtype) # Move norm layer and ModuleList container
354
+
355
+ def _build_edge_groups(self) -> Dict[tuple, List[str]]:
356
+ """Group relations by (src_type, dst_type) so conv weights can be shared."""
357
+ groups: Dict[tuple, List[str]] = defaultdict(list)
358
+
359
+ wallet_wallet_links = ['TransferLink', 'BundleTradeLink', 'CopiedTradeLink', 'CoordinatedActivityLink']
360
+ wallet_token_links = [
361
+ 'TransferLinkToken', 'MintedLink', 'SnipedLink', 'LockedSupplyLink',
362
+ 'BurnedLink', 'ProvidedLiquidityLink', 'WhaleOfLink', 'TopTraderOfLink'
363
+ ]
364
+
365
+ for link in wallet_wallet_links:
366
+ groups[('wallet', 'wallet')].append(link)
367
+ groups[('wallet', 'wallet')].append(f"rev_{link}")
368
+
369
+ for link in wallet_token_links:
370
+ groups[('wallet', 'token')].append(link)
371
+ groups[('token', 'wallet')].append(f"rev_{link}")
372
+
373
+ return groups
374
+
375
+ def forward(
376
+ self,
377
+ x_dict: Dict[str, torch.Tensor],
378
+ edge_data_dict: Dict[str, Dict[str, Any]]
379
+ ) -> Dict[str, torch.Tensor]:
380
+ device = x_dict['wallet'].device
381
+
382
+ # --- 1. Encode Edge Attributes ---
383
+ edge_index_dict = {}
384
+ edge_attr_dict = {}
385
+
386
+ for link_name, data in edge_data_dict.items():
387
+ edge_index = data.get('edge_index')
388
+ links = data.get('links', [])
389
+
390
+ # Check if edge_index is valid before proceeding
391
+ if edge_index is None or edge_index.numel() == 0 or not links:
392
+ continue # Skip if no links or index of this type
393
+
394
+ edge_index = edge_index.to(device)
395
+
396
+ # Use vocabulary to get the triplet (src, rel, dst)
397
+ # Make sure ID_TO_LINK_TYPE is correctly populated
398
+ if link_name not in vocabulary.LINK_NAME_TO_TRIPLET:
399
+ print(f"Warning: Link name '{link_name}' not found in vocabulary.LINK_NAME_TO_TRIPLET. Skipping.")
400
+ continue
401
+ src_type, rel_type, dst_type = vocabulary.LINK_NAME_TO_TRIPLET[link_name]
402
+
403
+ # Check if encoder exists for this link name
404
+ if link_name not in self.edge_encoders:
405
+ print(f"Warning: No edge encoder found for link type '{link_name}'. Skipping edge attributes.")
406
+ edge_attr = None # Or handle differently if attributes are essential
407
+ else:
408
+ edge_attr = self.edge_encoders[link_name](links, device).to(self.dtype)
409
+
410
+
411
+ # Forward link
412
+ fwd_key = (src_type, rel_type, dst_type)
413
+ edge_index_dict[fwd_key] = edge_index
414
+ if edge_attr is not None:
415
+ edge_attr_dict[fwd_key] = edge_attr
416
+
417
+ # Reverse link
418
+ # Ensure edge_index has the right shape for flipping
419
+ if edge_index.shape[0] == 2:
420
+ rev_edge_index = edge_index[[1, 0]]
421
+ rev_rel_type = f'rev_{rel_type}'
422
+ rev_key = (dst_type, rev_rel_type, src_type)
423
+ edge_index_dict[rev_key] = rev_edge_index
424
+ if edge_attr is not None:
425
+ # Re-use same attributes for reverse edge
426
+ edge_attr_dict[rev_key] = edge_attr
427
+ else:
428
+ print(f"Warning: Edge index for {link_name} has unexpected shape {edge_index.shape}. Cannot create reverse edge.")
429
+
430
+
431
+ # --- 2. Run GNN Layers MANUALLY ---
432
+ x_out = x_dict
433
+ for i in range(self.num_layers):
434
+ # Initialize aggregation tensors for each node type that exists in the input
435
+ msg_aggregates = {
436
+ node_type: torch.zeros_like(x_node)
437
+ for node_type, x_node in x_out.items()
438
+ }
439
+
440
+ # --- Message Passing ---
441
+ for edge_type_tuple in edge_index_dict.keys(): # Iterate through edges PRESENT in the batch
442
+ src_type, rel_type, dst_type = edge_type_tuple
443
+ edge_index = edge_index_dict[edge_type_tuple]
444
+ edge_attr = edge_attr_dict.get(edge_type_tuple) # Use .get() in case attr is None
445
+
446
+ x_src = x_out.get(src_type)
447
+ x_dst = x_out.get(dst_type)
448
+ if x_src is None or x_dst is None:
449
+ print(f"Warning: Missing node embeddings for types {src_type}->{dst_type}. Skipping.")
450
+ continue
451
+
452
+ block_key = f"{src_type}__{dst_type}"
453
+ if block_key not in self.conv_layers[i]:
454
+ print(f"Warning: Relational block for {block_key} not found in layer {i}. Skipping.")
455
+ continue
456
+ block = self.conv_layers[i][block_key]
457
+
458
+ try:
459
+ messages = block(x_src, x_dst, edge_index, edge_attr, rel_type)
460
+ except KeyError:
461
+ print(f"Warning: Relation '{rel_type}' missing in block {block_key}. Skipping.")
462
+ continue
463
+
464
+ # *** THE FIX ***
465
+ # Use scatter_add_ to accumulate messages for the destination node type.
466
+ # This correctly handles multiple edge types pointing to the same node type.
467
+ msg_aggregates[dst_type].scatter_add_(0, edge_index[1].unsqueeze(1).expand_as(messages), messages)
468
+
469
+ # --- Aggregation & Update (Residual Connection) ---
470
+ x_next = {}
471
+ for node_type, x_original in x_out.items():
472
+ # Check if messages were computed and stored correctly
473
+ if node_type in msg_aggregates and msg_aggregates[node_type].shape[0] > 0:
474
+ aggregated_msgs = msg_aggregates[node_type]
475
+ # Ensure dimensions match before adding
476
+ if x_original.shape == aggregated_msgs.shape:
477
+ x_next[node_type] = self.norm(x_original + aggregated_msgs)
478
+ else:
479
+ print(f"Warning: Shape mismatch for node type {node_type} during update. Original: {x_original.shape}, Aggregated: {aggregated_msgs.shape}. Skipping residual connection.")
480
+ x_next[node_type] = x_original # Fallback
481
+ else:
482
+ x_next[node_type] = x_original
483
+
484
+ x_out = x_next
485
+
486
+ return x_out
models/helper_encoders.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ import datetime
5
+ from typing import Dict, List, Any, Optional
6
+
7
+ class ContextualTimeEncoder(nn.Module):
8
+ def __init__(self, output_dim: int = 128, dtype: torch.dtype = torch.float32):
9
+ """
10
+ Encodes a Unix timestamp with support for mixed precision.
11
+
12
+ Args:
13
+ output_dim (int): The final dimension of the output embedding.
14
+ dtype (torch.dtype): The data type for the model's parameters (e.g., torch.float16).
15
+ """
16
+ super().__init__()
17
+ self.dtype = dtype
18
+ if output_dim < 12:
19
+ raise ValueError(f"output_dim must be at least 12, but got {output_dim}")
20
+
21
+ ts_dim = output_dim // 2
22
+ hour_dim = output_dim // 4
23
+ day_dim = output_dim - ts_dim - hour_dim
24
+
25
+ self.ts_dim = ts_dim + (ts_dim % 2)
26
+ self.hour_dim = hour_dim + (hour_dim % 2)
27
+ self.day_dim = day_dim + (day_dim % 2)
28
+
29
+ total_internal_dim = self.ts_dim + self.hour_dim + self.day_dim
30
+
31
+ self.projection = nn.Linear(total_internal_dim, output_dim)
32
+
33
+ # Cast the entire module to the specified dtype
34
+ self.to(dtype)
35
+
36
+ def _sinusoidal_encode(self, values: torch.Tensor, d_model: int) -> torch.Tensor:
37
+ device = values.device
38
+ half_dim = d_model // 2
39
+
40
+ # Calculations for sinusoidal encoding are more stable in float32
41
+ div_term = torch.exp(torch.arange(0, half_dim, device=device).float() * -(math.log(10000.0) / half_dim))
42
+ args = values.float().unsqueeze(-1) * div_term
43
+
44
+ return torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
45
+
46
+ def _cyclical_encode(self, values: torch.Tensor, d_model: int, max_val: float) -> torch.Tensor:
47
+ device = values.device
48
+ norm_values = (values.float() / max_val) * 2 * math.pi
49
+
50
+ half_dim = d_model // 2
51
+ sin_args = norm_values.unsqueeze(-1).repeat(1, half_dim)
52
+ cos_args = norm_values.unsqueeze(-1).repeat(1, half_dim)
53
+
54
+ return torch.cat([torch.sin(sin_args), torch.cos(cos_args)], dim=-1)
55
+
56
+ def forward(self, timestamps: torch.Tensor) -> torch.Tensor:
57
+ device = self.projection.weight.device
58
+
59
+ # 1. Store original shape (e.g., [B, L]) and flatten
60
+ original_shape = timestamps.shape
61
+ timestamps_flat = timestamps.flatten().float() # Shape [N_total]
62
+
63
+ # 2. Sinusoidal encode (already vectorized)
64
+ ts_encoding = self._sinusoidal_encode(timestamps_flat, self.ts_dim)
65
+
66
+ # 3. List comprehension (this is the only non-vectorized part)
67
+ # This loop is now correct, as it iterates over the 1D flat tensor
68
+ hours = torch.tensor([datetime.datetime.fromtimestamp(ts.item(), tz=datetime.timezone.utc).hour for ts in timestamps_flat], device=device, dtype=torch.float32)
69
+ days = torch.tensor([datetime.datetime.fromtimestamp(ts.item(), tz=datetime.timezone.utc).weekday() for ts in timestamps_flat], device=device, dtype=torch.float32)
70
+
71
+ # 4. Cyclical encode (already vectorized)
72
+ hour_encoding = self._cyclical_encode(hours, self.hour_dim, max_val=24.0)
73
+ day_encoding = self._cyclical_encode(days, self.day_dim, max_val=7.0)
74
+
75
+ # 5. Combine and project
76
+ combined_encoding = torch.cat([ts_encoding, hour_encoding, day_encoding], dim=1)
77
+ projected = self.projection(combined_encoding.to(self.dtype)) # Shape [N_total, output_dim]
78
+
79
+ # 6. Reshape to match original (e.g., [B, L, output_dim])
80
+ output_shape = original_shape + (self.projection.out_features,)
81
+ return projected.view(output_shape)
82
+
83
+ def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
84
+ mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
85
+ summed = torch.sum(last_hidden_state * mask, 1)
86
+ denom = torch.clamp(mask.sum(1), min=1e-9)
87
+ return summed / denom
models/model.py ADDED
@@ -0,0 +1,1009 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model.py (REFACTORED AND FIXED)
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers import AutoConfig, AutoModel
7
+ from typing import List, Dict, Any, Optional, Tuple
8
+
9
+ # --- NOW, we import all the encoders ---
10
+ from models.helper_encoders import ContextualTimeEncoder
11
+ from models.token_encoder import TokenEncoder
12
+ from models.wallet_encoder import WalletEncoder
13
+ from models.graph_updater import GraphUpdater
14
+ from models.ohlc_embedder import OHLCEmbedder
15
+ from models.HoldersEncoder import HolderDistributionEncoder # NEW
16
+ from models.SocialEncoders import SocialEncoder # NEW
17
+ import models.vocabulary as vocab # For vocab sizes
18
+
19
+ class Oracle(nn.Module):
20
+ """
21
+
22
+ """
23
+ def __init__(self,
24
+ token_encoder: TokenEncoder,
25
+ wallet_encoder: WalletEncoder,
26
+ graph_updater: GraphUpdater,
27
+ ohlc_embedder: OHLCEmbedder, # NEW
28
+ time_encoder: ContextualTimeEncoder,
29
+ num_event_types: int,
30
+ multi_modal_dim: int,
31
+ event_pad_id: int,
32
+ event_type_to_id: Dict[str, int],
33
+ model_config_name: str = "Qwen/Qwen3-0.6B",
34
+ quantiles: List[float] = [0.1, 0.5, 0.9],
35
+ horizons_seconds: List[int] = [30, 60, 120, 240, 420],
36
+ dtype: torch.dtype = torch.bfloat16):
37
+
38
+ super().__init__()
39
+
40
+ device = "cuda" if torch.cuda.is_available() else "cpu"
41
+ self.device = torch.device(device)
42
+ self.dtype = dtype
43
+ self.multi_modal_dim = multi_modal_dim
44
+
45
+
46
+ self.quantiles = quantiles
47
+ self.horizons_seconds = horizons_seconds
48
+ self.num_outputs = len(quantiles) * len(horizons_seconds)
49
+ self.dtype = dtype
50
+
51
+ # --- 2. Load Qwen3 Configuration (architecture only; training from scratch) ---
52
+ model_config = AutoConfig.from_pretrained(model_config_name, trust_remote_code=True)
53
+ self.d_model = model_config.hidden_size
54
+ self.model = AutoModel.from_config(model_config, trust_remote_code=True)
55
+ self.model.to(self.device, dtype=self.dtype)
56
+
57
+ # Quantile prediction head (maps pooled hidden state -> flattened horizon/quantile grid)
58
+ self.quantile_head = nn.Sequential(
59
+ nn.Linear(self.d_model, self.d_model),
60
+ nn.GELU(),
61
+ nn.Linear(self.d_model, self.num_outputs)
62
+ )
63
+
64
+ self.event_type_to_id = event_type_to_id
65
+
66
+ # --- 1. Store All Encoders ---
67
+ # Define Token Roles before using them
68
+ self.token_roles = {'main': 0, 'quote': 1, 'trending': 2} # Add trending for future use
69
+ self.main_token_role_id = self.token_roles['main']
70
+ self.quote_token_role_id = self.token_roles['quote']
71
+ self.trending_token_role_id = self.token_roles['trending']
72
+
73
+
74
+ self.token_encoder = token_encoder
75
+ self.wallet_encoder = wallet_encoder
76
+ self.graph_updater = graph_updater
77
+ self.ohlc_embedder = ohlc_embedder
78
+ self.time_encoder = time_encoder # Store time_encoder
79
+
80
+ self.social_encoder = SocialEncoder(d_model=self.d_model, dtype=self.dtype) # Now self.d_model is defined
81
+
82
+ # --- 4. Define Sequence Feature Embeddings ---
83
+ self.event_type_embedding = nn.Embedding(num_event_types, self.d_model, padding_idx=event_pad_id)
84
+
85
+ # --- NEW: Token Role Embeddings ---
86
+ self.token_role_embedding = nn.Embedding(len(self.token_roles), self.d_model)
87
+
88
+
89
+
90
+ # --- 5. Define Entity Padding (Learnable) ---
91
+ self.pad_wallet_emb = nn.Parameter(torch.zeros(1, self.wallet_encoder.d_model))
92
+ self.pad_token_emb = nn.Parameter(torch.zeros(1, self.token_encoder.output_dim))
93
+ self.pad_ohlc_emb = nn.Parameter(torch.zeros(1, self.ohlc_embedder.output_dim))
94
+ self.pad_precomputed_emb = nn.Parameter(torch.zeros(1, self.multi_modal_dim)) # NEW: For text/images
95
+
96
+ # --- NEW: Instantiate HolderDistributionEncoder internally ---
97
+ self.holder_dist_encoder = HolderDistributionEncoder(
98
+ wallet_embedding_dim=self.wallet_encoder.d_model,
99
+ output_dim=self.d_model,
100
+ dtype=self.dtype # Pass the correct dtype
101
+ )
102
+ self.pad_holder_snapshot_emb = nn.Parameter(torch.zeros(1, self.d_model)) # Output of holder_dist_encoder is d_model
103
+
104
+ # --- 6. Define Projection MLPs ---
105
+ self.time_proj = nn.Linear(self.time_encoder.projection.out_features, self.d_model)
106
+ self.rel_ts_proj = nn.Linear(1, self.d_model)
107
+ self.rel_ts_norm = nn.LayerNorm(1)
108
+ self.wallet_proj = nn.Linear(self.wallet_encoder.d_model, self.d_model)
109
+ self.token_proj = nn.Linear(self.token_encoder.output_dim, self.d_model)
110
+ self.ohlc_proj = nn.Linear(self.ohlc_embedder.output_dim, self.d_model)
111
+ # self.holder_snapshot_proj is no longer needed as HolderDistributionEncoder outputs directly to d_model
112
+
113
+
114
+ # --- NEW: Layers for Transfer Numerical Features ---
115
+ self.transfer_num_norm = nn.LayerNorm(4) # Normalize the 4 features
116
+ self.transfer_num_proj = nn.Linear(4, self.d_model) # Project to d_model
117
+
118
+ # --- NEW: Layers for Trade Numerical Features ---
119
+ # --- FIXED: Size reduced from 10 to 8 ---
120
+ self.trade_num_norm = nn.LayerNorm(8)
121
+ self.trade_num_proj = nn.Linear(8, self.d_model)
122
+ # --- NEW: Embedding for categorical dex_platform_id ---
123
+ self.dex_platform_embedding = nn.Embedding(vocab.NUM_DEX_PLATFORMS, self.d_model)
124
+ # --- NEW: Embedding for categorical trade_direction ---
125
+ self.trade_direction_embedding = nn.Embedding(2, self.d_model) # 0 for buy, 1 for sell
126
+ # --- FIXED: Embedding for categorical mev_protection is now binary ---
127
+ self.mev_protection_embedding = nn.Embedding(2, self.d_model) # 0 for false, 1 for true
128
+ # --- NEW: Embedding for categorical is_bundle ---
129
+ self.is_bundle_embedding = nn.Embedding(2, self.d_model) # 0 for false, 1 for true
130
+
131
+ # --- NEW: Separate Layers for Deployer Trade Numerical Features ---
132
+ # --- FIXED: Size reduced from 10 to 8 ---
133
+ self.deployer_trade_num_norm = nn.LayerNorm(8)
134
+ self.deployer_trade_num_proj = nn.Linear(8, self.d_model)
135
+
136
+ # --- NEW: Separate Layers for Smart Wallet Trade Numerical Features ---
137
+ # --- FIXED: Size reduced from 10 to 8 ---
138
+ self.smart_wallet_trade_num_norm = nn.LayerNorm(8)
139
+ self.smart_wallet_trade_num_proj = nn.Linear(8, self.d_model)
140
+
141
+ # --- NEW: Layers for PoolCreated Numerical Features ---
142
+ # --- FIXED: Size reduced from 5 to 4 ---
143
+ self.pool_created_num_norm = nn.LayerNorm(2)
144
+ self.pool_created_num_proj = nn.Linear(2, self.d_model)
145
+
146
+ # --- NEW: Layers for LiquidityChange Numerical Features ---
147
+ # --- FIXED: Size reduced from 3 to 2 ---
148
+ self.liquidity_change_num_norm = nn.LayerNorm(1)
149
+ self.liquidity_change_num_proj = nn.Linear(1, self.d_model)
150
+ # --- NEW: Embedding for categorical change_type_id ---
151
+ # --- FIXED: Hardcoded the number of types (add/remove) as per user instruction ---
152
+ self.liquidity_change_type_embedding = nn.Embedding(2, self.d_model)
153
+
154
+ # --- NEW: Layers for FeeCollected Numerical Features ---
155
+ self.fee_collected_num_norm = nn.LayerNorm(1) # sol_amount only
156
+ self.fee_collected_num_proj = nn.Linear(1, self.d_model)
157
+
158
+ # --- NEW: Layers for TokenBurn Numerical Features ---
159
+ self.token_burn_num_norm = nn.LayerNorm(2) # amount_pct, amount_tokens
160
+ self.token_burn_num_proj = nn.Linear(2, self.d_model)
161
+
162
+ # --- NEW: Layers for SupplyLock Numerical Features ---
163
+ self.supply_lock_num_norm = nn.LayerNorm(2) # amount_pct, lock_duration
164
+ self.supply_lock_num_proj = nn.Linear(2, self.d_model)
165
+
166
+ # --- NEW: Layers for OnChain_Snapshot Numerical Features ---
167
+ self.onchain_snapshot_num_norm = nn.LayerNorm(14)
168
+ self.onchain_snapshot_num_proj = nn.Linear(14, self.d_model)
169
+
170
+ # --- NEW: Layers for TrendingToken Numerical Features ---
171
+ # --- FIXED: Size reduced from 3 to 1 (rank only) ---
172
+ self.trending_token_num_norm = nn.LayerNorm(1)
173
+ self.trending_token_num_proj = nn.Linear(1, self.d_model)
174
+ # --- NEW: Embeddings for categorical IDs ---
175
+ self.trending_list_source_embedding = nn.Embedding(vocab.NUM_TRENDING_LIST_SOURCES, self.d_model)
176
+ self.trending_timeframe_embedding = nn.Embedding(vocab.NUM_TRENDING_LIST_TIMEFRAMES, self.d_model)
177
+
178
+ # --- NEW: Layers for BoostedToken Numerical Features ---
179
+ self.boosted_token_num_norm = nn.LayerNorm(2) # total_boost_amount, rank
180
+ self.boosted_token_num_proj = nn.Linear(2, self.d_model)
181
+
182
+ # --- NEW: Layers for DexBoost_Paid Numerical Features ---
183
+ self.dexboost_paid_num_norm = nn.LayerNorm(2) # amount, total_amount_on_token
184
+ self.dexboost_paid_num_proj = nn.Linear(2, self.d_model)
185
+
186
+ # --- NEW: Layers for DexProfile_Updated Features ---
187
+ self.dexprofile_updated_flags_proj = nn.Linear(4, self.d_model) # Project the 4 boolean flags
188
+
189
+ # --- NEW: Projection for all pre-computed embeddings (text/images) ---
190
+ self.precomputed_proj = nn.Linear(self.multi_modal_dim, self.d_model)
191
+
192
+ # --- NEW: Embedding for Protocol IDs (used in Migrated event) ---
193
+ self.protocol_embedding = nn.Embedding(vocab.NUM_PROTOCOLS, self.d_model)
194
+
195
+ # --- NEW: Embeddings for TrackerEncoder Events ---
196
+ # Note: NUM_CALL_CHANNELS might need to be large and managed as vocab grows.
197
+ self.alpha_group_embedding = nn.Embedding(vocab.NUM_ALPHA_GROUPS, self.d_model)
198
+ self.call_channel_embedding = nn.Embedding(vocab.NUM_CALL_CHANNELS, self.d_model)
199
+ self.cex_listing_embedding = nn.Embedding(vocab.NUM_EXCHANGES, self.d_model)
200
+
201
+ # --- NEW: Layers for GlobalTrendingEncoder Events ---
202
+ self.global_trending_num_norm = nn.LayerNorm(1) # rank
203
+ self.global_trending_num_proj = nn.Linear(1, self.d_model)
204
+
205
+ # --- NEW: Layers for ChainSnapshot Events ---
206
+ self.chainsnapshot_num_norm = nn.LayerNorm(2) # native_token_price_usd, gas_fee
207
+ self.chainsnapshot_num_proj = nn.Linear(2, self.d_model)
208
+
209
+ # --- NEW: Layers for Lighthouse_Snapshot Events ---
210
+ # --- FIXED: Size reduced from 7 to 5 ---
211
+ self.lighthousesnapshot_num_norm = nn.LayerNorm(5)
212
+ self.lighthousesnapshot_num_proj = nn.Linear(5, self.d_model)
213
+ # --- NEW: Embedding for timeframe ID (re-uses protocol_embedding) ---
214
+ self.lighthouse_timeframe_embedding = nn.Embedding(vocab.NUM_LIGHTHOUSE_TIMEFRAMES, self.d_model)
215
+
216
+ # --- NEW: Embeddings for Special Context Tokens ---
217
+ self.special_context_tokens = {'Middle': 0, 'RECENT': 1}
218
+ self.special_context_embedding = nn.Embedding(len(self.special_context_tokens), self.d_model)
219
+
220
+
221
+ # --- 7. Prediction Head --- (Unchanged)
222
+ # self.prediction_head = nn.Linear(self.d_model, self.num_outputs)
223
+
224
+ # --- 8. Move all new modules to correct dtype ---
225
+ self.to(dtype)
226
+ print("Oracle model (full pipeline) initialized.")
227
+
228
+ def _normalize_and_project(self,
229
+ features: torch.Tensor,
230
+ norm_layer: nn.LayerNorm,
231
+ proj_layer: nn.Linear,
232
+ log_indices: Optional[List[int]] = None) -> torch.Tensor:
233
+ """
234
+ A helper function to selectively apply log scaling, then normalize and project.
235
+ """
236
+ # Make a copy to avoid in-place modification issues
237
+ processed_features = features.clone()
238
+
239
+ # Apply log scaling only to specified indices
240
+ if log_indices:
241
+ # Ensure log_indices are valid
242
+ valid_indices = [i for i in log_indices if i < processed_features.shape[-1]]
243
+ if valid_indices:
244
+ log_features = processed_features[:, :, valid_indices].to(torch.float32)
245
+ log_scaled = torch.sign(log_features) * torch.log1p(torch.abs(log_features))
246
+ processed_features[:, :, valid_indices] = log_scaled.to(processed_features.dtype)
247
+
248
+ # Normalize and project the entire feature set
249
+ norm_dtype = norm_layer.weight.dtype
250
+ proj_dtype = proj_layer.weight.dtype
251
+ normed_features = norm_layer(processed_features.to(norm_dtype))
252
+ return proj_layer(normed_features.to(proj_dtype))
253
+
254
+ def _run_snapshot_encoders(self,
255
+ batch: Dict[str, Any],
256
+ final_wallet_embeddings_raw: torch.Tensor,
257
+ wallet_addr_to_batch_idx: Dict[str, int]) -> Dict[str, torch.Tensor]:
258
+ """
259
+ Runs snapshot-style encoders that process raw data into embeddings.
260
+ This is now truly end-to-end.
261
+ """
262
+ device = self.device
263
+ all_holder_snapshot_embeds = []
264
+
265
+ # Iterate through each HolderSnapshot event's raw data
266
+ for raw_holder_list in batch['holder_snapshot_raw_data']:
267
+ processed_holder_data = []
268
+ for holder in raw_holder_list:
269
+ wallet_addr = holder['wallet']
270
+ # Get the graph-updated wallet embedding using its index
271
+ wallet_idx = wallet_addr_to_batch_idx.get(wallet_addr, 0) # 0 is padding
272
+ if wallet_idx > 0: # If it's a valid wallet
273
+ wallet_embedding = final_wallet_embeddings_raw[wallet_idx - 1] # Adjust for 1-based indexing
274
+ processed_holder_data.append({
275
+ 'wallet_embedding': wallet_embedding,
276
+ 'pct': holder['holding_pct']
277
+ })
278
+ # Pass the processed data to the HolderDistributionEncoder
279
+ all_holder_snapshot_embeds.append(self.holder_dist_encoder(processed_holder_data))
280
+
281
+ return {"holder_snapshot": torch.cat(all_holder_snapshot_embeds, dim=0) if all_holder_snapshot_embeds else torch.empty(0, self.d_model, device=device, dtype=self.dtype)}
282
+
283
+
284
+ def _run_dynamic_encoders(self, batch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
285
+ """
286
+ Runs all dynamic encoders and returns a dictionary of raw, unprojected embeddings.
287
+ """
288
+ device = self.device
289
+ # --- NEW: Get pre-computed embedding indices ---
290
+ token_encoder_inputs = batch['token_encoder_inputs']
291
+ wallet_encoder_inputs = batch['wallet_encoder_inputs']
292
+ # The pre-computed embedding pool for the whole batch
293
+ embedding_pool = batch['embedding_pool']
294
+
295
+ ohlc_price_tensors = batch['ohlc_price_tensors'].to(device, self.dtype)
296
+ ohlc_interval_ids = batch['ohlc_interval_ids'].to(device)
297
+ graph_updater_links = batch['graph_updater_links']
298
+
299
+ # 1a. Encode Tokens
300
+ # --- FIXED: Check for a key that still exists ---
301
+ if token_encoder_inputs['name_embed_indices'].numel() > 0:
302
+ # --- AGGRESSIVE LOGGING ---
303
+ print("\n--- [Oracle DynamicEncoder LOG] ---")
304
+ print(f"[Oracle LOG] embedding_pool shape: {embedding_pool.shape}")
305
+ print(f"[Oracle LOG] name_embed_indices (shape {token_encoder_inputs['name_embed_indices'].shape}):\n{token_encoder_inputs['name_embed_indices']}")
306
+ print(f"[Oracle LOG] symbol_embed_indices (shape {token_encoder_inputs['symbol_embed_indices'].shape}):\n{token_encoder_inputs['symbol_embed_indices']}")
307
+ print(f"[Oracle LOG] image_embed_indices (shape {token_encoder_inputs['image_embed_indices'].shape}):\n{token_encoder_inputs['image_embed_indices']}")
308
+ print("--- [Oracle LOG] Calling F.embedding and TokenEncoder... ---")
309
+ # --- END LOGGING ---
310
+ # --- NEW: Gather pre-computed embeddings and pass to encoder ---
311
+ # --- CRITICAL FIX: Remove keys that are not part of the TokenEncoder's signature ---
312
+ encoder_args = token_encoder_inputs.copy()
313
+ encoder_args.pop('_addresses_for_lookup', None) # This key is for the WalletEncoder
314
+ encoder_args.pop('name_embed_indices', None)
315
+ encoder_args.pop('symbol_embed_indices', None)
316
+ encoder_args.pop('image_embed_indices', None)
317
+
318
+ # --- SAFETY: Create a padded view of the embedding pool and map missing indices (-1) to pad ---
319
+ if embedding_pool.numel() > 0:
320
+ pad_row = torch.zeros(1, embedding_pool.size(1), device=device, dtype=embedding_pool.dtype)
321
+ pool_padded = torch.cat([pad_row, embedding_pool], dim=0)
322
+ def pad_and_lookup(idx_tensor: torch.Tensor) -> torch.Tensor:
323
+ # Map valid indices >=0 to +1 (shift), invalid (<0) to 0 (pad)
324
+ shifted = torch.where(idx_tensor >= 0, idx_tensor + 1, torch.zeros_like(idx_tensor))
325
+ return F.embedding(shifted, pool_padded)
326
+ name_embeds = pad_and_lookup(token_encoder_inputs['name_embed_indices'])
327
+ symbol_embeds = pad_and_lookup(token_encoder_inputs['symbol_embed_indices'])
328
+ image_embeds = pad_and_lookup(token_encoder_inputs['image_embed_indices'])
329
+ else:
330
+ # Empty pool: provide zeros with correct shapes
331
+ n = token_encoder_inputs['name_embed_indices'].shape[0]
332
+ d = self.multi_modal_dim
333
+ zeros = torch.zeros(n, d, device=device, dtype=self.dtype)
334
+ name_embeds = zeros
335
+ symbol_embeds = zeros
336
+ image_embeds = zeros
337
+
338
+ batch_token_embeddings_unupd = self.token_encoder(
339
+ name_embeds=name_embeds,
340
+ symbol_embeds=symbol_embeds,
341
+ image_embeds=image_embeds,
342
+ # Pass all other keys like protocol_ids, is_vanity_flags, etc.
343
+ **encoder_args
344
+ )
345
+ else:
346
+ batch_token_embeddings_unupd = torch.empty(0, self.token_encoder.output_dim, device=device, dtype=self.dtype)
347
+
348
+ # 1b. Encode Wallets
349
+ if wallet_encoder_inputs['profile_rows']:
350
+ temp_token_lookup = {
351
+ addr: batch_token_embeddings_unupd[i]
352
+ for i, addr in enumerate(batch['token_encoder_inputs']['_addresses_for_lookup']) # Use helper key
353
+ }
354
+ initial_wallet_embeddings = self.wallet_encoder(
355
+ **wallet_encoder_inputs,
356
+ token_vibe_lookup=temp_token_lookup,
357
+ embedding_pool=embedding_pool
358
+ )
359
+ else:
360
+ initial_wallet_embeddings = torch.empty(0, self.wallet_encoder.d_model, device=device, dtype=self.dtype)
361
+
362
+ # 1c. Encode OHLC
363
+ if ohlc_price_tensors.shape[0] > 0:
364
+ batch_ohlc_embeddings_raw = self.ohlc_embedder(ohlc_price_tensors, ohlc_interval_ids)
365
+ else:
366
+ batch_ohlc_embeddings_raw = torch.empty(0, self.ohlc_embedder.output_dim, device=device, dtype=self.dtype)
367
+
368
+ # 1d. Run Graph Updater
369
+ pad_wallet_raw = self.pad_wallet_emb.to(self.dtype)
370
+ pad_token_raw = self.pad_token_emb.to(self.dtype)
371
+ padded_wallet_tensor = torch.cat([pad_wallet_raw, initial_wallet_embeddings], dim=0)
372
+ padded_token_tensor = torch.cat([pad_token_raw, batch_token_embeddings_unupd], dim=0)
373
+
374
+ x_dict_initial = {}
375
+ if padded_wallet_tensor.shape[0] > 1: x_dict_initial['wallet'] = padded_wallet_tensor
376
+ if padded_token_tensor.shape[0] > 1: x_dict_initial['token'] = padded_token_tensor
377
+
378
+ if x_dict_initial and graph_updater_links:
379
+ final_entity_embeddings_dict = self.graph_updater(x_dict_initial, graph_updater_links)
380
+ final_padded_wallet_embs = final_entity_embeddings_dict.get('wallet', padded_wallet_tensor)
381
+ final_padded_token_embs = final_entity_embeddings_dict.get('token', padded_token_tensor)
382
+ else:
383
+ final_padded_wallet_embs = padded_wallet_tensor
384
+ final_padded_token_embs = padded_token_tensor
385
+
386
+ # Strip padding before returning
387
+ final_wallet_embeddings_raw = final_padded_wallet_embs[1:]
388
+ final_token_embeddings_raw = final_padded_token_embs[1:]
389
+
390
+ return {
391
+ "wallet": final_wallet_embeddings_raw,
392
+ "token": final_token_embeddings_raw,
393
+ "ohlc": batch_ohlc_embeddings_raw
394
+ }
395
+
396
+ def _project_and_gather_embeddings(self, raw_embeds: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
397
+ """
398
+ Projects raw embeddings to d_model and gathers them into sequence-aligned tensors.
399
+ """
400
+ # Project raw embeddings to d_model
401
+ final_wallet_proj = self.wallet_proj(raw_embeds['wallet'])
402
+ final_token_proj = self.token_proj(raw_embeds['token'])
403
+ final_ohlc_proj = self.ohlc_proj(raw_embeds['ohlc'])
404
+
405
+ # Project padding embeddings to d_model
406
+ pad_wallet = self.wallet_proj(self.pad_wallet_emb.to(self.dtype))
407
+ pad_token = self.token_proj(self.pad_token_emb.to(self.dtype))
408
+ pad_ohlc = self.ohlc_proj(self.pad_ohlc_emb.to(self.dtype))
409
+ pad_holder_snapshot = self.pad_holder_snapshot_emb.to(self.dtype) # Already d_model
410
+
411
+ # --- NEW: Project pre-computed embeddings and create lookup ---
412
+ final_precomputed_proj = self.precomputed_proj(batch['embedding_pool'])
413
+ pad_precomputed = self.precomputed_proj(self.pad_precomputed_emb.to(self.dtype))
414
+ final_precomputed_lookup = torch.cat([pad_precomputed, final_precomputed_proj], dim=0)
415
+
416
+ # Create final lookup tables with padding at index 0
417
+ final_wallet_lookup = torch.cat([pad_wallet, final_wallet_proj], dim=0)
418
+ final_token_lookup = torch.cat([pad_token, final_token_proj], dim=0)
419
+ final_ohlc_lookup = torch.cat([pad_ohlc, final_ohlc_proj], dim=0)
420
+
421
+
422
+ # --- NEW: Add Role Embeddings ---
423
+ main_role_emb = self.token_role_embedding(torch.tensor(self.main_token_role_id, device=self.device))
424
+ quote_role_emb = self.token_role_embedding(torch.tensor(self.quote_token_role_id, device=self.device))
425
+ trending_role_emb = self.token_role_embedding(torch.tensor(self.trending_token_role_id, device=self.device))
426
+
427
+ # Gather base embeddings
428
+ gathered_main_token_embs = F.embedding(batch['token_indices'], final_token_lookup)
429
+ gathered_quote_token_embs = F.embedding(batch['quote_token_indices'], final_token_lookup)
430
+ gathered_trending_token_embs = F.embedding(batch['trending_token_indices'], final_token_lookup)
431
+ gathered_boosted_token_embs = F.embedding(batch['boosted_token_indices'], final_token_lookup)
432
+
433
+ # --- NEW: Handle HolderSnapshot ---
434
+ final_holder_snapshot_lookup = torch.cat([pad_holder_snapshot, raw_embeds['holder_snapshot']], dim=0)
435
+
436
+ # Gather embeddings for each event in the sequence
437
+ return {
438
+ "wallet": F.embedding(batch['wallet_indices'], final_wallet_lookup),
439
+ "token": gathered_main_token_embs, # This is the baseline, no role needed
440
+ "ohlc": F.embedding(batch['ohlc_indices'], final_ohlc_lookup),
441
+ "original_author": F.embedding(batch['original_author_indices'], final_wallet_lookup), # NEW
442
+ "dest_wallet": F.embedding(batch['dest_wallet_indices'], final_wallet_lookup), # Also gather dest wallet
443
+ "quote_token": gathered_quote_token_embs + quote_role_emb,
444
+ "trending_token": gathered_trending_token_embs + trending_role_emb,
445
+ "boosted_token": gathered_boosted_token_embs + trending_role_emb, # Same role as trending
446
+ "holder_snapshot": F.embedding(batch['holder_snapshot_indices'], final_holder_snapshot_lookup), # NEW
447
+ "precomputed": final_precomputed_lookup # NEW: Pass the full lookup table
448
+ }
449
+
450
+ def _get_transfer_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor:
451
+ """
452
+ Calculates the special embeddings for Transfer/LargeTransfer events.
453
+ """
454
+ device = self.device
455
+ transfer_numerical_features = batch['transfer_numerical_features']
456
+ event_type_ids = batch['event_type_ids']
457
+
458
+ # --- FIXED: Selectively log-scale features ---
459
+ # Log scale: token_amount (idx 0), priority_fee (idx 3)
460
+ # Linear scale: transfer_pct_of_total_supply (idx 1), transfer_pct_of_holding (idx 2)
461
+ projected_transfer_features = self._normalize_and_project(
462
+ transfer_numerical_features, self.transfer_num_norm, self.transfer_num_proj, log_indices=[0, 3]
463
+ )
464
+ # Create a mask for Transfer/LargeTransfer events
465
+ transfer_event_ids = [self.event_type_to_id.get('Transfer', -1), self.event_type_to_id.get('LargeTransfer', -1)] # ADDED LargeTransfer
466
+ transfer_mask = torch.isin(event_type_ids, torch.tensor(transfer_event_ids, device=device)).unsqueeze(-1)
467
+
468
+ # Combine destination wallet and numerical features, then apply mask
469
+ return (gathered_embeds['dest_wallet'] + projected_transfer_features) * transfer_mask
470
+
471
+ def _get_trade_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
472
+ """
473
+ Calculates the special embeddings for Trade events.
474
+ """
475
+ device = self.device
476
+ trade_numerical_features = batch['trade_numerical_features']
477
+ trade_dex_ids = batch['trade_dex_ids'] # NEW
478
+ trade_direction_ids = batch['trade_direction_ids']
479
+ trade_mev_protection_ids = batch['trade_mev_protection_ids'] # NEW
480
+ trade_is_bundle_ids = batch['trade_is_bundle_ids'] # NEW
481
+ event_type_ids = batch['event_type_ids']
482
+
483
+ # --- FIXED: Selectively log-scale features ---
484
+ # Log scale: sol_amount (idx 0), priority_fee (idx 1), total_usd (idx 7)
485
+ # Linear scale: pcts, slippage, price_impact, success flags
486
+ projected_trade_features = self._normalize_and_project(
487
+ trade_numerical_features, self.trade_num_norm, self.trade_num_proj, log_indices=[0, 1, 7]
488
+ )
489
+
490
+ # --- CORRECTED: This layer now handles both generic and large trades ---
491
+ trade_event_names = ['Trade', 'LargeTrade']
492
+ trade_event_ids = [self.event_type_to_id.get(name, -1) for name in trade_event_names]
493
+
494
+ # Create mask where event_type_id is one of the trade event ids
495
+ trade_mask = torch.isin(event_type_ids, torch.tensor(trade_event_ids, device=device)).unsqueeze(-1)
496
+
497
+ # --- NEW: Get embedding for the categorical dex_id ---
498
+ dex_id_embeds = self.dex_platform_embedding(trade_dex_ids)
499
+ direction_embeds = self.trade_direction_embedding(trade_direction_ids)
500
+ mev_embeds = self.mev_protection_embedding(trade_mev_protection_ids) # NEW
501
+ bundle_embeds = self.is_bundle_embedding(trade_is_bundle_ids) # NEW
502
+
503
+ return (projected_trade_features + dex_id_embeds + direction_embeds + mev_embeds + bundle_embeds) * trade_mask
504
+
505
+ def _get_deployer_trade_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
506
+ """
507
+ Calculates the special embeddings for Deployer_Trade events using its own layers.
508
+ """
509
+ device = self.device
510
+ deployer_trade_numerical_features = batch['deployer_trade_numerical_features']
511
+ trade_dex_ids = batch['trade_dex_ids'] # NEW: Re-use the same ID tensor
512
+ trade_direction_ids = batch['trade_direction_ids']
513
+ trade_mev_protection_ids = batch['trade_mev_protection_ids'] # NEW
514
+ trade_is_bundle_ids = batch['trade_is_bundle_ids'] # NEW
515
+ event_type_ids = batch['event_type_ids']
516
+
517
+ # --- FIXED: Selectively log-scale features ---
518
+ # Log scale: sol_amount (idx 0), priority_fee (idx 1), total_usd (idx 7)
519
+ projected_deployer_trade_features = self._normalize_and_project(
520
+ deployer_trade_numerical_features, self.deployer_trade_num_norm, self.deployer_trade_num_proj, log_indices=[0, 1, 7]
521
+ )
522
+
523
+ dex_id_embeds = self.dex_platform_embedding(trade_dex_ids)
524
+ direction_embeds = self.trade_direction_embedding(trade_direction_ids)
525
+ mev_embeds = self.mev_protection_embedding(trade_mev_protection_ids) # NEW
526
+ bundle_embeds = self.is_bundle_embedding(trade_is_bundle_ids) # NEW
527
+
528
+ deployer_trade_mask = (event_type_ids == self.event_type_to_id.get('Deployer_Trade', -1)).unsqueeze(-1)
529
+ return (projected_deployer_trade_features + dex_id_embeds + direction_embeds + mev_embeds + bundle_embeds) * deployer_trade_mask
530
+
531
+ def _get_smart_wallet_trade_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
532
+ """
533
+ Calculates the special embeddings for SmartWallet_Trade events using its own layers.
534
+ """
535
+ device = self.device
536
+ smart_wallet_trade_numerical_features = batch['smart_wallet_trade_numerical_features']
537
+ trade_dex_ids = batch['trade_dex_ids'] # NEW: Re-use the same ID tensor
538
+ trade_direction_ids = batch['trade_direction_ids']
539
+ trade_mev_protection_ids = batch['trade_mev_protection_ids'] # NEW
540
+ trade_is_bundle_ids = batch['trade_is_bundle_ids'] # NEW
541
+ event_type_ids = batch['event_type_ids']
542
+
543
+ # --- FIXED: Selectively log-scale features ---
544
+ # Log scale: sol_amount (idx 0), priority_fee (idx 1), total_usd (idx 7)
545
+ projected_features = self._normalize_and_project(
546
+ smart_wallet_trade_numerical_features, self.smart_wallet_trade_num_norm, self.smart_wallet_trade_num_proj, log_indices=[0, 1, 7]
547
+ )
548
+
549
+ dex_id_embeds = self.dex_platform_embedding(trade_dex_ids)
550
+ direction_embeds = self.trade_direction_embedding(trade_direction_ids)
551
+ mev_embeds = self.mev_protection_embedding(trade_mev_protection_ids) # NEW
552
+ bundle_embeds = self.is_bundle_embedding(trade_is_bundle_ids) # NEW
553
+
554
+ mask = (event_type_ids == self.event_type_to_id.get('SmartWallet_Trade', -1)).unsqueeze(-1)
555
+ return (projected_features + dex_id_embeds + direction_embeds + mev_embeds + bundle_embeds) * mask
556
+
557
+ def _get_pool_created_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor:
558
+ """
559
+ Calculates the special embeddings for PoolCreated events.
560
+ """
561
+ device = self.device
562
+ pool_created_numerical_features = batch['pool_created_numerical_features']
563
+ pool_created_protocol_ids = batch['pool_created_protocol_ids'] # NEW
564
+ event_type_ids = batch['event_type_ids']
565
+
566
+ # --- FIXED: Selectively log-scale features ---
567
+ # Log scale: base_amount (idx 0), quote_amount (idx 1)
568
+ # Linear scale: pcts (idx 2, 3)
569
+ projected_features = self._normalize_and_project(
570
+ pool_created_numerical_features, self.pool_created_num_norm, self.pool_created_num_proj, log_indices=[0, 1]
571
+ )
572
+ # --- NEW: Get embedding for the categorical protocol_id ---
573
+ protocol_id_embeds = self.protocol_embedding(pool_created_protocol_ids)
574
+
575
+ # Create mask for the event
576
+ mask = (event_type_ids == self.event_type_to_id.get('PoolCreated', -1)).unsqueeze(-1)
577
+
578
+ # Combine Quote Token embedding with projected numericals
579
+ return (gathered_embeds['quote_token'] + projected_features + protocol_id_embeds) * mask
580
+
581
+ def _get_liquidity_change_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor:
582
+ """
583
+ Calculates the special embeddings for LiquidityChange events.
584
+ """
585
+ device = self.device
586
+ liquidity_change_numerical_features = batch['liquidity_change_numerical_features']
587
+ liquidity_change_type_ids = batch['liquidity_change_type_ids'] # NEW
588
+ event_type_ids = batch['event_type_ids']
589
+
590
+ # --- FIXED: Selectively log-scale features ---
591
+ # Log scale: quote_amount (idx 0)
592
+ projected_features = self._normalize_and_project(
593
+ liquidity_change_numerical_features, self.liquidity_change_num_norm, self.liquidity_change_num_proj, log_indices=[0]
594
+ )
595
+ # --- NEW: Get embedding for the categorical change_type_id ---
596
+ change_type_embeds = self.liquidity_change_type_embedding(liquidity_change_type_ids)
597
+
598
+ # Create mask for the event
599
+ mask = (event_type_ids == self.event_type_to_id.get('LiquidityChange', -1)).unsqueeze(-1)
600
+
601
+ # Combine Quote Token embedding with projected numericals
602
+ return (gathered_embeds['quote_token'] + projected_features + change_type_embeds) * mask
603
+
604
+ def _get_fee_collected_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
605
+ """
606
+ Calculates the special embeddings for FeeCollected events.
607
+ """
608
+ device = self.device
609
+ fee_collected_numerical_features = batch['fee_collected_numerical_features']
610
+ event_type_ids = batch['event_type_ids']
611
+
612
+ # --- FIXED: Single amount, log-scale ---
613
+ projected_features = self._normalize_and_project(
614
+ fee_collected_numerical_features, self.fee_collected_num_norm, self.fee_collected_num_proj, log_indices=[0]
615
+ )
616
+
617
+ # Create mask for the event
618
+ mask = (event_type_ids == self.event_type_to_id.get('FeeCollected', -1)).unsqueeze(-1)
619
+
620
+ return projected_features * mask
621
+
622
+ def _get_token_burn_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
623
+ """
624
+ Calculates the special embeddings for TokenBurn events.
625
+ """
626
+ device = self.device
627
+ token_burn_numerical_features = batch['token_burn_numerical_features']
628
+ event_type_ids = batch['event_type_ids']
629
+
630
+ # --- FIXED: Selectively log-scale features ---
631
+ # Log scale: amount_tokens_burned (idx 1)
632
+ # Linear scale: amount_pct_of_total_supply (idx 0)
633
+ projected_features = self._normalize_and_project(
634
+ token_burn_numerical_features, self.token_burn_num_norm, self.token_burn_num_proj, log_indices=[1]
635
+ )
636
+ # Create mask for the event
637
+ mask = (event_type_ids == self.event_type_to_id.get('TokenBurn', -1)).unsqueeze(-1)
638
+
639
+ return projected_features * mask
640
+
641
+ def _get_supply_lock_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
642
+ """
643
+ Calculates the special embeddings for SupplyLock events.
644
+ """
645
+ device = self.device
646
+ supply_lock_numerical_features = batch['supply_lock_numerical_features']
647
+ event_type_ids = batch['event_type_ids']
648
+
649
+ # --- FIXED: Selectively log-scale features ---
650
+ # Log scale: lock_duration (idx 1)
651
+ # Linear scale: amount_pct_of_total_supply (idx 0)
652
+ projected_features = self._normalize_and_project(
653
+ supply_lock_numerical_features, self.supply_lock_num_norm, self.supply_lock_num_proj, log_indices=[1]
654
+ )
655
+ # Create mask for the event
656
+ mask = (event_type_ids == self.event_type_to_id.get('SupplyLock', -1)).unsqueeze(-1)
657
+
658
+ return projected_features * mask
659
+
660
+ def _get_onchain_snapshot_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
661
+ """
662
+ Calculates the special embeddings for OnChain_Snapshot events.
663
+ """
664
+ device = self.device
665
+ onchain_snapshot_numerical_features = batch['onchain_snapshot_numerical_features']
666
+ event_type_ids = batch['event_type_ids']
667
+
668
+ # --- FIXED: Selectively log-scale features ---
669
+ # Log scale: counts, market_cap, liquidity, volume, fees (almost all)
670
+ # Linear scale: growth_rate, holder_pcts (indices 3, 4, 5, 6, 7)
671
+ projected_features = self._normalize_and_project(
672
+ onchain_snapshot_numerical_features, self.onchain_snapshot_num_norm, self.onchain_snapshot_num_proj, log_indices=[0, 1, 2, 8, 9, 10, 11, 12, 13]
673
+ )
674
+ # Create mask for the event
675
+ mask = (event_type_ids == self.event_type_to_id.get('OnChain_Snapshot', -1)).unsqueeze(-1)
676
+
677
+ return projected_features * mask
678
+
679
+ def _get_trending_token_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor:
680
+ """
681
+ Calculates the special embeddings for TrendingToken events.
682
+ """
683
+ device = self.device
684
+ trending_token_numerical_features = batch['trending_token_numerical_features']
685
+ trending_token_source_ids = batch['trending_token_source_ids'] # NEW
686
+ trending_token_timeframe_ids = batch['trending_token_timeframe_ids'] # NEW
687
+ event_type_ids = batch['event_type_ids']
688
+
689
+ # --- FIXED: Rank is already inverted (0-1), so treat as linear ---
690
+ projected_features = self._normalize_and_project(
691
+ trending_token_numerical_features, self.trending_token_num_norm, self.trending_token_num_proj, log_indices=None
692
+ )
693
+
694
+ # --- NEW: Get embeddings for categorical IDs ---
695
+ source_embeds = self.trending_list_source_embedding(trending_token_source_ids)
696
+ timeframe_embeds = self.trending_timeframe_embedding(trending_token_timeframe_ids)
697
+
698
+ # Create mask for the event
699
+ mask = (event_type_ids == self.event_type_to_id.get('TrendingToken', -1)).unsqueeze(-1)
700
+
701
+ # Combine Trending Token embedding with its projected numericals
702
+ return (gathered_embeds['trending_token'] + projected_features + source_embeds + timeframe_embeds) * mask
703
+
704
+ def _get_boosted_token_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor:
705
+ """
706
+ Calculates the special embeddings for BoostedToken events.
707
+ """
708
+ device = self.device
709
+ boosted_token_numerical_features = batch['boosted_token_numerical_features']
710
+ event_type_ids = batch['event_type_ids']
711
+
712
+ # --- FIXED: Selectively log-scale features ---
713
+ # Log scale: total_boost_amount (idx 0)
714
+ # Linear scale: inverted rank (idx 1)
715
+ projected_features = self._normalize_and_project(
716
+ boosted_token_numerical_features, self.boosted_token_num_norm, self.boosted_token_num_proj, log_indices=[0]
717
+ )
718
+ # Create mask for the event
719
+ mask = (event_type_ids == self.event_type_to_id.get('BoostedToken', -1)).unsqueeze(-1)
720
+
721
+ # Combine Boosted Token embedding with its projected numericals
722
+ return (gathered_embeds['boosted_token'] + projected_features) * mask
723
+
724
+ def _get_dexboost_paid_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
725
+ """
726
+ Calculates the special embeddings for DexBoost_Paid events.
727
+ """
728
+ device = self.device
729
+ dexboost_paid_numerical_features = batch['dexboost_paid_numerical_features']
730
+ event_type_ids = batch['event_type_ids']
731
+
732
+ # --- FIXED: All features are amounts, so log-scale all ---
733
+ projected_features = self._normalize_and_project(
734
+ dexboost_paid_numerical_features, self.dexboost_paid_num_norm, self.dexboost_paid_num_proj, log_indices=[0, 1]
735
+ )
736
+ # Create mask for the event
737
+ mask = (event_type_ids == self.event_type_to_id.get('DexBoost_Paid', -1)).unsqueeze(-1)
738
+
739
+ return projected_features * mask
740
+
741
+ def _get_alphagroup_call_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
742
+ """
743
+ Handles AlphaGroup_Call events by looking up the group_id embedding.
744
+ """
745
+ device = self.device
746
+ group_ids = batch['alpha_group_ids']
747
+ event_type_ids = batch['event_type_ids']
748
+
749
+ # Look up the embedding for the group ID
750
+ group_embeds = self.alpha_group_embedding(group_ids)
751
+
752
+ # Create mask for the event
753
+ mask = (event_type_ids == self.event_type_to_id.get('AlphaGroup_Call', -1)).unsqueeze(-1)
754
+ return group_embeds * mask
755
+
756
+ def _get_channel_call_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
757
+ """
758
+ Handles Channel_Call events by looking up the channel_id embedding.
759
+ """
760
+ device = self.device
761
+ channel_ids = batch['channel_ids']
762
+ event_type_ids = batch['event_type_ids']
763
+
764
+ channel_embeds = self.call_channel_embedding(channel_ids)
765
+ mask = (event_type_ids == self.event_type_to_id.get('Channel_Call', -1)).unsqueeze(-1)
766
+ return channel_embeds * mask
767
+
768
+ def _get_cexlisting_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
769
+ """
770
+ Handles CexListing events by looking up the exchange_id embedding.
771
+ """
772
+ device = self.device
773
+ exchange_ids = batch['exchange_ids']
774
+ event_type_ids = batch['event_type_ids']
775
+
776
+ exchange_embeds = self.cex_listing_embedding(exchange_ids)
777
+ mask = (event_type_ids == self.event_type_to_id.get('CexListing', -1)).unsqueeze(-1)
778
+ return exchange_embeds * mask
779
+
780
+ def _get_chainsnapshot_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
781
+ """
782
+ Handles ChainSnapshot events.
783
+ """
784
+ device = self.device
785
+ numerical_features = batch['chainsnapshot_numerical_features']
786
+ event_type_ids = batch['event_type_ids']
787
+
788
+ # --- FIXED: All features are amounts/prices, so log-scale all ---
789
+ projected_features = self._normalize_and_project(
790
+ numerical_features, self.chainsnapshot_num_norm, self.chainsnapshot_num_proj, log_indices=[0, 1]
791
+ )
792
+ mask = (event_type_ids == self.event_type_to_id.get('ChainSnapshot', -1)).unsqueeze(-1)
793
+ return projected_features * mask
794
+
795
+ def _get_lighthousesnapshot_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
796
+ """
797
+ Handles Lighthouse_Snapshot events.
798
+ """
799
+ device = self.device
800
+ numerical_features = batch['lighthousesnapshot_numerical_features']
801
+ protocol_ids = batch['lighthousesnapshot_protocol_ids'] # NEW
802
+ timeframe_ids = batch['lighthousesnapshot_timeframe_ids'] # NEW
803
+ event_type_ids = batch['event_type_ids']
804
+
805
+ # --- FIXED: All features are counts/volumes, so log-scale all ---
806
+ projected_features = self._normalize_and_project(
807
+ numerical_features, self.lighthousesnapshot_num_norm, self.lighthousesnapshot_num_proj, log_indices=[0, 1, 2, 3, 4]
808
+ )
809
+ # --- NEW: Get embeddings for categorical IDs ---
810
+ # Re-use the main protocol embedding layer
811
+ protocol_embeds = self.protocol_embedding(protocol_ids)
812
+ timeframe_embeds = self.lighthouse_timeframe_embedding(timeframe_ids)
813
+
814
+ mask = (event_type_ids == self.event_type_to_id.get('Lighthouse_Snapshot', -1)).unsqueeze(-1)
815
+ return (projected_features + protocol_embeds + timeframe_embeds) * mask
816
+
817
+ def _get_migrated_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
818
+ """
819
+ Handles Migrated events by looking up the protocol_id embedding.
820
+ """
821
+ device = self.device
822
+ protocol_ids = batch['migrated_protocol_ids']
823
+ event_type_ids = batch['event_type_ids']
824
+
825
+ # Look up the embedding for the protocol ID
826
+ protocol_embeds = self.protocol_embedding(protocol_ids)
827
+
828
+ # Create mask for the event
829
+ mask = (event_type_ids == self.event_type_to_id.get('Migrated', -1)).unsqueeze(-1)
830
+ return protocol_embeds * mask
831
+
832
+ def _get_special_context_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
833
+ """
834
+ Handles special context tokens like 'Middle' and 'RECENT' by adding their unique learnable embeddings.
835
+ """
836
+ device = self.device
837
+ event_type_ids = batch['event_type_ids']
838
+ B, L = event_type_ids.shape
839
+
840
+ middle_id = self.event_type_to_id.get('Middle', -1)
841
+ recent_id = self.event_type_to_id.get('RECENT', -1)
842
+
843
+ middle_mask = (event_type_ids == middle_id)
844
+ recent_mask = (event_type_ids == recent_id)
845
+
846
+ middle_emb = self.special_context_embedding(torch.tensor(self.special_context_tokens['Middle'], device=device))
847
+ recent_emb = self.special_context_embedding(torch.tensor(self.special_context_tokens['RECENT'], device=device))
848
+
849
+ # Add the embeddings at the correct locations
850
+ return middle_mask.unsqueeze(-1) * middle_emb + recent_mask.unsqueeze(-1) * recent_emb
851
+
852
+ def _pool_hidden_states(self,
853
+ hidden_states: torch.Tensor,
854
+ attention_mask: torch.Tensor) -> torch.Tensor:
855
+ """
856
+ Pools variable-length hidden states into a single embedding per sequence by
857
+ selecting the last non-masked token for each batch element.
858
+ """
859
+ if hidden_states.size(0) == 0:
860
+ return torch.empty(0, self.d_model, device=hidden_states.device, dtype=hidden_states.dtype)
861
+
862
+ seq_lengths = attention_mask.long().sum(dim=1)
863
+ last_indices = torch.clamp(seq_lengths - 1, min=0)
864
+ batch_indices = torch.arange(hidden_states.size(0), device=hidden_states.device)
865
+ return hidden_states[batch_indices, last_indices]
866
+
867
+ def forward(self, batch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
868
+ device = self.device
869
+
870
+ # Unpack core sequence tensors
871
+ event_type_ids = batch['event_type_ids'].to(device)
872
+ timestamps_float = batch['timestamps_float'].to(device)
873
+ relative_ts = batch['relative_ts'].to(device, self.dtype)
874
+ attention_mask = batch['attention_mask'].to(device)
875
+
876
+ B, L = event_type_ids.shape
877
+ if B == 0 or L == 0:
878
+ print("Warning: Received empty batch in Oracle forward.")
879
+ empty_hidden = torch.empty(0, L, self.d_model, device=device, dtype=self.dtype)
880
+ empty_mask = torch.empty(0, L, device=device, dtype=torch.long)
881
+ empty_quantiles = torch.empty(0, self.num_outputs, device=device, dtype=self.dtype)
882
+ return {
883
+ 'quantile_logits': empty_quantiles,
884
+ 'pooled_states': torch.empty(0, self.d_model, device=device, dtype=self.dtype),
885
+ 'hidden_states': empty_hidden,
886
+ 'attention_mask': empty_mask
887
+ }
888
+
889
+ # === 1. Run Dynamic Encoders (produces graph-updated entity embeddings) ===
890
+ dynamic_raw_embeds = self._run_dynamic_encoders(batch)
891
+
892
+
893
+ # === 2. Run Snapshot Encoders (uses dynamic_raw_embeds) ===
894
+ wallet_addr_to_batch_idx = batch['wallet_addr_to_batch_idx']
895
+ snapshot_raw_embeds = self._run_snapshot_encoders(batch, dynamic_raw_embeds['wallet'], wallet_addr_to_batch_idx)
896
+
897
+ # === 3. Project Raw Embeddings and Gather for Sequence ===
898
+ raw_embeds = {**dynamic_raw_embeds, **snapshot_raw_embeds}
899
+ gathered_embeds = self._project_and_gather_embeddings(raw_embeds, batch)
900
+
901
+ # === 4. Assemble Final `inputs_embeds` ===
902
+ event_embeds = self.event_type_embedding(event_type_ids)
903
+ ts_embeds = self.time_proj(self.time_encoder(timestamps_float))
904
+ # Stabilize relative time: minutes scale + signed log1p + LayerNorm before projection
905
+ relative_ts_fp32 = batch['relative_ts'].to(device, torch.float32)
906
+ rel_ts_minutes = relative_ts_fp32 / 60.0
907
+ rel_ts_processed = torch.sign(rel_ts_minutes) * torch.log1p(torch.abs(rel_ts_minutes))
908
+ # Match LayerNorm parameter dtype, then match Linear parameter dtype
909
+ norm_dtype = self.rel_ts_norm.weight.dtype
910
+ proj_dtype = self.rel_ts_proj.weight.dtype
911
+ rel_ts_normed = self.rel_ts_norm(rel_ts_processed.to(norm_dtype))
912
+ rel_ts_embeds = self.rel_ts_proj(rel_ts_normed.to(proj_dtype))
913
+
914
+ # Get special embeddings for Transfer events
915
+ transfer_specific_embeds = self._get_transfer_specific_embeddings(batch, gathered_embeds)
916
+
917
+ # Get special embeddings for Trade events
918
+ trade_specific_embeds = self._get_trade_specific_embeddings(batch)
919
+
920
+ # Get special embeddings for Deployer Trade events
921
+ deployer_trade_specific_embeds = self._get_deployer_trade_specific_embeddings(batch)
922
+
923
+ # Get special embeddings for Smart Wallet Trade events
924
+ smart_wallet_trade_specific_embeds = self._get_smart_wallet_trade_specific_embeddings(batch)
925
+
926
+ # Get special embeddings for PoolCreated events
927
+ pool_created_specific_embeds = self._get_pool_created_specific_embeddings(batch, gathered_embeds)
928
+
929
+ # Get special embeddings for LiquidityChange events
930
+ liquidity_change_specific_embeds = self._get_liquidity_change_specific_embeddings(batch, gathered_embeds)
931
+
932
+ # Get special embeddings for FeeCollected events
933
+ fee_collected_specific_embeds = self._get_fee_collected_specific_embeddings(batch)
934
+
935
+ # Get special embeddings for TokenBurn events
936
+ token_burn_specific_embeds = self._get_token_burn_specific_embeddings(batch)
937
+
938
+ # Get special embeddings for SupplyLock events
939
+ supply_lock_specific_embeds = self._get_supply_lock_specific_embeddings(batch)
940
+
941
+ # Get special embeddings for OnChain_Snapshot events
942
+ onchain_snapshot_specific_embeds = self._get_onchain_snapshot_specific_embeddings(batch)
943
+
944
+ # Get special embeddings for TrendingToken events
945
+ trending_token_specific_embeds = self._get_trending_token_specific_embeddings(batch, gathered_embeds)
946
+
947
+ # Get special embeddings for BoostedToken events
948
+ boosted_token_specific_embeds = self._get_boosted_token_specific_embeddings(batch, gathered_embeds)
949
+
950
+ # Get special embeddings for DexBoost_Paid events
951
+ dexboost_paid_specific_embeds = self._get_dexboost_paid_specific_embeddings(batch)
952
+
953
+ # --- NEW: Get embeddings for Tracker events ---
954
+ alphagroup_call_specific_embeds = self._get_alphagroup_call_specific_embeddings(batch)
955
+ channel_call_specific_embeds = self._get_channel_call_specific_embeddings(batch)
956
+ cexlisting_specific_embeds = self._get_cexlisting_specific_embeddings(batch)
957
+
958
+ # --- NEW: Get embeddings for Chain and Lighthouse Snapshots ---
959
+ chainsnapshot_specific_embeds = self._get_chainsnapshot_specific_embeddings(batch)
960
+ lighthousesnapshot_specific_embeds = self._get_lighthousesnapshot_specific_embeddings(batch)
961
+
962
+ migrated_specific_embeds = self._get_migrated_specific_embeddings(batch)
963
+
964
+ # --- NEW: Handle DexProfile_Updated flags separately ---
965
+ dexprofile_updated_flags = batch['dexprofile_updated_flags']
966
+ dexprofile_flags_embeds = self.dexprofile_updated_flags_proj(dexprofile_updated_flags.to(self.dtype))
967
+
968
+ # --- REFACTORED: All text-based events are handled by the SocialEncoder ---
969
+ # This single call will replace the inefficient loops for social, dexprofile, and global trending events.
970
+ # The SocialEncoder's forward pass will need to be updated to handle this.
971
+ textual_event_embeds = self.social_encoder(
972
+ batch=batch,
973
+ gathered_embeds=gathered_embeds
974
+ )
975
+
976
+ # --- NEW: Get embeddings for special context injection tokens ---
977
+ special_context_embeds = self._get_special_context_embeddings(batch)
978
+
979
+ # --- Combine all features ---
980
+ # Sum in float32 for numerical stability, then cast back to model dtype
981
+ components = [
982
+ event_embeds, ts_embeds, rel_ts_embeds,
983
+ gathered_embeds['wallet'], gathered_embeds['token'], gathered_embeds['original_author'], gathered_embeds['ohlc'],
984
+ transfer_specific_embeds, trade_specific_embeds, deployer_trade_specific_embeds, smart_wallet_trade_specific_embeds,
985
+ pool_created_specific_embeds, liquidity_change_specific_embeds, fee_collected_specific_embeds,
986
+ token_burn_specific_embeds, supply_lock_specific_embeds, onchain_snapshot_specific_embeds,
987
+ trending_token_specific_embeds, boosted_token_specific_embeds, dexboost_paid_specific_embeds,
988
+ alphagroup_call_specific_embeds, channel_call_specific_embeds, cexlisting_specific_embeds,
989
+ migrated_specific_embeds, special_context_embeds, gathered_embeds['holder_snapshot'], textual_event_embeds,
990
+ dexprofile_flags_embeds, chainsnapshot_specific_embeds, lighthousesnapshot_specific_embeds
991
+ ]
992
+ inputs_embeds = sum([t.float() for t in components]).to(self.dtype)
993
+
994
+ hf_attention_mask = attention_mask.to(device=device, dtype=torch.long)
995
+ outputs = self.model(
996
+ inputs_embeds=inputs_embeds,
997
+ attention_mask=hf_attention_mask,
998
+ return_dict=True
999
+ )
1000
+ sequence_hidden = outputs.last_hidden_state
1001
+ pooled_states = self._pool_hidden_states(sequence_hidden, hf_attention_mask)
1002
+ quantile_logits = self.quantile_head(pooled_states)
1003
+
1004
+ return {
1005
+ 'quantile_logits': quantile_logits,
1006
+ 'pooled_states': pooled_states,
1007
+ 'hidden_states': sequence_hidden,
1008
+ 'attention_mask': hf_attention_mask
1009
+ }
models/multi_modal_processor.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # multi_modal_processor.py
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from transformers import AutoModel, AutoProcessor, AutoConfig
6
+ from typing import List, Union
7
+ from PIL import Image
8
+ import requests
9
+ import io
10
+ import os
11
+ import traceback
12
+ import numpy as np
13
+
14
+ # Suppress warnings
15
+ os.environ["TRANSFORMERS_VERBOSITY"] = "error"
16
+
17
+ class MultiModalEncoder:
18
+ """
19
+ Encodes text OR images into a shared, NORMALIZED embedding space
20
+ using google/siglip-so400m-patch16-256-i18n.
21
+ This class is intended for creating embeddings for vector search.
22
+ """
23
+
24
+ def __init__(self, model_id="google/siglip-so400m-patch16-256-i18n", dtype: torch.dtype = torch.bfloat16):
25
+ self.model_id = model_id
26
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
27
+
28
+ self.dtype = dtype
29
+
30
+
31
+ try:
32
+ # --- SigLIP Loading with Config Fix ---
33
+ self.processor = AutoProcessor.from_pretrained(
34
+ self.model_id,
35
+ use_fast=True
36
+ )
37
+
38
+ config = AutoConfig.from_pretrained(self.model_id)
39
+
40
+ if not hasattr(config, 'projection_dim'):
41
+ # print("❗ Config missing projection_dim, patching...")
42
+ config.projection_dim = config.text_config.hidden_size
43
+
44
+ self.model = AutoModel.from_pretrained(
45
+ self.model_id,
46
+ config=config,
47
+ dtype=self.dtype, # Use torch_dtype for from_pretrained
48
+ trust_remote_code=False
49
+ ).to(self.device).eval()
50
+ # -----------------------------------------------
51
+
52
+ self.embedding_dim = config.projection_dim
53
+
54
+ except Exception as e:
55
+ print(f"❌ Failed to load SigLIP model or components: {e}")
56
+ traceback.print_exc()
57
+ raise
58
+
59
+ @torch.no_grad()
60
+ def __call__(self, x: Union[List[str], List[Image.Image]]) -> torch.Tensor:
61
+ """
62
+ Encode a batch of text or images into normalized [batch_size, embedding_dim] vectors.
63
+ This is correct for storing in a vector DB for cosine similarity.
64
+ """
65
+ if not x:
66
+ return torch.empty(0, self.embedding_dim).to(self.device)
67
+
68
+ is_text = isinstance(x[0], str)
69
+
70
+ autocast_dtype = self.dtype if self.dtype in [torch.float16, torch.bfloat16] else None
71
+
72
+ print(f"\n[MME LOG] ENTERING __call__ for {'TEXT' if is_text else 'IMAGE'} batch of size {len(x)}")
73
+ print(f"[MME LOG] Input data preview: {str(x[0])[:100] if is_text else x[0]}")
74
+
75
+ with torch.amp.autocast(device_type=self.device, enabled=(self.device == 'cuda' and autocast_dtype is not None), dtype=autocast_dtype):
76
+ try:
77
+ if is_text:
78
+ inputs = self.processor(
79
+ text=x,
80
+ return_tensors="pt",
81
+ padding="max_length",
82
+ truncation=True
83
+ ).to(self.device)
84
+ print(f"[MME LOG] Text processor output shape: {inputs['input_ids'].shape}")
85
+ embeddings = self.model.get_text_features(**inputs)
86
+ else:
87
+ rgb_images = [img.convert("RGB") if img.mode != 'RGB' else img for img in x]
88
+ inputs = self.processor(
89
+ images=rgb_images,
90
+ return_tensors="pt"
91
+ ).to(self.device)
92
+
93
+ if 'pixel_values' in inputs and inputs['pixel_values'].dtype != self.dtype:
94
+ inputs['pixel_values'] = inputs['pixel_values'].to(self.dtype)
95
+
96
+ embeddings = self.model.get_image_features(**inputs)
97
+
98
+ print(f"[MME LOG] Raw model output embeddings shape: {embeddings.shape}, dtype: {embeddings.dtype}")
99
+
100
+ # <<< THIS IS THE FIX. I accidentally removed this.
101
+ # Normalize in float32 for numerical stability
102
+ embeddings = F.normalize(embeddings.float(), p=2, dim=-1)
103
+ print(f"[MME LOG] Normalized embeddings shape: {embeddings.shape}, dtype: {embeddings.dtype}")
104
+
105
+ final_embeddings = embeddings.to(self.dtype)
106
+ print(f"[MME LOG] Final embeddings shape: {final_embeddings.shape}, dtype: {final_embeddings.dtype}. EXITING __call__.")
107
+ return final_embeddings
108
+
109
+ except Exception as e:
110
+ print(f"❌ [MME LOG] FATAL ERROR during encoding {'text' if is_text else 'images'}: {e}")
111
+ traceback.print_exc()
112
+ return torch.empty(0, self.embedding_dim).to(self.device)
113
+
114
+ # --- Test block (SigLIP) ---
115
+ if __name__ == "__main__":
116
+ # This test now uses the encoder class exactly as you intend to.
117
+
118
+ MODEL_ID = "google/siglip-so400m-patch16-256-i18n"
119
+ print(f"\n--- MultiModalEncoder Test ({MODEL_ID}) ---")
120
+
121
+ texts = [
122
+ "Uranus", # Text 0
123
+ "Anus", # Text 1
124
+ "Ass", # Text 2
125
+ "Planet", # Text 3
126
+ "Dog" # Text 4
127
+ ]
128
+
129
+ try:
130
+ img_urls = [
131
+ "https://pbs.twimg.com/media/G3ra9C8W0AAGR8V.jpg", # Image 0: Uranus meme pic
132
+ ]
133
+ headers = {"User-Agent": "Mozilla/5.0"}
134
+ images = [
135
+ Image.open(io.BytesIO(requests.get(u, headers=headers).content))
136
+ for u in img_urls
137
+ ]
138
+
139
+ size = 256 # Model's expected size
140
+ images.append(Image.new("RGB", (size, size), color="green")) # Image 1: Green Square
141
+ print(f"✅ Downloaded test image and created green square (size {size}x{size}).")
142
+
143
+ except Exception as e:
144
+ print(f"❌ Failed to load images: {e}")
145
+ traceback.print_exc()
146
+ exit()
147
+
148
+ try:
149
+ # 1. Initialize your encoder
150
+ encoder = MultiModalEncoder(model_id=MODEL_ID)
151
+
152
+ print("\n--- Encoding Texts (Separately) ---")
153
+ text_embeddings = encoder(texts) # Uses __call__
154
+ print(f"Shape: {text_embeddings.shape}")
155
+
156
+ print("\n--- Encoding Images (Separately) ---")
157
+ image_embeddings = encoder(images) # Uses __call__
158
+ print(f"Shape: {image_embeddings.shape}")
159
+
160
+
161
+ print("\n--- Similarity Check (Your Goal) ---")
162
+
163
+ # 2. Calculate Cosine Similarity
164
+ # This is just a dot product because the encoder __call__ method
165
+ # already normalized the vectors.
166
+ similarity_matrix = torch.matmul(image_embeddings.cpu(), text_embeddings.cpu().T).numpy()
167
+
168
+ np.set_printoptions(precision=4, suppress=True)
169
+ print("\nCosine Similarity matrix (image × text):")
170
+ # Row: Images (0: Uranus Pic, 1: Green)
171
+ # Col: Texts (0: Uranus, 1: Anus, 2: Ass, 3: Planet, 4: Dog)
172
+ print(similarity_matrix)
173
+
174
+ print("\nSpecific Similarity Scores (Cosine Similarity, -1.0 to 1.0):")
175
+ print(f"Image 0 (Uranus pic) vs Text 0 (Uranus): {similarity_matrix[0][0]:.4f}")
176
+ print(f"Image 0 (Uranus pic) vs Text 1 (Anus): {similarity_matrix[0][1]:.4f}")
177
+ print(f"Image 0 (Uranus pic) vs Text 3 (Planet): {similarity_matrix[0][3]:.4f}")
178
+ print(f"Image 0 (Uranus pic) vs Text 4 (Dog): {similarity_matrix[0][4]:.4f}")
179
+ print(f"Image 1 (Green) vs Text 4 (Dog): {similarity_matrix[1][4]:.4f}")
180
+
181
+ except Exception as e:
182
+ print(f"\n--- An error occurred during the SigLIP test run ---")
183
+ print(f"Error: {e}")
184
+ traceback.print_exc()
models/ohlc_embedder.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import List
5
+
6
+ # --- Import vocabulary for the test block ---
7
+ import models.vocabulary as vocab
8
+
9
+ class OHLCEmbedder(nn.Module):
10
+ """
11
+ Embeds a sequence of Open and Close prices AND its interval.
12
+
13
+ FIXED: Now takes interval_ids as input and combines an
14
+ interval embedding with the 1D-CNN chart pattern features.
15
+ """
16
+ def __init__(
17
+ self,
18
+ # --- NEW: Interval vocab size ---
19
+ num_intervals: int,
20
+ input_channels: int = 2, # Open, Close
21
+ sequence_length: int = 300,
22
+ cnn_channels: List[int] = [16, 32, 64],
23
+ kernel_sizes: List[int] = [3, 3, 3],
24
+ # --- NEW: Interval embedding dim ---
25
+ interval_embed_dim: int = 32,
26
+ output_dim: int = 4096,
27
+ dtype: torch.dtype = torch.float16
28
+ ):
29
+ super().__init__()
30
+ assert len(cnn_channels) == len(kernel_sizes), "cnn_channels and kernel_sizes must have the same length"
31
+
32
+ self.dtype = dtype
33
+ self.sequence_length = sequence_length
34
+ self.cnn_layers = nn.ModuleList()
35
+ self.output_dim = output_dim
36
+
37
+ in_channels = input_channels
38
+ current_seq_len = sequence_length
39
+
40
+ for i, (out_channels, k_size) in enumerate(zip(cnn_channels, kernel_sizes)):
41
+ conv = nn.Conv1d(
42
+ in_channels=in_channels,
43
+ out_channels=out_channels,
44
+ kernel_size=k_size,
45
+ padding='same'
46
+ )
47
+ self.cnn_layers.append(conv)
48
+ pool = nn.MaxPool1d(kernel_size=2, stride=2)
49
+ self.cnn_layers.append(pool)
50
+ current_seq_len = current_seq_len // 2
51
+ self.cnn_layers.append(nn.ReLU())
52
+ in_channels = out_channels
53
+
54
+ self.global_pool = nn.AdaptiveAvgPool1d(1)
55
+
56
+ final_cnn_channels = cnn_channels[-1]
57
+
58
+ # --- NEW: Interval Embedding Layer ---
59
+ self.interval_embedding = nn.Embedding(num_intervals, interval_embed_dim, padding_idx=0)
60
+
61
+ # --- NEW: MLP input dim is (CNN features + Interval features) ---
62
+ mlp_input_dim = final_cnn_channels + interval_embed_dim
63
+
64
+ self.mlp = nn.Sequential(
65
+ nn.Linear(mlp_input_dim, mlp_input_dim * 2),
66
+ nn.GELU(),
67
+ nn.LayerNorm(mlp_input_dim * 2),
68
+ nn.Linear(mlp_input_dim * 2, output_dim),
69
+ nn.LayerNorm(output_dim)
70
+ )
71
+
72
+ self.to(dtype)
73
+
74
+ def forward(self, x: torch.Tensor, interval_ids: torch.Tensor) -> torch.Tensor:
75
+ """
76
+ Args:
77
+ x (torch.Tensor): Batch of normalized OHLC sequences.
78
+ Shape: [batch_size, 2, sequence_length]
79
+ interval_ids (torch.Tensor): Batch of interval IDs.
80
+ Shape: [batch_size]
81
+ Returns:
82
+ torch.Tensor: Batch of OHLC embeddings.
83
+ Shape: [batch_size, output_dim]
84
+ """
85
+ if x.shape[1] != 2 or x.shape[2] != self.sequence_length:
86
+ raise ValueError(f"Input tensor shape mismatch. Expected [B, 2, {self.sequence_length}], got {x.shape}")
87
+
88
+ x = x.to(self.dtype)
89
+
90
+ # 1. Pass through CNN layers
91
+ for layer in self.cnn_layers:
92
+ x = layer(x)
93
+
94
+ # 2. Apply global average pooling
95
+ x = self.global_pool(x)
96
+
97
+ # 3. Flatten for MLP
98
+ x = x.squeeze(-1)
99
+ # Shape: [batch_size, final_cnn_channels]
100
+
101
+ # 4. --- NEW: Get interval embedding ---
102
+ interval_embed = self.interval_embedding(interval_ids)
103
+ # Shape: [batch_size, interval_embed_dim]
104
+
105
+ # 5. --- NEW: Combine features ---
106
+ combined = torch.cat([x, interval_embed], dim=1)
107
+ # Shape: [batch_size, final_cnn_channels + interval_embed_dim]
108
+
109
+ # 6. Pass through final MLP
110
+ x = self.mlp(combined)
111
+ # Shape: [batch_size, output_dim]
112
+
113
+ return x
114
+
models/token_encoder.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # token_encoder.py (FIXED)
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from typing import List, Any
6
+ from PIL import Image
7
+
8
+ from models.multi_modal_processor import MultiModalEncoder
9
+ from models.wallet_set_encoder import WalletSetEncoder # Using your set encoder
10
+ from models.vocabulary import NUM_PROTOCOLS
11
+
12
+ class TokenEncoder(nn.Module):
13
+ """
14
+ Encodes a token's core identity into a single <TokenVibeEmbedding>.
15
+
16
+ FIXED: This version uses a robust fusion architecture and provides
17
+ a dynamic, smaller output dimension (e.g., 2048) suitable for
18
+ being an input to a larger model.
19
+ """
20
+ def __init__(
21
+ self,
22
+ multi_dim: int, # NEW: Pass the dimension directly
23
+ output_dim: int = 2048,
24
+ internal_dim: int = 1024, # INCREASED: Better balance between bottleneck and capacity
25
+ protocol_embed_dim: int = 64,
26
+ vanity_embed_dim: int = 32, # NEW: Small embedding for the vanity flag
27
+ nhead: int = 4,
28
+ num_layers: int = 1,
29
+ dtype: torch.dtype = torch.float16
30
+ ):
31
+ """
32
+ Initializes the TokenEncoder.
33
+
34
+ Args:
35
+ siglip_dim (int): The embedding dimension of the multimodal encoder (e.g., 1152).
36
+ output_dim (int):
37
+ The final dimension of the <TokenVibeEmbedding> (e.g., 2048).
38
+ internal_dim (int):
39
+ The shared dimension for the internal fusion transformer (e.g., 1024).
40
+ protocol_embed_dim (int):
41
+ Small dimension for the protocol ID (e.g., 64).
42
+ vanity_embed_dim (int):
43
+ Small dimension for the is_vanity boolean flag.
44
+ nhead (int):
45
+ Attention heads for the fusion transformer.
46
+ num_layers (int):
47
+ Layers for the fusion transformer.
48
+ dtype (torch.dtype):
49
+ The data type (e.g., torch.float16).
50
+ """
51
+ super().__init__()
52
+ self.output_dim = output_dim
53
+ self.internal_dim = internal_dim
54
+ self.dtype = dtype
55
+
56
+ # Store SigLIP's fixed output dim (e.g., 1152)
57
+ self.multi_dim = multi_dim
58
+
59
+ # --- 1. Projection Layers ---
60
+ # Project all features to the *internal_dim*
61
+ self.name_proj = nn.Linear(self.multi_dim, internal_dim)
62
+ self.symbol_proj = nn.Linear(self.multi_dim, internal_dim)
63
+ self.image_proj = nn.Linear(self.multi_dim, internal_dim)
64
+
65
+ # --- 2. Categorical & Boolean Feature Embeddings ---
66
+
67
+ # Use small vocab size and small embed dim
68
+ self.protocol_embedding = nn.Embedding(NUM_PROTOCOLS, protocol_embed_dim)
69
+ # Project from small dim (64) up to internal_dim (1024)
70
+ self.protocol_proj = nn.Linear(protocol_embed_dim, internal_dim)
71
+
72
+ # NEW: Embedding for the is_vanity boolean flag
73
+ self.vanity_embedding = nn.Embedding(2, vanity_embed_dim) # 2 classes: True/False
74
+ self.vanity_proj = nn.Linear(vanity_embed_dim, internal_dim)
75
+
76
+ # --- 3. Fusion Encoder ---
77
+ # Re-use WalletSetEncoder to fuse the sequence of 5 features
78
+ self.fusion_transformer = WalletSetEncoder(
79
+ d_model=internal_dim,
80
+ nhead=nhead,
81
+ num_layers=num_layers,
82
+ dim_feedforward=internal_dim * 4, # Standard 4x
83
+ dtype=dtype
84
+ )
85
+
86
+ # --- 4. Final Output Projection ---
87
+ # Project from the transformer's output (internal_dim)
88
+ # to the final, dynamic output_dim.
89
+ self.final_projection = nn.Sequential(
90
+ nn.Linear(internal_dim, internal_dim * 2),
91
+ nn.GELU(),
92
+ nn.LayerNorm(internal_dim * 2),
93
+ nn.Linear(internal_dim * 2, output_dim),
94
+ nn.LayerNorm(output_dim)
95
+ )
96
+
97
+ # Cast new layers to the correct dtype and device
98
+ device = "cuda" if torch.cuda.is_available() else "cpu"
99
+ self.to(device=device, dtype=dtype)
100
+
101
+ def forward(
102
+ self,
103
+ name_embeds: torch.Tensor,
104
+ symbol_embeds: torch.Tensor,
105
+ image_embeds: torch.Tensor,
106
+ protocol_ids: torch.Tensor,
107
+ is_vanity_flags: torch.Tensor,
108
+ ) -> torch.Tensor:
109
+ """
110
+ Processes a batch of token data to create a batch of embeddings.
111
+
112
+ Args:
113
+ name_embeds (torch.Tensor): Pre-computed embeddings for token names. Shape: [B, siglip_dim]
114
+ symbol_embeds (torch.Tensor): Pre-computed embeddings for token symbols. Shape: [B, siglip_dim]
115
+ image_embeds (torch.Tensor): Pre-computed embeddings for token images. Shape: [B, siglip_dim]
116
+ protocol_ids (torch.Tensor): Batch of protocol IDs. Shape: [B]
117
+ is_vanity_flags (torch.Tensor): Batch of boolean flags for vanity addresses. Shape: [B]
118
+
119
+ Returns:
120
+ torch.Tensor: The final <TokenVibeEmbedding> batch.
121
+ Shape: [batch_size, output_dim]
122
+ """
123
+ device = name_embeds.device
124
+ batch_size = name_embeds.shape[0]
125
+
126
+ # 2. Get Protocol embedding (small)
127
+ print(f"\n--- [TokenEncoder LOG] ENTERING FORWARD PASS (Batch Size: {batch_size}) ---")
128
+ print(f"[TokenEncoder LOG] Input protocol_ids (shape {protocol_ids.shape}):\n{protocol_ids}")
129
+ print(f"[TokenEncoder LOG] Protocol Embedding Vocab Size: {self.protocol_embedding.num_embeddings}")
130
+
131
+ protocol_ids_long = protocol_ids.to(device, dtype=torch.long)
132
+ protocol_emb_raw = self.protocol_embedding(protocol_ids_long) # [B, 64]
133
+ print(f"[TokenEncoder LOG] Raw protocol embeddings shape: {protocol_emb_raw.shape}")
134
+
135
+ # NEW: Get vanity embedding
136
+ vanity_ids_long = is_vanity_flags.to(device, dtype=torch.long)
137
+ vanity_emb_raw = self.vanity_embedding(vanity_ids_long) # [B, 32]
138
+
139
+ # 3. Project all features to internal_dim (e.g., 1024)
140
+ print(f"[TokenEncoder LOG] Projecting features to internal_dim: {self.internal_dim}")
141
+ name_emb = self.name_proj(name_embeds)
142
+ symbol_emb = self.symbol_proj(symbol_embeds)
143
+ image_emb = self.image_proj(image_embeds)
144
+ protocol_emb = self.protocol_proj(protocol_emb_raw)
145
+ vanity_emb = self.vanity_proj(vanity_emb_raw) # NEW
146
+
147
+ # 4. Stack all projected features into a sequence
148
+ feature_sequence = torch.stack([
149
+ name_emb,
150
+ symbol_emb,
151
+ image_emb,
152
+ protocol_emb,
153
+ vanity_emb, # NEW: Add the vanity embedding to the sequence
154
+ ], dim=1)
155
+
156
+ print(f"[TokenEncoder LOG] Stacked feature_sequence shape: {feature_sequence.shape}")
157
+ print(f" - name_emb shape: {name_emb.shape}")
158
+ print(f" - symbol_emb shape: {symbol_emb.shape}")
159
+ print(f" - image_emb shape: {image_emb.shape}")
160
+ print(f" - protocol_emb shape: {protocol_emb.shape}")
161
+ print(f" - vanity_emb shape: {vanity_emb.shape}") # ADDED: Log the new vanity embedding shape
162
+
163
+ # 5. Create the padding mask (all False, since we have a fixed number of features for all)
164
+ padding_mask = torch.zeros(batch_size, feature_sequence.shape[1], device=device, dtype=torch.bool)
165
+ print(f"[TokenEncoder LOG] Created padding_mask of shape: {padding_mask.shape}")
166
+
167
+ # 6. Fuse the sequence with the Transformer Encoder
168
+ # This returns the [CLS] token output.
169
+ # Shape: [B, internal_dim]
170
+ fused_embedding = self.fusion_transformer(
171
+ item_embeds=feature_sequence,
172
+ src_key_padding_mask=padding_mask
173
+ )
174
+ print(f"[TokenEncoder LOG] Fused embedding shape after transformer: {fused_embedding.shape}")
175
+
176
+ # 7. Project to the final output dimension
177
+ # Shape: [B, output_dim]
178
+ token_vibe_embedding = self.final_projection(fused_embedding)
179
+ print(f"[TokenEncoder LOG] Final token_vibe_embedding shape: {token_vibe_embedding.shape}")
180
+ print(f"--- [TokenEncoder LOG] EXITING FORWARD PASS ---\n")
181
+
182
+ return token_vibe_embedding
models/vocabulary.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vocabulary.py
2
+ """
3
+ Defines the vocabulary and mappings for categorical features.
4
+ """
5
+
6
+ # --- Event Type Mappings ---
7
+ EVENT_NAMES = [
8
+ '__PAD__', 'Chart_Segment', 'Mint',
9
+ 'Transfer', 'LargeTransfer',
10
+ 'Trade',
11
+ 'Deployer_Trade',
12
+ 'SmartWallet_Trade',
13
+ 'LargeTrade',
14
+ 'PoolCreated',
15
+ 'LiquidityChange',
16
+ 'FeeCollected',
17
+ 'TokenBurn',
18
+ 'SupplyLock',
19
+ 'OnChain_Snapshot',
20
+ 'HolderSnapshot',
21
+ 'TrendingToken',
22
+ 'BoostedToken',
23
+ 'XPost',
24
+ 'XRetweet',
25
+ 'XReply',
26
+ 'XQuoteTweet',
27
+ 'PumpReply',
28
+ 'DexBoost_Paid',
29
+ 'DexProfile_Updated',
30
+ 'AlphaGroup_Call',
31
+ 'Channel_Call',
32
+ 'CexListing',
33
+ 'TikTok_Trending_Hashtag',
34
+ 'XTrending_Hashtag',
35
+ 'ChainSnapshot',
36
+ 'Lighthouse_Snapshot',
37
+ 'Migrated',
38
+ 'MIDDLE',
39
+ 'RECENT'
40
+ ]
41
+ EVENT_TO_ID = {name: i for i, name in enumerate(EVENT_NAMES)}
42
+ ID_TO_EVENT = {i: name for i, name in enumerate(EVENT_NAMES)}
43
+ NUM_EVENT_TYPES = len(EVENT_NAMES)
44
+
45
+ # --- Protocol Mappings ---
46
+
47
+ # The canonical list of protocol names
48
+ PROTOCOL_NAMES = [
49
+ "Unknown",
50
+ "Pump V1",
51
+ "Pump AMM",
52
+ "Bonk",
53
+ "Raydium CPMM"
54
+ ]
55
+
56
+ PROTOCOL_TO_ID = {name: i for i, name in enumerate(PROTOCOL_NAMES)}
57
+ ID_TO_PROTOCOL = {i: name for i, name in enumerate(PROTOCOL_NAMES)}
58
+ NUM_PROTOCOLS = len(PROTOCOL_NAMES)
59
+
60
+
61
+ # --- Neo4J Link Type Mappings ---
62
+ # UPDATED: Added link types from your Neo4j schema
63
+ LINK_TYPES = [
64
+ "TransferLink",
65
+ "TransferLinkToken",
66
+ "BundleTradeLink",
67
+ "CopiedTradeLink",
68
+ "CoordinatedActivityLink",
69
+ "MintedLink",
70
+ "SnipedLink",
71
+ "LockedSupplyLink",
72
+ "BurnedLink",
73
+ "ProvidedLiquidityLink",
74
+ "WhaleOfLink",
75
+ "TopTraderOfLink",
76
+ ]
77
+
78
+ LINK_TYPE_TO_ID = {name: i for i, name in enumerate(LINK_TYPES)}
79
+ ID_TO_LINK_TYPE = {i: name for i, name in enumerate(LINK_TYPES)}
80
+ NUM_LINK_TYPES = len(LINK_TYPES)
81
+
82
+ LINK_NAME_TO_TRIPLET = {
83
+ # Wallet <-> Wallet Links
84
+ "TransferLink": ('wallet', 'TransferLink', 'wallet'),
85
+ "BundleTradeLink": ('wallet', 'BundleTradeLink', 'wallet'),
86
+ "CopiedTradeLink": ('wallet', 'CopiedTradeLink', 'wallet'),
87
+ "CoordinatedActivityLink": ('wallet', 'CoordinatedActivityLink', 'wallet'),
88
+
89
+ # Wallet -> Token Links
90
+ "TransferLinkToken": ('wallet', 'TransferLinkToken', 'token'),
91
+ "MintedLink": ('wallet', 'MintedLink', 'token'),
92
+ "SnipedLink": ('wallet', 'SnipedLink', 'token'),
93
+ "LockedSupplyLink": ('wallet', 'LockedSupplyLink', 'token'),
94
+ "BurnedLink": ('wallet', 'BurnedLink', 'token'),
95
+ "ProvidedLiquidityLink": ('wallet', 'ProvidedLiquidityLink', 'token'),
96
+ "WhaleOfLink": ('wallet', 'WhaleOfLink', 'token'),
97
+ "TopTraderOfLink": ('wallet', 'TopTraderOfLink', 'token'),
98
+ }
99
+
100
+
101
+ # --- NEW: OHLC Interval Mappings ---
102
+ OHLC_INTERVALS = [
103
+ "Unknown", # ID 0
104
+ "1s", # ID 1
105
+ "30s", # ID 2
106
+ ]
107
+
108
+ INTERVAL_TO_ID = {name: i for i, name in enumerate(OHLC_INTERVALS)}
109
+ ID_TO_INTERVAL = {i: name for i, name in enumerate(OHLC_INTERVALS)}
110
+ NUM_OHLC_INTERVALS = len(OHLC_INTERVALS)
111
+
112
+ DEX_NAMES = [
113
+ "Unknown",
114
+ "Axiom",
115
+ "Bullx",
116
+ "OXK",
117
+ "Trojan",
118
+ "Jupyter"
119
+ ]
120
+
121
+ DEX_TO_ID = {name: i for i, name in enumerate(DEX_NAMES)}
122
+ ID_TO_DEX = {i: name for i, name in enumerate(DEX_NAMES)}
123
+ NUM_DEX_PLATFORMS = len(DEX_NAMES)
124
+
125
+ # --- NEW: Trending List Source Mappings ---
126
+ TRENDING_LIST_SOURCES = [
127
+ "Unknown",
128
+ "Phantom",
129
+ "Dexscreener"
130
+ ]
131
+
132
+ TRENDING_LIST_SOURCE_TO_ID = {name: i for i, name in enumerate(TRENDING_LIST_SOURCES)}
133
+ ID_TO_TRENDING_LIST_SOURCE = {i: name for i, name in enumerate(TRENDING_LIST_SOURCES)}
134
+ NUM_TRENDING_LIST_SOURCES = len(TRENDING_LIST_SOURCES)
135
+
136
+ # --- NEW: Trending List Timeframe Mappings ---
137
+ TRENDING_LIST_TIMEFRAMES = [
138
+ "Unknown",
139
+ "5m",
140
+ "1h",
141
+ "24h"
142
+ ]
143
+ TRENDING_LIST_TIMEFRAME_TO_ID = {name: i for i, name in enumerate(TRENDING_LIST_TIMEFRAMES)}
144
+ ID_TO_TRENDING_LIST_TIMEFRAME = {i: name for i, name in enumerate(TRENDING_LIST_TIMEFRAMES)}
145
+ NUM_TRENDING_LIST_TIMEFRAMES = len(TRENDING_LIST_TIMEFRAMES)
146
+
147
+ # --- NEW: Lighthouse Snapshot Timeframe Mappings ---
148
+ LIGHTHOUSE_TIMEFRAMES = [
149
+ "Unknown",
150
+ "5m",
151
+ "1h",
152
+ "6h",
153
+ "24h"
154
+ ]
155
+ LIGHTHOUSE_TIMEFRAME_TO_ID = {name: i for i, name in enumerate(LIGHTHOUSE_TIMEFRAMES)}
156
+ NUM_LIGHTHOUSE_TIMEFRAMES = len(LIGHTHOUSE_TIMEFRAMES)
157
+
158
+ # --- NEW: TrackerEncoder Vocabularies ---
159
+
160
+ # Alpha Groups (Discord)
161
+ ALPHA_GROUPS = [
162
+ "unknown",
163
+ "Potion",
164
+ "Serenity",
165
+ "Digi World"
166
+ ]
167
+ ALPHA_GROUPS_TO_ID = {name: i for i, name in enumerate(ALPHA_GROUPS)}
168
+ ID_TO_ALPHA_GROUPS = {i: name for i, name in enumerate(ALPHA_GROUPS)}
169
+ NUM_ALPHA_GROUPS = len(ALPHA_GROUPS)
170
+
171
+ # Call Channels (Telegram)
172
+ CALL_CHANNELS = [
173
+ "unknown",
174
+ "MarcosCalls",
175
+ "kobecalls",
176
+ "DEGEMSCALLS"
177
+ ]
178
+ CALL_CHANNELS_TO_ID = {name: i for i, name in enumerate(CALL_CHANNELS)}
179
+ ID_TO_CALL_CHANNELS = {i: name for i, name in enumerate(CALL_CHANNELS)}
180
+ NUM_CALL_CHANNELS = len(CALL_CHANNELS)
181
+
182
+ # CEX Exchanges
183
+ EXCHANGES = [
184
+ "unknown", "mexc", "weex", "binance", "kraken"
185
+ ]
186
+ EXCHANGES_TO_ID = {name: i for i, name in enumerate(EXCHANGES)}
187
+ ID_TO_EXCHANGES = {i: name for i, name in enumerate(EXCHANGES)}
188
+ NUM_EXCHANGES = len(EXCHANGES)
models/wallet_encoder.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import List, Dict, Any, Optional
5
+ from PIL import Image
6
+
7
+ # We assume these helper modules are in the same directory
8
+ from models.multi_modal_processor import MultiModalEncoder
9
+ from models.wallet_set_encoder import WalletSetEncoder
10
+
11
+
12
+ class WalletEncoder(nn.Module):
13
+ """
14
+ Encodes a wallet's full identity into a single <WalletEmbedding>.
15
+ UPDATED: Aligned with the final feature spec.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ encoder: MultiModalEncoder ,
21
+ d_model: int = 2048, # Standardized to d_model
22
+ token_vibe_dim: int = 2048, # Expects token vibe of d_model
23
+ set_encoder_nhead: int = 8,
24
+ set_encoder_nlayers: int = 2,
25
+ dtype: torch.dtype = torch.float16
26
+ ):
27
+ """
28
+ Initializes the WalletEncoder.
29
+
30
+ Args:
31
+ d_model (int): The final output dimension (e.g., 4096).
32
+ token_vibe_dim (int): The dimension of the pre-computed
33
+ <TokenVibeEmbedding> (e.g., 1024).
34
+ encoder (MultiModalEncoder): Instantiated SigLIP encoder.
35
+ time_encoder (ContextualTimeEncoder): Instantiated time encoder.
36
+ set_encoder_nhead (int): Attention heads for set encoders.
37
+ set_encoder_nlayers (int): Transformer layers for set encoders.
38
+ dtype (torch.dtype): Data type.
39
+ """
40
+ super().__init__()
41
+ self.d_model = d_model
42
+ self.dtype = dtype
43
+ self.encoder = encoder
44
+
45
+ # --- Dimensions ---
46
+ self.token_vibe_dim = token_vibe_dim
47
+ self.mmp_dim = self.encoder.embedding_dim # 1152
48
+
49
+ # === 1. Profile Encoder (FIXED) ===
50
+ # 1 age + 5 deployer_stats + 1 balance + 4 lifetime_counts +
51
+ # 3 lifetime_trading + 12 1d_stats + 12 7d_stats = 38
52
+ self.profile_numerical_features = 38
53
+ self.profile_num_norm = nn.LayerNorm(self.profile_numerical_features)
54
+
55
+
56
+ # FIXED: Input dim no longer has bool embed or deployed tokens embed
57
+ profile_mlp_in_dim = self.profile_numerical_features # 38
58
+ self.profile_encoder_mlp = self._build_mlp(profile_mlp_in_dim, d_model)
59
+
60
+
61
+
62
+ # === 2. Social Encoder (FIXED) ===
63
+ # 4 booleans: has_pf, has_twitter, has_telegram, is_exchange_wallet
64
+ self.social_bool_embed = nn.Embedding(2, 16)
65
+ # FIXED: Input dim is (4 * 16) + mmp_dim
66
+ social_mlp_in_dim = (16 * 4) + self.mmp_dim # username embed
67
+ self.social_encoder_mlp = self._build_mlp(social_mlp_in_dim, d_model)
68
+
69
+
70
+ # === 3. Holdings Encoder (FIXED) ===
71
+ # 11 original stats + 1 holding_time = 12
72
+ self.holding_numerical_features = 12
73
+ self.holding_num_norm = nn.LayerNorm(self.holding_numerical_features)
74
+
75
+ # FIXED: Input dim no longer uses time_encoder
76
+ holding_row_in_dim = (
77
+ self.token_vibe_dim + # <TokenVibeEmbedding>
78
+ self.holding_numerical_features # 12
79
+ )
80
+ self.holding_row_encoder_mlp = self._build_mlp(holding_row_in_dim, d_model)
81
+
82
+ self.holdings_set_encoder = WalletSetEncoder(
83
+ d_model, set_encoder_nhead, set_encoder_nlayers, dtype=dtype
84
+ )
85
+
86
+
87
+ # === 5. Final Fusion Encoder (Unchanged) ===
88
+ # Still fuses 4 components: Profile, Social, Holdings, Graph
89
+ self.fusion_mlp = nn.Sequential(
90
+ nn.Linear(d_model * 3, d_model * 2), # Input is d_model * 3
91
+ nn.GELU(),
92
+ nn.LayerNorm(d_model * 2),
93
+ nn.Linear(d_model * 2, d_model),
94
+ nn.LayerNorm(d_model)
95
+ )
96
+ self.to(dtype)
97
+
98
+ def _build_mlp(self, in_dim, out_dim):
99
+ return nn.Sequential(
100
+ nn.Linear(in_dim, out_dim * 2),
101
+ nn.GELU(),
102
+ nn.LayerNorm(out_dim * 2),
103
+ nn.Linear(out_dim * 2, out_dim),
104
+ ).to(self.dtype)
105
+
106
+ def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor:
107
+ # Log-normalizes numerical features (like age, stats, etc.)
108
+ return torch.sign(x) * torch.log1p(torch.abs(x))
109
+
110
+ def _get_device(self) -> torch.device:
111
+ return self.encoder.device
112
+
113
+ def forward(
114
+ self,
115
+ profile_rows: List[Dict[str, Any]],
116
+ social_rows: List[Dict[str, Any]],
117
+ holdings_batch: List[List[Dict[str, Any]]],
118
+ token_vibe_lookup: Dict[str, torch.Tensor],
119
+ embedding_pool: torch.Tensor,
120
+ username_embed_indices: torch.Tensor
121
+ ) -> torch.Tensor:
122
+ device = self._get_device()
123
+
124
+ profile_embed = self._encode_profile_batch(profile_rows, device)
125
+ social_embed = self._encode_social_batch(social_rows, embedding_pool, username_embed_indices, device)
126
+ holdings_embed = self._encode_holdings_batch(holdings_batch, token_vibe_lookup, device)
127
+
128
+ fused = torch.cat([profile_embed, social_embed, holdings_embed], dim=1)
129
+ return self.fusion_mlp(fused)
130
+
131
+ def _encode_profile_batch(self, profile_rows, device):
132
+ batch_size = len(profile_rows)
133
+ # FIXED: 38 numerical features
134
+ num_tensor = torch.zeros(batch_size, self.profile_numerical_features, device=device, dtype=self.dtype)
135
+ # bool_tensor removed
136
+ # time_tensor removed
137
+
138
+ for i, row in enumerate(profile_rows):
139
+ # A: Numerical (FIXED: 38 features, MUST be present)
140
+ num_data = [
141
+ # 1. Age
142
+ row.get('age', 0.0),
143
+ # 2. Deployed Token Aggregates (5)
144
+ row.get('deployed_tokens_count', 0.0),
145
+ row.get('deployed_tokens_migrated_pct', 0.0),
146
+ row.get('deployed_tokens_avg_lifetime_sec', 0.0),
147
+ row.get('deployed_tokens_avg_peak_mc_usd', 0.0),
148
+ row.get('deployed_tokens_median_peak_mc_usd', 0.0),
149
+ # 3. Balance (1)
150
+ row.get('balance', 0.0),
151
+ # 4. Lifetime Transaction Counts (4)
152
+ row.get('transfers_in_count', 0.0), row.get('transfers_out_count', 0.0),
153
+ row.get('spl_transfers_in_count', 0.0), row.get('spl_transfers_out_count', 0.0),
154
+ # 5. Lifetime Trading Stats (3)
155
+ row.get('total_buys_count', 0.0), row.get('total_sells_count', 0.0),
156
+ row.get('total_winrate', 0.0),
157
+ # 6. 1-Day Stats (12)
158
+ row.get('stats_1d_realized_profit_sol', 0.0), row.get('stats_1d_realized_profit_pnl', 0.0),
159
+ row.get('stats_1d_buy_count', 0.0), row.get('stats_1d_sell_count', 0.0),
160
+ row.get('stats_1d_transfer_in_count', 0.0), row.get('stats_1d_transfer_out_count', 0.0),
161
+ row.get('stats_1d_avg_holding_period', 0.0), row.get('stats_1d_total_bought_cost_sol', 0.0),
162
+ row.get('stats_1d_total_sold_income_sol', 0.0), row.get('stats_1d_total_fee', 0.0),
163
+ row.get('stats_1d_winrate', 0.0), row.get('stats_1d_tokens_traded', 0.0),
164
+ # 7. 7-Day Stats (12)
165
+ row.get('stats_7d_realized_profit_sol', 0.0), row.get('stats_7d_realized_profit_pnl', 0.0),
166
+ row.get('stats_7d_buy_count', 0.0), row.get('stats_7d_sell_count', 0.0),
167
+ row.get('stats_7d_transfer_in_count', 0.0), row.get('stats_7d_transfer_out_count', 0.0),
168
+ row.get('stats_7d_avg_holding_period', 0.0), row.get('stats_7d_total_bought_cost_sol', 0.0),
169
+ row.get('stats_7d_total_sold_income_sol', 0.0), row.get('stats_7d_total_fee', 0.0),
170
+ row.get('stats_7d_winrate', 0.0), row.get('stats_7d_tokens_traded', 0.0),
171
+ ]
172
+ num_tensor[i] = torch.tensor(num_data, dtype=self.dtype)
173
+
174
+ # C: Booleans and deployed_tokens lists are GONE
175
+
176
+ # Log-normalize all numerical features (age, stats, etc.)
177
+ num_embed = self.profile_num_norm(self._safe_signed_log(num_tensor))
178
+
179
+ # The profile fused tensor is now just the numerical embeddings
180
+ profile_fused = num_embed
181
+ return self.profile_encoder_mlp(profile_fused)
182
+
183
+ def _encode_social_batch(self, social_rows, embedding_pool, username_embed_indices, device):
184
+ batch_size = len(social_rows)
185
+ # FIXED: 4 boolean features
186
+ bool_tensor = torch.zeros(batch_size, 4, device=device, dtype=torch.long)
187
+ for i, row in enumerate(social_rows):
188
+ # All features MUST be present
189
+ bool_tensor[i, 0] = 1 if row['has_pf_profile'] else 0
190
+ bool_tensor[i, 1] = 1 if row['has_twitter'] else 0
191
+ bool_tensor[i, 2] = 1 if row['has_telegram'] else 0
192
+ # FIXED: Added is_exchange_wallet
193
+ bool_tensor[i, 3] = 1 if row['is_exchange_wallet'] else 0
194
+
195
+ bool_embeds = self.social_bool_embed(bool_tensor).view(batch_size, -1) # [B, 64]
196
+ # --- NEW: Look up pre-computed username embeddings ---
197
+ # --- FIXED: Handle case where embedding_pool is empty ---
198
+ if embedding_pool.numel() > 0:
199
+ # SAFETY: build a padded view so missing indices (-1) map to a zero vector
200
+ pad_row = torch.zeros(1, embedding_pool.size(1), device=device, dtype=embedding_pool.dtype)
201
+ pool_padded = torch.cat([pad_row, embedding_pool], dim=0)
202
+ shifted_idx = torch.where(username_embed_indices >= 0, username_embed_indices + 1, torch.zeros_like(username_embed_indices))
203
+ username_embed = F.embedding(shifted_idx, pool_padded)
204
+ else:
205
+ # If there are no embeddings, create a zero tensor of the correct shape
206
+ username_embed = torch.zeros(batch_size, self.mmp_dim, device=device, dtype=self.dtype)
207
+ social_fused = torch.cat([bool_embeds, username_embed], dim=1)
208
+ return self.social_encoder_mlp(social_fused)
209
+
210
+ def _encode_holdings_batch(self, holdings_batch, token_vibe_lookup, device):
211
+ batch_size = len(holdings_batch)
212
+ max_len = max(len(h) for h in holdings_batch) if any(holdings_batch) else 1
213
+ seq_embeds = torch.zeros(batch_size, max_len, self.d_model, device=device, dtype=self.dtype)
214
+ mask = torch.ones(batch_size, max_len, device=device, dtype=torch.bool)
215
+ default_vibe = torch.zeros(self.token_vibe_dim, device=device, dtype=self.dtype)
216
+
217
+ for i, holdings in enumerate(holdings_batch):
218
+ if not holdings: continue
219
+ h_len = min(len(holdings), max_len)
220
+ holdings = holdings[:h_len]
221
+
222
+ # --- FIXED: Safely get vibes, using default if mint_address is missing or not in lookup ---
223
+ vibes = [token_vibe_lookup.get(row['mint_address'], default_vibe) for row in holdings if 'mint_address' in row]
224
+ if not vibes: continue # Skip if no valid holdings with vibes
225
+ vibe_tensor = torch.stack(vibes)
226
+
227
+ # time_tensor removed
228
+
229
+ num_data_list = []
230
+ for row in holdings:
231
+ # FIXED: All 12 numerical features MUST be present
232
+ num_data = [
233
+ # Use .get() with a 0.0 default for safety
234
+ row.get('holding_time', 0.0),
235
+ row.get('balance_pct_to_supply', 0.0),
236
+ row.get('history_bought_cost_sol', 0.0), # Corrected key from schema
237
+ row.get('bought_amount_sol_pct_to_native_balance', 0.0), # This key is not in schema, will default to 0
238
+ row.get('history_total_buys', 0.0),
239
+ row.get('history_total_sells', 0.0),
240
+ row.get('realized_profit_pnl', 0.0),
241
+ row.get('realized_profit_sol', 0.0),
242
+ row.get('history_transfer_in', 0.0),
243
+ row.get('history_transfer_out', 0.0),
244
+ row.get('avarage_trade_gap_seconds', 0.0),
245
+ row.get('total_fees', 0.0) # Corrected key from schema
246
+ ]
247
+ num_data_list.append(num_data)
248
+
249
+ num_tensor = torch.tensor(num_data_list, device=device, dtype=self.dtype)
250
+
251
+ # Log-normalize all numerical features (holding_time, stats, etc.)
252
+ num_embed = self.holding_num_norm(self._safe_signed_log(num_tensor))
253
+
254
+ # time_embed removed
255
+
256
+ # FIXED: Fused tensor no longer has time_embed
257
+ fused_rows = torch.cat([vibe_tensor, num_embed], dim=1)
258
+ encoded_rows = self.holding_row_encoder_mlp(fused_rows)
259
+ seq_embeds[i, :h_len] = encoded_rows
260
+ mask[i, :h_len] = False
261
+
262
+ return self.holdings_set_encoder(seq_embeds, mask)
models/wallet_set_encoder.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class WalletSetEncoder(nn.Module):
5
+ """
6
+ Encodes a variable-length set of embeddings into a single fixed-size vector
7
+ using a Transformer encoder and a [CLS] token.
8
+
9
+ This is used to pool:
10
+ 1. A wallet's `wallet_holdings` (a set of [holding_embeds]).
11
+ 2. A wallet's `Neo4J links` (a set of [link_embeds]).
12
+ 3. A wallet's `deployed_tokens` (a set of [token_name_embeds]).
13
+ """
14
+ def __init__(
15
+ self,
16
+ d_model: int,
17
+ nhead: int,
18
+ num_layers: int,
19
+ dim_feedforward: int = 2048,
20
+ dropout: float = 0.1,
21
+ dtype: torch.dtype = torch.float16
22
+ ):
23
+ """
24
+ Initializes the Set Encoder.
25
+
26
+ Args:
27
+ d_model (int): The input/output dimension of the embeddings.
28
+ nhead (int): Number of attention heads.
29
+ num_layers (int): Number of transformer layers.
30
+ dim_feedforward (int): Hidden dimension of the feedforward network.
31
+ dropout (float): Dropout rate.
32
+ dtype (torch.dtype): Data type.
33
+ """
34
+ super().__init__()
35
+ self.d_model = d_model
36
+ self.dtype = dtype
37
+
38
+ # The learnable [CLS] token, which will aggregate the set representation
39
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
40
+ nn.init.normal_(self.cls_token, std=0.02)
41
+
42
+ encoder_layer = nn.TransformerEncoderLayer(
43
+ d_model=d_model,
44
+ nhead=nhead,
45
+ dim_feedforward=dim_feedforward,
46
+ dropout=dropout,
47
+ batch_first=True
48
+ )
49
+ self.transformer_encoder = nn.TransformerEncoder(
50
+ encoder_layer,
51
+ num_layers=num_layers
52
+ )
53
+ self.output_norm = nn.LayerNorm(d_model)
54
+
55
+ self.to(dtype)
56
+
57
+ def forward(
58
+ self,
59
+ item_embeds: torch.Tensor,
60
+ src_key_padding_mask: torch.Tensor
61
+ ) -> torch.Tensor:
62
+ """
63
+ Forward pass.
64
+
65
+ Args:
66
+ item_embeds (torch.Tensor):
67
+ The batch of item embeddings.
68
+ Shape: [batch_size, seq_len, d_model]
69
+ src_key_padding_mask (torch.Tensor):
70
+ The boolean padding mask for the items, where True indicates
71
+ a padded position that should be ignored.
72
+ Shape: [batch_size, seq_len]
73
+
74
+ Returns:
75
+ torch.Tensor: The pooled set embedding.
76
+ Shape: [batch_size, d_model]
77
+ """
78
+ batch_size = item_embeds.shape[0]
79
+
80
+ # 1. Create [CLS] token batch and concatenate with item embeddings
81
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1).to(self.dtype)
82
+ x = torch.cat([cls_tokens, item_embeds], dim=1)
83
+
84
+ # 2. Create the mask for the [CLS] token (it is never masked)
85
+ cls_mask = torch.zeros(batch_size, 1, device=src_key_padding_mask.device, dtype=torch.bool)
86
+
87
+ # 3. Concatenate the [CLS] mask with the item mask
88
+ full_padding_mask = torch.cat([cls_mask, src_key_padding_mask], dim=1)
89
+
90
+ # 4. Pass through Transformer
91
+ transformer_output = self.transformer_encoder(
92
+ x,
93
+ src_key_padding_mask=full_padding_mask
94
+ )
95
+
96
+ # 5. Extract the output of the [CLS] token (the first token in the sequence)
97
+ cls_output = transformer_output[:, 0, :]
98
+
99
+ return self.output_norm(cls_output)
neo4j.rs ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Nodes
2
+
3
+ pub struct Token {
4
+ address: String,
5
+ }
6
+
7
+ pub struct Wallet {
8
+ address: String,
9
+ }
10
+
11
+ // Links
12
+
13
+ /// Tracks direct capital flow and identifies funding chains.
14
+ pub struct TransferLink {
15
+ pub signature: String,
16
+ pub source: String,
17
+ pub destination: String,
18
+ pub mint: String,
19
+ pub timestamp: i64,
20
+ }
21
+
22
+ /// Identifies wallets trading the same token in the same slot.
23
+ pub struct BundleTradeLink {
24
+ pub signatures: Vec<String>,
25
+ pub wallet_a: String,
26
+ pub wallet_b: String,
27
+ pub mint: String,
28
+ pub slot: i64,
29
+ pub timestamp: i64,
30
+ }
31
+
32
+ /// Reveals a behavioral pattern of one wallet mirroring another's successful trade.
33
+ pub struct CopiedTradeLink {
34
+ pub leader_buy_sig: String,
35
+ pub leader_sell_sig: String,
36
+ pub follower_buy_sig: String,
37
+ pub follower_sell_sig: String,
38
+ pub follower: String,
39
+ pub leader: String,
40
+ pub mint: String,
41
+ pub time_gap_on_buy_sec: i64,
42
+ pub time_gap_on_sell_sec: i64,
43
+ pub leader_pnl: f64,
44
+ pub follower_pnl: f64,
45
+
46
+ pub leader_buy_total: f64,
47
+ pub leader_sell_total: f64,
48
+
49
+ pub follower_buy_total: f64,
50
+ pub follower_sell_total: f64,
51
+ pub follower_buy_slippage: f32,
52
+ pub follower_sell_slippage: f32,
53
+ }
54
+
55
+ /// Represents a link where a group of wallets re-engage with a token in a coordinated manner.
56
+ pub struct CoordinatedActivityLink {
57
+ pub leader_first_sig: String,
58
+ pub leader_second_sig: String,
59
+ pub follower_first_sig: String,
60
+ pub follower_second_sig: String,
61
+ pub follower: String,
62
+ pub leader: String,
63
+ pub mint: String,
64
+ pub time_gap_on_first_sec: i64,
65
+ pub time_gap_on_second_sec: i64,
66
+ }
67
+
68
+ /// Links a token to its original creator.
69
+ pub struct MintedLink {
70
+ pub signature: String,
71
+ pub timestamp: i64,
72
+ pub buy_amount: f64,
73
+ }
74
+
75
+ /// Connects a token to its successful first-movers.
76
+ pub struct SnipedLink {
77
+ pub signature: String,
78
+ pub rank: i64,
79
+ pub sniped_amount: f64,
80
+ }
81
+
82
+ /// Represents connection between wallet that locked supply.
83
+ pub struct LockedSupplyLink {
84
+ pub signature: String,
85
+ pub amount: f64,
86
+ pub unlock_timestamp: u64,
87
+ }
88
+
89
+ /// link of the wallet that burned tokens.
90
+ pub struct BurnedLink {
91
+ pub signature: String,
92
+ pub amount: f64,
93
+ pub timestamp: i64,
94
+ }
95
+
96
+ /// Identifies wallets that provided liquidity, signaling high conviction.
97
+ pub struct ProvidedLiquidityLink {
98
+ pub signature: String,
99
+ pub wallet: String,
100
+ pub token: String,
101
+ pub pool_address: String,
102
+ pub amount_base: f64,
103
+ pub amount_quote: f64,
104
+ pub timestamp: i64,
105
+ }
106
+
107
+ /// A derived link connecting a token to its largest holders.
108
+ pub struct WhaleOfLink {
109
+ pub wallet: String,
110
+ pub token: String,
111
+ pub holding_pct_at_creation: f32, // Holding % when the link was made
112
+ pub ath_usd_at_creation: f64, // Token's ATH when the link was made
113
+ }
114
+
115
+ /// A derived link connecting a token to its most profitable traders.
116
+ pub struct TopTraderOfLink {
117
+ pub wallet: String,
118
+ pub token: String,
119
+ pub pnl_at_creation: f64, // The PNL that first triggered the link
120
+ pub ath_usd_at_creation: f64, // Token's ATH when the link was made
121
+ }
ohlc_stats.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f56037cf2ad8502213ee2c8470c314eef83a4cd93063290581ef45fadea5d48
3
+ size 1660
onchain.sql ADDED
@@ -0,0 +1,599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CREATE TABLE trades
2
+ (
3
+ timestamp DateTime('UTC'),
4
+ signature String,
5
+
6
+ slot UInt64,
7
+ transaction_index UInt32,
8
+ instruction_index UInt16,
9
+ success Boolean,
10
+ error Nullable(String),
11
+
12
+ -- Fee Structure
13
+ priority_fee Float64,
14
+ bribe_fee Float64,
15
+ coin_creator_fee Float64,
16
+ mev_protection UInt8,
17
+
18
+ -- Parties
19
+ maker String,
20
+
21
+ -- Balances (Pre & Post)
22
+ base_balance Float64,
23
+
24
+
25
+ quote_balance Float64,
26
+
27
+
28
+ -- Trade Semantics
29
+ trade_type UInt8,
30
+ protocol UInt8,
31
+ platform UInt8,
32
+
33
+ -- Asset Info
34
+ pool_address String,
35
+ base_address String,
36
+ quote_address String,
37
+
38
+ -- Trade Details
39
+ slippage Float32,
40
+ price_impact Float32,
41
+
42
+ base_amount UInt64,
43
+ quote_amount UInt64,
44
+
45
+ price Float64,
46
+ price_usd Float64,
47
+
48
+ total Float64,
49
+ total_usd Float64
50
+
51
+ )
52
+ ENGINE = MergeTree()
53
+ ORDER BY (base_address, timestamp, maker, signature);
54
+
55
+ --- mint
56
+ CREATE TABLE mints
57
+ (
58
+ -- === Transaction Details ===
59
+ -- Solana signature is usually 88 characters, but we use String for flexibility.
60
+ signature String,
61
+ -- Converted to DateTime for easier time-based operations in ClickHouse.
62
+ timestamp DateTime('UTC'),
63
+ slot UInt64,
64
+ success Boolean,
65
+ error Nullable(String),
66
+ priority_fee Float64,
67
+
68
+ -- === Protocol & Platform ===
69
+ -- Protocol codes: 0=Unknown, 1=PumpFunLaunchpad, 2=RaydiumLaunchpad,
70
+ -- 3=PumpFunAMM, 4=RaydiumCPMM, 5=MeteoraBonding
71
+ protocol UInt8,
72
+
73
+ -- === Mint & Pool Details ===
74
+ mint_address String,
75
+ creator_address String,
76
+ pool_address String,
77
+
78
+ -- === Liquidity Details ===
79
+ initial_base_liquidity UInt64,
80
+ initial_quote_liquidity UInt64,
81
+
82
+ -- === Token Metadata ===
83
+ token_name Nullable(String),
84
+ token_symbol Nullable(String),
85
+ token_uri Nullable(String),
86
+ token_decimals UInt8,
87
+ total_supply UInt64,
88
+
89
+ is_mutable Boolean,
90
+ update_authority Nullable(String),
91
+ mint_authority Nullable(String),
92
+ freeze_authority Nullable(String),
93
+
94
+ )
95
+ ENGINE = MergeTree()
96
+ ORDER BY (timestamp, creator_address, mint_address);
97
+
98
+ CREATE TABLE migrations
99
+ (
100
+ -- Transaction Details
101
+ timestamp DateTime('UTC'),
102
+
103
+ signature String,
104
+ slot UInt64,
105
+ success Boolean,
106
+ error Nullable(String),
107
+ priority_fee Float64,
108
+
109
+ -- Protocol & Platform
110
+ protocol UInt8,
111
+
112
+ -- Migration Details
113
+ mint_address String,
114
+ virtual_pool_address String,
115
+ pool_address String,
116
+
117
+ -- Liquidity Details
118
+ migrated_base_liquidity Nullable(UInt64),
119
+ migrated_quote_liquidity Nullable(UInt64)
120
+ )
121
+ ENGINE = MergeTree()
122
+ ORDER BY (mint_address, virtual_pool_address, pool_address, timestamp);
123
+
124
+ CREATE TABLE fee_collections
125
+ (
126
+ -- Transaction Details
127
+ timestamp DateTime('UTC'),
128
+
129
+ signature String,
130
+ slot UInt64,
131
+ success Boolean,
132
+ error Nullable(String),
133
+ priority_fee Float64,
134
+
135
+ -- Protocol & Platform
136
+ protocol UInt8,
137
+
138
+ -- Fee Details
139
+ vault_address String,
140
+ recipient_address String,
141
+
142
+ -- Collected Amounts
143
+ token_0_mint_address String,
144
+ token_0_amount Float64,
145
+ token_1_mint_address Nullable(String),
146
+ token_1_amount Nullable(Float64)
147
+ )
148
+ ENGINE = MergeTree()
149
+ ORDER BY (vault_address, recipient_address, timestamp);
150
+
151
+ CREATE TABLE liquidity
152
+ (
153
+ -- Transaction Details --
154
+ signature String,
155
+ timestamp DateTime('UTC'),
156
+ slot UInt64,
157
+ success Boolean,
158
+ error Nullable(String),
159
+ priority_fee Float64,
160
+
161
+ -- Protocol Info --
162
+ protocol UInt8,
163
+
164
+ -- LP Action Details --
165
+ change_type UInt8,
166
+ lp_provider String,
167
+ pool_address String,
168
+
169
+ -- Token Amounts --
170
+ base_amount UInt64,
171
+ quote_amount UInt64
172
+ )
173
+ ENGINE = MergeTree()
174
+ ORDER BY (timestamp, pool_address, lp_provider);
175
+
176
+ CREATE TABLE pool_creations (
177
+ -- Transaction Details --
178
+ signature String,
179
+ timestamp Datetime('UTC'),
180
+ slot UInt64,
181
+ success Boolean,
182
+ error Nullable(String),
183
+ priority_fee Float64,
184
+
185
+ -- Protocol Info --
186
+ protocol UInt8,
187
+
188
+ -- Pool & Token Details --
189
+ creator_address String,
190
+ pool_address String,
191
+ base_address String,
192
+ quote_address String,
193
+ lp_token_address String,
194
+
195
+ -- Optional Initial State --
196
+ initial_base_liquidity Nullable(UInt64),
197
+ initial_quote_liquidity Nullable(UInt64),
198
+ base_decimals Nullable(UInt8),
199
+ quote_decimals Nullable(UInt8)
200
+ )
201
+ ENGINE = MergeTree()
202
+ ORDER BY (base_address, creator_address);
203
+
204
+ CREATE TABLE transfers
205
+ (
206
+ -- Transaction Details
207
+ timestamp DateTime('UTC'),
208
+ signature String,
209
+ slot UInt64,
210
+ success Boolean,
211
+ error Nullable(String),
212
+ priority_fee Float64,
213
+
214
+ -- Transfer Details
215
+ source String,
216
+ destination String,
217
+
218
+ -- Amount & Mint Details
219
+ mint_address String,
220
+ amount UInt64,
221
+ amount_decimal Float64,
222
+
223
+ -- Balance Context ===
224
+ source_balance Float64,
225
+ destination_balance Float64
226
+ )
227
+ ENGINE = MergeTree()
228
+ ORDER BY (source, destination, mint_address, timestamp);
229
+
230
+ CREATE TABLE supply_locks
231
+ (
232
+ -- === Transaction Details ===
233
+ timestamp DateTime('UTC'),
234
+
235
+ signature String,
236
+ slot UInt64,
237
+ success Boolean,
238
+ error Nullable(String),
239
+ priority_fee Float64,
240
+
241
+ -- === Protocol Info ===
242
+ protocol UInt8,
243
+
244
+ -- === Vesting Details ===
245
+ contract_address String,
246
+ sender String,
247
+ recipient String,
248
+ mint_address String,
249
+ total_locked_amount Float64,
250
+ final_unlock_timestamp UInt64
251
+ )
252
+ ENGINE = MergeTree()
253
+ ORDER BY (timestamp, mint_address, sender, recipient);
254
+
255
+ CREATE TABLE supply_lock_actions
256
+ (
257
+ -- === Transaction Details ===
258
+
259
+ signature String,
260
+ timestamp DateTime('UTC'),
261
+ slot UInt64,
262
+ success Boolean,
263
+ error Nullable(String),
264
+ priority_fee Float64,
265
+
266
+ -- === Protocol Info ===
267
+ protocol UInt8,
268
+
269
+ -- === Action Details ===
270
+ action_type UInt8, -- e.g., 0 for Withdraw, 1 for Topup
271
+ contract_address String,
272
+ user String,
273
+ mint_address String,
274
+ amount Float64
275
+ )
276
+ ENGINE = MergeTree()
277
+ ORDER BY (timestamp, mint_address, user);
278
+
279
+ CREATE TABLE burns
280
+ (
281
+ -- Transaction Details
282
+ timestamp DateTime('UTC'),
283
+ signature String,
284
+ slot UInt64,
285
+ success Boolean,
286
+ error Nullable(String),
287
+ priority_fee Float64,
288
+
289
+ -- Burn Details
290
+ mint_address String,
291
+ source String,
292
+ amount UInt64,
293
+ amount_decimal Float64,
294
+
295
+ source_balance Float64
296
+ )
297
+ ENGINE = MergeTree()
298
+ ORDER BY (mint_address, source, timestamp);
299
+
300
+ -------- Wallet schema
301
+
302
+ CREATE TABLE wallet_profiles
303
+ (
304
+ updated_at DateTime('UTC'),
305
+ first_seen_ts DateTime('UTC'),
306
+ last_seen_ts DateTime('UTC'),
307
+
308
+ wallet_address String,
309
+ tags Array(String),
310
+ deployed_tokens Array(String),
311
+
312
+ funded_from String,
313
+ funded_timestamp UInt32,
314
+ funded_signature String,
315
+ funded_amount Float64
316
+ )
317
+ ENGINE = ReplacingMergeTree(updated_at)
318
+ PRIMARY KEY (wallet_address)
319
+ ORDER BY (wallet_address);
320
+
321
+ CREATE TABLE wallet_profile_metrics
322
+ (
323
+ updated_at DateTime('UTC'),
324
+ wallet_address String,
325
+ balance Float64,
326
+
327
+ transfers_in_count UInt32,
328
+ transfers_out_count UInt32,
329
+ spl_transfers_in_count UInt32,
330
+ spl_transfers_out_count UInt32,
331
+
332
+ total_buys_count UInt32,
333
+ total_sells_count UInt32,
334
+ total_winrate Float32,
335
+
336
+ stats_1d_realized_profit_sol Float64,
337
+ stats_1d_realized_profit_usd Float64,
338
+ stats_1d_realized_profit_pnl Float32,
339
+ stats_1d_buy_count UInt32,
340
+ stats_1d_sell_count UInt32,
341
+ stats_1d_transfer_in_count UInt32,
342
+ stats_1d_transfer_out_count UInt32,
343
+ stats_1d_avg_holding_period Float32,
344
+ stats_1d_total_bought_cost_sol Float64,
345
+ stats_1d_total_bought_cost_usd Float64,
346
+ stats_1d_total_sold_income_sol Float64,
347
+ stats_1d_total_sold_income_usd Float64,
348
+ stats_1d_total_fee Float64,
349
+ stats_1d_winrate Float32,
350
+ stats_1d_tokens_traded UInt32,
351
+
352
+ stats_7d_realized_profit_sol Float64,
353
+ stats_7d_realized_profit_usd Float64,
354
+ stats_7d_realized_profit_pnl Float32,
355
+ stats_7d_buy_count UInt32,
356
+ stats_7d_sell_count UInt32,
357
+ stats_7d_transfer_in_count UInt32,
358
+ stats_7d_transfer_out_count UInt32,
359
+ stats_7d_avg_holding_period Float32,
360
+ stats_7d_total_bought_cost_sol Float64,
361
+ stats_7d_total_bought_cost_usd Float64,
362
+ stats_7d_total_sold_income_sol Float64,
363
+ stats_7d_total_sold_income_usd Float64,
364
+ stats_7d_total_fee Float64,
365
+ stats_7d_winrate Float32,
366
+ stats_7d_tokens_traded UInt32,
367
+
368
+ stats_30d_realized_profit_sol Float64,
369
+ stats_30d_realized_profit_usd Float64,
370
+ stats_30d_realized_profit_pnl Float32,
371
+ stats_30d_buy_count UInt32,
372
+ stats_30d_sell_count UInt32,
373
+ stats_30d_transfer_in_count UInt32,
374
+ stats_30d_transfer_out_count UInt32,
375
+ stats_30d_avg_holding_period Float32,
376
+ stats_30d_total_bought_cost_sol Float64,
377
+ stats_30d_total_bought_cost_usd Float64,
378
+ stats_30d_total_sold_income_sol Float64,
379
+ stats_30d_total_sold_income_usd Float64,
380
+ stats_30d_total_fee Float64,
381
+ stats_30d_winrate Float32,
382
+ stats_30d_tokens_traded UInt32
383
+ )
384
+ ENGINE = MergeTree
385
+ ORDER BY (wallet_address, updated_at);
386
+
387
+ CREATE TABLE wallet_holdings
388
+ (
389
+ updated_at DateTime('UTC'),
390
+ start_holding_at DateTime('UTC'),
391
+
392
+ wallet_address String,
393
+ mint_address String,
394
+ current_balance Float64,
395
+
396
+ realized_profit_pnl Float32,
397
+ realized_profit_sol Float64,
398
+ realized_profit_usd Float64,
399
+
400
+ history_transfer_in UInt32,
401
+ history_transfer_out UInt32,
402
+
403
+ history_bought_amount Float64,
404
+ history_bought_cost_sol Float64,
405
+ history_sold_amount Float64,
406
+ history_sold_income_sol Float64
407
+ )
408
+ ENGINE = MergeTree
409
+ ORDER BY (wallet_address, mint_address, updated_at);
410
+
411
+ CREATE TABLE tokens (
412
+ updated_at DateTime('UTC'),
413
+ created_at DateTime('UTC'),
414
+
415
+ -- Core Identifiers
416
+ token_address String,
417
+ name String,
418
+ symbol String,
419
+ token_uri String,
420
+
421
+ -- Token Metadata
422
+ decimals UInt8,
423
+ creator_address String,
424
+ pool_addresses Array(String), -- Map Vec<String> to Array(String)
425
+
426
+ -- Protocol/Launchpad
427
+ launchpad UInt8,
428
+ protocol UInt8,
429
+ total_supply UInt64,
430
+
431
+ -- Authorities/Flags
432
+ is_mutable Boolean, -- Alias for UInt8, but Boolean is clearer/modern
433
+ update_authority Nullable(String), -- Map Option<String> to Nullable(String)
434
+ mint_authority Nullable(String),
435
+ freeze_authority Nullable(String)
436
+ )
437
+ ENGINE = ReplacingMergeTree(updated_at)
438
+ PRIMARY KEY (token_address)
439
+ ORDER BY (token_address, updated_at);
440
+
441
+ -- Latest tokens (one row per token_address)
442
+ CREATE TABLE tokens_latest
443
+ (
444
+ updated_at DateTime('UTC'),
445
+ created_at DateTime('UTC'),
446
+
447
+ token_address String,
448
+ name String,
449
+ symbol String,
450
+ token_uri String,
451
+
452
+ decimals UInt8,
453
+ creator_address String,
454
+ pool_addresses Array(String),
455
+
456
+ launchpad UInt8,
457
+ protocol UInt8,
458
+ total_supply UInt64,
459
+
460
+ is_mutable Boolean,
461
+ update_authority Nullable(String),
462
+ mint_authority Nullable(String),
463
+ freeze_authority Nullable(String)
464
+ )
465
+ ENGINE = ReplacingMergeTree(updated_at)
466
+ ORDER BY (token_address);
467
+
468
+ CREATE TABLE token_metrics (
469
+ updated_at DateTime('UTC'),
470
+ token_address String,
471
+ total_volume_usd Float64,
472
+ total_buys UInt32,
473
+ total_sells UInt32,
474
+ unique_holders UInt32,
475
+ ath_price_usd Float64
476
+ )
477
+ ENGINE = MergeTree
478
+ ORDER BY (token_address, updated_at);
479
+
480
+
481
+
482
+ -- ========= Latest snapshot helper tables =========
483
+ -- Keep full history in the base tables above, but read fast from these ReplacingMergeTree snapshots.
484
+
485
+ -- Latest wallet profile metrics (one row per wallet_address)
486
+ CREATE TABLE wallet_profile_metrics_latest
487
+ (
488
+ updated_at DateTime('UTC'),
489
+ wallet_address String,
490
+ balance Float64,
491
+
492
+ transfers_in_count UInt32,
493
+ transfers_out_count UInt32,
494
+ spl_transfers_in_count UInt32,
495
+ spl_transfers_out_count UInt32,
496
+
497
+ total_buys_count UInt32,
498
+ total_sells_count UInt32,
499
+ total_winrate Float32,
500
+
501
+ stats_1d_realized_profit_sol Float64,
502
+ stats_1d_realized_profit_usd Float64,
503
+ stats_1d_realized_profit_pnl Float32,
504
+ stats_1d_buy_count UInt32,
505
+ stats_1d_sell_count UInt32,
506
+ stats_1d_transfer_in_count UInt32,
507
+ stats_1d_transfer_out_count UInt32,
508
+ stats_1d_avg_holding_period Float32,
509
+ stats_1d_total_bought_cost_sol Float64,
510
+ stats_1d_total_bought_cost_usd Float64,
511
+ stats_1d_total_sold_income_sol Float64,
512
+ stats_1d_total_sold_income_usd Float64,
513
+ stats_1d_total_fee Float64,
514
+ stats_1d_winrate Float32,
515
+ stats_1d_tokens_traded UInt32,
516
+
517
+ stats_7d_realized_profit_sol Float64,
518
+ stats_7d_realized_profit_usd Float64,
519
+ stats_7d_realized_profit_pnl Float32,
520
+ stats_7d_buy_count UInt32,
521
+ stats_7d_sell_count UInt32,
522
+ stats_7d_transfer_in_count UInt32,
523
+ stats_7d_transfer_out_count UInt32,
524
+ stats_7d_avg_holding_period Float32,
525
+ stats_7d_total_bought_cost_sol Float64,
526
+ stats_7d_total_bought_cost_usd Float64,
527
+ stats_7d_total_sold_income_sol Float64,
528
+ stats_7d_total_sold_income_usd Float64,
529
+ stats_7d_total_fee Float64,
530
+ stats_7d_winrate Float32,
531
+ stats_7d_tokens_traded UInt32,
532
+
533
+ stats_30d_realized_profit_sol Float64,
534
+ stats_30d_realized_profit_usd Float64,
535
+ stats_30d_realized_profit_pnl Float32,
536
+ stats_30d_buy_count UInt32,
537
+ stats_30d_sell_count UInt32,
538
+ stats_30d_transfer_in_count UInt32,
539
+ stats_30d_transfer_out_count UInt32,
540
+ stats_30d_avg_holding_period Float32,
541
+ stats_30d_total_bought_cost_sol Float64,
542
+ stats_30d_total_bought_cost_usd Float64,
543
+ stats_30d_total_sold_income_sol Float64,
544
+ stats_30d_total_sold_income_usd Float64,
545
+ stats_30d_total_fee Float64,
546
+ stats_30d_winrate Float32,
547
+ stats_30d_tokens_traded UInt32
548
+ )
549
+ ENGINE = ReplacingMergeTree(updated_at)
550
+ ORDER BY (wallet_address);
551
+
552
+ -- Latest wallet holdings (one row per wallet_address + mint_address)
553
+ CREATE TABLE wallet_holdings_latest
554
+ (
555
+ updated_at DateTime('UTC'),
556
+ start_holding_at DateTime('UTC'),
557
+
558
+ wallet_address String,
559
+ mint_address String,
560
+ current_balance Float64,
561
+
562
+ realized_profit_pnl Float32,
563
+ realized_profit_sol Float64,
564
+ realized_profit_usd Float64,
565
+
566
+ history_transfer_in UInt32,
567
+ history_transfer_out UInt32,
568
+
569
+ history_bought_amount Float64,
570
+ history_bought_cost_sol Float64,
571
+ history_sold_amount Float64,
572
+ history_sold_income_sol Float64
573
+ )
574
+ ENGINE = ReplacingMergeTree(updated_at)
575
+ ORDER BY (wallet_address, mint_address);
576
+
577
+ -- Latest token metrics (one row per token_address)
578
+ CREATE TABLE token_metrics_latest
579
+ (
580
+ updated_at DateTime('UTC'),
581
+ token_address String,
582
+ total_volume_usd Float64,
583
+ total_buys UInt32,
584
+ total_sells UInt32,
585
+ unique_holders UInt32,
586
+ ath_price_usd Float64
587
+ )
588
+ ENGINE = ReplacingMergeTree(updated_at)
589
+ ORDER BY (token_address);
590
+
591
+
592
+ CREATE TABLE known_wallets
593
+ (
594
+ `wallet_address` String,
595
+ `name` String, -- e.g., "Pump.fun Fee Vault", "Raydium CPMM Authority V4", "KOL - Ansem"
596
+ `tag` String, -- e.g., "fee_vault", "dex_authority", "kol", "exchange"
597
+ )
598
+ ENGINE = ReplacingMergeTree()
599
+ ORDER BY (wallet_address);
pre_cache.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ python scripts/cache_dataset.py \
2
+ --offset-utc 2024-01-01T00:00:00Z \
3
+ --max-samples 100 \
4
+ --out-dir data/cache/epoch_851 \
5
+ --clickhouse-host localhost --clickhouse-port 9000 \
6
+ --neo4j-uri bolt://localhost:7687
scripts/cache_dataset.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Script to pre-generate and cache dataset items from the OracleDataset.
4
+
5
+ This script connects to the databases, instantiates the data loader in 'online' mode,
6
+ and iterates through the requested number of samples, saving each processed item
7
+ to a file. This avoids costly data fetching and processing during training.
8
+
9
+ Example usage:
10
+ python scripts/cache_dataset.py --output-dir ./data/cached_dataset --max-samples 1000 --start-date 2024-05-01
11
+ """
12
+
13
+ import argparse
14
+ import datetime
15
+ import os
16
+ import sys
17
+ from pathlib import Path
18
+
19
+ import torch
20
+ import clickhouse_connect
21
+ from neo4j import GraphDatabase
22
+ from tqdm import tqdm
23
+
24
+ # Add apollo to path to import modules
25
+ sys.path.append(str(Path(__file__).resolve().parents[1]))
26
+
27
+ from data.data_loader import OracleDataset
28
+ from data.data_fetcher import DataFetcher
29
+
30
+ # --- Database Connection Details (can be overridden by env vars) ---
31
+ CLICKHOUSE_HOST = os.getenv("CLICKHOUSE_HOST", "localhost")
32
+ CLICKHOUSE_PORT = int(os.getenv("CLICKHOUSE_PORT", "8123"))
33
+ CLICKHOUSE_USER = os.getenv("CLICKHOUSE_USER", "default")
34
+ CLICKHOUSE_PASSWORD = os.getenv("CLICKHOUSE_PASSWORD", "")
35
+ CLICKHOUSE_DATABASE = os.getenv("CLICKHOUSE_DATABASE", "default")
36
+
37
+ NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687")
38
+ NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
39
+ NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "password")
40
+
41
+ def parse_args():
42
+ parser = argparse.ArgumentParser(description="Cache OracleDataset items to disk.")
43
+ parser.add_argument(
44
+ "--output-dir",
45
+ type=str,
46
+ required=True,
47
+ help="Directory to save the cached .pt files."
48
+ )
49
+ parser.add_argument(
50
+ "--max-samples",
51
+ type=int,
52
+ default=None,
53
+ help="Maximum number of samples to generate and cache. Defaults to all available."
54
+ )
55
+ parser.add_argument(
56
+ "--start-date",
57
+ type=str,
58
+ default=None,
59
+ help="Start date for fetching mints in YYYY-MM-DD format. Fetches all mints on or after this UTC date."
60
+ )
61
+ parser.add_argument(
62
+ "--t-cutoff-seconds",
63
+ type=int,
64
+ default=60,
65
+ help="Time in seconds after mint to set the data cutoff (T_cutoff)."
66
+ )
67
+ parser.add_argument(
68
+ "--ohlc-stats-path",
69
+ type=str,
70
+ default="./data/ohlc_stats.npz",
71
+ help="Path to the OHLC stats file for normalization."
72
+ )
73
+ parser.add_argument(
74
+ "--min-trade-usd",
75
+ type=float,
76
+ default=5.0,
77
+ help="Minimum USD value for a trade to be included in the event sequence. Defaults to 5.0."
78
+ )
79
+ return parser.parse_args()
80
+
81
+ def main():
82
+ args = parse_args()
83
+
84
+ output_dir = Path(args.output_dir)
85
+ output_dir.mkdir(parents=True, exist_ok=True)
86
+ print(f"INFO: Caching dataset to {output_dir.resolve()}")
87
+
88
+ start_date_dt = None
89
+ if args.start_date:
90
+ try:
91
+ start_date_dt = datetime.datetime.strptime(args.start_date, "%Y-%m-%d").replace(tzinfo=datetime.timezone.utc)
92
+ print(f"INFO: Filtering mints on or after {start_date_dt}")
93
+ except ValueError:
94
+ print(f"ERROR: Invalid start-date format. Please use YYYY-MM-DD.", file=sys.stderr)
95
+ sys.exit(1)
96
+
97
+ # --- 1. Set up database connections ---
98
+ try:
99
+ print("INFO: Connecting to ClickHouse...")
100
+ clickhouse_client = clickhouse_connect.get_client(host=CLICKHOUSE_HOST, port=CLICKHOUSE_PORT, user=CLICKHOUSE_USER, password=CLICKHOUSE_PASSWORD, database=CLICKHOUSE_DATABASE)
101
+ print("INFO: Connecting to Neo4j...")
102
+ neo4j_driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
103
+ except Exception as e:
104
+ print(f"ERROR: Failed to connect to databases: {e}", file=sys.stderr)
105
+ sys.exit(1)
106
+
107
+ # --- 2. Initialize DataFetcher and OracleDataset ---
108
+ data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
109
+
110
+ dataset = OracleDataset(
111
+ data_fetcher=data_fetcher,
112
+ max_samples=args.max_samples,
113
+ start_date=start_date_dt,
114
+ t_cutoff_seconds=args.t_cutoff_seconds,
115
+ ohlc_stats_path=args.ohlc_stats_path,
116
+ horizons_seconds=[60, 300, 900, 1800, 3600],
117
+ quantiles=[0.5],
118
+ min_trade_usd=args.min_trade_usd
119
+ )
120
+
121
+ if len(dataset) == 0:
122
+ print("WARNING: Dataset initialization resulted in 0 samples. Nothing to cache.")
123
+ return
124
+
125
+ # --- 3. Iterate and cache each item ---
126
+ print(f"INFO: Starting to generate and cache {len(dataset)} samples...")
127
+ skipped_count = 0
128
+ for i in tqdm(range(len(dataset)), desc="Caching samples"):
129
+ try:
130
+ item = dataset.__cacheitem__(i)
131
+ if item is None:
132
+ skipped_count += 1
133
+ continue
134
+ output_path = output_dir / f"sample_{i}.pt"
135
+ torch.save(item, output_path)
136
+ except Exception as e:
137
+ print(f"\nERROR: Failed to generate or save sample {i} for mint '{dataset.sampled_mints[i]['mint_address']}'. Error: {e}", file=sys.stderr)
138
+ skipped_count += 1
139
+ continue
140
+
141
+ print(f"\n--- Caching Complete ---\nSuccessfully cached: {len(dataset) - skipped_count} items.\nSkipped: {skipped_count} items.\nCache location: {output_dir.resolve()}")
142
+
143
+ # --- 4. Close connections ---
144
+ clickhouse_client.close()
145
+ neo4j_driver.close()
146
+
147
+ if __name__ == "__main__":
148
+ main()
scripts/download_epoch_artifacts.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Download a specific epoch's parquet/Neo4j artifacts from Hugging Face.
4
+
5
+ Usage:
6
+ HF_TOKEN=your_token \
7
+ python scripts/download_epoch_artifacts.py --epoch 851
8
+ """
9
+
10
+ import argparse
11
+ import os
12
+ from pathlib import Path
13
+ from typing import List
14
+
15
+ from huggingface_hub import snapshot_download
16
+
17
+
18
+ REPO_ID = "zirobtc/pump-fun-dataset"
19
+ REPO_TYPE = "model" # dataset is not used here per user note
20
+ DEFAULT_DEST_DIR = "./data/pump_fun"
21
+
22
+ # File stems that are suffixed with `_epoch_{epoch}.parquet`
23
+ PARQUET_STEMS = [
24
+ "wallet_profiles",
25
+ "wallet_holdings",
26
+ "trades",
27
+ "transfers",
28
+ "burns",
29
+ "tokens",
30
+ "mints",
31
+ "liquidity",
32
+ "pool_creations",
33
+ "token_metrics",
34
+ "wallet_profile_metrics",
35
+ "migrations",
36
+ "fee_collections",
37
+ "supply_locks",
38
+ "supply_lock_actions",
39
+ "known_wallets",
40
+ ]
41
+
42
+ # Single Neo4j dump name
43
+ NEO4J_FILENAME = "neo4j_epoch_{epoch}.dump"
44
+
45
+
46
+ def build_patterns(epoch: int) -> List[str]:
47
+ epoch_str = str(epoch)
48
+ parquet_patterns = [f"{stem}_epoch_{epoch_str}.parquet" for stem in PARQUET_STEMS]
49
+ neo4j_pattern = NEO4J_FILENAME.format(epoch=epoch_str)
50
+ return parquet_patterns + [neo4j_pattern]
51
+
52
+
53
+ def parse_args() -> argparse.Namespace:
54
+ parser = argparse.ArgumentParser(description="Download epoch artifacts from Hugging Face.")
55
+ parser.add_argument("--epoch", type=int, required=False, help="Epoch number to download (e.g., 851)", default=851)
56
+ parser.add_argument(
57
+ "--token",
58
+ type=str,
59
+ default=None,
60
+ required=False,
61
+ help="Hugging Face token (or set HF_TOKEN env var)",
62
+ )
63
+
64
+ return parser.parse_args()
65
+
66
+
67
+ def main() -> None:
68
+ args = parse_args()
69
+ token = args.token or os.environ.get("HF_TOKEN")
70
+
71
+
72
+ patterns = build_patterns(args.epoch)
73
+ dest_root = Path(DEFAULT_DEST_DIR).expanduser()
74
+ dest_dir = dest_root / f"epoch_{args.epoch}"
75
+ dest_dir.mkdir(parents=True, exist_ok=True)
76
+
77
+ print(f"Downloading epoch {args.epoch} files from {REPO_ID} to {dest_dir}")
78
+ print("Files:")
79
+ for p in patterns:
80
+ print(f" - {p}")
81
+
82
+ snapshot_download(
83
+ repo_id=REPO_ID,
84
+ repo_type=REPO_TYPE,
85
+ local_dir=str(dest_dir),
86
+ local_dir_use_symlinks=False,
87
+ allow_patterns=patterns,
88
+ resume_download=True,
89
+ token=token,
90
+ )
91
+
92
+ print("Download complete.")
93
+
94
+
95
+ if __name__ == "__main__":
96
+ main()
scripts/ingest_epoch.py ADDED
@@ -0,0 +1,713 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ ETL Pipeline: Download epoch Parquet files, ingest into ClickHouse, and delete local files.
4
+
5
+ Usage:
6
+ python scripts/ingest_epoch.py --epoch 851
7
+
8
+ Environment Variables:
9
+ HF_TOKEN: Hugging Face token for downloading private datasets.
10
+ CLICKHOUSE_HOST, CLICKHOUSE_PORT, CLICKHOUSE_USER, CLICKHOUSE_PASSWORD, CLICKHOUSE_DATABASE
11
+ """
12
+
13
+ import argparse
14
+ import os
15
+ import sys
16
+ import time
17
+ from pathlib import Path
18
+
19
+ import clickhouse_connect
20
+ from huggingface_hub import snapshot_download
21
+ from tqdm import tqdm
22
+
23
+ # Hugging Face config
24
+ REPO_ID = "zirobtc/pump-fun-dataset"
25
+ REPO_TYPE = "model"
26
+ DEFAULT_DEST_DIR = "./data/pump_fun"
27
+ CLICKHOUSE_DOCKER_CONTAINER = "db-clickhouse"
28
+ CLICKHOUSE_INSERT_SETTINGS = "max_insert_threads=1,max_block_size=65536"
29
+ NEO4J_DOCKER_CONTAINER = "neo4j"
30
+ NEO4J_TARGET_DB = "neo4j"
31
+ NEO4J_TEMP_DB_PREFIX = "epoch"
32
+ NEO4J_MERGE_BATCH_SIZE = 2000
33
+ NEO4J_URI = "bolt://localhost:7687"
34
+ NEO4J_USER = None
35
+ NEO4J_PASSWORD = None
36
+
37
+ # Parquet file stems -> ClickHouse table names
38
+ # Maps the file stem to the target table. Usually they match.
39
+ PARQUET_TABLE_MAP = {
40
+ "wallet_profiles": "wallet_profiles",
41
+ "wallet_holdings": "wallet_holdings",
42
+ "trades": "trades",
43
+ "transfers": "transfers",
44
+ "burns": "burns",
45
+ "tokens": "tokens",
46
+ "mints": "mints",
47
+ "liquidity": "liquidity",
48
+ "pool_creations": "pool_creations",
49
+ "token_metrics": "token_metrics",
50
+ "wallet_profile_metrics": "wallet_profile_metrics",
51
+ "migrations": "migrations",
52
+ "fee_collections": "fee_collections",
53
+ "supply_locks": "supply_locks",
54
+ "supply_lock_actions": "supply_lock_actions",
55
+ "known_wallets": "known_wallets",
56
+ }
57
+
58
+ # Neo4j dump filename pattern
59
+ NEO4J_FILENAME = "neo4j_epoch_{epoch}.dump"
60
+
61
+ # ClickHouse connection defaults (can be overridden by env vars)
62
+ CH_HOST = os.getenv("CLICKHOUSE_HOST", "localhost")
63
+ CH_PORT = int(os.getenv("CLICKHOUSE_PORT", "8123"))
64
+ CH_USER = os.getenv("CLICKHOUSE_USER", "default")
65
+ CH_PASSWORD = os.getenv("CLICKHOUSE_PASSWORD", "")
66
+ CH_DATABASE = os.getenv("CLICKHOUSE_DATABASE", "default")
67
+
68
+
69
+ def build_patterns(epoch: int) -> list[str]:
70
+ """Build the list of file patterns to download for a given epoch."""
71
+ epoch_str = str(epoch)
72
+ parquet_patterns = [f"{stem}_epoch_{epoch_str}.parquet" for stem in PARQUET_TABLE_MAP.keys()]
73
+ neo4j_pattern = NEO4J_FILENAME.format(epoch=epoch_str)
74
+ return parquet_patterns + [neo4j_pattern]
75
+
76
+
77
+ def download_epoch(epoch: int, dest_dir: Path, token: str | None) -> None:
78
+ """Download epoch artifacts from Hugging Face."""
79
+ patterns = build_patterns(epoch)
80
+ dest_dir.mkdir(parents=True, exist_ok=True)
81
+
82
+ print(f"📥 Downloading epoch {epoch} from {REPO_ID}...")
83
+ snapshot_download(
84
+ repo_id=REPO_ID,
85
+ repo_type=REPO_TYPE,
86
+ local_dir=str(dest_dir),
87
+ local_dir_use_symlinks=False,
88
+ allow_patterns=patterns,
89
+ resume_download=True,
90
+ token=token,
91
+ )
92
+ print("✅ Download complete.")
93
+
94
+
95
+ def ingest_parquet(client, table_name: str, parquet_path: Path, dry_run: bool = False) -> bool:
96
+ """
97
+ Ingest a Parquet file into a ClickHouse table.
98
+ Returns True on success.
99
+ """
100
+ if dry_run:
101
+ print(f" [DRY-RUN] insert {parquet_path.name} -> {table_name}")
102
+ return True
103
+
104
+ try:
105
+ with parquet_path.open("rb") as fh:
106
+ magic = fh.read(4)
107
+ if magic != b"PAR1":
108
+ print(f" ⚠️ Skipping {parquet_path.name}: not a Parquet file.")
109
+ return False
110
+
111
+ # clickhouse-connect (HTTP) doesn't support FROM INFILE; prefer streaming inserts.
112
+ # Using insert_file can still be memory-heavy for large Parquet on some setups.
113
+ import subprocess
114
+ insert_query = f"INSERT INTO {table_name} FORMAT Parquet SETTINGS {CLICKHOUSE_INSERT_SETTINGS}"
115
+ infile_query = f"INSERT INTO {table_name} FROM INFILE '{parquet_path.resolve()}' FORMAT Parquet"
116
+ try:
117
+ cmd = [
118
+ "clickhouse-client",
119
+ "--host", CH_HOST,
120
+ "--port", str(CH_PORT),
121
+ "--user", CH_USER,
122
+ "--password", CH_PASSWORD,
123
+ "--database", CH_DATABASE,
124
+ "--query", infile_query,
125
+ ]
126
+ subprocess.run(cmd, check=True)
127
+ return True
128
+ except FileNotFoundError:
129
+ pass
130
+
131
+ # Docker fallback for ClickHouse container
132
+ ch_container = CLICKHOUSE_DOCKER_CONTAINER
133
+ try:
134
+ tmp_path = f"/tmp/{parquet_path.name}"
135
+ subprocess.run(
136
+ ["docker", "cp", str(parquet_path), f"{ch_container}:{tmp_path}"],
137
+ check=True,
138
+ )
139
+ docker_cmd = [
140
+ "docker", "exec", ch_container,
141
+ "clickhouse-client",
142
+ "--query", f"INSERT INTO {table_name} FROM INFILE '{tmp_path}' FORMAT Parquet",
143
+ ]
144
+ subprocess.run(docker_cmd, check=True)
145
+ subprocess.run(["docker", "exec", ch_container, "rm", "-f", tmp_path], check=True)
146
+ return True
147
+ except FileNotFoundError:
148
+ raise RuntimeError(
149
+ "clickhouse-client not found and docker is unavailable. Install clickhouse-client or use a ClickHouse container."
150
+ )
151
+ except Exception as e:
152
+ print(f" ❌ Failed to ingest {parquet_path.name}: {e}")
153
+ return False
154
+
155
+
156
+ def run_etl(epoch: int, dest_dir: Path, client, dry_run: bool = False, token: str | None = None, skip_neo4j: bool = False, skip_clickhouse: bool = False) -> None:
157
+ """
158
+ Full ETL pipeline:
159
+ 1. Use local Parquet files (no download)
160
+ 2. Ingest into ClickHouse
161
+ 3. Keep local files (no deletion)
162
+ """
163
+ if not dest_dir.exists():
164
+ raise FileNotFoundError(f"Epoch directory not found: {dest_dir}")
165
+
166
+ if not skip_clickhouse:
167
+ # Step 2: Ingest each Parquet file
168
+ print(f"\n📤 Ingesting Parquet files into ClickHouse...")
169
+ for stem, table_name in tqdm(PARQUET_TABLE_MAP.items(), desc="Ingesting"):
170
+ parquet_path = dest_dir / f"{stem}_epoch_{epoch}.parquet"
171
+ if not parquet_path.exists():
172
+ print(f" ⚠️ Skipping {stem}: file not found.")
173
+ continue
174
+
175
+ ingest_parquet(client, table_name, parquet_path, dry_run=dry_run)
176
+
177
+ print("\n✅ ClickHouse ingestion complete.")
178
+ else:
179
+ print("\nℹ️ ClickHouse ingestion skipped.")
180
+
181
+ # Step 4: Neo4j dump
182
+ neo4j_path = dest_dir / NEO4J_FILENAME.format(epoch=epoch)
183
+ if neo4j_path.exists() and not skip_neo4j:
184
+ merge_neo4j_epoch_dump(epoch, neo4j_path, dry_run=dry_run)
185
+ elif neo4j_path.exists() and skip_neo4j:
186
+ print(f"\nℹ️ Neo4j dump found but skipped: {neo4j_path}")
187
+
188
+ print("\n🎉 Full ETL pipeline complete.")
189
+
190
+
191
+ def ingest_neo4j_dump(dump_path: Path, database: str = "neo4j", dry_run: bool = False) -> bool:
192
+ """
193
+ Load a Neo4j dump file into the database.
194
+ Requires neo4j-admin CLI and the Neo4j service to be stopped.
195
+ Returns True on success.
196
+ """
197
+ import subprocess
198
+
199
+ if not dump_path.exists():
200
+ print(f" ⚠️ Neo4j dump not found: {dump_path}")
201
+ return False
202
+
203
+ import shutil
204
+
205
+ expected_dump_name = f"{database}.dump"
206
+ load_dir = dump_path.parent
207
+ temp_load_dir = None
208
+ if dump_path.name != expected_dump_name:
209
+ temp_load_dir = dump_path.parent / f"_neo4j_load_{database}"
210
+ temp_load_dir.mkdir(parents=True, exist_ok=True)
211
+ load_dump_path = temp_load_dir / expected_dump_name
212
+ shutil.copy2(dump_path, load_dump_path)
213
+ load_dir = temp_load_dir
214
+
215
+ # neo4j-admin database load requires a directory containing <database>.dump
216
+ # For Neo4j 5.x: neo4j-admin database load --from-path=<dir> <database>
217
+ # Note: User must clear the database before loading (no --overwrite flag)
218
+ cmd = [
219
+ "neo4j-admin", "database", "load",
220
+ f"--from-path={load_dir.resolve()}",
221
+ database,
222
+ ]
223
+
224
+ if dry_run:
225
+ print(f" [DRY-RUN] {' '.join(cmd)}")
226
+ return True
227
+
228
+ print(f"🔄 Loading Neo4j dump into database '{database}'...")
229
+ print(" ⚠️ Neo4j must be stopped for offline load.")
230
+
231
+ try:
232
+ result = subprocess.run(cmd, capture_output=True, text=True, check=True)
233
+ print(" ✅ Neo4j dump loaded successfully.")
234
+ return True
235
+ except FileNotFoundError:
236
+ # Fall back to dockerized neo4j-admin if available
237
+ docker_container = NEO4J_DOCKER_CONTAINER
238
+ try:
239
+ docker_ps = subprocess.run(
240
+ ["docker", "ps", "-a", "--format", "{{.Names}}\t{{.Image}}"],
241
+ capture_output=True,
242
+ text=True,
243
+ check=True,
244
+ )
245
+ except FileNotFoundError:
246
+ print(" ❌ neo4j-admin not found and docker is unavailable.")
247
+ return False
248
+ except subprocess.CalledProcessError as e:
249
+ print(f" ❌ Failed to list docker containers: {e.stderr}")
250
+ return False
251
+
252
+ containers = [line.strip().split("\t") for line in docker_ps.stdout.splitlines() if line.strip()]
253
+ container_names = {name for name, _ in containers}
254
+ if docker_container not in container_names:
255
+ # Try to auto-detect a neo4j container if the default name isn't found.
256
+ neo4j_candidates = [name for name, image in containers if image.startswith("neo4j")]
257
+ if neo4j_candidates:
258
+ docker_container = neo4j_candidates[0]
259
+ print(f" ℹ️ Using detected Neo4j container '{docker_container}'.")
260
+ else:
261
+ print(f" ❌ neo4j-admin not found and docker container '{docker_container}' does not exist.")
262
+ return False
263
+
264
+ docker_running = subprocess.run(
265
+ ["docker", "ps", "--format", "{{.Names}}"],
266
+ capture_output=True,
267
+ text=True,
268
+ check=True,
269
+ )
270
+ running = set(line.strip() for line in docker_running.stdout.splitlines() if line.strip())
271
+ was_running = docker_container in running
272
+
273
+ if was_running:
274
+ print(f" 🛑 Stopping Neo4j container '{docker_container}' for offline load...")
275
+ if dry_run:
276
+ print(f" [DRY-RUN] docker stop {docker_container}")
277
+ else:
278
+ subprocess.run(["docker", "stop", docker_container], check=True)
279
+
280
+ dump_name = dump_path.name
281
+ docker_cmd = [
282
+ "docker", "run", "--rm",
283
+ "--volumes-from", docker_container,
284
+ "-v", f"{load_dir.resolve()}:/dump",
285
+ "neo4j:latest",
286
+ "neo4j-admin", "database", "load",
287
+ f"--from-path=/dump",
288
+ "--overwrite-destination",
289
+ database,
290
+ ]
291
+
292
+ if dry_run:
293
+ print(f" [DRY-RUN] {' '.join(docker_cmd)}")
294
+ else:
295
+ print(f" 🔄 Running neo4j-admin in docker for {dump_name}...")
296
+ subprocess.run(docker_cmd, check=True)
297
+ print(" ✅ Neo4j dump loaded successfully (docker).")
298
+
299
+ if was_running:
300
+ print(f" ▶️ Starting Neo4j container '{docker_container}'...")
301
+ if dry_run:
302
+ print(f" [DRY-RUN] docker start {docker_container}")
303
+ else:
304
+ subprocess.run(["docker", "start", docker_container], check=True)
305
+ _wait_for_bolt(NEO4J_URI)
306
+ if temp_load_dir and not dry_run:
307
+ shutil.rmtree(temp_load_dir, ignore_errors=True)
308
+ return True
309
+ except subprocess.CalledProcessError as e:
310
+ print(f" ❌ Failed to load Neo4j dump: {e.stderr}")
311
+ if temp_load_dir and not dry_run:
312
+ shutil.rmtree(temp_load_dir, ignore_errors=True)
313
+ return False
314
+
315
+
316
+ def _neo4j_driver():
317
+ from neo4j import GraphDatabase
318
+ if NEO4J_USER is None and NEO4J_PASSWORD is None:
319
+ return GraphDatabase.driver(NEO4J_URI, auth=None)
320
+ return GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
321
+
322
+
323
+ def _run_merge_batch(tx, query: str, rows: list[dict]) -> None:
324
+ tx.run(query, rows=rows)
325
+
326
+
327
+ def _stream_merge(temp_session, target_session, match_query: str, merge_query: str, label: str) -> None:
328
+ batch = []
329
+ result = temp_session.run(match_query, fetch_size=NEO4J_MERGE_BATCH_SIZE)
330
+ for record in result:
331
+ batch.append(record.data())
332
+ if len(batch) >= NEO4J_MERGE_BATCH_SIZE:
333
+ target_session.execute_write(_run_merge_batch, merge_query, batch)
334
+ batch.clear()
335
+ if batch:
336
+ target_session.execute_write(_run_merge_batch, merge_query, batch)
337
+
338
+
339
+ def _wait_for_bolt(uri: str, timeout_sec: int = 60) -> None:
340
+ from neo4j import GraphDatabase
341
+ start = time.time()
342
+ while True:
343
+ try:
344
+ temp_driver = GraphDatabase.driver(uri, auth=None)
345
+ with temp_driver.session(database="neo4j") as session:
346
+ session.run("RETURN 1").consume()
347
+ temp_driver.close()
348
+ return
349
+ except Exception:
350
+ if time.time() - start > timeout_sec:
351
+ raise RuntimeError(f"Timed out waiting for Neo4j at {uri}")
352
+ time.sleep(1)
353
+
354
+
355
+ def _start_temp_neo4j_from_dump(epoch: int, dump_path: Path) -> tuple[str, str, str, Path]:
356
+ import subprocess
357
+ import shutil
358
+
359
+ expected_dump_name = "neo4j.dump"
360
+ temp_load_dir = dump_path.parent / f"_neo4j_load_{epoch}"
361
+ temp_load_dir.mkdir(parents=True, exist_ok=True)
362
+ load_dump_path = temp_load_dir / expected_dump_name
363
+ shutil.copy2(dump_path, load_dump_path)
364
+
365
+ volume_name = f"neo4j_tmp_{epoch}"
366
+ subprocess.run(["docker", "volume", "create", volume_name], check=True)
367
+
368
+ subprocess.run(
369
+ [
370
+ "docker", "run", "--rm",
371
+ "-v", f"{volume_name}:/data",
372
+ "-v", f"{temp_load_dir.resolve()}:/dump",
373
+ "neo4j:latest",
374
+ "neo4j-admin", "database", "load",
375
+ "--from-path=/dump",
376
+ "--overwrite-destination",
377
+ "neo4j",
378
+ ],
379
+ check=True,
380
+ )
381
+
382
+ container_id = subprocess.check_output(
383
+ [
384
+ "docker", "run", "-d", "--rm",
385
+ "-e", "NEO4J_AUTH=none",
386
+ "-v", f"{volume_name}:/data",
387
+ "-p", "0:7687",
388
+ "neo4j:latest",
389
+ ],
390
+ text=True,
391
+ ).strip()
392
+
393
+ port_out = subprocess.check_output(
394
+ ["docker", "port", container_id, "7687/tcp"],
395
+ text=True,
396
+ ).strip()
397
+ host_port = port_out.split(":")[-1]
398
+ bolt_uri = f"bolt://localhost:{host_port}"
399
+ return container_id, bolt_uri, volume_name, temp_load_dir
400
+
401
+
402
+ def merge_neo4j_epoch_dump(epoch: int, dump_path: Path, dry_run: bool = False) -> None:
403
+ print(f"\n🧩 Merging Neo4j dump into '{NEO4J_TARGET_DB}' via temp container...")
404
+ if dry_run:
405
+ _start_temp_neo4j_from_dump(epoch, dump_path)
406
+ print(" [DRY-RUN] merge skipped.")
407
+ return
408
+
409
+ temp_container_id = None
410
+ temp_volume = None
411
+ temp_load_dir = None
412
+ temp_driver = None
413
+ temp_db_name = "neo4j"
414
+
415
+ temp_container_id, temp_bolt_uri, temp_volume, temp_load_dir = _start_temp_neo4j_from_dump(epoch, dump_path)
416
+ _wait_for_bolt(temp_bolt_uri)
417
+ from neo4j import GraphDatabase
418
+ temp_driver = GraphDatabase.driver(temp_bolt_uri, auth=None)
419
+
420
+ _wait_for_bolt(NEO4J_URI)
421
+ driver = _neo4j_driver()
422
+ try:
423
+ with temp_driver.session(database=temp_db_name) as temp_session, driver.session(database=NEO4J_TARGET_DB) as target_session:
424
+ # Wallet nodes
425
+ _stream_merge(
426
+ temp_session,
427
+ target_session,
428
+ "MATCH (w:Wallet) RETURN w.address AS address",
429
+ "UNWIND $rows AS t MERGE (w:Wallet {address: t.address})",
430
+ "wallets",
431
+ )
432
+
433
+ # Token nodes
434
+ _stream_merge(
435
+ temp_session,
436
+ target_session,
437
+ "MATCH (t:Token) RETURN t.address AS address, t.created_ts AS created_ts",
438
+ "UNWIND $rows AS t MERGE (k:Token {address: t.address}) "
439
+ "ON CREATE SET k.created_ts = t.created_ts "
440
+ "ON MATCH SET k.created_ts = CASE WHEN k.created_ts IS NULL OR "
441
+ "t.created_ts < k.created_ts THEN t.created_ts ELSE k.created_ts END",
442
+ "tokens",
443
+ )
444
+
445
+ # BUNDLE_TRADE
446
+ _stream_merge(
447
+ temp_session,
448
+ target_session,
449
+ "MATCH (a:Wallet)-[r:BUNDLE_TRADE]->(b:Wallet) "
450
+ "RETURN a.address AS wa, b.address AS wb, r.mint AS mint, r.slot AS slot, "
451
+ "r.timestamp AS timestamp, r.signatures AS signatures",
452
+ "UNWIND $rows AS t "
453
+ "MERGE (a:Wallet {address: t.wa}) "
454
+ "MERGE (b:Wallet {address: t.wb}) "
455
+ "MERGE (a)-[r:BUNDLE_TRADE {mint: t.mint, slot: t.slot}]->(b) "
456
+ "ON CREATE SET r.timestamp = t.timestamp, r.signatures = t.signatures "
457
+ "ON MATCH SET r.timestamp = CASE WHEN r.timestamp IS NULL OR "
458
+ "t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END",
459
+ "bundle_trade",
460
+ )
461
+
462
+ # TRANSFERRED_TO
463
+ _stream_merge(
464
+ temp_session,
465
+ target_session,
466
+ "MATCH (s:Wallet)-[r:TRANSFERRED_TO]->(d:Wallet) "
467
+ "RETURN s.address AS source, d.address AS destination, r.mint AS mint, "
468
+ "r.signature AS signature, r.timestamp AS timestamp, r.amount AS amount",
469
+ "UNWIND $rows AS t "
470
+ "MERGE (s:Wallet {address: t.source}) "
471
+ "MERGE (d:Wallet {address: t.destination}) "
472
+ "MERGE (s)-[r:TRANSFERRED_TO {mint: t.mint}]->(d) "
473
+ "ON CREATE SET r.signature = t.signature, r.timestamp = t.timestamp, r.amount = t.amount "
474
+ "ON MATCH SET r.timestamp = CASE WHEN r.timestamp IS NULL OR "
475
+ "t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END",
476
+ "transfer",
477
+ )
478
+
479
+ # COORDINATED_ACTIVITY
480
+ _stream_merge(
481
+ temp_session,
482
+ target_session,
483
+ "MATCH (f:Wallet)-[r:COORDINATED_ACTIVITY]->(l:Wallet) "
484
+ "RETURN f.address AS follower, l.address AS leader, r.mint AS mint, r.timestamp AS timestamp, "
485
+ "r.leader_first_sig AS leader_first_sig, r.leader_second_sig AS leader_second_sig, "
486
+ "r.follower_first_sig AS follower_first_sig, r.follower_second_sig AS follower_second_sig, "
487
+ "r.time_gap_on_first_sec AS gap_1, r.time_gap_on_second_sec AS gap_2",
488
+ "UNWIND $rows AS t "
489
+ "MERGE (l:Wallet {address: t.leader}) "
490
+ "MERGE (f:Wallet {address: t.follower}) "
491
+ "MERGE (f)-[r:COORDINATED_ACTIVITY {mint: t.mint}]->(l) "
492
+ "ON CREATE SET r.timestamp = t.timestamp, r.leader_first_sig = t.leader_first_sig, "
493
+ "r.leader_second_sig = t.leader_second_sig, r.follower_first_sig = t.follower_first_sig, "
494
+ "r.follower_second_sig = t.follower_second_sig, r.time_gap_on_first_sec = t.gap_1, "
495
+ "r.time_gap_on_second_sec = t.gap_2 "
496
+ "ON MATCH SET r.timestamp = CASE WHEN r.timestamp IS NULL OR "
497
+ "t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END",
498
+ "coordinated_activity",
499
+ )
500
+
501
+ # COPIED_TRADE
502
+ _stream_merge(
503
+ temp_session,
504
+ target_session,
505
+ "MATCH (f:Wallet)-[r:COPIED_TRADE]->(l:Wallet) "
506
+ "RETURN f.address AS follower, l.address AS leader, r.mint AS mint, r.timestamp AS timestamp, "
507
+ "r.buy_gap AS buy_gap, r.sell_gap AS sell_gap, r.leader_pnl AS leader_pnl, "
508
+ "r.follower_pnl AS follower_pnl, r.l_buy_sig AS l_buy_sig, r.l_sell_sig AS l_sell_sig, "
509
+ "r.f_buy_sig AS f_buy_sig, r.f_sell_sig AS f_sell_sig, r.l_buy_total AS l_buy_total, "
510
+ "r.l_sell_total AS l_sell_total, r.f_buy_total AS f_buy_total, r.f_sell_total AS f_sell_total, "
511
+ "r.f_buy_slip AS f_buy_slip, r.f_sell_slip AS f_sell_slip",
512
+ "UNWIND $rows AS t "
513
+ "MERGE (f:Wallet {address: t.follower}) "
514
+ "MERGE (l:Wallet {address: t.leader}) "
515
+ "MERGE (f)-[r:COPIED_TRADE {mint: t.mint}]->(l) "
516
+ "ON CREATE SET r.timestamp = t.timestamp, r.follower = t.follower, r.leader = t.leader, "
517
+ "r.mint = t.mint, r.buy_gap = t.buy_gap, r.sell_gap = t.sell_gap, r.leader_pnl = t.leader_pnl, "
518
+ "r.follower_pnl = t.follower_pnl, r.l_buy_sig = t.l_buy_sig, r.l_sell_sig = t.l_sell_sig, "
519
+ "r.f_buy_sig = t.f_buy_sig, r.f_sell_sig = t.f_sell_sig, r.l_buy_total = t.l_buy_total, "
520
+ "r.l_sell_total = t.l_sell_total, r.f_buy_total = t.f_buy_total, r.f_sell_total = t.f_sell_total, "
521
+ "r.f_buy_slip = t.f_buy_slip, r.f_sell_slip = t.f_sell_slip "
522
+ "ON MATCH SET r.timestamp = CASE WHEN r.timestamp IS NULL OR "
523
+ "t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END",
524
+ "copied_trade",
525
+ )
526
+
527
+ # MINTED
528
+ _stream_merge(
529
+ temp_session,
530
+ target_session,
531
+ "MATCH (c:Wallet)-[r:MINTED]->(k:Token) "
532
+ "RETURN c.address AS creator, k.address AS token, r.signature AS signature, "
533
+ "r.timestamp AS timestamp, r.buy_amount AS buy_amount",
534
+ "UNWIND $rows AS t "
535
+ "MERGE (c:Wallet {address: t.creator}) "
536
+ "MERGE (k:Token {address: t.token}) "
537
+ "MERGE (c)-[r:MINTED {signature: t.signature}]->(k) "
538
+ "ON CREATE SET r.timestamp = t.timestamp, r.buy_amount = t.buy_amount "
539
+ "ON MATCH SET r.timestamp = CASE WHEN r.timestamp IS NULL OR "
540
+ "t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END",
541
+ "minted",
542
+ )
543
+
544
+ # SNIPED
545
+ _stream_merge(
546
+ temp_session,
547
+ target_session,
548
+ "MATCH (w:Wallet)-[r:SNIPED]->(k:Token) "
549
+ "RETURN w.address AS wallet, k.address AS token, r.signature AS signature, "
550
+ "r.rank AS rank, r.sniped_amount AS sniped_amount, r.timestamp AS timestamp",
551
+ "UNWIND $rows AS t "
552
+ "MERGE (w:Wallet {address: t.wallet}) "
553
+ "MERGE (k:Token {address: t.token}) "
554
+ "MERGE (w)-[r:SNIPED {signature: t.signature}]->(k) "
555
+ "ON CREATE SET r.rank = t.rank, r.sniped_amount = t.sniped_amount, r.timestamp = t.timestamp "
556
+ "ON MATCH SET r.timestamp = CASE WHEN r.timestamp IS NULL OR "
557
+ "t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END",
558
+ "sniped",
559
+ )
560
+
561
+ # LOCKED_SUPPLY
562
+ _stream_merge(
563
+ temp_session,
564
+ target_session,
565
+ "MATCH (s:Wallet)-[r:LOCKED_SUPPLY]->(k:Token) "
566
+ "RETURN s.address AS sender, k.address AS mint, r.signature AS signature, "
567
+ "r.amount AS amount, r.unlock_timestamp AS unlock_ts, r.recipient AS recipient, "
568
+ "r.timestamp AS timestamp",
569
+ "UNWIND $rows AS t "
570
+ "MERGE (s:Wallet {address: t.sender}) "
571
+ "MERGE (k:Token {address: t.mint}) "
572
+ "MERGE (s)-[r:LOCKED_SUPPLY {signature: t.signature}]->(k) "
573
+ "ON CREATE SET r.amount = t.amount, r.unlock_timestamp = t.unlock_ts, "
574
+ "r.recipient = t.recipient, r.timestamp = t.timestamp "
575
+ "ON MATCH SET r.timestamp = CASE WHEN r.timestamp IS NULL OR "
576
+ "t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END",
577
+ "locked_supply",
578
+ )
579
+
580
+ # BURNED
581
+ _stream_merge(
582
+ temp_session,
583
+ target_session,
584
+ "MATCH (w:Wallet)-[r:BURNED]->(k:Token) "
585
+ "RETURN w.address AS wallet, k.address AS token, r.signature AS signature, "
586
+ "r.amount AS amount, r.timestamp AS timestamp",
587
+ "UNWIND $rows AS t "
588
+ "MERGE (w:Wallet {address: t.wallet}) "
589
+ "MERGE (k:Token {address: t.token}) "
590
+ "MERGE (w)-[r:BURNED {signature: t.signature}]->(k) "
591
+ "ON CREATE SET r.amount = t.amount, r.timestamp = t.timestamp "
592
+ "ON MATCH SET r.timestamp = CASE WHEN r.timestamp IS NULL OR "
593
+ "t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END",
594
+ "burned",
595
+ )
596
+
597
+ # PROVIDED_LIQUIDITY
598
+ _stream_merge(
599
+ temp_session,
600
+ target_session,
601
+ "MATCH (w:Wallet)-[r:PROVIDED_LIQUIDITY]->(k:Token) "
602
+ "RETURN w.address AS wallet, k.address AS token, r.signature AS signature, "
603
+ "r.pool_address AS pool_address, r.amount_base AS amount_base, "
604
+ "r.amount_quote AS amount_quote, r.timestamp AS timestamp",
605
+ "UNWIND $rows AS t "
606
+ "MERGE (w:Wallet {address: t.wallet}) "
607
+ "MERGE (k:Token {address: t.token}) "
608
+ "MERGE (w)-[r:PROVIDED_LIQUIDITY {signature: t.signature}]->(k) "
609
+ "ON CREATE SET r.pool_address = t.pool_address, r.amount_base = t.amount_base, "
610
+ "r.amount_quote = t.amount_quote, r.timestamp = t.timestamp "
611
+ "ON MATCH SET r.timestamp = CASE WHEN r.timestamp IS NULL OR "
612
+ "t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END",
613
+ "provided_liquidity",
614
+ )
615
+
616
+ # TOP_TRADER_OF
617
+ _stream_merge(
618
+ temp_session,
619
+ target_session,
620
+ "MATCH (w:Wallet)-[r:TOP_TRADER_OF]->(k:Token) "
621
+ "RETURN w.address AS wallet, k.address AS token, r.pnl_at_creation AS pnl_at_creation, "
622
+ "r.ath_usd_at_creation AS ath_at_creation, r.timestamp AS timestamp",
623
+ "UNWIND $rows AS t "
624
+ "MERGE (w:Wallet {address: t.wallet}) "
625
+ "MERGE (k:Token {address: t.token}) "
626
+ "MERGE (w)-[r:TOP_TRADER_OF]->(k) "
627
+ "ON CREATE SET r.pnl_at_creation = t.pnl_at_creation, r.ath_usd_at_creation = t.ath_at_creation, "
628
+ "r.timestamp = t.timestamp "
629
+ "ON MATCH SET r.timestamp = CASE WHEN r.timestamp IS NULL OR "
630
+ "t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END",
631
+ "top_trader_of",
632
+ )
633
+
634
+ # WHALE_OF
635
+ _stream_merge(
636
+ temp_session,
637
+ target_session,
638
+ "MATCH (w:Wallet)-[r:WHALE_OF]->(k:Token) "
639
+ "RETURN w.address AS wallet, k.address AS token, r.holding_pct_at_creation AS pct_at_creation, "
640
+ "r.ath_usd_at_creation AS ath_at_creation, r.timestamp AS timestamp",
641
+ "UNWIND $rows AS t "
642
+ "MERGE (w:Wallet {address: t.wallet}) "
643
+ "MERGE (k:Token {address: t.token}) "
644
+ "MERGE (w)-[r:WHALE_OF]->(k) "
645
+ "ON CREATE SET r.holding_pct_at_creation = t.pct_at_creation, "
646
+ "r.ath_usd_at_creation = t.ath_at_creation, r.timestamp = t.timestamp "
647
+ "ON MATCH SET r.timestamp = CASE WHEN r.timestamp IS NULL OR "
648
+ "t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END",
649
+ "whale_of",
650
+ )
651
+ finally:
652
+ driver.close()
653
+
654
+ try:
655
+ if temp_driver:
656
+ temp_driver.close()
657
+ if temp_container_id:
658
+ import subprocess
659
+ subprocess.run(["docker", "stop", temp_container_id], check=True)
660
+ if temp_volume:
661
+ import subprocess
662
+ subprocess.run(["docker", "volume", "rm", "-f", temp_volume], check=True)
663
+ if temp_load_dir:
664
+ import shutil
665
+ shutil.rmtree(temp_load_dir, ignore_errors=True)
666
+ print(" 🧹 Dropped temp Neo4j container.")
667
+ except Exception as e:
668
+ print(f" ⚠️ Failed to clean temp Neo4j container: {e}")
669
+
670
+
671
+ def parse_args() -> argparse.Namespace:
672
+ parser = argparse.ArgumentParser(description="ETL: Download, Ingest, Delete epoch Parquet files.")
673
+ parser.add_argument("--epoch", type=int, required=True, help="Epoch number to process (e.g., 851)")
674
+ parser.add_argument("-c", "--skip-clickhouse", action="store_true", help="Skip ClickHouse ingestion")
675
+ parser.add_argument("--dry-run", action="store_true", help="Print queries without executing")
676
+ parser.add_argument("--skip-neo4j", action="store_true", help="Skip Neo4j dump loading")
677
+ parser.add_argument("--token", type=str, default=None, help="Hugging Face token (or set HF_TOKEN env var)")
678
+ return parser.parse_args()
679
+
680
+
681
+ def main() -> None:
682
+ args = parse_args()
683
+ token = args.token or os.environ.get("HF_TOKEN")
684
+
685
+ dest_dir = Path(DEFAULT_DEST_DIR).expanduser() / f"epoch_{args.epoch}"
686
+
687
+ # Connect to ClickHouse
688
+ print(f"🔌 Connecting to ClickHouse at {CH_HOST}:{CH_PORT}...")
689
+ try:
690
+ client = clickhouse_connect.get_client(
691
+ host=CH_HOST,
692
+ port=CH_PORT,
693
+ username=CH_USER,
694
+ password=CH_PASSWORD,
695
+ database=CH_DATABASE,
696
+ )
697
+ except Exception as e:
698
+ print(f"❌ Failed to connect to ClickHouse: {e}")
699
+ sys.exit(1)
700
+
701
+ run_etl(
702
+ epoch=args.epoch,
703
+ dest_dir=dest_dir,
704
+ client=client,
705
+ dry_run=args.dry_run,
706
+ token=token,
707
+ skip_neo4j=args.skip_neo4j,
708
+ skip_clickhouse=args.skip_clickhouse,
709
+ )
710
+
711
+
712
+ if __name__ == "__main__":
713
+ main()
train.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import math
4
+ import logging
5
+ from pathlib import Path
6
+ from typing import Any, Dict, List, Optional, Tuple
7
+
8
+ # Ensure torch/dill have a writable tmp dir
9
+ _DEFAULT_TMP = Path(os.getenv("TMPDIR_OVERRIDE", "./.tmp"))
10
+ _DEFAULT_TMP.mkdir(parents=True, exist_ok=True)
11
+ resolved_tmp = str(_DEFAULT_TMP.resolve())
12
+ for key in ("TMPDIR", "TMP", "TEMP"):
13
+ os.environ.setdefault(key, resolved_tmp)
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch.utils.data import DataLoader
18
+ from torch.optim import AdamW
19
+
20
+ # --- Accelerate & Transformers ---
21
+ from accelerate import Accelerator
22
+ from accelerate.logging import get_logger
23
+ from accelerate.utils import ProjectConfiguration, set_seed
24
+ from transformers import get_linear_schedule_with_warmup
25
+
26
+ # Logging
27
+ from tqdm.auto import tqdm
28
+
29
+ # DB Clients
30
+ from clickhouse_driver import Client as ClickHouseClient
31
+ from neo4j import GraphDatabase
32
+
33
+ # Local Imports
34
+ from data.data_fetcher import DataFetcher
35
+ from data.data_loader import OracleDataset
36
+ from data.data_collator import MemecoinCollator
37
+ from models.multi_modal_processor import MultiModalEncoder
38
+ from models.helper_encoders import ContextualTimeEncoder
39
+ from models.token_encoder import TokenEncoder
40
+ from models.wallet_encoder import WalletEncoder
41
+ from models.graph_updater import GraphUpdater
42
+ from models.ohlc_embedder import OHLCEmbedder
43
+ from models.model import Oracle
44
+ import models.vocabulary as vocab
45
+
46
+ # Setup Logger
47
+ logger = get_logger(__name__)
48
+
49
+
50
+ def compute_gradient_stats(model: nn.Module) -> Tuple[Optional[Dict[str, float]], Dict[str, float]]:
51
+ """Return overall and per-module gradient statistics for logging."""
52
+ grad_norms: List[float] = []
53
+ max_abs = 0.0
54
+ module_l2_sums: Dict[str, float] = {}
55
+
56
+ for name, param in model.named_parameters():
57
+ if param.grad is None:
58
+ continue
59
+ grad = param.grad.detach()
60
+ grad_norm = grad.norm().item()
61
+ grad_norms.append(grad_norm)
62
+ max_abs = max(max_abs, grad.abs().max().item())
63
+
64
+ module_name = name.split(".", 1)[0]
65
+ grad_fp32 = grad.float()
66
+ module_l2_sums[module_name] = module_l2_sums.get(module_name, 0.0) + float(grad_fp32.pow(2).sum().item())
67
+
68
+ if not grad_norms:
69
+ return None, {}
70
+
71
+ module_grad_norms = {module: math.sqrt(total) for module, total in module_l2_sums.items()}
72
+
73
+ return {
74
+ "grad_norm_mean": float(sum(grad_norms) / len(grad_norms)),
75
+ "grad_norm_max": float(max(grad_norms)),
76
+ "grad_abs_max": float(max_abs),
77
+ }, module_grad_norms
78
+
79
+ def quantile_pinball_loss(preds: torch.Tensor,
80
+ targets: torch.Tensor,
81
+ mask: torch.Tensor,
82
+ quantiles: List[float]) -> torch.Tensor:
83
+ """
84
+ Calculates Pinball Loss for quantile regression.
85
+ """
86
+ if mask.sum() == 0:
87
+ return torch.tensor(0.0, device=preds.device, dtype=preds.dtype)
88
+
89
+ num_quantiles = len(quantiles)
90
+ losses = []
91
+ for idx, q in enumerate(quantiles):
92
+ # Preds shape: [B, Horizons * Quantiles]
93
+ # Logic assumes interleaved outputs or consistent flattening.
94
+ pred_slice = preds[:, idx::num_quantiles]
95
+ target_slice = targets[:, idx::num_quantiles]
96
+ mask_slice = mask[:, idx::num_quantiles]
97
+
98
+ diff = target_slice - pred_slice
99
+ pinball = torch.maximum((q - 1.0) * diff, q * diff)
100
+ losses.append((pinball * mask_slice).sum())
101
+
102
+ return sum(losses) / mask.sum().clamp_min(1.0)
103
+
104
+
105
+ def filtered_collate(collator: MemecoinCollator,
106
+ batch: List[Optional[Dict[str, Any]]]) -> Optional[Dict[str, Any]]:
107
+ """Filter out None items from the dataset before collating."""
108
+ batch = [item for item in batch if item is not None]
109
+ if not batch:
110
+ return None
111
+ return collator(batch)
112
+
113
+
114
+ def parse_args() -> argparse.Namespace:
115
+ parser = argparse.ArgumentParser(description="Train the Oracle quantile model.")
116
+ parser.add_argument("--epochs", type=int, default=1)
117
+ parser.add_argument("--batch_size", type=int, default=1)
118
+ parser.add_argument("--learning_rate", type=float, default=5e-5)
119
+ parser.add_argument("--warmup_ratio", type=float, default=0.1)
120
+ parser.add_argument("--grad_accum_steps", type=int, default=1)
121
+ parser.add_argument("--max_grad_norm", type=float, default=1.0)
122
+ parser.add_argument("--seed", type=int, default=42)
123
+ parser.add_argument("--log_every", type=int, default=1)
124
+ parser.add_argument("--save_every", type=int, default=1000)
125
+ parser.add_argument("--tensorboard_dir", type=str, default="runs/oracle")
126
+ parser.add_argument("--checkpoint_dir", type=str, default="checkpoints")
127
+ parser.add_argument("--mixed_precision", type=str, default="bf16")
128
+ parser.add_argument("--max_seq_len", type=int, default=16000)
129
+ parser.add_argument("--ohlc_seq_len", type=int, default=60)
130
+ parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[30, 60, 120, 240, 420])
131
+ parser.add_argument("--quantiles", type=float, nargs="+", default=[0.1, 0.5, 0.9])
132
+ parser.add_argument("--max_samples", type=int, default=None)
133
+ parser.add_argument("--ohlc_stats_path", type=str, default="./data/ohlc_stats.npz")
134
+ parser.add_argument("--t_cutoff_seconds", type=int, default=60)
135
+ parser.add_argument("--shuffle", dest="shuffle", action="store_true", default=True)
136
+ parser.add_argument("--no-shuffle", dest="shuffle", action="store_false")
137
+ parser.add_argument("--num_workers", type=int, default=0)
138
+ parser.add_argument("--pin_memory", dest="pin_memory", action="store_true", default=False)
139
+ parser.add_argument("--no-pin_memory", dest="pin_memory", action="store_false")
140
+ parser.add_argument("--clickhouse_host", type=str, default="localhost")
141
+ parser.add_argument("--clickhouse_port", type=int, default=9000)
142
+ parser.add_argument("--neo4j_uri", type=str, default="bolt://localhost:7687")
143
+ parser.add_argument("--neo4j_user", type=str, default=None)
144
+ parser.add_argument("--neo4j_password", type=str, default=None)
145
+ return parser.parse_args()
146
+
147
+
148
+ def main() -> None:
149
+ args = parse_args()
150
+ epochs = args.epochs
151
+ batch_size = args.batch_size
152
+ learning_rate = args.learning_rate
153
+ warmup_ratio = args.warmup_ratio
154
+ grad_accum_steps = args.grad_accum_steps
155
+ max_grad_norm = args.max_grad_norm
156
+ seed = args.seed
157
+
158
+ log_every = args.log_every
159
+ save_every = args.save_every
160
+
161
+ tensorboard_dir = Path(args.tensorboard_dir).expanduser()
162
+ checkpoint_dir = Path(args.checkpoint_dir).expanduser()
163
+
164
+ # --- 1. Initialize Accelerator ---
165
+ project_config = ProjectConfiguration(project_dir=str(checkpoint_dir), logging_dir=str(tensorboard_dir))
166
+ accelerator = Accelerator(
167
+ gradient_accumulation_steps=grad_accum_steps,
168
+ log_with="tensorboard",
169
+ project_config=project_config,
170
+ mixed_precision=args.mixed_precision # Default to bf16 for stability
171
+ )
172
+
173
+ # Make one log on every process with the configuration for debugging.
174
+ logging.basicConfig(
175
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
176
+ datefmt="%m/%d/%Y %H:%M:%S",
177
+ level=logging.INFO,
178
+ )
179
+ logger.info(accelerator.state, main_process_only=False)
180
+
181
+ # Set seed for reproducibility
182
+ set_seed(seed)
183
+
184
+ if accelerator.is_main_process:
185
+ logger.info("Initialized with CLI arguments.")
186
+ tensorboard_dir.mkdir(parents=True, exist_ok=True)
187
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
188
+ accelerator.init_trackers("oracle_training")
189
+
190
+ device = accelerator.device
191
+
192
+ # Determine dtype for model initialization
193
+ init_dtype = torch.float32
194
+ if accelerator.mixed_precision == 'bf16':
195
+ init_dtype = torch.bfloat16
196
+ elif accelerator.mixed_precision == 'fp16':
197
+ init_dtype = torch.float16
198
+
199
+ # --- 2. Data Setup ---
200
+ horizons = args.horizons_seconds
201
+ quantiles = args.quantiles
202
+ max_seq_len = args.max_seq_len
203
+ ohlc_seq_len = args.ohlc_seq_len
204
+
205
+ logger.info(f"Initializing Encoders with dtype={init_dtype}...")
206
+
207
+ # Encoders
208
+ multi_modal_encoder = MultiModalEncoder(dtype=init_dtype)
209
+ time_encoder = ContextualTimeEncoder(dtype=init_dtype)
210
+ token_encoder = TokenEncoder(multi_dim=multi_modal_encoder.embedding_dim, dtype=init_dtype)
211
+ wallet_encoder = WalletEncoder(encoder=multi_modal_encoder, dtype=init_dtype)
212
+ graph_updater = GraphUpdater(time_encoder=time_encoder, dtype=init_dtype)
213
+ ohlc_embedder = OHLCEmbedder(
214
+ num_intervals=vocab.NUM_OHLC_INTERVALS,
215
+ sequence_length=ohlc_seq_len,
216
+ dtype=init_dtype
217
+ )
218
+
219
+ collator = MemecoinCollator(
220
+ event_type_to_id=vocab.EVENT_TO_ID,
221
+ device=device, # Note: Collator will handle basic moves, Accelerate handles the rest
222
+ multi_modal_encoder=multi_modal_encoder,
223
+ dtype=init_dtype,
224
+ ohlc_seq_len=ohlc_seq_len,
225
+ max_seq_len=max_seq_len
226
+ )
227
+
228
+ # DB Connections
229
+ clickhouse_client = ClickHouseClient(
230
+ host=args.clickhouse_host,
231
+ port=int(args.clickhouse_port)
232
+ )
233
+
234
+ neo4j_auth = None
235
+ if args.neo4j_user is not None:
236
+ neo4j_auth = (args.neo4j_user, args.neo4j_password or "")
237
+ neo4j_driver = GraphDatabase.driver(args.neo4j_uri, auth=neo4j_auth)
238
+
239
+ data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
240
+
241
+ dataset = OracleDataset(
242
+ data_fetcher=data_fetcher,
243
+ horizons_seconds=horizons,
244
+ quantiles=quantiles,
245
+ max_samples=args.max_samples,
246
+ ohlc_stats_path=args.ohlc_stats_path,
247
+ t_cutoff_seconds=int(args.t_cutoff_seconds)
248
+ )
249
+
250
+ if len(dataset) == 0:
251
+ raise RuntimeError("Dataset is empty.")
252
+
253
+ dataloader = DataLoader(
254
+ dataset,
255
+ batch_size=batch_size,
256
+ shuffle=bool(args.shuffle),
257
+ num_workers=int(args.num_workers),
258
+ pin_memory=bool(args.pin_memory),
259
+ collate_fn=lambda batch: filtered_collate(collator, batch)
260
+ )
261
+
262
+ # --- 3. Model Init ---
263
+ logger.info("Initializing Oracle Model...")
264
+ model = Oracle(
265
+ token_encoder=token_encoder,
266
+ wallet_encoder=wallet_encoder,
267
+ graph_updater=graph_updater,
268
+ ohlc_embedder=ohlc_embedder,
269
+ time_encoder=time_encoder,
270
+ num_event_types=vocab.NUM_EVENT_TYPES,
271
+ multi_modal_dim=multi_modal_encoder.embedding_dim,
272
+ event_pad_id=vocab.EVENT_TO_ID["__PAD__"],
273
+ event_type_to_id=vocab.EVENT_TO_ID,
274
+ model_config_name="Qwen/Qwen3-0.6B",
275
+ quantiles=quantiles,
276
+ horizons_seconds=horizons,
277
+ dtype=init_dtype
278
+ )
279
+
280
+ # Memory Optimization: Delete unused embedding layer from Qwen backbone
281
+ if hasattr(model.model, 'embed_tokens'):
282
+ del model.model.embed_tokens
283
+ logger.info("Freed unused Qwen embedding layer memory.")
284
+
285
+ # --- 4. Optimizer & Scheduler ---
286
+ optimizer = AdamW(model.parameters(), lr=learning_rate)
287
+
288
+ # Calculate training steps
289
+ num_update_steps_per_epoch = math.ceil(len(dataloader) / grad_accum_steps)
290
+ max_train_steps = epochs * num_update_steps_per_epoch
291
+ num_warmup_steps = int(max_train_steps * warmup_ratio)
292
+
293
+ scheduler = get_linear_schedule_with_warmup(
294
+ optimizer,
295
+ num_warmup_steps=num_warmup_steps,
296
+ num_training_steps=max_train_steps
297
+ )
298
+
299
+ # --- 5. Accelerate Prepare ---
300
+ model, optimizer, dataloader, scheduler = accelerator.prepare(
301
+ model, optimizer, dataloader, scheduler
302
+ )
303
+
304
+ # --- 6. Resume Training Logic ---
305
+ # Load checkpoint if it exists
306
+ starting_epoch = 0
307
+ resume_step = 0
308
+
309
+ # Check for existing checkpoints
310
+ if checkpoint_dir.exists():
311
+ # Look for subfolders named 'checkpoint-X' or 'epoch_X'
312
+ # Accelerate saves to folders.
313
+ dirs = [d for d in checkpoint_dir.iterdir() if d.is_dir()]
314
+ if dirs:
315
+ # Sort by modification time or name to find latest
316
+ dirs.sort(key=lambda x: x.stat().st_mtime)
317
+ latest_checkpoint = dirs[-1]
318
+ logger.info(f"Found checkpoint: {latest_checkpoint}. Resuming training...")
319
+ accelerator.load_state(str(latest_checkpoint))
320
+
321
+ # Try to infer epoch/step from folder name or saved state if custom tracking
322
+ # Accelerate restores DataLoader state, so we mainly need to know where we are for logging
323
+ # Assuming standard naming or just relying on DataLoader restore.
324
+ # Simple approach: Just trust Accelerate/DataLoader to skip.
325
+ # If you need precise epoch/step recovery for logging display:
326
+ # You could save a metadata.json inside the checkpoint folder.
327
+
328
+ logger.info("Checkpoint loaded. DataLoader state restored.")
329
+ else:
330
+ logger.info("No checkpoint found. Starting fresh.")
331
+
332
+ # --- 7. Training Loop ---
333
+ total_steps = 0
334
+
335
+ logger.info("***** Running training *****")
336
+ logger.info(f" Num examples = {len(dataset)}")
337
+ logger.info(f" Num Epochs = {epochs}")
338
+ logger.info(f" Instantaneous batch size per device = {batch_size}")
339
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {batch_size * accelerator.num_processes * grad_accum_steps}")
340
+ logger.info(f" Gradient Accumulation steps = {grad_accum_steps}")
341
+ logger.info(f" Total optimization steps = {max_train_steps}")
342
+
343
+ for epoch in range(starting_epoch, epochs):
344
+ model.train()
345
+ epoch_loss = 0.0
346
+ valid_batches = 0
347
+
348
+ # Tqdm only on main process
349
+ progress_bar = tqdm(
350
+ dataloader,
351
+ desc=f"Epoch {epoch+1}/{epochs}",
352
+ disable=not accelerator.is_local_main_process,
353
+ initial=resume_step # If you calculate resume_step from checkpoint
354
+ )
355
+
356
+ for step, batch in enumerate(progress_bar):
357
+ # Skip steps if resuming (Accelerate dataloader might handle this automatically if configured,
358
+ # but 'skip_first_batches' is often manual.
359
+ # For simplicity here, we assume load_state restored the dataloader iterator.)
360
+
361
+ if batch is None:
362
+ continue
363
+
364
+ # Safety Patch for missing social data
365
+ if 'textual_event_indices' not in batch:
366
+ B, L = batch['event_type_ids'].shape
367
+ batch['textual_event_indices'] = torch.zeros((B, L), dtype=torch.long, device=accelerator.device)
368
+ if 'textual_event_data' not in batch:
369
+ batch['textual_event_data'] = []
370
+
371
+ grad_stats: Optional[Dict[str, float]] = None
372
+ module_grad_stats: Dict[str, float] = {}
373
+ with accelerator.accumulate(model):
374
+ outputs = model(batch)
375
+
376
+ preds = outputs["quantile_logits"]
377
+ labels = batch["labels"]
378
+ labels_mask = batch["labels_mask"]
379
+
380
+ if labels_mask.sum() == 0:
381
+ loss = torch.tensor(0.0, requires_grad=True, device=accelerator.device)
382
+ else:
383
+ loss = quantile_pinball_loss(preds, labels, labels_mask, quantiles)
384
+
385
+ accelerator.backward(loss)
386
+
387
+ if accelerator.sync_gradients:
388
+ accelerator.clip_grad_norm_(model.parameters(), max_grad_norm)
389
+ grad_stats, module_grad_stats = compute_gradient_stats(model)
390
+ if grad_stats and accelerator.is_main_process:
391
+ logger.info(
392
+ "Gradients - mean norm: %.4f | max norm: %.4f | max abs: %.4f",
393
+ grad_stats["grad_norm_mean"],
394
+ grad_stats["grad_norm_max"],
395
+ grad_stats["grad_abs_max"],
396
+ )
397
+ if module_grad_stats:
398
+ module_entries = " | ".join(
399
+ f"{name}: {norm:.4f}" for name, norm in sorted(module_grad_stats.items())
400
+ )
401
+ logger.info("Per-module grad norms: %s", module_entries)
402
+
403
+ optimizer.step()
404
+ scheduler.step()
405
+ optimizer.zero_grad()
406
+
407
+ # Logging
408
+ if accelerator.sync_gradients:
409
+ total_steps += 1
410
+ current_loss = loss.item()
411
+ epoch_loss += current_loss
412
+ valid_batches += 1
413
+
414
+ if total_steps % log_every == 0:
415
+ lr = scheduler.get_last_lr()[0]
416
+ log_payload = {
417
+ "train/loss": current_loss,
418
+ "train/learning_rate": lr,
419
+ "train/epoch": epoch + (step / len(dataloader))
420
+ }
421
+ if grad_stats:
422
+ log_payload.update({
423
+ "train/grad_norm_mean": grad_stats["grad_norm_mean"],
424
+ "train/grad_norm_max": grad_stats["grad_norm_max"],
425
+ "train/grad_abs_max": grad_stats["grad_abs_max"],
426
+ })
427
+ accelerator.log(log_payload, step=total_steps)
428
+
429
+ if accelerator.is_main_process:
430
+ progress_bar.set_postfix({"loss": f"{current_loss:.4f}", "lr": f"{lr:.2e}"})
431
+ if grad_stats:
432
+ logger.info(
433
+ "Step %d | loss %.4f | grad_norm %.4f",
434
+ total_steps,
435
+ current_loss,
436
+ grad_stats["grad_norm_mean"],
437
+ )
438
+
439
+ # Save Checkpoint periodically
440
+ if total_steps % save_every == 0:
441
+ if accelerator.is_main_process:
442
+ save_path = checkpoint_dir / f"checkpoint-{total_steps}"
443
+ accelerator.save_state(output_dir=str(save_path))
444
+ logger.info(f"Saved checkpoint to {save_path}")
445
+
446
+ # End of Epoch Handling
447
+ if valid_batches > 0:
448
+ avg_loss = epoch_loss / valid_batches
449
+ if accelerator.is_main_process:
450
+ logger.info(f"Epoch {epoch+1} complete. Avg loss: {avg_loss:.6f}")
451
+ accelerator.log({"train/loss_epoch": avg_loss}, step=global_step)
452
+
453
+ # Save Checkpoint at end of epoch
454
+ save_path = checkpoint_dir / f"epoch_{epoch+1}"
455
+ accelerator.save_state(output_dir=str(save_path))
456
+ logger.info(f"Saved checkpoint to {save_path}")
457
+ else:
458
+ if accelerator.is_main_process:
459
+ logger.warning(f"Epoch {epoch+1}: No valid batches processed.")
460
+
461
+ accelerator.end_training()
462
+ neo4j_driver.close()
463
+
464
+ if __name__ == "__main__":
465
+ main()
train.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate launch train.py \
2
+ --epochs 1 \
3
+ --batch_size 1 \
4
+ --learning_rate 1e-4 \
5
+ --warmup_ratio 0.1 \
6
+ --grad_accum_steps 1 \
7
+ --max_grad_norm 1.0 \
8
+ --seed 42 \
9
+ --log_every 1 \
10
+ --save_every 1000 \
11
+ --tensorboard_dir runs/oracle \
12
+ --checkpoint_dir checkpoints \
13
+ --mixed_precision bf16 \
14
+ --max_seq_len 50 \
15
+ --ohlc_seq_len 300 \
16
+ --horizons_seconds 30 60 120 240 420 \
17
+ --quantiles 0.1 0.5 0.9 \
18
+ --ohlc_stats_path ./data/ohlc_stats.npz \
19
+ --t_cutoff_seconds 60 \
20
+ --num_workers 4 \
21
+ --clickhouse_host localhost \
22
+ --clickhouse_port 9000 \
23
+ --neo4j_uri bolt://localhost:7687
train.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ training:
2
+ epochs: 1
3
+ batch_size: 1
4
+ learning_rate: 5.0e-05
5
+ use_amp: true
6
+ log_every: 1
7
+ disable_tqdm: false
8
+ tensorboard_logdir: runs/oracle
9
+ checkpoint_path: checkpoints/oracle_checkpoint.pt
10
+
11
+ data:
12
+ max_samples: null
13
+ horizons_seconds: [30, 60, 120, 240, 420]
14
+ quantiles: [0.1, 0.5, 0.9]
15
+ max_seq_len: 50
16
+ ohlc_seq_len: 300
17
+ ohlc_stats_path: ./data/ohlc_stats.npz
18
+ t_cutoff_seconds: 60
19
+ shuffle: true
20
+ num_workers: 0
21
+ pin_memory: false
22
+
23
+ databases:
24
+ clickhouse:
25
+ host: localhost
26
+ port: 9000
27
+ neo4j:
28
+ uri: bolt://localhost:7687
29
+ user: null
30
+ password: null
utils.sql ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ OPTIMIZE TABLE wallet_profiles FINAL;
4
+ OPTIMIZE TABLE wallet_profile_metrics_latest FINAL;
5
+ OPTIMIZE TABLE wallet_holdings_latest FINAL;
6
+ OPTIMIZE TABLE tokens_latest FINAL;
7
+ OPTIMIZE TABLE token_metrics_latest FINAL;
8
+
9
+
10
+ TRUNCATE TABLE wallet_holdings;
11
+ TRUNCATE TABLE trades;
12
+ TRUNCATE TABLE transfers;
13
+ TRUNCATE TABLE burns;
14
+ TRUNCATE TABLE tokens;
15
+ TRUNCATE TABLE mints;
16
+ TRUNCATE TABLE liquidity;
17
+ TRUNCATE TABLE pool_creations;
18
+ TRUNCATE TABLE token_metrics;
19
+ TRUNCATE TABLE wallet_profile_metrics;
20
+ TRUNCATE TABLE migrations;
21
+ TRUNCATE TABLE fee_collections;
22
+ TRUNCATE TABLE supply_locks;
23
+ TRUNCATE TABLE supply_lock_actions;
24
+
25
+
26
+ TRUNCATE TABLE wallet_profile_metrics_latest;
27
+ TRUNCATE TABLE wallet_holdings_latest;
28
+ TRUNCATE TABLE token_metrics_latest;
29
+ TRUNCATE TABLE tokens_latest;
30
+ TRUNCATE TABLE wallet_profiles;
31
+
32
+
33
+ DROP TABLE IF EXISTS trades;
34
+ DROP TABLE IF EXISTS mints;
35
+ DROP TABLE IF EXISTS migrations;
36
+ DROP TABLE IF EXISTS fee_collections;
37
+ DROP TABLE IF EXISTS liquidity;
38
+ DROP TABLE IF EXISTS pool_creations;
39
+ DROP TABLE IF EXISTS transfers;
40
+ DROP TABLE IF EXISTS burns;
41
+ DROP TABLE IF EXISTS wallet_profiles;
42
+ DROP TABLE IF EXISTS wallet_holdings;
43
+ DROP TABLE IF EXISTS wallet_profile_metrics;
44
+ DROP TABLE IF EXISTS wallet_profile_metrics_latest;
45
+ DROP TABLE IF EXISTS tokens;
46
+ DROP TABLE IF EXISTS token_metrics;
47
+ DROP TABLE IF EXISTS token_metrics_latest;
48
+ DROP TABLE IF EXISTS supply_locks;
49
+ DROP TABLE IF EXISTS supply_lock_actions;
50
+ DROP TABLE IF EXISTS wallet_holdings_latest;
51
+ DROP TABLE IF EXISTS tokens_latest;
52
+
53
+
54
+ -- Backfilling Logic
55
+
56
+ CREATE TABLE IF NOT EXISTS tokens_backfill
57
+ (
58
+ token_address String,
59
+ name String,
60
+ symbol String,
61
+ token_uri String,
62
+ is_mutable UInt8,
63
+ update_authority Nullable(String),
64
+ mint_authority Nullable(String),
65
+ freeze_authority Nullable(String),
66
+ protocol UInt8
67
+ )
68
+ ENGINE = MergeTree
69
+ ORDER BY token_address;
validate.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import yaml
3
+ import torch
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ from clickhouse_driver import Client as ClickHouseClient
8
+ from neo4j import GraphDatabase
9
+
10
+ from data.data_fetcher import DataFetcher
11
+ from data.data_loader import OracleDataset
12
+ from data.data_collator import MemecoinCollator
13
+ from models.multi_modal_processor import MultiModalEncoder
14
+ from models.helper_encoders import ContextualTimeEncoder
15
+ from models.token_encoder import TokenEncoder
16
+ from models.wallet_encoder import WalletEncoder
17
+ from models.graph_updater import GraphUpdater
18
+ from models.ohlc_embedder import OHLCEmbedder
19
+ from models.model import Oracle
20
+ import models.vocabulary as vocab
21
+
22
+
23
+ def quantile_pinball_loss(preds: torch.Tensor,
24
+ targets: torch.Tensor,
25
+ mask: torch.Tensor,
26
+ quantiles: List[float]) -> torch.Tensor:
27
+ if mask.sum() == 0:
28
+ return torch.tensor(0.0, device=preds.device, dtype=preds.dtype)
29
+ num_q = len(quantiles)
30
+ losses = []
31
+ for idx, q in enumerate(quantiles):
32
+ pred_slice = preds[:, idx::num_q]
33
+ target_slice = targets[:, idx::num_q]
34
+ mask_slice = mask[:, idx::num_q]
35
+ diff = target_slice - pred_slice
36
+ pinball = torch.maximum((q - 1.0) * diff, q * diff)
37
+ losses.append((pinball * mask_slice).sum())
38
+ return sum(losses) / mask.sum().clamp_min(1.0)
39
+
40
+
41
+ def load_config(path: str) -> Dict[str, Any]:
42
+ cfg_path = Path(path)
43
+ if not cfg_path.exists():
44
+ raise FileNotFoundError(f"Config file not found: {cfg_path}")
45
+ with cfg_path.open("r") as handle:
46
+ return yaml.safe_load(handle) or {}
47
+
48
+
49
+ def parse_args() -> argparse.Namespace:
50
+ parser = argparse.ArgumentParser(description="Validate Oracle checkpoint on a single token.")
51
+ parser.add_argument("--config", type=str, default="train.yaml", help="Path to training YAML config.")
52
+ parser.add_argument("--checkpoint", type=str, default=None, help="Checkpoint to load. Defaults to config training.checkpoint_path.")
53
+ parser.add_argument("--sample-idx", type=int, default=0, help="Dataset index to validate.")
54
+ parser.add_argument("--token-address", type=str, default=None, help="Optional mint address to pick instead of index.")
55
+ parser.add_argument("--t-cutoff-seconds", type=int, default=None, help="Override cutoff horizon (seconds after mint).")
56
+ return parser.parse_args()
57
+
58
+
59
+ def resolve_sample_index(dataset: OracleDataset,
60
+ sample_idx: int,
61
+ token_address: Optional[str]) -> int:
62
+ if token_address:
63
+ for idx, mint in enumerate(getattr(dataset, "sampled_mints", [])):
64
+ if mint.get("mint_address") == token_address:
65
+ return idx
66
+ raise ValueError(f"Token {token_address} not found in loaded dataset.")
67
+ if sample_idx < 0 or sample_idx >= len(dataset):
68
+ raise ValueError(f"Sample index {sample_idx} out of range (len={len(dataset)}).")
69
+ return sample_idx
70
+
71
+
72
+ def move_to_device(batch: Dict[str, Any], device: torch.device) -> Dict[str, Any]:
73
+ for key, value in list(batch.items()):
74
+ if torch.is_tensor(value):
75
+ batch[key] = value.to(device)
76
+ return batch
77
+
78
+
79
+ def main() -> None:
80
+ args = parse_args()
81
+ config = load_config(args.config)
82
+
83
+ training_cfg = config.get("training", {})
84
+ data_cfg = config.get("data", {})
85
+ db_cfg = config.get("databases", {})
86
+
87
+ checkpoint_path = Path(args.checkpoint or training_cfg.get("checkpoint_path", "checkpoints/oracle_checkpoint.pt")).expanduser()
88
+ if not checkpoint_path.exists():
89
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
90
+
91
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
92
+ dtype = torch.bfloat16 if device.type == "cuda" and torch.cuda.is_bf16_supported() else torch.float16
93
+ if device.type == "cpu":
94
+ dtype = torch.float32
95
+
96
+ quantiles = data_cfg.get("quantiles", [0.1, 0.5, 0.9])
97
+ horizons = data_cfg.get("horizons_seconds", [30, 60, 120, 240, 420])
98
+ max_samples = data_cfg.get("max_samples", None)
99
+ max_seq_len = data_cfg.get("max_seq_len", 50)
100
+ ohlc_seq_len = data_cfg.get("ohlc_seq_len", 60)
101
+ default_t_cutoff = int(data_cfg.get("t_cutoff_seconds", 60))
102
+ t_cutoff_seconds = int(args.t_cutoff_seconds) if args.t_cutoff_seconds is not None else default_t_cutoff
103
+ ohlc_stats_path = data_cfg.get("ohlc_stats_path", "./data/ohlc_stats.npz")
104
+
105
+ multi_modal_encoder = MultiModalEncoder(dtype=dtype)
106
+ time_encoder = ContextualTimeEncoder(dtype=dtype)
107
+ token_encoder = TokenEncoder(multi_dim=multi_modal_encoder.embedding_dim, dtype=dtype)
108
+ wallet_encoder = WalletEncoder(encoder=multi_modal_encoder, dtype=dtype)
109
+ graph_updater = GraphUpdater(time_encoder=time_encoder, dtype=dtype)
110
+ ohlc_embedder = OHLCEmbedder(
111
+ num_intervals=vocab.NUM_OHLC_INTERVALS,
112
+ sequence_length=ohlc_seq_len,
113
+ dtype=dtype
114
+ )
115
+
116
+ collator = MemecoinCollator(
117
+ event_type_to_id=vocab.EVENT_TO_ID,
118
+ device=device,
119
+ multi_modal_encoder=multi_modal_encoder,
120
+ dtype=dtype,
121
+ ohlc_seq_len=ohlc_seq_len,
122
+ max_seq_len=max_seq_len
123
+ )
124
+
125
+ clickhouse_cfg = db_cfg.get("clickhouse", {})
126
+ clickhouse_client = ClickHouseClient(
127
+ host=clickhouse_cfg.get("host", "localhost"),
128
+ port=int(clickhouse_cfg.get("port", 9000))
129
+ )
130
+
131
+ neo4j_cfg = db_cfg.get("neo4j", {})
132
+ neo4j_auth = None
133
+ if neo4j_cfg.get("user") is not None:
134
+ neo4j_auth = (neo4j_cfg.get("user"), neo4j_cfg.get("password") or "")
135
+ neo4j_driver = GraphDatabase.driver(neo4j_cfg.get("uri", "bolt://localhost:7687"), auth=neo4j_auth)
136
+
137
+ data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
138
+ dataset = OracleDataset(
139
+ data_fetcher=data_fetcher,
140
+ horizons_seconds=horizons,
141
+ quantiles=quantiles,
142
+ max_samples=max_samples,
143
+ ohlc_stats_path=ohlc_stats_path,
144
+ token_allowlist=[args.token_address] if args.token_address else None,
145
+ t_cutoff_seconds=t_cutoff_seconds
146
+ )
147
+ if len(dataset) == 0:
148
+ raise RuntimeError("Dataset is empty; cannot validate.")
149
+
150
+ sample_idx = resolve_sample_index(dataset, args.sample_idx, args.token_address)
151
+ sample = dataset[sample_idx]
152
+ if sample is None:
153
+ raise RuntimeError(f"Dataset returned None for sample index {sample_idx}.")
154
+
155
+ token_address = getattr(dataset, "sampled_mints", [{}])[sample_idx].get("mint_address", "Unknown")
156
+ print(f"Validating token {token_address} (dataset idx {sample_idx}) with T_cutoff {t_cutoff_seconds} second(s) after mint")
157
+
158
+ collated = collator([sample])
159
+ collated = move_to_device(collated, device)
160
+
161
+ model = Oracle(
162
+ token_encoder=token_encoder,
163
+ wallet_encoder=wallet_encoder,
164
+ graph_updater=graph_updater,
165
+ ohlc_embedder=ohlc_embedder,
166
+ time_encoder=time_encoder,
167
+ num_event_types=vocab.NUM_EVENT_TYPES,
168
+ multi_modal_dim=multi_modal_encoder.embedding_dim,
169
+ event_pad_id=vocab.EVENT_TO_ID["__PAD__"],
170
+ event_type_to_id=vocab.EVENT_TO_ID,
171
+ quantiles=quantiles,
172
+ horizons_seconds=horizons,
173
+ dtype=dtype
174
+ ).to(device)
175
+ checkpoint = torch.load(checkpoint_path, map_location=device)
176
+ model.load_state_dict(checkpoint["model_state_dict"])
177
+ model.eval()
178
+
179
+ with torch.no_grad():
180
+ outputs = model(collated)
181
+ preds = outputs["quantile_logits"]
182
+ labels = collated["labels"]
183
+ labels_mask = collated["labels_mask"]
184
+
185
+ loss = quantile_pinball_loss(preds, labels, labels_mask, quantiles).item()
186
+ print(f"Pinball loss (masked): {loss:.6f}")
187
+
188
+ B = preds.shape[0]
189
+ grid = preds.view(B, len(horizons), len(quantiles))
190
+ label_grid = labels.view(B, len(horizons), len(quantiles))
191
+ mask_grid = labels_mask.view(B, len(horizons), len(quantiles))
192
+
193
+ for b in range(B):
194
+ print(f"\nSample {b} predictions:")
195
+ for h_idx, horizon in enumerate(horizons):
196
+ pred_row = grid[b, h_idx]
197
+ label_row = label_grid[b, h_idx]
198
+ mask_row = mask_grid[b, h_idx]
199
+ row_str = ", ".join(
200
+ f"q={quantiles[q_idx]:.2f}: pred={pred_row[q_idx].item():.6f}, "
201
+ f"label={label_row[q_idx].item():.6f}, mask={int(mask_row[q_idx].item())}"
202
+ for q_idx in range(len(quantiles))
203
+ )
204
+ print(f" Horizon {horizon:>4}s -> {row_str}")
205
+
206
+ neo4j_driver.close()
207
+
208
+
209
+ if __name__ == "__main__":
210
+ main()
validate.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ python validate.py \
2
+ --config train.yaml \
3
+ --checkpoint checkpoints/oracle_checkpoint.pt \
4
+ --t-cutoff-seconds 240 \
5
+ --token-address 'czaE9hrSWJ6g21bxS6qh9GbbczoRa5F5Lx2eo1apump'
6
+