Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- .gitignore +12 -0
- FullCryptoGuide.md +0 -0
- README.md +776 -0
- data/data_collator.py +708 -0
- data/data_fetcher.py +1009 -0
- data/data_loader.py +1657 -0
- data/ohlc_stats.npz +3 -0
- data/preprocess_distribution.py +164 -0
- graph_schema.rs +115 -0
- inference.py +271 -0
- link_graph.rs +2275 -0
- log.log +3 -0
- models/HoldersEncoder.py +81 -0
- models/SocialEncoders.py +245 -0
- models/__init__.py +0 -0
- models/graph_updater.py +486 -0
- models/helper_encoders.py +87 -0
- models/model.py +1009 -0
- models/multi_modal_processor.py +184 -0
- models/ohlc_embedder.py +114 -0
- models/token_encoder.py +182 -0
- models/vocabulary.py +188 -0
- models/wallet_encoder.py +262 -0
- models/wallet_set_encoder.py +99 -0
- neo4j.rs +121 -0
- ohlc_stats.npz +3 -0
- onchain.sql +599 -0
- pre_cache.sh +6 -0
- scripts/cache_dataset.py +148 -0
- scripts/download_epoch_artifacts.py +96 -0
- scripts/ingest_epoch.py +713 -0
- train.py +465 -0
- train.sh +23 -0
- train.yaml +30 -0
- utils.sql +69 -0
- validate.py +210 -0
- validate.sh +6 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
log.log filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ignore the __pycache__ directory anywhere in the repository
|
| 2 |
+
__pycache__/
|
| 3 |
+
|
| 4 |
+
# Ignore all .txt files anywhere in the repository
|
| 5 |
+
*.txt
|
| 6 |
+
|
| 7 |
+
# Ignore the 'runs' directory anywhere in the repository, regardless of nesting
|
| 8 |
+
runs/
|
| 9 |
+
|
| 10 |
+
data/pump_fun
|
| 11 |
+
|
| 12 |
+
.env
|
FullCryptoGuide.md
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
README.md
ADDED
|
@@ -0,0 +1,776 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =========================================
|
| 2 |
+
# Entity Encoders
|
| 3 |
+
# =========================================
|
| 4 |
+
# These are generated offline/streaming and are the "vocabulary" for the model.
|
| 5 |
+
|
| 6 |
+
<WalletEmbedding> # Embedding of a wallet's relationships, behavior, and history.
|
| 7 |
+
<WalletEmbedding> = [
|
| 8 |
+
// Data from the 'wallet_profiles' table (Wallet-level lifetime and daily/weekly stats)
|
| 9 |
+
wallet_profiles_row: [
|
| 10 |
+
// Core Info & Timestamps
|
| 11 |
+
age, // No Contextual
|
| 12 |
+
wallet_address, // Primary wallet identifier
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
// 7. NEW: Deployed Token Aggregates (8 Features)
|
| 16 |
+
deployed_tokens_count, // Total tokens created
|
| 17 |
+
deployed_tokens_migrated_pct, // % that migrated
|
| 18 |
+
deployed_tokens_avg_lifetime_sec, // Avg duration before dev selling
|
| 19 |
+
deployed_tokens_avg_peak_mc_usd, // Avg peak marketcap
|
| 20 |
+
deployed_tokens_median_peak_mc_usd,
|
| 21 |
+
|
| 22 |
+
// Metadata & Balances
|
| 23 |
+
balance, // Current SOL balance
|
| 24 |
+
|
| 25 |
+
// Lifetime Transaction Counts (Total history)
|
| 26 |
+
transfers_in_count, // Total native transfers received
|
| 27 |
+
transfers_out_count, // Total native transfers sent
|
| 28 |
+
spl_transfers_in_count, // Total SPL token transfers received
|
| 29 |
+
spl_transfers_out_count,// Total SPL token transfers sent
|
| 30 |
+
|
| 31 |
+
// Lifetime Trading Stats (Total history)
|
| 32 |
+
total_buys_count, // Total buys across all tokens
|
| 33 |
+
total_sells_count, // Total sells across all tokens
|
| 34 |
+
total_winrate, // Overall trading winrate
|
| 35 |
+
|
| 36 |
+
// 1-Day Stats (Realized P&L, Counts, Averages, Volume, Fees, Winrate)
|
| 37 |
+
stats_1d_realized_profit_sol,
|
| 38 |
+
stats_1d_realized_profit_pnl,
|
| 39 |
+
stats_1d_buy_count,
|
| 40 |
+
stats_1d_sell_count,
|
| 41 |
+
stats_1d_transfer_in_count,
|
| 42 |
+
stats_1d_transfer_out_count,
|
| 43 |
+
stats_1d_avg_holding_period,
|
| 44 |
+
stats_1d_total_bought_cost_sol,
|
| 45 |
+
stats_1d_total_sold_income_sol,
|
| 46 |
+
stats_1d_total_fee,
|
| 47 |
+
stats_1d_winrate,
|
| 48 |
+
stats_1d_tokens_traded,
|
| 49 |
+
|
| 50 |
+
// 7-Day Stats (Realized P&L, Counts, Averages, Volume, Fees, Winrate)
|
| 51 |
+
stats_7d_realized_profit_sol,
|
| 52 |
+
stats_7d_realized_profit_pnl,
|
| 53 |
+
stats_7d_buy_count,
|
| 54 |
+
stats_7d_sell_count,
|
| 55 |
+
stats_7d_transfer_in_count,
|
| 56 |
+
stats_7d_transfer_out_count,
|
| 57 |
+
stats_7d_avg_holding_period,
|
| 58 |
+
stats_7d_total_bought_cost_sol,
|
| 59 |
+
stats_7d_total_sold_income_sol,
|
| 60 |
+
stats_7d_total_fee,
|
| 61 |
+
stats_7d_winrate,
|
| 62 |
+
stats_7d_tokens_traded,
|
| 63 |
+
|
| 64 |
+
// 30 Days is to useless in the context
|
| 65 |
+
],
|
| 66 |
+
|
| 67 |
+
// Data from the 'wallet_socials' table (Social media and profile info)
|
| 68 |
+
wallet_socials_row: [
|
| 69 |
+
has_pf_profile,
|
| 70 |
+
has_twitter,
|
| 71 |
+
has_telegram,
|
| 72 |
+
is_exchange_wallet,
|
| 73 |
+
username,
|
| 74 |
+
],
|
| 75 |
+
// Data from the 'wallet_holdings' table (Token-level statistics for held tokens)
|
| 76 |
+
wallet_holdings_pool: [
|
| 77 |
+
<TokenVibeEmbedding>,
|
| 78 |
+
holding_time, // How much he held the token (We check only tokens that currently is holding, or recently traded)
|
| 79 |
+
|
| 80 |
+
balance_pct_to_supply, // Current quantity of the token held
|
| 81 |
+
|
| 82 |
+
// History (Amounts & Costs)
|
| 83 |
+
history_bought_amount_sol, // Total amount of token bought
|
| 84 |
+
bought_amount_sol_pct_to_native_balance // Is he traded a lot of his wallet size
|
| 85 |
+
|
| 86 |
+
// History (Counts)
|
| 87 |
+
history_total_buys, // Total number of buy transactions
|
| 88 |
+
history_total_sells, // Total number of sell transactions
|
| 89 |
+
|
| 90 |
+
// Profit and Loss
|
| 91 |
+
realized_profit_pnl, // Realized P&L as a percentage
|
| 92 |
+
realized_profit_sol,
|
| 93 |
+
|
| 94 |
+
// Transfers (Non-trade movements)
|
| 95 |
+
history_transfer_in,
|
| 96 |
+
history_transfer_out,
|
| 97 |
+
|
| 98 |
+
avarage_trade_gap_seconds,
|
| 99 |
+
total_priority_fees, // Total tips + Priority Fees
|
| 100 |
+
]
|
| 101 |
+
]
|
| 102 |
+
|
| 103 |
+
<TokenVibeEmbedding> # Multimodal embedding of a token's identity
|
| 104 |
+
<TokenVibeEmbedding> = [<TokenAddressEmbedding>, <NameEmbedding>, <SymbolEmbedding>, <ImageEmbedding>, protocol_id]
|
| 105 |
+
|
| 106 |
+
<TextEmbedding> # Text embedding MultiModal processor.
|
| 107 |
+
<MediaEmbedding> # Multimodal VIT encoder.
|
| 108 |
+
|
| 109 |
+
# -----------------------------------------
|
| 110 |
+
# 1. TradeEncoder
|
| 111 |
+
# -----------------------------------------
|
| 112 |
+
|
| 113 |
+
# Captures large-size trades from any wallet.
|
| 114 |
+
[timestamp, 'LargeTrade', relative_ts, <WalletEmbedding>, trade_direction, sol_amount, dex_platform_id, priority_fee, mev_protection, token_amount_pct_of_holding, quote_amount_pct_of_holding, slippage, price_impact, success, is_bundle, total_usd]
|
| 115 |
+
|
| 116 |
+
# Captures the high-signal "Dev Sold or Bought" event.
|
| 117 |
+
[timestamp, 'Deployer_Trade', relative_ts, <CreatorWalletEmbedding>, trade_direction, sol_amount, dex_platform_id, priority_fee, mev_protection, token_amount_pct_of_holding, quote_amount_pct_of_holding, slippage, price_impact, success, is_bundle, total_usd]
|
| 118 |
+
|
| 119 |
+
# Captures *all* trades from pre-defined high-P&L/win-rate, kol and known wallets.
|
| 120 |
+
[timestamp, 'SmartWallet_Trade', relative_ts, <TraderWalletEmbedding>, trade_direction, sol_amount, dex_platform_id, priority_fee, mev_protection, token_amount_pct_of_holding, quote_amount_pct_of_holding, slippage, price_impact, success, is_bundle, total_usd]
|
| 121 |
+
|
| 122 |
+
# Raw trades. Loaded in H/B/H Prefix (first ~10k) and Suffix (last ~5k).
|
| 123 |
+
[timestamp, 'Trade', relative_ts, <TraderWalletEmbedding>, trade_direction, sol_amount, dex_platform_id, priority_fee, mev_protection, token_amount_pct_of_holding, quote_amount_pct_of_holding, slippage, price_impact, success, is_bundle, total_usd]
|
| 124 |
+
|
| 125 |
+
# -----------------------------------------
|
| 126 |
+
# 2. TransferEncoder
|
| 127 |
+
# -----------------------------------------
|
| 128 |
+
|
| 129 |
+
# Raw transfers. Loaded in H/B/H Prefix (all in first ~10k trade window) and Suffix (all in last ~5k trade window).
|
| 130 |
+
[timestamp, 'Transfer', relative_ts, <SourceWalletEmbedding>, <DestinationWalletEmbedding>, token_amount, transfer_pct_of_total_supply, transfer_pct_of_holding, priority_fee]
|
| 131 |
+
|
| 132 |
+
# Captures scarce, large transfers *after* the initial launch window.
|
| 133 |
+
[timestamp, 'LargeTransfer', relative_ts, <FromWalletEmbedding>, <ToWalletEmbedding>, token_amount, transfer_pct_of_total_supply, transfer_pct_of_holding, priority_fee]
|
| 134 |
+
|
| 135 |
+
# -----------------------------------------
|
| 136 |
+
# 3. LifecycleEncoder
|
| 137 |
+
# -----------------------------------------
|
| 138 |
+
|
| 139 |
+
# The T0 event.
|
| 140 |
+
[timestamp, 'Mint', 0, <CreatorWalletEmbedding>, <TokenVibeEmbedding>]
|
| 141 |
+
|
| 142 |
+
# -----------------------------------------
|
| 143 |
+
# 3. PoolEncoder
|
| 144 |
+
# -----------------------------------------
|
| 145 |
+
|
| 146 |
+
# Signals migration from launchpad to a real pool.
|
| 147 |
+
[timestamp, 'PoolCreated', relative_ts, <ProviderWalletEmbedding>, protocol_id, <QuoteTokenVibeEmbedding>, base_amount, quote_amount, quote_pct_to_main_pool_balance, base_pct_to_main_pool_balance]
|
| 148 |
+
|
| 149 |
+
# Signals LP addition or removal.
|
| 150 |
+
[timestamp, 'LiquidityChange', relative_ts, <ProviderWalletEmbedding>, <QuoteTokenVibeEmbedding>, change_type_id, quote_amount, quote_pct_to_current_pool_balance]
|
| 151 |
+
|
| 152 |
+
# Signals creator/dev taking platform fees.
|
| 153 |
+
[timestamp, 'FeeCollected', relative_ts, <RecipientWalletEmbedding>, sol_amount, token_amount]
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
# -----------------------------------------
|
| 157 |
+
# SupplyEncoder
|
| 158 |
+
# -----------------------------------------
|
| 159 |
+
|
| 160 |
+
# Signals a supply reduction.
|
| 161 |
+
[timestamp, 'TokenBurn', relative_ts, <BurnerWalletEmbedding>, amount_pct_of_total_supply, amount_tokens_burned]
|
| 162 |
+
|
| 163 |
+
# Signals locked supply, e.g., for team/marketing.
|
| 164 |
+
[timestamp, 'SupplyLock', relative_ts, <LockerWalletEmbedding>, amount_pct_of_total_supply, lock_duration]
|
| 165 |
+
|
| 166 |
+
# -----------------------------------------
|
| 167 |
+
# ChartEncoder
|
| 168 |
+
# -----------------------------------------
|
| 169 |
+
|
| 170 |
+
# (The "Sliding Window") This is the new chart event.
|
| 171 |
+
[timestamp, 'Chart_Segment', relative_ts, OHLC_segment, chart_interval_id]
|
| 172 |
+
|
| 173 |
+
# -----------------------------------------
|
| 174 |
+
# PulseEncoder
|
| 175 |
+
# -----------------------------------------
|
| 176 |
+
|
| 177 |
+
# It is a low-frequency event (Dynamic Interval: 5min, 15min, or 1hr based on token age).
|
| 178 |
+
[timestamp, 'OnChain_Snapshot', relative_ts, total_holders, smart_traders, kols, holder_growth_rate, top_10_holder_pct, sniper_holding_pct, rat_wallets_holding_pct, bundle_holding_pct, current_market_cap, liquidity, volume, buy_count, sell_count, total_txns, global_fees_paid]
|
| 179 |
+
|
| 180 |
+
# -----------------------------------------
|
| 181 |
+
# HoldersListEncoder
|
| 182 |
+
# -----------------------------------------
|
| 183 |
+
|
| 184 |
+
<HolderDistributionEmbedding> # Transformer-based embedding of the top holders (WalletEmbeddings + Pct).
|
| 185 |
+
|
| 186 |
+
# Token-specific holder analysis.
|
| 187 |
+
[timestamp, 'HolderSnapshot', relative_ts, <HolderDistributionEmbedding>]
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
# -----------------------------------------
|
| 191 |
+
# ChainSnapshotEncoder
|
| 192 |
+
# -----------------------------------------
|
| 193 |
+
|
| 194 |
+
# Broad chain-level market conditions.
|
| 195 |
+
[timestamp, 'ChainSnapshot', relative_ts, native_token_price_usd, gas_fee]
|
| 196 |
+
|
| 197 |
+
# Launchpad market regime (using absolute, log-normalized values).
|
| 198 |
+
[timestamp, 'Lighthouse_Snapshot', relative_ts, protocol_id, timeframe_id, total_volume, total_transactions, total_traders, total_tokens_created, total_migrations]
|
| 199 |
+
|
| 200 |
+
# -----------------------------------------
|
| 201 |
+
# TokenTrendingListEncoder
|
| 202 |
+
# -----------------------------------------
|
| 203 |
+
|
| 204 |
+
# Fires *per token* on a trending list. The high-attention "meta" signal.
|
| 205 |
+
[timestamp, 'TrendingToken', relative_ts, <TokenVibeEmbedding_of_trending_token>, list_source_id, timeframe_id, rank]
|
| 206 |
+
|
| 207 |
+
# Fires *per token* on the boosted list.
|
| 208 |
+
[timestamp, 'BoostedToken', relative_ts, <TokenVibeEmbedding_of_boosted_token>, total_boost_amount, rank]
|
| 209 |
+
|
| 210 |
+
# -----------------------------------------
|
| 211 |
+
# LaunchpadTheadEncoder
|
| 212 |
+
# -----------------------------------------
|
| 213 |
+
|
| 214 |
+
# On-platform social signal (Pump.fun comments).
|
| 215 |
+
[timestamp, 'PumpReply', relative_ts, <UserWalletEmbedding>, <ReplyTextEmbedding>]
|
| 216 |
+
|
| 217 |
+
# -----------------------------------------
|
| 218 |
+
# CTEncoder
|
| 219 |
+
# -----------------------------------------
|
| 220 |
+
|
| 221 |
+
# Off-platform social signal (Twitter).
|
| 222 |
+
[timestamp, 'XPost', relative_ts, <AuthorWalletEmbedding>, <PostTextEmbedding>, <MediaEmbedding>]
|
| 223 |
+
[timestamp, 'XRetweet', relative_ts, <RetweeterWalletEmbedding>, <OriginalAuthorWalletEmbedding>, <OriginalPostTextEmbedding>, <OriginalPostMediaEmbedding>]
|
| 224 |
+
[timestamp, 'XReply', relative_ts, <AuthorWalletEmbedding>, <PostTextEmbedding>, <MediaEmbedding>, <MainTweetEmbedding>]
|
| 225 |
+
[timestamp, 'XQuoteTweet', relative_ts, <QuoterWalletEmbedding>, <QuoterTextEmbedding>, <OriginalAuthorWalletEmbedding>, <OriginalPostTextEmbedding>, <OriginalPostMediaEmbedding>]
|
| 226 |
+
|
| 227 |
+
# -----------------------------------------
|
| 228 |
+
# GlobalTrendingEncoder
|
| 229 |
+
# -----------------------------------------
|
| 230 |
+
|
| 231 |
+
# Broader cultural trend signal (TikTok).
|
| 232 |
+
[timestamp, 'TikTok_Trending_Hashtag', relative_ts, <HashtagNameEmbedding>, rank]
|
| 233 |
+
|
| 234 |
+
# Broader cultural trend signal (Twitter).
|
| 235 |
+
[timestamp, 'XTrending_Hashtag', relative_ts, <HashtagNameEmbedding>, rank]
|
| 236 |
+
|
| 237 |
+
# -----------------------------------------
|
| 238 |
+
# TrackerEncoder
|
| 239 |
+
# -----------------------------------------
|
| 240 |
+
|
| 241 |
+
# Retail marketing signal (Paid groups).
|
| 242 |
+
[timestamp, 'AlphaGroup_Call', relative_ts, group_id]
|
| 243 |
+
|
| 244 |
+
[timestamp, 'Call_Channel', relative_ts, channel_id]
|
| 245 |
+
|
| 246 |
+
# High-impact catalyst event.
|
| 247 |
+
[timestamp, 'CexListing', relative_ts, exchange_id]
|
| 248 |
+
|
| 249 |
+
# High-impact catalyst event.
|
| 250 |
+
[timestamp, 'Migrated', relative_ts, protocol_id]
|
| 251 |
+
|
| 252 |
+
# -----------------------------------------
|
| 253 |
+
# Dex Encoder
|
| 254 |
+
# -----------------------------------------
|
| 255 |
+
|
| 256 |
+
[timestamp, 'DexBoost_Paid', relative_ts, amount, total_amount_on_token]
|
| 257 |
+
|
| 258 |
+
[timestamp, 'DexProfile_Updated', relative_ts, has_changed_website_flag, has_changed_twitter_flag, has_changed_telegram_flag, has_changed_description_flag, <WebsiteEmbedding>, <TwitterLinkEmbedding>, <NewDescriptionEmbeeded>]
|
| 259 |
+
|
| 260 |
+
### **Global Context Injection**
|
| 261 |
+
|
| 262 |
+
<PRELAUNCH> <LAUNCH> <Middle> <RECENT>
|
| 263 |
+
|
| 264 |
+
### **Token Role Embedding**
|
| 265 |
+
|
| 266 |
+
<TokenVibeEmbedding_of_Token_A> + Subject_Token_Role
|
| 267 |
+
|
| 268 |
+
<TokenVibeEmbedding_of_Token_B> + Trending_Token_Role
|
| 269 |
+
|
| 270 |
+
<QuoteTokenVibeEmbedding_of_USDC> + Quote_Token_Role
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
# **Links**
|
| 274 |
+
|
| 275 |
+
### `TransferLink`
|
| 276 |
+
|
| 277 |
+
```
|
| 278 |
+
['signature', 'source', 'destination', 'mint', 'timestamp']
|
| 279 |
+
```
|
| 280 |
+
|
| 281 |
+
-----
|
| 282 |
+
|
| 283 |
+
### `BundleTradeLink`
|
| 284 |
+
|
| 285 |
+
```
|
| 286 |
+
['signatures', 'wallet_a', 'wallet_b', 'mint', 'slot', 'timestamp']
|
| 287 |
+
```
|
| 288 |
+
|
| 289 |
+
-----
|
| 290 |
+
|
| 291 |
+
### `CopiedTradeLink`
|
| 292 |
+
|
| 293 |
+
```
|
| 294 |
+
['leader_buy_sig', 'leader_sell_sig', 'follower_buy_sig', 'follower_sell_sig', 'follower', 'leader', 'mint', 'time_gap_on_buy_sec', 'time_gap_on_sell_sec', 'leader_pnl', 'follower_pnl', 'leader_buy_total', 'leader_sell_total', 'follower_buy_total', 'follower_sell_total', 'follower_buy_slippage', 'follower_sell_slippage']
|
| 295 |
+
```
|
| 296 |
+
|
| 297 |
+
-----
|
| 298 |
+
|
| 299 |
+
### `CoordinatedActivityLink`
|
| 300 |
+
|
| 301 |
+
```
|
| 302 |
+
['leader_first_sig', 'leader_second_sig', 'follower_first_sig', 'follower_second_sig', 'follower', 'leader', 'mint', 'time_gap_on_first_sec', 'time_gap_on_second_sec']
|
| 303 |
+
```
|
| 304 |
+
|
| 305 |
+
-----
|
| 306 |
+
|
| 307 |
+
### `MintedLink`
|
| 308 |
+
|
| 309 |
+
```
|
| 310 |
+
['signature', 'timestamp', 'buy_amount']
|
| 311 |
+
```
|
| 312 |
+
|
| 313 |
+
-----
|
| 314 |
+
|
| 315 |
+
### `SnipedLink`
|
| 316 |
+
|
| 317 |
+
```
|
| 318 |
+
['signature', 'rank', 'sniped_amount']
|
| 319 |
+
```
|
| 320 |
+
|
| 321 |
+
-----
|
| 322 |
+
|
| 323 |
+
### `LockedSupplyLink`
|
| 324 |
+
|
| 325 |
+
```
|
| 326 |
+
['signature', 'amount', 'unlock_timestamp']
|
| 327 |
+
```
|
| 328 |
+
|
| 329 |
+
-----
|
| 330 |
+
|
| 331 |
+
### `BurnedLink`
|
| 332 |
+
|
| 333 |
+
```
|
| 334 |
+
['signature', 'amount', 'timestamp']
|
| 335 |
+
```
|
| 336 |
+
|
| 337 |
+
-----
|
| 338 |
+
|
| 339 |
+
### `ProvidedLiquidityLink`
|
| 340 |
+
|
| 341 |
+
```
|
| 342 |
+
['signature', 'wallet', 'token', 'pool_address', 'amount_base', 'amount_quote', 'timestamp']
|
| 343 |
+
```
|
| 344 |
+
|
| 345 |
+
-----
|
| 346 |
+
|
| 347 |
+
### `WhaleOfLink`
|
| 348 |
+
|
| 349 |
+
```
|
| 350 |
+
['wallet', 'token', 'holding_pct_at_creation', 'ath_usd_at_creation']
|
| 351 |
+
```
|
| 352 |
+
|
| 353 |
+
-----
|
| 354 |
+
|
| 355 |
+
### `TopTraderOfLink`
|
| 356 |
+
|
| 357 |
+
```
|
| 358 |
+
['wallet', 'token', 'pnl_at_creation', 'ath_usd_at_creation']
|
| 359 |
+
```
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
/////
|
| 365 |
+
|
| 366 |
+
def __gettestitem__(self, idx: int) -> Dict[str, Any]:
|
| 367 |
+
"""
|
| 368 |
+
Generates a single complex data item, structured for the MemecoinCollator.
|
| 369 |
+
NOTE: This currently returns the same mock data regardless of `idx`.
|
| 370 |
+
"""
|
| 371 |
+
# --- 1. Setup Pooler and Define Raw Data ---
|
| 372 |
+
pooler = EmbeddingPooler()
|
| 373 |
+
|
| 374 |
+
# --- 5. Create Mock Raw Batch Data (FIXED) ---
|
| 375 |
+
print("Creating mock raw batch...")
|
| 376 |
+
|
| 377 |
+
# (Wallet profiles, socials, holdings definitions are unchanged)
|
| 378 |
+
profile1 = {
|
| 379 |
+
'wallet_address': 'addrW1', 'age': 1.5e7, 'balance': 10.5,
|
| 380 |
+
'deployed_tokens_count': 2, 'deployed_tokens_migrated_pct': 0.5, 'deployed_tokens_avg_lifetime_sec': 36000.0, 'deployed_tokens_avg_peak_mc_usd': 100000.0, 'deployed_tokens_median_peak_mc_usd': 50000.0,
|
| 381 |
+
'transfers_in_count': 10, 'transfers_out_count': 5, 'spl_transfers_in_count': 20, 'spl_transfers_out_count': 15,
|
| 382 |
+
'total_buys_count': 50, 'total_sells_count': 40, 'total_winrate': 0.6,
|
| 383 |
+
'stats_1d_realized_profit_sol': 1.2, 'stats_1d_realized_profit_pnl': 0.1, 'stats_1d_buy_count': 5, 'stats_1d_sell_count': 3, 'stats_1d_transfer_in_count': 2, 'stats_1d_transfer_out_count': 1, 'stats_1d_avg_holding_period': 3600, 'stats_1d_total_bought_cost_sol': 10.0, 'stats_1d_total_sold_income_sol': 11.2, 'stats_1d_total_fee': 0.1, 'stats_1d_winrate': 0.7, 'stats_1d_tokens_traded': 4,
|
| 384 |
+
'stats_7d_realized_profit_sol': 5.0, 'stats_7d_realized_profit_pnl': 0.2, 'stats_7d_buy_count': 20, 'stats_7d_sell_count': 15, 'stats_7d_transfer_in_count': 8, 'stats_7d_transfer_out_count': 4, 'stats_7d_avg_holding_period': 7200, 'stats_7d_total_bought_cost_sol': 40.0, 'stats_7d_total_sold_income_sol': 45.0, 'stats_7d_total_fee': 0.5, 'stats_7d_winrate': 0.65, 'stats_7d_tokens_traded': 10,
|
| 385 |
+
}
|
| 386 |
+
social1 = {'has_pf_profile': True, 'has_twitter': True, 'has_telegram': False, 'is_exchange_wallet': False, 'username': 'trader_one'}
|
| 387 |
+
holdings1 = [
|
| 388 |
+
{'mint_address': 'tknA', 'holding_time': 3600.0, 'realized_profit_sol': 5.2, 'total_priority_fees': 0.05, 'balance_pct_to_supply': 0.01, 'history_bought_amount_sol': 10, 'bought_amount_sol_pct_to_native_balance': 0.5, 'history_total_buys': 5, 'history_total_sells': 2, 'realized_profit_pnl': 0.52, 'history_transfer_in': 1, 'history_transfer_out': 0, 'avarage_trade_gap_seconds': 300},
|
| 389 |
+
]
|
| 390 |
+
profile2 = {
|
| 391 |
+
'wallet_address': 'addrW2', 'age': 1e6, 'balance': 1.0,
|
| 392 |
+
'deployed_tokens_count': 0, 'deployed_tokens_migrated_pct': 0.0, 'deployed_tokens_avg_lifetime_sec': 0.0, 'deployed_tokens_avg_peak_mc_usd': 0.0, 'deployed_tokens_median_peak_mc_usd': 0.0,
|
| 393 |
+
'transfers_in_count': 1, 'transfers_out_count': 0, 'spl_transfers_in_count': 0, 'spl_transfers_out_count': 0,
|
| 394 |
+
'total_buys_count': 0, 'total_sells_count': 0, 'total_winrate': 0.0,
|
| 395 |
+
'stats_1d_realized_profit_sol': 0.0, 'stats_1d_realized_profit_pnl': 0.0, 'stats_1d_buy_count': 0, 'stats_1d_sell_count': 0, 'stats_1d_transfer_in_count': 0, 'stats_1d_transfer_out_count': 0, 'stats_1d_avg_holding_period': 0, 'stats_1d_total_bought_cost_sol': 0.0, 'stats_1d_total_sold_income_sol': 0.0, 'stats_1d_total_fee': 0.0, 'stats_1d_winrate': 0.0, 'stats_1d_tokens_traded': 0,
|
| 396 |
+
'stats_7d_realized_profit_sol': 0.0, 'stats_7d_realized_profit_pnl': 0.0, 'stats_7d_buy_count': 0, 'stats_7d_sell_count': 0, 'stats_7d_transfer_in_count': 0, 'stats_7d_transfer_out_count': 0, 'stats_7d_avg_holding_period': 0, 'stats_7d_total_bought_cost_sol': 0.0, 'stats_7d_total_sold_income_sol': 0.0, 'stats_7d_total_fee': 0.0, 'stats_7d_winrate': 0.0, 'stats_7d_tokens_traded': 0,
|
| 397 |
+
}
|
| 398 |
+
social2 = {'has_pf_profile': False, 'has_twitter': False, 'has_telegram': False, 'is_exchange_wallet': True, 'username': 'cex_wallet'}
|
| 399 |
+
holdings2 = []
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
# Define raw data and get their indices
|
| 403 |
+
tokenA_data = {
|
| 404 |
+
'address_emb_idx': pooler.get_idx('tknA'),
|
| 405 |
+
'name_emb_idx': pooler.get_idx('Token A'),
|
| 406 |
+
'symbol_emb_idx': pooler.get_idx('TKA'),
|
| 407 |
+
'image_emb_idx': pooler.get_idx(Image.new('RGB',(256,256), color='blue')),
|
| 408 |
+
'protocol': 1
|
| 409 |
+
}
|
| 410 |
+
# Add wallet usernames to the pool
|
| 411 |
+
wallet1_user_idx = pooler.get_idx(social1['username'])
|
| 412 |
+
wallet2_user_idx = pooler.get_idx(social2['username'])
|
| 413 |
+
social1['username_emb_idx'] = wallet1_user_idx
|
| 414 |
+
social2['username_emb_idx'] = wallet2_user_idx
|
| 415 |
+
# --- NEW: Add a third wallet for social tests ---
|
| 416 |
+
social3 = {'has_pf_profile': False, 'has_twitter': True, 'has_telegram': True, 'is_exchange_wallet': False, 'username': 'social_butterfly'}
|
| 417 |
+
wallet3_user_idx = pooler.get_idx(social3['username'])
|
| 418 |
+
social3['username_emb_idx'] = wallet3_user_idx
|
| 419 |
+
|
| 420 |
+
# Create the final pre-computed data structures
|
| 421 |
+
tokenB_data = {
|
| 422 |
+
'address_emb_idx': pooler.get_idx('tknA'),
|
| 423 |
+
'name_emb_idx': pooler.get_idx('Token A'),
|
| 424 |
+
'symbol_emb_idx': pooler.get_idx('TKA'),
|
| 425 |
+
'image_emb_idx': pooler.get_idx(Image.new('RGB',(256,256), color='blue')),
|
| 426 |
+
'protocol': 1
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
tokenC_data = {
|
| 430 |
+
'address_emb_idx': pooler.get_idx('tknA'),
|
| 431 |
+
'name_emb_idx': pooler.get_idx('Token A'),
|
| 432 |
+
'symbol_emb_idx': pooler.get_idx('TKA'),
|
| 433 |
+
'image_emb_idx': pooler.get_idx(Image.new('RGB',(256,256), color='blue')),
|
| 434 |
+
'protocol': 1
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
tokenD_data = {
|
| 438 |
+
'address_emb_idx': pooler.get_idx('tknA'),
|
| 439 |
+
'name_emb_idx': pooler.get_idx('Token A'),
|
| 440 |
+
'symbol_emb_idx': pooler.get_idx('TKA'),
|
| 441 |
+
'image_emb_idx': pooler.get_idx(Image.new('RGB',(256,256), color='blue')),
|
| 442 |
+
'protocol': 1
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
item = {
|
| 446 |
+
'event_sequence': [
|
| 447 |
+
{'event_type': 'XPost', # NEW
|
| 448 |
+
'timestamp': 1729711350,
|
| 449 |
+
'relative_ts': -25,
|
| 450 |
+
'wallet_address': 'addrW1', # Author
|
| 451 |
+
'text_emb_idx': pooler.get_idx('This is the main tweet about $TKA'),
|
| 452 |
+
'media_emb_idx': pooler.get_idx(Image.new('RGB', (100,100), color='cyan'))
|
| 453 |
+
},
|
| 454 |
+
{'event_type': 'XReply', # NEW
|
| 455 |
+
'timestamp': 1729711360,
|
| 456 |
+
'relative_ts': -35,
|
| 457 |
+
'wallet_address': 'addrW2', # Replier
|
| 458 |
+
'text_emb_idx': pooler.get_idx('This is a reply to the main tweet'),
|
| 459 |
+
'media_emb_idx': pooler.get_idx(None), # No media in reply
|
| 460 |
+
'main_tweet_text_emb_idx': pooler.get_idx('This is the main tweet about $TKA')
|
| 461 |
+
},
|
| 462 |
+
{'event_type': 'XRetweet', # NEW
|
| 463 |
+
'timestamp': 1729711370,
|
| 464 |
+
'relative_ts': -40,
|
| 465 |
+
'wallet_address': 'addrW3', # The retweeter
|
| 466 |
+
'original_author_wallet_address': 'addrW1', # The original author
|
| 467 |
+
'original_post_text_emb_idx': pooler.get_idx('This is the main tweet about $TKA'),
|
| 468 |
+
'original_post_media_emb_idx': pooler.get_idx(Image.new('RGB', (100,100), color='cyan'))
|
| 469 |
+
},
|
| 470 |
+
# --- CORRECTED: Test a pre-launch event with negative relative_ts ---
|
| 471 |
+
{'event_type': 'Transfer',
|
| 472 |
+
'timestamp': 1729711180,
|
| 473 |
+
'relative_ts': -10, # Negative relative_ts indicates pre-launch
|
| 474 |
+
'wallet_address': 'addrW2',
|
| 475 |
+
'destination_wallet_address': 'addrW1',
|
| 476 |
+
'token_address': 'tknA',
|
| 477 |
+
'token_amount': 1000.0, 'transfer_pct_of_total_supply': 0.0, 'transfer_pct_of_holding': 0.0, 'priority_fee': 0.0
|
| 478 |
+
},
|
| 479 |
+
{'event_type': 'Mint', 'timestamp': 1729711190, 'relative_ts': 0, 'wallet_address': 'addrW1', 'token_address': 'tknA'},
|
| 480 |
+
{'event_type': 'Chart_Segment', 'timestamp': 1729711200, 'relative_ts': 60, 'opens': [1.0]*OHLC_SEQ_LEN, 'closes': [1.1]*OHLC_SEQ_LEN, 'i': '1s'}, # This is high-def (segment 0) by default
|
| 481 |
+
{'event_type': 'Chart_Segment', 'timestamp': 1729711260, 'relative_ts': 120, 'opens': [1.0]*OHLC_SEQ_LEN, 'closes': [1.1]*OHLC_SEQ_LEN, 'i': '1s'}, # You can mark this as blurry
|
| 482 |
+
{'event_type': 'Transfer',
|
| 483 |
+
'timestamp': 1729711210,
|
| 484 |
+
'relative_ts': 20,
|
| 485 |
+
'wallet_address': 'addrW1', # Source
|
| 486 |
+
'destination_wallet_address': 'addrW2', # Destination
|
| 487 |
+
'token_address': 'tknA', # Need token for context? (Optional, depends on design)
|
| 488 |
+
'token_amount': 500.0,
|
| 489 |
+
'transfer_pct_of_total_supply': 0.005,
|
| 490 |
+
'transfer_pct_of_holding': 0.1,
|
| 491 |
+
'priority_fee': 0.0001
|
| 492 |
+
},
|
| 493 |
+
{'event_type': 'Trade',
|
| 494 |
+
'timestamp': 1729711220,
|
| 495 |
+
'relative_ts': 30,
|
| 496 |
+
'wallet_address': 'addrW1',
|
| 497 |
+
'token_address': 'tknA',
|
| 498 |
+
'trade_direction': 0,
|
| 499 |
+
'sol_amount': 0.5,
|
| 500 |
+
# --- FIXED: Pass the integer ID directly ---
|
| 501 |
+
'dex_platform_id': vocab.DEX_TO_ID['Axiom'],
|
| 502 |
+
'priority_fee': 0.0002,
|
| 503 |
+
'mev_protection': False,
|
| 504 |
+
'token_amount_pct_of_holding': 0.05, 'quote_amount_pct_of_holding': 0.02,
|
| 505 |
+
'slippage': 0.01, 'price_impact': 0.005, 'success': True, 'is_bundle': False, 'total_usd': 75.0
|
| 506 |
+
},
|
| 507 |
+
{'event_type': 'Deployer_Trade', # NEW: Testing a trade variant
|
| 508 |
+
'timestamp': 1729711230,
|
| 509 |
+
'relative_ts': 40,
|
| 510 |
+
'wallet_address': 'addrW1', # The creator wallet
|
| 511 |
+
'token_address': 'tknA',
|
| 512 |
+
'trade_direction': 1, 'sol_amount': 0.2,
|
| 513 |
+
# --- FIXED: Pass the integer ID directly ---
|
| 514 |
+
'dex_platform_id': vocab.DEX_TO_ID['Trojan'],
|
| 515 |
+
'priority_fee': 0.0005,
|
| 516 |
+
'mev_protection': True,
|
| 517 |
+
'token_amount_pct_of_holding': 0.1, 'quote_amount_pct_of_holding': 0.0,
|
| 518 |
+
'slippage': 0.02, 'price_impact': 0.01, 'success': True, 'is_bundle': False, 'total_usd': 30.0
|
| 519 |
+
},
|
| 520 |
+
{'event_type': 'SmartWallet_Trade', # NEW
|
| 521 |
+
'timestamp': 1729711240,
|
| 522 |
+
'relative_ts': 50,
|
| 523 |
+
'wallet_address': 'addrW1', # A known smart wallet
|
| 524 |
+
'token_address': 'tknA',
|
| 525 |
+
'trade_direction': 0, 'sol_amount': 1.5,
|
| 526 |
+
# --- FIXED: Pass the integer ID directly ---
|
| 527 |
+
'dex_platform_id': vocab.DEX_TO_ID['Axiom'],
|
| 528 |
+
'priority_fee': 0.001,
|
| 529 |
+
'mev_protection': True,
|
| 530 |
+
'token_amount_pct_of_holding': 0.2, 'quote_amount_pct_of_holding': 0.1,
|
| 531 |
+
'slippage': 0.01, 'price_impact': 0.008, 'success': True, 'is_bundle': False, 'total_usd': 225.0
|
| 532 |
+
},
|
| 533 |
+
{'event_type': 'LargeTrade', # NEW
|
| 534 |
+
'timestamp': 1729711250,
|
| 535 |
+
'relative_ts': 60,
|
| 536 |
+
'wallet_address': 'addrW2', # Some other wallet
|
| 537 |
+
'token_address': 'tknA',
|
| 538 |
+
'trade_direction': 0, 'sol_amount': 10.0,
|
| 539 |
+
# --- FIXED: Pass the integer ID directly ---
|
| 540 |
+
'dex_platform_id': vocab.DEX_TO_ID['OXK'],
|
| 541 |
+
'priority_fee': 0.002,
|
| 542 |
+
'mev_protection': False,
|
| 543 |
+
'token_amount_pct_of_holding': 0.8, 'quote_amount_pct_of_holding': 0.5,
|
| 544 |
+
'slippage': 0.03, 'price_impact': 0.05, 'success': True, 'is_bundle': False, 'total_usd': 1500.0
|
| 545 |
+
},
|
| 546 |
+
{'event_type': 'Chart_Segment', 'timestamp': 1729711260, 'relative_ts': 70, 'opens': [1.0]*OHLC_SEQ_LEN, 'closes': [1.1]*OHLC_SEQ_LEN, 'i': '1s'},
|
| 547 |
+
{'event_type': 'PoolCreated', # NEW
|
| 548 |
+
'timestamp': 1729711270,
|
| 549 |
+
'relative_ts': 80,
|
| 550 |
+
'wallet_address': 'addrW1',
|
| 551 |
+
'protocol_id': vocab.PROTOCOL_TO_ID['Raydium CPMM'],
|
| 552 |
+
'quote_token_address': 'tknB',
|
| 553 |
+
'base_amount': 1000000.0,
|
| 554 |
+
'quote_amount': 10.0
|
| 555 |
+
},
|
| 556 |
+
{'event_type': 'LiquidityChange', # NEW
|
| 557 |
+
'timestamp': 1729711280,
|
| 558 |
+
'relative_ts': 90,
|
| 559 |
+
'wallet_address': 'addrW2',
|
| 560 |
+
'quote_token_address': 'tknB',
|
| 561 |
+
'change_type_id': 0, # 0 for 'add'
|
| 562 |
+
'quote_amount': 2.0
|
| 563 |
+
},
|
| 564 |
+
{'event_type': 'FeeCollected', # NEW
|
| 565 |
+
'timestamp': 1729711290,
|
| 566 |
+
'relative_ts': 100,
|
| 567 |
+
'wallet_address': 'addrW1', # The recipient (e.g., dev wallet)
|
| 568 |
+
'sol_amount': 0.1
|
| 569 |
+
},
|
| 570 |
+
{'event_type': 'TokenBurn', # NEW
|
| 571 |
+
'timestamp': 1729711300,
|
| 572 |
+
'relative_ts': 110,
|
| 573 |
+
'wallet_address': 'addrW2', # The burner wallet
|
| 574 |
+
'amount_pct_of_total_supply': 0.01, # 1% of supply
|
| 575 |
+
'amount_tokens_burned': 10000000.0
|
| 576 |
+
},
|
| 577 |
+
{'event_type': 'SupplyLock', # NEW
|
| 578 |
+
'timestamp': 1729711310,
|
| 579 |
+
'relative_ts': 120,
|
| 580 |
+
'wallet_address': 'addrW1', # The locker wallet
|
| 581 |
+
'amount_pct_of_total_supply': 0.10, # 10% of supply
|
| 582 |
+
'lock_duration': 2592000 # 30 days in seconds
|
| 583 |
+
},
|
| 584 |
+
{'event_type': 'HolderSnapshot', # NEW
|
| 585 |
+
'timestamp': 1729711320,
|
| 586 |
+
'relative_ts': 130,
|
| 587 |
+
# This is a pointer to the pre-computed embedding
|
| 588 |
+
# In a real system, this would be the index of the embedding
|
| 589 |
+
'holders': [ # Raw holder data
|
| 590 |
+
{'wallet': 'addrW1', 'holding_pct': 0.15},
|
| 591 |
+
{'wallet': 'addrW2', 'holding_pct': 0.05},
|
| 592 |
+
# Add more mock holders if needed
|
| 593 |
+
]
|
| 594 |
+
},
|
| 595 |
+
{'event_type': 'OnChain_Snapshot', # NEW
|
| 596 |
+
'timestamp': 1729711320,
|
| 597 |
+
'relative_ts': 130,
|
| 598 |
+
'total_holders': 500,
|
| 599 |
+
'smart_traders': 25,
|
| 600 |
+
'kols': 3,
|
| 601 |
+
'holder_growth_rate': 0.15,
|
| 602 |
+
'top_10_holder_pct': 0.22,
|
| 603 |
+
'sniper_holding_pct': 0.05,
|
| 604 |
+
'rat_wallets_holding_pct': 0.02,
|
| 605 |
+
'bundle_holding_pct': 0.01,
|
| 606 |
+
'current_market_cap': 150000.0,
|
| 607 |
+
'volume': 50000.0,
|
| 608 |
+
'buy_count': 120,
|
| 609 |
+
'sell_count': 80,
|
| 610 |
+
'total_txns': 200,
|
| 611 |
+
'global_fees_paid': 1.5
|
| 612 |
+
},
|
| 613 |
+
{'event_type': 'TrendingToken', # NEW
|
| 614 |
+
'timestamp': 1729711330,
|
| 615 |
+
'relative_ts': 140,
|
| 616 |
+
'token_address': 'tknC', # The token that is trending
|
| 617 |
+
'list_source_id': vocab.TRENDING_LIST_SOURCE_TO_ID['Phantom'],
|
| 618 |
+
'timeframe_id': vocab.TRENDING_LIST_TIMEFRAME_TO_ID['1h'],
|
| 619 |
+
'rank': 3
|
| 620 |
+
},
|
| 621 |
+
{'event_type': 'BoostedToken', # NEW
|
| 622 |
+
'timestamp': 1729711340,
|
| 623 |
+
'relative_ts': 150,
|
| 624 |
+
'token_address': 'tknD', # The token that is boosted
|
| 625 |
+
'total_boost_amount': 5000.0,
|
| 626 |
+
'rank': 1
|
| 627 |
+
},
|
| 628 |
+
{'event_type': 'XQuoteTweet', # NEW
|
| 629 |
+
'timestamp': 1729711380,
|
| 630 |
+
'relative_ts': 190,
|
| 631 |
+
'wallet_address': 'addrW3', # The quoter
|
| 632 |
+
'quoter_text_emb_idx': pooler.get_idx('Wow, look at this! $TKA'),
|
| 633 |
+
'original_author_wallet_address': 'addrW1', # The original author
|
| 634 |
+
'original_post_text_emb_idx': pooler.get_idx('This is the main tweet about $TKA'),
|
| 635 |
+
'original_post_media_emb_idx': pooler.get_idx(Image.new('RGB', (100,100), color='cyan'))
|
| 636 |
+
},
|
| 637 |
+
# --- NEW: Add special context tokens ---
|
| 638 |
+
{'event_type': 'MIDDLE', 'timestamp': 1729711500, 'relative_ts': 195},
|
| 639 |
+
{'event_type': 'PumpReply', # NEW
|
| 640 |
+
'timestamp': 1729711390,
|
| 641 |
+
'relative_ts': 200,
|
| 642 |
+
'wallet_address': 'addrW2', # The user who replied
|
| 643 |
+
'reply_text_emb_idx': pooler.get_idx('to the moon!')
|
| 644 |
+
},
|
| 645 |
+
{'event_type': 'DexBoost_Paid', # NEW
|
| 646 |
+
'timestamp': 1729711400,
|
| 647 |
+
'relative_ts': 210,
|
| 648 |
+
'amount': 5.0, # e.g., 5 Boost
|
| 649 |
+
'total_amount_on_token': 25.0 # 25 Boost Points
|
| 650 |
+
},
|
| 651 |
+
{'event_type': 'DexProfile_Updated', # NEW
|
| 652 |
+
'timestamp': 1729711410,
|
| 653 |
+
'relative_ts': 220,
|
| 654 |
+
'has_changed_website_flag': True,
|
| 655 |
+
'has_changed_twitter_flag': False,
|
| 656 |
+
'has_changed_telegram_flag': True,
|
| 657 |
+
'has_changed_description_flag': True,
|
| 658 |
+
# Pre-computed text embeddings
|
| 659 |
+
'website_emb_idx': pooler.get_idx('new-token-website.com'),
|
| 660 |
+
'twitter_link_emb_idx': pooler.get_idx('old_handle'), # No change, so old link
|
| 661 |
+
'telegram_link_emb_idx': pooler.get_idx('new_tg_group'),
|
| 662 |
+
'description_emb_idx': pooler.get_idx('This is the new and improved token description.')
|
| 663 |
+
},
|
| 664 |
+
{'event_type': 'AlphaGroup_Call', # NEW
|
| 665 |
+
'timestamp': 1729711420,
|
| 666 |
+
'relative_ts': 230,
|
| 667 |
+
'group_id': vocab.ALPHA_GROUPS_TO_ID['Potion']
|
| 668 |
+
},
|
| 669 |
+
{'event_type': 'Channel_Call', # NEW
|
| 670 |
+
'timestamp': 1729711430,
|
| 671 |
+
'relative_ts': 240,
|
| 672 |
+
'channel_id': vocab.CALL_CHANNELS_TO_ID['MarcosCalls']
|
| 673 |
+
},
|
| 674 |
+
{'event_type': 'RECENT', 'timestamp': 1729711510, 'relative_ts': 245},
|
| 675 |
+
{'event_type': 'CexListing', # NEW
|
| 676 |
+
'timestamp': 1729711440,
|
| 677 |
+
'relative_ts': 250,
|
| 678 |
+
'exchange_id': vocab.EXCHANGES_TO_ID['mexc']
|
| 679 |
+
},
|
| 680 |
+
{'event_type': 'TikTok_Trending_Hashtag', # NEW
|
| 681 |
+
'timestamp': 1729711450,
|
| 682 |
+
'relative_ts': 260,
|
| 683 |
+
'hashtag_name_emb_idx': pooler.get_idx('CryptoTok'),
|
| 684 |
+
'rank': 5
|
| 685 |
+
},
|
| 686 |
+
{'event_type': 'XTrending_Hashtag', # NEW
|
| 687 |
+
'timestamp': 1729711460,
|
| 688 |
+
'relative_ts': 270,
|
| 689 |
+
'hashtag_name_emb_idx': pooler.get_idx('SolanaMemes'),
|
| 690 |
+
'rank': 2
|
| 691 |
+
},
|
| 692 |
+
{'event_type': 'ChainSnapshot', # NEW
|
| 693 |
+
'timestamp': 1729711470,
|
| 694 |
+
'relative_ts': 280,
|
| 695 |
+
'native_token_price_usd': 150.75,
|
| 696 |
+
'gas_fee': 0.00015 # Example gas fee
|
| 697 |
+
},
|
| 698 |
+
{'event_type': 'Lighthouse_Snapshot', # NEW
|
| 699 |
+
'timestamp': 1729711480,
|
| 700 |
+
'relative_ts': 290,
|
| 701 |
+
'protocol_id': vocab.PROTOCOL_TO_ID['Pump V1'],
|
| 702 |
+
'timeframe_id': vocab.LIGHTHOUSE_TIMEFRAME_TO_ID['1h'],
|
| 703 |
+
'total_volume': 1.2e6,
|
| 704 |
+
'total_transactions': 5000,
|
| 705 |
+
'total_traders': 1200,
|
| 706 |
+
'total_tokens_created': 85,
|
| 707 |
+
'total_migrations': 70
|
| 708 |
+
},
|
| 709 |
+
{'event_type': 'Migrated', # NEW
|
| 710 |
+
'timestamp': 1729711490,
|
| 711 |
+
'relative_ts': 300,
|
| 712 |
+
'protocol_id': vocab.PROTOCOL_TO_ID['Raydium CPMM']
|
| 713 |
+
},
|
| 714 |
+
|
| 715 |
+
],
|
| 716 |
+
'wallets': {
|
| 717 |
+
'addrW1': {'profile': profile1, 'socials': social1, 'holdings': holdings1},
|
| 718 |
+
'addrW2': {'profile': profile2, 'socials': social2, 'holdings': holdings2},
|
| 719 |
+
# --- NEW: Add wallet 3 data ---
|
| 720 |
+
'addrW3': {
|
| 721 |
+
'profile': {**profile2, 'wallet_address': 'addrW3'}, # Reuse profile2 but change address
|
| 722 |
+
'socials': social3,
|
| 723 |
+
'holdings': []
|
| 724 |
+
}
|
| 725 |
+
},
|
| 726 |
+
'tokens': {
|
| 727 |
+
'tknA': tokenA_data, # Main token
|
| 728 |
+
'tknB': tokenB_data, # Quote token
|
| 729 |
+
'tknC': tokenC_data, # Trending token
|
| 730 |
+
'tknD': tokenD_data # Boosted token
|
| 731 |
+
},
|
| 732 |
+
# --- NEW: The pre-computed embedding pool is generated after collecting all items
|
| 733 |
+
'embedding_pooler': pooler, # Pass the pooler to generate the tensor later
|
| 734 |
+
|
| 735 |
+
# --- NEW: Expanded graph_links to test all encoders ---
|
| 736 |
+
# --- FIXED: Removed useless logging fields as per user request ---
|
| 737 |
+
'graph_links': {
|
| 738 |
+
'TransferLink': {'links': [{'timestamp': 1729711205}], 'edges': [('addrW1', 'addrW2')]}, # Keep timestamp
|
| 739 |
+
'BundleTradeLink': {'links': [{'timestamp': 1729711215}], 'edges': [('addrW1', 'addrW2')]}, # Keep timestamp
|
| 740 |
+
'CopiedTradeLink': {'links': [
|
| 741 |
+
{'time_gap_on_buy_sec': 10, 'time_gap_on_sell_sec': 120, 'leader_pnl': 5.0, 'follower_pnl': 4.0, 'follower_buy_total': 100, 'follower_sell_total': 120}
|
| 742 |
+
], 'edges': [('addrW1', 'addrW2')]},
|
| 743 |
+
'CoordinatedActivityLink': {'links': [
|
| 744 |
+
{'time_gap_on_first_sec': 5, 'time_gap_on_second_sec': 8}
|
| 745 |
+
], 'edges': [('addrW1', 'addrW2')]},
|
| 746 |
+
'MintedLink': {'links': [
|
| 747 |
+
{'timestamp': 1729711200, 'buy_amount': 1e9}
|
| 748 |
+
], 'edges': [('addrW1', 'tknA')]},
|
| 749 |
+
'SnipedLink': {'links': [
|
| 750 |
+
{'rank': 1, 'sniped_amount': 5e8}
|
| 751 |
+
], 'edges': [('addrW1', 'tknA')]},
|
| 752 |
+
'LockedSupplyLink': {'links': [
|
| 753 |
+
{'amount': 1e10} # Only amount is needed
|
| 754 |
+
], 'edges': [('addrW1', 'tknA')]},
|
| 755 |
+
'BurnedLink': {'links': [
|
| 756 |
+
{'timestamp': 1729711300} # Only timestamp is needed
|
| 757 |
+
], 'edges': [('addrW2', 'tknA')]},
|
| 758 |
+
'ProvidedLiquidityLink': {'links': [
|
| 759 |
+
{'timestamp': 1729711250} # Only timestamp is needed
|
| 760 |
+
], 'edges': [('addrW1', 'tknA')]},
|
| 761 |
+
'WhaleOfLink': {'links': [
|
| 762 |
+
{} # Just the existence of the link is the feature
|
| 763 |
+
], 'edges': [('addrW1', 'tknA')]},
|
| 764 |
+
'TopTraderOfLink': {'links': [
|
| 765 |
+
{'pnl_at_creation': 50000.0} # Only PnL is needed
|
| 766 |
+
], 'edges': [('addrW2', 'tknA')]}
|
| 767 |
+
},
|
| 768 |
+
|
| 769 |
+
# --- FIXED: Removed chart_segments dictionary ---
|
| 770 |
+
'labels': torch.randn(self.num_outputs) if self.num_outputs > 0 else torch.zeros(0),
|
| 771 |
+
'labels_mask': torch.ones(self.num_outputs) if self.num_outputs > 0 else torch.zeros(0)
|
| 772 |
+
}
|
| 773 |
+
|
| 774 |
+
print("Mock raw batch created.")
|
| 775 |
+
|
| 776 |
+
return item
|
data/data_collator.py
ADDED
|
@@ -0,0 +1,708 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# memecoin_collator.py (CORRECTED ORDER OF OPERATIONS)
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 6 |
+
from typing import List, Dict, Any, Tuple, Optional, Union
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from models.multi_modal_processor import MultiModalEncoder
|
| 10 |
+
|
| 11 |
+
# Encoders are NO LONGER imported here
|
| 12 |
+
import models.vocabulary as vocab # For IDs, config sizes
|
| 13 |
+
from data.data_loader import EmbeddingPooler # Import for type hinting and instantiation
|
| 14 |
+
|
| 15 |
+
NATIVE_MINT = "So11111111111111111111111111111111111111112"
|
| 16 |
+
QUOTE_MINTS = {
|
| 17 |
+
NATIVE_MINT, # SOL
|
| 18 |
+
"EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v", # USDC
|
| 19 |
+
"Es9vMFrzaCERmJfrF4H2FYD4KCoNkY11McCe8BenwNYB", # USDT
|
| 20 |
+
"USD1ttGY1N17NEEHLmELoaybftRBUSErhqYiQzvEmuB", # USD1
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
class MemecoinCollator:
|
| 24 |
+
"""
|
| 25 |
+
Callable class for PyTorch DataLoader's collate_fn.
|
| 26 |
+
... (rest of docstring) ...
|
| 27 |
+
"""
|
| 28 |
+
def __init__(self,
|
| 29 |
+
event_type_to_id: Dict[str, int],
|
| 30 |
+
device: torch.device,
|
| 31 |
+
multi_modal_encoder: MultiModalEncoder,
|
| 32 |
+
dtype: torch.dtype,
|
| 33 |
+
ohlc_seq_len: int = 300,
|
| 34 |
+
max_seq_len: Optional[int] = None
|
| 35 |
+
):
|
| 36 |
+
self.event_type_to_id = event_type_to_id
|
| 37 |
+
self.pad_token_id = event_type_to_id.get('__PAD__', 0)
|
| 38 |
+
self.multi_modal_encoder = multi_modal_encoder
|
| 39 |
+
self.entity_pad_idx = 0
|
| 40 |
+
|
| 41 |
+
self.device = device
|
| 42 |
+
self.dtype = dtype
|
| 43 |
+
self.ohlc_seq_len = ohlc_seq_len
|
| 44 |
+
self.max_seq_len = max_seq_len
|
| 45 |
+
|
| 46 |
+
def _collate_features_for_encoder(self, entities: List[Dict], feature_keys: List[str], device: torch.device, entity_type: str) -> Dict[str, Any]:
|
| 47 |
+
""" (Unchanged) """
|
| 48 |
+
collated = defaultdict(list)
|
| 49 |
+
if not entities:
|
| 50 |
+
# --- FIXED: Return a default empty structure for BOTH tokens and wallets ---
|
| 51 |
+
if entity_type == "token":
|
| 52 |
+
return {
|
| 53 |
+
'name_embed_indices': torch.tensor([], device=device, dtype=torch.long),
|
| 54 |
+
'symbol_embed_indices': torch.tensor([], device=device, dtype=torch.long),
|
| 55 |
+
'image_embed_indices': torch.tensor([], device=device, dtype=torch.long),
|
| 56 |
+
'protocol_ids': torch.tensor([], device=device, dtype=torch.long),
|
| 57 |
+
'is_vanity_flags': torch.tensor([], device=device, dtype=torch.bool),
|
| 58 |
+
'_addresses_for_lookup': []
|
| 59 |
+
}
|
| 60 |
+
elif entity_type == "wallet":
|
| 61 |
+
return {
|
| 62 |
+
'username_embed_indices': torch.tensor([], device=device, dtype=torch.long),
|
| 63 |
+
'profile_rows': [], 'social_rows': [], 'holdings_batch': []
|
| 64 |
+
}
|
| 65 |
+
return {} # Should not happen
|
| 66 |
+
|
| 67 |
+
# NEW: We now gather indices to pre-computed embeddings
|
| 68 |
+
if entity_type == "token":
|
| 69 |
+
# This indicates a Token entity
|
| 70 |
+
# Helper key for WalletEncoder to find token vibes
|
| 71 |
+
collated['_addresses_for_lookup'] = [e.get('address', '') for e in entities]
|
| 72 |
+
collated['name_embed_indices'] = torch.tensor([e.get('name_emb_idx', 0) for e in entities], device=device, dtype=torch.long)
|
| 73 |
+
collated['symbol_embed_indices'] = torch.tensor([e.get('symbol_emb_idx', 0) for e in entities], device=device, dtype=torch.long)
|
| 74 |
+
collated['image_embed_indices'] = torch.tensor([e.get('image_emb_idx', 0) for e in entities], device=device, dtype=torch.long)
|
| 75 |
+
collated['protocol_ids'] = torch.tensor([e.get('protocol', 0) for e in entities], device=device, dtype=torch.long)
|
| 76 |
+
collated['is_vanity_flags'] = torch.tensor([e.get('is_vanity', False) for e in entities], device=device, dtype=torch.bool)
|
| 77 |
+
elif entity_type == "wallet":
|
| 78 |
+
# NEW: Gather username indices for WalletEncoder
|
| 79 |
+
collated['username_embed_indices'] = torch.tensor([e.get('socials', {}).get('username_emb_idx', 0) for e in entities], device=device, dtype=torch.long)
|
| 80 |
+
collated['profile_rows'] = [e.get('profile', {}) for e in entities]
|
| 81 |
+
collated['social_rows'] = [e.get('socials', {}) for e in entities]
|
| 82 |
+
collated['holdings_batch'] = [e.get('holdings', []) for e in entities]
|
| 83 |
+
return dict(collated)
|
| 84 |
+
|
| 85 |
+
def _collate_ohlc_inputs(self, chart_events: List[Dict]) -> Dict[str, torch.Tensor]:
|
| 86 |
+
""" (Unchanged from previous correct version) """
|
| 87 |
+
if not chart_events:
|
| 88 |
+
return {
|
| 89 |
+
'price_tensor': torch.empty(0, 2, self.ohlc_seq_len, device=self.device, dtype=self.dtype),
|
| 90 |
+
'interval_ids': torch.empty(0, device=self.device, dtype=torch.long)
|
| 91 |
+
}
|
| 92 |
+
ohlc_tensors = []
|
| 93 |
+
interval_ids_list = []
|
| 94 |
+
seq_len = self.ohlc_seq_len
|
| 95 |
+
unknown_id = vocab.INTERVAL_TO_ID.get("Unknown", 0)
|
| 96 |
+
for segment_data in chart_events:
|
| 97 |
+
opens = segment_data.get('opens', [])
|
| 98 |
+
closes = segment_data.get('closes', [])
|
| 99 |
+
interval_str = segment_data.get('i', "Unknown")
|
| 100 |
+
pad_open = opens[-1] if opens else 0
|
| 101 |
+
pad_close = closes[-1] if closes else 0
|
| 102 |
+
o = torch.tensor(opens[:seq_len] + [pad_open]*(seq_len-len(opens)), dtype=self.dtype)
|
| 103 |
+
c = torch.tensor(closes[:seq_len] + [pad_close]*(seq_len-len(closes)), dtype=self.dtype)
|
| 104 |
+
ohlc_tensors.append(torch.stack([o, c]))
|
| 105 |
+
interval_id = vocab.INTERVAL_TO_ID.get(interval_str, unknown_id)
|
| 106 |
+
interval_ids_list.append(interval_id)
|
| 107 |
+
return {
|
| 108 |
+
'price_tensor': torch.stack(ohlc_tensors).to(self.device),
|
| 109 |
+
'interval_ids': torch.tensor(interval_ids_list, device=self.device, dtype=torch.long)
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
def _collate_graph_links(self,
|
| 113 |
+
batch_items: List[Dict],
|
| 114 |
+
wallet_addr_to_batch_idx: Dict[str, int],
|
| 115 |
+
token_addr_to_batch_idx: Dict[str, int]) -> Dict[str, Any]:
|
| 116 |
+
""" (Unchanged) """
|
| 117 |
+
aggregated_links = defaultdict(lambda: {'edge_index_list': [], 'links_list': []})
|
| 118 |
+
for item in batch_items:
|
| 119 |
+
item_wallets = item.get('wallets', {})
|
| 120 |
+
item_tokens = item.get('tokens', {})
|
| 121 |
+
item_wallet_addr_to_global_idx = {addr: wallet_addr_to_batch_idx.get(addr, self.entity_pad_idx) for addr in item_wallets.keys()}
|
| 122 |
+
item_token_addr_to_global_idx = {addr: token_addr_to_batch_idx.get(addr, self.entity_pad_idx) for addr in item_tokens.keys()}
|
| 123 |
+
for link_name, data in item.get('graph_links', {}).items():
|
| 124 |
+
aggregated_links[link_name]['links_list'].extend(data.get('links', []))
|
| 125 |
+
triplet = vocab.LINK_NAME_TO_TRIPLET.get(link_name)
|
| 126 |
+
if not triplet: continue
|
| 127 |
+
src_type, _, dst_type = triplet
|
| 128 |
+
edges = data.get('edges')
|
| 129 |
+
if not edges: continue
|
| 130 |
+
src_map = item_wallet_addr_to_global_idx if src_type == 'wallet' else item_token_addr_to_global_idx
|
| 131 |
+
dst_map = item_wallet_addr_to_global_idx if dst_type == 'wallet' else item_token_addr_to_global_idx
|
| 132 |
+
remapped_edge_list = []
|
| 133 |
+
for src_addr, dst_addr in edges:
|
| 134 |
+
src_idx_global = src_map.get(src_addr, self.entity_pad_idx)
|
| 135 |
+
dst_idx_global = dst_map.get(dst_addr, self.entity_pad_idx)
|
| 136 |
+
if src_idx_global != self.entity_pad_idx and dst_idx_global != self.entity_pad_idx:
|
| 137 |
+
remapped_edge_list.append([src_idx_global, dst_idx_global])
|
| 138 |
+
if remapped_edge_list:
|
| 139 |
+
remapped_edge_tensor = torch.tensor(remapped_edge_list, device=self.device, dtype=torch.long).t()
|
| 140 |
+
aggregated_links[link_name]['edge_index_list'].append(remapped_edge_tensor)
|
| 141 |
+
if link_name == "TransferLink":
|
| 142 |
+
link_props = data.get('links', [])
|
| 143 |
+
derived_edges = []
|
| 144 |
+
derived_props = []
|
| 145 |
+
for (src_addr, dst_addr), props in zip(edges, link_props):
|
| 146 |
+
mint_addr = props.get('mint')
|
| 147 |
+
if not mint_addr or mint_addr in QUOTE_MINTS:
|
| 148 |
+
continue
|
| 149 |
+
token_idx_global = item_token_addr_to_global_idx.get(mint_addr, self.entity_pad_idx)
|
| 150 |
+
if token_idx_global == self.entity_pad_idx:
|
| 151 |
+
continue
|
| 152 |
+
for wallet_addr in (src_addr, dst_addr):
|
| 153 |
+
wallet_idx_global = item_wallet_addr_to_global_idx.get(wallet_addr, self.entity_pad_idx)
|
| 154 |
+
if wallet_idx_global == self.entity_pad_idx:
|
| 155 |
+
continue
|
| 156 |
+
derived_edges.append([wallet_idx_global, token_idx_global])
|
| 157 |
+
derived_props.append(props)
|
| 158 |
+
if derived_edges:
|
| 159 |
+
derived_tensor = torch.tensor(derived_edges, device=self.device, dtype=torch.long).t()
|
| 160 |
+
aggregated_links["TransferLinkToken"]['edge_index_list'].append(derived_tensor)
|
| 161 |
+
aggregated_links["TransferLinkToken"]['links_list'].extend(derived_props)
|
| 162 |
+
final_links_dict = {}
|
| 163 |
+
for link_name, data in aggregated_links.items():
|
| 164 |
+
if data['edge_index_list']:
|
| 165 |
+
final_links_dict[link_name] = {
|
| 166 |
+
'links': data['links_list'],
|
| 167 |
+
'edge_index': torch.cat(data['edge_index_list'], dim=1)
|
| 168 |
+
}
|
| 169 |
+
return final_links_dict
|
| 170 |
+
|
| 171 |
+
def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 172 |
+
"""
|
| 173 |
+
Processes a batch of raw data items into tensors for the model.
|
| 174 |
+
"""
|
| 175 |
+
# --- NEW ARCHITECTURE ---
|
| 176 |
+
# 1. Aggregate all unique embeddable items from the entire batch.
|
| 177 |
+
# 2. Create a single embedding pool tensor for the whole batch.
|
| 178 |
+
# 3. Create a mapping from original (per-item) indices to the new batch-wide indices.
|
| 179 |
+
# 4. Remap all `_emb_idx` fields in the batch data using this new mapping.
|
| 180 |
+
|
| 181 |
+
batch_size = len(batch)
|
| 182 |
+
if batch_size == 0:
|
| 183 |
+
return {}
|
| 184 |
+
|
| 185 |
+
# --- 1. Aggregate all unique items and create index mappings ---
|
| 186 |
+
batch_wide_pooler = EmbeddingPooler()
|
| 187 |
+
# Map to translate from an item's original pooler to the new batch-wide indices
|
| 188 |
+
# Format: { batch_item_index: { original_idx: new_batch_idx } }
|
| 189 |
+
idx_remap = defaultdict(dict)
|
| 190 |
+
|
| 191 |
+
for i, item in enumerate(batch):
|
| 192 |
+
pooler = item.get('embedding_pooler')
|
| 193 |
+
if not pooler: continue
|
| 194 |
+
|
| 195 |
+
for pool_item_data in pooler.get_all_items():
|
| 196 |
+
original_idx = pool_item_data['idx']
|
| 197 |
+
raw_item = pool_item_data['item']
|
| 198 |
+
# get_idx will either return an existing index or create a new one
|
| 199 |
+
# --- FIX: Convert 1-based pooler index to 0-based tensor index ---
|
| 200 |
+
new_batch_idx_1_based = batch_wide_pooler.get_idx(raw_item)
|
| 201 |
+
new_batch_idx_0_based = new_batch_idx_1_based - 1
|
| 202 |
+
idx_remap[i][original_idx] = new_batch_idx_0_based
|
| 203 |
+
|
| 204 |
+
# --- 2. Create the single, batch-wide embedding pool tensor ---
|
| 205 |
+
all_items_sorted = batch_wide_pooler.get_all_items()
|
| 206 |
+
texts_to_encode = [d['item'] for d in all_items_sorted if isinstance(d['item'], str)]
|
| 207 |
+
images_to_encode = [d['item'] for d in all_items_sorted if isinstance(d['item'], Image.Image)]
|
| 208 |
+
|
| 209 |
+
text_embeds = self.multi_modal_encoder(texts_to_encode) if texts_to_encode else torch.empty(0)
|
| 210 |
+
image_embeds = self.multi_modal_encoder(images_to_encode) if images_to_encode else torch.empty(0)
|
| 211 |
+
|
| 212 |
+
# Create the final lookup tensor and fill it based on original item type
|
| 213 |
+
batch_embedding_pool = torch.zeros(len(all_items_sorted), self.multi_modal_encoder.embedding_dim, device=self.device, dtype=self.dtype)
|
| 214 |
+
text_cursor, image_cursor = 0, 0
|
| 215 |
+
for i, item_data in enumerate(all_items_sorted):
|
| 216 |
+
if isinstance(item_data['item'], str):
|
| 217 |
+
if text_embeds.numel() > 0:
|
| 218 |
+
batch_embedding_pool[i] = text_embeds[text_cursor]
|
| 219 |
+
text_cursor += 1
|
| 220 |
+
elif isinstance(item_data['item'], Image.Image):
|
| 221 |
+
if image_embeds.numel() > 0:
|
| 222 |
+
batch_embedding_pool[i] = image_embeds[image_cursor]
|
| 223 |
+
image_cursor += 1
|
| 224 |
+
|
| 225 |
+
# --- 3. Remap all indices in the batch data ---
|
| 226 |
+
for i, item in enumerate(batch):
|
| 227 |
+
remap_dict = idx_remap.get(i, {})
|
| 228 |
+
if not remap_dict: continue
|
| 229 |
+
|
| 230 |
+
# Remap tokens
|
| 231 |
+
for token_data in item.get('tokens', {}).values():
|
| 232 |
+
for key in ['name_emb_idx', 'symbol_emb_idx', 'image_emb_idx']:
|
| 233 |
+
if token_data.get(key, 0) > 0: # Check if it has a valid 1-based index
|
| 234 |
+
token_data[key] = remap_dict.get(token_data[key], -1) # Remap to 0-based, default to -1 if not found
|
| 235 |
+
# Remap wallets
|
| 236 |
+
for wallet_data in item.get('wallets', {}).values():
|
| 237 |
+
socials = wallet_data.get('socials', {})
|
| 238 |
+
if socials.get('username_emb_idx', 0) > 0:
|
| 239 |
+
socials['username_emb_idx'] = remap_dict.get(socials['username_emb_idx'], -1)
|
| 240 |
+
# Remap events
|
| 241 |
+
for event in item.get('event_sequence', []):
|
| 242 |
+
for key in event:
|
| 243 |
+
if key.endswith('_emb_idx') and event.get(key, 0) > 0:
|
| 244 |
+
event[key] = remap_dict.get(event[key], 0)
|
| 245 |
+
|
| 246 |
+
# --- 4. Standard Collation (Now that indices are correct) ---
|
| 247 |
+
unique_wallets_data = {}
|
| 248 |
+
unique_tokens_data = {}
|
| 249 |
+
all_event_sequences = []
|
| 250 |
+
max_len = 0
|
| 251 |
+
|
| 252 |
+
for item in batch:
|
| 253 |
+
seq = item.get('event_sequence', [])
|
| 254 |
+
if self.max_seq_len is not None and len(seq) > self.max_seq_len:
|
| 255 |
+
seq = seq[:self.max_seq_len]
|
| 256 |
+
all_event_sequences.append(seq)
|
| 257 |
+
max_len = max(max_len, len(seq))
|
| 258 |
+
unique_wallets_data.update(item.get('wallets', {}))
|
| 259 |
+
unique_tokens_data.update(item.get('tokens', {}))
|
| 260 |
+
|
| 261 |
+
# Create mappings needed for indexing
|
| 262 |
+
wallet_list_data = list(unique_wallets_data.values())
|
| 263 |
+
token_list_data = list(unique_tokens_data.values())
|
| 264 |
+
wallet_addr_to_batch_idx = {feat.get('profile', {}).get('wallet_address', f'__error_{i}'): i+1 for i, feat in enumerate(wallet_list_data)}
|
| 265 |
+
token_addr_to_batch_idx = {feat.get('address', f'__error_{i}'): i+1 for i, feat in enumerate(token_list_data)}
|
| 266 |
+
|
| 267 |
+
# Collate Static Raw Features (Tokens, Wallets, Graph)
|
| 268 |
+
token_encoder_inputs = self._collate_features_for_encoder(token_list_data, ['name'], self.device, "token")
|
| 269 |
+
wallet_encoder_inputs = self._collate_features_for_encoder(wallet_list_data, ['profile'], self.device, "wallet")
|
| 270 |
+
graph_updater_links = self._collate_graph_links(batch, wallet_addr_to_batch_idx, token_addr_to_batch_idx)
|
| 271 |
+
|
| 272 |
+
# --- Logging ---
|
| 273 |
+
pool_contents = batch_wide_pooler.get_all_items()
|
| 274 |
+
print(f"\n[DataCollator: Final Embedding Pool] ({len(pool_contents)} items):")
|
| 275 |
+
if pool_contents:
|
| 276 |
+
for item_data in pool_contents:
|
| 277 |
+
sample_item = item_data['item']
|
| 278 |
+
sample_type = "Image" if isinstance(sample_item, Image.Image) else "Text"
|
| 279 |
+
content_preview = str(sample_item)
|
| 280 |
+
if sample_type == "Text" and len(content_preview) > 100:
|
| 281 |
+
content_preview = content_preview[:97] + "..."
|
| 282 |
+
print(f" - Item (Original Idx {item_data['idx']}): Type='{sample_type}', Content='{content_preview}'")
|
| 283 |
+
|
| 284 |
+
# --- 5. Prepare Sequence Tensors & Collect Dynamic Data (OHLC) ---
|
| 285 |
+
B = batch_size
|
| 286 |
+
L = max_len
|
| 287 |
+
PAD_IDX_SEQ = self.pad_token_id
|
| 288 |
+
PAD_IDX_ENT = self.entity_pad_idx
|
| 289 |
+
|
| 290 |
+
# Initialize sequence tensors
|
| 291 |
+
event_type_ids = torch.full((B, L), PAD_IDX_SEQ, dtype=torch.long, device=self.device)
|
| 292 |
+
timestamps_float = torch.zeros((B, L), dtype=torch.float32, device=self.device)
|
| 293 |
+
# Store relative_ts in float32 for stability; model will scale/log/normalize
|
| 294 |
+
relative_ts = torch.zeros((B, L, 1), dtype=torch.float32, device=self.device)
|
| 295 |
+
attention_mask = torch.zeros((B, L), dtype=torch.long, device=self.device)
|
| 296 |
+
wallet_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
|
| 297 |
+
token_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
|
| 298 |
+
ohlc_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
|
| 299 |
+
quote_token_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device) # NEW
|
| 300 |
+
|
| 301 |
+
# --- NEW: Tensors for Transfer/LargeTransfer ---
|
| 302 |
+
dest_wallet_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
|
| 303 |
+
# --- NEW: Separate tensor for social media original authors ---
|
| 304 |
+
original_author_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
|
| 305 |
+
# 4 numerical features for transfers
|
| 306 |
+
transfer_numerical_features = torch.zeros((B, L, 4), dtype=self.dtype, device=self.device)
|
| 307 |
+
|
| 308 |
+
# --- NEW: Tensors for Trade ---
|
| 309 |
+
# --- FIXED: Size reduced from 10 to 8 ---
|
| 310 |
+
trade_numerical_features = torch.zeros((B, L, 8), dtype=self.dtype, device=self.device)
|
| 311 |
+
deployer_trade_numerical_features = torch.zeros((B, L, 8), dtype=self.dtype, device=self.device)
|
| 312 |
+
smart_wallet_trade_numerical_features = torch.zeros((B, L, 8), dtype=self.dtype, device=self.device)
|
| 313 |
+
# --- NEW: Dedicated tensor for categorical dex_platform_id ---
|
| 314 |
+
trade_dex_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
|
| 315 |
+
# --- NEW: Dedicated tensor for categorical trade_direction ---
|
| 316 |
+
trade_direction_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
|
| 317 |
+
# --- NEW: Dedicated tensor for categorical mev_protection ---
|
| 318 |
+
trade_mev_protection_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
|
| 319 |
+
# --- NEW: Dedicated tensor for categorical is_bundle ---
|
| 320 |
+
trade_is_bundle_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
|
| 321 |
+
|
| 322 |
+
# --- NEW: Tensors for PoolCreated ---
|
| 323 |
+
# --- UPDATED: Capture raw base/quote deposit amounts only ---
|
| 324 |
+
pool_created_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device)
|
| 325 |
+
# --- NEW: Dedicated tensor for categorical protocol_id ---
|
| 326 |
+
pool_created_protocol_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
|
| 327 |
+
|
| 328 |
+
# --- NEW: Tensors for LiquidityChange ---
|
| 329 |
+
# --- UPDATED: Keep only the raw quote amount deposit/withdraw ---
|
| 330 |
+
liquidity_change_numerical_features = torch.zeros((B, L, 1), dtype=self.dtype, device=self.device)
|
| 331 |
+
# --- NEW: Dedicated tensor for categorical change_type_id ---
|
| 332 |
+
liquidity_change_type_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
|
| 333 |
+
|
| 334 |
+
# --- NEW: Tensors for FeeCollected ---
|
| 335 |
+
fee_collected_numerical_features = torch.zeros((B, L, 1), dtype=self.dtype, device=self.device) # sol_amount only
|
| 336 |
+
# --- NEW: Tensors for TokenBurn ---
|
| 337 |
+
token_burn_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device) # amount_pct, amount_tokens
|
| 338 |
+
|
| 339 |
+
# --- NEW: Tensors for SupplyLock ---
|
| 340 |
+
supply_lock_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device) # amount_pct, lock_duration
|
| 341 |
+
|
| 342 |
+
# --- NEW: Tensors for OnChain_Snapshot ---
|
| 343 |
+
onchain_snapshot_numerical_features = torch.zeros((B, L, 14), dtype=self.dtype, device=self.device)
|
| 344 |
+
|
| 345 |
+
# --- NEW: Tensors for TrendingToken ---
|
| 346 |
+
trending_token_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
|
| 347 |
+
# --- FIXED: Size reduced from 3 to 1 after removing IDs ---
|
| 348 |
+
trending_token_numerical_features = torch.zeros((B, L, 1), dtype=self.dtype, device=self.device) # rank
|
| 349 |
+
trending_token_source_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
|
| 350 |
+
trending_token_timeframe_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
|
| 351 |
+
|
| 352 |
+
# --- NEW: Tensors for BoostedToken ---
|
| 353 |
+
boosted_token_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
|
| 354 |
+
boosted_token_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device) # total_boost_amount, rank
|
| 355 |
+
|
| 356 |
+
# --- NEW: Tensors for DexBoost_Paid ---
|
| 357 |
+
dexboost_paid_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device) # amount, total_amount_on_token
|
| 358 |
+
|
| 359 |
+
# --- NEW: Tensors for DexProfile_Updated ---
|
| 360 |
+
dexprofile_updated_flags = torch.zeros((B, L, 4), dtype=torch.float32, device=self.device) # Using float for easier projection
|
| 361 |
+
|
| 362 |
+
# --- NEW: Tensors for Tracker Events ---
|
| 363 |
+
alpha_group_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
|
| 364 |
+
channel_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
|
| 365 |
+
exchange_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
|
| 366 |
+
|
| 367 |
+
# --- NEW: Tensors for GlobalTrending Events ---
|
| 368 |
+
global_trending_numerical_features = torch.zeros((B, L, 1), dtype=self.dtype, device=self.device) # rank
|
| 369 |
+
|
| 370 |
+
# --- NEW: Tensors for ChainSnapshot ---
|
| 371 |
+
chainsnapshot_numerical_features = torch.zeros((B, L, 2), dtype=self.dtype, device=self.device) # native_token_price_usd, gas_fee
|
| 372 |
+
|
| 373 |
+
# --- NEW: Tensors for Lighthouse_Snapshot ---
|
| 374 |
+
# --- FIXED: Size reduced from 7 to 5 after removing IDs ---
|
| 375 |
+
lighthousesnapshot_numerical_features = torch.zeros((B, L, 5), dtype=self.dtype, device=self.device)
|
| 376 |
+
lighthousesnapshot_protocol_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
|
| 377 |
+
lighthousesnapshot_timeframe_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
|
| 378 |
+
|
| 379 |
+
# --- NEW: Tensors for Migrated event ---
|
| 380 |
+
migrated_protocol_ids = torch.full((B, L), 0, dtype=torch.long, device=self.device)
|
| 381 |
+
|
| 382 |
+
# --- NEW: Tensors for HolderSnapshot ---
|
| 383 |
+
# This will store the raw holder data for the Oracle to process
|
| 384 |
+
holder_snapshot_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
|
| 385 |
+
holder_snapshot_raw_data_list = [] # List of lists of dicts
|
| 386 |
+
|
| 387 |
+
# --- RENAMED: Generic tensors for any event with text/image features ---
|
| 388 |
+
textual_event_data_list = [] # List of dicts with text/media indices
|
| 389 |
+
textual_event_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
|
| 390 |
+
# --- NEW: Pointers for pre-encoded images ---
|
| 391 |
+
image_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
|
| 392 |
+
original_post_image_indices = torch.full((B, L), PAD_IDX_ENT, dtype=torch.long, device=self.device)
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
# --- CORRECTED: Initialize chart event collection here ---
|
| 397 |
+
batch_chart_events = []
|
| 398 |
+
chart_event_counter = 0
|
| 399 |
+
|
| 400 |
+
# Loop through sequences to populate tensors and collect chart events
|
| 401 |
+
for i, seq in enumerate(all_event_sequences):
|
| 402 |
+
seq_len = len(seq)
|
| 403 |
+
if seq_len == 0: continue
|
| 404 |
+
attention_mask[i, :seq_len] = 1
|
| 405 |
+
|
| 406 |
+
for j, event in enumerate(seq):
|
| 407 |
+
# Populate basic sequence info
|
| 408 |
+
event_type = event.get('event_type', '__PAD__')
|
| 409 |
+
type_id = self.event_type_to_id.get(event_type, PAD_IDX_SEQ)
|
| 410 |
+
event_type_ids[i, j] = type_id
|
| 411 |
+
timestamps_float[i, j] = event.get('timestamp', 0)
|
| 412 |
+
relative_ts[i, j, 0] = event.get('relative_ts', 0.0)
|
| 413 |
+
|
| 414 |
+
# Populate pointer indices
|
| 415 |
+
w_addr = event.get('wallet_address')
|
| 416 |
+
if w_addr:
|
| 417 |
+
wallet_indices[i, j] = wallet_addr_to_batch_idx.get(w_addr, PAD_IDX_ENT)
|
| 418 |
+
t_addr = event.get('token_address')
|
| 419 |
+
if t_addr:
|
| 420 |
+
token_indices[i, j] = token_addr_to_batch_idx.get(t_addr, PAD_IDX_ENT)
|
| 421 |
+
|
| 422 |
+
# If it's a chart event, collect it and record its index
|
| 423 |
+
if event_type == 'Chart_Segment':
|
| 424 |
+
batch_chart_events.append(event)
|
| 425 |
+
ohlc_indices[i, j] = chart_event_counter + 1 # Use 1-based index
|
| 426 |
+
chart_event_counter += 1
|
| 427 |
+
|
| 428 |
+
elif event_type in ['Transfer', 'LargeTransfer']: # ADDED LargeTransfer
|
| 429 |
+
# Get destination wallet index
|
| 430 |
+
dest_w_addr = event.get('destination_wallet_address') # Assuming this key exists
|
| 431 |
+
if dest_w_addr:
|
| 432 |
+
dest_wallet_indices[i, j] = wallet_addr_to_batch_idx.get(dest_w_addr, PAD_IDX_ENT)
|
| 433 |
+
|
| 434 |
+
# Get numerical features (use .get with default 0)
|
| 435 |
+
num_feats = [
|
| 436 |
+
event.get('token_amount', 0.0),
|
| 437 |
+
event.get('transfer_pct_of_total_supply', 0.0),
|
| 438 |
+
event.get('transfer_pct_of_holding', 0.0),
|
| 439 |
+
event.get('priority_fee', 0.0)
|
| 440 |
+
]
|
| 441 |
+
transfer_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
|
| 442 |
+
|
| 443 |
+
elif event_type in ['Trade', 'LargeTrade']:
|
| 444 |
+
# Get numerical and categorical features for the trade
|
| 445 |
+
trade_dex_ids[i, j] = event.get('dex_platform_id', 0)
|
| 446 |
+
trade_direction_ids[i, j] = event.get('trade_direction', 0) # 0=buy, 1=sell
|
| 447 |
+
trade_mev_protection_ids[i, j] = event.get('mev_protection', 0) # 0, 1, 2...
|
| 448 |
+
trade_is_bundle_ids[i, j] = 1 if event.get('is_bundle') else 0 # 0=false, 1=true
|
| 449 |
+
|
| 450 |
+
num_feats = [
|
| 451 |
+
event.get('sol_amount', 0.0),
|
| 452 |
+
event.get('priority_fee', 0.0),
|
| 453 |
+
event.get('token_amount_pct_of_holding', 0.0),
|
| 454 |
+
event.get('quote_amount_pct_of_holding', 0.0),
|
| 455 |
+
event.get('slippage', 0.0),
|
| 456 |
+
event.get('token_amount_pct_to_total_supply', 0.0), # REPLACED price_impact
|
| 457 |
+
1.0 if event.get('success') else 0.0,
|
| 458 |
+
event.get('total_usd', 0.0)
|
| 459 |
+
]
|
| 460 |
+
trade_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
|
| 461 |
+
|
| 462 |
+
elif event_type == 'Deployer_Trade':
|
| 463 |
+
# Use the dedicated tensor for deployer trades
|
| 464 |
+
trade_dex_ids[i, j] = event.get('dex_platform_id', 0)
|
| 465 |
+
trade_direction_ids[i, j] = event.get('trade_direction', 0) # 0=buy, 1=sell
|
| 466 |
+
trade_mev_protection_ids[i, j] = event.get('mev_protection', 0) # 0, 1, 2...
|
| 467 |
+
trade_is_bundle_ids[i, j] = 1 if event.get('is_bundle') else 0 # 0=false, 1=true
|
| 468 |
+
num_feats = [
|
| 469 |
+
event.get('sol_amount', 0.0),
|
| 470 |
+
event.get('priority_fee', 0.0),
|
| 471 |
+
event.get('token_amount_pct_of_holding', 0.0),
|
| 472 |
+
event.get('quote_amount_pct_of_holding', 0.0),
|
| 473 |
+
event.get('slippage', 0.0),
|
| 474 |
+
event.get('token_amount_pct_to_total_supply', 0.0), # REPLACED price_impact
|
| 475 |
+
1.0 if event.get('success') else 0.0,
|
| 476 |
+
event.get('total_usd', 0.0)
|
| 477 |
+
]
|
| 478 |
+
deployer_trade_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
|
| 479 |
+
|
| 480 |
+
elif event_type == 'SmartWallet_Trade':
|
| 481 |
+
# Use the dedicated tensor for smart wallet trades
|
| 482 |
+
trade_dex_ids[i, j] = event.get('dex_platform_id', 0)
|
| 483 |
+
trade_direction_ids[i, j] = event.get('trade_direction', 0) # 0=buy, 1=sell
|
| 484 |
+
trade_mev_protection_ids[i, j] = event.get('mev_protection', 0) # 0, 1, 2...
|
| 485 |
+
trade_is_bundle_ids[i, j] = 1 if event.get('is_bundle') else 0 # 0=false, 1=true
|
| 486 |
+
num_feats = [
|
| 487 |
+
event.get('sol_amount', 0.0),
|
| 488 |
+
event.get('priority_fee', 0.0),
|
| 489 |
+
event.get('token_amount_pct_of_holding', 0.0),
|
| 490 |
+
event.get('quote_amount_pct_of_holding', 0.0),
|
| 491 |
+
event.get('slippage', 0.0),
|
| 492 |
+
event.get('token_amount_pct_to_total_supply', 0.0), # REPLACED price_impact
|
| 493 |
+
1.0 if event.get('success') else 0.0,
|
| 494 |
+
event.get('total_usd', 0.0)
|
| 495 |
+
]
|
| 496 |
+
smart_wallet_trade_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
|
| 497 |
+
|
| 498 |
+
elif event_type == 'PoolCreated':
|
| 499 |
+
# Get the quote token index
|
| 500 |
+
quote_t_addr = event.get('quote_token_address')
|
| 501 |
+
if quote_t_addr:
|
| 502 |
+
quote_token_indices[i, j] = token_addr_to_batch_idx.get(quote_t_addr, PAD_IDX_ENT)
|
| 503 |
+
|
| 504 |
+
pool_created_protocol_ids[i, j] = event.get('protocol_id', 0)
|
| 505 |
+
# Get numerical features
|
| 506 |
+
num_feats = [
|
| 507 |
+
event.get('base_amount', 0.0),
|
| 508 |
+
event.get('quote_amount', 0.0)
|
| 509 |
+
]
|
| 510 |
+
pool_created_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
|
| 511 |
+
|
| 512 |
+
elif event_type == 'LiquidityChange':
|
| 513 |
+
# Get the quote token index
|
| 514 |
+
quote_t_addr = event.get('quote_token_address')
|
| 515 |
+
if quote_t_addr:
|
| 516 |
+
quote_token_indices[i, j] = token_addr_to_batch_idx.get(quote_t_addr, PAD_IDX_ENT)
|
| 517 |
+
|
| 518 |
+
liquidity_change_type_ids[i, j] = event.get('change_type_id', 0)
|
| 519 |
+
# Get numerical features
|
| 520 |
+
num_feats = [event.get('quote_amount', 0.0)]
|
| 521 |
+
liquidity_change_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
|
| 522 |
+
|
| 523 |
+
elif event_type == 'FeeCollected':
|
| 524 |
+
# This event has the recipient wallet plus a single numerical feature (SOL amount).
|
| 525 |
+
num_feats = [
|
| 526 |
+
event.get('sol_amount', 0.0)
|
| 527 |
+
]
|
| 528 |
+
fee_collected_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
|
| 529 |
+
|
| 530 |
+
elif event_type == 'TokenBurn':
|
| 531 |
+
# This event has a wallet (handled by wallet_indices) and two numerical features.
|
| 532 |
+
num_feats = [
|
| 533 |
+
event.get('amount_pct_of_total_supply', 0.0),
|
| 534 |
+
event.get('amount_tokens_burned', 0.0)
|
| 535 |
+
]
|
| 536 |
+
token_burn_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
|
| 537 |
+
|
| 538 |
+
elif event_type == 'SupplyLock':
|
| 539 |
+
# This event has a wallet and two numerical features.
|
| 540 |
+
num_feats = [
|
| 541 |
+
event.get('amount_pct_of_total_supply', 0.0),
|
| 542 |
+
event.get('lock_duration', 0.0)
|
| 543 |
+
]
|
| 544 |
+
supply_lock_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
|
| 545 |
+
|
| 546 |
+
elif event_type == 'OnChain_Snapshot':
|
| 547 |
+
# This event is a global snapshot with 14 numerical features.
|
| 548 |
+
num_feats = [
|
| 549 |
+
event.get('total_holders', 0.0),
|
| 550 |
+
event.get('smart_traders', 0.0),
|
| 551 |
+
event.get('kols', 0.0),
|
| 552 |
+
event.get('holder_growth_rate', 0.0),
|
| 553 |
+
event.get('top_10_holder_pct', 0.0),
|
| 554 |
+
event.get('sniper_holding_pct', 0.0),
|
| 555 |
+
event.get('rat_wallets_holding_pct', 0.0),
|
| 556 |
+
event.get('bundle_holding_pct', 0.0),
|
| 557 |
+
event.get('current_market_cap', 0.0),
|
| 558 |
+
event.get('volume', 0.0),
|
| 559 |
+
event.get('buy_count', 0.0),
|
| 560 |
+
event.get('sell_count', 0.0),
|
| 561 |
+
event.get('total_txns', 0.0),
|
| 562 |
+
event.get('global_fees_paid', 0.0)
|
| 563 |
+
]
|
| 564 |
+
onchain_snapshot_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
|
| 565 |
+
|
| 566 |
+
elif event_type == 'TrendingToken':
|
| 567 |
+
# Get the trending token index
|
| 568 |
+
trending_t_addr = event.get('token_address')
|
| 569 |
+
if trending_t_addr:
|
| 570 |
+
trending_token_indices[i, j] = token_addr_to_batch_idx.get(trending_t_addr, PAD_IDX_ENT)
|
| 571 |
+
|
| 572 |
+
trending_token_source_ids[i, j] = event.get('list_source_id', 0)
|
| 573 |
+
trending_token_timeframe_ids[i, j] = event.get('timeframe_id', 0)
|
| 574 |
+
# --- FIXED: Invert rank so that 1 is the highest value ---
|
| 575 |
+
# Get numerical/categorical features
|
| 576 |
+
num_feats = [
|
| 577 |
+
1.0 / event.get('rank', 1e9) # Use a large number for rank 0 or missing to make it small
|
| 578 |
+
]
|
| 579 |
+
trending_token_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
|
| 580 |
+
|
| 581 |
+
elif event_type == 'BoostedToken':
|
| 582 |
+
# Get the boosted token index
|
| 583 |
+
boosted_t_addr = event.get('token_address')
|
| 584 |
+
if boosted_t_addr:
|
| 585 |
+
boosted_token_indices[i, j] = token_addr_to_batch_idx.get(boosted_t_addr, PAD_IDX_ENT)
|
| 586 |
+
|
| 587 |
+
# --- FIXED: Invert rank so that 1 is the highest value ---
|
| 588 |
+
# Get numerical features
|
| 589 |
+
num_feats = [
|
| 590 |
+
event.get('total_boost_amount', 0.0),
|
| 591 |
+
1.0 / event.get('rank', 1e9)
|
| 592 |
+
]
|
| 593 |
+
boosted_token_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
|
| 594 |
+
|
| 595 |
+
elif event_type == 'HolderSnapshot':
|
| 596 |
+
# --- FIXED: Store raw holder data, not an index ---
|
| 597 |
+
raw_holders = event.get('holders', [])
|
| 598 |
+
holder_snapshot_raw_data_list.append(raw_holders)
|
| 599 |
+
holder_snapshot_indices[i, j] = len(holder_snapshot_raw_data_list) # 1-based index to the list
|
| 600 |
+
|
| 601 |
+
elif event_type == 'Lighthouse_Snapshot':
|
| 602 |
+
lighthousesnapshot_protocol_ids[i, j] = event.get('protocol_id', 0)
|
| 603 |
+
lighthousesnapshot_timeframe_ids[i, j] = event.get('timeframe_id', 0)
|
| 604 |
+
num_feats = [
|
| 605 |
+
event.get('total_volume', 0.0),
|
| 606 |
+
event.get('total_transactions', 0.0),
|
| 607 |
+
event.get('total_traders', 0.0),
|
| 608 |
+
event.get('total_tokens_created', 0.0),
|
| 609 |
+
event.get('total_migrations', 0.0)
|
| 610 |
+
]
|
| 611 |
+
lighthousesnapshot_numerical_features[i, j, :] = torch.tensor(num_feats, dtype=self.dtype)
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
# --- UPDATED: Group all events that contain pre-computed text/image indices ---
|
| 615 |
+
elif event_type in ['XPost', 'XReply', 'XRetweet', 'XQuoteTweet', 'PumpReply', 'DexProfile_Updated', 'TikTok_Trending_Hashtag', 'XTrending_Hashtag']:
|
| 616 |
+
# Store raw event data to look up text/image indices later
|
| 617 |
+
# 1. Store raw text/media data
|
| 618 |
+
textual_event_data_list.append(event)
|
| 619 |
+
textual_event_indices[i, j] = len(textual_event_data_list) # 1-based index
|
| 620 |
+
# --- FIXED: Handle rank for trending hashtags ---
|
| 621 |
+
if event_type in ['TikTok_Trending_Hashtag', 'XTrending_Hashtag']:
|
| 622 |
+
global_trending_numerical_features[i, j, 0] = 1.0 / event.get('rank', 1e9)
|
| 623 |
+
|
| 624 |
+
# 2. Populate wallet pointer tensors based on the event type
|
| 625 |
+
# The main 'wallet_address' is already handled above.
|
| 626 |
+
# Here we handle the *other* wallets involved.
|
| 627 |
+
if event_type == 'XRetweet' or event_type == 'XQuoteTweet':
|
| 628 |
+
orig_author_addr = event.get('original_author_wallet_address')
|
| 629 |
+
if orig_author_addr:
|
| 630 |
+
# --- FIXED: Use the dedicated tensor for original authors ---
|
| 631 |
+
original_author_indices[i, j] = wallet_addr_to_batch_idx.get(orig_author_addr, PAD_IDX_ENT)
|
| 632 |
+
|
| 633 |
+
# The pre-computed embedding indices are already in the event dict.
|
| 634 |
+
# No need to populate image_indices here anymore.
|
| 635 |
+
# For XReply, the main tweet is a text/media embedding, not a wallet.
|
| 636 |
+
# For XPost, there's only one wallet, already handled.
|
| 637 |
+
|
| 638 |
+
# --- 4. Collate Dynamic Features (OHLC) AFTER collecting them ---
|
| 639 |
+
ohlc_inputs_dict = self._collate_ohlc_inputs(batch_chart_events)
|
| 640 |
+
|
| 641 |
+
# --- 6. Prepare final output dictionary ---
|
| 642 |
+
collated_batch = {
|
| 643 |
+
# Sequence Tensors
|
| 644 |
+
'event_type_ids': event_type_ids,
|
| 645 |
+
'timestamps_float': timestamps_float,
|
| 646 |
+
'relative_ts': relative_ts,
|
| 647 |
+
'attention_mask': attention_mask,
|
| 648 |
+
# Pointer Tensors
|
| 649 |
+
'wallet_indices': wallet_indices,
|
| 650 |
+
'token_indices': token_indices,
|
| 651 |
+
'quote_token_indices': quote_token_indices, # NEW
|
| 652 |
+
'trending_token_indices': trending_token_indices, # NEW
|
| 653 |
+
'boosted_token_indices': boosted_token_indices, # NEW
|
| 654 |
+
'holder_snapshot_indices': holder_snapshot_indices, # This now points to the generated embeddings
|
| 655 |
+
'textual_event_indices': textual_event_indices, # RENAMED
|
| 656 |
+
'ohlc_indices': ohlc_indices,
|
| 657 |
+
# Raw Data for Encoders
|
| 658 |
+
'embedding_pool': batch_embedding_pool, # NEW
|
| 659 |
+
'token_encoder_inputs': token_encoder_inputs,
|
| 660 |
+
'wallet_encoder_inputs': wallet_encoder_inputs, # ADDED BACK
|
| 661 |
+
'ohlc_price_tensors': ohlc_inputs_dict['price_tensor'],
|
| 662 |
+
'ohlc_interval_ids': ohlc_inputs_dict['interval_ids'],
|
| 663 |
+
'graph_updater_links': graph_updater_links,
|
| 664 |
+
'wallet_addr_to_batch_idx': wallet_addr_to_batch_idx, # NEW: Pass the mapping
|
| 665 |
+
|
| 666 |
+
'dest_wallet_indices': dest_wallet_indices, # ADDED THIS LINE
|
| 667 |
+
'original_author_indices': original_author_indices, # NEW
|
| 668 |
+
# --- NEW: Numerical Features for Events ---
|
| 669 |
+
'transfer_numerical_features': transfer_numerical_features,
|
| 670 |
+
'trade_numerical_features': trade_numerical_features,
|
| 671 |
+
'trade_dex_ids': trade_dex_ids,
|
| 672 |
+
'deployer_trade_numerical_features': deployer_trade_numerical_features,
|
| 673 |
+
'trade_direction_ids': trade_direction_ids, # NEW
|
| 674 |
+
'trade_mev_protection_ids': trade_mev_protection_ids, # NEW
|
| 675 |
+
'smart_wallet_trade_numerical_features': smart_wallet_trade_numerical_features,
|
| 676 |
+
'trade_is_bundle_ids': trade_is_bundle_ids, # NEW
|
| 677 |
+
'pool_created_numerical_features': pool_created_numerical_features,
|
| 678 |
+
'pool_created_protocol_ids': pool_created_protocol_ids, # NEW
|
| 679 |
+
'liquidity_change_numerical_features': liquidity_change_numerical_features,
|
| 680 |
+
'liquidity_change_type_ids': liquidity_change_type_ids, # NEW
|
| 681 |
+
'fee_collected_numerical_features': fee_collected_numerical_features, # NEW
|
| 682 |
+
'token_burn_numerical_features': token_burn_numerical_features, # NEW
|
| 683 |
+
'supply_lock_numerical_features': supply_lock_numerical_features, # NEW
|
| 684 |
+
'onchain_snapshot_numerical_features': onchain_snapshot_numerical_features, # NEW
|
| 685 |
+
'boosted_token_numerical_features': boosted_token_numerical_features,
|
| 686 |
+
'trending_token_numerical_features': trending_token_numerical_features,
|
| 687 |
+
'trending_token_source_ids': trending_token_source_ids, # NEW
|
| 688 |
+
'trending_token_timeframe_ids': trending_token_timeframe_ids, # NEW
|
| 689 |
+
'dexboost_paid_numerical_features': dexboost_paid_numerical_features, # NEW
|
| 690 |
+
'dexprofile_updated_flags': dexprofile_updated_flags, # NEW,
|
| 691 |
+
'global_trending_numerical_features': global_trending_numerical_features, # NEW
|
| 692 |
+
'chainsnapshot_numerical_features': chainsnapshot_numerical_features, # NEW
|
| 693 |
+
'lighthousesnapshot_numerical_features': lighthousesnapshot_numerical_features,
|
| 694 |
+
'lighthousesnapshot_protocol_ids': lighthousesnapshot_protocol_ids, # NEW
|
| 695 |
+
'lighthousesnapshot_timeframe_ids': lighthousesnapshot_timeframe_ids, # NEW
|
| 696 |
+
'migrated_protocol_ids': migrated_protocol_ids, # NEW
|
| 697 |
+
'alpha_group_ids': alpha_group_ids, # NEW
|
| 698 |
+
'channel_ids': channel_ids, # NEW
|
| 699 |
+
'exchange_ids': exchange_ids, # NEW
|
| 700 |
+
'holder_snapshot_raw_data': holder_snapshot_raw_data_list, # NEW: Raw data for end-to-end processing
|
| 701 |
+
'textual_event_data': textual_event_data_list, # RENAMED
|
| 702 |
+
# Labels
|
| 703 |
+
'labels': torch.stack([item['labels'] for item in batch]) if batch and 'labels' in batch[0] else None,
|
| 704 |
+
'labels_mask': torch.stack([item['labels_mask'] for item in batch]) if batch and 'labels_mask' in batch[0] else None
|
| 705 |
+
}
|
| 706 |
+
|
| 707 |
+
# Filter out None values (e.g., if no labels provided)
|
| 708 |
+
return {k: v for k, v in collated_batch.items() if v is not None}
|
data/data_fetcher.py
ADDED
|
@@ -0,0 +1,1009 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# data_fetcher.py
|
| 2 |
+
|
| 3 |
+
from typing import List, Dict, Any, Tuple, Set
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
import datetime, time
|
| 6 |
+
|
| 7 |
+
# We need the vocabulary for mapping IDs
|
| 8 |
+
import models.vocabulary as vocab
|
| 9 |
+
|
| 10 |
+
class DataFetcher:
|
| 11 |
+
"""
|
| 12 |
+
A dedicated class to handle all database queries for ClickHouse and Neo4j.
|
| 13 |
+
This keeps data fetching logic separate from the dataset and model.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
# --- Explicit column definitions for wallet profile & social fetches ---
|
| 17 |
+
PROFILE_BASE_COLUMNS = [
|
| 18 |
+
'wallet_address',
|
| 19 |
+
'updated_at',
|
| 20 |
+
'first_seen_ts',
|
| 21 |
+
'last_seen_ts',
|
| 22 |
+
'tags',
|
| 23 |
+
'deployed_tokens',
|
| 24 |
+
'funded_from',
|
| 25 |
+
'funded_timestamp',
|
| 26 |
+
'funded_signature',
|
| 27 |
+
'funded_amount'
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
PROFILE_METRIC_COLUMNS = [
|
| 31 |
+
'balance',
|
| 32 |
+
'transfers_in_count',
|
| 33 |
+
'transfers_out_count',
|
| 34 |
+
'spl_transfers_in_count',
|
| 35 |
+
'spl_transfers_out_count',
|
| 36 |
+
'total_buys_count',
|
| 37 |
+
'total_sells_count',
|
| 38 |
+
'total_winrate',
|
| 39 |
+
'stats_1d_realized_profit_sol',
|
| 40 |
+
'stats_1d_realized_profit_usd',
|
| 41 |
+
'stats_1d_realized_profit_pnl',
|
| 42 |
+
'stats_1d_buy_count',
|
| 43 |
+
'stats_1d_sell_count',
|
| 44 |
+
'stats_1d_transfer_in_count',
|
| 45 |
+
'stats_1d_transfer_out_count',
|
| 46 |
+
'stats_1d_avg_holding_period',
|
| 47 |
+
'stats_1d_total_bought_cost_sol',
|
| 48 |
+
'stats_1d_total_bought_cost_usd',
|
| 49 |
+
'stats_1d_total_sold_income_sol',
|
| 50 |
+
'stats_1d_total_sold_income_usd',
|
| 51 |
+
'stats_1d_total_fee',
|
| 52 |
+
'stats_1d_winrate',
|
| 53 |
+
'stats_1d_tokens_traded',
|
| 54 |
+
'stats_7d_realized_profit_sol',
|
| 55 |
+
'stats_7d_realized_profit_usd',
|
| 56 |
+
'stats_7d_realized_profit_pnl',
|
| 57 |
+
'stats_7d_buy_count',
|
| 58 |
+
'stats_7d_sell_count',
|
| 59 |
+
'stats_7d_transfer_in_count',
|
| 60 |
+
'stats_7d_transfer_out_count',
|
| 61 |
+
'stats_7d_avg_holding_period',
|
| 62 |
+
'stats_7d_total_bought_cost_sol',
|
| 63 |
+
'stats_7d_total_bought_cost_usd',
|
| 64 |
+
'stats_7d_total_sold_income_sol',
|
| 65 |
+
'stats_7d_total_sold_income_usd',
|
| 66 |
+
'stats_7d_total_fee',
|
| 67 |
+
'stats_7d_winrate',
|
| 68 |
+
'stats_7d_tokens_traded',
|
| 69 |
+
'stats_30d_realized_profit_sol',
|
| 70 |
+
'stats_30d_realized_profit_usd',
|
| 71 |
+
'stats_30d_realized_profit_pnl',
|
| 72 |
+
'stats_30d_buy_count',
|
| 73 |
+
'stats_30d_sell_count',
|
| 74 |
+
'stats_30d_transfer_in_count',
|
| 75 |
+
'stats_30d_transfer_out_count',
|
| 76 |
+
'stats_30d_avg_holding_period',
|
| 77 |
+
'stats_30d_total_bought_cost_sol',
|
| 78 |
+
'stats_30d_total_bought_cost_usd',
|
| 79 |
+
'stats_30d_total_sold_income_sol',
|
| 80 |
+
'stats_30d_total_sold_income_usd',
|
| 81 |
+
'stats_30d_total_fee',
|
| 82 |
+
'stats_30d_winrate',
|
| 83 |
+
'stats_30d_tokens_traded'
|
| 84 |
+
]
|
| 85 |
+
|
| 86 |
+
PROFILE_COLUMNS_FOR_QUERY = PROFILE_BASE_COLUMNS + PROFILE_METRIC_COLUMNS
|
| 87 |
+
|
| 88 |
+
SOCIAL_COLUMNS_FOR_QUERY = [
|
| 89 |
+
'wallet_address',
|
| 90 |
+
'pumpfun_username',
|
| 91 |
+
'twitter_username',
|
| 92 |
+
'telegram_channel',
|
| 93 |
+
'kolscan_name',
|
| 94 |
+
'cabalspy_name',
|
| 95 |
+
'axiom_kol_name'
|
| 96 |
+
]
|
| 97 |
+
def __init__(self, clickhouse_client: Any, neo4j_driver: Any):
|
| 98 |
+
self.db_client = clickhouse_client
|
| 99 |
+
self.graph_client = neo4j_driver
|
| 100 |
+
print("DataFetcher instantiated.")
|
| 101 |
+
|
| 102 |
+
def get_all_mints(self, start_date: Optional[datetime.datetime] = None) -> List[Dict[str, Any]]:
|
| 103 |
+
"""
|
| 104 |
+
Fetches a list of all mint events to serve as dataset samples.
|
| 105 |
+
Can be filtered to only include mints on or after a given start_date.
|
| 106 |
+
"""
|
| 107 |
+
query = "SELECT mint_address, timestamp, creator_address, protocol, token_name, token_symbol, token_uri, total_supply, token_decimals FROM mints"
|
| 108 |
+
params = {}
|
| 109 |
+
where_clauses = []
|
| 110 |
+
|
| 111 |
+
if start_date:
|
| 112 |
+
where_clauses.append("timestamp >= %(start_date)s")
|
| 113 |
+
params['start_date'] = start_date
|
| 114 |
+
|
| 115 |
+
if where_clauses:
|
| 116 |
+
query += " WHERE " + " AND ".join(where_clauses)
|
| 117 |
+
|
| 118 |
+
print(f"INFO: Executing query to get all mints: `{query}` with params: {params}")
|
| 119 |
+
try:
|
| 120 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 121 |
+
if not rows:
|
| 122 |
+
return []
|
| 123 |
+
columns = [col[0] for col in columns_info]
|
| 124 |
+
result = [dict(zip(columns, row)) for row in rows]
|
| 125 |
+
if not result:
|
| 126 |
+
return []
|
| 127 |
+
return result
|
| 128 |
+
except Exception as e:
|
| 129 |
+
print(f"ERROR: Failed to fetch token addresses from ClickHouse: {e}")
|
| 130 |
+
print("INFO: Falling back to mock token addresses for development.")
|
| 131 |
+
return [{'mint_address': 'tknA_real', 'timestamp': datetime.datetime.now(datetime.timezone.utc), 'creator_address': 'addr_Creator_Real', 'protocol': 0}]
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def fetch_mint_record(self, token_address: str) -> Dict[str, Any]:
|
| 135 |
+
"""
|
| 136 |
+
Fetches the raw mint record for a token from the 'mints' table.
|
| 137 |
+
"""
|
| 138 |
+
query = f"SELECT timestamp, creator_address, mint_address, protocol FROM mints WHERE mint_address = '{token_address}' ORDER BY timestamp ASC LIMIT 1"
|
| 139 |
+
print(f"INFO: Executing query to fetch mint record: `{query}`")
|
| 140 |
+
|
| 141 |
+
# Assumes the client returns a list of dicts or can be converted
|
| 142 |
+
# Using column names from your schema
|
| 143 |
+
columns = ['timestamp', 'creator_address', 'mint_address', 'protocol']
|
| 144 |
+
try:
|
| 145 |
+
result = self.db_client.execute(query)
|
| 146 |
+
|
| 147 |
+
if not result or not result[0]:
|
| 148 |
+
raise ValueError(f"No mint event found for token {token_address}")
|
| 149 |
+
|
| 150 |
+
# Convert the tuple result into a dictionary
|
| 151 |
+
record = dict(zip(columns, result[0]))
|
| 152 |
+
return record
|
| 153 |
+
except Exception as e:
|
| 154 |
+
print(f"ERROR: Failed to fetch mint record for {token_address}: {e}")
|
| 155 |
+
print("INFO: Falling back to mock mint record for development.")
|
| 156 |
+
# Fallback for development if DB connection fails
|
| 157 |
+
return {
|
| 158 |
+
'timestamp': datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(days=1),
|
| 159 |
+
'creator_address': 'addr_Creator_Real',
|
| 160 |
+
'mint_address': token_address,
|
| 161 |
+
'protocol': vocab.PROTOCOL_TO_ID.get("Pump V1", 0)
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
def fetch_wallet_profiles(self, wallet_addresses: List[str], T_cutoff: datetime.datetime) -> Dict[str, Dict[str, Any]]:
|
| 165 |
+
"""
|
| 166 |
+
Convenience wrapper around fetch_wallet_profiles_and_socials for profile-only data.
|
| 167 |
+
"""
|
| 168 |
+
profiles, _ = self.fetch_wallet_profiles_and_socials(wallet_addresses, T_cutoff)
|
| 169 |
+
return profiles
|
| 170 |
+
|
| 171 |
+
def fetch_wallet_socials(self, wallet_addresses: List[str]) -> Dict[str, Dict[str, Any]]:
|
| 172 |
+
"""
|
| 173 |
+
Fetches wallet social records for a list of wallet addresses.
|
| 174 |
+
Returns a dictionary mapping wallet_address to its social data.
|
| 175 |
+
"""
|
| 176 |
+
if not wallet_addresses:
|
| 177 |
+
return {}
|
| 178 |
+
|
| 179 |
+
query = "SELECT * FROM wallet_socials WHERE wallet_address IN %(addresses)s"
|
| 180 |
+
params = {'addresses': wallet_addresses}
|
| 181 |
+
print(f"INFO: Executing query to fetch wallet socials for {len(wallet_addresses)} wallets.")
|
| 182 |
+
|
| 183 |
+
try:
|
| 184 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 185 |
+
if not rows:
|
| 186 |
+
return {}
|
| 187 |
+
|
| 188 |
+
columns = [col[0] for col in columns_info]
|
| 189 |
+
socials = {}
|
| 190 |
+
for row in rows:
|
| 191 |
+
social_dict = dict(zip(columns, row))
|
| 192 |
+
wallet_addr = social_dict.get('wallet_address')
|
| 193 |
+
if wallet_addr:
|
| 194 |
+
socials[wallet_addr] = social_dict
|
| 195 |
+
return socials
|
| 196 |
+
|
| 197 |
+
except Exception as e:
|
| 198 |
+
print(f"ERROR: Failed to fetch wallet socials: {e}")
|
| 199 |
+
print("INFO: Returning empty dictionary for wallet socials.")
|
| 200 |
+
return {}
|
| 201 |
+
|
| 202 |
+
def fetch_wallet_profiles_and_socials(self,
|
| 203 |
+
wallet_addresses: List[str],
|
| 204 |
+
T_cutoff: datetime.datetime) -> Tuple[Dict[str, Dict[str, Any]], Dict[str, Dict[str, Any]]]:
|
| 205 |
+
"""
|
| 206 |
+
Fetches wallet profiles (time-aware) and socials for all requested wallets in a single query.
|
| 207 |
+
Returns two dictionaries: profiles, socials.
|
| 208 |
+
"""
|
| 209 |
+
if not wallet_addresses:
|
| 210 |
+
return {}, {}
|
| 211 |
+
|
| 212 |
+
social_columns = self.SOCIAL_COLUMNS_FOR_QUERY
|
| 213 |
+
|
| 214 |
+
profile_base_cols = self.PROFILE_BASE_COLUMNS
|
| 215 |
+
profile_metric_cols = self.PROFILE_METRIC_COLUMNS
|
| 216 |
+
|
| 217 |
+
profile_base_str = ",\n ".join(profile_base_cols)
|
| 218 |
+
metric_projection_cols = ['wallet_address', 'updated_at'] + profile_metric_cols
|
| 219 |
+
profile_metric_str = ",\n ".join(metric_projection_cols)
|
| 220 |
+
|
| 221 |
+
profile_base_select_cols = [col for col in profile_base_cols if col != 'wallet_address']
|
| 222 |
+
profile_metric_select_cols = [
|
| 223 |
+
col for col in profile_metric_cols if col not in ('wallet_address',)
|
| 224 |
+
]
|
| 225 |
+
social_select_cols = [col for col in social_columns if col != 'wallet_address']
|
| 226 |
+
|
| 227 |
+
select_expressions = []
|
| 228 |
+
for col in profile_base_select_cols:
|
| 229 |
+
select_expressions.append(f"lp.{col} AS profile__{col}")
|
| 230 |
+
for col in profile_metric_select_cols:
|
| 231 |
+
select_expressions.append(f"lm.{col} AS profile__{col}")
|
| 232 |
+
for col in social_select_cols:
|
| 233 |
+
select_expressions.append(f"ws.{col} AS social__{col}")
|
| 234 |
+
select_clause = ""
|
| 235 |
+
if select_expressions:
|
| 236 |
+
select_clause = ",\n " + ",\n ".join(select_expressions)
|
| 237 |
+
|
| 238 |
+
query = f"""
|
| 239 |
+
WITH ranked_profiles AS (
|
| 240 |
+
SELECT
|
| 241 |
+
{profile_base_str},
|
| 242 |
+
ROW_NUMBER() OVER (PARTITION BY wallet_address ORDER BY updated_at DESC) AS rn
|
| 243 |
+
FROM wallet_profiles
|
| 244 |
+
WHERE wallet_address IN %(addresses)s
|
| 245 |
+
),
|
| 246 |
+
latest_profiles AS (
|
| 247 |
+
SELECT
|
| 248 |
+
{profile_base_str}
|
| 249 |
+
FROM ranked_profiles
|
| 250 |
+
WHERE rn = 1
|
| 251 |
+
),
|
| 252 |
+
ranked_metrics AS (
|
| 253 |
+
SELECT
|
| 254 |
+
{profile_metric_str},
|
| 255 |
+
ROW_NUMBER() OVER (PARTITION BY wallet_address ORDER BY updated_at DESC) AS rn
|
| 256 |
+
FROM wallet_profile_metrics
|
| 257 |
+
WHERE
|
| 258 |
+
wallet_address IN %(addresses)s
|
| 259 |
+
AND updated_at <= %(T_cutoff)s
|
| 260 |
+
),
|
| 261 |
+
latest_metrics AS (
|
| 262 |
+
SELECT
|
| 263 |
+
{profile_metric_str}
|
| 264 |
+
FROM ranked_metrics
|
| 265 |
+
WHERE rn = 1
|
| 266 |
+
),
|
| 267 |
+
requested_wallets AS (
|
| 268 |
+
SELECT DISTINCT wallet_address
|
| 269 |
+
FROM (SELECT arrayJoin(%(addresses)s) AS wallet_address)
|
| 270 |
+
)
|
| 271 |
+
SELECT
|
| 272 |
+
rw.wallet_address AS wallet_address
|
| 273 |
+
{select_clause}
|
| 274 |
+
FROM requested_wallets AS rw
|
| 275 |
+
LEFT JOIN latest_profiles AS lp ON rw.wallet_address = lp.wallet_address
|
| 276 |
+
LEFT JOIN latest_metrics AS lm ON rw.wallet_address = lm.wallet_address
|
| 277 |
+
LEFT JOIN wallet_socials AS ws ON rw.wallet_address = ws.wallet_address;
|
| 278 |
+
"""
|
| 279 |
+
|
| 280 |
+
params = {'addresses': wallet_addresses, 'T_cutoff': T_cutoff}
|
| 281 |
+
print(f"INFO: Executing combined query for profiles+socials on {len(wallet_addresses)} wallets.")
|
| 282 |
+
|
| 283 |
+
try:
|
| 284 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 285 |
+
if not rows:
|
| 286 |
+
return {}, {}
|
| 287 |
+
|
| 288 |
+
columns = [col[0] for col in columns_info]
|
| 289 |
+
profiles: Dict[str, Dict[str, Any]] = {}
|
| 290 |
+
socials: Dict[str, Dict[str, Any]] = {}
|
| 291 |
+
|
| 292 |
+
profile_keys = [f"profile__{col}" for col in (profile_base_select_cols + profile_metric_select_cols)]
|
| 293 |
+
social_keys = [f"social__{col}" for col in social_select_cols]
|
| 294 |
+
|
| 295 |
+
for row in rows:
|
| 296 |
+
row_dict = dict(zip(columns, row))
|
| 297 |
+
wallet_addr = row_dict.get('wallet_address')
|
| 298 |
+
if not wallet_addr:
|
| 299 |
+
continue
|
| 300 |
+
|
| 301 |
+
profile_data = {}
|
| 302 |
+
if profile_keys:
|
| 303 |
+
for pref_key in profile_keys:
|
| 304 |
+
if pref_key in row_dict:
|
| 305 |
+
value = row_dict[pref_key]
|
| 306 |
+
profile_data[pref_key.replace('profile__', '')] = value
|
| 307 |
+
|
| 308 |
+
if profile_data and any(value is not None for value in profile_data.values()):
|
| 309 |
+
profile_data['wallet_address'] = wallet_addr
|
| 310 |
+
profiles[wallet_addr] = profile_data
|
| 311 |
+
|
| 312 |
+
social_data = {}
|
| 313 |
+
if social_keys:
|
| 314 |
+
for pref_key in social_keys:
|
| 315 |
+
if pref_key in row_dict:
|
| 316 |
+
value = row_dict[pref_key]
|
| 317 |
+
social_data[pref_key.replace('social__', '')] = value
|
| 318 |
+
|
| 319 |
+
if social_data and any(value is not None for value in social_data.values()):
|
| 320 |
+
social_data['wallet_address'] = wallet_addr
|
| 321 |
+
socials[wallet_addr] = social_data
|
| 322 |
+
|
| 323 |
+
return profiles, socials
|
| 324 |
+
|
| 325 |
+
except Exception as e:
|
| 326 |
+
print(f"ERROR: Combined profile/social query failed: {e}")
|
| 327 |
+
print("INFO: Falling back to separate queries.")
|
| 328 |
+
profiles = self.fetch_wallet_profiles(wallet_addresses, T_cutoff)
|
| 329 |
+
socials = self.fetch_wallet_socials(wallet_addresses)
|
| 330 |
+
return profiles, socials
|
| 331 |
+
|
| 332 |
+
def fetch_wallet_holdings(self, wallet_addresses: List[str], T_cutoff: datetime.datetime) -> Dict[str, List[Dict[str, Any]]]:
|
| 333 |
+
"""
|
| 334 |
+
Fetches top 3 wallet holding records for a list of wallet addresses that were active at T_cutoff.
|
| 335 |
+
Returns a dictionary mapping wallet_address to a LIST of its holding data.
|
| 336 |
+
"""
|
| 337 |
+
if not wallet_addresses:
|
| 338 |
+
return {}
|
| 339 |
+
|
| 340 |
+
# --- NEW: Time-aware query based on user's superior logic ---
|
| 341 |
+
# 1. For each holding, find the latest state at or before T_cutoff.
|
| 342 |
+
# 2. Filter for holdings where the balance was greater than 0.
|
| 343 |
+
# 3. Rank these active holdings by USD volume and take the top 3 per wallet.
|
| 344 |
+
query = """
|
| 345 |
+
WITH point_in_time_holdings AS (
|
| 346 |
+
SELECT
|
| 347 |
+
*,
|
| 348 |
+
COALESCE(history_bought_cost_sol, 0) + COALESCE(history_sold_income_sol, 0) AS total_volume_usd,
|
| 349 |
+
ROW_NUMBER() OVER(PARTITION BY wallet_address, mint_address ORDER BY updated_at DESC) as rn_per_holding
|
| 350 |
+
FROM wallet_holdings
|
| 351 |
+
WHERE
|
| 352 |
+
wallet_address IN %(addresses)s
|
| 353 |
+
AND updated_at <= %(T_cutoff)s
|
| 354 |
+
),
|
| 355 |
+
ranked_active_holdings AS (
|
| 356 |
+
SELECT *,
|
| 357 |
+
ROW_NUMBER() OVER(PARTITION BY wallet_address ORDER BY total_volume_usd DESC) as rn_per_wallet
|
| 358 |
+
FROM point_in_time_holdings
|
| 359 |
+
WHERE rn_per_holding = 1 AND current_balance > 0
|
| 360 |
+
)
|
| 361 |
+
SELECT *
|
| 362 |
+
FROM ranked_active_holdings
|
| 363 |
+
WHERE rn_per_wallet <= 3;
|
| 364 |
+
"""
|
| 365 |
+
params = {'addresses': wallet_addresses, 'T_cutoff': T_cutoff}
|
| 366 |
+
print(f"INFO: Executing query to fetch wallet holdings for {len(wallet_addresses)} wallets.")
|
| 367 |
+
|
| 368 |
+
try:
|
| 369 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 370 |
+
if not rows:
|
| 371 |
+
return {}
|
| 372 |
+
|
| 373 |
+
columns = [col[0] for col in columns_info]
|
| 374 |
+
holdings = defaultdict(list)
|
| 375 |
+
for row in rows:
|
| 376 |
+
holding_dict = dict(zip(columns, row))
|
| 377 |
+
wallet_addr = holding_dict.get('wallet_address')
|
| 378 |
+
if wallet_addr:
|
| 379 |
+
holdings[wallet_addr].append(holding_dict)
|
| 380 |
+
return dict(holdings)
|
| 381 |
+
|
| 382 |
+
except Exception as e:
|
| 383 |
+
print(f"ERROR: Failed to fetch wallet holdings: {e}")
|
| 384 |
+
print("INFO: Returning empty dictionary for wallet holdings.")
|
| 385 |
+
return {}
|
| 386 |
+
|
| 387 |
+
def fetch_graph_links(self,
|
| 388 |
+
initial_addresses: List[str],
|
| 389 |
+
T_cutoff: datetime.datetime,
|
| 390 |
+
max_degrees: int = 2) -> Tuple[Dict[str, str], Dict[str, Dict[str, Any]]]:
|
| 391 |
+
"""
|
| 392 |
+
Fetches graph links from Neo4j, traversing up to a max degree of separation.
|
| 393 |
+
|
| 394 |
+
Args:
|
| 395 |
+
initial_addresses: A list of starting wallet or token addresses.
|
| 396 |
+
max_degrees: The maximum number of hops to traverse in the graph.
|
| 397 |
+
|
| 398 |
+
Returns:
|
| 399 |
+
A tuple containing:
|
| 400 |
+
- A dictionary mapping entity addresses to their type ('Wallet' or 'Token').
|
| 401 |
+
- A dictionary of aggregated links, structured for the GraphUpdater.
|
| 402 |
+
"""
|
| 403 |
+
if not initial_addresses:
|
| 404 |
+
return set(), {}
|
| 405 |
+
|
| 406 |
+
cutoff_ts = int(T_cutoff.timestamp())
|
| 407 |
+
|
| 408 |
+
print(f"INFO: Fetching graph links up to {max_degrees} degrees for {len(initial_addresses)} initial entities...")
|
| 409 |
+
try:
|
| 410 |
+
with self.graph_client.session() as session:
|
| 411 |
+
all_entities = {addr: 'Token' for addr in initial_addresses} # Assume initial are tokens
|
| 412 |
+
newly_found_entities = set(initial_addresses)
|
| 413 |
+
aggregated_links = defaultdict(lambda: {'links': [], 'edges': []})
|
| 414 |
+
|
| 415 |
+
for i in range(max_degrees):
|
| 416 |
+
if not newly_found_entities:
|
| 417 |
+
break
|
| 418 |
+
|
| 419 |
+
print(f" - Degree {i+1}: Traversing from {len(newly_found_entities)} new entities...")
|
| 420 |
+
|
| 421 |
+
# Cypher query to find direct neighbors of the current frontier
|
| 422 |
+
query = """
|
| 423 |
+
MATCH (a)-[r]-(b)
|
| 424 |
+
WHERE a.address IN $addresses
|
| 425 |
+
RETURN a.address AS source_address, type(r) AS link_type, properties(r) AS link_props, b.address AS dest_address, labels(b)[0] AS dest_type
|
| 426 |
+
"""
|
| 427 |
+
params = {'addresses': list(newly_found_entities)}
|
| 428 |
+
result = session.run(query, params)
|
| 429 |
+
|
| 430 |
+
current_degree_new_entities = set()
|
| 431 |
+
for record in result:
|
| 432 |
+
link_type = record['link_type']
|
| 433 |
+
link_props = dict(record['link_props'])
|
| 434 |
+
link_ts_raw = link_props.get('timestamp')
|
| 435 |
+
try:
|
| 436 |
+
link_ts = int(link_ts_raw)
|
| 437 |
+
except (TypeError, ValueError):
|
| 438 |
+
continue
|
| 439 |
+
if link_ts > cutoff_ts:
|
| 440 |
+
continue
|
| 441 |
+
source_addr = record['source_address']
|
| 442 |
+
dest_addr = record['dest_address']
|
| 443 |
+
dest_type = record['dest_type']
|
| 444 |
+
|
| 445 |
+
# Add the link and edge data
|
| 446 |
+
aggregated_links[link_type]['links'].append(link_props)
|
| 447 |
+
aggregated_links[link_type]['edges'].append((source_addr, dest_addr))
|
| 448 |
+
|
| 449 |
+
# If we found a new entity, add it to the set for the next iteration
|
| 450 |
+
if dest_addr not in all_entities.keys():
|
| 451 |
+
current_degree_new_entities.add(dest_addr)
|
| 452 |
+
all_entities[dest_addr] = dest_type
|
| 453 |
+
|
| 454 |
+
newly_found_entities = current_degree_new_entities
|
| 455 |
+
|
| 456 |
+
return all_entities, dict(aggregated_links)
|
| 457 |
+
except Exception as e:
|
| 458 |
+
print(f"ERROR: Failed to fetch graph links from Neo4j: {e}")
|
| 459 |
+
return {addr: 'Token' for addr in initial_addresses}, {}
|
| 460 |
+
|
| 461 |
+
def fetch_token_data(self, token_addresses: List[str], T_cutoff: datetime.datetime) -> Dict[str, Dict[str, Any]]:
|
| 462 |
+
"""
|
| 463 |
+
Fetches the latest token data for each address at or before T_cutoff.
|
| 464 |
+
Returns a dictionary mapping token_address to its data.
|
| 465 |
+
"""
|
| 466 |
+
if not token_addresses:
|
| 467 |
+
return {}
|
| 468 |
+
|
| 469 |
+
# --- NEW: Time-aware query for historical token data ---
|
| 470 |
+
query = """
|
| 471 |
+
WITH ranked_tokens AS (
|
| 472 |
+
SELECT
|
| 473 |
+
*,
|
| 474 |
+
ROW_NUMBER() OVER (PARTITION BY token_address ORDER BY updated_at DESC) as rn
|
| 475 |
+
FROM tokens
|
| 476 |
+
WHERE
|
| 477 |
+
token_address IN %(addresses)s
|
| 478 |
+
AND updated_at <= %(T_cutoff)s
|
| 479 |
+
)
|
| 480 |
+
SELECT token_address, name, symbol, token_uri, protocol, total_supply, decimals
|
| 481 |
+
FROM ranked_tokens
|
| 482 |
+
WHERE rn = 1;
|
| 483 |
+
"""
|
| 484 |
+
params = {'addresses': token_addresses, 'T_cutoff': T_cutoff}
|
| 485 |
+
print(f"INFO: Executing query to fetch token data for {len(token_addresses)} tokens.")
|
| 486 |
+
|
| 487 |
+
try:
|
| 488 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 489 |
+
|
| 490 |
+
if not rows:
|
| 491 |
+
return {}
|
| 492 |
+
|
| 493 |
+
# Get column names from the query result description
|
| 494 |
+
columns = [col[0] for col in columns_info]
|
| 495 |
+
|
| 496 |
+
tokens = {}
|
| 497 |
+
for row in rows:
|
| 498 |
+
token_dict = dict(zip(columns, row))
|
| 499 |
+
token_addr = token_dict.get('token_address')
|
| 500 |
+
if token_addr:
|
| 501 |
+
# The 'tokens' table in the schema has 'token_address' but the
|
| 502 |
+
# collator expects 'address'. We'll add it for compatibility.
|
| 503 |
+
token_dict['address'] = token_addr
|
| 504 |
+
tokens[token_addr] = token_dict
|
| 505 |
+
return tokens
|
| 506 |
+
|
| 507 |
+
except Exception as e:
|
| 508 |
+
print(f"ERROR: Failed to fetch token data: {e}")
|
| 509 |
+
print("INFO: Returning empty dictionary for token data.")
|
| 510 |
+
return {}
|
| 511 |
+
|
| 512 |
+
def fetch_deployed_token_details(self, token_addresses: List[str], T_cutoff: datetime.datetime) -> Dict[str, Dict[str, Any]]:
|
| 513 |
+
"""
|
| 514 |
+
Fetches historical details for deployed tokens at or before T_cutoff.
|
| 515 |
+
"""
|
| 516 |
+
if not token_addresses:
|
| 517 |
+
return {}
|
| 518 |
+
|
| 519 |
+
# --- NEW: Time-aware query for historical deployed token details ---
|
| 520 |
+
query = """
|
| 521 |
+
WITH ranked_tokens AS (
|
| 522 |
+
SELECT
|
| 523 |
+
*,
|
| 524 |
+
ROW_NUMBER() OVER (PARTITION BY token_address ORDER BY updated_at DESC) as rn
|
| 525 |
+
FROM tokens
|
| 526 |
+
WHERE
|
| 527 |
+
token_address IN %(addresses)s
|
| 528 |
+
AND updated_at <= %(T_cutoff)s
|
| 529 |
+
),
|
| 530 |
+
ranked_token_metrics AS (
|
| 531 |
+
SELECT
|
| 532 |
+
token_address,
|
| 533 |
+
ath_price_usd,
|
| 534 |
+
ROW_NUMBER() OVER (PARTITION BY token_address ORDER BY updated_at DESC) as rn
|
| 535 |
+
FROM token_metrics
|
| 536 |
+
WHERE
|
| 537 |
+
token_address IN %(addresses)s
|
| 538 |
+
AND updated_at <= %(T_cutoff)s
|
| 539 |
+
),
|
| 540 |
+
latest_tokens AS (
|
| 541 |
+
SELECT *
|
| 542 |
+
FROM ranked_tokens
|
| 543 |
+
WHERE rn = 1
|
| 544 |
+
),
|
| 545 |
+
latest_token_metrics AS (
|
| 546 |
+
SELECT *
|
| 547 |
+
FROM ranked_token_metrics
|
| 548 |
+
WHERE rn = 1
|
| 549 |
+
)
|
| 550 |
+
SELECT
|
| 551 |
+
lt.token_address,
|
| 552 |
+
lt.created_at,
|
| 553 |
+
lt.updated_at,
|
| 554 |
+
ltm.ath_price_usd,
|
| 555 |
+
lt.total_supply,
|
| 556 |
+
lt.decimals,
|
| 557 |
+
(lt.launchpad != lt.protocol) AS has_migrated
|
| 558 |
+
FROM latest_tokens AS lt
|
| 559 |
+
LEFT JOIN latest_token_metrics AS ltm
|
| 560 |
+
ON lt.token_address = ltm.token_address;
|
| 561 |
+
"""
|
| 562 |
+
params = {'addresses': token_addresses, 'T_cutoff': T_cutoff}
|
| 563 |
+
print(f"INFO: Executing query to fetch deployed token details for {len(token_addresses)} tokens.")
|
| 564 |
+
|
| 565 |
+
try:
|
| 566 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 567 |
+
if not rows:
|
| 568 |
+
return {}
|
| 569 |
+
|
| 570 |
+
columns = [col[0] for col in columns_info]
|
| 571 |
+
token_details = {row[0]: dict(zip(columns, row)) for row in rows}
|
| 572 |
+
return token_details
|
| 573 |
+
except Exception as e:
|
| 574 |
+
print(f"ERROR: Failed to fetch deployed token details: {e}")
|
| 575 |
+
return {}
|
| 576 |
+
|
| 577 |
+
def fetch_trades_for_token(self, token_address: str, T_cutoff: datetime.datetime, count_threshold: int, early_limit: int, recent_limit: int) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
|
| 578 |
+
"""
|
| 579 |
+
Fetches trades for a token, using a 3-part H/B/H strategy if the total count exceeds a threshold.
|
| 580 |
+
Returns three lists: early_trades, middle_trades, recent_trades.
|
| 581 |
+
"""
|
| 582 |
+
if not token_address:
|
| 583 |
+
return [], [], []
|
| 584 |
+
|
| 585 |
+
params = {'token_address': token_address, 'T_cutoff': T_cutoff}
|
| 586 |
+
|
| 587 |
+
# 1. Get the total count of trades for the token before the cutoff
|
| 588 |
+
count_query = "SELECT count() FROM trades WHERE base_address = %(token_address)s AND timestamp <= %(T_cutoff)s"
|
| 589 |
+
try:
|
| 590 |
+
total_trades = self.db_client.execute(count_query, params)[0][0]
|
| 591 |
+
print(f"INFO: Found {total_trades} total trades for token {token_address} before {T_cutoff}.")
|
| 592 |
+
except Exception as e:
|
| 593 |
+
print(f"ERROR: Could not count trades for token {token_address}: {e}")
|
| 594 |
+
return [], [], []
|
| 595 |
+
|
| 596 |
+
# 2. Decide which query to use based on the count
|
| 597 |
+
if total_trades < count_threshold:
|
| 598 |
+
print("INFO: Fetching all trades (count is below H/B/H threshold).")
|
| 599 |
+
query = "SELECT * FROM trades WHERE base_address = %(token_address)s AND timestamp <= %(T_cutoff)s ORDER BY timestamp ASC"
|
| 600 |
+
try:
|
| 601 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 602 |
+
if not rows: return [], [], []
|
| 603 |
+
columns = [col[0] for col in columns_info]
|
| 604 |
+
all_trades = [dict(zip(columns, row)) for row in rows]
|
| 605 |
+
# When not using HBH, all trades are considered "early"
|
| 606 |
+
return all_trades, [], []
|
| 607 |
+
except Exception as e:
|
| 608 |
+
print(f"ERROR: Failed to fetch all trades for token {token_address}: {e}")
|
| 609 |
+
return [], [], []
|
| 610 |
+
|
| 611 |
+
# 3. Use the H/B/H strategy if the count is high
|
| 612 |
+
print("INFO: Fetching trades using 3-part High-Def/Blurry/High-Def strategy.")
|
| 613 |
+
try:
|
| 614 |
+
# Fetch Early (High-Def)
|
| 615 |
+
early_query = "SELECT * FROM trades WHERE base_address = %(token_address)s AND timestamp <= %(T_cutoff)s ORDER BY timestamp ASC LIMIT %(limit)s"
|
| 616 |
+
early_rows, early_cols_info = self.db_client.execute(early_query, {'token_address': token_address, 'T_cutoff': T_cutoff, 'limit': early_limit}, with_column_types=True)
|
| 617 |
+
early_trades = [dict(zip([c[0] for c in early_cols_info], r)) for r in early_rows] if early_rows else []
|
| 618 |
+
|
| 619 |
+
# Fetch Recent (High-Def)
|
| 620 |
+
recent_query = "SELECT * FROM trades WHERE base_address = %(token_address)s AND timestamp <= %(T_cutoff)s ORDER BY timestamp DESC LIMIT %(limit)s"
|
| 621 |
+
recent_rows, recent_cols_info = self.db_client.execute(recent_query, {'token_address': token_address, 'T_cutoff': T_cutoff, 'limit': recent_limit}, with_column_types=True)
|
| 622 |
+
recent_trades = [dict(zip([c[0] for c in recent_cols_info], r)) for r in recent_rows] if recent_rows else []
|
| 623 |
+
recent_trades.reverse() # Order ASC
|
| 624 |
+
|
| 625 |
+
# Fetch Middle (Blurry - successful trades only)
|
| 626 |
+
middle_trades = []
|
| 627 |
+
if early_trades and recent_trades:
|
| 628 |
+
start_middle_ts = early_trades[-1]['timestamp']
|
| 629 |
+
end_middle_ts = recent_trades[0]['timestamp']
|
| 630 |
+
if start_middle_ts < end_middle_ts:
|
| 631 |
+
middle_query = """
|
| 632 |
+
SELECT * FROM trades
|
| 633 |
+
WHERE base_address = %(token_address)s
|
| 634 |
+
AND success = true
|
| 635 |
+
AND timestamp > %(start_ts)s
|
| 636 |
+
AND timestamp < %(end_ts)s
|
| 637 |
+
ORDER BY timestamp ASC
|
| 638 |
+
"""
|
| 639 |
+
middle_params = {'token_address': token_address, 'start_ts': start_middle_ts, 'end_ts': end_middle_ts}
|
| 640 |
+
middle_rows, middle_cols_info = self.db_client.execute(middle_query, middle_params, with_column_types=True)
|
| 641 |
+
middle_trades = [dict(zip([c[0] for c in middle_cols_info], r)) for r in middle_rows] if middle_rows else []
|
| 642 |
+
|
| 643 |
+
return early_trades, middle_trades, recent_trades
|
| 644 |
+
|
| 645 |
+
except Exception as e:
|
| 646 |
+
print(f"ERROR: Failed to fetch H/B/H trades for token {token_address}: {e}")
|
| 647 |
+
return [], [], []
|
| 648 |
+
|
| 649 |
+
def fetch_future_trades_for_token(self,
|
| 650 |
+
token_address: str,
|
| 651 |
+
start_ts: datetime.datetime,
|
| 652 |
+
end_ts: datetime.datetime) -> List[Dict[str, Any]]:
|
| 653 |
+
"""
|
| 654 |
+
Fetches successful trades for a token in the window (start_ts, end_ts].
|
| 655 |
+
Used for constructing label targets beyond the cutoff.
|
| 656 |
+
"""
|
| 657 |
+
if not token_address or start_ts is None or end_ts is None or start_ts >= end_ts:
|
| 658 |
+
return []
|
| 659 |
+
|
| 660 |
+
query = """
|
| 661 |
+
SELECT *
|
| 662 |
+
FROM trades
|
| 663 |
+
WHERE base_address = %(token_address)s
|
| 664 |
+
AND success = true
|
| 665 |
+
AND timestamp > %(start_ts)s
|
| 666 |
+
AND timestamp <= %(end_ts)s
|
| 667 |
+
ORDER BY timestamp ASC
|
| 668 |
+
"""
|
| 669 |
+
params = {
|
| 670 |
+
'token_address': token_address,
|
| 671 |
+
'start_ts': start_ts,
|
| 672 |
+
'end_ts': end_ts
|
| 673 |
+
}
|
| 674 |
+
|
| 675 |
+
try:
|
| 676 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 677 |
+
if not rows:
|
| 678 |
+
return []
|
| 679 |
+
columns = [col[0] for col in columns_info]
|
| 680 |
+
return [dict(zip(columns, row)) for row in rows]
|
| 681 |
+
except Exception as e:
|
| 682 |
+
print(f"ERROR: Failed to fetch future trades for token {token_address}: {e}")
|
| 683 |
+
return []
|
| 684 |
+
|
| 685 |
+
def fetch_transfers_for_token(self, token_address: str, T_cutoff: datetime.datetime, min_amount_threshold: float = 10_000_000) -> List[Dict[str, Any]]:
|
| 686 |
+
"""
|
| 687 |
+
Fetches all transfers for a token before T_cutoff, filtering out small amounts.
|
| 688 |
+
"""
|
| 689 |
+
if not token_address:
|
| 690 |
+
return []
|
| 691 |
+
|
| 692 |
+
query = """
|
| 693 |
+
SELECT * FROM transfers
|
| 694 |
+
WHERE mint_address = %(token_address)s
|
| 695 |
+
AND timestamp <= %(T_cutoff)s
|
| 696 |
+
AND amount_decimal >= %(min_amount)s
|
| 697 |
+
ORDER BY timestamp ASC
|
| 698 |
+
"""
|
| 699 |
+
params = {'token_address': token_address, 'T_cutoff': T_cutoff, 'min_amount': min_amount_threshold}
|
| 700 |
+
print(f"INFO: Fetching significant transfers for {token_address} (amount >= {min_amount_threshold}).")
|
| 701 |
+
|
| 702 |
+
try:
|
| 703 |
+
# This query no longer uses H/B/H, it fetches all significant transfers
|
| 704 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 705 |
+
if not rows: return []
|
| 706 |
+
columns = [col[0] for col in columns_info]
|
| 707 |
+
return [dict(zip(columns, row)) for row in rows]
|
| 708 |
+
except Exception as e:
|
| 709 |
+
print(f"ERROR: Failed to fetch transfers for token {token_address}: {e}")
|
| 710 |
+
return []
|
| 711 |
+
|
| 712 |
+
def fetch_pool_creations_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
|
| 713 |
+
"""
|
| 714 |
+
Fetches pool creation records where the token is the base asset.
|
| 715 |
+
"""
|
| 716 |
+
if not token_address:
|
| 717 |
+
return []
|
| 718 |
+
|
| 719 |
+
query = """
|
| 720 |
+
SELECT
|
| 721 |
+
signature,
|
| 722 |
+
timestamp,
|
| 723 |
+
slot,
|
| 724 |
+
success,
|
| 725 |
+
error,
|
| 726 |
+
priority_fee,
|
| 727 |
+
protocol,
|
| 728 |
+
creator_address,
|
| 729 |
+
pool_address,
|
| 730 |
+
base_address,
|
| 731 |
+
quote_address,
|
| 732 |
+
lp_token_address,
|
| 733 |
+
initial_base_liquidity,
|
| 734 |
+
initial_quote_liquidity,
|
| 735 |
+
base_decimals,
|
| 736 |
+
quote_decimals
|
| 737 |
+
FROM pool_creations
|
| 738 |
+
WHERE base_address = %(token_address)s
|
| 739 |
+
AND timestamp <= %(T_cutoff)s
|
| 740 |
+
ORDER BY timestamp ASC
|
| 741 |
+
"""
|
| 742 |
+
params = {'token_address': token_address, 'T_cutoff': T_cutoff}
|
| 743 |
+
print(f"INFO: Fetching pool creation events for {token_address}.")
|
| 744 |
+
|
| 745 |
+
try:
|
| 746 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 747 |
+
if not rows:
|
| 748 |
+
return []
|
| 749 |
+
columns = [col[0] for col in columns_info]
|
| 750 |
+
return [dict(zip(columns, row)) for row in rows]
|
| 751 |
+
except Exception as e:
|
| 752 |
+
print(f"ERROR: Failed to fetch pool creations for token {token_address}: {e}")
|
| 753 |
+
return []
|
| 754 |
+
|
| 755 |
+
def fetch_liquidity_changes_for_pools(self, pool_addresses: List[str], T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
|
| 756 |
+
"""
|
| 757 |
+
Fetches liquidity change records for the given pools up to T_cutoff.
|
| 758 |
+
"""
|
| 759 |
+
if not pool_addresses:
|
| 760 |
+
return []
|
| 761 |
+
|
| 762 |
+
query = """
|
| 763 |
+
SELECT
|
| 764 |
+
signature,
|
| 765 |
+
timestamp,
|
| 766 |
+
slot,
|
| 767 |
+
success,
|
| 768 |
+
error,
|
| 769 |
+
priority_fee,
|
| 770 |
+
protocol,
|
| 771 |
+
change_type,
|
| 772 |
+
lp_provider,
|
| 773 |
+
pool_address,
|
| 774 |
+
base_amount,
|
| 775 |
+
quote_amount
|
| 776 |
+
FROM liquidity
|
| 777 |
+
WHERE pool_address IN %(pool_addresses)s
|
| 778 |
+
AND timestamp <= %(T_cutoff)s
|
| 779 |
+
ORDER BY timestamp ASC
|
| 780 |
+
"""
|
| 781 |
+
params = {'pool_addresses': pool_addresses, 'T_cutoff': T_cutoff}
|
| 782 |
+
print(f"INFO: Fetching liquidity change events for {len(pool_addresses)} pool(s).")
|
| 783 |
+
|
| 784 |
+
try:
|
| 785 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 786 |
+
if not rows:
|
| 787 |
+
return []
|
| 788 |
+
columns = [col[0] for col in columns_info]
|
| 789 |
+
return [dict(zip(columns, row)) for row in rows]
|
| 790 |
+
except Exception as e:
|
| 791 |
+
print(f"ERROR: Failed to fetch liquidity changes for pools {pool_addresses}: {e}")
|
| 792 |
+
return []
|
| 793 |
+
|
| 794 |
+
def fetch_fee_collections_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
|
| 795 |
+
"""
|
| 796 |
+
Fetches fee collection events where the token appears as either token_0 or token_1.
|
| 797 |
+
"""
|
| 798 |
+
if not token_address:
|
| 799 |
+
return []
|
| 800 |
+
|
| 801 |
+
query = """
|
| 802 |
+
SELECT
|
| 803 |
+
timestamp,
|
| 804 |
+
signature,
|
| 805 |
+
slot,
|
| 806 |
+
success,
|
| 807 |
+
error,
|
| 808 |
+
priority_fee,
|
| 809 |
+
protocol,
|
| 810 |
+
recipient_address,
|
| 811 |
+
token_0_mint_address,
|
| 812 |
+
token_0_amount,
|
| 813 |
+
token_1_mint_address,
|
| 814 |
+
token_1_amount
|
| 815 |
+
FROM fee_collections
|
| 816 |
+
WHERE (token_0_mint_address = %(token)s OR token_1_mint_address = %(token)s)
|
| 817 |
+
AND timestamp <= %(T_cutoff)s
|
| 818 |
+
ORDER BY timestamp ASC
|
| 819 |
+
"""
|
| 820 |
+
params = {'token': token_address, 'T_cutoff': T_cutoff}
|
| 821 |
+
print(f"INFO: Fetching fee collection events for {token_address}.")
|
| 822 |
+
|
| 823 |
+
try:
|
| 824 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 825 |
+
if not rows:
|
| 826 |
+
return []
|
| 827 |
+
columns = [col[0] for col in columns_info]
|
| 828 |
+
return [dict(zip(columns, row)) for row in rows]
|
| 829 |
+
except Exception as e:
|
| 830 |
+
print(f"ERROR: Failed to fetch fee collections for token {token_address}: {e}")
|
| 831 |
+
return []
|
| 832 |
+
|
| 833 |
+
def fetch_migrations_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
|
| 834 |
+
"""
|
| 835 |
+
Fetches migration records for a given token up to T_cutoff.
|
| 836 |
+
"""
|
| 837 |
+
if not token_address:
|
| 838 |
+
return []
|
| 839 |
+
query = """
|
| 840 |
+
SELECT
|
| 841 |
+
timestamp,
|
| 842 |
+
signature,
|
| 843 |
+
slot,
|
| 844 |
+
success,
|
| 845 |
+
error,
|
| 846 |
+
priority_fee,
|
| 847 |
+
protocol,
|
| 848 |
+
mint_address,
|
| 849 |
+
virtual_pool_address,
|
| 850 |
+
pool_address,
|
| 851 |
+
migrated_base_liquidity,
|
| 852 |
+
migrated_quote_liquidity
|
| 853 |
+
FROM migrations
|
| 854 |
+
WHERE mint_address = %(token)s
|
| 855 |
+
AND timestamp <= %(T_cutoff)s
|
| 856 |
+
ORDER BY timestamp ASC
|
| 857 |
+
"""
|
| 858 |
+
params = {'token': token_address, 'T_cutoff': T_cutoff}
|
| 859 |
+
print(f"INFO: Fetching migrations for {token_address}.")
|
| 860 |
+
try:
|
| 861 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 862 |
+
if not rows:
|
| 863 |
+
return []
|
| 864 |
+
columns = [col[0] for col in columns_info]
|
| 865 |
+
return [dict(zip(columns, row)) for row in rows]
|
| 866 |
+
except Exception as e:
|
| 867 |
+
print(f"ERROR: Failed to fetch migrations for token {token_address}: {e}")
|
| 868 |
+
return []
|
| 869 |
+
|
| 870 |
+
def fetch_burns_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
|
| 871 |
+
"""
|
| 872 |
+
Fetches burn events for a given token up to T_cutoff.
|
| 873 |
+
Schema: burns(timestamp, signature, slot, success, error, priority_fee, mint_address, source, amount, amount_decimal, source_balance)
|
| 874 |
+
"""
|
| 875 |
+
if not token_address:
|
| 876 |
+
return []
|
| 877 |
+
|
| 878 |
+
query = """
|
| 879 |
+
SELECT
|
| 880 |
+
timestamp,
|
| 881 |
+
signature,
|
| 882 |
+
slot,
|
| 883 |
+
success,
|
| 884 |
+
error,
|
| 885 |
+
priority_fee,
|
| 886 |
+
mint_address,
|
| 887 |
+
source,
|
| 888 |
+
amount,
|
| 889 |
+
amount_decimal,
|
| 890 |
+
source_balance
|
| 891 |
+
FROM burns
|
| 892 |
+
WHERE mint_address = %(token)s
|
| 893 |
+
AND timestamp <= %(T_cutoff)s
|
| 894 |
+
ORDER BY timestamp ASC
|
| 895 |
+
"""
|
| 896 |
+
params = {'token': token_address, 'T_cutoff': T_cutoff}
|
| 897 |
+
print(f"INFO: Fetching burn events for {token_address}.")
|
| 898 |
+
|
| 899 |
+
try:
|
| 900 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 901 |
+
if not rows:
|
| 902 |
+
return []
|
| 903 |
+
columns = [col[0] for col in columns_info]
|
| 904 |
+
return [dict(zip(columns, row)) for row in rows]
|
| 905 |
+
except Exception as e:
|
| 906 |
+
print(f"ERROR: Failed to fetch burns for token {token_address}: {e}")
|
| 907 |
+
return []
|
| 908 |
+
|
| 909 |
+
def fetch_supply_locks_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
|
| 910 |
+
"""
|
| 911 |
+
Fetches supply lock events for a given token up to T_cutoff.
|
| 912 |
+
Schema: supply_locks(timestamp, signature, slot, success, error, priority_fee, protocol, contract_address, sender, recipient, mint_address, total_locked_amount, final_unlock_timestamp)
|
| 913 |
+
"""
|
| 914 |
+
if not token_address:
|
| 915 |
+
return []
|
| 916 |
+
|
| 917 |
+
query = """
|
| 918 |
+
SELECT
|
| 919 |
+
timestamp,
|
| 920 |
+
signature,
|
| 921 |
+
slot,
|
| 922 |
+
success,
|
| 923 |
+
error,
|
| 924 |
+
priority_fee,
|
| 925 |
+
protocol,
|
| 926 |
+
contract_address,
|
| 927 |
+
sender,
|
| 928 |
+
recipient,
|
| 929 |
+
mint_address,
|
| 930 |
+
total_locked_amount,
|
| 931 |
+
final_unlock_timestamp
|
| 932 |
+
FROM supply_locks
|
| 933 |
+
WHERE mint_address = %(token)s
|
| 934 |
+
AND timestamp <= %(T_cutoff)s
|
| 935 |
+
ORDER BY timestamp ASC
|
| 936 |
+
"""
|
| 937 |
+
params = {'token': token_address, 'T_cutoff': T_cutoff}
|
| 938 |
+
print(f"INFO: Fetching supply lock events for {token_address}.")
|
| 939 |
+
|
| 940 |
+
try:
|
| 941 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 942 |
+
if not rows:
|
| 943 |
+
return []
|
| 944 |
+
columns = [col[0] for col in columns_info]
|
| 945 |
+
return [dict(zip(columns, row)) for row in rows]
|
| 946 |
+
except Exception as e:
|
| 947 |
+
print(f"ERROR: Failed to fetch supply locks for token {token_address}: {e}")
|
| 948 |
+
return []
|
| 949 |
+
|
| 950 |
+
def fetch_token_holders_for_snapshot(self, token_address: str, T_cutoff: datetime.datetime, limit: int = 200) -> List[Dict[str, Any]]:
|
| 951 |
+
"""
|
| 952 |
+
Fetch top holders for a token at or before T_cutoff for snapshot purposes.
|
| 953 |
+
Returns rows with wallet_address and current_balance (>0), ordered by balance desc.
|
| 954 |
+
"""
|
| 955 |
+
if not token_address:
|
| 956 |
+
return []
|
| 957 |
+
query = """
|
| 958 |
+
WITH point_in_time_holdings AS (
|
| 959 |
+
SELECT *,
|
| 960 |
+
ROW_NUMBER() OVER(PARTITION BY wallet_address, mint_address ORDER BY updated_at DESC) as rn_per_holding
|
| 961 |
+
FROM wallet_holdings
|
| 962 |
+
WHERE mint_address = %(token)s AND updated_at <= %(T_cutoff)s
|
| 963 |
+
)
|
| 964 |
+
SELECT wallet_address, current_balance
|
| 965 |
+
FROM point_in_time_holdings
|
| 966 |
+
WHERE rn_per_holding = 1 AND current_balance > 0
|
| 967 |
+
ORDER BY current_balance DESC
|
| 968 |
+
LIMIT %(limit)s;
|
| 969 |
+
"""
|
| 970 |
+
params = {'token': token_address, 'T_cutoff': T_cutoff, 'limit': int(limit)}
|
| 971 |
+
print(f"INFO: Fetching top holders for snapshot for {token_address} (limit {limit}).")
|
| 972 |
+
try:
|
| 973 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 974 |
+
if not rows:
|
| 975 |
+
return []
|
| 976 |
+
columns = [col[0] for col in columns_info]
|
| 977 |
+
return [dict(zip(columns, row)) for row in rows]
|
| 978 |
+
except Exception as e:
|
| 979 |
+
print(f"ERROR: Failed to fetch token holders for {token_address}: {e}")
|
| 980 |
+
return []
|
| 981 |
+
|
| 982 |
+
def fetch_total_holders_count_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> int:
|
| 983 |
+
"""
|
| 984 |
+
Returns the total number of wallets holding the token (current_balance > 0)
|
| 985 |
+
at or before T_cutoff.
|
| 986 |
+
"""
|
| 987 |
+
if not token_address:
|
| 988 |
+
return 0
|
| 989 |
+
query = """
|
| 990 |
+
WITH point_in_time_holdings AS (
|
| 991 |
+
SELECT *,
|
| 992 |
+
ROW_NUMBER() OVER(PARTITION BY wallet_address, mint_address ORDER BY updated_at DESC) as rn_per_holding
|
| 993 |
+
FROM wallet_holdings
|
| 994 |
+
WHERE mint_address = %(token)s AND updated_at <= %(T_cutoff)s
|
| 995 |
+
)
|
| 996 |
+
SELECT count()
|
| 997 |
+
FROM point_in_time_holdings
|
| 998 |
+
WHERE rn_per_holding = 1 AND current_balance > 0;
|
| 999 |
+
"""
|
| 1000 |
+
params = {'token': token_address, 'T_cutoff': T_cutoff}
|
| 1001 |
+
print(f"INFO: Counting total holders for {token_address} at cutoff.")
|
| 1002 |
+
try:
|
| 1003 |
+
rows = self.db_client.execute(query, params)
|
| 1004 |
+
if not rows:
|
| 1005 |
+
return 0
|
| 1006 |
+
return int(rows[0][0])
|
| 1007 |
+
except Exception as e:
|
| 1008 |
+
print(f"ERROR: Failed to count total holders for token {token_address}: {e}")
|
| 1009 |
+
return 0
|
data/data_loader.py
ADDED
|
@@ -0,0 +1,1657 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
import datetime
|
| 4 |
+
import requests
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
from torch.utils.data import Dataset, IterableDataset
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from typing import List, Dict, Any, Optional, Union, Tuple
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import numpy as np
|
| 11 |
+
from bisect import bisect_left, bisect_right
|
| 12 |
+
|
| 13 |
+
# We need the vocabulary for IDs and the processor for the pooler
|
| 14 |
+
import models.vocabulary as vocab
|
| 15 |
+
from models.multi_modal_processor import MultiModalEncoder
|
| 16 |
+
from data.data_fetcher import DataFetcher # NEW: Import the DataFetcher
|
| 17 |
+
|
| 18 |
+
# --- NEW: Hardcoded decimals for common quote tokens ---
|
| 19 |
+
QUOTE_TOKEN_DECIMALS = {
|
| 20 |
+
'So11111111111111111111111111111111111111112': 9, # SOL
|
| 21 |
+
'EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v': 6, # USDC
|
| 22 |
+
'Es9vMFrzaCERmJfrF4H2FYD4KCoNkY11McCe8BenwNYB': 6, # USDT
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
# --- NEW: Hyperparameters for trade event classification ---
|
| 26 |
+
LARGE_TRADE_USD_THRESHOLD = 100.0
|
| 27 |
+
LARGE_TRADE_SUPPLY_PCT_THRESHOLD = 0.03 # 3% of supply
|
| 28 |
+
LARGE_TRANSFER_SUPPLY_PCT_THRESHOLD = 0.03 # 3% of supply
|
| 29 |
+
SMART_WALLET_PNL_THRESHOLD = 3.0 # 300% PNL
|
| 30 |
+
SMART_WALLET_USD_THRESHOLD = 20000.0
|
| 31 |
+
|
| 32 |
+
# --- NEW: Hyperparameters for H/B/H Event Fetching ---
|
| 33 |
+
EVENT_COUNT_THRESHOLD_FOR_HBH = 30000 # If total events > this, use H/B/H
|
| 34 |
+
HBH_EARLY_EVENT_LIMIT = 10000
|
| 35 |
+
HBH_RECENT_EVENT_LIMIT = 15000
|
| 36 |
+
|
| 37 |
+
# --- NEW: OHLC Sequence Length Constant ---
|
| 38 |
+
OHLC_SEQ_LEN = 300 # 4 minutes of chart
|
| 39 |
+
|
| 40 |
+
MIN_AMOUNT_TRANSFER_SUPPLY = 0.0 # 1.0% of total supply
|
| 41 |
+
|
| 42 |
+
# Interval for HolderSnapshot events (seconds)
|
| 43 |
+
HOLDER_SNAPSHOT_INTERVAL_SEC = 300
|
| 44 |
+
HOLDER_SNAPSHOT_TOP_K = 200
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class EmbeddingPooler:
|
| 48 |
+
"""
|
| 49 |
+
A helper class to manage the collection and encoding of unique text/image items
|
| 50 |
+
for a single data sample.
|
| 51 |
+
"""
|
| 52 |
+
def __init__(self):
|
| 53 |
+
self.pool_map = {}
|
| 54 |
+
self.next_idx = 1 # 0 is padding
|
| 55 |
+
|
| 56 |
+
def get_idx(self, item: Any) -> int:
|
| 57 |
+
"""
|
| 58 |
+
Returns a unique index for a given item (string or image).
|
| 59 |
+
- Returns 0 for None or empty strings.
|
| 60 |
+
- Deduplicates identical text and image objects.
|
| 61 |
+
"""
|
| 62 |
+
if item is None:
|
| 63 |
+
return 0
|
| 64 |
+
|
| 65 |
+
# Handle text case
|
| 66 |
+
if isinstance(item, str):
|
| 67 |
+
if not item.strip(): # skip empty or whitespace-only strings
|
| 68 |
+
return 0
|
| 69 |
+
key = item.strip() # use normalized text key
|
| 70 |
+
elif isinstance(item, Image.Image):
|
| 71 |
+
key = id(item) # unique memory address for images
|
| 72 |
+
else:
|
| 73 |
+
key = item # fallback: use object itself if hashable
|
| 74 |
+
|
| 75 |
+
if key not in self.pool_map:
|
| 76 |
+
self.pool_map[key] = {'item': item, 'idx': self.next_idx}
|
| 77 |
+
self.next_idx += 1
|
| 78 |
+
|
| 79 |
+
return self.pool_map[key]['idx']
|
| 80 |
+
|
| 81 |
+
def get_all_items(self) -> List[Dict[str, Any]]:
|
| 82 |
+
"""
|
| 83 |
+
Returns a list of all unique items, sorted by their assigned index.
|
| 84 |
+
"""
|
| 85 |
+
if not self.pool_map:
|
| 86 |
+
return []
|
| 87 |
+
return sorted(self.pool_map.values(), key=lambda x: x['idx'])
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class OracleDataset(Dataset):
|
| 91 |
+
"""
|
| 92 |
+
Dataset class for the Oracle model. It fetches, processes, and structures
|
| 93 |
+
all on-chain and off-chain data for a given token to create a comprehensive
|
| 94 |
+
input sequence for the model.
|
| 95 |
+
"""
|
| 96 |
+
def __init__(self,
|
| 97 |
+
data_fetcher: DataFetcher, # NEW: Pass the fetcher instance
|
| 98 |
+
horizons_seconds: List[int] = [],
|
| 99 |
+
quantiles: List[float] = [],
|
| 100 |
+
max_samples: Optional[int] = None,
|
| 101 |
+
ohlc_stats_path: Union[str, Path] = "./data/ohlc_stats.npz", # NEW: Add stats path parameter
|
| 102 |
+
token_allowlist: Optional[List[str]] = None,
|
| 103 |
+
t_cutoff_seconds: int = 60,
|
| 104 |
+
cache_dir: Optional[Union[str, Path]] = None,
|
| 105 |
+
start_date: Optional[datetime.datetime] = None,
|
| 106 |
+
min_trade_usd: float = 0.0):
|
| 107 |
+
|
| 108 |
+
# --- NEW: Create a persistent requests session for efficiency ---
|
| 109 |
+
self.http_session = requests.Session()
|
| 110 |
+
|
| 111 |
+
self.fetcher = data_fetcher
|
| 112 |
+
self.cache_dir = Path(cache_dir) if cache_dir else None
|
| 113 |
+
|
| 114 |
+
# If a fetcher is provided, we can determine the number of samples.
|
| 115 |
+
# Otherwise, we are likely in a test mode where __len__ might not be called
|
| 116 |
+
# or is used with a mock length.
|
| 117 |
+
self.t_cutoff_seconds = max(0, int(t_cutoff_seconds or 0))
|
| 118 |
+
self.token_allowlist = set(token_allowlist) if token_allowlist else None
|
| 119 |
+
|
| 120 |
+
if self.cache_dir and self.cache_dir.is_dir():
|
| 121 |
+
print(f"INFO: Initializing dataset in offline (cached) mode from: {self.cache_dir}")
|
| 122 |
+
# Scan for cached files to determine length
|
| 123 |
+
self.cached_files = sorted(self.cache_dir.glob("sample_*.pt"), key=lambda p: int(p.stem.split('_')[1]))
|
| 124 |
+
if not self.cached_files:
|
| 125 |
+
raise RuntimeError(f"Cache directory '{self.cache_dir}' provided but contains no 'sample_*.pt' files.")
|
| 126 |
+
|
| 127 |
+
self.num_samples = len(self.cached_files)
|
| 128 |
+
if max_samples is not None:
|
| 129 |
+
self.num_samples = min(max_samples, self.num_samples)
|
| 130 |
+
self.cached_files = self.cached_files[:self.num_samples]
|
| 131 |
+
print(f"INFO: Found {self.num_samples} cached samples to use.")
|
| 132 |
+
self.sampled_mints = [] # Not needed in cached mode
|
| 133 |
+
self.available_mints = []
|
| 134 |
+
|
| 135 |
+
elif self.fetcher:
|
| 136 |
+
print(f"INFO: Initializing dataset in online (generation) mode...")
|
| 137 |
+
self.available_mints = self.fetcher.get_all_mints(start_date=start_date)
|
| 138 |
+
if not self.available_mints:
|
| 139 |
+
raise RuntimeError("Dataset initialization failed: no mint records returned from data fetcher.")
|
| 140 |
+
if self.token_allowlist:
|
| 141 |
+
filtered_mints = [
|
| 142 |
+
mint for mint in self.available_mints
|
| 143 |
+
if mint.get('mint_address') in self.token_allowlist
|
| 144 |
+
]
|
| 145 |
+
if not filtered_mints:
|
| 146 |
+
raise RuntimeError(f"No mint records matched the provided token allowlist: {token_allowlist}")
|
| 147 |
+
self.available_mints = filtered_mints
|
| 148 |
+
|
| 149 |
+
total_mints = len(self.available_mints)
|
| 150 |
+
if max_samples is None:
|
| 151 |
+
self.num_samples = total_mints
|
| 152 |
+
self.sampled_mints = self.available_mints
|
| 153 |
+
else:
|
| 154 |
+
self.num_samples = min(max_samples, total_mints)
|
| 155 |
+
if self.num_samples < total_mints:
|
| 156 |
+
print(f"INFO: Limiting dataset to first {self.num_samples} of {total_mints} available mints.")
|
| 157 |
+
self.sampled_mints = self.available_mints[:self.num_samples]
|
| 158 |
+
else:
|
| 159 |
+
self.available_mints = []
|
| 160 |
+
self.sampled_mints = []
|
| 161 |
+
self.num_samples = 1 if max_samples is None else max_samples
|
| 162 |
+
|
| 163 |
+
self.horizons_seconds = sorted(set(horizons_seconds))
|
| 164 |
+
self.quantiles = quantiles
|
| 165 |
+
self.num_outputs = len(self.horizons_seconds) * len(self.quantiles)
|
| 166 |
+
|
| 167 |
+
# --- NEW: Load global OHLC normalization stats ---
|
| 168 |
+
stats_path = Path(ohlc_stats_path)
|
| 169 |
+
if not stats_path.exists():
|
| 170 |
+
raise FileNotFoundError(f"Required OHLC stats file not found: {stats_path}")
|
| 171 |
+
stats = np.load(stats_path)
|
| 172 |
+
self.ohlc_price_mean = float(stats.get('mean_price_usd', 0.0))
|
| 173 |
+
self.ohlc_price_std = float(stats.get('std_price_usd', 1.0)) or 1.0
|
| 174 |
+
|
| 175 |
+
self.min_trade_usd = min_trade_usd
|
| 176 |
+
|
| 177 |
+
def __len__(self) -> int:
|
| 178 |
+
return self.num_samples
|
| 179 |
+
|
| 180 |
+
def _normalize_price_series(self, values: List[float]) -> List[float]:
|
| 181 |
+
if not values:
|
| 182 |
+
return values
|
| 183 |
+
denom = self.ohlc_price_std if abs(self.ohlc_price_std) > 1e-9 else 1.0
|
| 184 |
+
return [(float(v) - self.ohlc_price_mean) / denom for v in values]
|
| 185 |
+
|
| 186 |
+
def _compute_future_return_labels(self,
|
| 187 |
+
anchor_price: Optional[float],
|
| 188 |
+
anchor_timestamp: int,
|
| 189 |
+
price_series: List[Tuple[int, float]]) -> Tuple[torch.Tensor, torch.Tensor, List[Dict[str, Any]]]:
|
| 190 |
+
if self.num_outputs == 0:
|
| 191 |
+
return torch.zeros(0), torch.zeros(0), []
|
| 192 |
+
|
| 193 |
+
if anchor_price is None or abs(anchor_price) < 1e-9 or not price_series:
|
| 194 |
+
return torch.zeros(self.num_outputs), torch.zeros(self.num_outputs), []
|
| 195 |
+
|
| 196 |
+
ts_list = [int(entry[0]) for entry in price_series]
|
| 197 |
+
price_list = [float(entry[1]) for entry in price_series]
|
| 198 |
+
if not ts_list:
|
| 199 |
+
return torch.zeros(self.num_outputs), torch.zeros(self.num_outputs), []
|
| 200 |
+
|
| 201 |
+
last_ts = ts_list[-1]
|
| 202 |
+
|
| 203 |
+
label_values: List[float] = []
|
| 204 |
+
mask_values: List[float] = []
|
| 205 |
+
debug_entries: List[Dict[str, Any]] = []
|
| 206 |
+
|
| 207 |
+
for horizon in self.horizons_seconds:
|
| 208 |
+
target_ts = anchor_timestamp + horizon
|
| 209 |
+
if target_ts > last_ts:
|
| 210 |
+
horizon_mask = 0.0
|
| 211 |
+
horizon_return = 0.0
|
| 212 |
+
future_price = None
|
| 213 |
+
else:
|
| 214 |
+
idx = bisect_right(ts_list, target_ts) - 1
|
| 215 |
+
if idx < 0:
|
| 216 |
+
horizon_mask = 0.0
|
| 217 |
+
horizon_return = 0.0
|
| 218 |
+
future_price = None
|
| 219 |
+
else:
|
| 220 |
+
future_price = price_list[idx]
|
| 221 |
+
horizon_return = (future_price - anchor_price) / anchor_price
|
| 222 |
+
horizon_return = max(min(horizon_return, 10.0), -10.0)
|
| 223 |
+
horizon_mask = 1.0
|
| 224 |
+
|
| 225 |
+
for _ in self.quantiles:
|
| 226 |
+
label_values.append(horizon_return)
|
| 227 |
+
mask_values.append(horizon_mask)
|
| 228 |
+
debug_entries.append({
|
| 229 |
+
'horizon': horizon,
|
| 230 |
+
'target_ts': target_ts,
|
| 231 |
+
'future_price': future_price,
|
| 232 |
+
'return': horizon_return,
|
| 233 |
+
'mask': horizon_mask
|
| 234 |
+
})
|
| 235 |
+
|
| 236 |
+
return (torch.tensor(label_values, dtype=torch.float32),
|
| 237 |
+
torch.tensor(mask_values, dtype=torch.float32),
|
| 238 |
+
debug_entries)
|
| 239 |
+
|
| 240 |
+
def _generate_onchain_snapshots(
|
| 241 |
+
self,
|
| 242 |
+
token_address: str,
|
| 243 |
+
t0_timestamp: int,
|
| 244 |
+
T_cutoff: datetime.datetime,
|
| 245 |
+
interval_sec: int,
|
| 246 |
+
trade_events: List[Dict[str, Any]],
|
| 247 |
+
transfer_events: List[Dict[str, Any]],
|
| 248 |
+
aggregation_trades: List[Dict[str, Any]],
|
| 249 |
+
wallet_data: Dict[str, Any],
|
| 250 |
+
total_supply_dec: float,
|
| 251 |
+
_register_event_fn
|
| 252 |
+
) -> None:
|
| 253 |
+
# Prepare helper sets and maps (static sniper set based on earliest buyers)
|
| 254 |
+
all_buy_trades = sorted([e for e in trade_events if e.get('trade_direction') == 0 and e.get('success', False)], key=lambda x: x['timestamp'])
|
| 255 |
+
sniper_wallets = []
|
| 256 |
+
seen_buyers = set()
|
| 257 |
+
for e in all_buy_trades:
|
| 258 |
+
wa = e['wallet_address']
|
| 259 |
+
if wa not in seen_buyers:
|
| 260 |
+
sniper_wallets.append(wa)
|
| 261 |
+
seen_buyers.add(wa)
|
| 262 |
+
if len(sniper_wallets) >= 70:
|
| 263 |
+
break
|
| 264 |
+
sniper_set = set(sniper_wallets)
|
| 265 |
+
|
| 266 |
+
KOL_NAME_KEYS = ['kolscan_name', 'cabalspy_name', 'axiom_kol_name']
|
| 267 |
+
|
| 268 |
+
# Build time arrays for price lookup
|
| 269 |
+
agg_ts = [int(t['timestamp']) for t in aggregation_trades] if aggregation_trades else []
|
| 270 |
+
agg_price = [float(t.get('price_usd', 0.0) or 0.0) for t in aggregation_trades] if aggregation_trades else []
|
| 271 |
+
|
| 272 |
+
start_ts = t0_timestamp
|
| 273 |
+
end_ts = int(self._timestamp_to_order_value(T_cutoff)) if hasattr(self, '_timestamp_to_order_value') else int(T_cutoff.timestamp())
|
| 274 |
+
if end_ts - start_ts < interval_sec:
|
| 275 |
+
oc_snapshot_times = [end_ts]
|
| 276 |
+
else:
|
| 277 |
+
steps = (end_ts - start_ts) // interval_sec
|
| 278 |
+
oc_snapshot_times = [start_ts + i * interval_sec for i in range(1, steps + 1)]
|
| 279 |
+
|
| 280 |
+
buyers_seen_global = set()
|
| 281 |
+
prev_holders_count = 0
|
| 282 |
+
for ts_value in oc_snapshot_times:
|
| 283 |
+
window_start = ts_value - interval_sec
|
| 284 |
+
trades_win = [e for e in trade_events if e.get('success', False) and window_start < e['timestamp'] <= ts_value]
|
| 285 |
+
xfers_win = [e for e in transfer_events if window_start < e['timestamp'] <= ts_value]
|
| 286 |
+
|
| 287 |
+
# Per-snapshot holder distribution at ts_value
|
| 288 |
+
cutoff_dt_ts = datetime.datetime.fromtimestamp(ts_value, tz=datetime.timezone.utc)
|
| 289 |
+
holder_records_ts = self.fetcher.fetch_token_holders_for_snapshot(token_address, cutoff_dt_ts, limit=HOLDER_SNAPSHOT_TOP_K)
|
| 290 |
+
holder_entries_ts = []
|
| 291 |
+
for rec in holder_records_ts:
|
| 292 |
+
addr = rec.get('wallet_address')
|
| 293 |
+
try:
|
| 294 |
+
bal = float(rec.get('current_balance', 0.0) or 0.0)
|
| 295 |
+
except (TypeError, ValueError):
|
| 296 |
+
bal = 0.0
|
| 297 |
+
pct = (bal / total_supply_dec) if total_supply_dec and total_supply_dec > 0 else 0.0
|
| 298 |
+
if addr and pct > 0.0:
|
| 299 |
+
holder_entries_ts.append({'wallet': addr, 'holding_pct': pct})
|
| 300 |
+
holder_entries_ts.sort(key=lambda d: d['holding_pct'], reverse=True)
|
| 301 |
+
|
| 302 |
+
# Emit HolderSnapshot for this ts_value
|
| 303 |
+
hs_event = {
|
| 304 |
+
'event_type': 'HolderSnapshot',
|
| 305 |
+
'timestamp': int(ts_value),
|
| 306 |
+
'relative_ts': ts_value - t0_timestamp,
|
| 307 |
+
'holders': holder_entries_ts
|
| 308 |
+
}
|
| 309 |
+
_register_event_fn(hs_event, self._event_execution_sort_key(ts_value, signature='HolderSnapshot') if hasattr(self, '_event_execution_sort_key') else (ts_value, 0, 0, 0, 'HolderSnapshot'))
|
| 310 |
+
|
| 311 |
+
holder_pct_map_ts = {d['wallet']: d['holding_pct'] for d in holder_entries_ts}
|
| 312 |
+
top10_holder_pct = sum(d['holding_pct'] for d in holder_entries_ts[:10]) if holder_entries_ts else 0.0
|
| 313 |
+
|
| 314 |
+
# Cumulative sets up to ts_value
|
| 315 |
+
rat_set_ts = set(ev['destination_wallet_address'] for ev in transfer_events if ev['timestamp'] <= ts_value)
|
| 316 |
+
bundle_buyer_set_ts = set(e['wallet_address'] for e in trade_events if e.get('is_bundle') and e.get('trade_direction') == 0 and e.get('success', False) and e['timestamp'] <= ts_value)
|
| 317 |
+
|
| 318 |
+
buy_count = sum(1 for e in trades_win if e.get('trade_direction') == 0)
|
| 319 |
+
sell_count = sum(1 for e in trades_win if e.get('trade_direction') == 1)
|
| 320 |
+
volume = sum(float(e.get('total_usd', 0.0) or 0.0) for e in trades_win)
|
| 321 |
+
total_txns = len(trades_win) + len(xfers_win)
|
| 322 |
+
global_fees_paid = sum(float(e.get('priority_fee', 0.0) or 0.0) for e in trades_win) + \
|
| 323 |
+
sum(float(e.get('priority_fee', 0.0) or 0.0) for e in xfers_win)
|
| 324 |
+
|
| 325 |
+
smart_trader_addrs = set(e['wallet_address'] for e in trades_win if e.get('event_type') == 'SmartWallet_Trade')
|
| 326 |
+
smart_traders = len(smart_trader_addrs)
|
| 327 |
+
|
| 328 |
+
kol_addrs = set()
|
| 329 |
+
for e in trades_win:
|
| 330 |
+
wa = e['wallet_address']
|
| 331 |
+
soc = wallet_data.get(wa, {}).get('socials', {})
|
| 332 |
+
if any(soc.get(k) for k in KOL_NAME_KEYS if soc):
|
| 333 |
+
kol_addrs.add(wa)
|
| 334 |
+
kols = len(kol_addrs)
|
| 335 |
+
|
| 336 |
+
new_buyers = [e['wallet_address'] for e in trades_win if e.get('trade_direction') == 0 and e['wallet_address'] not in buyers_seen_global]
|
| 337 |
+
for wa in new_buyers:
|
| 338 |
+
buyers_seen_global.add(wa)
|
| 339 |
+
|
| 340 |
+
# Compute growth against previous snapshot endpoint.
|
| 341 |
+
end_dt = datetime.datetime.fromtimestamp(ts_value, tz=datetime.timezone.utc)
|
| 342 |
+
holders_end = self.fetcher.fetch_total_holders_count_for_token(token_address, end_dt)
|
| 343 |
+
total_holders = float(holders_end)
|
| 344 |
+
delta_holders = holders_end - prev_holders_count
|
| 345 |
+
holder_growth_rate = float(delta_holders)
|
| 346 |
+
prev_holders_count = holders_end
|
| 347 |
+
|
| 348 |
+
# Market cap from last price at or before ts
|
| 349 |
+
last_price_usd = 0.0
|
| 350 |
+
if agg_ts:
|
| 351 |
+
for i in range(len(agg_ts) - 1, -1, -1):
|
| 352 |
+
if agg_ts[i] <= ts_value:
|
| 353 |
+
last_price_usd = agg_price[i]
|
| 354 |
+
break
|
| 355 |
+
current_market_cap = float(last_price_usd) * float(total_supply_dec)
|
| 356 |
+
|
| 357 |
+
oc_event = {
|
| 358 |
+
'event_type': 'OnChain_Snapshot',
|
| 359 |
+
'timestamp': int(ts_value),
|
| 360 |
+
'relative_ts': ts_value - t0_timestamp,
|
| 361 |
+
'total_holders': total_holders,
|
| 362 |
+
'smart_traders': float(smart_traders),
|
| 363 |
+
'kols': float(kols),
|
| 364 |
+
'holder_growth_rate': float(holder_growth_rate),
|
| 365 |
+
'top_10_holder_pct': float(top10_holder_pct),
|
| 366 |
+
'sniper_holding_pct': float(sum(holder_pct_map_ts.get(wa, 0.0) for wa in sniper_set)),
|
| 367 |
+
'rat_wallets_holding_pct': float(sum(holder_pct_map_ts.get(wa, 0.0) for wa in rat_set_ts)),
|
| 368 |
+
'bundle_holding_pct': float(sum(holder_pct_map_ts.get(wa, 0.0) for wa in bundle_buyer_set_ts)),
|
| 369 |
+
'current_market_cap': float(current_market_cap),
|
| 370 |
+
'volume': float(volume),
|
| 371 |
+
'buy_count': float(buy_count),
|
| 372 |
+
'sell_count': float(sell_count),
|
| 373 |
+
'total_txns': float(total_txns),
|
| 374 |
+
'global_fees_paid': float(global_fees_paid)
|
| 375 |
+
}
|
| 376 |
+
_register_event_fn(oc_event, self._event_execution_sort_key(ts_value, signature='OnChain_Snapshot') if hasattr(self, '_event_execution_sort_key') else (ts_value, 0, 0, 0, 'OnChain_Snapshot'))
|
| 377 |
+
|
| 378 |
+
def _calculate_deployed_token_stats(self, profiles: Dict[str, Dict[str, Any]], T_cutoff: datetime.datetime):
|
| 379 |
+
"""
|
| 380 |
+
Calculates aggregate statistics for wallets based on the tokens they've deployed.
|
| 381 |
+
This method modifies the `profiles` dictionary in-place.
|
| 382 |
+
"""
|
| 383 |
+
if not profiles: return
|
| 384 |
+
|
| 385 |
+
for addr, profile in profiles.items():
|
| 386 |
+
deployed_tokens = profile.get('deployed_tokens', [])
|
| 387 |
+
|
| 388 |
+
# 1. Deployed Tokens Count
|
| 389 |
+
count = len(deployed_tokens)
|
| 390 |
+
profile['deployed_tokens_count'] = float(count)
|
| 391 |
+
|
| 392 |
+
if count == 0:
|
| 393 |
+
profile['deployed_tokens_migrated_pct'] = 0.0
|
| 394 |
+
profile['deployed_tokens_avg_lifetime_sec'] = 0.0
|
| 395 |
+
profile['deployed_tokens_avg_peak_mc_usd'] = 0.0
|
| 396 |
+
profile['deployed_tokens_median_peak_mc_usd'] = 0.0
|
| 397 |
+
continue
|
| 398 |
+
|
| 399 |
+
# --- NEW: Fetch deployed token details with point-in-time logic ---
|
| 400 |
+
deployed_token_details = self.fetcher.fetch_deployed_token_details(deployed_tokens, T_cutoff)
|
| 401 |
+
|
| 402 |
+
# Collect stats for all deployed tokens of this wallet
|
| 403 |
+
lifetimes = []
|
| 404 |
+
peak_mcs = []
|
| 405 |
+
migrated_count = 0
|
| 406 |
+
for token_addr in deployed_tokens:
|
| 407 |
+
details = deployed_token_details.get(token_addr)
|
| 408 |
+
if not details: continue
|
| 409 |
+
|
| 410 |
+
if details.get('has_migrated'):
|
| 411 |
+
migrated_count += 1
|
| 412 |
+
|
| 413 |
+
lifetimes.append((details['updated_at'] - details['created_at']).total_seconds())
|
| 414 |
+
peak_mcs.append(details.get('ath_price_usd', 0.0) * details.get('total_supply', 0.0) / (10**details.get('decimals', 9))) # Simplified MC
|
| 415 |
+
|
| 416 |
+
# 2. Migrated Pct
|
| 417 |
+
profile['deployed_tokens_migrated_pct'] = (migrated_count / count) if count > 0 else 0.0
|
| 418 |
+
# 3. Avg Lifetime
|
| 419 |
+
profile['deployed_tokens_avg_lifetime_sec'] = torch.mean(torch.tensor(lifetimes)).item() if lifetimes else 0.0
|
| 420 |
+
# 4. Avg & Median Peak MC
|
| 421 |
+
profile['deployed_tokens_avg_peak_mc_usd'] = torch.mean(torch.tensor(peak_mcs)).item() if peak_mcs else 0.0
|
| 422 |
+
profile['deployed_tokens_median_peak_mc_usd'] = torch.median(torch.tensor(peak_mcs)).item() if peak_mcs else 0.0
|
| 423 |
+
|
| 424 |
+
def _process_wallet_data(self, wallet_addresses: List[str], token_data: Dict[str, Any], pooler: EmbeddingPooler, T_cutoff: datetime.datetime) -> tuple[Dict[str, Dict[str, Any]], Dict[str, Dict[str, Any]]]:
|
| 425 |
+
"""
|
| 426 |
+
Fetches and processes profile, social, and holdings data for a list of wallets.
|
| 427 |
+
Uses a T_cutoff to ensure data is point-in-time accurate.
|
| 428 |
+
"""
|
| 429 |
+
if not wallet_addresses:
|
| 430 |
+
return {}, token_data
|
| 431 |
+
|
| 432 |
+
print(f"INFO: Processing wallet data for {len(wallet_addresses)} unique wallets...")
|
| 433 |
+
# Bulk fetch all data
|
| 434 |
+
profiles, socials = self.fetcher.fetch_wallet_profiles_and_socials(wallet_addresses, T_cutoff)
|
| 435 |
+
holdings = self.fetcher.fetch_wallet_holdings(wallet_addresses, T_cutoff)
|
| 436 |
+
|
| 437 |
+
valid_wallets = [addr for addr in wallet_addresses if addr in profiles]
|
| 438 |
+
dropped_wallets = set(wallet_addresses) - set(valid_wallets)
|
| 439 |
+
if dropped_wallets:
|
| 440 |
+
print(f"INFO: Skipping {len(dropped_wallets)} wallets with no profile before cutoff.")
|
| 441 |
+
if not valid_wallets:
|
| 442 |
+
print("INFO: All wallets were graph-only or appeared after cutoff; skipping wallet processing for this token.")
|
| 443 |
+
return {}, token_data
|
| 444 |
+
wallet_addresses = valid_wallets
|
| 445 |
+
|
| 446 |
+
# --- NEW: Collect all unique mints from holdings to fetch their data ---
|
| 447 |
+
all_holding_mints = set()
|
| 448 |
+
for wallet_addr in wallet_addresses:
|
| 449 |
+
for holding_item in holdings.get(wallet_addr, []):
|
| 450 |
+
if 'mint_address' in holding_item:
|
| 451 |
+
all_holding_mints.add(holding_item['mint_address'])
|
| 452 |
+
|
| 453 |
+
# --- NEW: Process all discovered tokens with point-in-time logic ---
|
| 454 |
+
# 1. Fetch raw data for all newly found tokens from holdings.
|
| 455 |
+
# 2. Process this raw data to get embedding indices and add to the pooler.
|
| 456 |
+
# Note: _process_token_data is designed to take a list and return a dict.
|
| 457 |
+
# We pass the addresses and let it handle the fetching and processing internally.
|
| 458 |
+
processed_new_tokens = self._process_token_data(list(all_holding_mints), pooler, T_cutoff)
|
| 459 |
+
# 3. Merge the fully processed new tokens with the existing main token data.
|
| 460 |
+
all_token_data = {**token_data, **(processed_new_tokens or {})}
|
| 461 |
+
|
| 462 |
+
# --- NEW: Calculate deployed token stats using point-in-time logic ---
|
| 463 |
+
self._calculate_deployed_token_stats(profiles, T_cutoff)
|
| 464 |
+
|
| 465 |
+
# --- Assemble the final wallet dictionary ---
|
| 466 |
+
# This structure is exactly what the WalletEncoder expects.
|
| 467 |
+
final_wallets = {}
|
| 468 |
+
for addr in wallet_addresses:
|
| 469 |
+
|
| 470 |
+
# --- Define all expected numerical keys for a profile ---
|
| 471 |
+
# This prevents KeyErrors if the DB returns a partial profile.
|
| 472 |
+
expected_profile_keys = [
|
| 473 |
+
'age', 'deployed_tokens_count', 'deployed_tokens_migrated_pct',
|
| 474 |
+
'deployed_tokens_avg_lifetime_sec', 'deployed_tokens_avg_peak_mc_usd',
|
| 475 |
+
'deployed_tokens_median_peak_mc_usd', 'balance', 'transfers_in_count',
|
| 476 |
+
'transfers_out_count', 'spl_transfers_in_count', 'spl_transfers_out_count',
|
| 477 |
+
'total_buys_count', 'total_sells_count', 'total_winrate',
|
| 478 |
+
'stats_1d_realized_profit_sol', 'stats_1d_realized_profit_pnl', 'stats_1d_buy_count',
|
| 479 |
+
'stats_1d_sell_count', 'stats_1d_transfer_in_count', 'stats_1d_transfer_out_count',
|
| 480 |
+
'stats_1d_avg_holding_period', 'stats_1d_total_bought_cost_sol', 'stats_1d_total_sold_income_sol',
|
| 481 |
+
'stats_1d_total_fee', 'stats_1d_winrate', 'stats_1d_tokens_traded',
|
| 482 |
+
'stats_7d_realized_profit_sol', 'stats_7d_realized_profit_pnl', 'stats_7d_buy_count', 'stats_7d_sell_count', 'stats_7d_transfer_in_count', 'stats_7d_transfer_out_count', 'stats_7d_avg_holding_period', 'stats_7d_total_bought_cost_sol', 'stats_7d_total_sold_income_sol', 'stats_7d_total_fee', 'stats_7d_winrate', 'stats_7d_tokens_traded'
|
| 483 |
+
]
|
| 484 |
+
# --- FIXED: Use .get() and provide a default empty dict if not found ---
|
| 485 |
+
# --- NEW: If a wallet profile doesn't exist in the DB, skip it entirely. ---
|
| 486 |
+
# This removes the old logic that created a placeholder profile with zeroed-out features.
|
| 487 |
+
# "If it doesn't exist, it doesn't exist."
|
| 488 |
+
profile_data = profiles.get(addr, None)
|
| 489 |
+
if not profile_data:
|
| 490 |
+
print(f"INFO: Wallet {addr} found in graph but has no profile in DB. Skipping this wallet.")
|
| 491 |
+
continue
|
| 492 |
+
|
| 493 |
+
# --- NEW: Ensure all expected keys exist in the fetched profile ---
|
| 494 |
+
for key in expected_profile_keys:
|
| 495 |
+
profile_data.setdefault(key, 0.0) # Use 0.0 as a safe default for any missing numerical key
|
| 496 |
+
|
| 497 |
+
social_data = socials.get(addr, {})
|
| 498 |
+
|
| 499 |
+
# --- NEW: Derive boolean social flags based on schema ---
|
| 500 |
+
social_data['has_pf_profile'] = bool(social_data.get('pumpfun_username'))
|
| 501 |
+
social_data['has_twitter'] = bool(social_data.get('twitter_username'))
|
| 502 |
+
social_data['has_telegram'] = bool(social_data.get('telegram_channel'))
|
| 503 |
+
# 'is_exchange_wallet' is not in the schema, so we'll default to False for now.
|
| 504 |
+
# This is a feature that would likely come from a 'tags' column or a separate service.
|
| 505 |
+
social_data['is_exchange_wallet'] = 'exchange_wallet' in profile_data.get('tags', [])
|
| 506 |
+
|
| 507 |
+
# --- NEW: Calculate 'age' based on user's logic ---
|
| 508 |
+
funded_ts = profile_data.get('funded_timestamp', 0)
|
| 509 |
+
if funded_ts and funded_ts > 0:
|
| 510 |
+
# Calculate age in seconds from the funding timestamp
|
| 511 |
+
age_seconds = int(datetime.datetime.now(datetime.timezone.utc).timestamp()) - funded_ts
|
| 512 |
+
else:
|
| 513 |
+
# Fallback for wallets older than our DB window, as requested
|
| 514 |
+
# 5 months * 30 days/month * 24 hours/day * 3600 seconds/hour
|
| 515 |
+
age_seconds = 12_960_000
|
| 516 |
+
|
| 517 |
+
# Add the calculated age to the profile data that the WalletEncoder will receive
|
| 518 |
+
profile_data['age'] = float(age_seconds)
|
| 519 |
+
|
| 520 |
+
# Get the username and add it to the embedding pooler
|
| 521 |
+
username = social_data.get('pumpfun_username') or social_data.get('twitter_username') or social_data.get('kolscan_name')
|
| 522 |
+
|
| 523 |
+
if isinstance(username, str) and username.strip():
|
| 524 |
+
social_data['username_emb_idx'] = pooler.get_idx(username.strip())
|
| 525 |
+
else:
|
| 526 |
+
social_data['username_emb_idx'] = 0 # means "no embedding"
|
| 527 |
+
|
| 528 |
+
# --- NEW: Filter holdings and calculate derived features ---
|
| 529 |
+
# We create a new list `valid_wallet_holdings` to ensure that if a holding's
|
| 530 |
+
# token is invalid (filtered out by _process_token_data), the entire holding
|
| 531 |
+
# row is removed and not passed to the WalletEncoder.
|
| 532 |
+
original_holdings = holdings.get(addr, [])
|
| 533 |
+
valid_wallet_holdings = []
|
| 534 |
+
now_ts = datetime.datetime.now(datetime.timezone.utc)
|
| 535 |
+
for holding_item in original_holdings:
|
| 536 |
+
# 1. Calculate holding_time
|
| 537 |
+
start_ts = holding_item.get('start_holding_at')
|
| 538 |
+
mint_addr = holding_item.get('mint_address')
|
| 539 |
+
token_info = all_token_data.get(mint_addr)
|
| 540 |
+
|
| 541 |
+
if not token_info:
|
| 542 |
+
print(f"INFO: Skipping holding for token {mint_addr} in wallet {addr} because token data is invalid/missing.")
|
| 543 |
+
continue
|
| 544 |
+
|
| 545 |
+
end_ts = holding_item.get('end_holding_at')
|
| 546 |
+
if not start_ts:
|
| 547 |
+
holding_item['holding_time'] = 0.0
|
| 548 |
+
else:
|
| 549 |
+
end_ts = end_ts or now_ts
|
| 550 |
+
holding_item['holding_time'] = (end_ts - start_ts).total_seconds()
|
| 551 |
+
|
| 552 |
+
# 2. Calculate balance_pct_to_supply
|
| 553 |
+
if token_info and token_info.get('total_supply', 0) > 0:
|
| 554 |
+
total_supply = token_info['total_supply'] / (10**token_info.get('decimals', 9))
|
| 555 |
+
current_balance = holding_item.get('current_balance', 0.0)
|
| 556 |
+
holding_item['balance_pct_to_supply'] = (current_balance / total_supply) if total_supply > 0 else 0.0
|
| 557 |
+
else:
|
| 558 |
+
holding_item['balance_pct_to_supply'] = 0.0
|
| 559 |
+
|
| 560 |
+
# 3. --- NEW: Calculate bought_amount_sol_pct_to_native_balance ---
|
| 561 |
+
# This uses the historically accurate native balance from the profile.
|
| 562 |
+
wallet_native_balance = profile_data.get('balance', 0.0)
|
| 563 |
+
bought_cost_sol = holding_item.get('history_bought_cost_sol', 0.0)
|
| 564 |
+
if wallet_native_balance > 1e-9: # Use a small epsilon to avoid division by zero
|
| 565 |
+
holding_item['bought_amount_sol_pct_to_native_balance'] = bought_cost_sol / wallet_native_balance
|
| 566 |
+
else:
|
| 567 |
+
holding_item['bought_amount_sol_pct_to_native_balance'] = 0.0
|
| 568 |
+
|
| 569 |
+
valid_wallet_holdings.append(holding_item)
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
final_wallets[addr] = {
|
| 573 |
+
'profile': profile_data,
|
| 574 |
+
'socials': social_data,
|
| 575 |
+
'holdings': valid_wallet_holdings
|
| 576 |
+
}
|
| 577 |
+
|
| 578 |
+
return final_wallets, all_token_data
|
| 579 |
+
|
| 580 |
+
def _process_token_data(self, token_addresses: List[str], pooler: EmbeddingPooler, T_cutoff: datetime.datetime, token_data: Optional[Dict] = None) -> Dict[str, Dict[str, Any]]:
|
| 581 |
+
"""
|
| 582 |
+
Fetches and processes static data for a list of tokens.
|
| 583 |
+
"""
|
| 584 |
+
if not token_addresses:
|
| 585 |
+
return {}
|
| 586 |
+
|
| 587 |
+
if token_data is None:
|
| 588 |
+
print(f"INFO: Processing token data for {len(token_addresses)} unique tokens...")
|
| 589 |
+
token_data = self.fetcher.fetch_token_data(token_addresses, T_cutoff)
|
| 590 |
+
|
| 591 |
+
# --- NEW: Print the raw fetched token data as requested ---
|
| 592 |
+
print("\n--- RAW TOKEN DATA FROM DATABASE ---")
|
| 593 |
+
print(token_data)
|
| 594 |
+
|
| 595 |
+
# Add pre-computed embedding indices to the token data
|
| 596 |
+
# --- CRITICAL FIX: This function now returns None if the main token is invalid ---
|
| 597 |
+
valid_token_data = {}
|
| 598 |
+
for addr, data in token_data.items():
|
| 599 |
+
# --- FIXED: Only add to pooler if data is valid ---
|
| 600 |
+
image = None
|
| 601 |
+
token_uri = data.get('token_uri')
|
| 602 |
+
|
| 603 |
+
# --- NEW: Use multiple IPFS gateways for reliability ---
|
| 604 |
+
if token_uri and isinstance(token_uri, str) and token_uri.strip():
|
| 605 |
+
|
| 606 |
+
ipfs_gateways = [
|
| 607 |
+
"https://pump.mypinata.cloud/ipfs/",
|
| 608 |
+
"https://dweb.link/ipfs/",
|
| 609 |
+
"https://cloudflare-ipfs.com/ipfs/",
|
| 610 |
+
]
|
| 611 |
+
|
| 612 |
+
try:
|
| 613 |
+
# Handle IPFS URIs for metadata
|
| 614 |
+
if 'ipfs/' in token_uri:
|
| 615 |
+
metadata_hash = token_uri.split('ipfs/')[-1]
|
| 616 |
+
# Try fetching from multiple gateways
|
| 617 |
+
for gateway in ipfs_gateways:
|
| 618 |
+
try:
|
| 619 |
+
metadata_resp = self.http_session.get(f"{gateway}{metadata_hash}", timeout=5)
|
| 620 |
+
metadata_resp.raise_for_status()
|
| 621 |
+
metadata = metadata_resp.json()
|
| 622 |
+
break # Success, exit loop
|
| 623 |
+
except requests.RequestException:
|
| 624 |
+
continue # Try next gateway
|
| 625 |
+
else: # If all gateways fail
|
| 626 |
+
raise requests.RequestException("All IPFS gateways failed for metadata.")
|
| 627 |
+
else: # Handle regular HTTP URIs
|
| 628 |
+
metadata_resp = self.http_session.get(token_uri, timeout=5)
|
| 629 |
+
metadata_resp.raise_for_status()
|
| 630 |
+
metadata = metadata_resp.json()
|
| 631 |
+
|
| 632 |
+
# 1. Fetch metadata JSON from token_uri
|
| 633 |
+
image_url = metadata.get('image', '')
|
| 634 |
+
|
| 635 |
+
# --- FIXED: Apply the same multi-gateway logic to image fetching ---
|
| 636 |
+
if image_url:
|
| 637 |
+
# Handle IPFS URIs for the image
|
| 638 |
+
if 'ipfs/' in image_url:
|
| 639 |
+
image_hash = image_url.split('ipfs/')[-1]
|
| 640 |
+
# Try fetching image from multiple gateways
|
| 641 |
+
for gateway in ipfs_gateways:
|
| 642 |
+
try:
|
| 643 |
+
image_resp = self.http_session.get(f"{gateway}{image_hash}", timeout=10)
|
| 644 |
+
image_resp.raise_for_status()
|
| 645 |
+
image = Image.open(BytesIO(image_resp.content))
|
| 646 |
+
break # Success, exit loop
|
| 647 |
+
except requests.RequestException:
|
| 648 |
+
continue # Try next gateway
|
| 649 |
+
else: # If all gateways fail for the image
|
| 650 |
+
raise requests.RequestException("All IPFS gateways failed for image.")
|
| 651 |
+
else: # Handle regular HTTP image URLs
|
| 652 |
+
image_resp = self.http_session.get(image_url, timeout=10)
|
| 653 |
+
image_resp.raise_for_status()
|
| 654 |
+
image = Image.open(BytesIO(image_resp.content))
|
| 655 |
+
except (requests.RequestException, ValueError, IOError) as e:
|
| 656 |
+
print(f"WARN: Could not fetch or process image for token {addr} from URI {token_uri}. Reason: {e}")
|
| 657 |
+
image = None # Ensure image is None on failure
|
| 658 |
+
|
| 659 |
+
# --- FIXED: Check for valid metadata before adding to pooler ---
|
| 660 |
+
token_name = data.get('name') if data.get('name') and data.get('name').strip() else None
|
| 661 |
+
token_symbol = data.get('symbol') if data.get('symbol') and data.get('symbol').strip() else None
|
| 662 |
+
|
| 663 |
+
# --- IMAGE IS A FUCKING MUST
|
| 664 |
+
# --- FIXED: Correctly handle invalid secondary tokens without aborting the whole process ---
|
| 665 |
+
if not token_name or not token_symbol or not image:
|
| 666 |
+
if not token_name: reason = "name"
|
| 667 |
+
elif not token_symbol: reason = "symbol"
|
| 668 |
+
else: reason = "image (fetch failed)"
|
| 669 |
+
|
| 670 |
+
print(f"WARN: Token {addr} is missing essential metadata ('{reason}'). This token will be skipped.")
|
| 671 |
+
|
| 672 |
+
# If this function was called with only one token, it's the main token.
|
| 673 |
+
# If the main token is invalid, the whole sample is invalid, so return None.
|
| 674 |
+
if len(token_addresses) == 1:
|
| 675 |
+
return None
|
| 676 |
+
# Otherwise, it's a secondary token. Skip it and continue with the others.
|
| 677 |
+
continue
|
| 678 |
+
|
| 679 |
+
# --- NEW: Add is_vanity feature based on the token address ---
|
| 680 |
+
data['is_vanity'] = addr.lower().endswith("pump")
|
| 681 |
+
|
| 682 |
+
data['image_emb_idx'] = pooler.get_idx(image)
|
| 683 |
+
data['name_emb_idx'] = pooler.get_idx(token_name)
|
| 684 |
+
data['symbol_emb_idx'] = pooler.get_idx(token_symbol)
|
| 685 |
+
|
| 686 |
+
# FIX: Validate the protocol ID ---
|
| 687 |
+
# The DB might return an ID that is out of bounds for our nn.Embedding layer.
|
| 688 |
+
# We must ensure the ID is valid or map it to a default 'Unknown' ID.
|
| 689 |
+
raw_protocol_id = data.get('protocol')
|
| 690 |
+
if raw_protocol_id is not None and 0 <= raw_protocol_id < vocab.NUM_PROTOCOLS:
|
| 691 |
+
data['protocol'] = raw_protocol_id
|
| 692 |
+
else:
|
| 693 |
+
data['protocol'] = vocab.PROTOCOL_TO_ID.get('Unknown', 0)
|
| 694 |
+
|
| 695 |
+
valid_token_data[addr] = data
|
| 696 |
+
|
| 697 |
+
return valid_token_data
|
| 698 |
+
|
| 699 |
+
def _generate_ohlc(self, aggregation_trades: List[Dict[str, Any]], T_cutoff: datetime.datetime, interval_seconds: int) -> List[tuple]:
|
| 700 |
+
"""
|
| 701 |
+
Generates an OHLC series from a list of aggregated trades with a dynamic interval.
|
| 702 |
+
It forward-fills gaps and extends the series up to T_cutoff.
|
| 703 |
+
Returns a list of (timestamp, open, close) tuples.
|
| 704 |
+
"""
|
| 705 |
+
if not aggregation_trades:
|
| 706 |
+
return []
|
| 707 |
+
|
| 708 |
+
trades_by_interval = defaultdict(list)
|
| 709 |
+
for trade in aggregation_trades:
|
| 710 |
+
# Group trades into interval buckets
|
| 711 |
+
interval_start_ts = (trade['timestamp'] // interval_seconds) * interval_seconds
|
| 712 |
+
trades_by_interval[interval_start_ts].append(trade['price_usd'])
|
| 713 |
+
|
| 714 |
+
sorted_intervals = sorted(trades_by_interval.keys())
|
| 715 |
+
|
| 716 |
+
if not sorted_intervals:
|
| 717 |
+
return []
|
| 718 |
+
|
| 719 |
+
full_ohlc = []
|
| 720 |
+
start_ts = sorted_intervals[0]
|
| 721 |
+
end_ts = int(T_cutoff.timestamp())
|
| 722 |
+
# Align end_ts to the interval grid
|
| 723 |
+
end_ts = (end_ts // interval_seconds) * interval_seconds
|
| 724 |
+
last_price = aggregation_trades[0]['price_usd']
|
| 725 |
+
|
| 726 |
+
# --- NEW: Debugging log for trades grouped by interval ---
|
| 727 |
+
print(f"\n[DEBUG] OHLC Generation: Trades grouped by interval bucket:")
|
| 728 |
+
print(dict(trades_by_interval))
|
| 729 |
+
|
| 730 |
+
for ts in range(start_ts, end_ts + 1, interval_seconds):
|
| 731 |
+
if ts in trades_by_interval:
|
| 732 |
+
prices = trades_by_interval[ts]
|
| 733 |
+
open_price = prices[0]
|
| 734 |
+
close_price = prices[-1]
|
| 735 |
+
full_ohlc.append((ts, open_price, close_price))
|
| 736 |
+
last_price = close_price
|
| 737 |
+
else:
|
| 738 |
+
full_ohlc.append((ts, last_price, last_price))
|
| 739 |
+
return full_ohlc
|
| 740 |
+
|
| 741 |
+
def __getitem__(self, idx: int) -> Optional[Dict[str, Any]]:
|
| 742 |
+
"""
|
| 743 |
+
Loads a pre-processed data item from the cache, or generates it on-the-fly
|
| 744 |
+
if the dataset is in online mode.
|
| 745 |
+
"""
|
| 746 |
+
if self.cache_dir:
|
| 747 |
+
if idx >= len(self.cached_files):
|
| 748 |
+
raise IndexError(f"Index {idx} out of range for {len(self.cached_files)} cached files.")
|
| 749 |
+
filepath = self.cached_files[idx]
|
| 750 |
+
try:
|
| 751 |
+
# Use map_location to avoid issues if cached on GPU and loading on CPU
|
| 752 |
+
return torch.load(filepath, map_location='cpu')
|
| 753 |
+
except Exception as e:
|
| 754 |
+
print(f"ERROR: Could not load or process cached item {filepath}: {e}")
|
| 755 |
+
return None # DataLoader can be configured to skip None items
|
| 756 |
+
|
| 757 |
+
# Fallback to online generation if no cache_dir is set
|
| 758 |
+
return self.__cacheitem__(idx)
|
| 759 |
+
|
| 760 |
+
def __cacheitem__(self, idx: int) -> Optional[Dict[str, Any]]:
|
| 761 |
+
"""
|
| 762 |
+
The main data loading method. For a given token, it fetches all
|
| 763 |
+
relevant on-chain and off-chain data, processes it, and returns
|
| 764 |
+
a structured dictionary for the collator.
|
| 765 |
+
"""
|
| 766 |
+
|
| 767 |
+
if not self.sampled_mints:
|
| 768 |
+
raise RuntimeError("Dataset has no mint records loaded; ensure fetcher returned data during initialization.")
|
| 769 |
+
if idx >= len(self.sampled_mints):
|
| 770 |
+
raise IndexError(f"Requested sample index {idx} exceeds loaded mint count {len(self.sampled_mints)}.")
|
| 771 |
+
initial_mint_record = self.sampled_mints[idx]
|
| 772 |
+
t0 = initial_mint_record["timestamp"]
|
| 773 |
+
creator_address = initial_mint_record['creator_address']
|
| 774 |
+
token_address = initial_mint_record['mint_address']
|
| 775 |
+
print(f"\n--- Building dataset for token: {token_address} ---")
|
| 776 |
+
|
| 777 |
+
# The EmbeddingPooler is crucial for collecting unique text/images per sample
|
| 778 |
+
pooler = EmbeddingPooler()
|
| 779 |
+
|
| 780 |
+
def _safe_int(value: Any) -> int:
|
| 781 |
+
try:
|
| 782 |
+
return int(value)
|
| 783 |
+
except (TypeError, ValueError):
|
| 784 |
+
return 0
|
| 785 |
+
|
| 786 |
+
def _timestamp_to_order_value(ts_value: Any) -> float:
|
| 787 |
+
if isinstance(ts_value, datetime.datetime):
|
| 788 |
+
if ts_value.tzinfo is None:
|
| 789 |
+
ts_value = ts_value.replace(tzinfo=datetime.timezone.utc)
|
| 790 |
+
return ts_value.timestamp()
|
| 791 |
+
try:
|
| 792 |
+
return float(ts_value)
|
| 793 |
+
except (TypeError, ValueError):
|
| 794 |
+
return 0.0
|
| 795 |
+
|
| 796 |
+
def _event_execution_sort_key(timestamp_value: Any,
|
| 797 |
+
slot: Any = 0,
|
| 798 |
+
transaction_index: Any = 0,
|
| 799 |
+
instruction_index: Any = 0,
|
| 800 |
+
signature: str = '') -> tuple:
|
| 801 |
+
return (
|
| 802 |
+
_timestamp_to_order_value(timestamp_value),
|
| 803 |
+
_safe_int(slot),
|
| 804 |
+
_safe_int(transaction_index),
|
| 805 |
+
_safe_int(instruction_index),
|
| 806 |
+
signature or ''
|
| 807 |
+
)
|
| 808 |
+
|
| 809 |
+
|
| 810 |
+
|
| 811 |
+
# 1. Fetch anchor Mint event to establish the timeline & initial entities
|
| 812 |
+
# --- SIMPLIFIED: Use the mint record we already have ---
|
| 813 |
+
mint_event = {
|
| 814 |
+
'event_type': 'Mint',
|
| 815 |
+
'timestamp': int(initial_mint_record['timestamp'].timestamp()),
|
| 816 |
+
'relative_ts': 0,
|
| 817 |
+
'wallet_address': initial_mint_record['creator_address'],
|
| 818 |
+
'token_address': token_address,
|
| 819 |
+
'protocol_id': initial_mint_record.get('protocol')
|
| 820 |
+
}
|
| 821 |
+
|
| 822 |
+
initial_entities = {mint_event['wallet_address']}
|
| 823 |
+
event_sequence_entries: List[Tuple[tuple, Dict[str, Any]]] = []
|
| 824 |
+
|
| 825 |
+
def _register_event(event: Dict[str, Any], sort_key: tuple):
|
| 826 |
+
event_sequence_entries.append((sort_key, event))
|
| 827 |
+
|
| 828 |
+
_register_event(mint_event, _event_execution_sort_key(mint_event['timestamp'], signature='Mint'))
|
| 829 |
+
|
| 830 |
+
# Determine the cutoff time for all historical data fetching
|
| 831 |
+
# T_cutoff = datetime.datetime.fromtimestamp(event_sequence[-1]['timestamp'], tz=datetime.timezone.utc)
|
| 832 |
+
# --- MODIFIED: Set T_cutoff to mint timestamp + 1 day ---
|
| 833 |
+
T_cutoff = initial_mint_record['timestamp'] + datetime.timedelta(seconds=self.t_cutoff_seconds)
|
| 834 |
+
max_horizon_seconds = max(self.horizons_seconds) if self.horizons_seconds else 0
|
| 835 |
+
future_trades_for_labels: List[Dict[str, Any]] = []
|
| 836 |
+
if self.num_outputs > 0 and max_horizon_seconds > 0:
|
| 837 |
+
future_window_end = T_cutoff + datetime.timedelta(seconds=max_horizon_seconds)
|
| 838 |
+
future_trades_for_labels = self.fetcher.fetch_future_trades_for_token(
|
| 839 |
+
token_address, T_cutoff, future_window_end
|
| 840 |
+
)
|
| 841 |
+
if not future_trades_for_labels:
|
| 842 |
+
print(f"INFO: Skipping token {token_address} (no future trades beyond cutoff).")
|
| 843 |
+
return None
|
| 844 |
+
|
| 845 |
+
# --- NEW: Accumulate all wallets before hitting Neo4j to avoid duplicate queries ---
|
| 846 |
+
graph_seed_entities = set(initial_entities)
|
| 847 |
+
all_graph_entities: Dict[str, str] = {mint_event['wallet_address']: 'Wallet'}
|
| 848 |
+
all_graph_entity_addrs = set(all_graph_entities.keys())
|
| 849 |
+
graph_links: Dict[str, Any] = {}
|
| 850 |
+
|
| 851 |
+
# 3. Fetch trades and add traders to the entity set
|
| 852 |
+
# --- REFACTORED: Fetch trades using the new 3-part HBH system ---
|
| 853 |
+
early_trades, middle_trades, recent_trades = self.fetcher.fetch_trades_for_token(
|
| 854 |
+
token_address, T_cutoff, EVENT_COUNT_THRESHOLD_FOR_HBH, HBH_EARLY_EVENT_LIMIT, HBH_RECENT_EVENT_LIMIT
|
| 855 |
+
)
|
| 856 |
+
def _trade_execution_sort_key(trade: Dict[str, Any]) -> tuple:
|
| 857 |
+
return (
|
| 858 |
+
_timestamp_to_order_value(trade.get('timestamp')),
|
| 859 |
+
_safe_int(trade.get('slot')),
|
| 860 |
+
_safe_int(trade.get('transaction_index')),
|
| 861 |
+
_safe_int(trade.get('instruction_index')),
|
| 862 |
+
trade.get('signature', '')
|
| 863 |
+
)
|
| 864 |
+
|
| 865 |
+
early_trades = sorted(early_trades, key=_trade_execution_sort_key)
|
| 866 |
+
middle_trades = sorted(middle_trades, key=_trade_execution_sort_key)
|
| 867 |
+
recent_trades = sorted(recent_trades, key=_trade_execution_sort_key)
|
| 868 |
+
|
| 869 |
+
# --- NEW: Inject special context tokens to mark HBH boundaries ---
|
| 870 |
+
# 'Middle' marks the start of the blurry middle window
|
| 871 |
+
if middle_trades:
|
| 872 |
+
mid_ts_val = _timestamp_to_order_value(middle_trades[0].get('timestamp'))
|
| 873 |
+
middle_event = {
|
| 874 |
+
'event_type': 'Middle',
|
| 875 |
+
'timestamp': int(mid_ts_val),
|
| 876 |
+
'relative_ts': mid_ts_val - _timestamp_to_order_value(t0)
|
| 877 |
+
}
|
| 878 |
+
_register_event(middle_event, _event_execution_sort_key(mid_ts_val, signature='Middle'))
|
| 879 |
+
|
| 880 |
+
# 'RECENT' marks the start of the high-definition recent window
|
| 881 |
+
if recent_trades:
|
| 882 |
+
rec_ts_val = _timestamp_to_order_value(recent_trades[0].get('timestamp'))
|
| 883 |
+
recent_event = {
|
| 884 |
+
'event_type': 'RECENT',
|
| 885 |
+
'timestamp': int(rec_ts_val),
|
| 886 |
+
'relative_ts': rec_ts_val - _timestamp_to_order_value(t0)
|
| 887 |
+
}
|
| 888 |
+
_register_event(recent_event, _event_execution_sort_key(rec_ts_val, signature='RECENT'))
|
| 889 |
+
|
| 890 |
+
# For now, we only process the high-definition segments for event creation,
|
| 891 |
+
# deduplicated in case of overlap between early/recent slices.
|
| 892 |
+
trade_records = []
|
| 893 |
+
seen_trade_keys = set()
|
| 894 |
+
for trade in early_trades + recent_trades:
|
| 895 |
+
dedupe_key = (
|
| 896 |
+
_safe_int(trade.get('slot')),
|
| 897 |
+
_safe_int(trade.get('transaction_index')),
|
| 898 |
+
_safe_int(trade.get('instruction_index')),
|
| 899 |
+
trade.get('signature', '')
|
| 900 |
+
)
|
| 901 |
+
if dedupe_key in seen_trade_keys:
|
| 902 |
+
continue
|
| 903 |
+
seen_trade_keys.add(dedupe_key)
|
| 904 |
+
trade_records.append(trade)
|
| 905 |
+
|
| 906 |
+
for trade in trade_records:
|
| 907 |
+
trader_addr = trade['maker']
|
| 908 |
+
if trader_addr not in all_graph_entity_addrs:
|
| 909 |
+
all_graph_entity_addrs.add(trader_addr)
|
| 910 |
+
all_graph_entities[trader_addr] = 'Wallet' # Trades are always made by wallets
|
| 911 |
+
graph_seed_entities.add(trader_addr)
|
| 912 |
+
|
| 913 |
+
# --- REFACTORED: Fetch significant transfers, passing total supply for filtering ---
|
| 914 |
+
raw_total_supply = initial_mint_record.get('total_supply', 0)
|
| 915 |
+
base_decimals = initial_mint_record.get('token_decimals', 9)
|
| 916 |
+
total_supply_dec = (raw_total_supply / (10**base_decimals)) if base_decimals > 0 else raw_total_supply
|
| 917 |
+
|
| 918 |
+
# Calculate the minimum amount to be considered a significant transfer
|
| 919 |
+
total_supply_dec = total_supply_dec * MIN_AMOUNT_TRANSFER_SUPPLY # 0.01% of total supply
|
| 920 |
+
|
| 921 |
+
transfer_records = self.fetcher.fetch_transfers_for_token(token_address, T_cutoff, total_supply_dec)
|
| 922 |
+
for transfer in transfer_records:
|
| 923 |
+
src = transfer.get('source')
|
| 924 |
+
dst = transfer.get('destination')
|
| 925 |
+
if src:
|
| 926 |
+
all_graph_entities[src] = 'Wallet'
|
| 927 |
+
graph_seed_entities.add(src)
|
| 928 |
+
if dst:
|
| 929 |
+
all_graph_entities[dst] = 'Wallet'
|
| 930 |
+
graph_seed_entities.add(dst)
|
| 931 |
+
|
| 932 |
+
# --- NEW: Fetch pool creation events to enrich entity set and token list ---
|
| 933 |
+
pool_creation_records = self.fetcher.fetch_pool_creations_for_token(token_address, T_cutoff)
|
| 934 |
+
pool_quote_addresses = set()
|
| 935 |
+
pool_metadata_by_address: Dict[str, Dict[str, Any]] = {}
|
| 936 |
+
for pool_record in pool_creation_records:
|
| 937 |
+
creator_addr = pool_record.get('creator_address')
|
| 938 |
+
if creator_addr:
|
| 939 |
+
all_graph_entities[creator_addr] = 'Wallet'
|
| 940 |
+
graph_seed_entities.add(creator_addr)
|
| 941 |
+
quote_addr = pool_record.get('quote_address')
|
| 942 |
+
if quote_addr:
|
| 943 |
+
pool_quote_addresses.add(quote_addr)
|
| 944 |
+
# Mark discovered quote tokens so they can be fetched later if needed
|
| 945 |
+
all_graph_entities.setdefault(quote_addr, 'Token')
|
| 946 |
+
pool_addr = pool_record.get('pool_address')
|
| 947 |
+
if pool_addr:
|
| 948 |
+
pool_metadata_by_address[pool_addr] = {
|
| 949 |
+
'quote_token_address': quote_addr,
|
| 950 |
+
'quote_decimals': pool_record.get('quote_decimals'),
|
| 951 |
+
'base_decimals': pool_record.get('base_decimals')
|
| 952 |
+
}
|
| 953 |
+
|
| 954 |
+
liquidity_change_records = self.fetcher.fetch_liquidity_changes_for_pools(list(pool_metadata_by_address.keys()), T_cutoff)
|
| 955 |
+
for liquidity_record in liquidity_change_records:
|
| 956 |
+
lp_provider = liquidity_record.get('lp_provider')
|
| 957 |
+
if lp_provider:
|
| 958 |
+
all_graph_entities[lp_provider] = 'Wallet'
|
| 959 |
+
graph_seed_entities.add(lp_provider)
|
| 960 |
+
|
| 961 |
+
fee_collection_records = self.fetcher.fetch_fee_collections_for_token(token_address, T_cutoff)
|
| 962 |
+
burn_records = self.fetcher.fetch_burns_for_token(token_address, T_cutoff)
|
| 963 |
+
supply_lock_records = self.fetcher.fetch_supply_locks_for_token(token_address, T_cutoff)
|
| 964 |
+
migration_records = self.fetcher.fetch_migrations_for_token(token_address, T_cutoff)
|
| 965 |
+
# NEW: Fetch top holders to include their wallets so we can embed them
|
| 966 |
+
holder_records = self.fetcher.fetch_token_holders_for_snapshot(token_address, T_cutoff, limit=HOLDER_SNAPSHOT_TOP_K)
|
| 967 |
+
fee_related_mints = set()
|
| 968 |
+
for fee_record in fee_collection_records:
|
| 969 |
+
recipient = fee_record.get('recipient_address')
|
| 970 |
+
if recipient:
|
| 971 |
+
all_graph_entities[recipient] = 'Wallet'
|
| 972 |
+
graph_seed_entities.add(recipient)
|
| 973 |
+
mint_addr = fee_record.get('token_0_mint_address')
|
| 974 |
+
if mint_addr and mint_addr not in (token_address, ''):
|
| 975 |
+
fee_related_mints.add(mint_addr)
|
| 976 |
+
# Include migration pool addresses as tokens/entities if present
|
| 977 |
+
for mig in migration_records:
|
| 978 |
+
vpool = mig.get('virtual_pool_address')
|
| 979 |
+
paddr = mig.get('pool_address')
|
| 980 |
+
if vpool:
|
| 981 |
+
all_graph_entities.setdefault(vpool, 'Token')
|
| 982 |
+
if paddr:
|
| 983 |
+
all_graph_entities.setdefault(paddr, 'Token')
|
| 984 |
+
|
| 985 |
+
# Include burner wallets in entity set
|
| 986 |
+
for burn in burn_records:
|
| 987 |
+
src = burn.get('source')
|
| 988 |
+
if src:
|
| 989 |
+
all_graph_entities[src] = 'Wallet'
|
| 990 |
+
graph_seed_entities.add(src)
|
| 991 |
+
# Include holder wallets in entity set for embedding availability
|
| 992 |
+
for rec in holder_records:
|
| 993 |
+
wa = rec.get('wallet_address')
|
| 994 |
+
if wa:
|
| 995 |
+
all_graph_entities[wa] = 'Wallet'
|
| 996 |
+
graph_seed_entities.add(wa)
|
| 997 |
+
# Include lockers in entity set
|
| 998 |
+
for lock in supply_lock_records:
|
| 999 |
+
sender = lock.get('sender')
|
| 1000 |
+
recipient = lock.get('recipient')
|
| 1001 |
+
if sender:
|
| 1002 |
+
all_graph_entities[sender] = 'Wallet'
|
| 1003 |
+
graph_seed_entities.add(sender)
|
| 1004 |
+
if recipient:
|
| 1005 |
+
all_graph_entities[recipient] = 'Wallet'
|
| 1006 |
+
graph_seed_entities.add(recipient)
|
| 1007 |
+
|
| 1008 |
+
# --- NEW: Now that all wallets are known, fetch graph links once ---
|
| 1009 |
+
if graph_seed_entities:
|
| 1010 |
+
fetched_graph_entities, graph_links = self.fetcher.fetch_graph_links(
|
| 1011 |
+
list(graph_seed_entities),
|
| 1012 |
+
T_cutoff=T_cutoff,
|
| 1013 |
+
max_degrees=2
|
| 1014 |
+
)
|
| 1015 |
+
for addr, entity_type in fetched_graph_entities.items():
|
| 1016 |
+
all_graph_entities[addr] = entity_type
|
| 1017 |
+
all_graph_entity_addrs = set(all_graph_entities.keys())
|
| 1018 |
+
|
| 1019 |
+
# 4. Fetch and process static data for the main token
|
| 1020 |
+
tokens_to_fetch = [token_address]
|
| 1021 |
+
for quote_addr in pool_quote_addresses:
|
| 1022 |
+
if quote_addr and quote_addr not in tokens_to_fetch:
|
| 1023 |
+
tokens_to_fetch.append(quote_addr)
|
| 1024 |
+
for mint_addr in fee_related_mints:
|
| 1025 |
+
if mint_addr and mint_addr not in tokens_to_fetch:
|
| 1026 |
+
tokens_to_fetch.append(mint_addr)
|
| 1027 |
+
main_metadata = {}
|
| 1028 |
+
main_metadata[token_address] = {
|
| 1029 |
+
'name': initial_mint_record["token_name"],
|
| 1030 |
+
'symbol': initial_mint_record["token_symbol"],
|
| 1031 |
+
'token_uri': initial_mint_record["token_uri"],
|
| 1032 |
+
'protocol': initial_mint_record["protocol"],
|
| 1033 |
+
'total_supply': initial_mint_record["total_supply"],
|
| 1034 |
+
'decimals': initial_mint_record["token_decimals"],
|
| 1035 |
+
'address': token_address
|
| 1036 |
+
}
|
| 1037 |
+
|
| 1038 |
+
main_token_data = self._process_token_data(tokens_to_fetch, pooler, T_cutoff, main_metadata)
|
| 1039 |
+
|
| 1040 |
+
# --- CRITICAL FIX: If the main token is invalid, skip this entire sample ---
|
| 1041 |
+
if not main_token_data:
|
| 1042 |
+
return None # The specific reason is already logged in _process_token_data
|
| 1043 |
+
|
| 1044 |
+
# 5. Fetch and process data for ALL wallets discovered (from mint, graph, trades, etc.)
|
| 1045 |
+
# --- FIXED: Correctly identify wallets using their entity type from the graph ---
|
| 1046 |
+
wallets_to_fetch = [addr for addr, type in all_graph_entities.items() if type == 'Wallet']
|
| 1047 |
+
# Also include traders from trades, even if they weren't in the graph
|
| 1048 |
+
wallets_to_fetch.extend([trade['maker'] for trade in trade_records if trade['maker'] not in wallets_to_fetch])
|
| 1049 |
+
wallet_data, all_token_data = self._process_wallet_data(list(set(wallets_to_fetch)), main_token_data.copy(), pooler, T_cutoff)
|
| 1050 |
+
|
| 1051 |
+
# 6. Process trades into event format using the now-available wallet_data
|
| 1052 |
+
trade_events = []
|
| 1053 |
+
|
| 1054 |
+
aggregation_trades = []
|
| 1055 |
+
high_def_chart_trades = [] # Early + recent windows use 1s candles
|
| 1056 |
+
middle_chart_trades = [] # Middle window uses 30s candles
|
| 1057 |
+
# --- FIXED: Get main token decimals once before the loop ---
|
| 1058 |
+
main_token_info = main_token_data[token_address]
|
| 1059 |
+
base_decimals = main_token_info.get('decimals', 6)
|
| 1060 |
+
# --- FIXED: Get total_supply directly from the initial mint record ---
|
| 1061 |
+
raw_total_supply = initial_mint_record.get('total_supply', 0)
|
| 1062 |
+
total_supply_dec = (raw_total_supply / (10**base_decimals)) if base_decimals > 0 else raw_total_supply
|
| 1063 |
+
print("SUPPLY", total_supply_dec)
|
| 1064 |
+
|
| 1065 |
+
t0_timestamp = _timestamp_to_order_value(t0)
|
| 1066 |
+
|
| 1067 |
+
for trade in trade_records:
|
| 1068 |
+
# --- NEW: Filter out trades with low USD value ---
|
| 1069 |
+
# This applies to both event creation and chart aggregation.
|
| 1070 |
+
if trade.get('total_usd', 0.0) < self.min_trade_usd:
|
| 1071 |
+
continue
|
| 1072 |
+
|
| 1073 |
+
trade_sort_key = _trade_execution_sort_key(trade)
|
| 1074 |
+
trade_timestamp = trade.get('timestamp')
|
| 1075 |
+
trade_timestamp_value = _timestamp_to_order_value(trade_timestamp)
|
| 1076 |
+
trade_timestamp_int = int(trade_timestamp_value)
|
| 1077 |
+
# --- NEW: Determine event type with priority ---
|
| 1078 |
+
trader_addr = trade['maker']
|
| 1079 |
+
trader_wallet_data = wallet_data.get(trader_addr, {})
|
| 1080 |
+
trader_profile = trader_wallet_data.get('profile', {})
|
| 1081 |
+
trader_socials = trader_wallet_data.get('socials', {})
|
| 1082 |
+
|
| 1083 |
+
KOL_NAME_KEYS = ['kolscan_name', 'cabalspy_name', 'axiom_kol_name']
|
| 1084 |
+
is_kol = any(trader_socials.get(key) for key in KOL_NAME_KEYS if trader_socials)
|
| 1085 |
+
is_profitable = (trader_profile.get('stats_30d_realized_profit_pnl', 0.0) > SMART_WALLET_PNL_THRESHOLD and
|
| 1086 |
+
trader_profile.get('stats_30d_realized_profit_usd', 0.0) > SMART_WALLET_USD_THRESHOLD)
|
| 1087 |
+
|
| 1088 |
+
base_amount_dec = trade.get('base_amount', 0) / (10**base_decimals)
|
| 1089 |
+
is_large_amount = (total_supply_dec > 0 and (base_amount_dec / total_supply_dec) > LARGE_TRADE_SUPPLY_PCT_THRESHOLD)
|
| 1090 |
+
|
| 1091 |
+
if trader_addr == creator_address:
|
| 1092 |
+
event_type = 'Deployer_Trade'
|
| 1093 |
+
elif is_kol or is_profitable:
|
| 1094 |
+
event_type = 'SmartWallet_Trade'
|
| 1095 |
+
elif trade.get('total_usd', 0.0) > LARGE_TRADE_USD_THRESHOLD or is_large_amount:
|
| 1096 |
+
event_type = 'LargeTrade'
|
| 1097 |
+
else:
|
| 1098 |
+
event_type = 'Trade'
|
| 1099 |
+
|
| 1100 |
+
# --- NEW: Get token decimals for accurate calculations ---
|
| 1101 |
+
quote_address = trade.get('quote_address')
|
| 1102 |
+
quote_decimals = QUOTE_TOKEN_DECIMALS.get(quote_address, 9) # Default to 9 for SOL
|
| 1103 |
+
|
| 1104 |
+
quote_amount_dec = trade.get('quote_amount', 0) / (10**quote_decimals)
|
| 1105 |
+
|
| 1106 |
+
# --- NEW: Correctly calculate pre-trade balances ---
|
| 1107 |
+
is_sell = trade.get('trade_type') == 1
|
| 1108 |
+
|
| 1109 |
+
# If it's a sell, the pre-trade base balance was higher.
|
| 1110 |
+
pre_trade_base_balance = (trade.get('base_balance', 0.0) + base_amount_dec) if is_sell else trade.get('base_balance', 0.0)
|
| 1111 |
+
# If it's a buy, the pre-trade quote balance was higher.
|
| 1112 |
+
pre_trade_quote_balance = (trade.get('quote_balance', 0.0) + quote_amount_dec) if not is_sell else trade.get('quote_balance', 0.0)
|
| 1113 |
+
|
| 1114 |
+
# --- NEW: Calculate percentage features with the corrected values ---
|
| 1115 |
+
token_amount_pct = (base_amount_dec / pre_trade_base_balance) if pre_trade_base_balance > 1e-9 else 1.0
|
| 1116 |
+
quote_amount_pct = (quote_amount_dec / pre_trade_quote_balance) if pre_trade_quote_balance > 1e-9 else 1.0
|
| 1117 |
+
is_success = trade.get('success', False)
|
| 1118 |
+
if is_success:
|
| 1119 |
+
chart_entry = {
|
| 1120 |
+
'trade_direction': 1 if is_sell else 0, # 1 for sell, 0 for buy,
|
| 1121 |
+
'price_usd': trade.get('price_usd', 0.0),
|
| 1122 |
+
'timestamp': trade_timestamp_int,
|
| 1123 |
+
'sort_key': trade_sort_key,
|
| 1124 |
+
}
|
| 1125 |
+
aggregation_trades.append(chart_entry)
|
| 1126 |
+
high_def_chart_trades.append(chart_entry.copy())
|
| 1127 |
+
# --- NEW: Calculate token amount as a percentage of total supply ---
|
| 1128 |
+
token_amount_pct_of_supply = (base_amount_dec / total_supply_dec) if total_supply_dec > 0 else 0.0
|
| 1129 |
+
trade_event = {
|
| 1130 |
+
'event_type': event_type,
|
| 1131 |
+
'timestamp': trade_timestamp_int,
|
| 1132 |
+
'relative_ts': trade_timestamp_value - t0_timestamp,
|
| 1133 |
+
'wallet_address': trade['maker'],
|
| 1134 |
+
'token_address': token_address,
|
| 1135 |
+
'trade_direction': 1 if is_sell else 0, # 1 for sell, 0 for buy
|
| 1136 |
+
'sol_amount': trade.get('total', 0.0), # Assuming 'total' is the SOL amount
|
| 1137 |
+
'dex_platform_id': trade.get('platform', 0),
|
| 1138 |
+
'priority_fee': trade.get('priority_fee', 0.0),
|
| 1139 |
+
'mev_protection': 1 if trade.get('mev_protection', 0) > 0 else 0, # Convert to binary: 0 for False, 1 for True
|
| 1140 |
+
# --- FIXED: Use the new, correct percentage calculations ---
|
| 1141 |
+
'token_amount_pct_of_holding': token_amount_pct,
|
| 1142 |
+
'quote_amount_pct_of_holding': quote_amount_pct,
|
| 1143 |
+
'slippage': trade.get('slippage', 0.0),
|
| 1144 |
+
'token_amount_pct_to_total_supply': token_amount_pct_of_supply, # FIXED: Replaced price_impact
|
| 1145 |
+
'success': is_success,
|
| 1146 |
+
'is_bundle': False, # Default to False, will be updated below
|
| 1147 |
+
'total_usd': trade.get('total_usd', 0.0)
|
| 1148 |
+
}
|
| 1149 |
+
trade_events.append(trade_event)
|
| 1150 |
+
_register_event(trade_event, trade_sort_key)
|
| 1151 |
+
|
| 1152 |
+
for trade in middle_trades:
|
| 1153 |
+
# --- NEW: Filter out trades with low USD value from chart aggregation ---
|
| 1154 |
+
if trade.get('total_usd', 0.0) < self.min_trade_usd:
|
| 1155 |
+
continue
|
| 1156 |
+
|
| 1157 |
+
# --- NEW: Correctly calculate pre-trade balances ---
|
| 1158 |
+
is_sell = trade.get('trade_type') == 1
|
| 1159 |
+
|
| 1160 |
+
chart_entry = {
|
| 1161 |
+
'trade_direction': 1 if is_sell else 0, # 1 for sell, 0 for buy,
|
| 1162 |
+
'price_usd': trade.get('price_usd', 0.0),
|
| 1163 |
+
'timestamp': int(_timestamp_to_order_value(trade.get('timestamp'))),
|
| 1164 |
+
'sort_key': _trade_execution_sort_key(trade),
|
| 1165 |
+
}
|
| 1166 |
+
aggregation_trades.append(chart_entry)
|
| 1167 |
+
middle_chart_trades.append(chart_entry.copy())
|
| 1168 |
+
|
| 1169 |
+
def _finalize_chart_trade_list(trade_list: List[Dict[str, Any]]):
|
| 1170 |
+
trade_list.sort(key=lambda x: x['sort_key'])
|
| 1171 |
+
for entry in trade_list:
|
| 1172 |
+
entry.pop('sort_key', None)
|
| 1173 |
+
|
| 1174 |
+
_finalize_chart_trade_list(aggregation_trades)
|
| 1175 |
+
_finalize_chart_trade_list(high_def_chart_trades)
|
| 1176 |
+
_finalize_chart_trade_list(middle_chart_trades)
|
| 1177 |
+
|
| 1178 |
+
# --- NEW: Debugging log for all trades used in chart generation ---
|
| 1179 |
+
print(f"\n[DEBUG] Total aggregated trades for OHLC: {len(aggregation_trades)}")
|
| 1180 |
+
if aggregation_trades:
|
| 1181 |
+
print("[DEBUG] First 5 aggregated trades:", aggregation_trades[:5])
|
| 1182 |
+
|
| 1183 |
+
HIGH_DEF_INTERVAL = ("1s", 1)
|
| 1184 |
+
MIDDLE_INTERVAL = ("30s", 30)
|
| 1185 |
+
|
| 1186 |
+
def _emit_chart_segments(trades: List[Dict[str, Any]], interval: tuple, signature_prefix: str):
|
| 1187 |
+
if not trades:
|
| 1188 |
+
return []
|
| 1189 |
+
interval_label, interval_seconds = interval
|
| 1190 |
+
ohlc_series = self._generate_ohlc(trades, T_cutoff, interval_seconds)
|
| 1191 |
+
print(f"[DEBUG] Generated OHLC series ({interval_label}) with {len(ohlc_series)} candles. First 5: {ohlc_series[:5]}")
|
| 1192 |
+
emitted_events = []
|
| 1193 |
+
for idx in range(0, len(ohlc_series), OHLC_SEQ_LEN):
|
| 1194 |
+
segment = ohlc_series[idx:idx + OHLC_SEQ_LEN]
|
| 1195 |
+
if not segment:
|
| 1196 |
+
continue
|
| 1197 |
+
last_ts = segment[-1][0]
|
| 1198 |
+
opens_raw = [s[1] for s in segment]
|
| 1199 |
+
closes_raw = [s[2] for s in segment]
|
| 1200 |
+
chart_event = {
|
| 1201 |
+
'event_type': 'Chart_Segment',
|
| 1202 |
+
'timestamp': last_ts,
|
| 1203 |
+
'relative_ts': last_ts - t0_timestamp,
|
| 1204 |
+
'opens': self._normalize_price_series(opens_raw),
|
| 1205 |
+
'closes': self._normalize_price_series(closes_raw),
|
| 1206 |
+
'i': interval_label
|
| 1207 |
+
}
|
| 1208 |
+
emitted_events.append(chart_event)
|
| 1209 |
+
_register_event(chart_event, _event_execution_sort_key(last_ts, signature=f"{signature_prefix}-{idx}"))
|
| 1210 |
+
return emitted_events
|
| 1211 |
+
|
| 1212 |
+
# --- NEW: Generate Chart_Segment events from aggregated trades ---
|
| 1213 |
+
chart_events = []
|
| 1214 |
+
chart_events.extend(_emit_chart_segments(high_def_chart_trades, HIGH_DEF_INTERVAL, "chart-hd"))
|
| 1215 |
+
chart_events.extend(_emit_chart_segments(middle_chart_trades, MIDDLE_INTERVAL, "chart-mid"))
|
| 1216 |
+
|
| 1217 |
+
# --- NEW: Convert pool creation records into structured events ---
|
| 1218 |
+
SOL_MINT_ADDRESS = 'So11111111111111111111111111111111111111112'
|
| 1219 |
+
|
| 1220 |
+
def _convert_amount_with_decimals(raw_amount: Any, mint_addr: Optional[str]) -> float:
|
| 1221 |
+
if raw_amount is None:
|
| 1222 |
+
return 0.0
|
| 1223 |
+
try:
|
| 1224 |
+
amount_float = float(raw_amount)
|
| 1225 |
+
except (TypeError, ValueError):
|
| 1226 |
+
return 0.0
|
| 1227 |
+
decimals_value = None
|
| 1228 |
+
if mint_addr == SOL_MINT_ADDRESS:
|
| 1229 |
+
decimals_value = QUOTE_TOKEN_DECIMALS.get(SOL_MINT_ADDRESS, 9)
|
| 1230 |
+
elif mint_addr:
|
| 1231 |
+
token_info = all_token_data.get(mint_addr) or main_token_data.get(mint_addr)
|
| 1232 |
+
if token_info:
|
| 1233 |
+
decimals_value = token_info.get('decimals')
|
| 1234 |
+
if decimals_value is None:
|
| 1235 |
+
return amount_float
|
| 1236 |
+
try:
|
| 1237 |
+
decimals_int = max(int(decimals_value), 0)
|
| 1238 |
+
except (TypeError, ValueError):
|
| 1239 |
+
decimals_int = 0
|
| 1240 |
+
if decimals_int <= 0:
|
| 1241 |
+
return amount_float
|
| 1242 |
+
if mint_addr == SOL_MINT_ADDRESS:
|
| 1243 |
+
should_scale = abs(amount_float) >= 1e5
|
| 1244 |
+
else:
|
| 1245 |
+
should_scale = abs(amount_float) >= (10 ** decimals_int)
|
| 1246 |
+
return amount_float / (10 ** decimals_int) if should_scale else amount_float
|
| 1247 |
+
|
| 1248 |
+
pool_created_events = []
|
| 1249 |
+
for pool_record in pool_creation_records:
|
| 1250 |
+
pool_ts_value = _timestamp_to_order_value(pool_record.get('timestamp'))
|
| 1251 |
+
pool_timestamp_int = int(pool_ts_value)
|
| 1252 |
+
|
| 1253 |
+
quote_token_address = pool_record.get('quote_address')
|
| 1254 |
+
|
| 1255 |
+
base_liquidity_raw = pool_record.get('initial_base_liquidity')
|
| 1256 |
+
base_decimals_override = pool_record.get('base_decimals')
|
| 1257 |
+
if base_decimals_override is None:
|
| 1258 |
+
base_decimals_override = main_token_info.get('decimals', base_decimals)
|
| 1259 |
+
base_decimals_value = int(base_decimals_override) if base_decimals_override is not None else int(base_decimals)
|
| 1260 |
+
base_amount_dec = _convert_amount_with_decimals(base_liquidity_raw, token_address)
|
| 1261 |
+
|
| 1262 |
+
quote_liquidity_raw = pool_record.get('initial_quote_liquidity')
|
| 1263 |
+
quote_decimals_override = pool_record.get('quote_decimals')
|
| 1264 |
+
if quote_decimals_override is None:
|
| 1265 |
+
quote_token_info = main_token_data.get(quote_token_address, {})
|
| 1266 |
+
quote_decimals_override = quote_token_info.get('decimals', QUOTE_TOKEN_DECIMALS.get(quote_token_address, 9))
|
| 1267 |
+
if quote_decimals_override is None:
|
| 1268 |
+
quote_decimals_override = 9
|
| 1269 |
+
quote_decimals_value = int(quote_decimals_override)
|
| 1270 |
+
quote_amount_dec = _convert_amount_with_decimals(quote_liquidity_raw, quote_token_address)
|
| 1271 |
+
|
| 1272 |
+
protocol_raw = pool_record.get('protocol')
|
| 1273 |
+
protocol_id = protocol_raw if isinstance(protocol_raw, int) and 0 <= protocol_raw < vocab.NUM_PROTOCOLS else vocab.PROTOCOL_TO_ID.get('Unknown', 0)
|
| 1274 |
+
|
| 1275 |
+
pool_event = {
|
| 1276 |
+
'event_type': 'PoolCreated',
|
| 1277 |
+
'timestamp': pool_timestamp_int,
|
| 1278 |
+
'relative_ts': pool_ts_value - t0_timestamp,
|
| 1279 |
+
'wallet_address': pool_record.get('creator_address'),
|
| 1280 |
+
'token_address': token_address,
|
| 1281 |
+
'protocol_id': protocol_id,
|
| 1282 |
+
'quote_token_address': quote_token_address,
|
| 1283 |
+
'base_amount': base_amount_dec,
|
| 1284 |
+
'quote_amount': quote_amount_dec,
|
| 1285 |
+
'priority_fee': pool_record.get('priority_fee', 0.0),
|
| 1286 |
+
}
|
| 1287 |
+
pool_created_events.append(pool_event)
|
| 1288 |
+
pool_sort_key = _event_execution_sort_key(
|
| 1289 |
+
pool_ts_value,
|
| 1290 |
+
slot=pool_record.get('slot'),
|
| 1291 |
+
transaction_index=0,
|
| 1292 |
+
instruction_index=0,
|
| 1293 |
+
signature=pool_record.get('signature', '')
|
| 1294 |
+
)
|
| 1295 |
+
_register_event(pool_event, pool_sort_key)
|
| 1296 |
+
|
| 1297 |
+
# --- NEW: Convert liquidity change records into structured events ---
|
| 1298 |
+
liquidity_change_events = []
|
| 1299 |
+
for liquidity_record in liquidity_change_records:
|
| 1300 |
+
pool_address = liquidity_record.get('pool_address')
|
| 1301 |
+
pool_meta = pool_metadata_by_address.get(pool_address, {})
|
| 1302 |
+
quote_token_address = pool_meta.get('quote_token_address')
|
| 1303 |
+
|
| 1304 |
+
quote_decimals_override = pool_meta.get('quote_decimals')
|
| 1305 |
+
if quote_decimals_override is None:
|
| 1306 |
+
quote_token_info = main_token_data.get(quote_token_address, {})
|
| 1307 |
+
quote_decimals_override = quote_token_info.get('decimals', QUOTE_TOKEN_DECIMALS.get(quote_token_address, 9))
|
| 1308 |
+
if quote_decimals_override is None:
|
| 1309 |
+
quote_decimals_override = 9
|
| 1310 |
+
|
| 1311 |
+
quote_amount_raw = liquidity_record.get('quote_amount', 0)
|
| 1312 |
+
quote_decimals_value = int(quote_decimals_override)
|
| 1313 |
+
quote_amount_dec = _convert_amount_with_decimals(quote_amount_raw, quote_token_address)
|
| 1314 |
+
|
| 1315 |
+
liquidity_ts_value = _timestamp_to_order_value(liquidity_record.get('timestamp'))
|
| 1316 |
+
liquidity_timestamp_int = int(liquidity_ts_value)
|
| 1317 |
+
|
| 1318 |
+
protocol_raw = liquidity_record.get('protocol')
|
| 1319 |
+
protocol_id = protocol_raw if isinstance(protocol_raw, int) and 0 <= protocol_raw < vocab.NUM_PROTOCOLS else vocab.PROTOCOL_TO_ID.get('Unknown', 0)
|
| 1320 |
+
change_type_id = int(liquidity_record.get('change_type', 0) or 0)
|
| 1321 |
+
|
| 1322 |
+
liquidity_event = {
|
| 1323 |
+
'event_type': 'LiquidityChange',
|
| 1324 |
+
'timestamp': liquidity_timestamp_int,
|
| 1325 |
+
'relative_ts': liquidity_ts_value - t0_timestamp,
|
| 1326 |
+
'wallet_address': liquidity_record.get('lp_provider'),
|
| 1327 |
+
'token_address': token_address,
|
| 1328 |
+
'protocol_id': protocol_id,
|
| 1329 |
+
'quote_token_address': quote_token_address,
|
| 1330 |
+
'change_type_id': change_type_id,
|
| 1331 |
+
'quote_amount': quote_amount_dec,
|
| 1332 |
+
'priority_fee': liquidity_record.get('priority_fee', 0.0),
|
| 1333 |
+
'success': liquidity_record.get('success', False)
|
| 1334 |
+
}
|
| 1335 |
+
|
| 1336 |
+
if quote_token_address:
|
| 1337 |
+
liquidity_change_events.append(liquidity_event)
|
| 1338 |
+
liquidity_sort_key = _event_execution_sort_key(
|
| 1339 |
+
liquidity_ts_value,
|
| 1340 |
+
slot=liquidity_record.get('slot'),
|
| 1341 |
+
transaction_index=0,
|
| 1342 |
+
instruction_index=0,
|
| 1343 |
+
signature=liquidity_record.get('signature', '')
|
| 1344 |
+
)
|
| 1345 |
+
_register_event(liquidity_event, liquidity_sort_key)
|
| 1346 |
+
|
| 1347 |
+
# --- NEW: Convert fee collection records into structured events ---
|
| 1348 |
+
fee_collected_events = []
|
| 1349 |
+
for fee_record in fee_collection_records:
|
| 1350 |
+
fee_ts_value = _timestamp_to_order_value(fee_record.get('timestamp'))
|
| 1351 |
+
fee_timestamp_int = int(fee_ts_value)
|
| 1352 |
+
|
| 1353 |
+
token0_mint = fee_record.get('token_0_mint_address')
|
| 1354 |
+
token1_mint = fee_record.get('token_1_mint_address')
|
| 1355 |
+
token0_amount_raw = fee_record.get('token_0_amount')
|
| 1356 |
+
token1_amount_raw = fee_record.get('token_1_amount')
|
| 1357 |
+
|
| 1358 |
+
sol_amount = 0.0
|
| 1359 |
+
if token0_mint == SOL_MINT_ADDRESS:
|
| 1360 |
+
sol_amount = _convert_amount_with_decimals(token0_amount_raw, SOL_MINT_ADDRESS)
|
| 1361 |
+
elif token1_mint == SOL_MINT_ADDRESS:
|
| 1362 |
+
sol_amount = _convert_amount_with_decimals(token1_amount_raw, SOL_MINT_ADDRESS)
|
| 1363 |
+
|
| 1364 |
+
# Skip if both amounts are zero and no meaningful wallet
|
| 1365 |
+
recipient_addr = fee_record.get('recipient_address')
|
| 1366 |
+
if not recipient_addr:
|
| 1367 |
+
continue
|
| 1368 |
+
|
| 1369 |
+
fee_event = {
|
| 1370 |
+
'event_type': 'FeeCollected',
|
| 1371 |
+
'timestamp': fee_timestamp_int,
|
| 1372 |
+
'relative_ts': fee_ts_value - t0_timestamp,
|
| 1373 |
+
'wallet_address': recipient_addr,
|
| 1374 |
+
'token_address': token_address,
|
| 1375 |
+
'sol_amount': sol_amount,
|
| 1376 |
+
'priority_fee': fee_record.get('priority_fee', 0.0),
|
| 1377 |
+
'protocol_id': fee_record.get('protocol', 0),
|
| 1378 |
+
'success': fee_record.get('success', False),
|
| 1379 |
+
}
|
| 1380 |
+
|
| 1381 |
+
fee_collected_events.append(fee_event)
|
| 1382 |
+
fee_sort_key = _event_execution_sort_key(
|
| 1383 |
+
fee_ts_value,
|
| 1384 |
+
slot=fee_record.get('slot'),
|
| 1385 |
+
transaction_index=0,
|
| 1386 |
+
instruction_index=0,
|
| 1387 |
+
signature=fee_record.get('signature', '')
|
| 1388 |
+
)
|
| 1389 |
+
_register_event(fee_event, fee_sort_key)
|
| 1390 |
+
|
| 1391 |
+
# --- NEW: Convert burn records into structured TokenBurn events ---
|
| 1392 |
+
token_burn_events = []
|
| 1393 |
+
for burn in burn_records:
|
| 1394 |
+
burn_ts_value = _timestamp_to_order_value(burn.get('timestamp'))
|
| 1395 |
+
burn_timestamp_int = int(burn_ts_value)
|
| 1396 |
+
|
| 1397 |
+
amount_dec = burn.get('amount_decimal')
|
| 1398 |
+
if amount_dec is None:
|
| 1399 |
+
raw_amount = burn.get('amount', 0)
|
| 1400 |
+
try:
|
| 1401 |
+
raw_amount = float(raw_amount)
|
| 1402 |
+
except (TypeError, ValueError):
|
| 1403 |
+
raw_amount = 0.0
|
| 1404 |
+
amount_dec = raw_amount / (10**base_decimals) if base_decimals and base_decimals > 0 else raw_amount
|
| 1405 |
+
|
| 1406 |
+
pct_of_supply = (amount_dec / total_supply_dec) if total_supply_dec and total_supply_dec > 0 else 0.0
|
| 1407 |
+
|
| 1408 |
+
burn_event = {
|
| 1409 |
+
'event_type': 'TokenBurn',
|
| 1410 |
+
'timestamp': burn_timestamp_int,
|
| 1411 |
+
'relative_ts': burn_ts_value - t0_timestamp,
|
| 1412 |
+
'wallet_address': burn.get('source'),
|
| 1413 |
+
'token_address': token_address,
|
| 1414 |
+
'amount_pct_of_total_supply': pct_of_supply,
|
| 1415 |
+
'amount_tokens_burned': amount_dec,
|
| 1416 |
+
'priority_fee': burn.get('priority_fee', 0.0),
|
| 1417 |
+
'success': burn.get('success', False),
|
| 1418 |
+
}
|
| 1419 |
+
token_burn_events.append(burn_event)
|
| 1420 |
+
burn_sort_key = _event_execution_sort_key(
|
| 1421 |
+
burn_ts_value,
|
| 1422 |
+
slot=burn.get('slot'),
|
| 1423 |
+
transaction_index=0,
|
| 1424 |
+
instruction_index=0,
|
| 1425 |
+
signature=burn.get('signature', '')
|
| 1426 |
+
)
|
| 1427 |
+
_register_event(burn_event, burn_sort_key)
|
| 1428 |
+
|
| 1429 |
+
# --- NEW: Convert migrations into Migrated events ---
|
| 1430 |
+
for mig in migration_records:
|
| 1431 |
+
mig_ts_value = _timestamp_to_order_value(mig.get('timestamp'))
|
| 1432 |
+
mig_timestamp_int = int(mig_ts_value)
|
| 1433 |
+
prot_raw = mig.get('protocol', 0)
|
| 1434 |
+
protocol_id = prot_raw if isinstance(prot_raw, int) and 0 <= prot_raw < vocab.NUM_PROTOCOLS else vocab.PROTOCOL_TO_ID.get('Unknown', 0)
|
| 1435 |
+
mig_event = {
|
| 1436 |
+
'event_type': 'Migrated',
|
| 1437 |
+
'timestamp': mig_timestamp_int,
|
| 1438 |
+
'relative_ts': mig_ts_value - t0_timestamp,
|
| 1439 |
+
'protocol_id': protocol_id,
|
| 1440 |
+
}
|
| 1441 |
+
mig_sort_key = _event_execution_sort_key(
|
| 1442 |
+
mig_ts_value,
|
| 1443 |
+
slot=mig.get('slot'),
|
| 1444 |
+
transaction_index=0,
|
| 1445 |
+
instruction_index=0,
|
| 1446 |
+
signature=mig.get('signature', '')
|
| 1447 |
+
)
|
| 1448 |
+
_register_event(mig_event, mig_sort_key)
|
| 1449 |
+
|
| 1450 |
+
# NOTE: HolderSnapshot events are generated per-snapshot time inside _generate_onchain_snapshots
|
| 1451 |
+
|
| 1452 |
+
# --- NEW: Convert supply lock records into structured SupplyLock events ---
|
| 1453 |
+
supply_lock_events = []
|
| 1454 |
+
for lock in supply_lock_records:
|
| 1455 |
+
lock_ts_value = _timestamp_to_order_value(lock.get('timestamp'))
|
| 1456 |
+
lock_timestamp_int = int(lock_ts_value)
|
| 1457 |
+
|
| 1458 |
+
# total_locked_amount is Float64, typically already decimal-scaled
|
| 1459 |
+
raw_locked = lock.get('total_locked_amount', 0.0)
|
| 1460 |
+
try:
|
| 1461 |
+
locked_amount = float(raw_locked)
|
| 1462 |
+
except (TypeError, ValueError):
|
| 1463 |
+
locked_amount = 0.0
|
| 1464 |
+
|
| 1465 |
+
pct_of_supply = (locked_amount / total_supply_dec) if total_supply_dec and total_supply_dec > 0 else 0.0
|
| 1466 |
+
|
| 1467 |
+
final_unlock_ts = lock.get('final_unlock_timestamp') or 0
|
| 1468 |
+
try:
|
| 1469 |
+
final_unlock_ts = int(final_unlock_ts)
|
| 1470 |
+
except (TypeError, ValueError):
|
| 1471 |
+
final_unlock_ts = 0
|
| 1472 |
+
lock_duration = max(0, final_unlock_ts - lock_timestamp_int)
|
| 1473 |
+
|
| 1474 |
+
lock_event = {
|
| 1475 |
+
'event_type': 'SupplyLock',
|
| 1476 |
+
'timestamp': lock_timestamp_int,
|
| 1477 |
+
'relative_ts': lock_ts_value - t0_timestamp,
|
| 1478 |
+
'wallet_address': lock.get('sender'),
|
| 1479 |
+
'token_address': token_address,
|
| 1480 |
+
'amount_pct_of_total_supply': pct_of_supply,
|
| 1481 |
+
'lock_duration': float(lock_duration),
|
| 1482 |
+
'priority_fee': lock.get('priority_fee', 0.0),
|
| 1483 |
+
'success': lock.get('success', False),
|
| 1484 |
+
}
|
| 1485 |
+
supply_lock_events.append(lock_event)
|
| 1486 |
+
lock_sort_key = _event_execution_sort_key(
|
| 1487 |
+
lock_ts_value,
|
| 1488 |
+
slot=lock.get('slot'),
|
| 1489 |
+
transaction_index=0,
|
| 1490 |
+
instruction_index=0,
|
| 1491 |
+
signature=lock.get('signature', '')
|
| 1492 |
+
)
|
| 1493 |
+
_register_event(lock_event, lock_sort_key)
|
| 1494 |
+
|
| 1495 |
+
# --- NEW: Process transfer events with strict validation ---
|
| 1496 |
+
transfer_events = []
|
| 1497 |
+
for transfer in transfer_records:
|
| 1498 |
+
print("BOMBOCLAT TRANSFER", transfer)
|
| 1499 |
+
# --- VALIDATION: Ensure the destination wallet has a valid profile ---
|
| 1500 |
+
if transfer['destination'] not in wallet_data:
|
| 1501 |
+
print(f"INFO: Skipping transfer event {transfer['signature']} because destination wallet {transfer['destination']} has no profile.")
|
| 1502 |
+
continue
|
| 1503 |
+
|
| 1504 |
+
# Calculate features
|
| 1505 |
+
token_amount = transfer.get('amount_decimal', 0.0)
|
| 1506 |
+
pct_of_supply = (token_amount / total_supply_dec) if total_supply_dec > 0 else 0.0
|
| 1507 |
+
|
| 1508 |
+
# Reconstruct pre-transfer balance of the source wallet
|
| 1509 |
+
pre_transfer_source_balance = transfer.get('source_balance', 0.0) + token_amount
|
| 1510 |
+
pct_of_holding = (token_amount / pre_transfer_source_balance) if pre_transfer_source_balance > 1e-9 else 1.0
|
| 1511 |
+
|
| 1512 |
+
# --- NEW: Classify LargeTransfer based on supply percentage ---
|
| 1513 |
+
if pct_of_supply > LARGE_TRANSFER_SUPPLY_PCT_THRESHOLD:
|
| 1514 |
+
event_type = 'LargeTransfer'
|
| 1515 |
+
else:
|
| 1516 |
+
event_type = 'Transfer'
|
| 1517 |
+
|
| 1518 |
+
transfer_ts_value = _timestamp_to_order_value(transfer.get('timestamp'))
|
| 1519 |
+
transfer_event = {
|
| 1520 |
+
'event_type': event_type,
|
| 1521 |
+
'timestamp': int(transfer_ts_value),
|
| 1522 |
+
'relative_ts': transfer_ts_value - t0_timestamp,
|
| 1523 |
+
'wallet_address': transfer['source'],
|
| 1524 |
+
'destination_wallet_address': transfer['destination'],
|
| 1525 |
+
'token_address': token_address,
|
| 1526 |
+
'token_amount': token_amount,
|
| 1527 |
+
'transfer_pct_of_total_supply': pct_of_supply,
|
| 1528 |
+
'transfer_pct_of_holding': pct_of_holding,
|
| 1529 |
+
'priority_fee': transfer.get('priority_fee', 0.0)
|
| 1530 |
+
}
|
| 1531 |
+
transfer_events.append(transfer_event)
|
| 1532 |
+
transfer_sort_key = _event_execution_sort_key(
|
| 1533 |
+
transfer_ts_value,
|
| 1534 |
+
slot=transfer.get('slot'),
|
| 1535 |
+
transaction_index=transfer.get('transaction_index'),
|
| 1536 |
+
instruction_index=transfer.get('instruction_index'),
|
| 1537 |
+
signature=transfer.get('signature', '')
|
| 1538 |
+
)
|
| 1539 |
+
_register_event(transfer_event, transfer_sort_key)
|
| 1540 |
+
|
| 1541 |
+
# --- NEW: Correctly detect bundles with a single pass after event creation ---
|
| 1542 |
+
# trade_records are ordered by (timestamp, slot, transaction_index, instruction_index),
|
| 1543 |
+
# so adjacent entries that share a slot belong to the same bundle.
|
| 1544 |
+
if len(trade_records) > 1:
|
| 1545 |
+
for i in range(1, len(trade_records)):
|
| 1546 |
+
if trade_records[i]['slot'] == trade_records[i-1]['slot']:
|
| 1547 |
+
# The corresponding events are at the same indices in trade_events
|
| 1548 |
+
trade_events[i]['is_bundle'] = True
|
| 1549 |
+
trade_events[i-1]['is_bundle'] = True
|
| 1550 |
+
|
| 1551 |
+
# Generate OnChain_Snapshot events using helper
|
| 1552 |
+
self._generate_onchain_snapshots(
|
| 1553 |
+
token_address=token_address,
|
| 1554 |
+
t0_timestamp=t0_timestamp,
|
| 1555 |
+
T_cutoff=T_cutoff,
|
| 1556 |
+
interval_sec=HOLDER_SNAPSHOT_INTERVAL_SEC,
|
| 1557 |
+
trade_events=trade_events,
|
| 1558 |
+
transfer_events=transfer_events,
|
| 1559 |
+
aggregation_trades=aggregation_trades,
|
| 1560 |
+
wallet_data=wallet_data,
|
| 1561 |
+
total_supply_dec=total_supply_dec,
|
| 1562 |
+
_register_event_fn=_register_event
|
| 1563 |
+
)
|
| 1564 |
+
|
| 1565 |
+
# 7. TODO: Fetch social events (tweets, replies, etc.) for all discovered wallets
|
| 1566 |
+
# - Query tables like 'x_posts', 'pump_replies'.
|
| 1567 |
+
# - Use the pooler to get indices for text and media.
|
| 1568 |
+
|
| 1569 |
+
# Sort the combined event sequence by precise execution order
|
| 1570 |
+
event_sequence_entries.sort(key=lambda entry: entry[0])
|
| 1571 |
+
event_sequence = [event for _, event in event_sequence_entries]
|
| 1572 |
+
|
| 1573 |
+
anchor_timestamp_int = int(_timestamp_to_order_value(T_cutoff))
|
| 1574 |
+
anchor_price = None
|
| 1575 |
+
if aggregation_trades:
|
| 1576 |
+
for trade in reversed(aggregation_trades):
|
| 1577 |
+
price_val = trade.get('price_usd')
|
| 1578 |
+
if price_val is not None:
|
| 1579 |
+
anchor_price = float(price_val)
|
| 1580 |
+
break
|
| 1581 |
+
if self.num_outputs > 0 and anchor_price is None:
|
| 1582 |
+
print(f"INFO: Skipping token {token_address} (no pre-cutoff price for labeling).")
|
| 1583 |
+
return None
|
| 1584 |
+
|
| 1585 |
+
future_price_series: List[Tuple[int, float]] = []
|
| 1586 |
+
if (self.num_outputs > 0 and max_horizon_seconds > 0 and
|
| 1587 |
+
anchor_price is not None):
|
| 1588 |
+
timeline = [(anchor_timestamp_int, anchor_price)]
|
| 1589 |
+
for trade in future_trades_for_labels:
|
| 1590 |
+
price_val = trade.get('price_usd')
|
| 1591 |
+
if price_val is None:
|
| 1592 |
+
continue
|
| 1593 |
+
ts_int = int(_timestamp_to_order_value(trade.get('timestamp')))
|
| 1594 |
+
if ts_int <= timeline[-1][0]:
|
| 1595 |
+
continue
|
| 1596 |
+
timeline.append((ts_int, float(price_val)))
|
| 1597 |
+
if len(timeline) > 1:
|
| 1598 |
+
future_price_series = timeline
|
| 1599 |
+
|
| 1600 |
+
debug_label_entries: List[Dict[str, Any]] = []
|
| 1601 |
+
if self.num_outputs > 0:
|
| 1602 |
+
labels_tensor, labels_mask_tensor, debug_label_entries = self._compute_future_return_labels(
|
| 1603 |
+
anchor_price, anchor_timestamp_int, future_price_series
|
| 1604 |
+
)
|
| 1605 |
+
if labels_mask_tensor.sum() == 0:
|
| 1606 |
+
print(f"INFO: Skipping token {token_address} (no valid horizons in future).")
|
| 1607 |
+
return None
|
| 1608 |
+
print("\n[Label Debug]")
|
| 1609 |
+
for entry in debug_label_entries:
|
| 1610 |
+
print(f" Horizon {entry['horizon']}s -> target_ts={entry['target_ts']}, "
|
| 1611 |
+
f"future_price={entry['future_price']}, return={entry['return']:.6f}, "
|
| 1612 |
+
f"mask={int(entry['mask'])}")
|
| 1613 |
+
else:
|
| 1614 |
+
labels_tensor = torch.zeros(0)
|
| 1615 |
+
labels_mask_tensor = torch.zeros(0)
|
| 1616 |
+
|
| 1617 |
+
# For now, we'll return the item with mint and trade events
|
| 1618 |
+
item = {
|
| 1619 |
+
'event_sequence': event_sequence,
|
| 1620 |
+
'wallets': wallet_data,
|
| 1621 |
+
'tokens': all_token_data, # FIXED: Use the comprehensive token data
|
| 1622 |
+
'graph_links': graph_links, # NEW: Add the fetched graph links
|
| 1623 |
+
'embedding_pooler': pooler,
|
| 1624 |
+
'labels': labels_tensor,
|
| 1625 |
+
'labels_mask': labels_mask_tensor}
|
| 1626 |
+
|
| 1627 |
+
# --- NEW: Comprehensive logging before returning the item ---
|
| 1628 |
+
print("\n--- Dataset Item Generation Summary ---")
|
| 1629 |
+
print(f"Token Address: {token_address}"
|
| 1630 |
+
)
|
| 1631 |
+
print(f"\n[Event Sequence] ({len(item['event_sequence'])} events):")
|
| 1632 |
+
for i, event in enumerate(item['event_sequence']):
|
| 1633 |
+
print(f" - Event {i}: {event}")
|
| 1634 |
+
|
| 1635 |
+
print(f"\n[Wallets] ({len(item['wallets'])} wallets):")
|
| 1636 |
+
for i, (addr, data) in enumerate(item['wallets'].items()):
|
| 1637 |
+
print(f" - Wallet {addr}:")
|
| 1638 |
+
print(f" - Profile: {data.get('profile', {})}")
|
| 1639 |
+
print(f" - Socials: {data.get('socials', {})}")
|
| 1640 |
+
|
| 1641 |
+
print(f"\n[Tokens] ({len(item['tokens'])} tokens):")
|
| 1642 |
+
for addr, data in item['tokens'].items():
|
| 1643 |
+
print(f" - Token {addr}: {data}")
|
| 1644 |
+
|
| 1645 |
+
if self.num_outputs > 0:
|
| 1646 |
+
print(f"\n[Labels]")
|
| 1647 |
+
for h_idx, horizon in enumerate(self.horizons_seconds):
|
| 1648 |
+
offset = h_idx * len(self.quantiles)
|
| 1649 |
+
values = item['labels'][offset:offset + len(self.quantiles)]
|
| 1650 |
+
masks = item['labels_mask'][offset:offset + len(self.quantiles)]
|
| 1651 |
+
print(f" Horizon {horizon}s:")
|
| 1652 |
+
for q_idx, quantile in enumerate(self.quantiles):
|
| 1653 |
+
print(f" q={quantile:.2f}: value={values[q_idx]:.6f}, mask={masks[q_idx]:.0f}")
|
| 1654 |
+
|
| 1655 |
+
print("--- End Summary ---\n")
|
| 1656 |
+
|
| 1657 |
+
return item
|
data/ohlc_stats.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f39f15281440244b927a46d14a85537afd891163556d46ee3a79c80c25b6f36b
|
| 3 |
+
size 1660
|
data/preprocess_distribution.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Preprocess distribution statistics for OHLC normalization and token history coverage.
|
| 4 |
+
|
| 5 |
+
This script:
|
| 6 |
+
1. Computes global mean/std figures for price/volume so downstream code can normalize.
|
| 7 |
+
2. Prints descriptive stats about how much price history (in seconds) each token has,
|
| 8 |
+
helping decide which horizons are realistic.
|
| 9 |
+
|
| 10 |
+
All configuration is done via environment variables (see below).
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import pathlib
|
| 15 |
+
import sys
|
| 16 |
+
from typing import List
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import clickhouse_connect
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# --- Configuration (override via env vars if needed) ---
|
| 23 |
+
CLICKHOUSE_HOST = os.getenv("CLICKHOUSE_HOST", "localhost")
|
| 24 |
+
CLICKHOUSE_PORT = int(os.getenv("CLICKHOUSE_PORT", "8123"))
|
| 25 |
+
CLICKHOUSE_USERNAME = os.getenv("CLICKHOUSE_USERNAME", "default")
|
| 26 |
+
CLICKHOUSE_PASSWORD = os.getenv("CLICKHOUSE_PASSWORD", "")
|
| 27 |
+
CLICKHOUSE_DATABASE = os.getenv("CLICKHOUSE_DATABASE", "default")
|
| 28 |
+
|
| 29 |
+
OUTPUT_PATH = pathlib.Path(os.getenv("OHLC_STATS_PATH", "ohlc_stats.npz"))
|
| 30 |
+
MIN_PRICE_USD = float(os.getenv("OHLC_MIN_PRICE_USD", "0.0"))
|
| 31 |
+
MIN_VOLUME_USD = float(os.getenv("OHLC_MIN_VOLUME_USD", "0.0"))
|
| 32 |
+
|
| 33 |
+
TOKEN_ADDRESSES_ENV = os.getenv("OHLC_TOKEN_ADDRESSES", "")
|
| 34 |
+
TOKEN_ADDRESSES = tuple(addr.strip() for addr in TOKEN_ADDRESSES_ENV.split(",") if addr.strip()) or None
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def build_where_clause() -> List[str]:
|
| 38 |
+
clauses = ["t.price_usd > %(min_price)s", "t.total_usd > %(min_vol)s"]
|
| 39 |
+
if TOKEN_ADDRESSES:
|
| 40 |
+
clauses.append("t.base_address IN %(token_addresses)s")
|
| 41 |
+
return clauses
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def build_stats_query(where_sql: str) -> str:
|
| 45 |
+
return f"""
|
| 46 |
+
SELECT
|
| 47 |
+
AVG(t.price_usd) AS mean_price_usd,
|
| 48 |
+
stddevPop(t.price_usd) AS std_price_usd,
|
| 49 |
+
AVG(t.price) AS mean_price_native,
|
| 50 |
+
stddevPop(t.price) AS std_price_native,
|
| 51 |
+
AVG(t.total_usd) AS mean_trade_value_usd,
|
| 52 |
+
stddevPop(t.total_usd) AS std_trade_value_usd
|
| 53 |
+
FROM trades AS t
|
| 54 |
+
INNER JOIN mints AS m
|
| 55 |
+
ON m.mint_address = t.base_address
|
| 56 |
+
WHERE {where_sql}
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def build_history_query(where_sql: str) -> str:
|
| 61 |
+
return f"""
|
| 62 |
+
SELECT
|
| 63 |
+
t.base_address AS token_address,
|
| 64 |
+
toUnixTimestamp(min(t.timestamp)) AS first_ts,
|
| 65 |
+
toUnixTimestamp(max(t.timestamp)) AS last_ts,
|
| 66 |
+
toUnixTimestamp(max(t.timestamp)) - toUnixTimestamp(min(t.timestamp)) AS history_seconds
|
| 67 |
+
FROM trades AS t
|
| 68 |
+
INNER JOIN mints AS m
|
| 69 |
+
ON m.mint_address = t.base_address
|
| 70 |
+
WHERE {where_sql}
|
| 71 |
+
GROUP BY token_address
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def summarize_histories(histories: np.ndarray) -> None:
|
| 76 |
+
if histories.size == 0:
|
| 77 |
+
print("No token history stats available (no qualifying trades).")
|
| 78 |
+
return
|
| 79 |
+
|
| 80 |
+
stats = {
|
| 81 |
+
"count": histories.size,
|
| 82 |
+
"min": histories.min(),
|
| 83 |
+
"median": float(np.median(histories)),
|
| 84 |
+
"mean": histories.mean(),
|
| 85 |
+
"p90": float(np.percentile(histories, 90)),
|
| 86 |
+
"max": histories.max(),
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
def format_seconds(sec: float) -> str:
|
| 90 |
+
hours = sec / 3600.0
|
| 91 |
+
days = hours / 24.0
|
| 92 |
+
return f"{sec:.0f}s ({hours:.2f}h / {days:.2f}d)"
|
| 93 |
+
|
| 94 |
+
print("\nToken history coverage (seconds):")
|
| 95 |
+
print(f" Tokens analyzed: {int(stats['count'])}")
|
| 96 |
+
print(f" Min history: {format_seconds(stats['min'])}")
|
| 97 |
+
print(f" Median history: {format_seconds(stats['median'])}")
|
| 98 |
+
print(f" Mean history: {format_seconds(stats['mean'])}")
|
| 99 |
+
print(f" 90th percentile: {format_seconds(stats['p90'])}")
|
| 100 |
+
print(f" Max history: {format_seconds(stats['max'])}")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def main() -> int:
|
| 104 |
+
where_clauses = build_where_clause()
|
| 105 |
+
where_sql = " AND ".join(where_clauses) if where_clauses else "1"
|
| 106 |
+
params: dict[str, object] = {
|
| 107 |
+
"min_price": max(MIN_PRICE_USD, 0.0),
|
| 108 |
+
"min_vol": max(MIN_VOLUME_USD, 0.0),
|
| 109 |
+
}
|
| 110 |
+
if TOKEN_ADDRESSES:
|
| 111 |
+
params["token_addresses"] = TOKEN_ADDRESSES
|
| 112 |
+
|
| 113 |
+
client = clickhouse_connect.get_client(
|
| 114 |
+
host=CLICKHOUSE_HOST,
|
| 115 |
+
port=CLICKHOUSE_PORT,
|
| 116 |
+
username=CLICKHOUSE_USERNAME,
|
| 117 |
+
password=CLICKHOUSE_PASSWORD,
|
| 118 |
+
database=CLICKHOUSE_DATABASE,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# --- Price/volume stats ---
|
| 122 |
+
stats_query = build_stats_query(where_sql)
|
| 123 |
+
stats_result = client.query(stats_query, parameters=params)
|
| 124 |
+
if not stats_result.result_rows:
|
| 125 |
+
print("ERROR: Stats query returned no rows. Check filters / connectivity.", file=sys.stderr)
|
| 126 |
+
return 1
|
| 127 |
+
(
|
| 128 |
+
mean_price_usd,
|
| 129 |
+
std_price_usd,
|
| 130 |
+
mean_price_native,
|
| 131 |
+
std_price_native,
|
| 132 |
+
mean_trade_value_usd,
|
| 133 |
+
std_trade_value_usd,
|
| 134 |
+
) = stats_result.result_rows[0]
|
| 135 |
+
|
| 136 |
+
stats = {
|
| 137 |
+
"mean_price_usd": float(mean_price_usd or 0.0),
|
| 138 |
+
"std_price_usd": float(std_price_usd or 1.0),
|
| 139 |
+
"mean_price_native": float(mean_price_native or 0.0),
|
| 140 |
+
"std_price_native": float(std_price_native or 1.0),
|
| 141 |
+
"mean_trade_value_usd": float(mean_trade_value_usd or 0.0),
|
| 142 |
+
"std_trade_value_usd": float(std_trade_value_usd or 1.0),
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True)
|
| 146 |
+
np.savez(OUTPUT_PATH, **stats)
|
| 147 |
+
|
| 148 |
+
print(f"Saved stats to {OUTPUT_PATH.resolve()}:")
|
| 149 |
+
for key, value in stats.items():
|
| 150 |
+
print(f" {key}: {value:.6f}")
|
| 151 |
+
|
| 152 |
+
# --- Token history coverage ---
|
| 153 |
+
history_query = build_history_query(where_sql)
|
| 154 |
+
history_result = client.query(history_query, parameters=params)
|
| 155 |
+
history_seconds = np.array(
|
| 156 |
+
[float(row[3]) for row in history_result.result_rows if row[3] is not None],
|
| 157 |
+
dtype=np.float64
|
| 158 |
+
)
|
| 159 |
+
summarize_histories(history_seconds)
|
| 160 |
+
return 0
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
if __name__ == "__main__":
|
| 164 |
+
raise SystemExit(main())
|
graph_schema.rs
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/// Tracks direct capital flow and identifies funding chains.
|
| 2 |
+
pub struct TransferLink {
|
| 3 |
+
pub signature: String,
|
| 4 |
+
pub source: String,
|
| 5 |
+
pub destination: String,
|
| 6 |
+
pub mint: String,
|
| 7 |
+
pub timestamp: i64,
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
/// Identifies wallets trading the same token in the same slot.
|
| 11 |
+
pub struct BundleTradeLink {
|
| 12 |
+
pub signatures: Vec<String>,
|
| 13 |
+
pub wallet_a: String,
|
| 14 |
+
pub wallet_b: String,
|
| 15 |
+
pub mint: String,
|
| 16 |
+
pub slot: i64,
|
| 17 |
+
pub timestamp: i64,
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
/// Reveals a behavioral pattern of one wallet mirroring another's successful trade.
|
| 21 |
+
pub struct CopiedTradeLink {
|
| 22 |
+
pub timestamp: i64,
|
| 23 |
+
pub leader_buy_sig: String,
|
| 24 |
+
pub leader_sell_sig: String,
|
| 25 |
+
pub follower_buy_sig: String,
|
| 26 |
+
pub follower_sell_sig: String,
|
| 27 |
+
pub follower: String,
|
| 28 |
+
pub leader: String,
|
| 29 |
+
pub mint: String,
|
| 30 |
+
pub time_gap_on_buy_sec: i64,
|
| 31 |
+
pub time_gap_on_sell_sec: i64,
|
| 32 |
+
pub leader_pnl: f64,
|
| 33 |
+
pub follower_pnl: f64,
|
| 34 |
+
|
| 35 |
+
pub leader_buy_total: f64,
|
| 36 |
+
pub leader_sell_total: f64,
|
| 37 |
+
|
| 38 |
+
pub follower_buy_total: f64,
|
| 39 |
+
pub follower_sell_total: f64,
|
| 40 |
+
pub follower_buy_slippage: f32,
|
| 41 |
+
pub follower_sell_slippage: f32,
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
/// Represents a link where a group of wallets re-engage with a token in a coordinated manner.
|
| 45 |
+
pub struct CoordinatedActivityLink {
|
| 46 |
+
pub timestamp: i64,
|
| 47 |
+
pub leader_first_sig: String,
|
| 48 |
+
pub leader_second_sig: String,
|
| 49 |
+
pub follower_first_sig: String,
|
| 50 |
+
pub follower_second_sig: String,
|
| 51 |
+
pub follower: String,
|
| 52 |
+
pub leader: String,
|
| 53 |
+
pub mint: String,
|
| 54 |
+
pub time_gap_on_first_sec: i64,
|
| 55 |
+
pub time_gap_on_second_sec: i64,
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
/// Links a token to its original creator.
|
| 59 |
+
pub struct MintedLink {
|
| 60 |
+
pub signature: String,
|
| 61 |
+
pub timestamp: i64,
|
| 62 |
+
pub buy_amount: f64,
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
/// Connects a token to its successful first-movers.
|
| 66 |
+
pub struct SnipedLink {
|
| 67 |
+
pub timestamp: i64,
|
| 68 |
+
pub signature: String,
|
| 69 |
+
pub rank: i64,
|
| 70 |
+
pub sniped_amount: f64,
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
/// Represents connection between wallet that locked supply.
|
| 74 |
+
pub struct LockedSupplyLink {
|
| 75 |
+
pub timestamp: i64,
|
| 76 |
+
pub signature: String,
|
| 77 |
+
pub amount: f64,
|
| 78 |
+
pub unlock_timestamp: u64,
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
/// link of the wallet that burned tokens.
|
| 82 |
+
pub struct BurnedLink {
|
| 83 |
+
pub signature: String,
|
| 84 |
+
pub amount: f64,
|
| 85 |
+
pub timestamp: i64,
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
/// Identifies wallets that provided liquidity, signaling high conviction.
|
| 89 |
+
pub struct ProvidedLiquidityLink {
|
| 90 |
+
pub signature: String,
|
| 91 |
+
pub wallet: String,
|
| 92 |
+
pub token: String,
|
| 93 |
+
pub pool_address: String,
|
| 94 |
+
pub amount_base: f64,
|
| 95 |
+
pub amount_quote: f64,
|
| 96 |
+
pub timestamp: i64,
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
/// A derived link connecting a token to its largest holders.
|
| 100 |
+
pub struct WhaleOfLink {
|
| 101 |
+
pub timestamp: i64,
|
| 102 |
+
pub wallet: String,
|
| 103 |
+
pub token: String,
|
| 104 |
+
pub holding_pct_at_creation: f32, // Holding % when the link was made
|
| 105 |
+
pub ath_usd_at_creation: f64, // Token's ATH when the link was made
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
/// A derived link connecting a token to its most profitable traders.
|
| 109 |
+
pub struct TopTraderOfLink {
|
| 110 |
+
pub timestamp: i64,
|
| 111 |
+
pub wallet: String,
|
| 112 |
+
pub token: String,
|
| 113 |
+
pub pnl_at_creation: f64, // The PNL that first triggered the link
|
| 114 |
+
pub ath_usd_at_creation: f64, // Token's ATH when the link was made
|
| 115 |
+
}
|
inference.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# inference.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import traceback
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
# Import all the necessary components from our project
|
| 8 |
+
from models.model import Oracle
|
| 9 |
+
from data.data_collator import MemecoinCollator
|
| 10 |
+
from models.multi_modal_processor import MultiModalEncoder
|
| 11 |
+
from data.data_loader import OracleDataset
|
| 12 |
+
from data.data_fetcher import DataFetcher
|
| 13 |
+
from models.helper_encoders import ContextualTimeEncoder
|
| 14 |
+
from models.token_encoder import TokenEncoder
|
| 15 |
+
from models.wallet_encoder import WalletEncoder
|
| 16 |
+
from models.graph_updater import GraphUpdater
|
| 17 |
+
from models.ohlc_embedder import OHLCEmbedder
|
| 18 |
+
import models.vocabulary as vocab
|
| 19 |
+
|
| 20 |
+
# --- NEW: Import database clients ---
|
| 21 |
+
from clickhouse_driver import Client as ClickHouseClient
|
| 22 |
+
from neo4j import GraphDatabase
|
| 23 |
+
|
| 24 |
+
# =============================================================================
|
| 25 |
+
# Inference/Test Script for the Oracle Model
|
| 26 |
+
# This script replicates the test logic previously in model.py
|
| 27 |
+
# =============================================================================
|
| 28 |
+
if __name__ == "__main__":
|
| 29 |
+
print("--- Oracle Inference Script (Full Pipeline Test) ---")
|
| 30 |
+
|
| 31 |
+
# --- 1. Define Configs ---
|
| 32 |
+
OHLC_SEQ_LEN = 60
|
| 33 |
+
print(f"Using {vocab.NUM_EVENT_TYPES} event types from vocabulary.")
|
| 34 |
+
|
| 35 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 36 |
+
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
|
| 37 |
+
if device.type == 'cpu': dtype = torch.float32
|
| 38 |
+
print(f"Using device: {device}, dtype: {dtype}")
|
| 39 |
+
|
| 40 |
+
_test_quantiles = [0.1, 0.5, 0.9]
|
| 41 |
+
_test_horizons = [30, 60, 120, 240, 420]
|
| 42 |
+
_test_num_outputs = len(_test_quantiles) * len(_test_horizons)
|
| 43 |
+
|
| 44 |
+
# --- 2. Instantiate ALL Encoders ---
|
| 45 |
+
print("Instantiating encoders (using defaults)...")
|
| 46 |
+
try:
|
| 47 |
+
multi_modal_encoder = MultiModalEncoder(dtype=dtype)
|
| 48 |
+
real_time_enc = ContextualTimeEncoder(dtype=dtype)
|
| 49 |
+
|
| 50 |
+
real_token_enc = TokenEncoder(
|
| 51 |
+
multi_dim=multi_modal_encoder.embedding_dim,
|
| 52 |
+
dtype=dtype
|
| 53 |
+
)
|
| 54 |
+
real_wallet_enc = WalletEncoder(encoder=multi_modal_encoder, dtype=dtype)
|
| 55 |
+
real_graph_upd = GraphUpdater(time_encoder=real_time_enc, dtype=dtype)
|
| 56 |
+
|
| 57 |
+
real_ohlc_emb = OHLCEmbedder(
|
| 58 |
+
num_intervals=vocab.NUM_OHLC_INTERVALS,
|
| 59 |
+
sequence_length=OHLC_SEQ_LEN,
|
| 60 |
+
dtype=dtype
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
print(f"TokenEncoder default output_dim: {real_token_enc.output_dim}")
|
| 64 |
+
print(f"WalletEncoder default d_model: {real_wallet_enc.d_model}")
|
| 65 |
+
print(f"OHLCEmbedder default output_dim: {real_ohlc_emb.output_dim}")
|
| 66 |
+
|
| 67 |
+
print("Encoders instantiated.")
|
| 68 |
+
except Exception as e:
|
| 69 |
+
print(f"Failed to instantiate encoders: {e}")
|
| 70 |
+
traceback.print_exc()
|
| 71 |
+
exit()
|
| 72 |
+
|
| 73 |
+
# --- 3. Instantiate the Collator ---
|
| 74 |
+
collator = MemecoinCollator(
|
| 75 |
+
event_type_to_id=vocab.EVENT_TO_ID,
|
| 76 |
+
device=device,
|
| 77 |
+
multi_modal_encoder=multi_modal_encoder,
|
| 78 |
+
dtype=dtype,
|
| 79 |
+
ohlc_seq_len=OHLC_SEQ_LEN,
|
| 80 |
+
max_seq_len=50
|
| 81 |
+
)
|
| 82 |
+
print("MemecoinCollator (fast batcher) instantiated.")
|
| 83 |
+
|
| 84 |
+
# --- 4. Instantiate the Oracle Model ---
|
| 85 |
+
print("Instantiating Oracle (full pipeline)...")
|
| 86 |
+
model = Oracle(
|
| 87 |
+
token_encoder=real_token_enc,
|
| 88 |
+
wallet_encoder=real_wallet_enc,
|
| 89 |
+
graph_updater=real_graph_upd,
|
| 90 |
+
time_encoder=real_time_enc,
|
| 91 |
+
multi_modal_dim=multi_modal_encoder.embedding_dim,
|
| 92 |
+
num_event_types=vocab.NUM_EVENT_TYPES,
|
| 93 |
+
event_pad_id=vocab.EVENT_TO_ID['__PAD__'],
|
| 94 |
+
event_type_to_id=vocab.EVENT_TO_ID,
|
| 95 |
+
model_config_name="Qwen/Qwen3-0.6B",
|
| 96 |
+
quantiles=_test_quantiles,
|
| 97 |
+
horizons_seconds=_test_horizons,
|
| 98 |
+
dtype=dtype,
|
| 99 |
+
ohlc_embedder=real_ohlc_emb
|
| 100 |
+
).to(device)
|
| 101 |
+
model.eval()
|
| 102 |
+
print(f"Oracle d_model: {model.d_model}")
|
| 103 |
+
|
| 104 |
+
# --- 5. Create Dataset and run pre-collation step ---
|
| 105 |
+
print("Creating Dataset...")
|
| 106 |
+
|
| 107 |
+
# --- NEW: Initialize real database clients and DataFetcher ---
|
| 108 |
+
try:
|
| 109 |
+
print("Connecting to databases...")
|
| 110 |
+
# ClickHouse running locally on port 8123 with no auth
|
| 111 |
+
clickhouse_client = ClickHouseClient(host='localhost', port=9000)
|
| 112 |
+
# Neo4j running locally on port 7687 with no auth
|
| 113 |
+
neo4j_driver = GraphDatabase.driver("bolt://localhost:7687", auth=None)
|
| 114 |
+
|
| 115 |
+
data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
|
| 116 |
+
print("Database clients and DataFetcher initialized.")
|
| 117 |
+
|
| 118 |
+
# --- Fetch mints to get the first token for processing ---
|
| 119 |
+
all_mints = data_fetcher.get_all_mints()
|
| 120 |
+
if not all_mints:
|
| 121 |
+
print("\n❌ No mints found in the database. Exiting test.")
|
| 122 |
+
exit()
|
| 123 |
+
|
| 124 |
+
# --- FIXED: Instantiate the dataset in REAL mode, removing is_test flag ---
|
| 125 |
+
dataset = OracleDataset(
|
| 126 |
+
data_fetcher=data_fetcher,
|
| 127 |
+
horizons_seconds=_test_horizons,
|
| 128 |
+
quantiles=_test_quantiles,
|
| 129 |
+
max_samples=57)
|
| 130 |
+
|
| 131 |
+
except Exception as e:
|
| 132 |
+
print(f"FATAL: Could not initialize database connections or dataset: {e}")
|
| 133 |
+
traceback.print_exc()
|
| 134 |
+
exit()
|
| 135 |
+
|
| 136 |
+
# --- PRODUCTION-READY: Process a full batch of items from the dataset ---
|
| 137 |
+
print(f"\n--- Processing a batch of up to {len(dataset)} items from the dataset ---")
|
| 138 |
+
batch_items = []
|
| 139 |
+
for i in range(len(dataset)):
|
| 140 |
+
token_addr = dataset.sampled_mints[i].get('mint_address', 'unknown')
|
| 141 |
+
print(f" - Attempting to process sample {i+1}/{len(dataset)} ({token_addr})...")
|
| 142 |
+
fetch_start = time.time()
|
| 143 |
+
sample = dataset[i]
|
| 144 |
+
fetch_elapsed = time.time() - fetch_start
|
| 145 |
+
print(f" ... fetch completed in {fetch_elapsed:.2f}s")
|
| 146 |
+
if sample is not None:
|
| 147 |
+
batch_items.append(sample)
|
| 148 |
+
print(f" ... Success! Sample added to batch.")
|
| 149 |
+
|
| 150 |
+
if not batch_items:
|
| 151 |
+
print("\n❌ No valid samples could be generated from the dataset. Exiting.")
|
| 152 |
+
exit()
|
| 153 |
+
|
| 154 |
+
# --- 6. Run Collator AND Model ---
|
| 155 |
+
print("\n--- Testing Pipeline (Collator + Model.forward) ---")
|
| 156 |
+
try:
|
| 157 |
+
# 1. Collator
|
| 158 |
+
collate_start = time.time()
|
| 159 |
+
collated_batch = collator(batch_items)
|
| 160 |
+
collate_elapsed = time.time() - collate_start
|
| 161 |
+
print("Collation successful!")
|
| 162 |
+
print(f"Collation time for batch of {len(batch_items)} tokens: {collate_elapsed:.2f}s")
|
| 163 |
+
|
| 164 |
+
# --- Check collator output ---
|
| 165 |
+
B = len(batch_items)
|
| 166 |
+
L = collated_batch['attention_mask'].shape[1]
|
| 167 |
+
assert 'ohlc_price_tensors' in collated_batch
|
| 168 |
+
ohlc_price_tensors = collated_batch['ohlc_price_tensors']
|
| 169 |
+
assert ohlc_price_tensors.dim() == 3, f"Expected 3D OHLC tensor, got shape {tuple(ohlc_price_tensors.shape)}"
|
| 170 |
+
assert ohlc_price_tensors.shape[1] == 2, f"Expected OHLC tensor with 2 rows (open/close), got {ohlc_price_tensors.shape[1]}"
|
| 171 |
+
assert ohlc_price_tensors.shape[2] == OHLC_SEQ_LEN, f"Expected OHLC seq len {OHLC_SEQ_LEN}, got {ohlc_price_tensors.shape[2]}"
|
| 172 |
+
assert collated_batch['ohlc_interval_ids'].shape[0] == ohlc_price_tensors.shape[0], "Interval ids must align with OHLC segments"
|
| 173 |
+
assert ohlc_price_tensors.dtype == dtype, f"OHLC tensor dtype {ohlc_price_tensors.dtype} != expected {dtype}"
|
| 174 |
+
print(f"Collator produced {ohlc_price_tensors.shape[0]} OHLC segment(s).")
|
| 175 |
+
|
| 176 |
+
# --- FIXED: Update assertions for event-specific data which is mostly empty for now ---
|
| 177 |
+
assert collated_batch['dest_wallet_indices'].shape == (B, L)
|
| 178 |
+
assert collated_batch['transfer_numerical_features'].shape == (B, L, 4)
|
| 179 |
+
assert collated_batch['trade_numerical_features'].shape == (B, L, 8) # Corrected from 10
|
| 180 |
+
assert collated_batch['deployer_trade_numerical_features'].shape == (B, L, 8) # Corrected from 10
|
| 181 |
+
assert collated_batch['smart_wallet_trade_numerical_features'].shape == (B, L, 8) # Corrected from 10
|
| 182 |
+
assert collated_batch['pool_created_numerical_features'].shape == (B, L, 2)
|
| 183 |
+
assert collated_batch['liquidity_change_numerical_features'].shape == (B, L, 1)
|
| 184 |
+
assert collated_batch['fee_collected_numerical_features'].shape == (B, L, 1)
|
| 185 |
+
assert collated_batch['token_burn_numerical_features'].shape == (B, L, 2)
|
| 186 |
+
assert collated_batch['supply_lock_numerical_features'].shape == (B, L, 2)
|
| 187 |
+
assert collated_batch['onchain_snapshot_numerical_features'].shape == (B, L, 14)
|
| 188 |
+
assert collated_batch['trending_token_numerical_features'].shape == (B, L, 1)
|
| 189 |
+
assert collated_batch['boosted_token_numerical_features'].shape == (B, L, 2)
|
| 190 |
+
# assert len(collated_batch['holder_snapshot_raw_data']) == 1 # No holder snapshots yet
|
| 191 |
+
# assert len(collated_batch['textual_event_data']) == 8 # No textual events yet
|
| 192 |
+
assert collated_batch['dexboost_paid_numerical_features'].shape == (B, L, 2)
|
| 193 |
+
print("Collator correctly processed all event-specific numerical data into their respective tensors.")
|
| 194 |
+
|
| 195 |
+
# --- NEW: Comprehensive Debugging Output ---
|
| 196 |
+
print("\n--- Collated Batch Debug Output ---")
|
| 197 |
+
print(f"Batch Size: {B}, Max Sequence Length: {L}")
|
| 198 |
+
|
| 199 |
+
# Print shapes of key tensors
|
| 200 |
+
print("\n[Core Tensors]")
|
| 201 |
+
print(f" event_type_ids: {collated_batch['event_type_ids'].shape}")
|
| 202 |
+
print(f" attention_mask: {collated_batch['attention_mask'].shape}")
|
| 203 |
+
print(f" timestamps_float: {collated_batch['timestamps_float'].shape}")
|
| 204 |
+
|
| 205 |
+
print("\n[Pointer Tensors]")
|
| 206 |
+
print(f" wallet_indices: {collated_batch['wallet_indices'].shape}")
|
| 207 |
+
print(f" token_indices: {collated_batch['token_indices'].shape}")
|
| 208 |
+
|
| 209 |
+
print("\n[Encoder Inputs]")
|
| 210 |
+
print(f" embedding_pool: {collated_batch['embedding_pool'].shape}")
|
| 211 |
+
# --- FIXED: Check for a key that still exists after removing address embeddings ---
|
| 212 |
+
if collated_batch['token_encoder_inputs']['name_embed_indices'].numel() > 0:
|
| 213 |
+
print(f" token_encoder_inputs contains {collated_batch['token_encoder_inputs']['name_embed_indices'].shape[0]} tokens.")
|
| 214 |
+
else:
|
| 215 |
+
print(" token_encoder_inputs is empty.")
|
| 216 |
+
if collated_batch['wallet_encoder_inputs']['profile_rows']:
|
| 217 |
+
print(f" wallet_encoder_inputs contains {len(collated_batch['wallet_encoder_inputs']['profile_rows'])} wallets.")
|
| 218 |
+
else:
|
| 219 |
+
print(" wallet_encoder_inputs is empty.")
|
| 220 |
+
|
| 221 |
+
print("\n[Graph Links]")
|
| 222 |
+
if collated_batch['graph_updater_links']:
|
| 223 |
+
for link_name, data in collated_batch['graph_updater_links'].items():
|
| 224 |
+
print(f" - {link_name}: {data['edge_index'].shape[1]} edges")
|
| 225 |
+
else:
|
| 226 |
+
print(" No graph links in this batch.")
|
| 227 |
+
print("--- End Debug Output ---\n")
|
| 228 |
+
|
| 229 |
+
print("Embedding pool size:", collated_batch["embedding_pool"].shape[0])
|
| 230 |
+
print("Max name_emb_idx:", collated_batch["token_encoder_inputs"]["name_embed_indices"].max().item())
|
| 231 |
+
|
| 232 |
+
# 2. Model Forward Pass
|
| 233 |
+
with torch.no_grad():
|
| 234 |
+
model_outputs = model(collated_batch)
|
| 235 |
+
quantile_logits = model_outputs["quantile_logits"]
|
| 236 |
+
hidden_states = model_outputs["hidden_states"]
|
| 237 |
+
attention_mask = model_outputs["attention_mask"]
|
| 238 |
+
pooled_states = model_outputs["pooled_states"]
|
| 239 |
+
print("Model forward pass successful!")
|
| 240 |
+
|
| 241 |
+
# --- 7. Verify Output ---
|
| 242 |
+
print("\n--- Test Results ---")
|
| 243 |
+
D_MODEL = model.d_model
|
| 244 |
+
|
| 245 |
+
print(f"Final hidden_states shape: {hidden_states.shape}")
|
| 246 |
+
print(f"Final attention_mask shape: {attention_mask.shape}")
|
| 247 |
+
|
| 248 |
+
assert hidden_states.shape == (B, L, D_MODEL)
|
| 249 |
+
assert attention_mask.shape == (B, L)
|
| 250 |
+
assert hidden_states.dtype == dtype
|
| 251 |
+
|
| 252 |
+
print(f"Output mean (sanity check): {hidden_states.mean().item()}")
|
| 253 |
+
print(f"Pooled state shape: {pooled_states.shape}")
|
| 254 |
+
print(f"Quantile logits shape: {quantile_logits.shape}")
|
| 255 |
+
|
| 256 |
+
quantile_grid = quantile_logits.view(B, len(_test_horizons), len(_test_quantiles))
|
| 257 |
+
print("\n[Quantile Predictions]")
|
| 258 |
+
for b_idx in range(B):
|
| 259 |
+
print(f" Sample {b_idx}:")
|
| 260 |
+
for h_idx, horizon in enumerate(_test_horizons):
|
| 261 |
+
row = quantile_grid[b_idx, h_idx]
|
| 262 |
+
print(f" Horizon {horizon}s -> " + ", ".join(
|
| 263 |
+
f"q={q:.2f}: {row[q_idx].item():.6f}"
|
| 264 |
+
for q_idx, q in enumerate(_test_quantiles)
|
| 265 |
+
))
|
| 266 |
+
|
| 267 |
+
print("\n✅ **Test Passed!** Full ENCODING pipeline is working.")
|
| 268 |
+
|
| 269 |
+
except Exception as e:
|
| 270 |
+
print(f"\n❌ Error during pipeline test: {e}")
|
| 271 |
+
traceback.print_exc()
|
link_graph.rs
ADDED
|
@@ -0,0 +1,2275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
use crate::aggregator::graph_schema::{
|
| 2 |
+
BundleTradeLink, BurnedLink, CoordinatedActivityLink, CopiedTradeLink, LockedSupplyLink,
|
| 3 |
+
MintedLink, ProvidedLiquidityLink, SnipedLink, TopTraderOfLink, TransferLink, WhaleOfLink,
|
| 4 |
+
};
|
| 5 |
+
use crate::handlers::constants::{
|
| 6 |
+
NATIVE_MINT, PROTOCOL_PUMPFUN_LAUNCHPAD, USD1_MINT, USDC_MINT, USDT_MINT,
|
| 7 |
+
};
|
| 8 |
+
use crate::types::{
|
| 9 |
+
BurnRow, EventPayload, EventType, LiquidityRow, MintRow, SupplyLockRow, TradeRow, TransferRow,
|
| 10 |
+
};
|
| 11 |
+
use anyhow::{Result, anyhow};
|
| 12 |
+
use chrono::Utc;
|
| 13 |
+
use clickhouse::{Client, Row};
|
| 14 |
+
use futures::stream::{self, StreamExt};
|
| 15 |
+
use itertools::Itertools;
|
| 16 |
+
use neo4rs::{BoltType, Graph, query};
|
| 17 |
+
use once_cell::sync::Lazy;
|
| 18 |
+
use serde::Deserialize;
|
| 19 |
+
use solana_sdk::native_token::LAMPORTS_PER_SOL;
|
| 20 |
+
use std::collections::{HashMap, HashSet, VecDeque};
|
| 21 |
+
use std::future::Future;
|
| 22 |
+
use std::str::FromStr;
|
| 23 |
+
use std::sync::Arc;
|
| 24 |
+
use std::sync::atomic::{AtomicUsize, Ordering};
|
| 25 |
+
use std::time::Duration;
|
| 26 |
+
use tokio::sync::{Mutex, mpsc};
|
| 27 |
+
use tokio::time::sleep;
|
| 28 |
+
use tokio::time::{Instant, MissedTickBehavior, interval};
|
| 29 |
+
use tokio::try_join;
|
| 30 |
+
|
| 31 |
+
fn decimals_for_quote(mint: &str) -> u8 {
|
| 32 |
+
if mint == NATIVE_MINT {
|
| 33 |
+
9
|
| 34 |
+
} else if mint == USDC_MINT || mint == USDT_MINT || mint == USD1_MINT {
|
| 35 |
+
6
|
| 36 |
+
} else {
|
| 37 |
+
9 // default assumption if unknown
|
| 38 |
+
}
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
#[derive(Debug)]
|
| 42 |
+
struct LinkGraphConfig {
|
| 43 |
+
time_window_seconds: u32,
|
| 44 |
+
copied_trade_window_seconds: i64,
|
| 45 |
+
sniper_rank_threshold: u64,
|
| 46 |
+
whale_rank_threshold: u64,
|
| 47 |
+
min_top_trader_pnl: f32,
|
| 48 |
+
min_trade_total_usd: f64,
|
| 49 |
+
ath_price_threshold_usd: f64,
|
| 50 |
+
window_max_wait_ms: u64,
|
| 51 |
+
late_slack_ms: u64,
|
| 52 |
+
chunk_size_large: usize,
|
| 53 |
+
chunk_size_historical: usize,
|
| 54 |
+
chunk_size_mint_small: usize,
|
| 55 |
+
chunk_size_mint_large: usize,
|
| 56 |
+
chunk_size_token: usize,
|
| 57 |
+
trade_cache_max_entries: usize,
|
| 58 |
+
trade_cache_ttl_secs: u32,
|
| 59 |
+
trade_cache_max_recent: usize,
|
| 60 |
+
writer_channel_capacity: usize,
|
| 61 |
+
writer_max_batch_rows: usize,
|
| 62 |
+
writer_retry_attempts: u32,
|
| 63 |
+
writer_retry_backoff_ms: u64,
|
| 64 |
+
ath_fetch_chunk_size: usize,
|
| 65 |
+
ch_retry_attempts: u32,
|
| 66 |
+
ch_retry_backoff_ms: u64,
|
| 67 |
+
ch_fail_fast: bool,
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
static LINK_GRAPH_CONFIG: Lazy<LinkGraphConfig> = Lazy::new(|| LinkGraphConfig {
|
| 71 |
+
time_window_seconds: env_parse("LINK_GRAPH_TIME_WINDOW_SECONDS", 120_u32),
|
| 72 |
+
copied_trade_window_seconds: env_parse("LINK_GRAPH_COPIED_TRADE_WINDOW_SECONDS", 60_i64),
|
| 73 |
+
sniper_rank_threshold: env_parse("LINK_GRAPH_SNIPER_RANK_THRESHOLD", 45_u64),
|
| 74 |
+
whale_rank_threshold: env_parse("LINK_GRAPH_WHALE_RANK_THRESHOLD", 5_u64),
|
| 75 |
+
min_top_trader_pnl: env_parse("LINK_GRAPH_MIN_TOP_TRADER_PNL", 1.0_f32),
|
| 76 |
+
min_trade_total_usd: env_parse("LINK_GRAPH_MIN_TRADE_TOTAL_USD", 20.0_f64),
|
| 77 |
+
ath_price_threshold_usd: env_parse("LINK_GRAPH_ATH_PRICE_THRESHOLD_USD", 0.0002000_f64),
|
| 78 |
+
window_max_wait_ms: env_parse("LINK_GRAPH_WINDOW_MAX_WAIT_MS", 250_u64),
|
| 79 |
+
late_slack_ms: env_parse("LINK_GRAPH_LATE_SLACK_MS", 2000_u64),
|
| 80 |
+
chunk_size_large: env_parse("LINK_GRAPH_CHUNK_SIZE_LARGE", 3000_usize),
|
| 81 |
+
chunk_size_historical: env_parse("LINK_GRAPH_CHUNK_SIZE_HISTORICAL", 1000_usize),
|
| 82 |
+
chunk_size_mint_small: env_parse("LINK_GRAPH_CHUNK_SIZE_MINT_SMALL", 1500_usize),
|
| 83 |
+
chunk_size_mint_large: env_parse("LINK_GRAPH_CHUNK_SIZE_MINT_LARGE", 3000_usize),
|
| 84 |
+
chunk_size_token: env_parse("LINK_GRAPH_CHUNK_SIZE_TOKEN", 3000_usize),
|
| 85 |
+
trade_cache_max_entries: env_parse("LINK_GRAPH_TRADE_CACHE_MAX_ENTRIES", 1_000_000_usize),
|
| 86 |
+
trade_cache_ttl_secs: env_parse("LINK_GRAPH_TRADE_CACHE_TTL_SECS", 600_u32),
|
| 87 |
+
trade_cache_max_recent: env_parse("LINK_GRAPH_TRADE_CACHE_MAX_RECENT", 16_usize),
|
| 88 |
+
writer_channel_capacity: env_parse("LINK_GRAPH_WRITER_CHANNEL_CAPACITY", 5000_usize),
|
| 89 |
+
writer_max_batch_rows: env_parse("LINK_GRAPH_WRITER_MAX_BATCH_ROWS", 1000_usize),
|
| 90 |
+
writer_retry_attempts: env_parse("LINK_GRAPH_WRITER_RETRY_ATTEMPTS", 3_u32),
|
| 91 |
+
writer_retry_backoff_ms: env_parse("LINK_GRAPH_WRITER_RETRY_BACKOFF_MS", 250_u64),
|
| 92 |
+
ath_fetch_chunk_size: env_parse("LINK_GRAPH_ATH_FETCH_CHUNK_SIZE", 500_usize),
|
| 93 |
+
ch_retry_attempts: env_parse("LINK_GRAPH_CH_RETRY_ATTEMPTS", 3_u32),
|
| 94 |
+
ch_retry_backoff_ms: env_parse("LINK_GRAPH_CH_RETRY_BACKOFF_MS", 500_u64),
|
| 95 |
+
ch_fail_fast: env_parse("LINK_GRAPH_CH_FAIL_FAST", true),
|
| 96 |
+
});
|
| 97 |
+
|
| 98 |
+
fn env_parse<T: FromStr>(key: &str, default: T) -> T {
|
| 99 |
+
std::env::var(key)
|
| 100 |
+
.ok()
|
| 101 |
+
.and_then(|v| v.parse::<T>().ok())
|
| 102 |
+
.unwrap_or(default)
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
#[derive(Row, Deserialize, Clone)]
|
| 106 |
+
struct FullHistTrade {
|
| 107 |
+
maker: String,
|
| 108 |
+
base_address: String,
|
| 109 |
+
timestamp: u32,
|
| 110 |
+
signature: String,
|
| 111 |
+
trade_type: u8,
|
| 112 |
+
total_usd: f64,
|
| 113 |
+
slippage: f32,
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
enum FollowerLink {
|
| 117 |
+
Copied(CopiedTradeLink),
|
| 118 |
+
Coordinated(CoordinatedActivityLink),
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
pub struct LinkGraph {
|
| 122 |
+
db_client: Client,
|
| 123 |
+
neo4j_client: Arc<Graph>,
|
| 124 |
+
rx: mpsc::Receiver<EventPayload>,
|
| 125 |
+
link_graph_depth: Arc<AtomicUsize>,
|
| 126 |
+
write_lock: Mutex<()>,
|
| 127 |
+
trade_cache: Arc<Mutex<HashMap<(String, String), CachedPairState>>>,
|
| 128 |
+
write_sender: mpsc::Sender<WriteJob>,
|
| 129 |
+
writer_depth: Arc<AtomicUsize>,
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
// Global Neo4j write lock to serialize batches across workers and avoid deadlocks.
|
| 133 |
+
static NEO4J_WRITE_LOCK: Lazy<Mutex<()>> = Lazy::new(|| Mutex::new(()));
|
| 134 |
+
|
| 135 |
+
#[derive(Row, Deserialize, Debug)]
|
| 136 |
+
struct Ping {
|
| 137 |
+
alive: u8,
|
| 138 |
+
}
|
| 139 |
+
#[derive(Row, Deserialize, Debug)]
|
| 140 |
+
struct CountResult {
|
| 141 |
+
count: u64,
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
#[derive(Clone, Debug)]
|
| 145 |
+
struct CachedTrade {
|
| 146 |
+
maker: String,
|
| 147 |
+
base_address: String,
|
| 148 |
+
timestamp: u32,
|
| 149 |
+
signature: String,
|
| 150 |
+
trade_type: u8,
|
| 151 |
+
total_usd: f64,
|
| 152 |
+
slippage: f32,
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
#[derive(Debug)]
|
| 156 |
+
struct CachedPairState {
|
| 157 |
+
first_buy: Option<CachedTrade>,
|
| 158 |
+
first_sell: Option<CachedTrade>,
|
| 159 |
+
recent: VecDeque<CachedTrade>,
|
| 160 |
+
last_seen: u32,
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
#[derive(Debug)]
|
| 164 |
+
pub struct WriteJob {
|
| 165 |
+
query: String,
|
| 166 |
+
params: Vec<HashMap<String, BoltType>>,
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
impl LinkGraph {
|
| 170 |
+
pub async fn new(
|
| 171 |
+
db_client: Client,
|
| 172 |
+
neo4j_client: Arc<Graph>,
|
| 173 |
+
rx: mpsc::Receiver<EventPayload>,
|
| 174 |
+
link_graph_depth: Arc<AtomicUsize>,
|
| 175 |
+
write_sender: mpsc::Sender<WriteJob>,
|
| 176 |
+
writer_depth: Arc<AtomicUsize>,
|
| 177 |
+
) -> Result<Self> {
|
| 178 |
+
let _: Ping = db_client.query("SELECT 1 as alive").fetch_one().await?;
|
| 179 |
+
neo4j_client.run(query("MATCH (n) RETURN count(n)")).await?;
|
| 180 |
+
println!("[WalletGraph] ✔️ Connected to ClickHouse, Neo4j. Listening on channel.");
|
| 181 |
+
Ok(Self {
|
| 182 |
+
db_client,
|
| 183 |
+
neo4j_client,
|
| 184 |
+
rx,
|
| 185 |
+
link_graph_depth,
|
| 186 |
+
write_lock: Mutex::new(()),
|
| 187 |
+
trade_cache: Arc::new(Mutex::new(HashMap::new())),
|
| 188 |
+
write_sender,
|
| 189 |
+
writer_depth,
|
| 190 |
+
})
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
async fn with_ch_retry<T, F, Fut>(&self, mut op: F, label: &str) -> Result<T>
|
| 194 |
+
where
|
| 195 |
+
F: FnMut() -> Fut,
|
| 196 |
+
Fut: Future<Output = Result<T>>,
|
| 197 |
+
{
|
| 198 |
+
let cfg = &*LINK_GRAPH_CONFIG;
|
| 199 |
+
let mut attempts = 0;
|
| 200 |
+
loop {
|
| 201 |
+
attempts += 1;
|
| 202 |
+
match op().await {
|
| 203 |
+
Ok(res) => return Ok(res),
|
| 204 |
+
Err(e) => {
|
| 205 |
+
if attempts >= cfg.ch_retry_attempts {
|
| 206 |
+
return Err(anyhow!(
|
| 207 |
+
"[LinkGraph] {} failed after {} attempts: {}",
|
| 208 |
+
label,
|
| 209 |
+
attempts,
|
| 210 |
+
e
|
| 211 |
+
));
|
| 212 |
+
}
|
| 213 |
+
let backoff = cfg.ch_retry_backoff_ms * attempts as u64;
|
| 214 |
+
eprintln!(
|
| 215 |
+
"[LinkGraph] ⚠️ {} retry {}/{} after {}ms: {}",
|
| 216 |
+
label, attempts, cfg.ch_retry_attempts, backoff, e
|
| 217 |
+
);
|
| 218 |
+
sleep(Duration::from_millis(backoff)).await;
|
| 219 |
+
}
|
| 220 |
+
}
|
| 221 |
+
}
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
pub async fn run(&mut self) -> Result<()> {
|
| 225 |
+
let cfg = &*LINK_GRAPH_CONFIG;
|
| 226 |
+
let mut message_buffer: Vec<EventPayload> = Vec::new();
|
| 227 |
+
let mut current_window_start: Option<u32> = None;
|
| 228 |
+
let mut window_opened_at: Option<Instant> = None;
|
| 229 |
+
let mut flush_check = interval(Duration::from_millis(cfg.window_max_wait_ms.max(50)));
|
| 230 |
+
flush_check.set_missed_tick_behavior(MissedTickBehavior::Delay);
|
| 231 |
+
let late_slack_secs: u32 = (cfg.late_slack_ms / 1000) as u32;
|
| 232 |
+
|
| 233 |
+
loop {
|
| 234 |
+
tokio::select! {
|
| 235 |
+
maybe_payload = self.rx.recv() => {
|
| 236 |
+
match maybe_payload {
|
| 237 |
+
Some(payload) => {
|
| 238 |
+
// one item left the channel
|
| 239 |
+
self.link_graph_depth.fetch_sub(1, Ordering::Relaxed);
|
| 240 |
+
if current_window_start.is_none() {
|
| 241 |
+
current_window_start = Some(payload.timestamp);
|
| 242 |
+
window_opened_at = Some(Instant::now());
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
let window_end = current_window_start.unwrap() + cfg.time_window_seconds;
|
| 246 |
+
if payload.timestamp <= window_end + late_slack_secs {
|
| 247 |
+
message_buffer.push(payload);
|
| 248 |
+
} else {
|
| 249 |
+
if !message_buffer.is_empty() {
|
| 250 |
+
message_buffer.sort_by_key(|p| p.timestamp);
|
| 251 |
+
let batch = std::mem::take(&mut message_buffer);
|
| 252 |
+
if let Err(e) = self.process_batch_with_retry(batch).await {
|
| 253 |
+
eprintln!("[LinkGraph] 🔴 Fatal processing window: {}", e);
|
| 254 |
+
std::process::exit(1);
|
| 255 |
+
}
|
| 256 |
+
}
|
| 257 |
+
current_window_start = Some(payload.timestamp);
|
| 258 |
+
window_opened_at = Some(Instant::now());
|
| 259 |
+
message_buffer.push(payload);
|
| 260 |
+
}
|
| 261 |
+
}
|
| 262 |
+
None => {
|
| 263 |
+
eprintln!("[LinkGraph] 🔴 Input channel closed. Exiting.");
|
| 264 |
+
if !message_buffer.is_empty() {
|
| 265 |
+
message_buffer.sort_by_key(|p| p.timestamp);
|
| 266 |
+
let batch = std::mem::take(&mut message_buffer);
|
| 267 |
+
if let Err(e) = self.process_batch_with_retry(batch).await {
|
| 268 |
+
eprintln!("[LinkGraph] 🔴 Fatal processing final window: {}", e);
|
| 269 |
+
}
|
| 270 |
+
}
|
| 271 |
+
// Fatal: the producer is gone. Exit so it's obvious.
|
| 272 |
+
std::process::exit(1);
|
| 273 |
+
}
|
| 274 |
+
}
|
| 275 |
+
}
|
| 276 |
+
_ = flush_check.tick() => {
|
| 277 |
+
if !message_buffer.is_empty() {
|
| 278 |
+
if let Some(opened) = window_opened_at {
|
| 279 |
+
if opened.elapsed() >= Duration::from_millis(cfg.window_max_wait_ms) {
|
| 280 |
+
message_buffer.sort_by_key(|p| p.timestamp);
|
| 281 |
+
let batch = std::mem::take(&mut message_buffer);
|
| 282 |
+
if let Err(e) = self.process_batch_with_retry(batch).await {
|
| 283 |
+
eprintln!("[LinkGraph] 🔴 Fatal processing timed window: {}", e);
|
| 284 |
+
std::process::exit(1);
|
| 285 |
+
}
|
| 286 |
+
current_window_start = None;
|
| 287 |
+
window_opened_at = None;
|
| 288 |
+
}
|
| 289 |
+
}
|
| 290 |
+
}
|
| 291 |
+
}
|
| 292 |
+
}
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
Ok(())
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
async fn process_time_window(&self, payloads: &[EventPayload]) -> Result<()> {
|
| 299 |
+
let cfg = &*LINK_GRAPH_CONFIG;
|
| 300 |
+
let mut unique_wallets = HashSet::new();
|
| 301 |
+
let mut unique_tokens = HashSet::new();
|
| 302 |
+
let mut trades = Vec::new();
|
| 303 |
+
let mut transfers = Vec::new();
|
| 304 |
+
let mut mints = Vec::new();
|
| 305 |
+
let mut supply_locks = Vec::new();
|
| 306 |
+
let mut burns = Vec::new();
|
| 307 |
+
let mut liquidity_events = Vec::new();
|
| 308 |
+
|
| 309 |
+
for payload in payloads {
|
| 310 |
+
match &payload.event {
|
| 311 |
+
EventType::Trade(trade) => {
|
| 312 |
+
// Skip dust trades to reduce noise in downstream links/datasets
|
| 313 |
+
if trade.total_usd >= cfg.min_trade_total_usd {
|
| 314 |
+
unique_wallets.insert(trade.maker.clone());
|
| 315 |
+
unique_tokens.insert(trade.base_address.clone());
|
| 316 |
+
trades.push(trade.clone());
|
| 317 |
+
}
|
| 318 |
+
}
|
| 319 |
+
EventType::Transfer(transfer) => {
|
| 320 |
+
unique_wallets.insert(transfer.source.clone());
|
| 321 |
+
unique_wallets.insert(transfer.destination.clone());
|
| 322 |
+
transfers.push(transfer.clone());
|
| 323 |
+
}
|
| 324 |
+
EventType::Mint(mint) => {
|
| 325 |
+
unique_wallets.insert(mint.creator_address.clone());
|
| 326 |
+
unique_tokens.insert(mint.mint_address.clone());
|
| 327 |
+
mints.push(mint.clone());
|
| 328 |
+
}
|
| 329 |
+
EventType::SupplyLock(lock) => {
|
| 330 |
+
unique_wallets.insert(lock.sender.clone());
|
| 331 |
+
unique_wallets.insert(lock.recipient.clone());
|
| 332 |
+
unique_tokens.insert(lock.mint_address.clone());
|
| 333 |
+
supply_locks.push(lock.clone());
|
| 334 |
+
}
|
| 335 |
+
EventType::Burn(burn) => {
|
| 336 |
+
unique_wallets.insert(burn.source.clone());
|
| 337 |
+
unique_tokens.insert(burn.mint_address.clone());
|
| 338 |
+
burns.push(burn.clone());
|
| 339 |
+
}
|
| 340 |
+
EventType::Liquidity(liquidity) => {
|
| 341 |
+
if liquidity.change_type == 0 {
|
| 342 |
+
// 0 = Add Liquidity
|
| 343 |
+
unique_wallets.insert(liquidity.lp_provider.clone());
|
| 344 |
+
liquidity_events.push(liquidity.clone());
|
| 345 |
+
}
|
| 346 |
+
}
|
| 347 |
+
_ => {}
|
| 348 |
+
}
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
// Run link detection in parallel; writes remain serialized by the global Neo4j lock.
|
| 352 |
+
let parallel_start = Instant::now();
|
| 353 |
+
try_join!(
|
| 354 |
+
self.process_mints(&mints, &trades),
|
| 355 |
+
self.process_transfers_and_funding(&transfers),
|
| 356 |
+
self.process_supply_locks(&supply_locks),
|
| 357 |
+
self.process_burns(&burns),
|
| 358 |
+
self.process_liquidity_events(&liquidity_events),
|
| 359 |
+
self.process_trade_patterns(&trades, &mints),
|
| 360 |
+
)?;
|
| 361 |
+
println!(
|
| 362 |
+
"[LinkGraph] [TimeWindow] Parallel link processing finished in: {:?}",
|
| 363 |
+
parallel_start.elapsed()
|
| 364 |
+
);
|
| 365 |
+
Ok(())
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
async fn process_batch(&self, mut payloads: Vec<EventPayload>) -> Result<()> {
|
| 369 |
+
if payloads.is_empty() {
|
| 370 |
+
return Ok(());
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
// Payloads are already a complete time-window. We just need to sort them.
|
| 374 |
+
payloads.sort_by_key(|p| p.timestamp);
|
| 375 |
+
|
| 376 |
+
// Process the entire batch as a single logical unit with a per-worker write lock.
|
| 377 |
+
let _guard = self.write_lock.lock().await;
|
| 378 |
+
self.process_time_window(&payloads).await?;
|
| 379 |
+
|
| 380 |
+
println!(
|
| 381 |
+
"[LinkGraph] Finished processing batch of {} events.",
|
| 382 |
+
payloads.len()
|
| 383 |
+
);
|
| 384 |
+
Ok(())
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
async fn process_batch_with_retry(&self, payloads: Vec<EventPayload>) -> Result<()> {
|
| 388 |
+
// Serialize across all workers to avoid Neo4j deadlocks.
|
| 389 |
+
let _global_lock = NEO4J_WRITE_LOCK.lock().await;
|
| 390 |
+
let mut attempts = 0;
|
| 391 |
+
let max_retries = 3;
|
| 392 |
+
loop {
|
| 393 |
+
match self.process_batch(payloads.clone()).await {
|
| 394 |
+
Ok(_) => return Ok(()),
|
| 395 |
+
Err(e) => {
|
| 396 |
+
let err_str = e.to_string();
|
| 397 |
+
if err_str.contains("DeadlockDetected") && attempts < max_retries {
|
| 398 |
+
attempts += 1;
|
| 399 |
+
let backoff_ms = 200 * attempts;
|
| 400 |
+
eprintln!(
|
| 401 |
+
"[LinkGraph] ⚠️ Deadlock detected, retrying {}/{} after {}ms",
|
| 402 |
+
attempts, max_retries, backoff_ms
|
| 403 |
+
);
|
| 404 |
+
sleep(Duration::from_millis(backoff_ms as u64)).await;
|
| 405 |
+
continue;
|
| 406 |
+
} else {
|
| 407 |
+
return Err(e);
|
| 408 |
+
}
|
| 409 |
+
}
|
| 410 |
+
}
|
| 411 |
+
}
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
// --- Main Logic for Pattern Detection ---
|
| 415 |
+
fn cached_trade_from_trade(trade: &TradeRow) -> CachedTrade {
|
| 416 |
+
CachedTrade {
|
| 417 |
+
maker: trade.maker.clone(),
|
| 418 |
+
base_address: trade.base_address.clone(),
|
| 419 |
+
timestamp: trade.timestamp,
|
| 420 |
+
signature: trade.signature.clone(),
|
| 421 |
+
trade_type: trade.trade_type,
|
| 422 |
+
total_usd: trade.total_usd,
|
| 423 |
+
slippage: trade.slippage,
|
| 424 |
+
}
|
| 425 |
+
}
|
| 426 |
+
|
| 427 |
+
async fn update_trade_cache(&self, trades: &[&TradeRow]) -> Result<()> {
|
| 428 |
+
if trades.is_empty() {
|
| 429 |
+
return Ok(());
|
| 430 |
+
}
|
| 431 |
+
let cfg = &*LINK_GRAPH_CONFIG;
|
| 432 |
+
let now_ts = trades.iter().map(|t| t.timestamp).max().unwrap_or(0);
|
| 433 |
+
let cutoff = now_ts.saturating_sub(cfg.trade_cache_ttl_secs);
|
| 434 |
+
|
| 435 |
+
let mut cache = self.trade_cache.lock().await;
|
| 436 |
+
cache.retain(|_, state| state.last_seen >= cutoff);
|
| 437 |
+
|
| 438 |
+
for trade in trades {
|
| 439 |
+
let key = (trade.maker.clone(), trade.base_address.clone());
|
| 440 |
+
let entry = cache.entry(key).or_insert_with(|| CachedPairState {
|
| 441 |
+
first_buy: None,
|
| 442 |
+
first_sell: None,
|
| 443 |
+
recent: VecDeque::new(),
|
| 444 |
+
last_seen: 0,
|
| 445 |
+
});
|
| 446 |
+
|
| 447 |
+
entry.last_seen = entry.last_seen.max(trade.timestamp);
|
| 448 |
+
|
| 449 |
+
let ct = Self::cached_trade_from_trade(trade);
|
| 450 |
+
if trade.trade_type == 0 {
|
| 451 |
+
if entry
|
| 452 |
+
.first_buy
|
| 453 |
+
.as_ref()
|
| 454 |
+
.map_or(true, |b| ct.timestamp < b.timestamp)
|
| 455 |
+
{
|
| 456 |
+
entry.first_buy = Some(ct.clone());
|
| 457 |
+
}
|
| 458 |
+
} else if trade.trade_type == 1 {
|
| 459 |
+
if entry
|
| 460 |
+
.first_sell
|
| 461 |
+
.as_ref()
|
| 462 |
+
.map_or(true, |s| ct.timestamp < s.timestamp)
|
| 463 |
+
{
|
| 464 |
+
entry.first_sell = Some(ct.clone());
|
| 465 |
+
}
|
| 466 |
+
}
|
| 467 |
+
|
| 468 |
+
entry.recent.push_back(ct);
|
| 469 |
+
while entry.recent.len() > cfg.trade_cache_max_recent {
|
| 470 |
+
entry.recent.pop_front();
|
| 471 |
+
}
|
| 472 |
+
while let Some(front) = entry.recent.front() {
|
| 473 |
+
if front.timestamp + cfg.trade_cache_ttl_secs < now_ts {
|
| 474 |
+
entry.recent.pop_front();
|
| 475 |
+
} else {
|
| 476 |
+
break;
|
| 477 |
+
}
|
| 478 |
+
}
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
if cache.len() > cfg.trade_cache_max_entries {
|
| 482 |
+
let mut entries: Vec<_> = cache
|
| 483 |
+
.iter()
|
| 484 |
+
.map(|(k, v)| (k.clone(), v.last_seen))
|
| 485 |
+
.collect();
|
| 486 |
+
entries.sort_by_key(|(_, ts)| *ts);
|
| 487 |
+
let to_drop = entries.len().saturating_sub(cfg.trade_cache_max_entries);
|
| 488 |
+
for (key, _) in entries.into_iter().take(to_drop) {
|
| 489 |
+
cache.remove(&key);
|
| 490 |
+
}
|
| 491 |
+
}
|
| 492 |
+
Ok(())
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
async fn build_histories_from_cache(
|
| 496 |
+
&self,
|
| 497 |
+
pairs: &[(String, String)],
|
| 498 |
+
) -> Result<HashMap<(String, String), Vec<FullHistTrade>>> {
|
| 499 |
+
let mut map = HashMap::new();
|
| 500 |
+
let cache = self.trade_cache.lock().await;
|
| 501 |
+
for pair in pairs {
|
| 502 |
+
if let Some(state) = cache.get(pair) {
|
| 503 |
+
let mut collected = Vec::new();
|
| 504 |
+
if let Some(b) = &state.first_buy {
|
| 505 |
+
collected.push(Self::cached_to_full(b));
|
| 506 |
+
}
|
| 507 |
+
if let Some(s) = &state.first_sell {
|
| 508 |
+
collected.push(Self::cached_to_full(s));
|
| 509 |
+
}
|
| 510 |
+
for t in state.recent.iter() {
|
| 511 |
+
collected.push(Self::cached_to_full(t));
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
if !collected.is_empty() {
|
| 515 |
+
collected.sort_by_key(|t| t.timestamp);
|
| 516 |
+
collected.dedup_by(|a, b| a.signature == b.signature);
|
| 517 |
+
map.insert(pair.clone(), collected);
|
| 518 |
+
}
|
| 519 |
+
}
|
| 520 |
+
}
|
| 521 |
+
Ok(map)
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
fn cached_to_full(ct: &CachedTrade) -> FullHistTrade {
|
| 525 |
+
FullHistTrade {
|
| 526 |
+
maker: ct.maker.clone(),
|
| 527 |
+
base_address: ct.base_address.clone(),
|
| 528 |
+
timestamp: ct.timestamp,
|
| 529 |
+
signature: ct.signature.clone(),
|
| 530 |
+
trade_type: ct.trade_type,
|
| 531 |
+
total_usd: ct.total_usd,
|
| 532 |
+
slippage: ct.slippage,
|
| 533 |
+
}
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
pub async fn writer_task(
|
| 537 |
+
mut rx: mpsc::Receiver<WriteJob>,
|
| 538 |
+
neo4j_client: Arc<Graph>,
|
| 539 |
+
writer_depth: Arc<AtomicUsize>,
|
| 540 |
+
) {
|
| 541 |
+
let cfg = &*LINK_GRAPH_CONFIG;
|
| 542 |
+
while let Some(job) = rx.recv().await {
|
| 543 |
+
writer_depth.fetch_sub(1, Ordering::Relaxed);
|
| 544 |
+
let batches = job
|
| 545 |
+
.params
|
| 546 |
+
.chunks(cfg.writer_max_batch_rows.max(1))
|
| 547 |
+
.map(|chunk| chunk.to_vec())
|
| 548 |
+
.collect::<Vec<_>>();
|
| 549 |
+
|
| 550 |
+
for (idx, params) in batches.iter().enumerate() {
|
| 551 |
+
let q = query(&job.query).param("x", params.clone());
|
| 552 |
+
let mut attempts = 0;
|
| 553 |
+
loop {
|
| 554 |
+
let start = Instant::now();
|
| 555 |
+
match neo4j_client.run(q.clone()).await {
|
| 556 |
+
Ok(_) => {
|
| 557 |
+
println!(
|
| 558 |
+
"[LinkGraph] [Writer] ✅ wrote {} rows (chunk {}/{}) in {:?}",
|
| 559 |
+
params.len(),
|
| 560 |
+
idx + 1,
|
| 561 |
+
batches.len(),
|
| 562 |
+
start.elapsed()
|
| 563 |
+
);
|
| 564 |
+
break;
|
| 565 |
+
}
|
| 566 |
+
Err(e) => {
|
| 567 |
+
let msg = e.to_string();
|
| 568 |
+
attempts += 1;
|
| 569 |
+
if msg.contains("DeadlockDetected")
|
| 570 |
+
&& attempts <= cfg.writer_retry_attempts
|
| 571 |
+
{
|
| 572 |
+
let backoff = cfg.writer_retry_backoff_ms * attempts as u64;
|
| 573 |
+
eprintln!(
|
| 574 |
+
"[LinkGraph] [Writer] ⚠️ deadlock, retry {}/{} after {}ms: {}",
|
| 575 |
+
attempts, cfg.writer_retry_attempts, backoff, msg
|
| 576 |
+
);
|
| 577 |
+
sleep(Duration::from_millis(backoff)).await;
|
| 578 |
+
continue;
|
| 579 |
+
} else {
|
| 580 |
+
eprintln!(
|
| 581 |
+
"[LinkGraph] 🔴 Writer fatal after {} attempts: {}",
|
| 582 |
+
attempts, msg
|
| 583 |
+
);
|
| 584 |
+
std::process::exit(1);
|
| 585 |
+
}
|
| 586 |
+
}
|
| 587 |
+
}
|
| 588 |
+
}
|
| 589 |
+
}
|
| 590 |
+
}
|
| 591 |
+
eprintln!("[LinkGraph] 🔴 Writer channel closed.");
|
| 592 |
+
std::process::exit(1);
|
| 593 |
+
}
|
| 594 |
+
|
| 595 |
+
async fn enqueue_write(
|
| 596 |
+
&self,
|
| 597 |
+
cypher: &str,
|
| 598 |
+
params: Vec<HashMap<String, BoltType>>,
|
| 599 |
+
) -> Result<()> {
|
| 600 |
+
let job = WriteJob {
|
| 601 |
+
query: cypher.to_string(),
|
| 602 |
+
params,
|
| 603 |
+
};
|
| 604 |
+
self.write_sender
|
| 605 |
+
.send(job)
|
| 606 |
+
.await
|
| 607 |
+
.map_err(|e| anyhow!("[LinkGraph] Failed to enqueue write: {}", e))?;
|
| 608 |
+
self.writer_depth.fetch_add(1, Ordering::Relaxed);
|
| 609 |
+
Ok(())
|
| 610 |
+
}
|
| 611 |
+
|
| 612 |
+
async fn process_mints(
|
| 613 |
+
&self,
|
| 614 |
+
mints: &[MintRow],
|
| 615 |
+
all_trades_in_batch: &[TradeRow],
|
| 616 |
+
) -> Result<()> {
|
| 617 |
+
let start = Instant::now();
|
| 618 |
+
if mints.is_empty() {
|
| 619 |
+
return Ok(());
|
| 620 |
+
}
|
| 621 |
+
let mut links = Vec::new();
|
| 622 |
+
|
| 623 |
+
for mint in mints {
|
| 624 |
+
let dev_buy = all_trades_in_batch.iter().find(
|
| 625 |
+
|t| {
|
| 626 |
+
t.maker == mint.creator_address
|
| 627 |
+
&& t.base_address == mint.mint_address
|
| 628 |
+
&& t.trade_type == 0
|
| 629 |
+
}, // 0 = Buy
|
| 630 |
+
);
|
| 631 |
+
let buy_amount_decimals = dev_buy.map_or(0.0, |t| {
|
| 632 |
+
let quote_decimals = decimals_for_quote(&t.quote_address);
|
| 633 |
+
t.quote_amount as f64 / 10f64.powi(quote_decimals as i32)
|
| 634 |
+
});
|
| 635 |
+
links.push(MintedLink {
|
| 636 |
+
signature: mint.signature.clone(),
|
| 637 |
+
timestamp: mint.timestamp as i64,
|
| 638 |
+
buy_amount: buy_amount_decimals,
|
| 639 |
+
});
|
| 640 |
+
}
|
| 641 |
+
self.write_minted_links(&links, mints).await?;
|
| 642 |
+
println!(
|
| 643 |
+
"[LinkGraph] [Profile] process_mints: {} mints in {:?}",
|
| 644 |
+
mints.len(),
|
| 645 |
+
start.elapsed()
|
| 646 |
+
);
|
| 647 |
+
Ok(())
|
| 648 |
+
}
|
| 649 |
+
|
| 650 |
+
async fn process_supply_locks(&self, locks: &[SupplyLockRow]) -> Result<()> {
|
| 651 |
+
let start = Instant::now();
|
| 652 |
+
if locks.is_empty() {
|
| 653 |
+
return Ok(());
|
| 654 |
+
}
|
| 655 |
+
let links: Vec<_> = locks
|
| 656 |
+
.iter()
|
| 657 |
+
.map(|l| LockedSupplyLink {
|
| 658 |
+
signature: l.signature.clone(),
|
| 659 |
+
amount: l.total_locked_amount as f64,
|
| 660 |
+
timestamp: l.timestamp as i64,
|
| 661 |
+
unlock_timestamp: l.final_unlock_timestamp,
|
| 662 |
+
})
|
| 663 |
+
.collect();
|
| 664 |
+
self.write_locked_supply_links(&links, locks).await?;
|
| 665 |
+
println!(
|
| 666 |
+
"[LinkGraph] [Profile] process_supply_locks: {} locks in {:?}",
|
| 667 |
+
locks.len(),
|
| 668 |
+
start.elapsed()
|
| 669 |
+
);
|
| 670 |
+
Ok(())
|
| 671 |
+
}
|
| 672 |
+
|
| 673 |
+
async fn process_burns(&self, burns: &[BurnRow]) -> Result<()> {
|
| 674 |
+
let start = Instant::now();
|
| 675 |
+
if burns.is_empty() {
|
| 676 |
+
return Ok(());
|
| 677 |
+
}
|
| 678 |
+
let links: Vec<_> = burns
|
| 679 |
+
.iter()
|
| 680 |
+
.map(|b| BurnedLink {
|
| 681 |
+
signature: b.signature.clone(),
|
| 682 |
+
amount: b.amount_decimal,
|
| 683 |
+
timestamp: b.timestamp as i64,
|
| 684 |
+
})
|
| 685 |
+
.collect();
|
| 686 |
+
self.write_burned_links(&links, burns).await?;
|
| 687 |
+
println!(
|
| 688 |
+
"[LinkGraph] [Profile] process_burns: {} burns in {:?}",
|
| 689 |
+
burns.len(),
|
| 690 |
+
start.elapsed()
|
| 691 |
+
);
|
| 692 |
+
Ok(())
|
| 693 |
+
}
|
| 694 |
+
|
| 695 |
+
async fn process_transfers_and_funding(&self, transfers: &[TransferRow]) -> Result<()> {
|
| 696 |
+
let start = Instant::now();
|
| 697 |
+
if transfers.is_empty() {
|
| 698 |
+
return Ok(());
|
| 699 |
+
}
|
| 700 |
+
|
| 701 |
+
// Directly map every TransferRow to a TransferLink without any extra logic.
|
| 702 |
+
let transfer_links: Vec<TransferLink> = transfers
|
| 703 |
+
.iter()
|
| 704 |
+
.map(|transfer| TransferLink {
|
| 705 |
+
source: transfer.source.clone(),
|
| 706 |
+
destination: transfer.destination.clone(),
|
| 707 |
+
signature: transfer.signature.clone(),
|
| 708 |
+
mint: transfer.mint_address.clone(),
|
| 709 |
+
timestamp: transfer.timestamp as i64,
|
| 710 |
+
amount: transfer.amount_decimal,
|
| 711 |
+
})
|
| 712 |
+
.collect();
|
| 713 |
+
|
| 714 |
+
self.write_transfer_links(&transfer_links).await?;
|
| 715 |
+
println!(
|
| 716 |
+
"[LinkGraph] [Profile] process_transfers: {} transfers in {:?}",
|
| 717 |
+
transfers.len(),
|
| 718 |
+
start.elapsed()
|
| 719 |
+
);
|
| 720 |
+
Ok(())
|
| 721 |
+
}
|
| 722 |
+
|
| 723 |
+
async fn process_trade_patterns(
|
| 724 |
+
&self,
|
| 725 |
+
trades: &[TradeRow],
|
| 726 |
+
mints_in_batch: &[MintRow],
|
| 727 |
+
) -> Result<()> {
|
| 728 |
+
let start = Instant::now();
|
| 729 |
+
if trades.is_empty() {
|
| 730 |
+
return Ok(());
|
| 731 |
+
}
|
| 732 |
+
|
| 733 |
+
let creator_map: HashMap<String, String> = mints_in_batch
|
| 734 |
+
.iter()
|
| 735 |
+
.map(|m| (m.mint_address.clone(), m.creator_address.clone()))
|
| 736 |
+
.collect();
|
| 737 |
+
|
| 738 |
+
let mut processed_pairs = HashSet::new();
|
| 739 |
+
|
| 740 |
+
let bundle_links = self.detect_bundle_trades(trades, &mut processed_pairs);
|
| 741 |
+
if !bundle_links.is_empty() {
|
| 742 |
+
self.write_bundle_trade_links(&bundle_links).await?;
|
| 743 |
+
}
|
| 744 |
+
|
| 745 |
+
let follower_links = self
|
| 746 |
+
.detect_follower_activity(trades, &mut processed_pairs)
|
| 747 |
+
.await?;
|
| 748 |
+
if !follower_links.is_empty() {
|
| 749 |
+
let mut copied_links = Vec::new();
|
| 750 |
+
let mut coordinated_links = Vec::new();
|
| 751 |
+
for link in follower_links {
|
| 752 |
+
match link {
|
| 753 |
+
FollowerLink::Copied(l) => copied_links.push(l),
|
| 754 |
+
FollowerLink::Coordinated(l) => coordinated_links.push(l),
|
| 755 |
+
}
|
| 756 |
+
}
|
| 757 |
+
if !copied_links.is_empty() {
|
| 758 |
+
self.write_copied_trade_links(&copied_links).await?;
|
| 759 |
+
}
|
| 760 |
+
if !coordinated_links.is_empty() {
|
| 761 |
+
self.write_coordinated_activity_links(&coordinated_links)
|
| 762 |
+
.await?;
|
| 763 |
+
}
|
| 764 |
+
}
|
| 765 |
+
|
| 766 |
+
self.detect_and_write_snipes(trades, creator_map).await?;
|
| 767 |
+
self.detect_and_write_whale_links(trades).await?;
|
| 768 |
+
self.detect_and_write_top_trader_links(trades).await?;
|
| 769 |
+
|
| 770 |
+
println!(
|
| 771 |
+
"[LinkGraph] [Profile] process_trade_patterns: {} trades in {:?}",
|
| 772 |
+
trades.len(),
|
| 773 |
+
start.elapsed()
|
| 774 |
+
);
|
| 775 |
+
Ok(())
|
| 776 |
+
}
|
| 777 |
+
|
| 778 |
+
async fn detect_and_write_snipes(
|
| 779 |
+
&self,
|
| 780 |
+
_trades: &[TradeRow],
|
| 781 |
+
creator_map: HashMap<String, String>,
|
| 782 |
+
) -> Result<()> {
|
| 783 |
+
let start = Instant::now();
|
| 784 |
+
let cfg = &*LINK_GRAPH_CONFIG;
|
| 785 |
+
let mut links: Vec<SnipedLink> = Vec::new();
|
| 786 |
+
let mut snipers_map: HashMap<String, (String, String)> = HashMap::new();
|
| 787 |
+
// Limit sniper detection to Pump.fun launchpad trades only.
|
| 788 |
+
let pump_trades: Vec<&TradeRow> = _trades
|
| 789 |
+
.iter()
|
| 790 |
+
.filter(|t| t.protocol == PROTOCOL_PUMPFUN_LAUNCHPAD)
|
| 791 |
+
.collect();
|
| 792 |
+
if pump_trades.is_empty() {
|
| 793 |
+
return Ok(());
|
| 794 |
+
}
|
| 795 |
+
|
| 796 |
+
let unique_mints: HashSet<String> =
|
| 797 |
+
pump_trades.iter().map(|t| t.base_address.clone()).collect();
|
| 798 |
+
if unique_mints.is_empty() {
|
| 799 |
+
return Ok(());
|
| 800 |
+
}
|
| 801 |
+
|
| 802 |
+
// This pre-flight check remains the same
|
| 803 |
+
#[derive(Row, Deserialize, Debug)]
|
| 804 |
+
struct TokenHolderInfo {
|
| 805 |
+
token_address: String,
|
| 806 |
+
unique_holders: u32,
|
| 807 |
+
}
|
| 808 |
+
|
| 809 |
+
let holder_check_query = "
|
| 810 |
+
SELECT token_address, unique_holders
|
| 811 |
+
FROM token_metrics_latest
|
| 812 |
+
WHERE token_address IN ?
|
| 813 |
+
ORDER BY token_address, updated_at DESC
|
| 814 |
+
LIMIT 1 BY token_address
|
| 815 |
+
";
|
| 816 |
+
let mut holder_infos: Vec<TokenHolderInfo> = Vec::new();
|
| 817 |
+
let unique_mints_vec: Vec<_> = unique_mints.iter().cloned().collect();
|
| 818 |
+
|
| 819 |
+
for chunk in unique_mints_vec.chunks(cfg.chunk_size_large) {
|
| 820 |
+
let mut chunk_results = self
|
| 821 |
+
.with_ch_retry(
|
| 822 |
+
|| async {
|
| 823 |
+
self.db_client
|
| 824 |
+
.query(holder_check_query)
|
| 825 |
+
.bind(chunk)
|
| 826 |
+
.fetch_all()
|
| 827 |
+
.await
|
| 828 |
+
.map_err(anyhow::Error::from)
|
| 829 |
+
},
|
| 830 |
+
"Snipes-HolderCheck chunk",
|
| 831 |
+
)
|
| 832 |
+
.await?;
|
| 833 |
+
holder_infos.append(&mut chunk_results);
|
| 834 |
+
}
|
| 835 |
+
|
| 836 |
+
let token_holder_map: HashMap<String, u32> = holder_infos
|
| 837 |
+
.into_iter()
|
| 838 |
+
.map(|t| (t.token_address, t.unique_holders))
|
| 839 |
+
.collect();
|
| 840 |
+
|
| 841 |
+
#[derive(Row, Deserialize, Clone, Debug)]
|
| 842 |
+
struct SniperInfo {
|
| 843 |
+
maker: String,
|
| 844 |
+
first_sig: String,
|
| 845 |
+
first_total: f64,
|
| 846 |
+
first_ts: u32,
|
| 847 |
+
}
|
| 848 |
+
|
| 849 |
+
#[derive(Row, Deserialize, Debug)]
|
| 850 |
+
struct TokenCreator {
|
| 851 |
+
creator_address: String,
|
| 852 |
+
}
|
| 853 |
+
|
| 854 |
+
// OPTIMIZATION: Parallelize the database queries for each mint.
|
| 855 |
+
let query_futures = unique_mints
|
| 856 |
+
.into_iter()
|
| 857 |
+
.filter(|mint| {
|
| 858 |
+
// Pre-filter mints that are too established
|
| 859 |
+
let holder_count = token_holder_map.get(mint).cloned().unwrap_or(0);
|
| 860 |
+
holder_count <= cfg.sniper_rank_threshold as u32
|
| 861 |
+
})
|
| 862 |
+
.map(|mint| {
|
| 863 |
+
let db_client = self.db_client.clone();
|
| 864 |
+
let creator_map_clone = creator_map.clone();
|
| 865 |
+
// Create an async block (a future) for each query
|
| 866 |
+
async move {
|
| 867 |
+
let snipers_query = "
|
| 868 |
+
SELECT maker,
|
| 869 |
+
argMin(signature, timestamp) as first_sig,
|
| 870 |
+
argMin(total, timestamp) as first_total,
|
| 871 |
+
min(toUInt32(timestamp)) as first_ts
|
| 872 |
+
FROM trades WHERE base_address = ? AND trade_type = 0
|
| 873 |
+
GROUP BY maker ORDER BY min(timestamp) ASC LIMIT ?
|
| 874 |
+
";
|
| 875 |
+
|
| 876 |
+
let result = db_client
|
| 877 |
+
.query(snipers_query)
|
| 878 |
+
.bind(mint.clone()) // Keep this bind
|
| 879 |
+
.bind(cfg.sniper_rank_threshold) // And this one
|
| 880 |
+
.fetch_all::<SniperInfo>()
|
| 881 |
+
.await
|
| 882 |
+
.map_err(|e| {
|
| 883 |
+
anyhow!(
|
| 884 |
+
"[SNIPER_FAIL]: Sniper fetch for mint '{}' failed. Error: {}",
|
| 885 |
+
mint,
|
| 886 |
+
e
|
| 887 |
+
)
|
| 888 |
+
});
|
| 889 |
+
|
| 890 |
+
(mint, result)
|
| 891 |
+
}
|
| 892 |
+
});
|
| 893 |
+
|
| 894 |
+
// Execute the futures concurrently with a limit of 20 at a time.
|
| 895 |
+
let results = stream::iter(query_futures)
|
| 896 |
+
.buffer_unordered(20) // CONCURRENCY LIMIT
|
| 897 |
+
.collect::<Vec<_>>()
|
| 898 |
+
.await;
|
| 899 |
+
|
| 900 |
+
// Process the results after they have all completed
|
| 901 |
+
for (mint, result) in results {
|
| 902 |
+
match result {
|
| 903 |
+
Ok(sniper_candidates) => {
|
| 904 |
+
for (i, sniper) in sniper_candidates.iter().enumerate() {
|
| 905 |
+
links.push(SnipedLink {
|
| 906 |
+
timestamp: sniper.first_ts as i64,
|
| 907 |
+
signature: sniper.first_sig.clone(),
|
| 908 |
+
rank: (i + 1) as i64,
|
| 909 |
+
sniped_amount: sniper.first_total,
|
| 910 |
+
});
|
| 911 |
+
snipers_map.insert(
|
| 912 |
+
sniper.first_sig.clone(),
|
| 913 |
+
(sniper.maker.clone(), mint.clone()),
|
| 914 |
+
);
|
| 915 |
+
}
|
| 916 |
+
}
|
| 917 |
+
Err(e) => eprintln!("[Snipers] Error processing mint {}: {}", mint, e),
|
| 918 |
+
}
|
| 919 |
+
}
|
| 920 |
+
|
| 921 |
+
if !links.is_empty() {
|
| 922 |
+
self.write_sniped_links(&links, &snipers_map).await?;
|
| 923 |
+
}
|
| 924 |
+
println!(
|
| 925 |
+
"[LinkGraph] [Profile] detect_and_write_snipes: {} links in {:?}",
|
| 926 |
+
links.len(),
|
| 927 |
+
start.elapsed()
|
| 928 |
+
);
|
| 929 |
+
Ok(())
|
| 930 |
+
}
|
| 931 |
+
|
| 932 |
+
fn detect_bundle_trades(
|
| 933 |
+
&self,
|
| 934 |
+
trades: &[TradeRow],
|
| 935 |
+
processed_pairs: &mut HashSet<(String, String)>,
|
| 936 |
+
) -> Vec<BundleTradeLink> {
|
| 937 |
+
let mut links = Vec::new();
|
| 938 |
+
let trades_by_slot_mint = trades
|
| 939 |
+
.iter()
|
| 940 |
+
.into_group_map_by(|t| (t.slot, t.base_address.clone()));
|
| 941 |
+
|
| 942 |
+
for ((slot, mint), trades_in_bundle) in trades_by_slot_mint {
|
| 943 |
+
let unique_makers: Vec<_> =
|
| 944 |
+
trades_in_bundle.iter().map(|t| &t.maker).unique().collect();
|
| 945 |
+
if unique_makers.len() <= 1 {
|
| 946 |
+
continue;
|
| 947 |
+
}
|
| 948 |
+
|
| 949 |
+
// Leader Election: Find the trade with the max `quote_amount`.
|
| 950 |
+
// Includes a deterministic tie-breaker using the wallet address.
|
| 951 |
+
let leader_trade = match trades_in_bundle.iter().max_by(|a, b| {
|
| 952 |
+
match a.quote_amount.cmp(&b.quote_amount) {
|
| 953 |
+
std::cmp::Ordering::Equal => b.maker.cmp(&a.maker),
|
| 954 |
+
other => other,
|
| 955 |
+
}
|
| 956 |
+
}) {
|
| 957 |
+
Some(trade) => trade,
|
| 958 |
+
None => continue,
|
| 959 |
+
};
|
| 960 |
+
let leader_wallet = &leader_trade.maker;
|
| 961 |
+
|
| 962 |
+
let all_bundle_signatures: Vec<String> = trades_in_bundle
|
| 963 |
+
.iter()
|
| 964 |
+
.map(|t| t.signature.clone())
|
| 965 |
+
.collect();
|
| 966 |
+
|
| 967 |
+
for follower_trade in trades_in_bundle
|
| 968 |
+
.iter()
|
| 969 |
+
.filter(|t| &t.maker != leader_wallet)
|
| 970 |
+
{
|
| 971 |
+
let follower_wallet = &follower_trade.maker;
|
| 972 |
+
|
| 973 |
+
let mut combo_sorted = vec![leader_wallet.clone(), follower_wallet.clone()];
|
| 974 |
+
combo_sorted.sort();
|
| 975 |
+
let pair_key = (combo_sorted[0].clone(), combo_sorted[1].clone());
|
| 976 |
+
|
| 977 |
+
// Populate the processed_pairs set and create the link.
|
| 978 |
+
if processed_pairs.insert(pair_key) {
|
| 979 |
+
links.push(BundleTradeLink {
|
| 980 |
+
signatures: all_bundle_signatures.clone(),
|
| 981 |
+
wallet_a: leader_wallet.clone(),
|
| 982 |
+
wallet_b: follower_wallet.clone(),
|
| 983 |
+
mint: mint.clone(),
|
| 984 |
+
slot: slot as i64,
|
| 985 |
+
timestamp: leader_trade.timestamp as i64,
|
| 986 |
+
});
|
| 987 |
+
}
|
| 988 |
+
}
|
| 989 |
+
}
|
| 990 |
+
links
|
| 991 |
+
}
|
| 992 |
+
|
| 993 |
+
async fn detect_follower_activity(
|
| 994 |
+
&self,
|
| 995 |
+
trades: &[TradeRow],
|
| 996 |
+
processed_pairs: &mut HashSet<(String, String)>,
|
| 997 |
+
) -> Result<Vec<FollowerLink>> {
|
| 998 |
+
let cfg = &*LINK_GRAPH_CONFIG;
|
| 999 |
+
let mut links = Vec::new();
|
| 1000 |
+
let min_usd_value = cfg.min_trade_total_usd;
|
| 1001 |
+
|
| 1002 |
+
let significant_trades: Vec<&TradeRow> = trades
|
| 1003 |
+
.iter()
|
| 1004 |
+
.filter(|t| t.total_usd >= min_usd_value)
|
| 1005 |
+
.collect();
|
| 1006 |
+
|
| 1007 |
+
if significant_trades.len() < 2 {
|
| 1008 |
+
return Ok(links);
|
| 1009 |
+
}
|
| 1010 |
+
|
| 1011 |
+
let unique_pairs: Vec<(String, String)> = significant_trades
|
| 1012 |
+
.iter()
|
| 1013 |
+
.map(|t| (t.maker.clone(), t.base_address.clone()))
|
| 1014 |
+
.unique()
|
| 1015 |
+
.collect();
|
| 1016 |
+
// Update and read from the bounded in-memory cache; fallback to CH only on misses.
|
| 1017 |
+
self.update_trade_cache(&significant_trades).await?;
|
| 1018 |
+
let mut historical_trades_map = self.build_histories_from_cache(&unique_pairs).await?;
|
| 1019 |
+
|
| 1020 |
+
let missing_pairs: Vec<(String, String)> = unique_pairs
|
| 1021 |
+
.iter()
|
| 1022 |
+
.filter(|k| !historical_trades_map.contains_key(*k))
|
| 1023 |
+
.cloned()
|
| 1024 |
+
.collect();
|
| 1025 |
+
if !missing_pairs.is_empty() {
|
| 1026 |
+
let historical_query = "
|
| 1027 |
+
SELECT maker, base_address, toUnixTimestamp(timestamp) as timestamp, signature, trade_type, total_usd, slippage
|
| 1028 |
+
FROM trades
|
| 1029 |
+
WHERE (maker, base_address) IN ?
|
| 1030 |
+
";
|
| 1031 |
+
for chunk in missing_pairs.chunks(cfg.chunk_size_historical) {
|
| 1032 |
+
let chunk_results: Vec<FullHistTrade> = self
|
| 1033 |
+
.db_client
|
| 1034 |
+
.query(historical_query)
|
| 1035 |
+
.bind(chunk)
|
| 1036 |
+
.fetch_all()
|
| 1037 |
+
.await
|
| 1038 |
+
.map_err(|e| {
|
| 1039 |
+
anyhow!(
|
| 1040 |
+
"[FOLLOWER_FAIL]: Historical trade fetch failed. Error: {}",
|
| 1041 |
+
e
|
| 1042 |
+
)
|
| 1043 |
+
})?;
|
| 1044 |
+
|
| 1045 |
+
for trade in chunk_results {
|
| 1046 |
+
historical_trades_map
|
| 1047 |
+
.entry((trade.maker.clone(), trade.base_address.clone()))
|
| 1048 |
+
.or_default()
|
| 1049 |
+
.push(trade);
|
| 1050 |
+
}
|
| 1051 |
+
}
|
| 1052 |
+
}
|
| 1053 |
+
|
| 1054 |
+
let trades_by_mint = significant_trades
|
| 1055 |
+
.into_iter()
|
| 1056 |
+
.into_group_map_by(|t| t.base_address.clone());
|
| 1057 |
+
|
| 1058 |
+
for (mint, trades_in_batch) in trades_by_mint {
|
| 1059 |
+
if trades_in_batch.len() < 2 {
|
| 1060 |
+
continue;
|
| 1061 |
+
}
|
| 1062 |
+
|
| 1063 |
+
let Some(leader_trade) = trades_in_batch.iter().min_by_key(|t| t.timestamp) else {
|
| 1064 |
+
continue;
|
| 1065 |
+
};
|
| 1066 |
+
let leader_wallet = &leader_trade.maker;
|
| 1067 |
+
|
| 1068 |
+
for follower_trade in trades_in_batch.iter().filter(|t| &t.maker != leader_wallet) {
|
| 1069 |
+
let follower_wallet = &follower_trade.maker;
|
| 1070 |
+
|
| 1071 |
+
let mut pair_key_vec = vec![leader_wallet.to_string(), follower_wallet.to_string()];
|
| 1072 |
+
pair_key_vec.sort();
|
| 1073 |
+
let pair_key = (pair_key_vec[0].clone(), pair_key_vec[1].clone());
|
| 1074 |
+
if processed_pairs.contains(&pair_key) {
|
| 1075 |
+
continue;
|
| 1076 |
+
}
|
| 1077 |
+
|
| 1078 |
+
if let (Some(leader_hist_ref), Some(follower_hist_ref)) = (
|
| 1079 |
+
historical_trades_map.get(&(leader_wallet.clone(), mint.clone())),
|
| 1080 |
+
historical_trades_map.get(&(follower_wallet.clone(), mint.clone())),
|
| 1081 |
+
) {
|
| 1082 |
+
let mut leader_hist = leader_hist_ref.clone();
|
| 1083 |
+
let mut follower_hist = follower_hist_ref.clone();
|
| 1084 |
+
leader_hist.sort_by_key(|t| t.timestamp);
|
| 1085 |
+
follower_hist.sort_by_key(|t| t.timestamp);
|
| 1086 |
+
|
| 1087 |
+
let leader_first_trade = leader_hist.get(0);
|
| 1088 |
+
let follower_first_trade = follower_hist.get(0);
|
| 1089 |
+
|
| 1090 |
+
// --- THE CRITICAL FIX ---
|
| 1091 |
+
// Base the decision on the very first interaction.
|
| 1092 |
+
if let (Some(l1), Some(f1)) = (leader_first_trade, follower_first_trade) {
|
| 1093 |
+
let first_gap = (f1.timestamp as i64 - l1.timestamp as i64).abs();
|
| 1094 |
+
|
| 1095 |
+
if first_gap > 0 && first_gap <= cfg.copied_trade_window_seconds {
|
| 1096 |
+
processed_pairs.insert(pair_key); // Process this pair only once
|
| 1097 |
+
|
| 1098 |
+
// A) If the FIRST trades are BOTH BUYS, it's a COPIED_TRADE.
|
| 1099 |
+
if l1.trade_type == 0 && f1.trade_type == 0 {
|
| 1100 |
+
let l_buy = l1; // Already have the first buy
|
| 1101 |
+
let f_buy = f1; // Already have the first buy
|
| 1102 |
+
|
| 1103 |
+
let leader_sells: Vec<_> =
|
| 1104 |
+
leader_hist.iter().filter(|t| t.trade_type == 1).collect();
|
| 1105 |
+
let follower_sells: Vec<_> =
|
| 1106 |
+
follower_hist.iter().filter(|t| t.trade_type == 1).collect();
|
| 1107 |
+
let leader_sell_total: f64 =
|
| 1108 |
+
leader_sells.iter().map(|t| t.total_usd).sum();
|
| 1109 |
+
let follower_sell_total: f64 =
|
| 1110 |
+
follower_sells.iter().map(|t| t.total_usd).sum();
|
| 1111 |
+
let leader_pnl = if l_buy.total_usd > 0.0 {
|
| 1112 |
+
(leader_sell_total - l_buy.total_usd) / l_buy.total_usd
|
| 1113 |
+
} else {
|
| 1114 |
+
0.0
|
| 1115 |
+
};
|
| 1116 |
+
let follower_pnl = if f_buy.total_usd > 0.0 {
|
| 1117 |
+
(follower_sell_total - f_buy.total_usd) / f_buy.total_usd
|
| 1118 |
+
} else {
|
| 1119 |
+
0.0
|
| 1120 |
+
};
|
| 1121 |
+
let leader_first_sell =
|
| 1122 |
+
leader_sells.iter().min_by_key(|t| t.timestamp);
|
| 1123 |
+
let follower_first_sell =
|
| 1124 |
+
follower_sells.iter().min_by_key(|t| t.timestamp);
|
| 1125 |
+
|
| 1126 |
+
let (sell_gap, l_sell_sig, f_sell_sig, f_sell_slip) =
|
| 1127 |
+
if let (Some(l_sell), Some(f_sell)) =
|
| 1128 |
+
(leader_first_sell, follower_first_sell)
|
| 1129 |
+
{
|
| 1130 |
+
(
|
| 1131 |
+
(f_sell.timestamp as i64 - l_sell.timestamp as i64)
|
| 1132 |
+
.abs(),
|
| 1133 |
+
l_sell.signature.clone(),
|
| 1134 |
+
f_sell.signature.clone(),
|
| 1135 |
+
f_sell.slippage,
|
| 1136 |
+
)
|
| 1137 |
+
} else {
|
| 1138 |
+
(0, "".to_string(), "".to_string(), 0.0)
|
| 1139 |
+
};
|
| 1140 |
+
|
| 1141 |
+
links.push(FollowerLink::Copied(CopiedTradeLink {
|
| 1142 |
+
timestamp: f_buy.timestamp as i64,
|
| 1143 |
+
follower: follower_wallet.clone(),
|
| 1144 |
+
leader: leader_wallet.clone(),
|
| 1145 |
+
mint: mint.clone(),
|
| 1146 |
+
time_gap_on_buy_sec: first_gap, // Use the already calculated gap
|
| 1147 |
+
time_gap_on_sell_sec: sell_gap,
|
| 1148 |
+
leader_pnl,
|
| 1149 |
+
follower_pnl,
|
| 1150 |
+
leader_buy_sig: l_buy.signature.clone(),
|
| 1151 |
+
leader_sell_sig: l_sell_sig,
|
| 1152 |
+
follower_buy_sig: f_buy.signature.clone(),
|
| 1153 |
+
follower_sell_sig: f_sell_sig,
|
| 1154 |
+
leader_buy_total: l_buy.total_usd,
|
| 1155 |
+
leader_sell_total,
|
| 1156 |
+
follower_buy_total: f_buy.total_usd,
|
| 1157 |
+
follower_sell_total,
|
| 1158 |
+
follower_buy_slippage: f_buy.slippage,
|
| 1159 |
+
follower_sell_slippage: f_sell_slip,
|
| 1160 |
+
}));
|
| 1161 |
+
}
|
| 1162 |
+
// B) ELSE, if the first trades are not both buys, it's a COORDINATED_ACTIVITY.
|
| 1163 |
+
else {
|
| 1164 |
+
let leader_second_trade = leader_hist.get(1);
|
| 1165 |
+
let follower_second_trade = follower_hist.get(1);
|
| 1166 |
+
|
| 1167 |
+
let (l2_sig, f2_sig, second_gap) = if let (Some(l2), Some(f2)) =
|
| 1168 |
+
(leader_second_trade, follower_second_trade)
|
| 1169 |
+
{
|
| 1170 |
+
(
|
| 1171 |
+
l2.signature.clone(),
|
| 1172 |
+
f2.signature.clone(),
|
| 1173 |
+
(f2.timestamp as i64 - l2.timestamp as i64).abs(),
|
| 1174 |
+
)
|
| 1175 |
+
} else {
|
| 1176 |
+
("".to_string(), "".to_string(), 0)
|
| 1177 |
+
};
|
| 1178 |
+
|
| 1179 |
+
links.push(FollowerLink::Coordinated(CoordinatedActivityLink {
|
| 1180 |
+
timestamp: l1.timestamp as i64,
|
| 1181 |
+
leader: leader_wallet.clone(),
|
| 1182 |
+
follower: follower_wallet.clone(),
|
| 1183 |
+
mint: mint.clone(),
|
| 1184 |
+
leader_first_sig: l1.signature.clone(),
|
| 1185 |
+
follower_first_sig: f1.signature.clone(),
|
| 1186 |
+
time_gap_on_first_sec: first_gap,
|
| 1187 |
+
leader_second_sig: l2_sig,
|
| 1188 |
+
follower_second_sig: f2_sig,
|
| 1189 |
+
time_gap_on_second_sec: second_gap,
|
| 1190 |
+
}));
|
| 1191 |
+
}
|
| 1192 |
+
}
|
| 1193 |
+
}
|
| 1194 |
+
}
|
| 1195 |
+
}
|
| 1196 |
+
}
|
| 1197 |
+
Ok(links)
|
| 1198 |
+
}
|
| 1199 |
+
|
| 1200 |
+
async fn detect_and_write_top_trader_links(&self, trades: &[TradeRow]) -> Result<()> {
|
| 1201 |
+
let start = Instant::now();
|
| 1202 |
+
let cfg = &*LINK_GRAPH_CONFIG;
|
| 1203 |
+
let active_trader_pairs: Vec<(String, String)> = trades
|
| 1204 |
+
.iter()
|
| 1205 |
+
.map(|t| (t.maker.clone(), t.base_address.clone()))
|
| 1206 |
+
.unique()
|
| 1207 |
+
.collect();
|
| 1208 |
+
|
| 1209 |
+
if active_trader_pairs.is_empty() {
|
| 1210 |
+
return Ok(());
|
| 1211 |
+
}
|
| 1212 |
+
|
| 1213 |
+
// --- NEW: CONFIDENCE FILTER ---
|
| 1214 |
+
// 1. Get all unique mints from the active pairs.
|
| 1215 |
+
let unique_mints: Vec<String> = active_trader_pairs
|
| 1216 |
+
.iter()
|
| 1217 |
+
.map(|(_, mint)| mint.clone())
|
| 1218 |
+
.unique()
|
| 1219 |
+
.collect();
|
| 1220 |
+
|
| 1221 |
+
#[derive(Row, Deserialize, Debug)]
|
| 1222 |
+
struct MintCheck {
|
| 1223 |
+
mint_address: String,
|
| 1224 |
+
}
|
| 1225 |
+
let mint_query = "SELECT DISTINCT mint_address FROM mints WHERE mint_address IN ?";
|
| 1226 |
+
|
| 1227 |
+
let mut fully_tracked_mints = HashSet::new();
|
| 1228 |
+
let mint_chunk_small = cfg.chunk_size_mint_small;
|
| 1229 |
+
|
| 1230 |
+
for chunk in unique_mints.chunks(mint_chunk_small) {
|
| 1231 |
+
let chunk_rows: Vec<MintCheck> = self
|
| 1232 |
+
.with_ch_retry(
|
| 1233 |
+
|| async {
|
| 1234 |
+
self.db_client
|
| 1235 |
+
.query(mint_query)
|
| 1236 |
+
.bind(chunk)
|
| 1237 |
+
.fetch_all()
|
| 1238 |
+
.await
|
| 1239 |
+
.map_err(anyhow::Error::from)
|
| 1240 |
+
},
|
| 1241 |
+
"TopTrader mint check chunk",
|
| 1242 |
+
)
|
| 1243 |
+
.await?;
|
| 1244 |
+
for mint_row in chunk_rows {
|
| 1245 |
+
fully_tracked_mints.insert(mint_row.mint_address);
|
| 1246 |
+
}
|
| 1247 |
+
}
|
| 1248 |
+
|
| 1249 |
+
// 2. Filter the active pairs to only include those for fully tracked tokens.
|
| 1250 |
+
let confident_trader_pairs: Vec<(String, String)> = active_trader_pairs
|
| 1251 |
+
.into_iter()
|
| 1252 |
+
.filter(|(_, mint)| fully_tracked_mints.contains(mint))
|
| 1253 |
+
.collect();
|
| 1254 |
+
|
| 1255 |
+
if confident_trader_pairs.is_empty() {
|
| 1256 |
+
return Ok(());
|
| 1257 |
+
}
|
| 1258 |
+
// --- END CONFIDENCE FILTER ---
|
| 1259 |
+
|
| 1260 |
+
let mints_to_query: Vec<String> = fully_tracked_mints.iter().cloned().collect();
|
| 1261 |
+
if mints_to_query.is_empty() {
|
| 1262 |
+
return Ok(());
|
| 1263 |
+
}
|
| 1264 |
+
|
| 1265 |
+
let ath_map = self
|
| 1266 |
+
.fetch_latest_ath_map_with_retry(&mints_to_query)
|
| 1267 |
+
.await?;
|
| 1268 |
+
if ath_map.is_empty() {
|
| 1269 |
+
return Ok(());
|
| 1270 |
+
}
|
| 1271 |
+
|
| 1272 |
+
#[derive(Row, Deserialize, Debug)]
|
| 1273 |
+
struct TraderContextInfo {
|
| 1274 |
+
wallet_address: String,
|
| 1275 |
+
mint_address: String,
|
| 1276 |
+
realized_profit_pnl: f32,
|
| 1277 |
+
}
|
| 1278 |
+
|
| 1279 |
+
let pnl_query = "
|
| 1280 |
+
SELECT
|
| 1281 |
+
wh.wallet_address, wh.mint_address, wh.realized_profit_pnl
|
| 1282 |
+
FROM wallet_holdings_latest AS wh
|
| 1283 |
+
WHERE wh.mint_address IN ?
|
| 1284 |
+
AND wh.realized_profit_pnl > ?
|
| 1285 |
+
QUALIFY ROW_NUMBER() OVER (PARTITION BY wh.mint_address ORDER BY wh.realized_profit_pnl DESC) = 1
|
| 1286 |
+
";
|
| 1287 |
+
|
| 1288 |
+
let mut top_traders: Vec<TraderContextInfo> = Vec::new();
|
| 1289 |
+
|
| 1290 |
+
for chunk in mints_to_query.chunks(cfg.chunk_size_mint_large) {
|
| 1291 |
+
let chunk_results = self
|
| 1292 |
+
.db_client
|
| 1293 |
+
.query(pnl_query)
|
| 1294 |
+
.bind(chunk)
|
| 1295 |
+
.bind(cfg.min_top_trader_pnl)
|
| 1296 |
+
.fetch_all()
|
| 1297 |
+
.await
|
| 1298 |
+
.map_err(|e| anyhow!("[TOPTRADER_FAIL]: Top-1 PNL fetch failed. Error: {}", e))?;
|
| 1299 |
+
top_traders.extend(chunk_results);
|
| 1300 |
+
}
|
| 1301 |
+
|
| 1302 |
+
let links: Vec<TopTraderOfLink> = top_traders
|
| 1303 |
+
.into_iter()
|
| 1304 |
+
.filter_map(|trader| {
|
| 1305 |
+
ath_map
|
| 1306 |
+
.get(&trader.mint_address)
|
| 1307 |
+
.filter(|ath| **ath >= cfg.ath_price_threshold_usd)
|
| 1308 |
+
.map(|ath| TopTraderOfLink {
|
| 1309 |
+
timestamp: Utc::now().timestamp(),
|
| 1310 |
+
wallet: trader.wallet_address,
|
| 1311 |
+
token: trader.mint_address,
|
| 1312 |
+
pnl_at_creation: trader.realized_profit_pnl as f64,
|
| 1313 |
+
ath_usd_at_creation: *ath,
|
| 1314 |
+
})
|
| 1315 |
+
})
|
| 1316 |
+
.collect();
|
| 1317 |
+
|
| 1318 |
+
if !links.is_empty() {
|
| 1319 |
+
self.write_top_trader_of_links(&links).await?;
|
| 1320 |
+
}
|
| 1321 |
+
|
| 1322 |
+
println!(
|
| 1323 |
+
"[LinkGraph] [Profile] detect_and_write_top_trader_links: {} links in {:?}",
|
| 1324 |
+
links.len(),
|
| 1325 |
+
start.elapsed()
|
| 1326 |
+
);
|
| 1327 |
+
Ok(())
|
| 1328 |
+
}
|
| 1329 |
+
|
| 1330 |
+
async fn process_liquidity_events(&self, liquidity_adds: &[LiquidityRow]) -> Result<()> {
|
| 1331 |
+
let cfg = &*LINK_GRAPH_CONFIG;
|
| 1332 |
+
if liquidity_adds.is_empty() {
|
| 1333 |
+
return Ok(());
|
| 1334 |
+
}
|
| 1335 |
+
let unique_pools: HashSet<String> = liquidity_adds
|
| 1336 |
+
.iter()
|
| 1337 |
+
.map(|l| l.pool_address.clone())
|
| 1338 |
+
.collect();
|
| 1339 |
+
if unique_pools.is_empty() {
|
| 1340 |
+
return Ok(());
|
| 1341 |
+
}
|
| 1342 |
+
|
| 1343 |
+
#[derive(Row, Deserialize, Debug)]
|
| 1344 |
+
struct PoolInfo {
|
| 1345 |
+
pool_address: String,
|
| 1346 |
+
base_address: String,
|
| 1347 |
+
base_decimals: Option<u8>,
|
| 1348 |
+
quote_decimals: Option<u8>,
|
| 1349 |
+
}
|
| 1350 |
+
|
| 1351 |
+
let pool_query = "SELECT pool_address, base_address, base_decimals, quote_decimals FROM pool_creations WHERE pool_address IN ?";
|
| 1352 |
+
let mut pools_info: Vec<PoolInfo> = Vec::new();
|
| 1353 |
+
let unique_pools_vec: Vec<_> = unique_pools.iter().cloned().collect();
|
| 1354 |
+
|
| 1355 |
+
for chunk in unique_pools_vec.chunks(cfg.chunk_size_large) {
|
| 1356 |
+
let mut chunk_results = self
|
| 1357 |
+
.db_client
|
| 1358 |
+
.query(pool_query)
|
| 1359 |
+
.bind(chunk)
|
| 1360 |
+
.fetch_all()
|
| 1361 |
+
.await
|
| 1362 |
+
.map_err(|e| anyhow!("[LIQUIDITY_FAIL]: PoolQuery chunk failed. Error: {}", e))?;
|
| 1363 |
+
pools_info.append(&mut chunk_results);
|
| 1364 |
+
}
|
| 1365 |
+
|
| 1366 |
+
let pool_to_token_map: HashMap<String, (String, Option<u8>, Option<u8>)> = pools_info
|
| 1367 |
+
.into_iter()
|
| 1368 |
+
.map(|p| {
|
| 1369 |
+
(
|
| 1370 |
+
p.pool_address,
|
| 1371 |
+
(p.base_address, p.base_decimals, p.quote_decimals),
|
| 1372 |
+
)
|
| 1373 |
+
})
|
| 1374 |
+
.collect();
|
| 1375 |
+
|
| 1376 |
+
let links: Vec<_> = liquidity_adds
|
| 1377 |
+
.iter()
|
| 1378 |
+
.filter_map(|l| {
|
| 1379 |
+
pool_to_token_map.get(&l.pool_address).map(
|
| 1380 |
+
|(token_address, base_decimals, quote_decimals)| {
|
| 1381 |
+
let base_scale = 10f64.powi(base_decimals.unwrap_or(0) as i32);
|
| 1382 |
+
let quote_scale = 10f64.powi(quote_decimals.unwrap_or(0) as i32);
|
| 1383 |
+
ProvidedLiquidityLink {
|
| 1384 |
+
signature: l.signature.clone(),
|
| 1385 |
+
wallet: l.lp_provider.clone(),
|
| 1386 |
+
token: token_address.clone(),
|
| 1387 |
+
pool_address: l.pool_address.clone(),
|
| 1388 |
+
amount_base: l.base_amount as f64 / base_scale,
|
| 1389 |
+
amount_quote: l.quote_amount as f64 / quote_scale,
|
| 1390 |
+
timestamp: l.timestamp as i64,
|
| 1391 |
+
}
|
| 1392 |
+
},
|
| 1393 |
+
)
|
| 1394 |
+
})
|
| 1395 |
+
.collect();
|
| 1396 |
+
|
| 1397 |
+
if !links.is_empty() {
|
| 1398 |
+
self.write_provided_liquidity_links(&links).await?;
|
| 1399 |
+
}
|
| 1400 |
+
Ok(())
|
| 1401 |
+
}
|
| 1402 |
+
|
| 1403 |
+
async fn detect_and_write_whale_links(&self, trades: &[TradeRow]) -> Result<()> {
|
| 1404 |
+
let start = Instant::now();
|
| 1405 |
+
let cfg = &*LINK_GRAPH_CONFIG;
|
| 1406 |
+
let unique_mints_in_batch: Vec<String> = trades
|
| 1407 |
+
.iter()
|
| 1408 |
+
.map(|t| t.base_address.clone())
|
| 1409 |
+
.unique()
|
| 1410 |
+
.collect();
|
| 1411 |
+
if unique_mints_in_batch.is_empty() {
|
| 1412 |
+
return Ok(());
|
| 1413 |
+
}
|
| 1414 |
+
|
| 1415 |
+
// --- NEW: CONFIDENCE FILTER ---
|
| 1416 |
+
// 1. Check which of the mints in the batch have a creation event in our DB.
|
| 1417 |
+
#[derive(Row, Deserialize, Debug)]
|
| 1418 |
+
struct MintCheck {
|
| 1419 |
+
mint_address: String,
|
| 1420 |
+
}
|
| 1421 |
+
let mint_query = "SELECT DISTINCT mint_address FROM mints WHERE mint_address IN ?";
|
| 1422 |
+
|
| 1423 |
+
let mut fully_tracked_mints = HashSet::new();
|
| 1424 |
+
for chunk in unique_mints_in_batch.chunks(cfg.chunk_size_mint_large) {
|
| 1425 |
+
let chunk_rows: Vec<MintCheck> = self
|
| 1426 |
+
.with_ch_retry(
|
| 1427 |
+
|| async {
|
| 1428 |
+
self.db_client
|
| 1429 |
+
.query(mint_query)
|
| 1430 |
+
.bind(chunk)
|
| 1431 |
+
.fetch_all()
|
| 1432 |
+
.await
|
| 1433 |
+
.map_err(anyhow::Error::from)
|
| 1434 |
+
},
|
| 1435 |
+
"Whale mint check chunk",
|
| 1436 |
+
)
|
| 1437 |
+
.await?;
|
| 1438 |
+
for mint_row in chunk_rows {
|
| 1439 |
+
fully_tracked_mints.insert(mint_row.mint_address);
|
| 1440 |
+
}
|
| 1441 |
+
}
|
| 1442 |
+
|
| 1443 |
+
if fully_tracked_mints.is_empty() {
|
| 1444 |
+
return Ok(());
|
| 1445 |
+
}
|
| 1446 |
+
let confident_mints: Vec<String> = fully_tracked_mints.iter().cloned().collect();
|
| 1447 |
+
let ath_map = self
|
| 1448 |
+
.fetch_latest_ath_map_with_retry(&confident_mints)
|
| 1449 |
+
.await?;
|
| 1450 |
+
if ath_map.is_empty() {
|
| 1451 |
+
return Ok(());
|
| 1452 |
+
}
|
| 1453 |
+
// --- END CONFIDENCE FILTER ---
|
| 1454 |
+
|
| 1455 |
+
#[derive(Row, Deserialize, Debug)]
|
| 1456 |
+
struct TokenInfo {
|
| 1457 |
+
token_address: String,
|
| 1458 |
+
total_supply: u64,
|
| 1459 |
+
decimals: u8,
|
| 1460 |
+
}
|
| 1461 |
+
|
| 1462 |
+
let token_query = "SELECT token_address, total_supply, decimals FROM tokens FINAL WHERE token_address IN ?";
|
| 1463 |
+
|
| 1464 |
+
// --- RE-INTRODUCED CHUNKING for the token pre-filter ---
|
| 1465 |
+
let mut context_map: HashMap<String, (u64, f64, u8)> = HashMap::new();
|
| 1466 |
+
|
| 1467 |
+
for chunk in confident_mints.chunks(cfg.chunk_size_token) {
|
| 1468 |
+
let mut attempts = 0;
|
| 1469 |
+
loop {
|
| 1470 |
+
attempts += 1;
|
| 1471 |
+
let result: Result<Vec<TokenInfo>> = self
|
| 1472 |
+
.db_client
|
| 1473 |
+
.query(token_query)
|
| 1474 |
+
.bind(chunk)
|
| 1475 |
+
.fetch_all()
|
| 1476 |
+
.await
|
| 1477 |
+
.map_err(anyhow::Error::from);
|
| 1478 |
+
|
| 1479 |
+
match result {
|
| 1480 |
+
Ok(chunk_results) => {
|
| 1481 |
+
for token in chunk_results {
|
| 1482 |
+
if let Some(ath) = ath_map.get(&token.token_address) {
|
| 1483 |
+
if *ath >= cfg.ath_price_threshold_usd {
|
| 1484 |
+
context_map.insert(
|
| 1485 |
+
token.token_address,
|
| 1486 |
+
(token.total_supply, *ath, token.decimals),
|
| 1487 |
+
);
|
| 1488 |
+
}
|
| 1489 |
+
}
|
| 1490 |
+
}
|
| 1491 |
+
break;
|
| 1492 |
+
}
|
| 1493 |
+
Err(e) => {
|
| 1494 |
+
if attempts >= cfg.ch_retry_attempts {
|
| 1495 |
+
return Err(anyhow!(
|
| 1496 |
+
"[WHALE_FAIL]: Token pre-filter chunk failed after {} attempts: {}",
|
| 1497 |
+
attempts,
|
| 1498 |
+
e
|
| 1499 |
+
));
|
| 1500 |
+
}
|
| 1501 |
+
let backoff = cfg.ch_retry_backoff_ms * attempts as u64;
|
| 1502 |
+
eprintln!(
|
| 1503 |
+
"[LinkGraph] ⚠️ Whale token pre-filter retry {}/{} after {}ms: {}",
|
| 1504 |
+
attempts, cfg.ch_retry_attempts, backoff, e
|
| 1505 |
+
);
|
| 1506 |
+
sleep(Duration::from_millis(backoff)).await;
|
| 1507 |
+
}
|
| 1508 |
+
}
|
| 1509 |
+
}
|
| 1510 |
+
}
|
| 1511 |
+
// --- END CHUNKING ---
|
| 1512 |
+
|
| 1513 |
+
if context_map.is_empty() {
|
| 1514 |
+
return Ok(());
|
| 1515 |
+
}
|
| 1516 |
+
|
| 1517 |
+
let tokens_to_query: Vec<String> = context_map.keys().cloned().collect();
|
| 1518 |
+
|
| 1519 |
+
#[derive(Row, Deserialize, Debug)]
|
| 1520 |
+
struct WhaleInfo {
|
| 1521 |
+
wallet_address: String,
|
| 1522 |
+
mint_address: String,
|
| 1523 |
+
current_balance: f64,
|
| 1524 |
+
}
|
| 1525 |
+
|
| 1526 |
+
let whales_query = "
|
| 1527 |
+
SELECT wallet_address, mint_address, current_balance
|
| 1528 |
+
FROM wallet_holdings_latest
|
| 1529 |
+
WHERE mint_address IN ? AND current_balance > 0
|
| 1530 |
+
QUALIFY ROW_NUMBER() OVER (PARTITION BY mint_address ORDER BY current_balance DESC) <= ?
|
| 1531 |
+
";
|
| 1532 |
+
|
| 1533 |
+
// --- RE-INTRODUCED CHUNKING for the main whale query ---
|
| 1534 |
+
let mut top_holders: Vec<WhaleInfo> = Vec::new();
|
| 1535 |
+
for chunk in tokens_to_query.chunks(cfg.chunk_size_token) {
|
| 1536 |
+
let mut attempts = 0;
|
| 1537 |
+
loop {
|
| 1538 |
+
attempts += 1;
|
| 1539 |
+
let result: Result<Vec<WhaleInfo>> = self
|
| 1540 |
+
.db_client
|
| 1541 |
+
.query(whales_query)
|
| 1542 |
+
.bind(chunk)
|
| 1543 |
+
.bind(cfg.whale_rank_threshold)
|
| 1544 |
+
.fetch_all()
|
| 1545 |
+
.await
|
| 1546 |
+
.map_err(anyhow::Error::from);
|
| 1547 |
+
|
| 1548 |
+
match result {
|
| 1549 |
+
Ok(chunk_results) => {
|
| 1550 |
+
top_holders.extend(chunk_results);
|
| 1551 |
+
break;
|
| 1552 |
+
}
|
| 1553 |
+
Err(e) => {
|
| 1554 |
+
if attempts >= cfg.ch_retry_attempts {
|
| 1555 |
+
return Err(anyhow!(
|
| 1556 |
+
"[WHALE_FAIL]: Holder query chunk failed after {} attempts: {}",
|
| 1557 |
+
attempts,
|
| 1558 |
+
e
|
| 1559 |
+
));
|
| 1560 |
+
}
|
| 1561 |
+
let backoff = cfg.ch_retry_backoff_ms * attempts as u64;
|
| 1562 |
+
eprintln!(
|
| 1563 |
+
"[LinkGraph] ⚠️ Whale holder chunk retry {}/{} after {}ms: {}",
|
| 1564 |
+
attempts, cfg.ch_retry_attempts, backoff, e
|
| 1565 |
+
);
|
| 1566 |
+
sleep(Duration::from_millis(backoff)).await;
|
| 1567 |
+
}
|
| 1568 |
+
}
|
| 1569 |
+
}
|
| 1570 |
+
}
|
| 1571 |
+
// --- END CHUNKING ---
|
| 1572 |
+
|
| 1573 |
+
let mut links = Vec::new();
|
| 1574 |
+
for holder in top_holders {
|
| 1575 |
+
if let Some((raw_total_supply, ath_usd, decimals)) =
|
| 1576 |
+
context_map.get(&holder.mint_address)
|
| 1577 |
+
{
|
| 1578 |
+
if *raw_total_supply == 0 {
|
| 1579 |
+
continue;
|
| 1580 |
+
}
|
| 1581 |
+
|
| 1582 |
+
// --- THE FIX ---
|
| 1583 |
+
// Adjust the total supply to be human-readable before dividing.
|
| 1584 |
+
let human_total_supply = *raw_total_supply as f64 / 10f64.powi(*decimals as i32);
|
| 1585 |
+
if human_total_supply == 0.0 {
|
| 1586 |
+
continue;
|
| 1587 |
+
}
|
| 1588 |
+
// --- END FIX ---
|
| 1589 |
+
|
| 1590 |
+
let holding_pct = (holder.current_balance / human_total_supply) as f32;
|
| 1591 |
+
|
| 1592 |
+
links.push(WhaleOfLink {
|
| 1593 |
+
timestamp: Utc::now().timestamp(),
|
| 1594 |
+
wallet: holder.wallet_address.clone(),
|
| 1595 |
+
token: holder.mint_address.clone(),
|
| 1596 |
+
holding_pct_at_creation: holding_pct,
|
| 1597 |
+
ath_usd_at_creation: *ath_usd,
|
| 1598 |
+
});
|
| 1599 |
+
}
|
| 1600 |
+
}
|
| 1601 |
+
|
| 1602 |
+
if !links.is_empty() {
|
| 1603 |
+
self.write_whale_of_links(&links).await?;
|
| 1604 |
+
}
|
| 1605 |
+
|
| 1606 |
+
println!(
|
| 1607 |
+
"[LinkGraph] [Profile] detect_and_write_whale_links: {} links in {:?}",
|
| 1608 |
+
links.len(),
|
| 1609 |
+
start.elapsed()
|
| 1610 |
+
);
|
| 1611 |
+
Ok(())
|
| 1612 |
+
}
|
| 1613 |
+
|
| 1614 |
+
async fn create_wallet_nodes(&self, wallets: &HashSet<String>) -> Result<()> {
|
| 1615 |
+
if wallets.is_empty() {
|
| 1616 |
+
return Ok(());
|
| 1617 |
+
}
|
| 1618 |
+
let cfg = &*LINK_GRAPH_CONFIG;
|
| 1619 |
+
|
| 1620 |
+
// Convert the HashSet to a Vec to be able to create chunks
|
| 1621 |
+
let wallet_vec: Vec<_> = wallets.iter().cloned().collect();
|
| 1622 |
+
|
| 1623 |
+
// Process the wallets in smaller, manageable chunks
|
| 1624 |
+
for chunk in wallet_vec.chunks(cfg.chunk_size_large) {
|
| 1625 |
+
let params: Vec<_> = chunk
|
| 1626 |
+
.iter()
|
| 1627 |
+
.map(|addr| HashMap::from([("address".to_string(), BoltType::from(addr.clone()))]))
|
| 1628 |
+
.collect();
|
| 1629 |
+
|
| 1630 |
+
let cypher = "
|
| 1631 |
+
UNWIND $wallets as wallet
|
| 1632 |
+
MERGE (w:Wallet {address: wallet.address})
|
| 1633 |
+
";
|
| 1634 |
+
|
| 1635 |
+
self.enqueue_write(cypher, params).await?;
|
| 1636 |
+
}
|
| 1637 |
+
Ok(())
|
| 1638 |
+
}
|
| 1639 |
+
|
| 1640 |
+
async fn create_token_nodes(&self, tokens: &HashSet<String>) -> Result<()> {
|
| 1641 |
+
if tokens.is_empty() {
|
| 1642 |
+
return Ok(());
|
| 1643 |
+
}
|
| 1644 |
+
let cfg = &*LINK_GRAPH_CONFIG;
|
| 1645 |
+
|
| 1646 |
+
// Convert the HashSet to a Vec to be able to create chunks
|
| 1647 |
+
let token_vec: Vec<_> = tokens.iter().cloned().collect();
|
| 1648 |
+
|
| 1649 |
+
// Process the tokens in smaller, manageable chunks
|
| 1650 |
+
for chunk in token_vec.chunks(cfg.chunk_size_large) {
|
| 1651 |
+
let params: Vec<_> = chunk
|
| 1652 |
+
.iter()
|
| 1653 |
+
.map(|addr| HashMap::from([("address".to_string(), BoltType::from(addr.clone()))]))
|
| 1654 |
+
.collect();
|
| 1655 |
+
|
| 1656 |
+
let cypher = "
|
| 1657 |
+
UNWIND $tokens as token
|
| 1658 |
+
MERGE (t:Token {address: token.address})
|
| 1659 |
+
ON CREATE SET t.created_ts = token.created_ts
|
| 1660 |
+
";
|
| 1661 |
+
|
| 1662 |
+
self.enqueue_write(cypher, params).await?;
|
| 1663 |
+
}
|
| 1664 |
+
Ok(())
|
| 1665 |
+
}
|
| 1666 |
+
|
| 1667 |
+
async fn write_bundle_trade_links(&self, links: &[BundleTradeLink]) -> Result<()> {
|
| 1668 |
+
if links.is_empty() {
|
| 1669 |
+
return Ok(());
|
| 1670 |
+
}
|
| 1671 |
+
let params: Vec<_> = links
|
| 1672 |
+
.iter()
|
| 1673 |
+
.map(|l| {
|
| 1674 |
+
HashMap::from([
|
| 1675 |
+
("wa".to_string(), BoltType::from(l.wallet_a.clone())),
|
| 1676 |
+
("wb".to_string(), BoltType::from(l.wallet_b.clone())),
|
| 1677 |
+
("mint".to_string(), BoltType::from(l.mint.clone())),
|
| 1678 |
+
("slot".to_string(), BoltType::from(l.slot)),
|
| 1679 |
+
("timestamp".to_string(), BoltType::from(l.timestamp)),
|
| 1680 |
+
(
|
| 1681 |
+
"signatures".to_string(),
|
| 1682 |
+
BoltType::from(l.signatures.clone()),
|
| 1683 |
+
),
|
| 1684 |
+
])
|
| 1685 |
+
})
|
| 1686 |
+
.collect();
|
| 1687 |
+
// Corrected relationship name to BUNDLE_TRADE for consistency
|
| 1688 |
+
let cypher = "
|
| 1689 |
+
UNWIND $x as t
|
| 1690 |
+
MERGE (a:Wallet {address: t.wa})
|
| 1691 |
+
MERGE (b:Wallet {address: t.wb})
|
| 1692 |
+
MERGE (a)-[r:BUNDLE_TRADE {mint: t.mint, slot: t.slot}]->(b)
|
| 1693 |
+
ON CREATE SET r.timestamp = t.timestamp, r.signatures = t.signatures
|
| 1694 |
+
ON MATCH SET r.timestamp = CASE WHEN t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END
|
| 1695 |
+
";
|
| 1696 |
+
self.enqueue_write(cypher, params).await
|
| 1697 |
+
}
|
| 1698 |
+
|
| 1699 |
+
async fn write_transfer_links(&self, links: &[TransferLink]) -> Result<()> {
|
| 1700 |
+
if links.is_empty() {
|
| 1701 |
+
return Ok(());
|
| 1702 |
+
}
|
| 1703 |
+
|
| 1704 |
+
// --- THE FIX ---
|
| 1705 |
+
// Use `unique_by` to get the *entire first link object* for each unique path.
|
| 1706 |
+
// This preserves the signature and timestamp from the first event we see.
|
| 1707 |
+
let unique_links = links
|
| 1708 |
+
.iter()
|
| 1709 |
+
.unique_by(|l| (&l.source, &l.destination, &l.mint))
|
| 1710 |
+
.collect::<Vec<_>>();
|
| 1711 |
+
|
| 1712 |
+
// Now build the parameters with the full data from the unique links.
|
| 1713 |
+
let params: Vec<_> = unique_links
|
| 1714 |
+
.iter()
|
| 1715 |
+
.map(|l| {
|
| 1716 |
+
HashMap::from([
|
| 1717 |
+
("source".to_string(), BoltType::from(l.source.clone())),
|
| 1718 |
+
(
|
| 1719 |
+
"destination".to_string(),
|
| 1720 |
+
BoltType::from(l.destination.clone()),
|
| 1721 |
+
),
|
| 1722 |
+
("mint".to_string(), BoltType::from(l.mint.clone())),
|
| 1723 |
+
("signature".to_string(), BoltType::from(l.signature.clone())), // Include the signature
|
| 1724 |
+
("timestamp".to_string(), BoltType::from(l.timestamp)), // Include the on-chain timestamp
|
| 1725 |
+
("amount".to_string(), BoltType::from(l.amount)),
|
| 1726 |
+
])
|
| 1727 |
+
})
|
| 1728 |
+
.collect();
|
| 1729 |
+
|
| 1730 |
+
// --- UPDATED CYPHER QUERY ---
|
| 1731 |
+
// The query now sets the signature and on-chain timestamp on the link when it's first created.
|
| 1732 |
+
let cypher = "
|
| 1733 |
+
UNWIND $x as t
|
| 1734 |
+
MERGE (s:Wallet {address: t.source})
|
| 1735 |
+
MERGE (d:Wallet {address: t.destination})
|
| 1736 |
+
MERGE (s)-[r:TRANSFERRED_TO {mint: t.mint}]->(d)
|
| 1737 |
+
ON CREATE SET
|
| 1738 |
+
r.signature = t.signature,
|
| 1739 |
+
r.timestamp = t.timestamp,
|
| 1740 |
+
r.amount = t.amount
|
| 1741 |
+
ON MATCH SET r.timestamp = CASE WHEN t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END
|
| 1742 |
+
";
|
| 1743 |
+
|
| 1744 |
+
self.enqueue_write(cypher, params).await
|
| 1745 |
+
}
|
| 1746 |
+
|
| 1747 |
+
async fn write_coordinated_activity_links(
|
| 1748 |
+
&self,
|
| 1749 |
+
links: &[CoordinatedActivityLink],
|
| 1750 |
+
) -> Result<()> {
|
| 1751 |
+
if links.is_empty() {
|
| 1752 |
+
return Ok(());
|
| 1753 |
+
}
|
| 1754 |
+
|
| 1755 |
+
let params: Vec<_> = links
|
| 1756 |
+
.iter()
|
| 1757 |
+
.map(|l| {
|
| 1758 |
+
HashMap::from([
|
| 1759 |
+
("leader".to_string(), BoltType::from(l.leader.clone())),
|
| 1760 |
+
("follower".to_string(), BoltType::from(l.follower.clone())),
|
| 1761 |
+
("mint".to_string(), BoltType::from(l.mint.clone())),
|
| 1762 |
+
("timestamp".to_string(), BoltType::from(l.timestamp)),
|
| 1763 |
+
// Use the new, correct field names
|
| 1764 |
+
(
|
| 1765 |
+
"l_sig_1".to_string(),
|
| 1766 |
+
BoltType::from(l.leader_first_sig.clone()),
|
| 1767 |
+
),
|
| 1768 |
+
(
|
| 1769 |
+
"l_sig_2".to_string(),
|
| 1770 |
+
BoltType::from(l.leader_second_sig.clone()),
|
| 1771 |
+
),
|
| 1772 |
+
(
|
| 1773 |
+
"f_sig_1".to_string(),
|
| 1774 |
+
BoltType::from(l.follower_first_sig.clone()),
|
| 1775 |
+
),
|
| 1776 |
+
(
|
| 1777 |
+
"f_sig_2".to_string(),
|
| 1778 |
+
BoltType::from(l.follower_second_sig.clone()),
|
| 1779 |
+
),
|
| 1780 |
+
("gap_1".to_string(), BoltType::from(l.time_gap_on_first_sec)),
|
| 1781 |
+
(
|
| 1782 |
+
"gap_2".to_string(),
|
| 1783 |
+
BoltType::from(l.time_gap_on_second_sec),
|
| 1784 |
+
),
|
| 1785 |
+
])
|
| 1786 |
+
})
|
| 1787 |
+
.collect();
|
| 1788 |
+
|
| 1789 |
+
// This query now creates a single, comprehensive link per pair/mint
|
| 1790 |
+
let cypher = "
|
| 1791 |
+
UNWIND $x as t
|
| 1792 |
+
MERGE (l:Wallet {address: t.leader})
|
| 1793 |
+
MERGE (f:Wallet {address: t.follower})
|
| 1794 |
+
MERGE (f)-[r:COORDINATED_ACTIVITY {mint: t.mint}]->(l)
|
| 1795 |
+
ON CREATE SET
|
| 1796 |
+
r.timestamp = t.timestamp,
|
| 1797 |
+
r.leader_first_sig = t.l_sig_1,
|
| 1798 |
+
r.leader_second_sig = t.l_sig_2,
|
| 1799 |
+
r.follower_first_sig = t.f_sig_1,
|
| 1800 |
+
r.follower_second_sig = t.f_sig_2,
|
| 1801 |
+
r.time_gap_on_first_sec = t.gap_1,
|
| 1802 |
+
r.time_gap_on_second_sec = t.gap_2
|
| 1803 |
+
ON MATCH SET r.timestamp = CASE WHEN t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END
|
| 1804 |
+
";
|
| 1805 |
+
|
| 1806 |
+
self.enqueue_write(cypher, params).await
|
| 1807 |
+
}
|
| 1808 |
+
|
| 1809 |
+
async fn write_copied_trade_links(&self, links: &[CopiedTradeLink]) -> Result<()> {
|
| 1810 |
+
if links.is_empty() {
|
| 1811 |
+
return Ok(());
|
| 1812 |
+
}
|
| 1813 |
+
// This uses the latest struct definition provided in the prompt.
|
| 1814 |
+
let params: Vec<_> = links
|
| 1815 |
+
.iter()
|
| 1816 |
+
.map(|l| {
|
| 1817 |
+
HashMap::from([
|
| 1818 |
+
("follower".to_string(), BoltType::from(l.follower.clone())),
|
| 1819 |
+
("leader".to_string(), BoltType::from(l.leader.clone())),
|
| 1820 |
+
("mint".to_string(), BoltType::from(l.mint.clone())),
|
| 1821 |
+
("buy_gap".to_string(), BoltType::from(l.time_gap_on_buy_sec)),
|
| 1822 |
+
(
|
| 1823 |
+
"sell_gap".to_string(),
|
| 1824 |
+
BoltType::from(l.time_gap_on_sell_sec),
|
| 1825 |
+
),
|
| 1826 |
+
("leader_pnl".to_string(), BoltType::from(l.leader_pnl)),
|
| 1827 |
+
("follower_pnl".to_string(), BoltType::from(l.follower_pnl)),
|
| 1828 |
+
(
|
| 1829 |
+
"l_buy_sig".to_string(),
|
| 1830 |
+
BoltType::from(l.leader_buy_sig.clone()),
|
| 1831 |
+
),
|
| 1832 |
+
(
|
| 1833 |
+
"l_sell_sig".to_string(),
|
| 1834 |
+
BoltType::from(l.leader_sell_sig.clone()),
|
| 1835 |
+
),
|
| 1836 |
+
(
|
| 1837 |
+
"f_buy_sig".to_string(),
|
| 1838 |
+
BoltType::from(l.follower_buy_sig.clone()),
|
| 1839 |
+
),
|
| 1840 |
+
(
|
| 1841 |
+
"f_sell_sig".to_string(),
|
| 1842 |
+
BoltType::from(l.follower_sell_sig.clone()),
|
| 1843 |
+
),
|
| 1844 |
+
(
|
| 1845 |
+
"l_buy_total".to_string(),
|
| 1846 |
+
BoltType::from(l.leader_buy_total),
|
| 1847 |
+
),
|
| 1848 |
+
(
|
| 1849 |
+
"l_sell_total".to_string(),
|
| 1850 |
+
BoltType::from(l.leader_sell_total),
|
| 1851 |
+
),
|
| 1852 |
+
(
|
| 1853 |
+
"f_buy_total".to_string(),
|
| 1854 |
+
BoltType::from(l.follower_buy_total),
|
| 1855 |
+
),
|
| 1856 |
+
(
|
| 1857 |
+
"f_sell_total".to_string(),
|
| 1858 |
+
BoltType::from(l.follower_sell_total),
|
| 1859 |
+
),
|
| 1860 |
+
(
|
| 1861 |
+
"f_buy_slip".to_string(),
|
| 1862 |
+
BoltType::from(l.follower_buy_slippage),
|
| 1863 |
+
),
|
| 1864 |
+
(
|
| 1865 |
+
"f_sell_slip".to_string(),
|
| 1866 |
+
BoltType::from(l.follower_sell_slippage),
|
| 1867 |
+
),
|
| 1868 |
+
("timestamp".to_string(), BoltType::from(l.timestamp)),
|
| 1869 |
+
])
|
| 1870 |
+
})
|
| 1871 |
+
.collect();
|
| 1872 |
+
let cypher = "
|
| 1873 |
+
UNWIND $x as t
|
| 1874 |
+
MERGE (f:Wallet {address: t.follower})
|
| 1875 |
+
MERGE (l:Wallet {address: t.leader})
|
| 1876 |
+
MERGE (f)-[r:COPIED_TRADE {mint: t.mint}]->(l)
|
| 1877 |
+
ON CREATE SET
|
| 1878 |
+
r.timestamp = t.timestamp,
|
| 1879 |
+
r.follower = t.follower,
|
| 1880 |
+
r.leader = t.leader,
|
| 1881 |
+
r.mint = t.mint,
|
| 1882 |
+
r.buy_gap = t.buy_gap,
|
| 1883 |
+
r.sell_gap = t.sell_gap,
|
| 1884 |
+
r.leader_pnl = t.leader_pnl,
|
| 1885 |
+
r.follower_pnl = t.follower_pnl,
|
| 1886 |
+
r.l_buy_sig = t.l_buy_sig,
|
| 1887 |
+
r.l_sell_sig = t.l_sell_sig,
|
| 1888 |
+
r.f_buy_sig = t.f_buy_sig,
|
| 1889 |
+
r.f_sell_sig = t.f_sell_sig,
|
| 1890 |
+
r.l_buy_total = t.l_buy_total,
|
| 1891 |
+
r.l_sell_total = t.l_sell_total,
|
| 1892 |
+
r.f_buy_total = t.f_buy_total,
|
| 1893 |
+
r.f_sell_total = t.f_sell_total,
|
| 1894 |
+
r.f_buy_slip = t.f_buy_slip,
|
| 1895 |
+
r.f_sell_slip = t.f_sell_slip
|
| 1896 |
+
ON MATCH SET r.timestamp = CASE WHEN t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END
|
| 1897 |
+
";
|
| 1898 |
+
self.enqueue_write(cypher, params).await
|
| 1899 |
+
}
|
| 1900 |
+
|
| 1901 |
+
async fn write_minted_links(&self, links: &[MintedLink], mints: &[MintRow]) -> Result<()> {
|
| 1902 |
+
if links.is_empty() {
|
| 1903 |
+
return Ok(());
|
| 1904 |
+
}
|
| 1905 |
+
let mint_map: HashMap<_, _> = mints.iter().map(|m| (m.signature.clone(), m)).collect();
|
| 1906 |
+
|
| 1907 |
+
let params: Vec<_> = links
|
| 1908 |
+
.iter()
|
| 1909 |
+
.filter_map(|l| {
|
| 1910 |
+
mint_map.get(&l.signature).map(|m| {
|
| 1911 |
+
HashMap::from([
|
| 1912 |
+
(
|
| 1913 |
+
"creator".to_string(),
|
| 1914 |
+
BoltType::from(m.creator_address.clone()),
|
| 1915 |
+
),
|
| 1916 |
+
("token".to_string(), BoltType::from(m.mint_address.clone())),
|
| 1917 |
+
("signature".to_string(), BoltType::from(l.signature.clone())),
|
| 1918 |
+
("timestamp".to_string(), BoltType::from(l.timestamp)),
|
| 1919 |
+
("buy_amount".to_string(), BoltType::from(l.buy_amount)),
|
| 1920 |
+
])
|
| 1921 |
+
})
|
| 1922 |
+
})
|
| 1923 |
+
.collect();
|
| 1924 |
+
|
| 1925 |
+
if params.is_empty() {
|
| 1926 |
+
return Ok(());
|
| 1927 |
+
}
|
| 1928 |
+
// --- MODIFIED: MERGE on the signature for idempotency ---
|
| 1929 |
+
let cypher = "
|
| 1930 |
+
UNWIND $x as t
|
| 1931 |
+
MERGE (c:Wallet {address: t.creator})
|
| 1932 |
+
MERGE (k:Token {address: t.token})
|
| 1933 |
+
MERGE (c)-[r:MINTED {signature: t.signature}]->(k)
|
| 1934 |
+
ON CREATE SET r.timestamp = t.timestamp, r.buy_amount = t.buy_amount
|
| 1935 |
+
ON MATCH SET r.timestamp = CASE WHEN t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END
|
| 1936 |
+
";
|
| 1937 |
+
self.enqueue_write(cypher, params).await
|
| 1938 |
+
}
|
| 1939 |
+
|
| 1940 |
+
async fn write_sniped_links(
|
| 1941 |
+
&self,
|
| 1942 |
+
links: &[SnipedLink],
|
| 1943 |
+
snipers: &HashMap<String, (String, String)>,
|
| 1944 |
+
) -> Result<()> {
|
| 1945 |
+
if links.is_empty() {
|
| 1946 |
+
return Ok(());
|
| 1947 |
+
}
|
| 1948 |
+
|
| 1949 |
+
let params: Vec<_> = links
|
| 1950 |
+
.iter()
|
| 1951 |
+
.filter_map(|l| {
|
| 1952 |
+
snipers.get(&l.signature).map(|(wallet, token)| {
|
| 1953 |
+
HashMap::from([
|
| 1954 |
+
("wallet".to_string(), BoltType::from(wallet.clone())),
|
| 1955 |
+
("token".to_string(), BoltType::from(token.clone())),
|
| 1956 |
+
("signature".to_string(), BoltType::from(l.signature.clone())),
|
| 1957 |
+
("rank".to_string(), BoltType::from(l.rank)),
|
| 1958 |
+
("sniped_amount".to_string(), BoltType::from(l.sniped_amount)),
|
| 1959 |
+
("timestamp".to_string(), BoltType::from(l.timestamp)),
|
| 1960 |
+
])
|
| 1961 |
+
})
|
| 1962 |
+
})
|
| 1963 |
+
.collect();
|
| 1964 |
+
|
| 1965 |
+
if params.is_empty() {
|
| 1966 |
+
return Ok(());
|
| 1967 |
+
}
|
| 1968 |
+
|
| 1969 |
+
// --- MODIFIED: MERGE on signature ---
|
| 1970 |
+
let cypher = "
|
| 1971 |
+
UNWIND $x as t
|
| 1972 |
+
MERGE (w:Wallet {address: t.wallet})
|
| 1973 |
+
MERGE (k:Token {address: t.token})
|
| 1974 |
+
MERGE (w)-[r:SNIPED {signature: t.signature}]->(k)
|
| 1975 |
+
ON CREATE SET r.rank = t.rank, r.sniped_amount = t.sniped_amount, r.timestamp = t.timestamp
|
| 1976 |
+
ON MATCH SET r.timestamp = CASE WHEN t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END
|
| 1977 |
+
";
|
| 1978 |
+
self.enqueue_write(cypher, params).await
|
| 1979 |
+
}
|
| 1980 |
+
|
| 1981 |
+
async fn write_locked_supply_links(
|
| 1982 |
+
&self,
|
| 1983 |
+
links: &[LockedSupplyLink],
|
| 1984 |
+
locks: &[SupplyLockRow],
|
| 1985 |
+
) -> Result<()> {
|
| 1986 |
+
if links.is_empty() {
|
| 1987 |
+
return Ok(());
|
| 1988 |
+
}
|
| 1989 |
+
let lock_map: HashMap<_, _> = locks.iter().map(|l| (l.signature.clone(), l)).collect();
|
| 1990 |
+
|
| 1991 |
+
let params: Vec<_> = links
|
| 1992 |
+
.iter()
|
| 1993 |
+
.filter_map(|l| {
|
| 1994 |
+
lock_map.get(&l.signature).map(|lock_row| {
|
| 1995 |
+
HashMap::from([
|
| 1996 |
+
(
|
| 1997 |
+
"sender".to_string(),
|
| 1998 |
+
BoltType::from(lock_row.sender.clone()),
|
| 1999 |
+
),
|
| 2000 |
+
(
|
| 2001 |
+
"recipient".to_string(),
|
| 2002 |
+
BoltType::from(lock_row.recipient.clone()),
|
| 2003 |
+
),
|
| 2004 |
+
(
|
| 2005 |
+
"mint".to_string(),
|
| 2006 |
+
BoltType::from(lock_row.mint_address.clone()),
|
| 2007 |
+
),
|
| 2008 |
+
("signature".to_string(), BoltType::from(l.signature.clone())),
|
| 2009 |
+
("amount".to_string(), BoltType::from(l.amount)),
|
| 2010 |
+
(
|
| 2011 |
+
"unlock_ts".to_string(),
|
| 2012 |
+
BoltType::from(l.unlock_timestamp as i64),
|
| 2013 |
+
),
|
| 2014 |
+
("timestamp".to_string(), BoltType::from(l.timestamp)),
|
| 2015 |
+
])
|
| 2016 |
+
})
|
| 2017 |
+
})
|
| 2018 |
+
.collect();
|
| 2019 |
+
|
| 2020 |
+
if params.is_empty() {
|
| 2021 |
+
return Ok(());
|
| 2022 |
+
}
|
| 2023 |
+
|
| 2024 |
+
// --- THE CRITICAL FIX ---
|
| 2025 |
+
let cypher = "
|
| 2026 |
+
UNWIND $x as t
|
| 2027 |
+
MERGE (s:Wallet {address: t.sender})
|
| 2028 |
+
MERGE (k:Token {address: t.mint})
|
| 2029 |
+
MERGE (s)-[r:LOCKED_SUPPLY {signature: t.signature}]->(k)
|
| 2030 |
+
ON CREATE SET r.amount = t.amount, r.unlock_timestamp = t.unlock_ts, r.recipient = t.recipient, r.timestamp = t.timestamp
|
| 2031 |
+
ON MATCH SET r.timestamp = CASE WHEN t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END
|
| 2032 |
+
";
|
| 2033 |
+
self.enqueue_write(cypher, params).await
|
| 2034 |
+
}
|
| 2035 |
+
|
| 2036 |
+
async fn write_burned_links(&self, links: &[BurnedLink], burns: &[BurnRow]) -> Result<()> {
|
| 2037 |
+
if links.is_empty() {
|
| 2038 |
+
return Ok(());
|
| 2039 |
+
}
|
| 2040 |
+
let burn_map: HashMap<_, _> = burns.iter().map(|b| (b.signature.clone(), b)).collect();
|
| 2041 |
+
|
| 2042 |
+
let params: Vec<_> = links
|
| 2043 |
+
.iter()
|
| 2044 |
+
.filter_map(|l| {
|
| 2045 |
+
burn_map.get(&l.signature).map(|burn_row| {
|
| 2046 |
+
HashMap::from([
|
| 2047 |
+
(
|
| 2048 |
+
"wallet".to_string(),
|
| 2049 |
+
BoltType::from(burn_row.source.clone()),
|
| 2050 |
+
),
|
| 2051 |
+
(
|
| 2052 |
+
"token".to_string(),
|
| 2053 |
+
BoltType::from(burn_row.mint_address.clone()),
|
| 2054 |
+
),
|
| 2055 |
+
("signature".to_string(), BoltType::from(l.signature.clone())),
|
| 2056 |
+
("amount".to_string(), BoltType::from(l.amount)),
|
| 2057 |
+
("timestamp".to_string(), BoltType::from(l.timestamp)),
|
| 2058 |
+
])
|
| 2059 |
+
})
|
| 2060 |
+
})
|
| 2061 |
+
.collect();
|
| 2062 |
+
|
| 2063 |
+
if params.is_empty() {
|
| 2064 |
+
return Ok(());
|
| 2065 |
+
}
|
| 2066 |
+
// --- MODIFIED: MERGE on signature ---
|
| 2067 |
+
let cypher = "
|
| 2068 |
+
UNWIND $x as t
|
| 2069 |
+
MATCH (w:Wallet {address: t.wallet}), (k:Token {address: t.token})
|
| 2070 |
+
MERGE (w)-[r:BURNED {signature: t.signature}]->(k)
|
| 2071 |
+
ON CREATE SET r.amount = t.amount, r.timestamp = t.timestamp
|
| 2072 |
+
ON MATCH SET r.timestamp = CASE WHEN t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END
|
| 2073 |
+
";
|
| 2074 |
+
self.enqueue_write(cypher, params).await
|
| 2075 |
+
}
|
| 2076 |
+
|
| 2077 |
+
async fn write_provided_liquidity_links(&self, links: &[ProvidedLiquidityLink]) -> Result<()> {
|
| 2078 |
+
if links.is_empty() {
|
| 2079 |
+
return Ok(());
|
| 2080 |
+
}
|
| 2081 |
+
let params: Vec<_> = links
|
| 2082 |
+
.iter()
|
| 2083 |
+
.map(|l| {
|
| 2084 |
+
HashMap::from([
|
| 2085 |
+
("wallet".to_string(), BoltType::from(l.wallet.clone())),
|
| 2086 |
+
("token".to_string(), BoltType::from(l.token.clone())),
|
| 2087 |
+
("signature".to_string(), BoltType::from(l.signature.clone())),
|
| 2088 |
+
(
|
| 2089 |
+
"pool_address".to_string(),
|
| 2090 |
+
BoltType::from(l.pool_address.clone()),
|
| 2091 |
+
),
|
| 2092 |
+
("amount_base".to_string(), BoltType::from(l.amount_base)),
|
| 2093 |
+
("amount_quote".to_string(), BoltType::from(l.amount_quote)),
|
| 2094 |
+
("timestamp".to_string(), BoltType::from(l.timestamp)),
|
| 2095 |
+
])
|
| 2096 |
+
})
|
| 2097 |
+
.collect();
|
| 2098 |
+
|
| 2099 |
+
// --- MODIFIED: MERGE on signature ---
|
| 2100 |
+
let cypher = "
|
| 2101 |
+
UNWIND $x as t
|
| 2102 |
+
MERGE (w:Wallet {address: t.wallet})
|
| 2103 |
+
MERGE (k:Token {address: t.token})
|
| 2104 |
+
MERGE (w)-[r:PROVIDED_LIQUIDITY {signature: t.signature}]->(k)
|
| 2105 |
+
ON CREATE SET r.pool_address = t.pool_address, r.amount_base = t.amount_base, r.amount_quote = t.amount_quote, r.timestamp = t.timestamp
|
| 2106 |
+
ON MATCH SET r.timestamp = CASE WHEN t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END
|
| 2107 |
+
";
|
| 2108 |
+
self.enqueue_write(cypher, params).await
|
| 2109 |
+
}
|
| 2110 |
+
|
| 2111 |
+
async fn write_top_trader_of_links(&self, links: &[TopTraderOfLink]) -> Result<()> {
|
| 2112 |
+
if links.is_empty() {
|
| 2113 |
+
return Ok(());
|
| 2114 |
+
}
|
| 2115 |
+
let params: Vec<_> = links
|
| 2116 |
+
.iter()
|
| 2117 |
+
.map(|l| {
|
| 2118 |
+
HashMap::from([
|
| 2119 |
+
("wallet".to_string(), BoltType::from(l.wallet.clone())),
|
| 2120 |
+
("token".to_string(), BoltType::from(l.token.clone())),
|
| 2121 |
+
// Add new params
|
| 2122 |
+
(
|
| 2123 |
+
"pnl_at_creation".to_string(),
|
| 2124 |
+
BoltType::from(l.pnl_at_creation),
|
| 2125 |
+
),
|
| 2126 |
+
(
|
| 2127 |
+
"ath_at_creation".to_string(),
|
| 2128 |
+
BoltType::from(l.ath_usd_at_creation),
|
| 2129 |
+
),
|
| 2130 |
+
("timestamp".to_string(), BoltType::from(l.timestamp)),
|
| 2131 |
+
])
|
| 2132 |
+
})
|
| 2133 |
+
.collect();
|
| 2134 |
+
|
| 2135 |
+
// --- MODIFIED: The definitive Cypher query ---
|
| 2136 |
+
let cypher = "
|
| 2137 |
+
UNWIND $x as t
|
| 2138 |
+
MERGE (w:Wallet {address: t.wallet})
|
| 2139 |
+
MERGE (k:Token {address: t.token})
|
| 2140 |
+
MERGE (w)-[r:TOP_TRADER_OF]->(k)
|
| 2141 |
+
ON CREATE SET
|
| 2142 |
+
r.pnl_at_creation = t.pnl_at_creation,
|
| 2143 |
+
r.ath_usd_at_creation = t.ath_at_creation,
|
| 2144 |
+
r.timestamp = t.timestamp
|
| 2145 |
+
ON MATCH SET r.timestamp = CASE WHEN t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END
|
| 2146 |
+
";
|
| 2147 |
+
self.enqueue_write(cypher, params).await
|
| 2148 |
+
}
|
| 2149 |
+
|
| 2150 |
+
async fn write_whale_of_links(&self, links: &[WhaleOfLink]) -> Result<()> {
|
| 2151 |
+
if links.is_empty() {
|
| 2152 |
+
return Ok(());
|
| 2153 |
+
}
|
| 2154 |
+
let params: Vec<_> = links
|
| 2155 |
+
.iter()
|
| 2156 |
+
.map(|l| {
|
| 2157 |
+
HashMap::from([
|
| 2158 |
+
("wallet".to_string(), BoltType::from(l.wallet.clone())),
|
| 2159 |
+
("token".to_string(), BoltType::from(l.token.clone())),
|
| 2160 |
+
// Add new params
|
| 2161 |
+
(
|
| 2162 |
+
"pct_at_creation".to_string(),
|
| 2163 |
+
BoltType::from(l.holding_pct_at_creation),
|
| 2164 |
+
),
|
| 2165 |
+
(
|
| 2166 |
+
"ath_at_creation".to_string(),
|
| 2167 |
+
BoltType::from(l.ath_usd_at_creation),
|
| 2168 |
+
),
|
| 2169 |
+
("timestamp".to_string(), BoltType::from(l.timestamp)),
|
| 2170 |
+
])
|
| 2171 |
+
})
|
| 2172 |
+
.collect();
|
| 2173 |
+
|
| 2174 |
+
// --- MODIFIED: The definitive Cypher query ---
|
| 2175 |
+
let cypher = "
|
| 2176 |
+
UNWIND $x as t
|
| 2177 |
+
MERGE (w:Wallet {address: t.wallet})
|
| 2178 |
+
MERGE (k:Token {address: t.token})
|
| 2179 |
+
MERGE (w)-[r:WHALE_OF]->(k)
|
| 2180 |
+
ON CREATE SET
|
| 2181 |
+
r.holding_pct_at_creation = t.pct_at_creation,
|
| 2182 |
+
r.ath_usd_at_creation = t.ath_at_creation,
|
| 2183 |
+
r.timestamp = t.timestamp
|
| 2184 |
+
ON MATCH SET r.timestamp = CASE WHEN t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END
|
| 2185 |
+
";
|
| 2186 |
+
self.enqueue_write(cypher, params).await
|
| 2187 |
+
}
|
| 2188 |
+
|
| 2189 |
+
async fn fetch_latest_ath_map_with_retry(
|
| 2190 |
+
&self,
|
| 2191 |
+
token_addresses: &[String],
|
| 2192 |
+
) -> Result<HashMap<String, f64>> {
|
| 2193 |
+
let mut ath_map = HashMap::new();
|
| 2194 |
+
if token_addresses.is_empty() {
|
| 2195 |
+
return Ok(ath_map);
|
| 2196 |
+
}
|
| 2197 |
+
let cfg = &*LINK_GRAPH_CONFIG;
|
| 2198 |
+
|
| 2199 |
+
#[derive(Row, Deserialize, Debug)]
|
| 2200 |
+
struct AthInfo {
|
| 2201 |
+
token_address: String,
|
| 2202 |
+
ath_price_usd: f64,
|
| 2203 |
+
}
|
| 2204 |
+
|
| 2205 |
+
let query = "
|
| 2206 |
+
SELECT token_address, ath_price_usd
|
| 2207 |
+
FROM token_metrics_latest
|
| 2208 |
+
WHERE token_address IN ?
|
| 2209 |
+
ORDER BY token_address, updated_at DESC
|
| 2210 |
+
LIMIT 1 BY token_address
|
| 2211 |
+
";
|
| 2212 |
+
|
| 2213 |
+
for chunk in token_addresses.chunks(cfg.ath_fetch_chunk_size.max(1)) {
|
| 2214 |
+
let mut attempts = 0;
|
| 2215 |
+
loop {
|
| 2216 |
+
attempts += 1;
|
| 2217 |
+
let result: Result<Vec<AthInfo>> = self
|
| 2218 |
+
.db_client
|
| 2219 |
+
.query(query)
|
| 2220 |
+
.bind(chunk)
|
| 2221 |
+
.fetch_all()
|
| 2222 |
+
.await
|
| 2223 |
+
.map_err(|e| anyhow!("[LinkGraph] ATH fetch failed: {}", e));
|
| 2224 |
+
|
| 2225 |
+
match result {
|
| 2226 |
+
Ok(mut chunk_rows) => {
|
| 2227 |
+
for row in chunk_rows.drain(..) {
|
| 2228 |
+
ath_map.insert(row.token_address, row.ath_price_usd);
|
| 2229 |
+
}
|
| 2230 |
+
break;
|
| 2231 |
+
}
|
| 2232 |
+
Err(e) => {
|
| 2233 |
+
if attempts >= cfg.ch_retry_attempts {
|
| 2234 |
+
eprintln!(
|
| 2235 |
+
"[LinkGraph] 🔴 ATH fetch failed after {} attempts: {}",
|
| 2236 |
+
attempts, e
|
| 2237 |
+
);
|
| 2238 |
+
std::process::exit(1);
|
| 2239 |
+
}
|
| 2240 |
+
let backoff = cfg.ch_retry_backoff_ms * attempts as u64;
|
| 2241 |
+
eprintln!(
|
| 2242 |
+
"[LinkGraph] ⚠️ ATH fetch retry {}/{} after {}ms: {}",
|
| 2243 |
+
attempts, cfg.ch_retry_attempts, backoff, e
|
| 2244 |
+
);
|
| 2245 |
+
sleep(Duration::from_millis(backoff)).await;
|
| 2246 |
+
}
|
| 2247 |
+
}
|
| 2248 |
+
}
|
| 2249 |
+
}
|
| 2250 |
+
|
| 2251 |
+
Ok(ath_map)
|
| 2252 |
+
}
|
| 2253 |
+
|
| 2254 |
+
async fn fetch_pnl(&self, wallet_address: &str, mint_address: &str) -> Result<f64> {
|
| 2255 |
+
let q_str = format!(
|
| 2256 |
+
"SELECT realized_profit_pnl FROM wallet_holdings_latest WHERE wallet_address = '{}' AND mint_address = '{}'",
|
| 2257 |
+
wallet_address, mint_address
|
| 2258 |
+
);
|
| 2259 |
+
// Fetch the pre-calculated f32 value
|
| 2260 |
+
let pnl_f32 = self
|
| 2261 |
+
.with_ch_retry(
|
| 2262 |
+
|| async {
|
| 2263 |
+
self.db_client
|
| 2264 |
+
.query(&q_str)
|
| 2265 |
+
.fetch_one::<f32>()
|
| 2266 |
+
.await
|
| 2267 |
+
.map_err(anyhow::Error::from)
|
| 2268 |
+
},
|
| 2269 |
+
"Fetch PNL",
|
| 2270 |
+
)
|
| 2271 |
+
.await?;
|
| 2272 |
+
// Cast to f64 for the return type
|
| 2273 |
+
Ok(pnl_f32 as f64)
|
| 2274 |
+
}
|
| 2275 |
+
}
|
log.log
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:47b5b03f090da19eba850d54ea4cab1a97ebfdb7712ef4842cfc43804ec411b8
|
| 3 |
+
size 10517118
|
models/HoldersEncoder.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from typing import List, Dict, Any
|
| 4 |
+
|
| 5 |
+
class HolderDistributionEncoder(nn.Module):
|
| 6 |
+
"""
|
| 7 |
+
Encodes a list of top holders (wallet embeddings + holding percentages)
|
| 8 |
+
into a single fixed-size embedding representing the holder distribution.
|
| 9 |
+
It uses a Transformer Encoder to capture patterns and relationships.
|
| 10 |
+
"""
|
| 11 |
+
def __init__(self,
|
| 12 |
+
wallet_embedding_dim: int,
|
| 13 |
+
output_dim: int,
|
| 14 |
+
nhead: int = 4,
|
| 15 |
+
num_layers: int = 2,
|
| 16 |
+
dtype: torch.dtype = torch.float16):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.wallet_embedding_dim = wallet_embedding_dim
|
| 19 |
+
self.output_dim = output_dim
|
| 20 |
+
self.dtype = dtype
|
| 21 |
+
|
| 22 |
+
# 1. MLP to project holding percentage to the wallet embedding dimension
|
| 23 |
+
self.pct_proj = nn.Sequential(
|
| 24 |
+
nn.Linear(1, wallet_embedding_dim // 4),
|
| 25 |
+
nn.GELU(),
|
| 26 |
+
nn.Linear(wallet_embedding_dim // 4, wallet_embedding_dim)
|
| 27 |
+
).to(dtype)
|
| 28 |
+
|
| 29 |
+
# 2. Transformer Encoder to process the sequence of holders
|
| 30 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 31 |
+
d_model=wallet_embedding_dim,
|
| 32 |
+
nhead=nhead,
|
| 33 |
+
dim_feedforward=wallet_embedding_dim * 4,
|
| 34 |
+
dropout=0.1,
|
| 35 |
+
activation='gelu',
|
| 36 |
+
batch_first=True,
|
| 37 |
+
dtype=dtype
|
| 38 |
+
)
|
| 39 |
+
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
| 40 |
+
|
| 41 |
+
# 3. A learnable [CLS] token to aggregate the sequence information
|
| 42 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, wallet_embedding_dim, dtype=dtype))
|
| 43 |
+
|
| 44 |
+
# 4. Final projection layer to get the desired output dimension
|
| 45 |
+
self.final_proj = nn.Linear(wallet_embedding_dim, output_dim).to(dtype)
|
| 46 |
+
|
| 47 |
+
def forward(self, holder_data: List[Dict[str, Any]]) -> torch.Tensor:
|
| 48 |
+
"""
|
| 49 |
+
Args:
|
| 50 |
+
holder_data: A list of dictionaries, where each dict contains:
|
| 51 |
+
'wallet_embedding': A tensor of shape [wallet_embedding_dim]
|
| 52 |
+
'pct': The holding percentage as a float.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
A tensor of shape [1, output_dim] representing the entire distribution.
|
| 56 |
+
"""
|
| 57 |
+
if not holder_data:
|
| 58 |
+
# Return a zero tensor if there are no holders
|
| 59 |
+
return torch.zeros(1, self.output_dim, device=self.cls_token.device, dtype=self.dtype)
|
| 60 |
+
|
| 61 |
+
# Prepare inputs for the transformer
|
| 62 |
+
wallet_embeds = torch.stack([d['wallet_embedding'] for d in holder_data])
|
| 63 |
+
holder_pcts = torch.tensor([[d['pct']] for d in holder_data], device=wallet_embeds.device, dtype=self.dtype)
|
| 64 |
+
|
| 65 |
+
# Project percentages and add to wallet embeddings to create holder features
|
| 66 |
+
pct_embeds = self.pct_proj(holder_pcts)
|
| 67 |
+
holder_inputs = (wallet_embeds + pct_embeds).unsqueeze(0) # Add batch dimension
|
| 68 |
+
|
| 69 |
+
# Prepend the [CLS] token
|
| 70 |
+
batch_size = holder_inputs.size(0)
|
| 71 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
| 72 |
+
transformer_input = torch.cat((cls_tokens, holder_inputs), dim=1)
|
| 73 |
+
|
| 74 |
+
# Pass through the transformer
|
| 75 |
+
transformer_output = self.transformer_encoder(transformer_input)
|
| 76 |
+
|
| 77 |
+
# Get the embedding of the [CLS] token (the first token)
|
| 78 |
+
cls_embedding = transformer_output[:, 0, :]
|
| 79 |
+
|
| 80 |
+
# Project to the final output dimension
|
| 81 |
+
return self.final_proj(cls_embedding)
|
models/SocialEncoders.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from typing import Optional, Dict, Any
|
| 5 |
+
import models.vocabulary as vocab # For event type IDs
|
| 6 |
+
|
| 7 |
+
class XPostEncoder(nn.Module):
|
| 8 |
+
""" Encodes: <AuthorWalletEmbedding>, <PostTextEmbedding>, <MediaEmbedding> """
|
| 9 |
+
def __init__(self, d_model: int, dtype: torch.dtype):
|
| 10 |
+
super().__init__()
|
| 11 |
+
# Input: Wallet (d_model) + Text (d_model) + Media (d_model)
|
| 12 |
+
self.mlp = nn.Sequential(
|
| 13 |
+
nn.Linear(d_model * 3, d_model * 2),
|
| 14 |
+
nn.GELU(),
|
| 15 |
+
nn.LayerNorm(d_model * 2),
|
| 16 |
+
nn.Linear(d_model * 2, d_model)
|
| 17 |
+
).to(dtype)
|
| 18 |
+
|
| 19 |
+
def forward(self, author_emb: torch.Tensor, text_emb: torch.Tensor, media_emb: torch.Tensor) -> torch.Tensor:
|
| 20 |
+
combined = torch.cat([author_emb, text_emb, media_emb], dim=-1)
|
| 21 |
+
return self.mlp(combined)
|
| 22 |
+
|
| 23 |
+
class XRetweetEncoder(nn.Module):
|
| 24 |
+
""" Encodes: <RetweeterWalletEmbedding>, <OriginalAuthorWalletEmbedding>, <OriginalPostTextEmbedding>, <OriginalPostMediaEmbedding> """
|
| 25 |
+
def __init__(self, d_model: int, dtype: torch.dtype):
|
| 26 |
+
super().__init__()
|
| 27 |
+
# Input: Retweeter (d_model) + Original Author (d_model) + Original Text (d_model) + Original Media (d_model)
|
| 28 |
+
self.mlp = nn.Sequential(
|
| 29 |
+
nn.Linear(d_model * 4, d_model * 2),
|
| 30 |
+
nn.GELU(),
|
| 31 |
+
nn.LayerNorm(d_model * 2),
|
| 32 |
+
nn.Linear(d_model * 2, d_model)
|
| 33 |
+
).to(dtype)
|
| 34 |
+
|
| 35 |
+
def forward(self,
|
| 36 |
+
retweeter_emb: torch.Tensor,
|
| 37 |
+
orig_author_emb: torch.Tensor,
|
| 38 |
+
orig_text_emb: torch.Tensor,
|
| 39 |
+
orig_media_emb: torch.Tensor) -> torch.Tensor:
|
| 40 |
+
combined = torch.cat([retweeter_emb, orig_author_emb, orig_text_emb, orig_media_emb], dim=-1)
|
| 41 |
+
return self.mlp(combined)
|
| 42 |
+
|
| 43 |
+
class XReplyEncoder(nn.Module):
|
| 44 |
+
""" Encodes: <AuthorWalletEmbedding>, <PostTextEmbedding>, <MediaEmbedding>, <MainTweetEmbedding> """
|
| 45 |
+
def __init__(self, d_model: int, dtype: torch.dtype):
|
| 46 |
+
super().__init__()
|
| 47 |
+
# Input: Author (d_model) + Reply Text (d_model) + Reply Media (d_model) + Main Tweet Text (d_model)
|
| 48 |
+
self.mlp = nn.Sequential(
|
| 49 |
+
nn.Linear(d_model * 4, d_model * 2),
|
| 50 |
+
nn.GELU(),
|
| 51 |
+
nn.LayerNorm(d_model * 2),
|
| 52 |
+
nn.Linear(d_model * 2, d_model)
|
| 53 |
+
).to(dtype)
|
| 54 |
+
|
| 55 |
+
def forward(self,
|
| 56 |
+
author_emb: torch.Tensor,
|
| 57 |
+
text_emb: torch.Tensor,
|
| 58 |
+
media_emb: torch.Tensor,
|
| 59 |
+
main_tweet_emb: torch.Tensor) -> torch.Tensor:
|
| 60 |
+
combined = torch.cat([author_emb, text_emb, media_emb, main_tweet_emb], dim=-1)
|
| 61 |
+
return self.mlp(combined)
|
| 62 |
+
|
| 63 |
+
class XQuoteTweetEncoder(nn.Module):
|
| 64 |
+
""" Encodes: <QuoterWalletEmbedding>, <QuoterTextEmbedding>, <OriginalAuthorWalletEmbedding>, <OriginalPostTextEmbedding>, <OriginalPostMediaEmbedding> """
|
| 65 |
+
def __init__(self, d_model: int, dtype: torch.dtype):
|
| 66 |
+
super().__init__()
|
| 67 |
+
# Input: Quoter Wallet (d_model) + Quoter Text (d_model) + Orig Author (d_model) + Orig Text (d_model) + Orig Media (d_model)
|
| 68 |
+
self.mlp = nn.Sequential(
|
| 69 |
+
nn.Linear(d_model * 5, d_model * 2),
|
| 70 |
+
nn.GELU(),
|
| 71 |
+
nn.LayerNorm(d_model * 2),
|
| 72 |
+
nn.Linear(d_model * 2, d_model)
|
| 73 |
+
).to(dtype)
|
| 74 |
+
|
| 75 |
+
def forward(self,
|
| 76 |
+
quoter_wallet_emb: torch.Tensor,
|
| 77 |
+
quoter_text_emb: torch.Tensor,
|
| 78 |
+
orig_author_emb: torch.Tensor,
|
| 79 |
+
orig_text_emb: torch.Tensor,
|
| 80 |
+
orig_media_emb: torch.Tensor) -> torch.Tensor:
|
| 81 |
+
combined = torch.cat([quoter_wallet_emb, quoter_text_emb, orig_author_emb, orig_text_emb, orig_media_emb], dim=-1)
|
| 82 |
+
return self.mlp(combined)
|
| 83 |
+
|
| 84 |
+
class PumpReplyEncoder(nn.Module):
|
| 85 |
+
""" Encodes: <UserWalletEmbedding>, <ReplyTextEmbedding> """
|
| 86 |
+
def __init__(self, d_model: int, dtype: torch.dtype):
|
| 87 |
+
super().__init__()
|
| 88 |
+
# Input: User Wallet (d_model) + Reply Text (d_model)
|
| 89 |
+
self.mlp = nn.Sequential(
|
| 90 |
+
nn.Linear(d_model * 2, d_model * 2),
|
| 91 |
+
nn.GELU(),
|
| 92 |
+
nn.LayerNorm(d_model * 2),
|
| 93 |
+
nn.Linear(d_model * 2, d_model)
|
| 94 |
+
).to(dtype)
|
| 95 |
+
|
| 96 |
+
def forward(self, user_emb: torch.Tensor, text_emb: torch.Tensor) -> torch.Tensor:
|
| 97 |
+
combined = torch.cat([user_emb, text_emb], dim=-1)
|
| 98 |
+
return self.mlp(combined)
|
| 99 |
+
|
| 100 |
+
# --- NEW: Encoders for other text-based events ---
|
| 101 |
+
class DexProfileUpdatedEncoder(nn.Module):
|
| 102 |
+
""" Encodes: <4_flags_projection>, <website_emb>, <twitter_emb>, <telegram_emb>, <description_emb> """
|
| 103 |
+
def __init__(self, d_model: int, dtype: torch.dtype):
|
| 104 |
+
super().__init__()
|
| 105 |
+
# Input: flags_proj (d_model) + 4x text_embeds (d_model)
|
| 106 |
+
self.mlp = nn.Sequential(
|
| 107 |
+
nn.Linear(d_model * 4, d_model * 2), # Corrected from 5 to 4, flags are separate
|
| 108 |
+
nn.GELU(),
|
| 109 |
+
nn.LayerNorm(d_model * 2),
|
| 110 |
+
nn.Linear(d_model * 2, d_model)
|
| 111 |
+
).to(dtype)
|
| 112 |
+
|
| 113 |
+
def forward(self, website_emb: torch.Tensor, twitter_emb: torch.Tensor, telegram_emb: torch.Tensor, description_emb: torch.Tensor) -> torch.Tensor:
|
| 114 |
+
combined = torch.cat([website_emb, twitter_emb, telegram_emb, description_emb], dim=-1)
|
| 115 |
+
return self.mlp(combined)
|
| 116 |
+
|
| 117 |
+
class GlobalTrendingEncoder(nn.Module):
|
| 118 |
+
""" Encodes: <hashtag_emb> """
|
| 119 |
+
def __init__(self, d_model: int, dtype: torch.dtype):
|
| 120 |
+
super().__init__()
|
| 121 |
+
# Input: hashtag_emb (d_model)
|
| 122 |
+
self.mlp = nn.Sequential(
|
| 123 |
+
nn.Linear(d_model, d_model),
|
| 124 |
+
nn.GELU(),
|
| 125 |
+
nn.Linear(d_model, d_model)
|
| 126 |
+
).to(dtype)
|
| 127 |
+
|
| 128 |
+
def forward(self, hashtag_emb: torch.Tensor) -> torch.Tensor:
|
| 129 |
+
return self.mlp(hashtag_emb)
|
| 130 |
+
class SocialEncoder(nn.Module):
|
| 131 |
+
"""
|
| 132 |
+
A single module to house all social event encoders.
|
| 133 |
+
This simplifies instantiation in the main Oracle model.
|
| 134 |
+
"""
|
| 135 |
+
def __init__(self, d_model: int, dtype: torch.dtype):
|
| 136 |
+
super().__init__()
|
| 137 |
+
self.x_post_encoder = XPostEncoder(d_model, dtype)
|
| 138 |
+
self.x_retweet_encoder = XRetweetEncoder(d_model, dtype)
|
| 139 |
+
self.x_reply_encoder = XReplyEncoder(d_model, dtype)
|
| 140 |
+
self.x_quote_tweet_encoder = XQuoteTweetEncoder(d_model, dtype)
|
| 141 |
+
self.pump_reply_encoder = PumpReplyEncoder(d_model, dtype)
|
| 142 |
+
# --- NEW: Add the other text-based encoders ---
|
| 143 |
+
self.dex_profile_encoder = DexProfileUpdatedEncoder(d_model, dtype)
|
| 144 |
+
self.global_trending_encoder = GlobalTrendingEncoder(d_model, dtype)
|
| 145 |
+
|
| 146 |
+
# Store for convenience
|
| 147 |
+
self.d_model = d_model
|
| 148 |
+
self.dtype = dtype
|
| 149 |
+
|
| 150 |
+
def forward(self, batch: Dict[str, Any], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 151 |
+
"""
|
| 152 |
+
REFACTORED: Processes all text-based events for the entire batch in a vectorized way.
|
| 153 |
+
This replaces the inefficient loops in the main Oracle model.
|
| 154 |
+
"""
|
| 155 |
+
device = gathered_embeds['wallet'].device
|
| 156 |
+
B, L, D = gathered_embeds['wallet'].shape
|
| 157 |
+
final_embeds = torch.zeros(B, L, D, device=device, dtype=self.dtype)
|
| 158 |
+
|
| 159 |
+
textual_event_indices = batch['textual_event_indices']
|
| 160 |
+
textual_event_data = batch.get('textual_event_data', [])
|
| 161 |
+
precomputed_lookup = gathered_embeds['precomputed']
|
| 162 |
+
zero_emb = torch.zeros(self.d_model, device=device, dtype=self.dtype)
|
| 163 |
+
|
| 164 |
+
# --- Create masks for each event type ---
|
| 165 |
+
event_type_ids = batch['event_type_ids']
|
| 166 |
+
event_masks = {
|
| 167 |
+
'XPost': (event_type_ids == vocab.EVENT_TO_ID.get('XPost', -1)),
|
| 168 |
+
'XReply': (event_type_ids == vocab.EVENT_TO_ID.get('XReply', -1)),
|
| 169 |
+
'XRetweet': (event_type_ids == vocab.EVENT_TO_ID.get('XRetweet', -1)),
|
| 170 |
+
'XQuoteTweet': (event_type_ids == vocab.EVENT_TO_ID.get('XQuoteTweet', -1)),
|
| 171 |
+
'PumpReply': (event_type_ids == vocab.EVENT_TO_ID.get('PumpReply', -1)),
|
| 172 |
+
'DexProfile_Updated': (event_type_ids == vocab.EVENT_TO_ID.get('DexProfile_Updated', -1)),
|
| 173 |
+
'TikTok_Trending_Hashtag': (event_type_ids == vocab.EVENT_TO_ID.get('TikTok_Trending_Hashtag', -1)),
|
| 174 |
+
'XTrending_Hashtag': (event_type_ids == vocab.EVENT_TO_ID.get('XTrending_Hashtag', -1)),
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
# --- Gather all necessary pre-computed embeddings in one go ---
|
| 178 |
+
# Flatten indices for efficient lookup, then reshape
|
| 179 |
+
flat_indices = textual_event_indices.flatten()
|
| 180 |
+
# Create a default event structure for padding indices (idx=0)
|
| 181 |
+
default_event = {'event_type': 'PAD'}
|
| 182 |
+
# Use 1-based index from collator, so textual_event_data[idx-1]
|
| 183 |
+
raw_events_flat = [textual_event_data[idx-1] if idx > 0 else default_event for idx in flat_indices.tolist()]
|
| 184 |
+
|
| 185 |
+
# Helper to gather embeddings for a specific key
|
| 186 |
+
def gather_precomputed(key: str) -> torch.Tensor:
|
| 187 |
+
indices = torch.tensor([e.get(key, 0) for e in raw_events_flat], device=device, dtype=torch.long)
|
| 188 |
+
return F.embedding(indices, precomputed_lookup).view(B, L, -1)
|
| 189 |
+
|
| 190 |
+
# --- Process each event type ---
|
| 191 |
+
|
| 192 |
+
# XPost
|
| 193 |
+
if event_masks['XPost'].any():
|
| 194 |
+
text_emb = gather_precomputed('text_emb_idx')
|
| 195 |
+
media_emb = gather_precomputed('media_emb_idx')
|
| 196 |
+
post_embeds = self.x_post_encoder(gathered_embeds['wallet'], text_emb, media_emb)
|
| 197 |
+
final_embeds += post_embeds * event_masks['XPost'].unsqueeze(-1)
|
| 198 |
+
|
| 199 |
+
# XReply
|
| 200 |
+
if event_masks['XReply'].any():
|
| 201 |
+
text_emb = gather_precomputed('text_emb_idx')
|
| 202 |
+
media_emb = gather_precomputed('media_emb_idx')
|
| 203 |
+
main_tweet_emb = gather_precomputed('main_tweet_text_emb_idx')
|
| 204 |
+
reply_embeds = self.x_reply_encoder(gathered_embeds['wallet'], text_emb, media_emb, main_tweet_emb)
|
| 205 |
+
final_embeds += reply_embeds * event_masks['XReply'].unsqueeze(-1)
|
| 206 |
+
|
| 207 |
+
# XRetweet
|
| 208 |
+
if event_masks['XRetweet'].any():
|
| 209 |
+
orig_text_emb = gather_precomputed('original_post_text_emb_idx')
|
| 210 |
+
orig_media_emb = gather_precomputed('original_post_media_emb_idx')
|
| 211 |
+
retweet_embeds = self.x_retweet_encoder(gathered_embeds['wallet'], gathered_embeds['original_author'], orig_text_emb, orig_media_emb)
|
| 212 |
+
final_embeds += retweet_embeds * event_masks['XRetweet'].unsqueeze(-1)
|
| 213 |
+
|
| 214 |
+
# XQuoteTweet
|
| 215 |
+
if event_masks['XQuoteTweet'].any():
|
| 216 |
+
quoter_text_emb = gather_precomputed('quoter_text_emb_idx')
|
| 217 |
+
orig_text_emb = gather_precomputed('original_post_text_emb_idx')
|
| 218 |
+
orig_media_emb = gather_precomputed('original_post_media_emb_idx')
|
| 219 |
+
quote_embeds = self.x_quote_tweet_encoder(gathered_embeds['wallet'], quoter_text_emb, gathered_embeds['original_author'], orig_text_emb, orig_media_emb)
|
| 220 |
+
final_embeds += quote_embeds * event_masks['XQuoteTweet'].unsqueeze(-1)
|
| 221 |
+
|
| 222 |
+
# PumpReply
|
| 223 |
+
if event_masks['PumpReply'].any():
|
| 224 |
+
text_emb = gather_precomputed('reply_text_emb_idx')
|
| 225 |
+
pump_reply_embeds = self.pump_reply_encoder(gathered_embeds['wallet'], text_emb)
|
| 226 |
+
final_embeds += pump_reply_embeds * event_masks['PumpReply'].unsqueeze(-1)
|
| 227 |
+
|
| 228 |
+
# DexProfile_Updated
|
| 229 |
+
if event_masks['DexProfile_Updated'].any():
|
| 230 |
+
website_emb = gather_precomputed('website_emb_idx')
|
| 231 |
+
twitter_emb = gather_precomputed('twitter_link_emb_idx')
|
| 232 |
+
telegram_emb = gather_precomputed('telegram_link_emb_idx')
|
| 233 |
+
description_emb = gather_precomputed('description_emb_idx')
|
| 234 |
+
profile_embeds = self.dex_profile_encoder(website_emb, twitter_emb, telegram_emb, description_emb)
|
| 235 |
+
# Note: The flags are handled separately in the main model now, so we just add the text embeds
|
| 236 |
+
final_embeds += profile_embeds * event_masks['DexProfile_Updated'].unsqueeze(-1)
|
| 237 |
+
|
| 238 |
+
# Global Trending Hashtags
|
| 239 |
+
trending_mask = event_masks['TikTok_Trending_Hashtag'] | event_masks['XTrending_Hashtag']
|
| 240 |
+
if trending_mask.any():
|
| 241 |
+
hashtag_emb = gather_precomputed('hashtag_name_emb_idx')
|
| 242 |
+
trending_embeds = self.global_trending_encoder(hashtag_emb)
|
| 243 |
+
final_embeds += trending_embeds * trending_mask.unsqueeze(-1)
|
| 244 |
+
|
| 245 |
+
return final_embeds
|
models/__init__.py
ADDED
|
File without changes
|
models/graph_updater.py
ADDED
|
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
# We still use GATv2Conv, just not the to_hetero wrapper
|
| 4 |
+
from torch_geometric.nn import GATv2Conv
|
| 5 |
+
from torch_geometric.data import HeteroData
|
| 6 |
+
from typing import Dict, List, Any
|
| 7 |
+
from collections import defaultdict # For easy aggregation
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
from models.helper_encoders import ContextualTimeEncoder # Type hint for constructor compatibility
|
| 11 |
+
# Import the actual ID_TO_LINK_TYPE mapping
|
| 12 |
+
from models.vocabulary import ID_TO_LINK_TYPE
|
| 13 |
+
# Import other modules needed for the test block
|
| 14 |
+
import models.vocabulary
|
| 15 |
+
from models.wallet_encoder import WalletEncoder
|
| 16 |
+
from models.token_encoder import TokenEncoder
|
| 17 |
+
from models.multi_modal_processor import MultiModalEncoder
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class _TransferLinkEncoder(nn.Module):
|
| 21 |
+
"""Encodes: transfer amount only (timestamps removed)."""
|
| 22 |
+
def __init__(self, out_dim: int, dtype: torch.dtype):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.proj = nn.Sequential(
|
| 25 |
+
nn.Linear(1, out_dim),
|
| 26 |
+
nn.GELU(),
|
| 27 |
+
nn.Linear(out_dim, out_dim)
|
| 28 |
+
)
|
| 29 |
+
self.dtype = dtype
|
| 30 |
+
|
| 31 |
+
def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor:
|
| 32 |
+
return torch.sign(x) * torch.log1p(torch.abs(x))
|
| 33 |
+
|
| 34 |
+
def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
|
| 35 |
+
amounts = torch.tensor([[l.get('amount', 0.0)] for l in links], device=device, dtype=self.dtype)
|
| 36 |
+
features = self._safe_signed_log(amounts)
|
| 37 |
+
|
| 38 |
+
return self.proj(features)
|
| 39 |
+
|
| 40 |
+
class _BundleTradeLinkEncoder(nn.Module):
|
| 41 |
+
"""Encodes: total_amount across bundle (timestamps removed)."""
|
| 42 |
+
def __init__(self, out_dim: int, dtype: torch.dtype):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.proj = nn.Sequential(
|
| 45 |
+
nn.Linear(1, out_dim),
|
| 46 |
+
nn.GELU(),
|
| 47 |
+
nn.Linear(out_dim, out_dim)
|
| 48 |
+
)
|
| 49 |
+
self.dtype = dtype
|
| 50 |
+
|
| 51 |
+
def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor:
|
| 52 |
+
return torch.sign(x) * torch.log1p(torch.abs(x))
|
| 53 |
+
|
| 54 |
+
def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
|
| 55 |
+
totals = torch.tensor([[l.get('total_amount', 0.0)] for l in links], device=device, dtype=self.dtype)
|
| 56 |
+
total_embeds = self._safe_signed_log(totals)
|
| 57 |
+
|
| 58 |
+
return self.proj(total_embeds)
|
| 59 |
+
|
| 60 |
+
class _CopiedTradeLinkEncoder(nn.Module):
|
| 61 |
+
""" Encodes: 10 numerical features """
|
| 62 |
+
def __init__(self, in_features: int, out_dim: int, dtype: torch.dtype): # Added dtype
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.in_features = in_features
|
| 65 |
+
self.norm = nn.LayerNorm(in_features)
|
| 66 |
+
self.mlp = nn.Sequential(
|
| 67 |
+
nn.Linear(in_features, out_dim * 2), nn.GELU(),
|
| 68 |
+
nn.Linear(out_dim * 2, out_dim)
|
| 69 |
+
)
|
| 70 |
+
self.dtype = dtype # Store dtype
|
| 71 |
+
|
| 72 |
+
def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor:
|
| 73 |
+
return torch.sign(x) * torch.log1p(torch.abs(x))
|
| 74 |
+
|
| 75 |
+
def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
|
| 76 |
+
num_data = []
|
| 77 |
+
for l in links:
|
| 78 |
+
# --- FIXED: Only use the 6 essential features ---
|
| 79 |
+
num_data.append([
|
| 80 |
+
l.get('time_gap_on_buy_sec', 0), l.get('time_gap_on_sell_sec', 0),
|
| 81 |
+
l.get('leader_pnl', 0), l.get('follower_pnl', 0),
|
| 82 |
+
l.get('follower_buy_total', 0), l.get('follower_sell_total', 0)
|
| 83 |
+
])
|
| 84 |
+
# Create tensor with correct dtype
|
| 85 |
+
x = torch.tensor(num_data, device=device, dtype=self.dtype)
|
| 86 |
+
# Input to norm must match norm's dtype
|
| 87 |
+
x_norm = self.norm(self._safe_signed_log(x))
|
| 88 |
+
return self.mlp(x_norm)
|
| 89 |
+
|
| 90 |
+
class _CoordinatedActivityLinkEncoder(nn.Module):
|
| 91 |
+
""" Encodes: 2 numerical features """
|
| 92 |
+
def __init__(self, in_features: int, out_dim: int, dtype: torch.dtype): # Added dtype
|
| 93 |
+
super().__init__()
|
| 94 |
+
self.in_features = in_features
|
| 95 |
+
self.norm = nn.LayerNorm(in_features)
|
| 96 |
+
self.mlp = nn.Sequential(
|
| 97 |
+
nn.Linear(in_features, out_dim), nn.GELU(),
|
| 98 |
+
nn.Linear(out_dim, out_dim)
|
| 99 |
+
)
|
| 100 |
+
self.dtype = dtype # Store dtype
|
| 101 |
+
|
| 102 |
+
def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor:
|
| 103 |
+
return torch.sign(x) * torch.log1p(torch.abs(x))
|
| 104 |
+
|
| 105 |
+
def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
|
| 106 |
+
num_data = []
|
| 107 |
+
for l in links:
|
| 108 |
+
num_data.append([
|
| 109 |
+
l.get('time_gap_on_first_sec', 0), l.get('time_gap_on_second_sec', 0)
|
| 110 |
+
])
|
| 111 |
+
# Create tensor with correct dtype
|
| 112 |
+
x = torch.tensor(num_data, device=device, dtype=self.dtype)
|
| 113 |
+
x_norm = self.norm(self._safe_signed_log(x))
|
| 114 |
+
return self.mlp(x_norm)
|
| 115 |
+
|
| 116 |
+
class _MintedLinkEncoder(nn.Module):
|
| 117 |
+
"""Encodes: buy_amount only (timestamps removed)."""
|
| 118 |
+
def __init__(self, out_dim: int, dtype: torch.dtype):
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.proj = nn.Sequential(
|
| 121 |
+
nn.Linear(1, out_dim),
|
| 122 |
+
nn.GELU(),
|
| 123 |
+
nn.Linear(out_dim, out_dim)
|
| 124 |
+
)
|
| 125 |
+
self.dtype = dtype # Store dtype
|
| 126 |
+
|
| 127 |
+
def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor:
|
| 128 |
+
return torch.sign(x) * torch.log1p(torch.abs(x))
|
| 129 |
+
|
| 130 |
+
def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
|
| 131 |
+
nums = torch.tensor([[l['buy_amount']] for l in links], device=device, dtype=self.dtype)
|
| 132 |
+
|
| 133 |
+
num_embeds = self._safe_signed_log(nums)
|
| 134 |
+
|
| 135 |
+
return self.proj(num_embeds)
|
| 136 |
+
|
| 137 |
+
class _SnipedLinkEncoder(nn.Module):
|
| 138 |
+
""" Encodes: rank, sniped_amount """
|
| 139 |
+
def __init__(self, in_features: int, out_dim: int, dtype: torch.dtype): # Added dtype
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.norm = nn.LayerNorm(in_features)
|
| 142 |
+
self.mlp = nn.Sequential(nn.Linear(in_features, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim))
|
| 143 |
+
self.dtype = dtype # Store dtype
|
| 144 |
+
|
| 145 |
+
def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
|
| 146 |
+
num_data = [[l.get('rank', 0), l.get('sniped_amount', 0)] for l in links]
|
| 147 |
+
# Create tensor with correct dtype
|
| 148 |
+
x = torch.tensor(num_data, device=device, dtype=self.dtype)
|
| 149 |
+
|
| 150 |
+
# --- FIXED: Selectively log-scale features ---
|
| 151 |
+
# Invert rank so 1 is highest, treat as linear. Log-scale sniped_amount.
|
| 152 |
+
x[:, 0] = 1.0 / torch.clamp(x[:, 0], min=1.0) # Invert rank, clamp to avoid division by zero
|
| 153 |
+
x[:, 1] = torch.sign(x[:, 1]) * torch.log1p(torch.abs(x[:, 1])) # Log-scale amount
|
| 154 |
+
|
| 155 |
+
x_norm = self.norm(x)
|
| 156 |
+
return self.mlp(x_norm)
|
| 157 |
+
|
| 158 |
+
class _LockedSupplyLinkEncoder(nn.Module):
|
| 159 |
+
""" Encodes: amount """
|
| 160 |
+
def __init__(self, out_dim: int, dtype: torch.dtype): # Removed time_encoder
|
| 161 |
+
super().__init__()
|
| 162 |
+
self.proj = nn.Sequential(
|
| 163 |
+
nn.Linear(1, out_dim),
|
| 164 |
+
nn.GELU(),
|
| 165 |
+
nn.Linear(out_dim, out_dim)
|
| 166 |
+
)
|
| 167 |
+
self.dtype = dtype # Store dtype
|
| 168 |
+
|
| 169 |
+
def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor:
|
| 170 |
+
return torch.sign(x) * torch.log1p(torch.abs(x))
|
| 171 |
+
|
| 172 |
+
def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
|
| 173 |
+
nums = torch.tensor([[l['amount']] for l in links], device=device, dtype=self.dtype)
|
| 174 |
+
num_embeds = self._safe_signed_log(nums)
|
| 175 |
+
return self.proj(num_embeds)
|
| 176 |
+
|
| 177 |
+
class _BurnedLinkEncoder(nn.Module):
|
| 178 |
+
"""Encodes: burned amount (timestamps removed)."""
|
| 179 |
+
def __init__(self, out_dim: int, dtype: torch.dtype):
|
| 180 |
+
super().__init__()
|
| 181 |
+
self.proj = nn.Sequential(
|
| 182 |
+
nn.Linear(1, out_dim),
|
| 183 |
+
nn.GELU(),
|
| 184 |
+
nn.Linear(out_dim, out_dim)
|
| 185 |
+
)
|
| 186 |
+
self.dtype = dtype
|
| 187 |
+
|
| 188 |
+
def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor:
|
| 189 |
+
return torch.sign(x) * torch.log1p(torch.abs(x))
|
| 190 |
+
|
| 191 |
+
def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
|
| 192 |
+
amounts = torch.tensor([[l.get('amount', 0.0)] for l in links], device=device, dtype=self.dtype)
|
| 193 |
+
amount_embeds = self._safe_signed_log(amounts)
|
| 194 |
+
|
| 195 |
+
return self.proj(amount_embeds)
|
| 196 |
+
|
| 197 |
+
class _ProvidedLiquidityLinkEncoder(nn.Module):
|
| 198 |
+
"""Encodes: quote amount (timestamps removed)."""
|
| 199 |
+
def __init__(self, out_dim: int, dtype: torch.dtype):
|
| 200 |
+
super().__init__()
|
| 201 |
+
self.proj = nn.Sequential(
|
| 202 |
+
nn.Linear(1, out_dim),
|
| 203 |
+
nn.GELU(),
|
| 204 |
+
nn.Linear(out_dim, out_dim)
|
| 205 |
+
)
|
| 206 |
+
self.dtype = dtype
|
| 207 |
+
|
| 208 |
+
def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor:
|
| 209 |
+
return torch.sign(x) * torch.log1p(torch.abs(x))
|
| 210 |
+
|
| 211 |
+
def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
|
| 212 |
+
quote_amounts = torch.tensor([[l.get('amount_quote', 0.0)] for l in links], device=device, dtype=self.dtype)
|
| 213 |
+
quote_embeds = self._safe_signed_log(quote_amounts)
|
| 214 |
+
|
| 215 |
+
return self.proj(quote_embeds)
|
| 216 |
+
|
| 217 |
+
class _WhaleOfLinkEncoder(nn.Module):
|
| 218 |
+
""" Encodes: holding_pct_at_creation """
|
| 219 |
+
def __init__(self, out_dim: int, dtype: torch.dtype):
|
| 220 |
+
super().__init__()
|
| 221 |
+
self.mlp = nn.Sequential(
|
| 222 |
+
nn.Linear(1, out_dim),
|
| 223 |
+
nn.GELU(),
|
| 224 |
+
nn.Linear(out_dim, out_dim)
|
| 225 |
+
)
|
| 226 |
+
self.dtype = dtype
|
| 227 |
+
|
| 228 |
+
def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
|
| 229 |
+
vals = torch.tensor([[l.get('holding_pct_at_creation', 0.0)] for l in links], device=device, dtype=self.dtype)
|
| 230 |
+
vals_log = torch.sign(vals) * torch.log1p(torch.abs(vals))
|
| 231 |
+
return self.mlp(vals_log)
|
| 232 |
+
|
| 233 |
+
class _TopTraderOfLinkEncoder(nn.Module):
|
| 234 |
+
""" Encodes: pnl_at_creation """
|
| 235 |
+
def __init__(self, out_dim: int, dtype: torch.dtype): # Removed in_features
|
| 236 |
+
super().__init__()
|
| 237 |
+
self.mlp = nn.Sequential(nn.Linear(1, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim))
|
| 238 |
+
self.dtype = dtype
|
| 239 |
+
|
| 240 |
+
def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
|
| 241 |
+
num_data = [[l.get('pnl_at_creation', 0)] for l in links]
|
| 242 |
+
x = torch.tensor(num_data, device=device, dtype=self.dtype)
|
| 243 |
+
log_scaled_x = torch.sign(x) * torch.log1p(torch.abs(x))
|
| 244 |
+
return self.mlp(log_scaled_x)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class RelationalGATBlock(nn.Module):
|
| 248 |
+
"""
|
| 249 |
+
Shared GATv2Conv that remains relation-aware by concatenating a learned
|
| 250 |
+
relation embedding to every edge attribute before message passing.
|
| 251 |
+
"""
|
| 252 |
+
|
| 253 |
+
def __init__(
|
| 254 |
+
self,
|
| 255 |
+
node_dim: int,
|
| 256 |
+
edge_attr_dim: int,
|
| 257 |
+
n_heads: int,
|
| 258 |
+
relations: List[str],
|
| 259 |
+
dtype: torch.dtype,
|
| 260 |
+
):
|
| 261 |
+
super().__init__()
|
| 262 |
+
self.rel_to_id = {name: idx for idx, name in enumerate(relations)}
|
| 263 |
+
self.edge_attr_dim = edge_attr_dim
|
| 264 |
+
self.rel_emb = nn.Embedding(len(relations), edge_attr_dim)
|
| 265 |
+
self.conv = GATv2Conv(
|
| 266 |
+
in_channels=node_dim,
|
| 267 |
+
out_channels=node_dim,
|
| 268 |
+
heads=n_heads,
|
| 269 |
+
concat=False,
|
| 270 |
+
dropout=0.1,
|
| 271 |
+
add_self_loops=False,
|
| 272 |
+
edge_dim=edge_attr_dim * 2, # concat of edge attr + relation emb
|
| 273 |
+
).to(dtype)
|
| 274 |
+
|
| 275 |
+
def forward(
|
| 276 |
+
self,
|
| 277 |
+
x_src: torch.Tensor,
|
| 278 |
+
x_dst: torch.Tensor,
|
| 279 |
+
edge_index: torch.Tensor,
|
| 280 |
+
edge_attr: torch.Tensor,
|
| 281 |
+
rel_type: str,
|
| 282 |
+
) -> torch.Tensor:
|
| 283 |
+
num_edges = edge_index.size(1)
|
| 284 |
+
device = edge_index.device
|
| 285 |
+
|
| 286 |
+
if edge_attr is None:
|
| 287 |
+
edge_attr = torch.zeros(
|
| 288 |
+
num_edges,
|
| 289 |
+
self.edge_attr_dim,
|
| 290 |
+
device=device,
|
| 291 |
+
dtype=x_src.dtype,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
rel_id = self.rel_to_id.get(rel_type)
|
| 295 |
+
if rel_id is None:
|
| 296 |
+
raise KeyError(f"Relation '{rel_type}' not registered in RelationalGATBlock.")
|
| 297 |
+
|
| 298 |
+
rel_feat = self.rel_emb.weight[rel_id].to(edge_attr.dtype)
|
| 299 |
+
rel_feat = rel_feat.expand(num_edges, -1)
|
| 300 |
+
augmented_attr = torch.cat([edge_attr, rel_feat], dim=-1)
|
| 301 |
+
|
| 302 |
+
return self.conv((x_src, x_dst), edge_index, edge_attr=augmented_attr)
|
| 303 |
+
# =============================================================================
|
| 304 |
+
# 2. The Main GraphUpdater (GNN) - MANUAL HETEROGENEOUS IMPLEMENTATION
|
| 305 |
+
# =============================================================================
|
| 306 |
+
|
| 307 |
+
class GraphUpdater(nn.Module):
|
| 308 |
+
"""
|
| 309 |
+
FIXED: Manually implements Heterogeneous GNN logic using separate GATv2Conv
|
| 310 |
+
layers for each edge type, bypassing the problematic `to_hetero` wrapper.
|
| 311 |
+
"""
|
| 312 |
+
|
| 313 |
+
def __init__(self,time_encoder: ContextualTimeEncoder, edge_attr_dim: int = 64,
|
| 314 |
+
n_heads: int = 4, num_layers: int = 2, node_dim: int = 2048, dtype: torch.dtype = torch.float16):
|
| 315 |
+
super().__init__()
|
| 316 |
+
self.node_dim = node_dim
|
| 317 |
+
self.edge_attr_dim = edge_attr_dim
|
| 318 |
+
self.num_layers = num_layers
|
| 319 |
+
self.dtype = dtype
|
| 320 |
+
|
| 321 |
+
# --- Instantiate all 11 Link Feature Encoders --- (Unchanged)
|
| 322 |
+
self.edge_encoders = nn.ModuleDict({
|
| 323 |
+
'TransferLink': _TransferLinkEncoder(edge_attr_dim, dtype=dtype),
|
| 324 |
+
'TransferLinkToken': _TransferLinkEncoder(edge_attr_dim, dtype=dtype),
|
| 325 |
+
'BundleTradeLink': _BundleTradeLinkEncoder(edge_attr_dim, dtype=dtype),
|
| 326 |
+
'CopiedTradeLink': _CopiedTradeLinkEncoder(6, edge_attr_dim, dtype=dtype), # FIXED: in_features=6
|
| 327 |
+
'CoordinatedActivityLink': _CoordinatedActivityLinkEncoder(2, edge_attr_dim, dtype=dtype),
|
| 328 |
+
'MintedLink': _MintedLinkEncoder(edge_attr_dim, dtype=dtype),
|
| 329 |
+
'SnipedLink': _SnipedLinkEncoder(2, edge_attr_dim, dtype=dtype),
|
| 330 |
+
'LockedSupplyLink': _LockedSupplyLinkEncoder(edge_attr_dim, dtype=dtype), # FIXED: No time_encoder
|
| 331 |
+
'BurnedLink': _BurnedLinkEncoder(edge_attr_dim, dtype=dtype),
|
| 332 |
+
'ProvidedLiquidityLink': _ProvidedLiquidityLinkEncoder(edge_attr_dim, dtype=dtype),
|
| 333 |
+
'WhaleOfLink': _WhaleOfLinkEncoder(edge_attr_dim, dtype=dtype), # FIXED: No in_features
|
| 334 |
+
'TopTraderOfLink': _TopTraderOfLinkEncoder(edge_attr_dim, dtype=dtype), # FIXED: No in_features
|
| 335 |
+
}).to(dtype)
|
| 336 |
+
|
| 337 |
+
# --- Define shared relational GNN blocks per meta edge direction ---
|
| 338 |
+
self.edge_groups = self._build_edge_groups()
|
| 339 |
+
self.conv_layers = nn.ModuleList()
|
| 340 |
+
for _ in range(num_layers):
|
| 341 |
+
conv_dict = nn.ModuleDict()
|
| 342 |
+
for (src_type, dst_type), relations in self.edge_groups.items():
|
| 343 |
+
conv_dict[f"{src_type}__{dst_type}"] = RelationalGATBlock(
|
| 344 |
+
node_dim=node_dim,
|
| 345 |
+
edge_attr_dim=edge_attr_dim,
|
| 346 |
+
n_heads=n_heads,
|
| 347 |
+
relations=relations,
|
| 348 |
+
dtype=dtype,
|
| 349 |
+
)
|
| 350 |
+
self.conv_layers.append(conv_dict)
|
| 351 |
+
|
| 352 |
+
self.norm = nn.LayerNorm(node_dim)
|
| 353 |
+
self.to(dtype) # Move norm layer and ModuleList container
|
| 354 |
+
|
| 355 |
+
def _build_edge_groups(self) -> Dict[tuple, List[str]]:
|
| 356 |
+
"""Group relations by (src_type, dst_type) so conv weights can be shared."""
|
| 357 |
+
groups: Dict[tuple, List[str]] = defaultdict(list)
|
| 358 |
+
|
| 359 |
+
wallet_wallet_links = ['TransferLink', 'BundleTradeLink', 'CopiedTradeLink', 'CoordinatedActivityLink']
|
| 360 |
+
wallet_token_links = [
|
| 361 |
+
'TransferLinkToken', 'MintedLink', 'SnipedLink', 'LockedSupplyLink',
|
| 362 |
+
'BurnedLink', 'ProvidedLiquidityLink', 'WhaleOfLink', 'TopTraderOfLink'
|
| 363 |
+
]
|
| 364 |
+
|
| 365 |
+
for link in wallet_wallet_links:
|
| 366 |
+
groups[('wallet', 'wallet')].append(link)
|
| 367 |
+
groups[('wallet', 'wallet')].append(f"rev_{link}")
|
| 368 |
+
|
| 369 |
+
for link in wallet_token_links:
|
| 370 |
+
groups[('wallet', 'token')].append(link)
|
| 371 |
+
groups[('token', 'wallet')].append(f"rev_{link}")
|
| 372 |
+
|
| 373 |
+
return groups
|
| 374 |
+
|
| 375 |
+
def forward(
|
| 376 |
+
self,
|
| 377 |
+
x_dict: Dict[str, torch.Tensor],
|
| 378 |
+
edge_data_dict: Dict[str, Dict[str, Any]]
|
| 379 |
+
) -> Dict[str, torch.Tensor]:
|
| 380 |
+
device = x_dict['wallet'].device
|
| 381 |
+
|
| 382 |
+
# --- 1. Encode Edge Attributes ---
|
| 383 |
+
edge_index_dict = {}
|
| 384 |
+
edge_attr_dict = {}
|
| 385 |
+
|
| 386 |
+
for link_name, data in edge_data_dict.items():
|
| 387 |
+
edge_index = data.get('edge_index')
|
| 388 |
+
links = data.get('links', [])
|
| 389 |
+
|
| 390 |
+
# Check if edge_index is valid before proceeding
|
| 391 |
+
if edge_index is None or edge_index.numel() == 0 or not links:
|
| 392 |
+
continue # Skip if no links or index of this type
|
| 393 |
+
|
| 394 |
+
edge_index = edge_index.to(device)
|
| 395 |
+
|
| 396 |
+
# Use vocabulary to get the triplet (src, rel, dst)
|
| 397 |
+
# Make sure ID_TO_LINK_TYPE is correctly populated
|
| 398 |
+
if link_name not in vocabulary.LINK_NAME_TO_TRIPLET:
|
| 399 |
+
print(f"Warning: Link name '{link_name}' not found in vocabulary.LINK_NAME_TO_TRIPLET. Skipping.")
|
| 400 |
+
continue
|
| 401 |
+
src_type, rel_type, dst_type = vocabulary.LINK_NAME_TO_TRIPLET[link_name]
|
| 402 |
+
|
| 403 |
+
# Check if encoder exists for this link name
|
| 404 |
+
if link_name not in self.edge_encoders:
|
| 405 |
+
print(f"Warning: No edge encoder found for link type '{link_name}'. Skipping edge attributes.")
|
| 406 |
+
edge_attr = None # Or handle differently if attributes are essential
|
| 407 |
+
else:
|
| 408 |
+
edge_attr = self.edge_encoders[link_name](links, device).to(self.dtype)
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
# Forward link
|
| 412 |
+
fwd_key = (src_type, rel_type, dst_type)
|
| 413 |
+
edge_index_dict[fwd_key] = edge_index
|
| 414 |
+
if edge_attr is not None:
|
| 415 |
+
edge_attr_dict[fwd_key] = edge_attr
|
| 416 |
+
|
| 417 |
+
# Reverse link
|
| 418 |
+
# Ensure edge_index has the right shape for flipping
|
| 419 |
+
if edge_index.shape[0] == 2:
|
| 420 |
+
rev_edge_index = edge_index[[1, 0]]
|
| 421 |
+
rev_rel_type = f'rev_{rel_type}'
|
| 422 |
+
rev_key = (dst_type, rev_rel_type, src_type)
|
| 423 |
+
edge_index_dict[rev_key] = rev_edge_index
|
| 424 |
+
if edge_attr is not None:
|
| 425 |
+
# Re-use same attributes for reverse edge
|
| 426 |
+
edge_attr_dict[rev_key] = edge_attr
|
| 427 |
+
else:
|
| 428 |
+
print(f"Warning: Edge index for {link_name} has unexpected shape {edge_index.shape}. Cannot create reverse edge.")
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
# --- 2. Run GNN Layers MANUALLY ---
|
| 432 |
+
x_out = x_dict
|
| 433 |
+
for i in range(self.num_layers):
|
| 434 |
+
# Initialize aggregation tensors for each node type that exists in the input
|
| 435 |
+
msg_aggregates = {
|
| 436 |
+
node_type: torch.zeros_like(x_node)
|
| 437 |
+
for node_type, x_node in x_out.items()
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
# --- Message Passing ---
|
| 441 |
+
for edge_type_tuple in edge_index_dict.keys(): # Iterate through edges PRESENT in the batch
|
| 442 |
+
src_type, rel_type, dst_type = edge_type_tuple
|
| 443 |
+
edge_index = edge_index_dict[edge_type_tuple]
|
| 444 |
+
edge_attr = edge_attr_dict.get(edge_type_tuple) # Use .get() in case attr is None
|
| 445 |
+
|
| 446 |
+
x_src = x_out.get(src_type)
|
| 447 |
+
x_dst = x_out.get(dst_type)
|
| 448 |
+
if x_src is None or x_dst is None:
|
| 449 |
+
print(f"Warning: Missing node embeddings for types {src_type}->{dst_type}. Skipping.")
|
| 450 |
+
continue
|
| 451 |
+
|
| 452 |
+
block_key = f"{src_type}__{dst_type}"
|
| 453 |
+
if block_key not in self.conv_layers[i]:
|
| 454 |
+
print(f"Warning: Relational block for {block_key} not found in layer {i}. Skipping.")
|
| 455 |
+
continue
|
| 456 |
+
block = self.conv_layers[i][block_key]
|
| 457 |
+
|
| 458 |
+
try:
|
| 459 |
+
messages = block(x_src, x_dst, edge_index, edge_attr, rel_type)
|
| 460 |
+
except KeyError:
|
| 461 |
+
print(f"Warning: Relation '{rel_type}' missing in block {block_key}. Skipping.")
|
| 462 |
+
continue
|
| 463 |
+
|
| 464 |
+
# *** THE FIX ***
|
| 465 |
+
# Use scatter_add_ to accumulate messages for the destination node type.
|
| 466 |
+
# This correctly handles multiple edge types pointing to the same node type.
|
| 467 |
+
msg_aggregates[dst_type].scatter_add_(0, edge_index[1].unsqueeze(1).expand_as(messages), messages)
|
| 468 |
+
|
| 469 |
+
# --- Aggregation & Update (Residual Connection) ---
|
| 470 |
+
x_next = {}
|
| 471 |
+
for node_type, x_original in x_out.items():
|
| 472 |
+
# Check if messages were computed and stored correctly
|
| 473 |
+
if node_type in msg_aggregates and msg_aggregates[node_type].shape[0] > 0:
|
| 474 |
+
aggregated_msgs = msg_aggregates[node_type]
|
| 475 |
+
# Ensure dimensions match before adding
|
| 476 |
+
if x_original.shape == aggregated_msgs.shape:
|
| 477 |
+
x_next[node_type] = self.norm(x_original + aggregated_msgs)
|
| 478 |
+
else:
|
| 479 |
+
print(f"Warning: Shape mismatch for node type {node_type} during update. Original: {x_original.shape}, Aggregated: {aggregated_msgs.shape}. Skipping residual connection.")
|
| 480 |
+
x_next[node_type] = x_original # Fallback
|
| 481 |
+
else:
|
| 482 |
+
x_next[node_type] = x_original
|
| 483 |
+
|
| 484 |
+
x_out = x_next
|
| 485 |
+
|
| 486 |
+
return x_out
|
models/helper_encoders.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import math
|
| 4 |
+
import datetime
|
| 5 |
+
from typing import Dict, List, Any, Optional
|
| 6 |
+
|
| 7 |
+
class ContextualTimeEncoder(nn.Module):
|
| 8 |
+
def __init__(self, output_dim: int = 128, dtype: torch.dtype = torch.float32):
|
| 9 |
+
"""
|
| 10 |
+
Encodes a Unix timestamp with support for mixed precision.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
output_dim (int): The final dimension of the output embedding.
|
| 14 |
+
dtype (torch.dtype): The data type for the model's parameters (e.g., torch.float16).
|
| 15 |
+
"""
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.dtype = dtype
|
| 18 |
+
if output_dim < 12:
|
| 19 |
+
raise ValueError(f"output_dim must be at least 12, but got {output_dim}")
|
| 20 |
+
|
| 21 |
+
ts_dim = output_dim // 2
|
| 22 |
+
hour_dim = output_dim // 4
|
| 23 |
+
day_dim = output_dim - ts_dim - hour_dim
|
| 24 |
+
|
| 25 |
+
self.ts_dim = ts_dim + (ts_dim % 2)
|
| 26 |
+
self.hour_dim = hour_dim + (hour_dim % 2)
|
| 27 |
+
self.day_dim = day_dim + (day_dim % 2)
|
| 28 |
+
|
| 29 |
+
total_internal_dim = self.ts_dim + self.hour_dim + self.day_dim
|
| 30 |
+
|
| 31 |
+
self.projection = nn.Linear(total_internal_dim, output_dim)
|
| 32 |
+
|
| 33 |
+
# Cast the entire module to the specified dtype
|
| 34 |
+
self.to(dtype)
|
| 35 |
+
|
| 36 |
+
def _sinusoidal_encode(self, values: torch.Tensor, d_model: int) -> torch.Tensor:
|
| 37 |
+
device = values.device
|
| 38 |
+
half_dim = d_model // 2
|
| 39 |
+
|
| 40 |
+
# Calculations for sinusoidal encoding are more stable in float32
|
| 41 |
+
div_term = torch.exp(torch.arange(0, half_dim, device=device).float() * -(math.log(10000.0) / half_dim))
|
| 42 |
+
args = values.float().unsqueeze(-1) * div_term
|
| 43 |
+
|
| 44 |
+
return torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
|
| 45 |
+
|
| 46 |
+
def _cyclical_encode(self, values: torch.Tensor, d_model: int, max_val: float) -> torch.Tensor:
|
| 47 |
+
device = values.device
|
| 48 |
+
norm_values = (values.float() / max_val) * 2 * math.pi
|
| 49 |
+
|
| 50 |
+
half_dim = d_model // 2
|
| 51 |
+
sin_args = norm_values.unsqueeze(-1).repeat(1, half_dim)
|
| 52 |
+
cos_args = norm_values.unsqueeze(-1).repeat(1, half_dim)
|
| 53 |
+
|
| 54 |
+
return torch.cat([torch.sin(sin_args), torch.cos(cos_args)], dim=-1)
|
| 55 |
+
|
| 56 |
+
def forward(self, timestamps: torch.Tensor) -> torch.Tensor:
|
| 57 |
+
device = self.projection.weight.device
|
| 58 |
+
|
| 59 |
+
# 1. Store original shape (e.g., [B, L]) and flatten
|
| 60 |
+
original_shape = timestamps.shape
|
| 61 |
+
timestamps_flat = timestamps.flatten().float() # Shape [N_total]
|
| 62 |
+
|
| 63 |
+
# 2. Sinusoidal encode (already vectorized)
|
| 64 |
+
ts_encoding = self._sinusoidal_encode(timestamps_flat, self.ts_dim)
|
| 65 |
+
|
| 66 |
+
# 3. List comprehension (this is the only non-vectorized part)
|
| 67 |
+
# This loop is now correct, as it iterates over the 1D flat tensor
|
| 68 |
+
hours = torch.tensor([datetime.datetime.fromtimestamp(ts.item(), tz=datetime.timezone.utc).hour for ts in timestamps_flat], device=device, dtype=torch.float32)
|
| 69 |
+
days = torch.tensor([datetime.datetime.fromtimestamp(ts.item(), tz=datetime.timezone.utc).weekday() for ts in timestamps_flat], device=device, dtype=torch.float32)
|
| 70 |
+
|
| 71 |
+
# 4. Cyclical encode (already vectorized)
|
| 72 |
+
hour_encoding = self._cyclical_encode(hours, self.hour_dim, max_val=24.0)
|
| 73 |
+
day_encoding = self._cyclical_encode(days, self.day_dim, max_val=7.0)
|
| 74 |
+
|
| 75 |
+
# 5. Combine and project
|
| 76 |
+
combined_encoding = torch.cat([ts_encoding, hour_encoding, day_encoding], dim=1)
|
| 77 |
+
projected = self.projection(combined_encoding.to(self.dtype)) # Shape [N_total, output_dim]
|
| 78 |
+
|
| 79 |
+
# 6. Reshape to match original (e.g., [B, L, output_dim])
|
| 80 |
+
output_shape = original_shape + (self.projection.out_features,)
|
| 81 |
+
return projected.view(output_shape)
|
| 82 |
+
|
| 83 |
+
def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
| 84 |
+
mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
|
| 85 |
+
summed = torch.sum(last_hidden_state * mask, 1)
|
| 86 |
+
denom = torch.clamp(mask.sum(1), min=1e-9)
|
| 87 |
+
return summed / denom
|
models/model.py
ADDED
|
@@ -0,0 +1,1009 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model.py (REFACTORED AND FIXED)
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from transformers import AutoConfig, AutoModel
|
| 7 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
# --- NOW, we import all the encoders ---
|
| 10 |
+
from models.helper_encoders import ContextualTimeEncoder
|
| 11 |
+
from models.token_encoder import TokenEncoder
|
| 12 |
+
from models.wallet_encoder import WalletEncoder
|
| 13 |
+
from models.graph_updater import GraphUpdater
|
| 14 |
+
from models.ohlc_embedder import OHLCEmbedder
|
| 15 |
+
from models.HoldersEncoder import HolderDistributionEncoder # NEW
|
| 16 |
+
from models.SocialEncoders import SocialEncoder # NEW
|
| 17 |
+
import models.vocabulary as vocab # For vocab sizes
|
| 18 |
+
|
| 19 |
+
class Oracle(nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
"""
|
| 23 |
+
def __init__(self,
|
| 24 |
+
token_encoder: TokenEncoder,
|
| 25 |
+
wallet_encoder: WalletEncoder,
|
| 26 |
+
graph_updater: GraphUpdater,
|
| 27 |
+
ohlc_embedder: OHLCEmbedder, # NEW
|
| 28 |
+
time_encoder: ContextualTimeEncoder,
|
| 29 |
+
num_event_types: int,
|
| 30 |
+
multi_modal_dim: int,
|
| 31 |
+
event_pad_id: int,
|
| 32 |
+
event_type_to_id: Dict[str, int],
|
| 33 |
+
model_config_name: str = "Qwen/Qwen3-0.6B",
|
| 34 |
+
quantiles: List[float] = [0.1, 0.5, 0.9],
|
| 35 |
+
horizons_seconds: List[int] = [30, 60, 120, 240, 420],
|
| 36 |
+
dtype: torch.dtype = torch.bfloat16):
|
| 37 |
+
|
| 38 |
+
super().__init__()
|
| 39 |
+
|
| 40 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 41 |
+
self.device = torch.device(device)
|
| 42 |
+
self.dtype = dtype
|
| 43 |
+
self.multi_modal_dim = multi_modal_dim
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
self.quantiles = quantiles
|
| 47 |
+
self.horizons_seconds = horizons_seconds
|
| 48 |
+
self.num_outputs = len(quantiles) * len(horizons_seconds)
|
| 49 |
+
self.dtype = dtype
|
| 50 |
+
|
| 51 |
+
# --- 2. Load Qwen3 Configuration (architecture only; training from scratch) ---
|
| 52 |
+
model_config = AutoConfig.from_pretrained(model_config_name, trust_remote_code=True)
|
| 53 |
+
self.d_model = model_config.hidden_size
|
| 54 |
+
self.model = AutoModel.from_config(model_config, trust_remote_code=True)
|
| 55 |
+
self.model.to(self.device, dtype=self.dtype)
|
| 56 |
+
|
| 57 |
+
# Quantile prediction head (maps pooled hidden state -> flattened horizon/quantile grid)
|
| 58 |
+
self.quantile_head = nn.Sequential(
|
| 59 |
+
nn.Linear(self.d_model, self.d_model),
|
| 60 |
+
nn.GELU(),
|
| 61 |
+
nn.Linear(self.d_model, self.num_outputs)
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
self.event_type_to_id = event_type_to_id
|
| 65 |
+
|
| 66 |
+
# --- 1. Store All Encoders ---
|
| 67 |
+
# Define Token Roles before using them
|
| 68 |
+
self.token_roles = {'main': 0, 'quote': 1, 'trending': 2} # Add trending for future use
|
| 69 |
+
self.main_token_role_id = self.token_roles['main']
|
| 70 |
+
self.quote_token_role_id = self.token_roles['quote']
|
| 71 |
+
self.trending_token_role_id = self.token_roles['trending']
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
self.token_encoder = token_encoder
|
| 75 |
+
self.wallet_encoder = wallet_encoder
|
| 76 |
+
self.graph_updater = graph_updater
|
| 77 |
+
self.ohlc_embedder = ohlc_embedder
|
| 78 |
+
self.time_encoder = time_encoder # Store time_encoder
|
| 79 |
+
|
| 80 |
+
self.social_encoder = SocialEncoder(d_model=self.d_model, dtype=self.dtype) # Now self.d_model is defined
|
| 81 |
+
|
| 82 |
+
# --- 4. Define Sequence Feature Embeddings ---
|
| 83 |
+
self.event_type_embedding = nn.Embedding(num_event_types, self.d_model, padding_idx=event_pad_id)
|
| 84 |
+
|
| 85 |
+
# --- NEW: Token Role Embeddings ---
|
| 86 |
+
self.token_role_embedding = nn.Embedding(len(self.token_roles), self.d_model)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# --- 5. Define Entity Padding (Learnable) ---
|
| 91 |
+
self.pad_wallet_emb = nn.Parameter(torch.zeros(1, self.wallet_encoder.d_model))
|
| 92 |
+
self.pad_token_emb = nn.Parameter(torch.zeros(1, self.token_encoder.output_dim))
|
| 93 |
+
self.pad_ohlc_emb = nn.Parameter(torch.zeros(1, self.ohlc_embedder.output_dim))
|
| 94 |
+
self.pad_precomputed_emb = nn.Parameter(torch.zeros(1, self.multi_modal_dim)) # NEW: For text/images
|
| 95 |
+
|
| 96 |
+
# --- NEW: Instantiate HolderDistributionEncoder internally ---
|
| 97 |
+
self.holder_dist_encoder = HolderDistributionEncoder(
|
| 98 |
+
wallet_embedding_dim=self.wallet_encoder.d_model,
|
| 99 |
+
output_dim=self.d_model,
|
| 100 |
+
dtype=self.dtype # Pass the correct dtype
|
| 101 |
+
)
|
| 102 |
+
self.pad_holder_snapshot_emb = nn.Parameter(torch.zeros(1, self.d_model)) # Output of holder_dist_encoder is d_model
|
| 103 |
+
|
| 104 |
+
# --- 6. Define Projection MLPs ---
|
| 105 |
+
self.time_proj = nn.Linear(self.time_encoder.projection.out_features, self.d_model)
|
| 106 |
+
self.rel_ts_proj = nn.Linear(1, self.d_model)
|
| 107 |
+
self.rel_ts_norm = nn.LayerNorm(1)
|
| 108 |
+
self.wallet_proj = nn.Linear(self.wallet_encoder.d_model, self.d_model)
|
| 109 |
+
self.token_proj = nn.Linear(self.token_encoder.output_dim, self.d_model)
|
| 110 |
+
self.ohlc_proj = nn.Linear(self.ohlc_embedder.output_dim, self.d_model)
|
| 111 |
+
# self.holder_snapshot_proj is no longer needed as HolderDistributionEncoder outputs directly to d_model
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# --- NEW: Layers for Transfer Numerical Features ---
|
| 115 |
+
self.transfer_num_norm = nn.LayerNorm(4) # Normalize the 4 features
|
| 116 |
+
self.transfer_num_proj = nn.Linear(4, self.d_model) # Project to d_model
|
| 117 |
+
|
| 118 |
+
# --- NEW: Layers for Trade Numerical Features ---
|
| 119 |
+
# --- FIXED: Size reduced from 10 to 8 ---
|
| 120 |
+
self.trade_num_norm = nn.LayerNorm(8)
|
| 121 |
+
self.trade_num_proj = nn.Linear(8, self.d_model)
|
| 122 |
+
# --- NEW: Embedding for categorical dex_platform_id ---
|
| 123 |
+
self.dex_platform_embedding = nn.Embedding(vocab.NUM_DEX_PLATFORMS, self.d_model)
|
| 124 |
+
# --- NEW: Embedding for categorical trade_direction ---
|
| 125 |
+
self.trade_direction_embedding = nn.Embedding(2, self.d_model) # 0 for buy, 1 for sell
|
| 126 |
+
# --- FIXED: Embedding for categorical mev_protection is now binary ---
|
| 127 |
+
self.mev_protection_embedding = nn.Embedding(2, self.d_model) # 0 for false, 1 for true
|
| 128 |
+
# --- NEW: Embedding for categorical is_bundle ---
|
| 129 |
+
self.is_bundle_embedding = nn.Embedding(2, self.d_model) # 0 for false, 1 for true
|
| 130 |
+
|
| 131 |
+
# --- NEW: Separate Layers for Deployer Trade Numerical Features ---
|
| 132 |
+
# --- FIXED: Size reduced from 10 to 8 ---
|
| 133 |
+
self.deployer_trade_num_norm = nn.LayerNorm(8)
|
| 134 |
+
self.deployer_trade_num_proj = nn.Linear(8, self.d_model)
|
| 135 |
+
|
| 136 |
+
# --- NEW: Separate Layers for Smart Wallet Trade Numerical Features ---
|
| 137 |
+
# --- FIXED: Size reduced from 10 to 8 ---
|
| 138 |
+
self.smart_wallet_trade_num_norm = nn.LayerNorm(8)
|
| 139 |
+
self.smart_wallet_trade_num_proj = nn.Linear(8, self.d_model)
|
| 140 |
+
|
| 141 |
+
# --- NEW: Layers for PoolCreated Numerical Features ---
|
| 142 |
+
# --- FIXED: Size reduced from 5 to 4 ---
|
| 143 |
+
self.pool_created_num_norm = nn.LayerNorm(2)
|
| 144 |
+
self.pool_created_num_proj = nn.Linear(2, self.d_model)
|
| 145 |
+
|
| 146 |
+
# --- NEW: Layers for LiquidityChange Numerical Features ---
|
| 147 |
+
# --- FIXED: Size reduced from 3 to 2 ---
|
| 148 |
+
self.liquidity_change_num_norm = nn.LayerNorm(1)
|
| 149 |
+
self.liquidity_change_num_proj = nn.Linear(1, self.d_model)
|
| 150 |
+
# --- NEW: Embedding for categorical change_type_id ---
|
| 151 |
+
# --- FIXED: Hardcoded the number of types (add/remove) as per user instruction ---
|
| 152 |
+
self.liquidity_change_type_embedding = nn.Embedding(2, self.d_model)
|
| 153 |
+
|
| 154 |
+
# --- NEW: Layers for FeeCollected Numerical Features ---
|
| 155 |
+
self.fee_collected_num_norm = nn.LayerNorm(1) # sol_amount only
|
| 156 |
+
self.fee_collected_num_proj = nn.Linear(1, self.d_model)
|
| 157 |
+
|
| 158 |
+
# --- NEW: Layers for TokenBurn Numerical Features ---
|
| 159 |
+
self.token_burn_num_norm = nn.LayerNorm(2) # amount_pct, amount_tokens
|
| 160 |
+
self.token_burn_num_proj = nn.Linear(2, self.d_model)
|
| 161 |
+
|
| 162 |
+
# --- NEW: Layers for SupplyLock Numerical Features ---
|
| 163 |
+
self.supply_lock_num_norm = nn.LayerNorm(2) # amount_pct, lock_duration
|
| 164 |
+
self.supply_lock_num_proj = nn.Linear(2, self.d_model)
|
| 165 |
+
|
| 166 |
+
# --- NEW: Layers for OnChain_Snapshot Numerical Features ---
|
| 167 |
+
self.onchain_snapshot_num_norm = nn.LayerNorm(14)
|
| 168 |
+
self.onchain_snapshot_num_proj = nn.Linear(14, self.d_model)
|
| 169 |
+
|
| 170 |
+
# --- NEW: Layers for TrendingToken Numerical Features ---
|
| 171 |
+
# --- FIXED: Size reduced from 3 to 1 (rank only) ---
|
| 172 |
+
self.trending_token_num_norm = nn.LayerNorm(1)
|
| 173 |
+
self.trending_token_num_proj = nn.Linear(1, self.d_model)
|
| 174 |
+
# --- NEW: Embeddings for categorical IDs ---
|
| 175 |
+
self.trending_list_source_embedding = nn.Embedding(vocab.NUM_TRENDING_LIST_SOURCES, self.d_model)
|
| 176 |
+
self.trending_timeframe_embedding = nn.Embedding(vocab.NUM_TRENDING_LIST_TIMEFRAMES, self.d_model)
|
| 177 |
+
|
| 178 |
+
# --- NEW: Layers for BoostedToken Numerical Features ---
|
| 179 |
+
self.boosted_token_num_norm = nn.LayerNorm(2) # total_boost_amount, rank
|
| 180 |
+
self.boosted_token_num_proj = nn.Linear(2, self.d_model)
|
| 181 |
+
|
| 182 |
+
# --- NEW: Layers for DexBoost_Paid Numerical Features ---
|
| 183 |
+
self.dexboost_paid_num_norm = nn.LayerNorm(2) # amount, total_amount_on_token
|
| 184 |
+
self.dexboost_paid_num_proj = nn.Linear(2, self.d_model)
|
| 185 |
+
|
| 186 |
+
# --- NEW: Layers for DexProfile_Updated Features ---
|
| 187 |
+
self.dexprofile_updated_flags_proj = nn.Linear(4, self.d_model) # Project the 4 boolean flags
|
| 188 |
+
|
| 189 |
+
# --- NEW: Projection for all pre-computed embeddings (text/images) ---
|
| 190 |
+
self.precomputed_proj = nn.Linear(self.multi_modal_dim, self.d_model)
|
| 191 |
+
|
| 192 |
+
# --- NEW: Embedding for Protocol IDs (used in Migrated event) ---
|
| 193 |
+
self.protocol_embedding = nn.Embedding(vocab.NUM_PROTOCOLS, self.d_model)
|
| 194 |
+
|
| 195 |
+
# --- NEW: Embeddings for TrackerEncoder Events ---
|
| 196 |
+
# Note: NUM_CALL_CHANNELS might need to be large and managed as vocab grows.
|
| 197 |
+
self.alpha_group_embedding = nn.Embedding(vocab.NUM_ALPHA_GROUPS, self.d_model)
|
| 198 |
+
self.call_channel_embedding = nn.Embedding(vocab.NUM_CALL_CHANNELS, self.d_model)
|
| 199 |
+
self.cex_listing_embedding = nn.Embedding(vocab.NUM_EXCHANGES, self.d_model)
|
| 200 |
+
|
| 201 |
+
# --- NEW: Layers for GlobalTrendingEncoder Events ---
|
| 202 |
+
self.global_trending_num_norm = nn.LayerNorm(1) # rank
|
| 203 |
+
self.global_trending_num_proj = nn.Linear(1, self.d_model)
|
| 204 |
+
|
| 205 |
+
# --- NEW: Layers for ChainSnapshot Events ---
|
| 206 |
+
self.chainsnapshot_num_norm = nn.LayerNorm(2) # native_token_price_usd, gas_fee
|
| 207 |
+
self.chainsnapshot_num_proj = nn.Linear(2, self.d_model)
|
| 208 |
+
|
| 209 |
+
# --- NEW: Layers for Lighthouse_Snapshot Events ---
|
| 210 |
+
# --- FIXED: Size reduced from 7 to 5 ---
|
| 211 |
+
self.lighthousesnapshot_num_norm = nn.LayerNorm(5)
|
| 212 |
+
self.lighthousesnapshot_num_proj = nn.Linear(5, self.d_model)
|
| 213 |
+
# --- NEW: Embedding for timeframe ID (re-uses protocol_embedding) ---
|
| 214 |
+
self.lighthouse_timeframe_embedding = nn.Embedding(vocab.NUM_LIGHTHOUSE_TIMEFRAMES, self.d_model)
|
| 215 |
+
|
| 216 |
+
# --- NEW: Embeddings for Special Context Tokens ---
|
| 217 |
+
self.special_context_tokens = {'Middle': 0, 'RECENT': 1}
|
| 218 |
+
self.special_context_embedding = nn.Embedding(len(self.special_context_tokens), self.d_model)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
# --- 7. Prediction Head --- (Unchanged)
|
| 222 |
+
# self.prediction_head = nn.Linear(self.d_model, self.num_outputs)
|
| 223 |
+
|
| 224 |
+
# --- 8. Move all new modules to correct dtype ---
|
| 225 |
+
self.to(dtype)
|
| 226 |
+
print("Oracle model (full pipeline) initialized.")
|
| 227 |
+
|
| 228 |
+
def _normalize_and_project(self,
|
| 229 |
+
features: torch.Tensor,
|
| 230 |
+
norm_layer: nn.LayerNorm,
|
| 231 |
+
proj_layer: nn.Linear,
|
| 232 |
+
log_indices: Optional[List[int]] = None) -> torch.Tensor:
|
| 233 |
+
"""
|
| 234 |
+
A helper function to selectively apply log scaling, then normalize and project.
|
| 235 |
+
"""
|
| 236 |
+
# Make a copy to avoid in-place modification issues
|
| 237 |
+
processed_features = features.clone()
|
| 238 |
+
|
| 239 |
+
# Apply log scaling only to specified indices
|
| 240 |
+
if log_indices:
|
| 241 |
+
# Ensure log_indices are valid
|
| 242 |
+
valid_indices = [i for i in log_indices if i < processed_features.shape[-1]]
|
| 243 |
+
if valid_indices:
|
| 244 |
+
log_features = processed_features[:, :, valid_indices].to(torch.float32)
|
| 245 |
+
log_scaled = torch.sign(log_features) * torch.log1p(torch.abs(log_features))
|
| 246 |
+
processed_features[:, :, valid_indices] = log_scaled.to(processed_features.dtype)
|
| 247 |
+
|
| 248 |
+
# Normalize and project the entire feature set
|
| 249 |
+
norm_dtype = norm_layer.weight.dtype
|
| 250 |
+
proj_dtype = proj_layer.weight.dtype
|
| 251 |
+
normed_features = norm_layer(processed_features.to(norm_dtype))
|
| 252 |
+
return proj_layer(normed_features.to(proj_dtype))
|
| 253 |
+
|
| 254 |
+
def _run_snapshot_encoders(self,
|
| 255 |
+
batch: Dict[str, Any],
|
| 256 |
+
final_wallet_embeddings_raw: torch.Tensor,
|
| 257 |
+
wallet_addr_to_batch_idx: Dict[str, int]) -> Dict[str, torch.Tensor]:
|
| 258 |
+
"""
|
| 259 |
+
Runs snapshot-style encoders that process raw data into embeddings.
|
| 260 |
+
This is now truly end-to-end.
|
| 261 |
+
"""
|
| 262 |
+
device = self.device
|
| 263 |
+
all_holder_snapshot_embeds = []
|
| 264 |
+
|
| 265 |
+
# Iterate through each HolderSnapshot event's raw data
|
| 266 |
+
for raw_holder_list in batch['holder_snapshot_raw_data']:
|
| 267 |
+
processed_holder_data = []
|
| 268 |
+
for holder in raw_holder_list:
|
| 269 |
+
wallet_addr = holder['wallet']
|
| 270 |
+
# Get the graph-updated wallet embedding using its index
|
| 271 |
+
wallet_idx = wallet_addr_to_batch_idx.get(wallet_addr, 0) # 0 is padding
|
| 272 |
+
if wallet_idx > 0: # If it's a valid wallet
|
| 273 |
+
wallet_embedding = final_wallet_embeddings_raw[wallet_idx - 1] # Adjust for 1-based indexing
|
| 274 |
+
processed_holder_data.append({
|
| 275 |
+
'wallet_embedding': wallet_embedding,
|
| 276 |
+
'pct': holder['holding_pct']
|
| 277 |
+
})
|
| 278 |
+
# Pass the processed data to the HolderDistributionEncoder
|
| 279 |
+
all_holder_snapshot_embeds.append(self.holder_dist_encoder(processed_holder_data))
|
| 280 |
+
|
| 281 |
+
return {"holder_snapshot": torch.cat(all_holder_snapshot_embeds, dim=0) if all_holder_snapshot_embeds else torch.empty(0, self.d_model, device=device, dtype=self.dtype)}
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def _run_dynamic_encoders(self, batch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
|
| 285 |
+
"""
|
| 286 |
+
Runs all dynamic encoders and returns a dictionary of raw, unprojected embeddings.
|
| 287 |
+
"""
|
| 288 |
+
device = self.device
|
| 289 |
+
# --- NEW: Get pre-computed embedding indices ---
|
| 290 |
+
token_encoder_inputs = batch['token_encoder_inputs']
|
| 291 |
+
wallet_encoder_inputs = batch['wallet_encoder_inputs']
|
| 292 |
+
# The pre-computed embedding pool for the whole batch
|
| 293 |
+
embedding_pool = batch['embedding_pool']
|
| 294 |
+
|
| 295 |
+
ohlc_price_tensors = batch['ohlc_price_tensors'].to(device, self.dtype)
|
| 296 |
+
ohlc_interval_ids = batch['ohlc_interval_ids'].to(device)
|
| 297 |
+
graph_updater_links = batch['graph_updater_links']
|
| 298 |
+
|
| 299 |
+
# 1a. Encode Tokens
|
| 300 |
+
# --- FIXED: Check for a key that still exists ---
|
| 301 |
+
if token_encoder_inputs['name_embed_indices'].numel() > 0:
|
| 302 |
+
# --- AGGRESSIVE LOGGING ---
|
| 303 |
+
print("\n--- [Oracle DynamicEncoder LOG] ---")
|
| 304 |
+
print(f"[Oracle LOG] embedding_pool shape: {embedding_pool.shape}")
|
| 305 |
+
print(f"[Oracle LOG] name_embed_indices (shape {token_encoder_inputs['name_embed_indices'].shape}):\n{token_encoder_inputs['name_embed_indices']}")
|
| 306 |
+
print(f"[Oracle LOG] symbol_embed_indices (shape {token_encoder_inputs['symbol_embed_indices'].shape}):\n{token_encoder_inputs['symbol_embed_indices']}")
|
| 307 |
+
print(f"[Oracle LOG] image_embed_indices (shape {token_encoder_inputs['image_embed_indices'].shape}):\n{token_encoder_inputs['image_embed_indices']}")
|
| 308 |
+
print("--- [Oracle LOG] Calling F.embedding and TokenEncoder... ---")
|
| 309 |
+
# --- END LOGGING ---
|
| 310 |
+
# --- NEW: Gather pre-computed embeddings and pass to encoder ---
|
| 311 |
+
# --- CRITICAL FIX: Remove keys that are not part of the TokenEncoder's signature ---
|
| 312 |
+
encoder_args = token_encoder_inputs.copy()
|
| 313 |
+
encoder_args.pop('_addresses_for_lookup', None) # This key is for the WalletEncoder
|
| 314 |
+
encoder_args.pop('name_embed_indices', None)
|
| 315 |
+
encoder_args.pop('symbol_embed_indices', None)
|
| 316 |
+
encoder_args.pop('image_embed_indices', None)
|
| 317 |
+
|
| 318 |
+
# --- SAFETY: Create a padded view of the embedding pool and map missing indices (-1) to pad ---
|
| 319 |
+
if embedding_pool.numel() > 0:
|
| 320 |
+
pad_row = torch.zeros(1, embedding_pool.size(1), device=device, dtype=embedding_pool.dtype)
|
| 321 |
+
pool_padded = torch.cat([pad_row, embedding_pool], dim=0)
|
| 322 |
+
def pad_and_lookup(idx_tensor: torch.Tensor) -> torch.Tensor:
|
| 323 |
+
# Map valid indices >=0 to +1 (shift), invalid (<0) to 0 (pad)
|
| 324 |
+
shifted = torch.where(idx_tensor >= 0, idx_tensor + 1, torch.zeros_like(idx_tensor))
|
| 325 |
+
return F.embedding(shifted, pool_padded)
|
| 326 |
+
name_embeds = pad_and_lookup(token_encoder_inputs['name_embed_indices'])
|
| 327 |
+
symbol_embeds = pad_and_lookup(token_encoder_inputs['symbol_embed_indices'])
|
| 328 |
+
image_embeds = pad_and_lookup(token_encoder_inputs['image_embed_indices'])
|
| 329 |
+
else:
|
| 330 |
+
# Empty pool: provide zeros with correct shapes
|
| 331 |
+
n = token_encoder_inputs['name_embed_indices'].shape[0]
|
| 332 |
+
d = self.multi_modal_dim
|
| 333 |
+
zeros = torch.zeros(n, d, device=device, dtype=self.dtype)
|
| 334 |
+
name_embeds = zeros
|
| 335 |
+
symbol_embeds = zeros
|
| 336 |
+
image_embeds = zeros
|
| 337 |
+
|
| 338 |
+
batch_token_embeddings_unupd = self.token_encoder(
|
| 339 |
+
name_embeds=name_embeds,
|
| 340 |
+
symbol_embeds=symbol_embeds,
|
| 341 |
+
image_embeds=image_embeds,
|
| 342 |
+
# Pass all other keys like protocol_ids, is_vanity_flags, etc.
|
| 343 |
+
**encoder_args
|
| 344 |
+
)
|
| 345 |
+
else:
|
| 346 |
+
batch_token_embeddings_unupd = torch.empty(0, self.token_encoder.output_dim, device=device, dtype=self.dtype)
|
| 347 |
+
|
| 348 |
+
# 1b. Encode Wallets
|
| 349 |
+
if wallet_encoder_inputs['profile_rows']:
|
| 350 |
+
temp_token_lookup = {
|
| 351 |
+
addr: batch_token_embeddings_unupd[i]
|
| 352 |
+
for i, addr in enumerate(batch['token_encoder_inputs']['_addresses_for_lookup']) # Use helper key
|
| 353 |
+
}
|
| 354 |
+
initial_wallet_embeddings = self.wallet_encoder(
|
| 355 |
+
**wallet_encoder_inputs,
|
| 356 |
+
token_vibe_lookup=temp_token_lookup,
|
| 357 |
+
embedding_pool=embedding_pool
|
| 358 |
+
)
|
| 359 |
+
else:
|
| 360 |
+
initial_wallet_embeddings = torch.empty(0, self.wallet_encoder.d_model, device=device, dtype=self.dtype)
|
| 361 |
+
|
| 362 |
+
# 1c. Encode OHLC
|
| 363 |
+
if ohlc_price_tensors.shape[0] > 0:
|
| 364 |
+
batch_ohlc_embeddings_raw = self.ohlc_embedder(ohlc_price_tensors, ohlc_interval_ids)
|
| 365 |
+
else:
|
| 366 |
+
batch_ohlc_embeddings_raw = torch.empty(0, self.ohlc_embedder.output_dim, device=device, dtype=self.dtype)
|
| 367 |
+
|
| 368 |
+
# 1d. Run Graph Updater
|
| 369 |
+
pad_wallet_raw = self.pad_wallet_emb.to(self.dtype)
|
| 370 |
+
pad_token_raw = self.pad_token_emb.to(self.dtype)
|
| 371 |
+
padded_wallet_tensor = torch.cat([pad_wallet_raw, initial_wallet_embeddings], dim=0)
|
| 372 |
+
padded_token_tensor = torch.cat([pad_token_raw, batch_token_embeddings_unupd], dim=0)
|
| 373 |
+
|
| 374 |
+
x_dict_initial = {}
|
| 375 |
+
if padded_wallet_tensor.shape[0] > 1: x_dict_initial['wallet'] = padded_wallet_tensor
|
| 376 |
+
if padded_token_tensor.shape[0] > 1: x_dict_initial['token'] = padded_token_tensor
|
| 377 |
+
|
| 378 |
+
if x_dict_initial and graph_updater_links:
|
| 379 |
+
final_entity_embeddings_dict = self.graph_updater(x_dict_initial, graph_updater_links)
|
| 380 |
+
final_padded_wallet_embs = final_entity_embeddings_dict.get('wallet', padded_wallet_tensor)
|
| 381 |
+
final_padded_token_embs = final_entity_embeddings_dict.get('token', padded_token_tensor)
|
| 382 |
+
else:
|
| 383 |
+
final_padded_wallet_embs = padded_wallet_tensor
|
| 384 |
+
final_padded_token_embs = padded_token_tensor
|
| 385 |
+
|
| 386 |
+
# Strip padding before returning
|
| 387 |
+
final_wallet_embeddings_raw = final_padded_wallet_embs[1:]
|
| 388 |
+
final_token_embeddings_raw = final_padded_token_embs[1:]
|
| 389 |
+
|
| 390 |
+
return {
|
| 391 |
+
"wallet": final_wallet_embeddings_raw,
|
| 392 |
+
"token": final_token_embeddings_raw,
|
| 393 |
+
"ohlc": batch_ohlc_embeddings_raw
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
def _project_and_gather_embeddings(self, raw_embeds: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
| 397 |
+
"""
|
| 398 |
+
Projects raw embeddings to d_model and gathers them into sequence-aligned tensors.
|
| 399 |
+
"""
|
| 400 |
+
# Project raw embeddings to d_model
|
| 401 |
+
final_wallet_proj = self.wallet_proj(raw_embeds['wallet'])
|
| 402 |
+
final_token_proj = self.token_proj(raw_embeds['token'])
|
| 403 |
+
final_ohlc_proj = self.ohlc_proj(raw_embeds['ohlc'])
|
| 404 |
+
|
| 405 |
+
# Project padding embeddings to d_model
|
| 406 |
+
pad_wallet = self.wallet_proj(self.pad_wallet_emb.to(self.dtype))
|
| 407 |
+
pad_token = self.token_proj(self.pad_token_emb.to(self.dtype))
|
| 408 |
+
pad_ohlc = self.ohlc_proj(self.pad_ohlc_emb.to(self.dtype))
|
| 409 |
+
pad_holder_snapshot = self.pad_holder_snapshot_emb.to(self.dtype) # Already d_model
|
| 410 |
+
|
| 411 |
+
# --- NEW: Project pre-computed embeddings and create lookup ---
|
| 412 |
+
final_precomputed_proj = self.precomputed_proj(batch['embedding_pool'])
|
| 413 |
+
pad_precomputed = self.precomputed_proj(self.pad_precomputed_emb.to(self.dtype))
|
| 414 |
+
final_precomputed_lookup = torch.cat([pad_precomputed, final_precomputed_proj], dim=0)
|
| 415 |
+
|
| 416 |
+
# Create final lookup tables with padding at index 0
|
| 417 |
+
final_wallet_lookup = torch.cat([pad_wallet, final_wallet_proj], dim=0)
|
| 418 |
+
final_token_lookup = torch.cat([pad_token, final_token_proj], dim=0)
|
| 419 |
+
final_ohlc_lookup = torch.cat([pad_ohlc, final_ohlc_proj], dim=0)
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
# --- NEW: Add Role Embeddings ---
|
| 423 |
+
main_role_emb = self.token_role_embedding(torch.tensor(self.main_token_role_id, device=self.device))
|
| 424 |
+
quote_role_emb = self.token_role_embedding(torch.tensor(self.quote_token_role_id, device=self.device))
|
| 425 |
+
trending_role_emb = self.token_role_embedding(torch.tensor(self.trending_token_role_id, device=self.device))
|
| 426 |
+
|
| 427 |
+
# Gather base embeddings
|
| 428 |
+
gathered_main_token_embs = F.embedding(batch['token_indices'], final_token_lookup)
|
| 429 |
+
gathered_quote_token_embs = F.embedding(batch['quote_token_indices'], final_token_lookup)
|
| 430 |
+
gathered_trending_token_embs = F.embedding(batch['trending_token_indices'], final_token_lookup)
|
| 431 |
+
gathered_boosted_token_embs = F.embedding(batch['boosted_token_indices'], final_token_lookup)
|
| 432 |
+
|
| 433 |
+
# --- NEW: Handle HolderSnapshot ---
|
| 434 |
+
final_holder_snapshot_lookup = torch.cat([pad_holder_snapshot, raw_embeds['holder_snapshot']], dim=0)
|
| 435 |
+
|
| 436 |
+
# Gather embeddings for each event in the sequence
|
| 437 |
+
return {
|
| 438 |
+
"wallet": F.embedding(batch['wallet_indices'], final_wallet_lookup),
|
| 439 |
+
"token": gathered_main_token_embs, # This is the baseline, no role needed
|
| 440 |
+
"ohlc": F.embedding(batch['ohlc_indices'], final_ohlc_lookup),
|
| 441 |
+
"original_author": F.embedding(batch['original_author_indices'], final_wallet_lookup), # NEW
|
| 442 |
+
"dest_wallet": F.embedding(batch['dest_wallet_indices'], final_wallet_lookup), # Also gather dest wallet
|
| 443 |
+
"quote_token": gathered_quote_token_embs + quote_role_emb,
|
| 444 |
+
"trending_token": gathered_trending_token_embs + trending_role_emb,
|
| 445 |
+
"boosted_token": gathered_boosted_token_embs + trending_role_emb, # Same role as trending
|
| 446 |
+
"holder_snapshot": F.embedding(batch['holder_snapshot_indices'], final_holder_snapshot_lookup), # NEW
|
| 447 |
+
"precomputed": final_precomputed_lookup # NEW: Pass the full lookup table
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
def _get_transfer_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 451 |
+
"""
|
| 452 |
+
Calculates the special embeddings for Transfer/LargeTransfer events.
|
| 453 |
+
"""
|
| 454 |
+
device = self.device
|
| 455 |
+
transfer_numerical_features = batch['transfer_numerical_features']
|
| 456 |
+
event_type_ids = batch['event_type_ids']
|
| 457 |
+
|
| 458 |
+
# --- FIXED: Selectively log-scale features ---
|
| 459 |
+
# Log scale: token_amount (idx 0), priority_fee (idx 3)
|
| 460 |
+
# Linear scale: transfer_pct_of_total_supply (idx 1), transfer_pct_of_holding (idx 2)
|
| 461 |
+
projected_transfer_features = self._normalize_and_project(
|
| 462 |
+
transfer_numerical_features, self.transfer_num_norm, self.transfer_num_proj, log_indices=[0, 3]
|
| 463 |
+
)
|
| 464 |
+
# Create a mask for Transfer/LargeTransfer events
|
| 465 |
+
transfer_event_ids = [self.event_type_to_id.get('Transfer', -1), self.event_type_to_id.get('LargeTransfer', -1)] # ADDED LargeTransfer
|
| 466 |
+
transfer_mask = torch.isin(event_type_ids, torch.tensor(transfer_event_ids, device=device)).unsqueeze(-1)
|
| 467 |
+
|
| 468 |
+
# Combine destination wallet and numerical features, then apply mask
|
| 469 |
+
return (gathered_embeds['dest_wallet'] + projected_transfer_features) * transfer_mask
|
| 470 |
+
|
| 471 |
+
def _get_trade_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 472 |
+
"""
|
| 473 |
+
Calculates the special embeddings for Trade events.
|
| 474 |
+
"""
|
| 475 |
+
device = self.device
|
| 476 |
+
trade_numerical_features = batch['trade_numerical_features']
|
| 477 |
+
trade_dex_ids = batch['trade_dex_ids'] # NEW
|
| 478 |
+
trade_direction_ids = batch['trade_direction_ids']
|
| 479 |
+
trade_mev_protection_ids = batch['trade_mev_protection_ids'] # NEW
|
| 480 |
+
trade_is_bundle_ids = batch['trade_is_bundle_ids'] # NEW
|
| 481 |
+
event_type_ids = batch['event_type_ids']
|
| 482 |
+
|
| 483 |
+
# --- FIXED: Selectively log-scale features ---
|
| 484 |
+
# Log scale: sol_amount (idx 0), priority_fee (idx 1), total_usd (idx 7)
|
| 485 |
+
# Linear scale: pcts, slippage, price_impact, success flags
|
| 486 |
+
projected_trade_features = self._normalize_and_project(
|
| 487 |
+
trade_numerical_features, self.trade_num_norm, self.trade_num_proj, log_indices=[0, 1, 7]
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
# --- CORRECTED: This layer now handles both generic and large trades ---
|
| 491 |
+
trade_event_names = ['Trade', 'LargeTrade']
|
| 492 |
+
trade_event_ids = [self.event_type_to_id.get(name, -1) for name in trade_event_names]
|
| 493 |
+
|
| 494 |
+
# Create mask where event_type_id is one of the trade event ids
|
| 495 |
+
trade_mask = torch.isin(event_type_ids, torch.tensor(trade_event_ids, device=device)).unsqueeze(-1)
|
| 496 |
+
|
| 497 |
+
# --- NEW: Get embedding for the categorical dex_id ---
|
| 498 |
+
dex_id_embeds = self.dex_platform_embedding(trade_dex_ids)
|
| 499 |
+
direction_embeds = self.trade_direction_embedding(trade_direction_ids)
|
| 500 |
+
mev_embeds = self.mev_protection_embedding(trade_mev_protection_ids) # NEW
|
| 501 |
+
bundle_embeds = self.is_bundle_embedding(trade_is_bundle_ids) # NEW
|
| 502 |
+
|
| 503 |
+
return (projected_trade_features + dex_id_embeds + direction_embeds + mev_embeds + bundle_embeds) * trade_mask
|
| 504 |
+
|
| 505 |
+
def _get_deployer_trade_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 506 |
+
"""
|
| 507 |
+
Calculates the special embeddings for Deployer_Trade events using its own layers.
|
| 508 |
+
"""
|
| 509 |
+
device = self.device
|
| 510 |
+
deployer_trade_numerical_features = batch['deployer_trade_numerical_features']
|
| 511 |
+
trade_dex_ids = batch['trade_dex_ids'] # NEW: Re-use the same ID tensor
|
| 512 |
+
trade_direction_ids = batch['trade_direction_ids']
|
| 513 |
+
trade_mev_protection_ids = batch['trade_mev_protection_ids'] # NEW
|
| 514 |
+
trade_is_bundle_ids = batch['trade_is_bundle_ids'] # NEW
|
| 515 |
+
event_type_ids = batch['event_type_ids']
|
| 516 |
+
|
| 517 |
+
# --- FIXED: Selectively log-scale features ---
|
| 518 |
+
# Log scale: sol_amount (idx 0), priority_fee (idx 1), total_usd (idx 7)
|
| 519 |
+
projected_deployer_trade_features = self._normalize_and_project(
|
| 520 |
+
deployer_trade_numerical_features, self.deployer_trade_num_norm, self.deployer_trade_num_proj, log_indices=[0, 1, 7]
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
dex_id_embeds = self.dex_platform_embedding(trade_dex_ids)
|
| 524 |
+
direction_embeds = self.trade_direction_embedding(trade_direction_ids)
|
| 525 |
+
mev_embeds = self.mev_protection_embedding(trade_mev_protection_ids) # NEW
|
| 526 |
+
bundle_embeds = self.is_bundle_embedding(trade_is_bundle_ids) # NEW
|
| 527 |
+
|
| 528 |
+
deployer_trade_mask = (event_type_ids == self.event_type_to_id.get('Deployer_Trade', -1)).unsqueeze(-1)
|
| 529 |
+
return (projected_deployer_trade_features + dex_id_embeds + direction_embeds + mev_embeds + bundle_embeds) * deployer_trade_mask
|
| 530 |
+
|
| 531 |
+
def _get_smart_wallet_trade_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 532 |
+
"""
|
| 533 |
+
Calculates the special embeddings for SmartWallet_Trade events using its own layers.
|
| 534 |
+
"""
|
| 535 |
+
device = self.device
|
| 536 |
+
smart_wallet_trade_numerical_features = batch['smart_wallet_trade_numerical_features']
|
| 537 |
+
trade_dex_ids = batch['trade_dex_ids'] # NEW: Re-use the same ID tensor
|
| 538 |
+
trade_direction_ids = batch['trade_direction_ids']
|
| 539 |
+
trade_mev_protection_ids = batch['trade_mev_protection_ids'] # NEW
|
| 540 |
+
trade_is_bundle_ids = batch['trade_is_bundle_ids'] # NEW
|
| 541 |
+
event_type_ids = batch['event_type_ids']
|
| 542 |
+
|
| 543 |
+
# --- FIXED: Selectively log-scale features ---
|
| 544 |
+
# Log scale: sol_amount (idx 0), priority_fee (idx 1), total_usd (idx 7)
|
| 545 |
+
projected_features = self._normalize_and_project(
|
| 546 |
+
smart_wallet_trade_numerical_features, self.smart_wallet_trade_num_norm, self.smart_wallet_trade_num_proj, log_indices=[0, 1, 7]
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
dex_id_embeds = self.dex_platform_embedding(trade_dex_ids)
|
| 550 |
+
direction_embeds = self.trade_direction_embedding(trade_direction_ids)
|
| 551 |
+
mev_embeds = self.mev_protection_embedding(trade_mev_protection_ids) # NEW
|
| 552 |
+
bundle_embeds = self.is_bundle_embedding(trade_is_bundle_ids) # NEW
|
| 553 |
+
|
| 554 |
+
mask = (event_type_ids == self.event_type_to_id.get('SmartWallet_Trade', -1)).unsqueeze(-1)
|
| 555 |
+
return (projected_features + dex_id_embeds + direction_embeds + mev_embeds + bundle_embeds) * mask
|
| 556 |
+
|
| 557 |
+
def _get_pool_created_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 558 |
+
"""
|
| 559 |
+
Calculates the special embeddings for PoolCreated events.
|
| 560 |
+
"""
|
| 561 |
+
device = self.device
|
| 562 |
+
pool_created_numerical_features = batch['pool_created_numerical_features']
|
| 563 |
+
pool_created_protocol_ids = batch['pool_created_protocol_ids'] # NEW
|
| 564 |
+
event_type_ids = batch['event_type_ids']
|
| 565 |
+
|
| 566 |
+
# --- FIXED: Selectively log-scale features ---
|
| 567 |
+
# Log scale: base_amount (idx 0), quote_amount (idx 1)
|
| 568 |
+
# Linear scale: pcts (idx 2, 3)
|
| 569 |
+
projected_features = self._normalize_and_project(
|
| 570 |
+
pool_created_numerical_features, self.pool_created_num_norm, self.pool_created_num_proj, log_indices=[0, 1]
|
| 571 |
+
)
|
| 572 |
+
# --- NEW: Get embedding for the categorical protocol_id ---
|
| 573 |
+
protocol_id_embeds = self.protocol_embedding(pool_created_protocol_ids)
|
| 574 |
+
|
| 575 |
+
# Create mask for the event
|
| 576 |
+
mask = (event_type_ids == self.event_type_to_id.get('PoolCreated', -1)).unsqueeze(-1)
|
| 577 |
+
|
| 578 |
+
# Combine Quote Token embedding with projected numericals
|
| 579 |
+
return (gathered_embeds['quote_token'] + projected_features + protocol_id_embeds) * mask
|
| 580 |
+
|
| 581 |
+
def _get_liquidity_change_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 582 |
+
"""
|
| 583 |
+
Calculates the special embeddings for LiquidityChange events.
|
| 584 |
+
"""
|
| 585 |
+
device = self.device
|
| 586 |
+
liquidity_change_numerical_features = batch['liquidity_change_numerical_features']
|
| 587 |
+
liquidity_change_type_ids = batch['liquidity_change_type_ids'] # NEW
|
| 588 |
+
event_type_ids = batch['event_type_ids']
|
| 589 |
+
|
| 590 |
+
# --- FIXED: Selectively log-scale features ---
|
| 591 |
+
# Log scale: quote_amount (idx 0)
|
| 592 |
+
projected_features = self._normalize_and_project(
|
| 593 |
+
liquidity_change_numerical_features, self.liquidity_change_num_norm, self.liquidity_change_num_proj, log_indices=[0]
|
| 594 |
+
)
|
| 595 |
+
# --- NEW: Get embedding for the categorical change_type_id ---
|
| 596 |
+
change_type_embeds = self.liquidity_change_type_embedding(liquidity_change_type_ids)
|
| 597 |
+
|
| 598 |
+
# Create mask for the event
|
| 599 |
+
mask = (event_type_ids == self.event_type_to_id.get('LiquidityChange', -1)).unsqueeze(-1)
|
| 600 |
+
|
| 601 |
+
# Combine Quote Token embedding with projected numericals
|
| 602 |
+
return (gathered_embeds['quote_token'] + projected_features + change_type_embeds) * mask
|
| 603 |
+
|
| 604 |
+
def _get_fee_collected_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 605 |
+
"""
|
| 606 |
+
Calculates the special embeddings for FeeCollected events.
|
| 607 |
+
"""
|
| 608 |
+
device = self.device
|
| 609 |
+
fee_collected_numerical_features = batch['fee_collected_numerical_features']
|
| 610 |
+
event_type_ids = batch['event_type_ids']
|
| 611 |
+
|
| 612 |
+
# --- FIXED: Single amount, log-scale ---
|
| 613 |
+
projected_features = self._normalize_and_project(
|
| 614 |
+
fee_collected_numerical_features, self.fee_collected_num_norm, self.fee_collected_num_proj, log_indices=[0]
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
# Create mask for the event
|
| 618 |
+
mask = (event_type_ids == self.event_type_to_id.get('FeeCollected', -1)).unsqueeze(-1)
|
| 619 |
+
|
| 620 |
+
return projected_features * mask
|
| 621 |
+
|
| 622 |
+
def _get_token_burn_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 623 |
+
"""
|
| 624 |
+
Calculates the special embeddings for TokenBurn events.
|
| 625 |
+
"""
|
| 626 |
+
device = self.device
|
| 627 |
+
token_burn_numerical_features = batch['token_burn_numerical_features']
|
| 628 |
+
event_type_ids = batch['event_type_ids']
|
| 629 |
+
|
| 630 |
+
# --- FIXED: Selectively log-scale features ---
|
| 631 |
+
# Log scale: amount_tokens_burned (idx 1)
|
| 632 |
+
# Linear scale: amount_pct_of_total_supply (idx 0)
|
| 633 |
+
projected_features = self._normalize_and_project(
|
| 634 |
+
token_burn_numerical_features, self.token_burn_num_norm, self.token_burn_num_proj, log_indices=[1]
|
| 635 |
+
)
|
| 636 |
+
# Create mask for the event
|
| 637 |
+
mask = (event_type_ids == self.event_type_to_id.get('TokenBurn', -1)).unsqueeze(-1)
|
| 638 |
+
|
| 639 |
+
return projected_features * mask
|
| 640 |
+
|
| 641 |
+
def _get_supply_lock_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 642 |
+
"""
|
| 643 |
+
Calculates the special embeddings for SupplyLock events.
|
| 644 |
+
"""
|
| 645 |
+
device = self.device
|
| 646 |
+
supply_lock_numerical_features = batch['supply_lock_numerical_features']
|
| 647 |
+
event_type_ids = batch['event_type_ids']
|
| 648 |
+
|
| 649 |
+
# --- FIXED: Selectively log-scale features ---
|
| 650 |
+
# Log scale: lock_duration (idx 1)
|
| 651 |
+
# Linear scale: amount_pct_of_total_supply (idx 0)
|
| 652 |
+
projected_features = self._normalize_and_project(
|
| 653 |
+
supply_lock_numerical_features, self.supply_lock_num_norm, self.supply_lock_num_proj, log_indices=[1]
|
| 654 |
+
)
|
| 655 |
+
# Create mask for the event
|
| 656 |
+
mask = (event_type_ids == self.event_type_to_id.get('SupplyLock', -1)).unsqueeze(-1)
|
| 657 |
+
|
| 658 |
+
return projected_features * mask
|
| 659 |
+
|
| 660 |
+
def _get_onchain_snapshot_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 661 |
+
"""
|
| 662 |
+
Calculates the special embeddings for OnChain_Snapshot events.
|
| 663 |
+
"""
|
| 664 |
+
device = self.device
|
| 665 |
+
onchain_snapshot_numerical_features = batch['onchain_snapshot_numerical_features']
|
| 666 |
+
event_type_ids = batch['event_type_ids']
|
| 667 |
+
|
| 668 |
+
# --- FIXED: Selectively log-scale features ---
|
| 669 |
+
# Log scale: counts, market_cap, liquidity, volume, fees (almost all)
|
| 670 |
+
# Linear scale: growth_rate, holder_pcts (indices 3, 4, 5, 6, 7)
|
| 671 |
+
projected_features = self._normalize_and_project(
|
| 672 |
+
onchain_snapshot_numerical_features, self.onchain_snapshot_num_norm, self.onchain_snapshot_num_proj, log_indices=[0, 1, 2, 8, 9, 10, 11, 12, 13]
|
| 673 |
+
)
|
| 674 |
+
# Create mask for the event
|
| 675 |
+
mask = (event_type_ids == self.event_type_to_id.get('OnChain_Snapshot', -1)).unsqueeze(-1)
|
| 676 |
+
|
| 677 |
+
return projected_features * mask
|
| 678 |
+
|
| 679 |
+
def _get_trending_token_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 680 |
+
"""
|
| 681 |
+
Calculates the special embeddings for TrendingToken events.
|
| 682 |
+
"""
|
| 683 |
+
device = self.device
|
| 684 |
+
trending_token_numerical_features = batch['trending_token_numerical_features']
|
| 685 |
+
trending_token_source_ids = batch['trending_token_source_ids'] # NEW
|
| 686 |
+
trending_token_timeframe_ids = batch['trending_token_timeframe_ids'] # NEW
|
| 687 |
+
event_type_ids = batch['event_type_ids']
|
| 688 |
+
|
| 689 |
+
# --- FIXED: Rank is already inverted (0-1), so treat as linear ---
|
| 690 |
+
projected_features = self._normalize_and_project(
|
| 691 |
+
trending_token_numerical_features, self.trending_token_num_norm, self.trending_token_num_proj, log_indices=None
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
# --- NEW: Get embeddings for categorical IDs ---
|
| 695 |
+
source_embeds = self.trending_list_source_embedding(trending_token_source_ids)
|
| 696 |
+
timeframe_embeds = self.trending_timeframe_embedding(trending_token_timeframe_ids)
|
| 697 |
+
|
| 698 |
+
# Create mask for the event
|
| 699 |
+
mask = (event_type_ids == self.event_type_to_id.get('TrendingToken', -1)).unsqueeze(-1)
|
| 700 |
+
|
| 701 |
+
# Combine Trending Token embedding with its projected numericals
|
| 702 |
+
return (gathered_embeds['trending_token'] + projected_features + source_embeds + timeframe_embeds) * mask
|
| 703 |
+
|
| 704 |
+
def _get_boosted_token_specific_embeddings(self, batch: Dict[str, torch.Tensor], gathered_embeds: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 705 |
+
"""
|
| 706 |
+
Calculates the special embeddings for BoostedToken events.
|
| 707 |
+
"""
|
| 708 |
+
device = self.device
|
| 709 |
+
boosted_token_numerical_features = batch['boosted_token_numerical_features']
|
| 710 |
+
event_type_ids = batch['event_type_ids']
|
| 711 |
+
|
| 712 |
+
# --- FIXED: Selectively log-scale features ---
|
| 713 |
+
# Log scale: total_boost_amount (idx 0)
|
| 714 |
+
# Linear scale: inverted rank (idx 1)
|
| 715 |
+
projected_features = self._normalize_and_project(
|
| 716 |
+
boosted_token_numerical_features, self.boosted_token_num_norm, self.boosted_token_num_proj, log_indices=[0]
|
| 717 |
+
)
|
| 718 |
+
# Create mask for the event
|
| 719 |
+
mask = (event_type_ids == self.event_type_to_id.get('BoostedToken', -1)).unsqueeze(-1)
|
| 720 |
+
|
| 721 |
+
# Combine Boosted Token embedding with its projected numericals
|
| 722 |
+
return (gathered_embeds['boosted_token'] + projected_features) * mask
|
| 723 |
+
|
| 724 |
+
def _get_dexboost_paid_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 725 |
+
"""
|
| 726 |
+
Calculates the special embeddings for DexBoost_Paid events.
|
| 727 |
+
"""
|
| 728 |
+
device = self.device
|
| 729 |
+
dexboost_paid_numerical_features = batch['dexboost_paid_numerical_features']
|
| 730 |
+
event_type_ids = batch['event_type_ids']
|
| 731 |
+
|
| 732 |
+
# --- FIXED: All features are amounts, so log-scale all ---
|
| 733 |
+
projected_features = self._normalize_and_project(
|
| 734 |
+
dexboost_paid_numerical_features, self.dexboost_paid_num_norm, self.dexboost_paid_num_proj, log_indices=[0, 1]
|
| 735 |
+
)
|
| 736 |
+
# Create mask for the event
|
| 737 |
+
mask = (event_type_ids == self.event_type_to_id.get('DexBoost_Paid', -1)).unsqueeze(-1)
|
| 738 |
+
|
| 739 |
+
return projected_features * mask
|
| 740 |
+
|
| 741 |
+
def _get_alphagroup_call_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 742 |
+
"""
|
| 743 |
+
Handles AlphaGroup_Call events by looking up the group_id embedding.
|
| 744 |
+
"""
|
| 745 |
+
device = self.device
|
| 746 |
+
group_ids = batch['alpha_group_ids']
|
| 747 |
+
event_type_ids = batch['event_type_ids']
|
| 748 |
+
|
| 749 |
+
# Look up the embedding for the group ID
|
| 750 |
+
group_embeds = self.alpha_group_embedding(group_ids)
|
| 751 |
+
|
| 752 |
+
# Create mask for the event
|
| 753 |
+
mask = (event_type_ids == self.event_type_to_id.get('AlphaGroup_Call', -1)).unsqueeze(-1)
|
| 754 |
+
return group_embeds * mask
|
| 755 |
+
|
| 756 |
+
def _get_channel_call_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 757 |
+
"""
|
| 758 |
+
Handles Channel_Call events by looking up the channel_id embedding.
|
| 759 |
+
"""
|
| 760 |
+
device = self.device
|
| 761 |
+
channel_ids = batch['channel_ids']
|
| 762 |
+
event_type_ids = batch['event_type_ids']
|
| 763 |
+
|
| 764 |
+
channel_embeds = self.call_channel_embedding(channel_ids)
|
| 765 |
+
mask = (event_type_ids == self.event_type_to_id.get('Channel_Call', -1)).unsqueeze(-1)
|
| 766 |
+
return channel_embeds * mask
|
| 767 |
+
|
| 768 |
+
def _get_cexlisting_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 769 |
+
"""
|
| 770 |
+
Handles CexListing events by looking up the exchange_id embedding.
|
| 771 |
+
"""
|
| 772 |
+
device = self.device
|
| 773 |
+
exchange_ids = batch['exchange_ids']
|
| 774 |
+
event_type_ids = batch['event_type_ids']
|
| 775 |
+
|
| 776 |
+
exchange_embeds = self.cex_listing_embedding(exchange_ids)
|
| 777 |
+
mask = (event_type_ids == self.event_type_to_id.get('CexListing', -1)).unsqueeze(-1)
|
| 778 |
+
return exchange_embeds * mask
|
| 779 |
+
|
| 780 |
+
def _get_chainsnapshot_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 781 |
+
"""
|
| 782 |
+
Handles ChainSnapshot events.
|
| 783 |
+
"""
|
| 784 |
+
device = self.device
|
| 785 |
+
numerical_features = batch['chainsnapshot_numerical_features']
|
| 786 |
+
event_type_ids = batch['event_type_ids']
|
| 787 |
+
|
| 788 |
+
# --- FIXED: All features are amounts/prices, so log-scale all ---
|
| 789 |
+
projected_features = self._normalize_and_project(
|
| 790 |
+
numerical_features, self.chainsnapshot_num_norm, self.chainsnapshot_num_proj, log_indices=[0, 1]
|
| 791 |
+
)
|
| 792 |
+
mask = (event_type_ids == self.event_type_to_id.get('ChainSnapshot', -1)).unsqueeze(-1)
|
| 793 |
+
return projected_features * mask
|
| 794 |
+
|
| 795 |
+
def _get_lighthousesnapshot_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 796 |
+
"""
|
| 797 |
+
Handles Lighthouse_Snapshot events.
|
| 798 |
+
"""
|
| 799 |
+
device = self.device
|
| 800 |
+
numerical_features = batch['lighthousesnapshot_numerical_features']
|
| 801 |
+
protocol_ids = batch['lighthousesnapshot_protocol_ids'] # NEW
|
| 802 |
+
timeframe_ids = batch['lighthousesnapshot_timeframe_ids'] # NEW
|
| 803 |
+
event_type_ids = batch['event_type_ids']
|
| 804 |
+
|
| 805 |
+
# --- FIXED: All features are counts/volumes, so log-scale all ---
|
| 806 |
+
projected_features = self._normalize_and_project(
|
| 807 |
+
numerical_features, self.lighthousesnapshot_num_norm, self.lighthousesnapshot_num_proj, log_indices=[0, 1, 2, 3, 4]
|
| 808 |
+
)
|
| 809 |
+
# --- NEW: Get embeddings for categorical IDs ---
|
| 810 |
+
# Re-use the main protocol embedding layer
|
| 811 |
+
protocol_embeds = self.protocol_embedding(protocol_ids)
|
| 812 |
+
timeframe_embeds = self.lighthouse_timeframe_embedding(timeframe_ids)
|
| 813 |
+
|
| 814 |
+
mask = (event_type_ids == self.event_type_to_id.get('Lighthouse_Snapshot', -1)).unsqueeze(-1)
|
| 815 |
+
return (projected_features + protocol_embeds + timeframe_embeds) * mask
|
| 816 |
+
|
| 817 |
+
def _get_migrated_specific_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 818 |
+
"""
|
| 819 |
+
Handles Migrated events by looking up the protocol_id embedding.
|
| 820 |
+
"""
|
| 821 |
+
device = self.device
|
| 822 |
+
protocol_ids = batch['migrated_protocol_ids']
|
| 823 |
+
event_type_ids = batch['event_type_ids']
|
| 824 |
+
|
| 825 |
+
# Look up the embedding for the protocol ID
|
| 826 |
+
protocol_embeds = self.protocol_embedding(protocol_ids)
|
| 827 |
+
|
| 828 |
+
# Create mask for the event
|
| 829 |
+
mask = (event_type_ids == self.event_type_to_id.get('Migrated', -1)).unsqueeze(-1)
|
| 830 |
+
return protocol_embeds * mask
|
| 831 |
+
|
| 832 |
+
def _get_special_context_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 833 |
+
"""
|
| 834 |
+
Handles special context tokens like 'Middle' and 'RECENT' by adding their unique learnable embeddings.
|
| 835 |
+
"""
|
| 836 |
+
device = self.device
|
| 837 |
+
event_type_ids = batch['event_type_ids']
|
| 838 |
+
B, L = event_type_ids.shape
|
| 839 |
+
|
| 840 |
+
middle_id = self.event_type_to_id.get('Middle', -1)
|
| 841 |
+
recent_id = self.event_type_to_id.get('RECENT', -1)
|
| 842 |
+
|
| 843 |
+
middle_mask = (event_type_ids == middle_id)
|
| 844 |
+
recent_mask = (event_type_ids == recent_id)
|
| 845 |
+
|
| 846 |
+
middle_emb = self.special_context_embedding(torch.tensor(self.special_context_tokens['Middle'], device=device))
|
| 847 |
+
recent_emb = self.special_context_embedding(torch.tensor(self.special_context_tokens['RECENT'], device=device))
|
| 848 |
+
|
| 849 |
+
# Add the embeddings at the correct locations
|
| 850 |
+
return middle_mask.unsqueeze(-1) * middle_emb + recent_mask.unsqueeze(-1) * recent_emb
|
| 851 |
+
|
| 852 |
+
def _pool_hidden_states(self,
|
| 853 |
+
hidden_states: torch.Tensor,
|
| 854 |
+
attention_mask: torch.Tensor) -> torch.Tensor:
|
| 855 |
+
"""
|
| 856 |
+
Pools variable-length hidden states into a single embedding per sequence by
|
| 857 |
+
selecting the last non-masked token for each batch element.
|
| 858 |
+
"""
|
| 859 |
+
if hidden_states.size(0) == 0:
|
| 860 |
+
return torch.empty(0, self.d_model, device=hidden_states.device, dtype=hidden_states.dtype)
|
| 861 |
+
|
| 862 |
+
seq_lengths = attention_mask.long().sum(dim=1)
|
| 863 |
+
last_indices = torch.clamp(seq_lengths - 1, min=0)
|
| 864 |
+
batch_indices = torch.arange(hidden_states.size(0), device=hidden_states.device)
|
| 865 |
+
return hidden_states[batch_indices, last_indices]
|
| 866 |
+
|
| 867 |
+
def forward(self, batch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
|
| 868 |
+
device = self.device
|
| 869 |
+
|
| 870 |
+
# Unpack core sequence tensors
|
| 871 |
+
event_type_ids = batch['event_type_ids'].to(device)
|
| 872 |
+
timestamps_float = batch['timestamps_float'].to(device)
|
| 873 |
+
relative_ts = batch['relative_ts'].to(device, self.dtype)
|
| 874 |
+
attention_mask = batch['attention_mask'].to(device)
|
| 875 |
+
|
| 876 |
+
B, L = event_type_ids.shape
|
| 877 |
+
if B == 0 or L == 0:
|
| 878 |
+
print("Warning: Received empty batch in Oracle forward.")
|
| 879 |
+
empty_hidden = torch.empty(0, L, self.d_model, device=device, dtype=self.dtype)
|
| 880 |
+
empty_mask = torch.empty(0, L, device=device, dtype=torch.long)
|
| 881 |
+
empty_quantiles = torch.empty(0, self.num_outputs, device=device, dtype=self.dtype)
|
| 882 |
+
return {
|
| 883 |
+
'quantile_logits': empty_quantiles,
|
| 884 |
+
'pooled_states': torch.empty(0, self.d_model, device=device, dtype=self.dtype),
|
| 885 |
+
'hidden_states': empty_hidden,
|
| 886 |
+
'attention_mask': empty_mask
|
| 887 |
+
}
|
| 888 |
+
|
| 889 |
+
# === 1. Run Dynamic Encoders (produces graph-updated entity embeddings) ===
|
| 890 |
+
dynamic_raw_embeds = self._run_dynamic_encoders(batch)
|
| 891 |
+
|
| 892 |
+
|
| 893 |
+
# === 2. Run Snapshot Encoders (uses dynamic_raw_embeds) ===
|
| 894 |
+
wallet_addr_to_batch_idx = batch['wallet_addr_to_batch_idx']
|
| 895 |
+
snapshot_raw_embeds = self._run_snapshot_encoders(batch, dynamic_raw_embeds['wallet'], wallet_addr_to_batch_idx)
|
| 896 |
+
|
| 897 |
+
# === 3. Project Raw Embeddings and Gather for Sequence ===
|
| 898 |
+
raw_embeds = {**dynamic_raw_embeds, **snapshot_raw_embeds}
|
| 899 |
+
gathered_embeds = self._project_and_gather_embeddings(raw_embeds, batch)
|
| 900 |
+
|
| 901 |
+
# === 4. Assemble Final `inputs_embeds` ===
|
| 902 |
+
event_embeds = self.event_type_embedding(event_type_ids)
|
| 903 |
+
ts_embeds = self.time_proj(self.time_encoder(timestamps_float))
|
| 904 |
+
# Stabilize relative time: minutes scale + signed log1p + LayerNorm before projection
|
| 905 |
+
relative_ts_fp32 = batch['relative_ts'].to(device, torch.float32)
|
| 906 |
+
rel_ts_minutes = relative_ts_fp32 / 60.0
|
| 907 |
+
rel_ts_processed = torch.sign(rel_ts_minutes) * torch.log1p(torch.abs(rel_ts_minutes))
|
| 908 |
+
# Match LayerNorm parameter dtype, then match Linear parameter dtype
|
| 909 |
+
norm_dtype = self.rel_ts_norm.weight.dtype
|
| 910 |
+
proj_dtype = self.rel_ts_proj.weight.dtype
|
| 911 |
+
rel_ts_normed = self.rel_ts_norm(rel_ts_processed.to(norm_dtype))
|
| 912 |
+
rel_ts_embeds = self.rel_ts_proj(rel_ts_normed.to(proj_dtype))
|
| 913 |
+
|
| 914 |
+
# Get special embeddings for Transfer events
|
| 915 |
+
transfer_specific_embeds = self._get_transfer_specific_embeddings(batch, gathered_embeds)
|
| 916 |
+
|
| 917 |
+
# Get special embeddings for Trade events
|
| 918 |
+
trade_specific_embeds = self._get_trade_specific_embeddings(batch)
|
| 919 |
+
|
| 920 |
+
# Get special embeddings for Deployer Trade events
|
| 921 |
+
deployer_trade_specific_embeds = self._get_deployer_trade_specific_embeddings(batch)
|
| 922 |
+
|
| 923 |
+
# Get special embeddings for Smart Wallet Trade events
|
| 924 |
+
smart_wallet_trade_specific_embeds = self._get_smart_wallet_trade_specific_embeddings(batch)
|
| 925 |
+
|
| 926 |
+
# Get special embeddings for PoolCreated events
|
| 927 |
+
pool_created_specific_embeds = self._get_pool_created_specific_embeddings(batch, gathered_embeds)
|
| 928 |
+
|
| 929 |
+
# Get special embeddings for LiquidityChange events
|
| 930 |
+
liquidity_change_specific_embeds = self._get_liquidity_change_specific_embeddings(batch, gathered_embeds)
|
| 931 |
+
|
| 932 |
+
# Get special embeddings for FeeCollected events
|
| 933 |
+
fee_collected_specific_embeds = self._get_fee_collected_specific_embeddings(batch)
|
| 934 |
+
|
| 935 |
+
# Get special embeddings for TokenBurn events
|
| 936 |
+
token_burn_specific_embeds = self._get_token_burn_specific_embeddings(batch)
|
| 937 |
+
|
| 938 |
+
# Get special embeddings for SupplyLock events
|
| 939 |
+
supply_lock_specific_embeds = self._get_supply_lock_specific_embeddings(batch)
|
| 940 |
+
|
| 941 |
+
# Get special embeddings for OnChain_Snapshot events
|
| 942 |
+
onchain_snapshot_specific_embeds = self._get_onchain_snapshot_specific_embeddings(batch)
|
| 943 |
+
|
| 944 |
+
# Get special embeddings for TrendingToken events
|
| 945 |
+
trending_token_specific_embeds = self._get_trending_token_specific_embeddings(batch, gathered_embeds)
|
| 946 |
+
|
| 947 |
+
# Get special embeddings for BoostedToken events
|
| 948 |
+
boosted_token_specific_embeds = self._get_boosted_token_specific_embeddings(batch, gathered_embeds)
|
| 949 |
+
|
| 950 |
+
# Get special embeddings for DexBoost_Paid events
|
| 951 |
+
dexboost_paid_specific_embeds = self._get_dexboost_paid_specific_embeddings(batch)
|
| 952 |
+
|
| 953 |
+
# --- NEW: Get embeddings for Tracker events ---
|
| 954 |
+
alphagroup_call_specific_embeds = self._get_alphagroup_call_specific_embeddings(batch)
|
| 955 |
+
channel_call_specific_embeds = self._get_channel_call_specific_embeddings(batch)
|
| 956 |
+
cexlisting_specific_embeds = self._get_cexlisting_specific_embeddings(batch)
|
| 957 |
+
|
| 958 |
+
# --- NEW: Get embeddings for Chain and Lighthouse Snapshots ---
|
| 959 |
+
chainsnapshot_specific_embeds = self._get_chainsnapshot_specific_embeddings(batch)
|
| 960 |
+
lighthousesnapshot_specific_embeds = self._get_lighthousesnapshot_specific_embeddings(batch)
|
| 961 |
+
|
| 962 |
+
migrated_specific_embeds = self._get_migrated_specific_embeddings(batch)
|
| 963 |
+
|
| 964 |
+
# --- NEW: Handle DexProfile_Updated flags separately ---
|
| 965 |
+
dexprofile_updated_flags = batch['dexprofile_updated_flags']
|
| 966 |
+
dexprofile_flags_embeds = self.dexprofile_updated_flags_proj(dexprofile_updated_flags.to(self.dtype))
|
| 967 |
+
|
| 968 |
+
# --- REFACTORED: All text-based events are handled by the SocialEncoder ---
|
| 969 |
+
# This single call will replace the inefficient loops for social, dexprofile, and global trending events.
|
| 970 |
+
# The SocialEncoder's forward pass will need to be updated to handle this.
|
| 971 |
+
textual_event_embeds = self.social_encoder(
|
| 972 |
+
batch=batch,
|
| 973 |
+
gathered_embeds=gathered_embeds
|
| 974 |
+
)
|
| 975 |
+
|
| 976 |
+
# --- NEW: Get embeddings for special context injection tokens ---
|
| 977 |
+
special_context_embeds = self._get_special_context_embeddings(batch)
|
| 978 |
+
|
| 979 |
+
# --- Combine all features ---
|
| 980 |
+
# Sum in float32 for numerical stability, then cast back to model dtype
|
| 981 |
+
components = [
|
| 982 |
+
event_embeds, ts_embeds, rel_ts_embeds,
|
| 983 |
+
gathered_embeds['wallet'], gathered_embeds['token'], gathered_embeds['original_author'], gathered_embeds['ohlc'],
|
| 984 |
+
transfer_specific_embeds, trade_specific_embeds, deployer_trade_specific_embeds, smart_wallet_trade_specific_embeds,
|
| 985 |
+
pool_created_specific_embeds, liquidity_change_specific_embeds, fee_collected_specific_embeds,
|
| 986 |
+
token_burn_specific_embeds, supply_lock_specific_embeds, onchain_snapshot_specific_embeds,
|
| 987 |
+
trending_token_specific_embeds, boosted_token_specific_embeds, dexboost_paid_specific_embeds,
|
| 988 |
+
alphagroup_call_specific_embeds, channel_call_specific_embeds, cexlisting_specific_embeds,
|
| 989 |
+
migrated_specific_embeds, special_context_embeds, gathered_embeds['holder_snapshot'], textual_event_embeds,
|
| 990 |
+
dexprofile_flags_embeds, chainsnapshot_specific_embeds, lighthousesnapshot_specific_embeds
|
| 991 |
+
]
|
| 992 |
+
inputs_embeds = sum([t.float() for t in components]).to(self.dtype)
|
| 993 |
+
|
| 994 |
+
hf_attention_mask = attention_mask.to(device=device, dtype=torch.long)
|
| 995 |
+
outputs = self.model(
|
| 996 |
+
inputs_embeds=inputs_embeds,
|
| 997 |
+
attention_mask=hf_attention_mask,
|
| 998 |
+
return_dict=True
|
| 999 |
+
)
|
| 1000 |
+
sequence_hidden = outputs.last_hidden_state
|
| 1001 |
+
pooled_states = self._pool_hidden_states(sequence_hidden, hf_attention_mask)
|
| 1002 |
+
quantile_logits = self.quantile_head(pooled_states)
|
| 1003 |
+
|
| 1004 |
+
return {
|
| 1005 |
+
'quantile_logits': quantile_logits,
|
| 1006 |
+
'pooled_states': pooled_states,
|
| 1007 |
+
'hidden_states': sequence_hidden,
|
| 1008 |
+
'attention_mask': hf_attention_mask
|
| 1009 |
+
}
|
models/multi_modal_processor.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# multi_modal_processor.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from transformers import AutoModel, AutoProcessor, AutoConfig
|
| 6 |
+
from typing import List, Union
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import requests
|
| 9 |
+
import io
|
| 10 |
+
import os
|
| 11 |
+
import traceback
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
# Suppress warnings
|
| 15 |
+
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
| 16 |
+
|
| 17 |
+
class MultiModalEncoder:
|
| 18 |
+
"""
|
| 19 |
+
Encodes text OR images into a shared, NORMALIZED embedding space
|
| 20 |
+
using google/siglip-so400m-patch16-256-i18n.
|
| 21 |
+
This class is intended for creating embeddings for vector search.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, model_id="google/siglip-so400m-patch16-256-i18n", dtype: torch.dtype = torch.bfloat16):
|
| 25 |
+
self.model_id = model_id
|
| 26 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 27 |
+
|
| 28 |
+
self.dtype = dtype
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
# --- SigLIP Loading with Config Fix ---
|
| 33 |
+
self.processor = AutoProcessor.from_pretrained(
|
| 34 |
+
self.model_id,
|
| 35 |
+
use_fast=True
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
config = AutoConfig.from_pretrained(self.model_id)
|
| 39 |
+
|
| 40 |
+
if not hasattr(config, 'projection_dim'):
|
| 41 |
+
# print("❗ Config missing projection_dim, patching...")
|
| 42 |
+
config.projection_dim = config.text_config.hidden_size
|
| 43 |
+
|
| 44 |
+
self.model = AutoModel.from_pretrained(
|
| 45 |
+
self.model_id,
|
| 46 |
+
config=config,
|
| 47 |
+
dtype=self.dtype, # Use torch_dtype for from_pretrained
|
| 48 |
+
trust_remote_code=False
|
| 49 |
+
).to(self.device).eval()
|
| 50 |
+
# -----------------------------------------------
|
| 51 |
+
|
| 52 |
+
self.embedding_dim = config.projection_dim
|
| 53 |
+
|
| 54 |
+
except Exception as e:
|
| 55 |
+
print(f"❌ Failed to load SigLIP model or components: {e}")
|
| 56 |
+
traceback.print_exc()
|
| 57 |
+
raise
|
| 58 |
+
|
| 59 |
+
@torch.no_grad()
|
| 60 |
+
def __call__(self, x: Union[List[str], List[Image.Image]]) -> torch.Tensor:
|
| 61 |
+
"""
|
| 62 |
+
Encode a batch of text or images into normalized [batch_size, embedding_dim] vectors.
|
| 63 |
+
This is correct for storing in a vector DB for cosine similarity.
|
| 64 |
+
"""
|
| 65 |
+
if not x:
|
| 66 |
+
return torch.empty(0, self.embedding_dim).to(self.device)
|
| 67 |
+
|
| 68 |
+
is_text = isinstance(x[0], str)
|
| 69 |
+
|
| 70 |
+
autocast_dtype = self.dtype if self.dtype in [torch.float16, torch.bfloat16] else None
|
| 71 |
+
|
| 72 |
+
print(f"\n[MME LOG] ENTERING __call__ for {'TEXT' if is_text else 'IMAGE'} batch of size {len(x)}")
|
| 73 |
+
print(f"[MME LOG] Input data preview: {str(x[0])[:100] if is_text else x[0]}")
|
| 74 |
+
|
| 75 |
+
with torch.amp.autocast(device_type=self.device, enabled=(self.device == 'cuda' and autocast_dtype is not None), dtype=autocast_dtype):
|
| 76 |
+
try:
|
| 77 |
+
if is_text:
|
| 78 |
+
inputs = self.processor(
|
| 79 |
+
text=x,
|
| 80 |
+
return_tensors="pt",
|
| 81 |
+
padding="max_length",
|
| 82 |
+
truncation=True
|
| 83 |
+
).to(self.device)
|
| 84 |
+
print(f"[MME LOG] Text processor output shape: {inputs['input_ids'].shape}")
|
| 85 |
+
embeddings = self.model.get_text_features(**inputs)
|
| 86 |
+
else:
|
| 87 |
+
rgb_images = [img.convert("RGB") if img.mode != 'RGB' else img for img in x]
|
| 88 |
+
inputs = self.processor(
|
| 89 |
+
images=rgb_images,
|
| 90 |
+
return_tensors="pt"
|
| 91 |
+
).to(self.device)
|
| 92 |
+
|
| 93 |
+
if 'pixel_values' in inputs and inputs['pixel_values'].dtype != self.dtype:
|
| 94 |
+
inputs['pixel_values'] = inputs['pixel_values'].to(self.dtype)
|
| 95 |
+
|
| 96 |
+
embeddings = self.model.get_image_features(**inputs)
|
| 97 |
+
|
| 98 |
+
print(f"[MME LOG] Raw model output embeddings shape: {embeddings.shape}, dtype: {embeddings.dtype}")
|
| 99 |
+
|
| 100 |
+
# <<< THIS IS THE FIX. I accidentally removed this.
|
| 101 |
+
# Normalize in float32 for numerical stability
|
| 102 |
+
embeddings = F.normalize(embeddings.float(), p=2, dim=-1)
|
| 103 |
+
print(f"[MME LOG] Normalized embeddings shape: {embeddings.shape}, dtype: {embeddings.dtype}")
|
| 104 |
+
|
| 105 |
+
final_embeddings = embeddings.to(self.dtype)
|
| 106 |
+
print(f"[MME LOG] Final embeddings shape: {final_embeddings.shape}, dtype: {final_embeddings.dtype}. EXITING __call__.")
|
| 107 |
+
return final_embeddings
|
| 108 |
+
|
| 109 |
+
except Exception as e:
|
| 110 |
+
print(f"❌ [MME LOG] FATAL ERROR during encoding {'text' if is_text else 'images'}: {e}")
|
| 111 |
+
traceback.print_exc()
|
| 112 |
+
return torch.empty(0, self.embedding_dim).to(self.device)
|
| 113 |
+
|
| 114 |
+
# --- Test block (SigLIP) ---
|
| 115 |
+
if __name__ == "__main__":
|
| 116 |
+
# This test now uses the encoder class exactly as you intend to.
|
| 117 |
+
|
| 118 |
+
MODEL_ID = "google/siglip-so400m-patch16-256-i18n"
|
| 119 |
+
print(f"\n--- MultiModalEncoder Test ({MODEL_ID}) ---")
|
| 120 |
+
|
| 121 |
+
texts = [
|
| 122 |
+
"Uranus", # Text 0
|
| 123 |
+
"Anus", # Text 1
|
| 124 |
+
"Ass", # Text 2
|
| 125 |
+
"Planet", # Text 3
|
| 126 |
+
"Dog" # Text 4
|
| 127 |
+
]
|
| 128 |
+
|
| 129 |
+
try:
|
| 130 |
+
img_urls = [
|
| 131 |
+
"https://pbs.twimg.com/media/G3ra9C8W0AAGR8V.jpg", # Image 0: Uranus meme pic
|
| 132 |
+
]
|
| 133 |
+
headers = {"User-Agent": "Mozilla/5.0"}
|
| 134 |
+
images = [
|
| 135 |
+
Image.open(io.BytesIO(requests.get(u, headers=headers).content))
|
| 136 |
+
for u in img_urls
|
| 137 |
+
]
|
| 138 |
+
|
| 139 |
+
size = 256 # Model's expected size
|
| 140 |
+
images.append(Image.new("RGB", (size, size), color="green")) # Image 1: Green Square
|
| 141 |
+
print(f"✅ Downloaded test image and created green square (size {size}x{size}).")
|
| 142 |
+
|
| 143 |
+
except Exception as e:
|
| 144 |
+
print(f"❌ Failed to load images: {e}")
|
| 145 |
+
traceback.print_exc()
|
| 146 |
+
exit()
|
| 147 |
+
|
| 148 |
+
try:
|
| 149 |
+
# 1. Initialize your encoder
|
| 150 |
+
encoder = MultiModalEncoder(model_id=MODEL_ID)
|
| 151 |
+
|
| 152 |
+
print("\n--- Encoding Texts (Separately) ---")
|
| 153 |
+
text_embeddings = encoder(texts) # Uses __call__
|
| 154 |
+
print(f"Shape: {text_embeddings.shape}")
|
| 155 |
+
|
| 156 |
+
print("\n--- Encoding Images (Separately) ---")
|
| 157 |
+
image_embeddings = encoder(images) # Uses __call__
|
| 158 |
+
print(f"Shape: {image_embeddings.shape}")
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
print("\n--- Similarity Check (Your Goal) ---")
|
| 162 |
+
|
| 163 |
+
# 2. Calculate Cosine Similarity
|
| 164 |
+
# This is just a dot product because the encoder __call__ method
|
| 165 |
+
# already normalized the vectors.
|
| 166 |
+
similarity_matrix = torch.matmul(image_embeddings.cpu(), text_embeddings.cpu().T).numpy()
|
| 167 |
+
|
| 168 |
+
np.set_printoptions(precision=4, suppress=True)
|
| 169 |
+
print("\nCosine Similarity matrix (image × text):")
|
| 170 |
+
# Row: Images (0: Uranus Pic, 1: Green)
|
| 171 |
+
# Col: Texts (0: Uranus, 1: Anus, 2: Ass, 3: Planet, 4: Dog)
|
| 172 |
+
print(similarity_matrix)
|
| 173 |
+
|
| 174 |
+
print("\nSpecific Similarity Scores (Cosine Similarity, -1.0 to 1.0):")
|
| 175 |
+
print(f"Image 0 (Uranus pic) vs Text 0 (Uranus): {similarity_matrix[0][0]:.4f}")
|
| 176 |
+
print(f"Image 0 (Uranus pic) vs Text 1 (Anus): {similarity_matrix[0][1]:.4f}")
|
| 177 |
+
print(f"Image 0 (Uranus pic) vs Text 3 (Planet): {similarity_matrix[0][3]:.4f}")
|
| 178 |
+
print(f"Image 0 (Uranus pic) vs Text 4 (Dog): {similarity_matrix[0][4]:.4f}")
|
| 179 |
+
print(f"Image 1 (Green) vs Text 4 (Dog): {similarity_matrix[1][4]:.4f}")
|
| 180 |
+
|
| 181 |
+
except Exception as e:
|
| 182 |
+
print(f"\n--- An error occurred during the SigLIP test run ---")
|
| 183 |
+
print(f"Error: {e}")
|
| 184 |
+
traceback.print_exc()
|
models/ohlc_embedder.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
# --- Import vocabulary for the test block ---
|
| 7 |
+
import models.vocabulary as vocab
|
| 8 |
+
|
| 9 |
+
class OHLCEmbedder(nn.Module):
|
| 10 |
+
"""
|
| 11 |
+
Embeds a sequence of Open and Close prices AND its interval.
|
| 12 |
+
|
| 13 |
+
FIXED: Now takes interval_ids as input and combines an
|
| 14 |
+
interval embedding with the 1D-CNN chart pattern features.
|
| 15 |
+
"""
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
# --- NEW: Interval vocab size ---
|
| 19 |
+
num_intervals: int,
|
| 20 |
+
input_channels: int = 2, # Open, Close
|
| 21 |
+
sequence_length: int = 300,
|
| 22 |
+
cnn_channels: List[int] = [16, 32, 64],
|
| 23 |
+
kernel_sizes: List[int] = [3, 3, 3],
|
| 24 |
+
# --- NEW: Interval embedding dim ---
|
| 25 |
+
interval_embed_dim: int = 32,
|
| 26 |
+
output_dim: int = 4096,
|
| 27 |
+
dtype: torch.dtype = torch.float16
|
| 28 |
+
):
|
| 29 |
+
super().__init__()
|
| 30 |
+
assert len(cnn_channels) == len(kernel_sizes), "cnn_channels and kernel_sizes must have the same length"
|
| 31 |
+
|
| 32 |
+
self.dtype = dtype
|
| 33 |
+
self.sequence_length = sequence_length
|
| 34 |
+
self.cnn_layers = nn.ModuleList()
|
| 35 |
+
self.output_dim = output_dim
|
| 36 |
+
|
| 37 |
+
in_channels = input_channels
|
| 38 |
+
current_seq_len = sequence_length
|
| 39 |
+
|
| 40 |
+
for i, (out_channels, k_size) in enumerate(zip(cnn_channels, kernel_sizes)):
|
| 41 |
+
conv = nn.Conv1d(
|
| 42 |
+
in_channels=in_channels,
|
| 43 |
+
out_channels=out_channels,
|
| 44 |
+
kernel_size=k_size,
|
| 45 |
+
padding='same'
|
| 46 |
+
)
|
| 47 |
+
self.cnn_layers.append(conv)
|
| 48 |
+
pool = nn.MaxPool1d(kernel_size=2, stride=2)
|
| 49 |
+
self.cnn_layers.append(pool)
|
| 50 |
+
current_seq_len = current_seq_len // 2
|
| 51 |
+
self.cnn_layers.append(nn.ReLU())
|
| 52 |
+
in_channels = out_channels
|
| 53 |
+
|
| 54 |
+
self.global_pool = nn.AdaptiveAvgPool1d(1)
|
| 55 |
+
|
| 56 |
+
final_cnn_channels = cnn_channels[-1]
|
| 57 |
+
|
| 58 |
+
# --- NEW: Interval Embedding Layer ---
|
| 59 |
+
self.interval_embedding = nn.Embedding(num_intervals, interval_embed_dim, padding_idx=0)
|
| 60 |
+
|
| 61 |
+
# --- NEW: MLP input dim is (CNN features + Interval features) ---
|
| 62 |
+
mlp_input_dim = final_cnn_channels + interval_embed_dim
|
| 63 |
+
|
| 64 |
+
self.mlp = nn.Sequential(
|
| 65 |
+
nn.Linear(mlp_input_dim, mlp_input_dim * 2),
|
| 66 |
+
nn.GELU(),
|
| 67 |
+
nn.LayerNorm(mlp_input_dim * 2),
|
| 68 |
+
nn.Linear(mlp_input_dim * 2, output_dim),
|
| 69 |
+
nn.LayerNorm(output_dim)
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
self.to(dtype)
|
| 73 |
+
|
| 74 |
+
def forward(self, x: torch.Tensor, interval_ids: torch.Tensor) -> torch.Tensor:
|
| 75 |
+
"""
|
| 76 |
+
Args:
|
| 77 |
+
x (torch.Tensor): Batch of normalized OHLC sequences.
|
| 78 |
+
Shape: [batch_size, 2, sequence_length]
|
| 79 |
+
interval_ids (torch.Tensor): Batch of interval IDs.
|
| 80 |
+
Shape: [batch_size]
|
| 81 |
+
Returns:
|
| 82 |
+
torch.Tensor: Batch of OHLC embeddings.
|
| 83 |
+
Shape: [batch_size, output_dim]
|
| 84 |
+
"""
|
| 85 |
+
if x.shape[1] != 2 or x.shape[2] != self.sequence_length:
|
| 86 |
+
raise ValueError(f"Input tensor shape mismatch. Expected [B, 2, {self.sequence_length}], got {x.shape}")
|
| 87 |
+
|
| 88 |
+
x = x.to(self.dtype)
|
| 89 |
+
|
| 90 |
+
# 1. Pass through CNN layers
|
| 91 |
+
for layer in self.cnn_layers:
|
| 92 |
+
x = layer(x)
|
| 93 |
+
|
| 94 |
+
# 2. Apply global average pooling
|
| 95 |
+
x = self.global_pool(x)
|
| 96 |
+
|
| 97 |
+
# 3. Flatten for MLP
|
| 98 |
+
x = x.squeeze(-1)
|
| 99 |
+
# Shape: [batch_size, final_cnn_channels]
|
| 100 |
+
|
| 101 |
+
# 4. --- NEW: Get interval embedding ---
|
| 102 |
+
interval_embed = self.interval_embedding(interval_ids)
|
| 103 |
+
# Shape: [batch_size, interval_embed_dim]
|
| 104 |
+
|
| 105 |
+
# 5. --- NEW: Combine features ---
|
| 106 |
+
combined = torch.cat([x, interval_embed], dim=1)
|
| 107 |
+
# Shape: [batch_size, final_cnn_channels + interval_embed_dim]
|
| 108 |
+
|
| 109 |
+
# 6. Pass through final MLP
|
| 110 |
+
x = self.mlp(combined)
|
| 111 |
+
# Shape: [batch_size, output_dim]
|
| 112 |
+
|
| 113 |
+
return x
|
| 114 |
+
|
models/token_encoder.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# token_encoder.py (FIXED)
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from typing import List, Any
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
from models.multi_modal_processor import MultiModalEncoder
|
| 9 |
+
from models.wallet_set_encoder import WalletSetEncoder # Using your set encoder
|
| 10 |
+
from models.vocabulary import NUM_PROTOCOLS
|
| 11 |
+
|
| 12 |
+
class TokenEncoder(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
Encodes a token's core identity into a single <TokenVibeEmbedding>.
|
| 15 |
+
|
| 16 |
+
FIXED: This version uses a robust fusion architecture and provides
|
| 17 |
+
a dynamic, smaller output dimension (e.g., 2048) suitable for
|
| 18 |
+
being an input to a larger model.
|
| 19 |
+
"""
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
multi_dim: int, # NEW: Pass the dimension directly
|
| 23 |
+
output_dim: int = 2048,
|
| 24 |
+
internal_dim: int = 1024, # INCREASED: Better balance between bottleneck and capacity
|
| 25 |
+
protocol_embed_dim: int = 64,
|
| 26 |
+
vanity_embed_dim: int = 32, # NEW: Small embedding for the vanity flag
|
| 27 |
+
nhead: int = 4,
|
| 28 |
+
num_layers: int = 1,
|
| 29 |
+
dtype: torch.dtype = torch.float16
|
| 30 |
+
):
|
| 31 |
+
"""
|
| 32 |
+
Initializes the TokenEncoder.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
siglip_dim (int): The embedding dimension of the multimodal encoder (e.g., 1152).
|
| 36 |
+
output_dim (int):
|
| 37 |
+
The final dimension of the <TokenVibeEmbedding> (e.g., 2048).
|
| 38 |
+
internal_dim (int):
|
| 39 |
+
The shared dimension for the internal fusion transformer (e.g., 1024).
|
| 40 |
+
protocol_embed_dim (int):
|
| 41 |
+
Small dimension for the protocol ID (e.g., 64).
|
| 42 |
+
vanity_embed_dim (int):
|
| 43 |
+
Small dimension for the is_vanity boolean flag.
|
| 44 |
+
nhead (int):
|
| 45 |
+
Attention heads for the fusion transformer.
|
| 46 |
+
num_layers (int):
|
| 47 |
+
Layers for the fusion transformer.
|
| 48 |
+
dtype (torch.dtype):
|
| 49 |
+
The data type (e.g., torch.float16).
|
| 50 |
+
"""
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.output_dim = output_dim
|
| 53 |
+
self.internal_dim = internal_dim
|
| 54 |
+
self.dtype = dtype
|
| 55 |
+
|
| 56 |
+
# Store SigLIP's fixed output dim (e.g., 1152)
|
| 57 |
+
self.multi_dim = multi_dim
|
| 58 |
+
|
| 59 |
+
# --- 1. Projection Layers ---
|
| 60 |
+
# Project all features to the *internal_dim*
|
| 61 |
+
self.name_proj = nn.Linear(self.multi_dim, internal_dim)
|
| 62 |
+
self.symbol_proj = nn.Linear(self.multi_dim, internal_dim)
|
| 63 |
+
self.image_proj = nn.Linear(self.multi_dim, internal_dim)
|
| 64 |
+
|
| 65 |
+
# --- 2. Categorical & Boolean Feature Embeddings ---
|
| 66 |
+
|
| 67 |
+
# Use small vocab size and small embed dim
|
| 68 |
+
self.protocol_embedding = nn.Embedding(NUM_PROTOCOLS, protocol_embed_dim)
|
| 69 |
+
# Project from small dim (64) up to internal_dim (1024)
|
| 70 |
+
self.protocol_proj = nn.Linear(protocol_embed_dim, internal_dim)
|
| 71 |
+
|
| 72 |
+
# NEW: Embedding for the is_vanity boolean flag
|
| 73 |
+
self.vanity_embedding = nn.Embedding(2, vanity_embed_dim) # 2 classes: True/False
|
| 74 |
+
self.vanity_proj = nn.Linear(vanity_embed_dim, internal_dim)
|
| 75 |
+
|
| 76 |
+
# --- 3. Fusion Encoder ---
|
| 77 |
+
# Re-use WalletSetEncoder to fuse the sequence of 5 features
|
| 78 |
+
self.fusion_transformer = WalletSetEncoder(
|
| 79 |
+
d_model=internal_dim,
|
| 80 |
+
nhead=nhead,
|
| 81 |
+
num_layers=num_layers,
|
| 82 |
+
dim_feedforward=internal_dim * 4, # Standard 4x
|
| 83 |
+
dtype=dtype
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# --- 4. Final Output Projection ---
|
| 87 |
+
# Project from the transformer's output (internal_dim)
|
| 88 |
+
# to the final, dynamic output_dim.
|
| 89 |
+
self.final_projection = nn.Sequential(
|
| 90 |
+
nn.Linear(internal_dim, internal_dim * 2),
|
| 91 |
+
nn.GELU(),
|
| 92 |
+
nn.LayerNorm(internal_dim * 2),
|
| 93 |
+
nn.Linear(internal_dim * 2, output_dim),
|
| 94 |
+
nn.LayerNorm(output_dim)
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Cast new layers to the correct dtype and device
|
| 98 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 99 |
+
self.to(device=device, dtype=dtype)
|
| 100 |
+
|
| 101 |
+
def forward(
|
| 102 |
+
self,
|
| 103 |
+
name_embeds: torch.Tensor,
|
| 104 |
+
symbol_embeds: torch.Tensor,
|
| 105 |
+
image_embeds: torch.Tensor,
|
| 106 |
+
protocol_ids: torch.Tensor,
|
| 107 |
+
is_vanity_flags: torch.Tensor,
|
| 108 |
+
) -> torch.Tensor:
|
| 109 |
+
"""
|
| 110 |
+
Processes a batch of token data to create a batch of embeddings.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
name_embeds (torch.Tensor): Pre-computed embeddings for token names. Shape: [B, siglip_dim]
|
| 114 |
+
symbol_embeds (torch.Tensor): Pre-computed embeddings for token symbols. Shape: [B, siglip_dim]
|
| 115 |
+
image_embeds (torch.Tensor): Pre-computed embeddings for token images. Shape: [B, siglip_dim]
|
| 116 |
+
protocol_ids (torch.Tensor): Batch of protocol IDs. Shape: [B]
|
| 117 |
+
is_vanity_flags (torch.Tensor): Batch of boolean flags for vanity addresses. Shape: [B]
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
torch.Tensor: The final <TokenVibeEmbedding> batch.
|
| 121 |
+
Shape: [batch_size, output_dim]
|
| 122 |
+
"""
|
| 123 |
+
device = name_embeds.device
|
| 124 |
+
batch_size = name_embeds.shape[0]
|
| 125 |
+
|
| 126 |
+
# 2. Get Protocol embedding (small)
|
| 127 |
+
print(f"\n--- [TokenEncoder LOG] ENTERING FORWARD PASS (Batch Size: {batch_size}) ---")
|
| 128 |
+
print(f"[TokenEncoder LOG] Input protocol_ids (shape {protocol_ids.shape}):\n{protocol_ids}")
|
| 129 |
+
print(f"[TokenEncoder LOG] Protocol Embedding Vocab Size: {self.protocol_embedding.num_embeddings}")
|
| 130 |
+
|
| 131 |
+
protocol_ids_long = protocol_ids.to(device, dtype=torch.long)
|
| 132 |
+
protocol_emb_raw = self.protocol_embedding(protocol_ids_long) # [B, 64]
|
| 133 |
+
print(f"[TokenEncoder LOG] Raw protocol embeddings shape: {protocol_emb_raw.shape}")
|
| 134 |
+
|
| 135 |
+
# NEW: Get vanity embedding
|
| 136 |
+
vanity_ids_long = is_vanity_flags.to(device, dtype=torch.long)
|
| 137 |
+
vanity_emb_raw = self.vanity_embedding(vanity_ids_long) # [B, 32]
|
| 138 |
+
|
| 139 |
+
# 3. Project all features to internal_dim (e.g., 1024)
|
| 140 |
+
print(f"[TokenEncoder LOG] Projecting features to internal_dim: {self.internal_dim}")
|
| 141 |
+
name_emb = self.name_proj(name_embeds)
|
| 142 |
+
symbol_emb = self.symbol_proj(symbol_embeds)
|
| 143 |
+
image_emb = self.image_proj(image_embeds)
|
| 144 |
+
protocol_emb = self.protocol_proj(protocol_emb_raw)
|
| 145 |
+
vanity_emb = self.vanity_proj(vanity_emb_raw) # NEW
|
| 146 |
+
|
| 147 |
+
# 4. Stack all projected features into a sequence
|
| 148 |
+
feature_sequence = torch.stack([
|
| 149 |
+
name_emb,
|
| 150 |
+
symbol_emb,
|
| 151 |
+
image_emb,
|
| 152 |
+
protocol_emb,
|
| 153 |
+
vanity_emb, # NEW: Add the vanity embedding to the sequence
|
| 154 |
+
], dim=1)
|
| 155 |
+
|
| 156 |
+
print(f"[TokenEncoder LOG] Stacked feature_sequence shape: {feature_sequence.shape}")
|
| 157 |
+
print(f" - name_emb shape: {name_emb.shape}")
|
| 158 |
+
print(f" - symbol_emb shape: {symbol_emb.shape}")
|
| 159 |
+
print(f" - image_emb shape: {image_emb.shape}")
|
| 160 |
+
print(f" - protocol_emb shape: {protocol_emb.shape}")
|
| 161 |
+
print(f" - vanity_emb shape: {vanity_emb.shape}") # ADDED: Log the new vanity embedding shape
|
| 162 |
+
|
| 163 |
+
# 5. Create the padding mask (all False, since we have a fixed number of features for all)
|
| 164 |
+
padding_mask = torch.zeros(batch_size, feature_sequence.shape[1], device=device, dtype=torch.bool)
|
| 165 |
+
print(f"[TokenEncoder LOG] Created padding_mask of shape: {padding_mask.shape}")
|
| 166 |
+
|
| 167 |
+
# 6. Fuse the sequence with the Transformer Encoder
|
| 168 |
+
# This returns the [CLS] token output.
|
| 169 |
+
# Shape: [B, internal_dim]
|
| 170 |
+
fused_embedding = self.fusion_transformer(
|
| 171 |
+
item_embeds=feature_sequence,
|
| 172 |
+
src_key_padding_mask=padding_mask
|
| 173 |
+
)
|
| 174 |
+
print(f"[TokenEncoder LOG] Fused embedding shape after transformer: {fused_embedding.shape}")
|
| 175 |
+
|
| 176 |
+
# 7. Project to the final output dimension
|
| 177 |
+
# Shape: [B, output_dim]
|
| 178 |
+
token_vibe_embedding = self.final_projection(fused_embedding)
|
| 179 |
+
print(f"[TokenEncoder LOG] Final token_vibe_embedding shape: {token_vibe_embedding.shape}")
|
| 180 |
+
print(f"--- [TokenEncoder LOG] EXITING FORWARD PASS ---\n")
|
| 181 |
+
|
| 182 |
+
return token_vibe_embedding
|
models/vocabulary.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# vocabulary.py
|
| 2 |
+
"""
|
| 3 |
+
Defines the vocabulary and mappings for categorical features.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
# --- Event Type Mappings ---
|
| 7 |
+
EVENT_NAMES = [
|
| 8 |
+
'__PAD__', 'Chart_Segment', 'Mint',
|
| 9 |
+
'Transfer', 'LargeTransfer',
|
| 10 |
+
'Trade',
|
| 11 |
+
'Deployer_Trade',
|
| 12 |
+
'SmartWallet_Trade',
|
| 13 |
+
'LargeTrade',
|
| 14 |
+
'PoolCreated',
|
| 15 |
+
'LiquidityChange',
|
| 16 |
+
'FeeCollected',
|
| 17 |
+
'TokenBurn',
|
| 18 |
+
'SupplyLock',
|
| 19 |
+
'OnChain_Snapshot',
|
| 20 |
+
'HolderSnapshot',
|
| 21 |
+
'TrendingToken',
|
| 22 |
+
'BoostedToken',
|
| 23 |
+
'XPost',
|
| 24 |
+
'XRetweet',
|
| 25 |
+
'XReply',
|
| 26 |
+
'XQuoteTweet',
|
| 27 |
+
'PumpReply',
|
| 28 |
+
'DexBoost_Paid',
|
| 29 |
+
'DexProfile_Updated',
|
| 30 |
+
'AlphaGroup_Call',
|
| 31 |
+
'Channel_Call',
|
| 32 |
+
'CexListing',
|
| 33 |
+
'TikTok_Trending_Hashtag',
|
| 34 |
+
'XTrending_Hashtag',
|
| 35 |
+
'ChainSnapshot',
|
| 36 |
+
'Lighthouse_Snapshot',
|
| 37 |
+
'Migrated',
|
| 38 |
+
'MIDDLE',
|
| 39 |
+
'RECENT'
|
| 40 |
+
]
|
| 41 |
+
EVENT_TO_ID = {name: i for i, name in enumerate(EVENT_NAMES)}
|
| 42 |
+
ID_TO_EVENT = {i: name for i, name in enumerate(EVENT_NAMES)}
|
| 43 |
+
NUM_EVENT_TYPES = len(EVENT_NAMES)
|
| 44 |
+
|
| 45 |
+
# --- Protocol Mappings ---
|
| 46 |
+
|
| 47 |
+
# The canonical list of protocol names
|
| 48 |
+
PROTOCOL_NAMES = [
|
| 49 |
+
"Unknown",
|
| 50 |
+
"Pump V1",
|
| 51 |
+
"Pump AMM",
|
| 52 |
+
"Bonk",
|
| 53 |
+
"Raydium CPMM"
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
PROTOCOL_TO_ID = {name: i for i, name in enumerate(PROTOCOL_NAMES)}
|
| 57 |
+
ID_TO_PROTOCOL = {i: name for i, name in enumerate(PROTOCOL_NAMES)}
|
| 58 |
+
NUM_PROTOCOLS = len(PROTOCOL_NAMES)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# --- Neo4J Link Type Mappings ---
|
| 62 |
+
# UPDATED: Added link types from your Neo4j schema
|
| 63 |
+
LINK_TYPES = [
|
| 64 |
+
"TransferLink",
|
| 65 |
+
"TransferLinkToken",
|
| 66 |
+
"BundleTradeLink",
|
| 67 |
+
"CopiedTradeLink",
|
| 68 |
+
"CoordinatedActivityLink",
|
| 69 |
+
"MintedLink",
|
| 70 |
+
"SnipedLink",
|
| 71 |
+
"LockedSupplyLink",
|
| 72 |
+
"BurnedLink",
|
| 73 |
+
"ProvidedLiquidityLink",
|
| 74 |
+
"WhaleOfLink",
|
| 75 |
+
"TopTraderOfLink",
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
LINK_TYPE_TO_ID = {name: i for i, name in enumerate(LINK_TYPES)}
|
| 79 |
+
ID_TO_LINK_TYPE = {i: name for i, name in enumerate(LINK_TYPES)}
|
| 80 |
+
NUM_LINK_TYPES = len(LINK_TYPES)
|
| 81 |
+
|
| 82 |
+
LINK_NAME_TO_TRIPLET = {
|
| 83 |
+
# Wallet <-> Wallet Links
|
| 84 |
+
"TransferLink": ('wallet', 'TransferLink', 'wallet'),
|
| 85 |
+
"BundleTradeLink": ('wallet', 'BundleTradeLink', 'wallet'),
|
| 86 |
+
"CopiedTradeLink": ('wallet', 'CopiedTradeLink', 'wallet'),
|
| 87 |
+
"CoordinatedActivityLink": ('wallet', 'CoordinatedActivityLink', 'wallet'),
|
| 88 |
+
|
| 89 |
+
# Wallet -> Token Links
|
| 90 |
+
"TransferLinkToken": ('wallet', 'TransferLinkToken', 'token'),
|
| 91 |
+
"MintedLink": ('wallet', 'MintedLink', 'token'),
|
| 92 |
+
"SnipedLink": ('wallet', 'SnipedLink', 'token'),
|
| 93 |
+
"LockedSupplyLink": ('wallet', 'LockedSupplyLink', 'token'),
|
| 94 |
+
"BurnedLink": ('wallet', 'BurnedLink', 'token'),
|
| 95 |
+
"ProvidedLiquidityLink": ('wallet', 'ProvidedLiquidityLink', 'token'),
|
| 96 |
+
"WhaleOfLink": ('wallet', 'WhaleOfLink', 'token'),
|
| 97 |
+
"TopTraderOfLink": ('wallet', 'TopTraderOfLink', 'token'),
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# --- NEW: OHLC Interval Mappings ---
|
| 102 |
+
OHLC_INTERVALS = [
|
| 103 |
+
"Unknown", # ID 0
|
| 104 |
+
"1s", # ID 1
|
| 105 |
+
"30s", # ID 2
|
| 106 |
+
]
|
| 107 |
+
|
| 108 |
+
INTERVAL_TO_ID = {name: i for i, name in enumerate(OHLC_INTERVALS)}
|
| 109 |
+
ID_TO_INTERVAL = {i: name for i, name in enumerate(OHLC_INTERVALS)}
|
| 110 |
+
NUM_OHLC_INTERVALS = len(OHLC_INTERVALS)
|
| 111 |
+
|
| 112 |
+
DEX_NAMES = [
|
| 113 |
+
"Unknown",
|
| 114 |
+
"Axiom",
|
| 115 |
+
"Bullx",
|
| 116 |
+
"OXK",
|
| 117 |
+
"Trojan",
|
| 118 |
+
"Jupyter"
|
| 119 |
+
]
|
| 120 |
+
|
| 121 |
+
DEX_TO_ID = {name: i for i, name in enumerate(DEX_NAMES)}
|
| 122 |
+
ID_TO_DEX = {i: name for i, name in enumerate(DEX_NAMES)}
|
| 123 |
+
NUM_DEX_PLATFORMS = len(DEX_NAMES)
|
| 124 |
+
|
| 125 |
+
# --- NEW: Trending List Source Mappings ---
|
| 126 |
+
TRENDING_LIST_SOURCES = [
|
| 127 |
+
"Unknown",
|
| 128 |
+
"Phantom",
|
| 129 |
+
"Dexscreener"
|
| 130 |
+
]
|
| 131 |
+
|
| 132 |
+
TRENDING_LIST_SOURCE_TO_ID = {name: i for i, name in enumerate(TRENDING_LIST_SOURCES)}
|
| 133 |
+
ID_TO_TRENDING_LIST_SOURCE = {i: name for i, name in enumerate(TRENDING_LIST_SOURCES)}
|
| 134 |
+
NUM_TRENDING_LIST_SOURCES = len(TRENDING_LIST_SOURCES)
|
| 135 |
+
|
| 136 |
+
# --- NEW: Trending List Timeframe Mappings ---
|
| 137 |
+
TRENDING_LIST_TIMEFRAMES = [
|
| 138 |
+
"Unknown",
|
| 139 |
+
"5m",
|
| 140 |
+
"1h",
|
| 141 |
+
"24h"
|
| 142 |
+
]
|
| 143 |
+
TRENDING_LIST_TIMEFRAME_TO_ID = {name: i for i, name in enumerate(TRENDING_LIST_TIMEFRAMES)}
|
| 144 |
+
ID_TO_TRENDING_LIST_TIMEFRAME = {i: name for i, name in enumerate(TRENDING_LIST_TIMEFRAMES)}
|
| 145 |
+
NUM_TRENDING_LIST_TIMEFRAMES = len(TRENDING_LIST_TIMEFRAMES)
|
| 146 |
+
|
| 147 |
+
# --- NEW: Lighthouse Snapshot Timeframe Mappings ---
|
| 148 |
+
LIGHTHOUSE_TIMEFRAMES = [
|
| 149 |
+
"Unknown",
|
| 150 |
+
"5m",
|
| 151 |
+
"1h",
|
| 152 |
+
"6h",
|
| 153 |
+
"24h"
|
| 154 |
+
]
|
| 155 |
+
LIGHTHOUSE_TIMEFRAME_TO_ID = {name: i for i, name in enumerate(LIGHTHOUSE_TIMEFRAMES)}
|
| 156 |
+
NUM_LIGHTHOUSE_TIMEFRAMES = len(LIGHTHOUSE_TIMEFRAMES)
|
| 157 |
+
|
| 158 |
+
# --- NEW: TrackerEncoder Vocabularies ---
|
| 159 |
+
|
| 160 |
+
# Alpha Groups (Discord)
|
| 161 |
+
ALPHA_GROUPS = [
|
| 162 |
+
"unknown",
|
| 163 |
+
"Potion",
|
| 164 |
+
"Serenity",
|
| 165 |
+
"Digi World"
|
| 166 |
+
]
|
| 167 |
+
ALPHA_GROUPS_TO_ID = {name: i for i, name in enumerate(ALPHA_GROUPS)}
|
| 168 |
+
ID_TO_ALPHA_GROUPS = {i: name for i, name in enumerate(ALPHA_GROUPS)}
|
| 169 |
+
NUM_ALPHA_GROUPS = len(ALPHA_GROUPS)
|
| 170 |
+
|
| 171 |
+
# Call Channels (Telegram)
|
| 172 |
+
CALL_CHANNELS = [
|
| 173 |
+
"unknown",
|
| 174 |
+
"MarcosCalls",
|
| 175 |
+
"kobecalls",
|
| 176 |
+
"DEGEMSCALLS"
|
| 177 |
+
]
|
| 178 |
+
CALL_CHANNELS_TO_ID = {name: i for i, name in enumerate(CALL_CHANNELS)}
|
| 179 |
+
ID_TO_CALL_CHANNELS = {i: name for i, name in enumerate(CALL_CHANNELS)}
|
| 180 |
+
NUM_CALL_CHANNELS = len(CALL_CHANNELS)
|
| 181 |
+
|
| 182 |
+
# CEX Exchanges
|
| 183 |
+
EXCHANGES = [
|
| 184 |
+
"unknown", "mexc", "weex", "binance", "kraken"
|
| 185 |
+
]
|
| 186 |
+
EXCHANGES_TO_ID = {name: i for i, name in enumerate(EXCHANGES)}
|
| 187 |
+
ID_TO_EXCHANGES = {i: name for i, name in enumerate(EXCHANGES)}
|
| 188 |
+
NUM_EXCHANGES = len(EXCHANGES)
|
models/wallet_encoder.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from typing import List, Dict, Any, Optional
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
# We assume these helper modules are in the same directory
|
| 8 |
+
from models.multi_modal_processor import MultiModalEncoder
|
| 9 |
+
from models.wallet_set_encoder import WalletSetEncoder
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class WalletEncoder(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
Encodes a wallet's full identity into a single <WalletEmbedding>.
|
| 15 |
+
UPDATED: Aligned with the final feature spec.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
encoder: MultiModalEncoder ,
|
| 21 |
+
d_model: int = 2048, # Standardized to d_model
|
| 22 |
+
token_vibe_dim: int = 2048, # Expects token vibe of d_model
|
| 23 |
+
set_encoder_nhead: int = 8,
|
| 24 |
+
set_encoder_nlayers: int = 2,
|
| 25 |
+
dtype: torch.dtype = torch.float16
|
| 26 |
+
):
|
| 27 |
+
"""
|
| 28 |
+
Initializes the WalletEncoder.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
d_model (int): The final output dimension (e.g., 4096).
|
| 32 |
+
token_vibe_dim (int): The dimension of the pre-computed
|
| 33 |
+
<TokenVibeEmbedding> (e.g., 1024).
|
| 34 |
+
encoder (MultiModalEncoder): Instantiated SigLIP encoder.
|
| 35 |
+
time_encoder (ContextualTimeEncoder): Instantiated time encoder.
|
| 36 |
+
set_encoder_nhead (int): Attention heads for set encoders.
|
| 37 |
+
set_encoder_nlayers (int): Transformer layers for set encoders.
|
| 38 |
+
dtype (torch.dtype): Data type.
|
| 39 |
+
"""
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.d_model = d_model
|
| 42 |
+
self.dtype = dtype
|
| 43 |
+
self.encoder = encoder
|
| 44 |
+
|
| 45 |
+
# --- Dimensions ---
|
| 46 |
+
self.token_vibe_dim = token_vibe_dim
|
| 47 |
+
self.mmp_dim = self.encoder.embedding_dim # 1152
|
| 48 |
+
|
| 49 |
+
# === 1. Profile Encoder (FIXED) ===
|
| 50 |
+
# 1 age + 5 deployer_stats + 1 balance + 4 lifetime_counts +
|
| 51 |
+
# 3 lifetime_trading + 12 1d_stats + 12 7d_stats = 38
|
| 52 |
+
self.profile_numerical_features = 38
|
| 53 |
+
self.profile_num_norm = nn.LayerNorm(self.profile_numerical_features)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# FIXED: Input dim no longer has bool embed or deployed tokens embed
|
| 57 |
+
profile_mlp_in_dim = self.profile_numerical_features # 38
|
| 58 |
+
self.profile_encoder_mlp = self._build_mlp(profile_mlp_in_dim, d_model)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# === 2. Social Encoder (FIXED) ===
|
| 63 |
+
# 4 booleans: has_pf, has_twitter, has_telegram, is_exchange_wallet
|
| 64 |
+
self.social_bool_embed = nn.Embedding(2, 16)
|
| 65 |
+
# FIXED: Input dim is (4 * 16) + mmp_dim
|
| 66 |
+
social_mlp_in_dim = (16 * 4) + self.mmp_dim # username embed
|
| 67 |
+
self.social_encoder_mlp = self._build_mlp(social_mlp_in_dim, d_model)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# === 3. Holdings Encoder (FIXED) ===
|
| 71 |
+
# 11 original stats + 1 holding_time = 12
|
| 72 |
+
self.holding_numerical_features = 12
|
| 73 |
+
self.holding_num_norm = nn.LayerNorm(self.holding_numerical_features)
|
| 74 |
+
|
| 75 |
+
# FIXED: Input dim no longer uses time_encoder
|
| 76 |
+
holding_row_in_dim = (
|
| 77 |
+
self.token_vibe_dim + # <TokenVibeEmbedding>
|
| 78 |
+
self.holding_numerical_features # 12
|
| 79 |
+
)
|
| 80 |
+
self.holding_row_encoder_mlp = self._build_mlp(holding_row_in_dim, d_model)
|
| 81 |
+
|
| 82 |
+
self.holdings_set_encoder = WalletSetEncoder(
|
| 83 |
+
d_model, set_encoder_nhead, set_encoder_nlayers, dtype=dtype
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# === 5. Final Fusion Encoder (Unchanged) ===
|
| 88 |
+
# Still fuses 4 components: Profile, Social, Holdings, Graph
|
| 89 |
+
self.fusion_mlp = nn.Sequential(
|
| 90 |
+
nn.Linear(d_model * 3, d_model * 2), # Input is d_model * 3
|
| 91 |
+
nn.GELU(),
|
| 92 |
+
nn.LayerNorm(d_model * 2),
|
| 93 |
+
nn.Linear(d_model * 2, d_model),
|
| 94 |
+
nn.LayerNorm(d_model)
|
| 95 |
+
)
|
| 96 |
+
self.to(dtype)
|
| 97 |
+
|
| 98 |
+
def _build_mlp(self, in_dim, out_dim):
|
| 99 |
+
return nn.Sequential(
|
| 100 |
+
nn.Linear(in_dim, out_dim * 2),
|
| 101 |
+
nn.GELU(),
|
| 102 |
+
nn.LayerNorm(out_dim * 2),
|
| 103 |
+
nn.Linear(out_dim * 2, out_dim),
|
| 104 |
+
).to(self.dtype)
|
| 105 |
+
|
| 106 |
+
def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor:
|
| 107 |
+
# Log-normalizes numerical features (like age, stats, etc.)
|
| 108 |
+
return torch.sign(x) * torch.log1p(torch.abs(x))
|
| 109 |
+
|
| 110 |
+
def _get_device(self) -> torch.device:
|
| 111 |
+
return self.encoder.device
|
| 112 |
+
|
| 113 |
+
def forward(
|
| 114 |
+
self,
|
| 115 |
+
profile_rows: List[Dict[str, Any]],
|
| 116 |
+
social_rows: List[Dict[str, Any]],
|
| 117 |
+
holdings_batch: List[List[Dict[str, Any]]],
|
| 118 |
+
token_vibe_lookup: Dict[str, torch.Tensor],
|
| 119 |
+
embedding_pool: torch.Tensor,
|
| 120 |
+
username_embed_indices: torch.Tensor
|
| 121 |
+
) -> torch.Tensor:
|
| 122 |
+
device = self._get_device()
|
| 123 |
+
|
| 124 |
+
profile_embed = self._encode_profile_batch(profile_rows, device)
|
| 125 |
+
social_embed = self._encode_social_batch(social_rows, embedding_pool, username_embed_indices, device)
|
| 126 |
+
holdings_embed = self._encode_holdings_batch(holdings_batch, token_vibe_lookup, device)
|
| 127 |
+
|
| 128 |
+
fused = torch.cat([profile_embed, social_embed, holdings_embed], dim=1)
|
| 129 |
+
return self.fusion_mlp(fused)
|
| 130 |
+
|
| 131 |
+
def _encode_profile_batch(self, profile_rows, device):
|
| 132 |
+
batch_size = len(profile_rows)
|
| 133 |
+
# FIXED: 38 numerical features
|
| 134 |
+
num_tensor = torch.zeros(batch_size, self.profile_numerical_features, device=device, dtype=self.dtype)
|
| 135 |
+
# bool_tensor removed
|
| 136 |
+
# time_tensor removed
|
| 137 |
+
|
| 138 |
+
for i, row in enumerate(profile_rows):
|
| 139 |
+
# A: Numerical (FIXED: 38 features, MUST be present)
|
| 140 |
+
num_data = [
|
| 141 |
+
# 1. Age
|
| 142 |
+
row.get('age', 0.0),
|
| 143 |
+
# 2. Deployed Token Aggregates (5)
|
| 144 |
+
row.get('deployed_tokens_count', 0.0),
|
| 145 |
+
row.get('deployed_tokens_migrated_pct', 0.0),
|
| 146 |
+
row.get('deployed_tokens_avg_lifetime_sec', 0.0),
|
| 147 |
+
row.get('deployed_tokens_avg_peak_mc_usd', 0.0),
|
| 148 |
+
row.get('deployed_tokens_median_peak_mc_usd', 0.0),
|
| 149 |
+
# 3. Balance (1)
|
| 150 |
+
row.get('balance', 0.0),
|
| 151 |
+
# 4. Lifetime Transaction Counts (4)
|
| 152 |
+
row.get('transfers_in_count', 0.0), row.get('transfers_out_count', 0.0),
|
| 153 |
+
row.get('spl_transfers_in_count', 0.0), row.get('spl_transfers_out_count', 0.0),
|
| 154 |
+
# 5. Lifetime Trading Stats (3)
|
| 155 |
+
row.get('total_buys_count', 0.0), row.get('total_sells_count', 0.0),
|
| 156 |
+
row.get('total_winrate', 0.0),
|
| 157 |
+
# 6. 1-Day Stats (12)
|
| 158 |
+
row.get('stats_1d_realized_profit_sol', 0.0), row.get('stats_1d_realized_profit_pnl', 0.0),
|
| 159 |
+
row.get('stats_1d_buy_count', 0.0), row.get('stats_1d_sell_count', 0.0),
|
| 160 |
+
row.get('stats_1d_transfer_in_count', 0.0), row.get('stats_1d_transfer_out_count', 0.0),
|
| 161 |
+
row.get('stats_1d_avg_holding_period', 0.0), row.get('stats_1d_total_bought_cost_sol', 0.0),
|
| 162 |
+
row.get('stats_1d_total_sold_income_sol', 0.0), row.get('stats_1d_total_fee', 0.0),
|
| 163 |
+
row.get('stats_1d_winrate', 0.0), row.get('stats_1d_tokens_traded', 0.0),
|
| 164 |
+
# 7. 7-Day Stats (12)
|
| 165 |
+
row.get('stats_7d_realized_profit_sol', 0.0), row.get('stats_7d_realized_profit_pnl', 0.0),
|
| 166 |
+
row.get('stats_7d_buy_count', 0.0), row.get('stats_7d_sell_count', 0.0),
|
| 167 |
+
row.get('stats_7d_transfer_in_count', 0.0), row.get('stats_7d_transfer_out_count', 0.0),
|
| 168 |
+
row.get('stats_7d_avg_holding_period', 0.0), row.get('stats_7d_total_bought_cost_sol', 0.0),
|
| 169 |
+
row.get('stats_7d_total_sold_income_sol', 0.0), row.get('stats_7d_total_fee', 0.0),
|
| 170 |
+
row.get('stats_7d_winrate', 0.0), row.get('stats_7d_tokens_traded', 0.0),
|
| 171 |
+
]
|
| 172 |
+
num_tensor[i] = torch.tensor(num_data, dtype=self.dtype)
|
| 173 |
+
|
| 174 |
+
# C: Booleans and deployed_tokens lists are GONE
|
| 175 |
+
|
| 176 |
+
# Log-normalize all numerical features (age, stats, etc.)
|
| 177 |
+
num_embed = self.profile_num_norm(self._safe_signed_log(num_tensor))
|
| 178 |
+
|
| 179 |
+
# The profile fused tensor is now just the numerical embeddings
|
| 180 |
+
profile_fused = num_embed
|
| 181 |
+
return self.profile_encoder_mlp(profile_fused)
|
| 182 |
+
|
| 183 |
+
def _encode_social_batch(self, social_rows, embedding_pool, username_embed_indices, device):
|
| 184 |
+
batch_size = len(social_rows)
|
| 185 |
+
# FIXED: 4 boolean features
|
| 186 |
+
bool_tensor = torch.zeros(batch_size, 4, device=device, dtype=torch.long)
|
| 187 |
+
for i, row in enumerate(social_rows):
|
| 188 |
+
# All features MUST be present
|
| 189 |
+
bool_tensor[i, 0] = 1 if row['has_pf_profile'] else 0
|
| 190 |
+
bool_tensor[i, 1] = 1 if row['has_twitter'] else 0
|
| 191 |
+
bool_tensor[i, 2] = 1 if row['has_telegram'] else 0
|
| 192 |
+
# FIXED: Added is_exchange_wallet
|
| 193 |
+
bool_tensor[i, 3] = 1 if row['is_exchange_wallet'] else 0
|
| 194 |
+
|
| 195 |
+
bool_embeds = self.social_bool_embed(bool_tensor).view(batch_size, -1) # [B, 64]
|
| 196 |
+
# --- NEW: Look up pre-computed username embeddings ---
|
| 197 |
+
# --- FIXED: Handle case where embedding_pool is empty ---
|
| 198 |
+
if embedding_pool.numel() > 0:
|
| 199 |
+
# SAFETY: build a padded view so missing indices (-1) map to a zero vector
|
| 200 |
+
pad_row = torch.zeros(1, embedding_pool.size(1), device=device, dtype=embedding_pool.dtype)
|
| 201 |
+
pool_padded = torch.cat([pad_row, embedding_pool], dim=0)
|
| 202 |
+
shifted_idx = torch.where(username_embed_indices >= 0, username_embed_indices + 1, torch.zeros_like(username_embed_indices))
|
| 203 |
+
username_embed = F.embedding(shifted_idx, pool_padded)
|
| 204 |
+
else:
|
| 205 |
+
# If there are no embeddings, create a zero tensor of the correct shape
|
| 206 |
+
username_embed = torch.zeros(batch_size, self.mmp_dim, device=device, dtype=self.dtype)
|
| 207 |
+
social_fused = torch.cat([bool_embeds, username_embed], dim=1)
|
| 208 |
+
return self.social_encoder_mlp(social_fused)
|
| 209 |
+
|
| 210 |
+
def _encode_holdings_batch(self, holdings_batch, token_vibe_lookup, device):
|
| 211 |
+
batch_size = len(holdings_batch)
|
| 212 |
+
max_len = max(len(h) for h in holdings_batch) if any(holdings_batch) else 1
|
| 213 |
+
seq_embeds = torch.zeros(batch_size, max_len, self.d_model, device=device, dtype=self.dtype)
|
| 214 |
+
mask = torch.ones(batch_size, max_len, device=device, dtype=torch.bool)
|
| 215 |
+
default_vibe = torch.zeros(self.token_vibe_dim, device=device, dtype=self.dtype)
|
| 216 |
+
|
| 217 |
+
for i, holdings in enumerate(holdings_batch):
|
| 218 |
+
if not holdings: continue
|
| 219 |
+
h_len = min(len(holdings), max_len)
|
| 220 |
+
holdings = holdings[:h_len]
|
| 221 |
+
|
| 222 |
+
# --- FIXED: Safely get vibes, using default if mint_address is missing or not in lookup ---
|
| 223 |
+
vibes = [token_vibe_lookup.get(row['mint_address'], default_vibe) for row in holdings if 'mint_address' in row]
|
| 224 |
+
if not vibes: continue # Skip if no valid holdings with vibes
|
| 225 |
+
vibe_tensor = torch.stack(vibes)
|
| 226 |
+
|
| 227 |
+
# time_tensor removed
|
| 228 |
+
|
| 229 |
+
num_data_list = []
|
| 230 |
+
for row in holdings:
|
| 231 |
+
# FIXED: All 12 numerical features MUST be present
|
| 232 |
+
num_data = [
|
| 233 |
+
# Use .get() with a 0.0 default for safety
|
| 234 |
+
row.get('holding_time', 0.0),
|
| 235 |
+
row.get('balance_pct_to_supply', 0.0),
|
| 236 |
+
row.get('history_bought_cost_sol', 0.0), # Corrected key from schema
|
| 237 |
+
row.get('bought_amount_sol_pct_to_native_balance', 0.0), # This key is not in schema, will default to 0
|
| 238 |
+
row.get('history_total_buys', 0.0),
|
| 239 |
+
row.get('history_total_sells', 0.0),
|
| 240 |
+
row.get('realized_profit_pnl', 0.0),
|
| 241 |
+
row.get('realized_profit_sol', 0.0),
|
| 242 |
+
row.get('history_transfer_in', 0.0),
|
| 243 |
+
row.get('history_transfer_out', 0.0),
|
| 244 |
+
row.get('avarage_trade_gap_seconds', 0.0),
|
| 245 |
+
row.get('total_fees', 0.0) # Corrected key from schema
|
| 246 |
+
]
|
| 247 |
+
num_data_list.append(num_data)
|
| 248 |
+
|
| 249 |
+
num_tensor = torch.tensor(num_data_list, device=device, dtype=self.dtype)
|
| 250 |
+
|
| 251 |
+
# Log-normalize all numerical features (holding_time, stats, etc.)
|
| 252 |
+
num_embed = self.holding_num_norm(self._safe_signed_log(num_tensor))
|
| 253 |
+
|
| 254 |
+
# time_embed removed
|
| 255 |
+
|
| 256 |
+
# FIXED: Fused tensor no longer has time_embed
|
| 257 |
+
fused_rows = torch.cat([vibe_tensor, num_embed], dim=1)
|
| 258 |
+
encoded_rows = self.holding_row_encoder_mlp(fused_rows)
|
| 259 |
+
seq_embeds[i, :h_len] = encoded_rows
|
| 260 |
+
mask[i, :h_len] = False
|
| 261 |
+
|
| 262 |
+
return self.holdings_set_encoder(seq_embeds, mask)
|
models/wallet_set_encoder.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class WalletSetEncoder(nn.Module):
|
| 5 |
+
"""
|
| 6 |
+
Encodes a variable-length set of embeddings into a single fixed-size vector
|
| 7 |
+
using a Transformer encoder and a [CLS] token.
|
| 8 |
+
|
| 9 |
+
This is used to pool:
|
| 10 |
+
1. A wallet's `wallet_holdings` (a set of [holding_embeds]).
|
| 11 |
+
2. A wallet's `Neo4J links` (a set of [link_embeds]).
|
| 12 |
+
3. A wallet's `deployed_tokens` (a set of [token_name_embeds]).
|
| 13 |
+
"""
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
d_model: int,
|
| 17 |
+
nhead: int,
|
| 18 |
+
num_layers: int,
|
| 19 |
+
dim_feedforward: int = 2048,
|
| 20 |
+
dropout: float = 0.1,
|
| 21 |
+
dtype: torch.dtype = torch.float16
|
| 22 |
+
):
|
| 23 |
+
"""
|
| 24 |
+
Initializes the Set Encoder.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
d_model (int): The input/output dimension of the embeddings.
|
| 28 |
+
nhead (int): Number of attention heads.
|
| 29 |
+
num_layers (int): Number of transformer layers.
|
| 30 |
+
dim_feedforward (int): Hidden dimension of the feedforward network.
|
| 31 |
+
dropout (float): Dropout rate.
|
| 32 |
+
dtype (torch.dtype): Data type.
|
| 33 |
+
"""
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.d_model = d_model
|
| 36 |
+
self.dtype = dtype
|
| 37 |
+
|
| 38 |
+
# The learnable [CLS] token, which will aggregate the set representation
|
| 39 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
|
| 40 |
+
nn.init.normal_(self.cls_token, std=0.02)
|
| 41 |
+
|
| 42 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 43 |
+
d_model=d_model,
|
| 44 |
+
nhead=nhead,
|
| 45 |
+
dim_feedforward=dim_feedforward,
|
| 46 |
+
dropout=dropout,
|
| 47 |
+
batch_first=True
|
| 48 |
+
)
|
| 49 |
+
self.transformer_encoder = nn.TransformerEncoder(
|
| 50 |
+
encoder_layer,
|
| 51 |
+
num_layers=num_layers
|
| 52 |
+
)
|
| 53 |
+
self.output_norm = nn.LayerNorm(d_model)
|
| 54 |
+
|
| 55 |
+
self.to(dtype)
|
| 56 |
+
|
| 57 |
+
def forward(
|
| 58 |
+
self,
|
| 59 |
+
item_embeds: torch.Tensor,
|
| 60 |
+
src_key_padding_mask: torch.Tensor
|
| 61 |
+
) -> torch.Tensor:
|
| 62 |
+
"""
|
| 63 |
+
Forward pass.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
item_embeds (torch.Tensor):
|
| 67 |
+
The batch of item embeddings.
|
| 68 |
+
Shape: [batch_size, seq_len, d_model]
|
| 69 |
+
src_key_padding_mask (torch.Tensor):
|
| 70 |
+
The boolean padding mask for the items, where True indicates
|
| 71 |
+
a padded position that should be ignored.
|
| 72 |
+
Shape: [batch_size, seq_len]
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
torch.Tensor: The pooled set embedding.
|
| 76 |
+
Shape: [batch_size, d_model]
|
| 77 |
+
"""
|
| 78 |
+
batch_size = item_embeds.shape[0]
|
| 79 |
+
|
| 80 |
+
# 1. Create [CLS] token batch and concatenate with item embeddings
|
| 81 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1).to(self.dtype)
|
| 82 |
+
x = torch.cat([cls_tokens, item_embeds], dim=1)
|
| 83 |
+
|
| 84 |
+
# 2. Create the mask for the [CLS] token (it is never masked)
|
| 85 |
+
cls_mask = torch.zeros(batch_size, 1, device=src_key_padding_mask.device, dtype=torch.bool)
|
| 86 |
+
|
| 87 |
+
# 3. Concatenate the [CLS] mask with the item mask
|
| 88 |
+
full_padding_mask = torch.cat([cls_mask, src_key_padding_mask], dim=1)
|
| 89 |
+
|
| 90 |
+
# 4. Pass through Transformer
|
| 91 |
+
transformer_output = self.transformer_encoder(
|
| 92 |
+
x,
|
| 93 |
+
src_key_padding_mask=full_padding_mask
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# 5. Extract the output of the [CLS] token (the first token in the sequence)
|
| 97 |
+
cls_output = transformer_output[:, 0, :]
|
| 98 |
+
|
| 99 |
+
return self.output_norm(cls_output)
|
neo4j.rs
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Nodes
|
| 2 |
+
|
| 3 |
+
pub struct Token {
|
| 4 |
+
address: String,
|
| 5 |
+
}
|
| 6 |
+
|
| 7 |
+
pub struct Wallet {
|
| 8 |
+
address: String,
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
// Links
|
| 12 |
+
|
| 13 |
+
/// Tracks direct capital flow and identifies funding chains.
|
| 14 |
+
pub struct TransferLink {
|
| 15 |
+
pub signature: String,
|
| 16 |
+
pub source: String,
|
| 17 |
+
pub destination: String,
|
| 18 |
+
pub mint: String,
|
| 19 |
+
pub timestamp: i64,
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
/// Identifies wallets trading the same token in the same slot.
|
| 23 |
+
pub struct BundleTradeLink {
|
| 24 |
+
pub signatures: Vec<String>,
|
| 25 |
+
pub wallet_a: String,
|
| 26 |
+
pub wallet_b: String,
|
| 27 |
+
pub mint: String,
|
| 28 |
+
pub slot: i64,
|
| 29 |
+
pub timestamp: i64,
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
/// Reveals a behavioral pattern of one wallet mirroring another's successful trade.
|
| 33 |
+
pub struct CopiedTradeLink {
|
| 34 |
+
pub leader_buy_sig: String,
|
| 35 |
+
pub leader_sell_sig: String,
|
| 36 |
+
pub follower_buy_sig: String,
|
| 37 |
+
pub follower_sell_sig: String,
|
| 38 |
+
pub follower: String,
|
| 39 |
+
pub leader: String,
|
| 40 |
+
pub mint: String,
|
| 41 |
+
pub time_gap_on_buy_sec: i64,
|
| 42 |
+
pub time_gap_on_sell_sec: i64,
|
| 43 |
+
pub leader_pnl: f64,
|
| 44 |
+
pub follower_pnl: f64,
|
| 45 |
+
|
| 46 |
+
pub leader_buy_total: f64,
|
| 47 |
+
pub leader_sell_total: f64,
|
| 48 |
+
|
| 49 |
+
pub follower_buy_total: f64,
|
| 50 |
+
pub follower_sell_total: f64,
|
| 51 |
+
pub follower_buy_slippage: f32,
|
| 52 |
+
pub follower_sell_slippage: f32,
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
/// Represents a link where a group of wallets re-engage with a token in a coordinated manner.
|
| 56 |
+
pub struct CoordinatedActivityLink {
|
| 57 |
+
pub leader_first_sig: String,
|
| 58 |
+
pub leader_second_sig: String,
|
| 59 |
+
pub follower_first_sig: String,
|
| 60 |
+
pub follower_second_sig: String,
|
| 61 |
+
pub follower: String,
|
| 62 |
+
pub leader: String,
|
| 63 |
+
pub mint: String,
|
| 64 |
+
pub time_gap_on_first_sec: i64,
|
| 65 |
+
pub time_gap_on_second_sec: i64,
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
/// Links a token to its original creator.
|
| 69 |
+
pub struct MintedLink {
|
| 70 |
+
pub signature: String,
|
| 71 |
+
pub timestamp: i64,
|
| 72 |
+
pub buy_amount: f64,
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
/// Connects a token to its successful first-movers.
|
| 76 |
+
pub struct SnipedLink {
|
| 77 |
+
pub signature: String,
|
| 78 |
+
pub rank: i64,
|
| 79 |
+
pub sniped_amount: f64,
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
/// Represents connection between wallet that locked supply.
|
| 83 |
+
pub struct LockedSupplyLink {
|
| 84 |
+
pub signature: String,
|
| 85 |
+
pub amount: f64,
|
| 86 |
+
pub unlock_timestamp: u64,
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
/// link of the wallet that burned tokens.
|
| 90 |
+
pub struct BurnedLink {
|
| 91 |
+
pub signature: String,
|
| 92 |
+
pub amount: f64,
|
| 93 |
+
pub timestamp: i64,
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
/// Identifies wallets that provided liquidity, signaling high conviction.
|
| 97 |
+
pub struct ProvidedLiquidityLink {
|
| 98 |
+
pub signature: String,
|
| 99 |
+
pub wallet: String,
|
| 100 |
+
pub token: String,
|
| 101 |
+
pub pool_address: String,
|
| 102 |
+
pub amount_base: f64,
|
| 103 |
+
pub amount_quote: f64,
|
| 104 |
+
pub timestamp: i64,
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
/// A derived link connecting a token to its largest holders.
|
| 108 |
+
pub struct WhaleOfLink {
|
| 109 |
+
pub wallet: String,
|
| 110 |
+
pub token: String,
|
| 111 |
+
pub holding_pct_at_creation: f32, // Holding % when the link was made
|
| 112 |
+
pub ath_usd_at_creation: f64, // Token's ATH when the link was made
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
/// A derived link connecting a token to its most profitable traders.
|
| 116 |
+
pub struct TopTraderOfLink {
|
| 117 |
+
pub wallet: String,
|
| 118 |
+
pub token: String,
|
| 119 |
+
pub pnl_at_creation: f64, // The PNL that first triggered the link
|
| 120 |
+
pub ath_usd_at_creation: f64, // Token's ATH when the link was made
|
| 121 |
+
}
|
ohlc_stats.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6f56037cf2ad8502213ee2c8470c314eef83a4cd93063290581ef45fadea5d48
|
| 3 |
+
size 1660
|
onchain.sql
ADDED
|
@@ -0,0 +1,599 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CREATE TABLE trades
|
| 2 |
+
(
|
| 3 |
+
timestamp DateTime('UTC'),
|
| 4 |
+
signature String,
|
| 5 |
+
|
| 6 |
+
slot UInt64,
|
| 7 |
+
transaction_index UInt32,
|
| 8 |
+
instruction_index UInt16,
|
| 9 |
+
success Boolean,
|
| 10 |
+
error Nullable(String),
|
| 11 |
+
|
| 12 |
+
-- Fee Structure
|
| 13 |
+
priority_fee Float64,
|
| 14 |
+
bribe_fee Float64,
|
| 15 |
+
coin_creator_fee Float64,
|
| 16 |
+
mev_protection UInt8,
|
| 17 |
+
|
| 18 |
+
-- Parties
|
| 19 |
+
maker String,
|
| 20 |
+
|
| 21 |
+
-- Balances (Pre & Post)
|
| 22 |
+
base_balance Float64,
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
quote_balance Float64,
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
-- Trade Semantics
|
| 29 |
+
trade_type UInt8,
|
| 30 |
+
protocol UInt8,
|
| 31 |
+
platform UInt8,
|
| 32 |
+
|
| 33 |
+
-- Asset Info
|
| 34 |
+
pool_address String,
|
| 35 |
+
base_address String,
|
| 36 |
+
quote_address String,
|
| 37 |
+
|
| 38 |
+
-- Trade Details
|
| 39 |
+
slippage Float32,
|
| 40 |
+
price_impact Float32,
|
| 41 |
+
|
| 42 |
+
base_amount UInt64,
|
| 43 |
+
quote_amount UInt64,
|
| 44 |
+
|
| 45 |
+
price Float64,
|
| 46 |
+
price_usd Float64,
|
| 47 |
+
|
| 48 |
+
total Float64,
|
| 49 |
+
total_usd Float64
|
| 50 |
+
|
| 51 |
+
)
|
| 52 |
+
ENGINE = MergeTree()
|
| 53 |
+
ORDER BY (base_address, timestamp, maker, signature);
|
| 54 |
+
|
| 55 |
+
--- mint
|
| 56 |
+
CREATE TABLE mints
|
| 57 |
+
(
|
| 58 |
+
-- === Transaction Details ===
|
| 59 |
+
-- Solana signature is usually 88 characters, but we use String for flexibility.
|
| 60 |
+
signature String,
|
| 61 |
+
-- Converted to DateTime for easier time-based operations in ClickHouse.
|
| 62 |
+
timestamp DateTime('UTC'),
|
| 63 |
+
slot UInt64,
|
| 64 |
+
success Boolean,
|
| 65 |
+
error Nullable(String),
|
| 66 |
+
priority_fee Float64,
|
| 67 |
+
|
| 68 |
+
-- === Protocol & Platform ===
|
| 69 |
+
-- Protocol codes: 0=Unknown, 1=PumpFunLaunchpad, 2=RaydiumLaunchpad,
|
| 70 |
+
-- 3=PumpFunAMM, 4=RaydiumCPMM, 5=MeteoraBonding
|
| 71 |
+
protocol UInt8,
|
| 72 |
+
|
| 73 |
+
-- === Mint & Pool Details ===
|
| 74 |
+
mint_address String,
|
| 75 |
+
creator_address String,
|
| 76 |
+
pool_address String,
|
| 77 |
+
|
| 78 |
+
-- === Liquidity Details ===
|
| 79 |
+
initial_base_liquidity UInt64,
|
| 80 |
+
initial_quote_liquidity UInt64,
|
| 81 |
+
|
| 82 |
+
-- === Token Metadata ===
|
| 83 |
+
token_name Nullable(String),
|
| 84 |
+
token_symbol Nullable(String),
|
| 85 |
+
token_uri Nullable(String),
|
| 86 |
+
token_decimals UInt8,
|
| 87 |
+
total_supply UInt64,
|
| 88 |
+
|
| 89 |
+
is_mutable Boolean,
|
| 90 |
+
update_authority Nullable(String),
|
| 91 |
+
mint_authority Nullable(String),
|
| 92 |
+
freeze_authority Nullable(String),
|
| 93 |
+
|
| 94 |
+
)
|
| 95 |
+
ENGINE = MergeTree()
|
| 96 |
+
ORDER BY (timestamp, creator_address, mint_address);
|
| 97 |
+
|
| 98 |
+
CREATE TABLE migrations
|
| 99 |
+
(
|
| 100 |
+
-- Transaction Details
|
| 101 |
+
timestamp DateTime('UTC'),
|
| 102 |
+
|
| 103 |
+
signature String,
|
| 104 |
+
slot UInt64,
|
| 105 |
+
success Boolean,
|
| 106 |
+
error Nullable(String),
|
| 107 |
+
priority_fee Float64,
|
| 108 |
+
|
| 109 |
+
-- Protocol & Platform
|
| 110 |
+
protocol UInt8,
|
| 111 |
+
|
| 112 |
+
-- Migration Details
|
| 113 |
+
mint_address String,
|
| 114 |
+
virtual_pool_address String,
|
| 115 |
+
pool_address String,
|
| 116 |
+
|
| 117 |
+
-- Liquidity Details
|
| 118 |
+
migrated_base_liquidity Nullable(UInt64),
|
| 119 |
+
migrated_quote_liquidity Nullable(UInt64)
|
| 120 |
+
)
|
| 121 |
+
ENGINE = MergeTree()
|
| 122 |
+
ORDER BY (mint_address, virtual_pool_address, pool_address, timestamp);
|
| 123 |
+
|
| 124 |
+
CREATE TABLE fee_collections
|
| 125 |
+
(
|
| 126 |
+
-- Transaction Details
|
| 127 |
+
timestamp DateTime('UTC'),
|
| 128 |
+
|
| 129 |
+
signature String,
|
| 130 |
+
slot UInt64,
|
| 131 |
+
success Boolean,
|
| 132 |
+
error Nullable(String),
|
| 133 |
+
priority_fee Float64,
|
| 134 |
+
|
| 135 |
+
-- Protocol & Platform
|
| 136 |
+
protocol UInt8,
|
| 137 |
+
|
| 138 |
+
-- Fee Details
|
| 139 |
+
vault_address String,
|
| 140 |
+
recipient_address String,
|
| 141 |
+
|
| 142 |
+
-- Collected Amounts
|
| 143 |
+
token_0_mint_address String,
|
| 144 |
+
token_0_amount Float64,
|
| 145 |
+
token_1_mint_address Nullable(String),
|
| 146 |
+
token_1_amount Nullable(Float64)
|
| 147 |
+
)
|
| 148 |
+
ENGINE = MergeTree()
|
| 149 |
+
ORDER BY (vault_address, recipient_address, timestamp);
|
| 150 |
+
|
| 151 |
+
CREATE TABLE liquidity
|
| 152 |
+
(
|
| 153 |
+
-- Transaction Details --
|
| 154 |
+
signature String,
|
| 155 |
+
timestamp DateTime('UTC'),
|
| 156 |
+
slot UInt64,
|
| 157 |
+
success Boolean,
|
| 158 |
+
error Nullable(String),
|
| 159 |
+
priority_fee Float64,
|
| 160 |
+
|
| 161 |
+
-- Protocol Info --
|
| 162 |
+
protocol UInt8,
|
| 163 |
+
|
| 164 |
+
-- LP Action Details --
|
| 165 |
+
change_type UInt8,
|
| 166 |
+
lp_provider String,
|
| 167 |
+
pool_address String,
|
| 168 |
+
|
| 169 |
+
-- Token Amounts --
|
| 170 |
+
base_amount UInt64,
|
| 171 |
+
quote_amount UInt64
|
| 172 |
+
)
|
| 173 |
+
ENGINE = MergeTree()
|
| 174 |
+
ORDER BY (timestamp, pool_address, lp_provider);
|
| 175 |
+
|
| 176 |
+
CREATE TABLE pool_creations (
|
| 177 |
+
-- Transaction Details --
|
| 178 |
+
signature String,
|
| 179 |
+
timestamp Datetime('UTC'),
|
| 180 |
+
slot UInt64,
|
| 181 |
+
success Boolean,
|
| 182 |
+
error Nullable(String),
|
| 183 |
+
priority_fee Float64,
|
| 184 |
+
|
| 185 |
+
-- Protocol Info --
|
| 186 |
+
protocol UInt8,
|
| 187 |
+
|
| 188 |
+
-- Pool & Token Details --
|
| 189 |
+
creator_address String,
|
| 190 |
+
pool_address String,
|
| 191 |
+
base_address String,
|
| 192 |
+
quote_address String,
|
| 193 |
+
lp_token_address String,
|
| 194 |
+
|
| 195 |
+
-- Optional Initial State --
|
| 196 |
+
initial_base_liquidity Nullable(UInt64),
|
| 197 |
+
initial_quote_liquidity Nullable(UInt64),
|
| 198 |
+
base_decimals Nullable(UInt8),
|
| 199 |
+
quote_decimals Nullable(UInt8)
|
| 200 |
+
)
|
| 201 |
+
ENGINE = MergeTree()
|
| 202 |
+
ORDER BY (base_address, creator_address);
|
| 203 |
+
|
| 204 |
+
CREATE TABLE transfers
|
| 205 |
+
(
|
| 206 |
+
-- Transaction Details
|
| 207 |
+
timestamp DateTime('UTC'),
|
| 208 |
+
signature String,
|
| 209 |
+
slot UInt64,
|
| 210 |
+
success Boolean,
|
| 211 |
+
error Nullable(String),
|
| 212 |
+
priority_fee Float64,
|
| 213 |
+
|
| 214 |
+
-- Transfer Details
|
| 215 |
+
source String,
|
| 216 |
+
destination String,
|
| 217 |
+
|
| 218 |
+
-- Amount & Mint Details
|
| 219 |
+
mint_address String,
|
| 220 |
+
amount UInt64,
|
| 221 |
+
amount_decimal Float64,
|
| 222 |
+
|
| 223 |
+
-- Balance Context ===
|
| 224 |
+
source_balance Float64,
|
| 225 |
+
destination_balance Float64
|
| 226 |
+
)
|
| 227 |
+
ENGINE = MergeTree()
|
| 228 |
+
ORDER BY (source, destination, mint_address, timestamp);
|
| 229 |
+
|
| 230 |
+
CREATE TABLE supply_locks
|
| 231 |
+
(
|
| 232 |
+
-- === Transaction Details ===
|
| 233 |
+
timestamp DateTime('UTC'),
|
| 234 |
+
|
| 235 |
+
signature String,
|
| 236 |
+
slot UInt64,
|
| 237 |
+
success Boolean,
|
| 238 |
+
error Nullable(String),
|
| 239 |
+
priority_fee Float64,
|
| 240 |
+
|
| 241 |
+
-- === Protocol Info ===
|
| 242 |
+
protocol UInt8,
|
| 243 |
+
|
| 244 |
+
-- === Vesting Details ===
|
| 245 |
+
contract_address String,
|
| 246 |
+
sender String,
|
| 247 |
+
recipient String,
|
| 248 |
+
mint_address String,
|
| 249 |
+
total_locked_amount Float64,
|
| 250 |
+
final_unlock_timestamp UInt64
|
| 251 |
+
)
|
| 252 |
+
ENGINE = MergeTree()
|
| 253 |
+
ORDER BY (timestamp, mint_address, sender, recipient);
|
| 254 |
+
|
| 255 |
+
CREATE TABLE supply_lock_actions
|
| 256 |
+
(
|
| 257 |
+
-- === Transaction Details ===
|
| 258 |
+
|
| 259 |
+
signature String,
|
| 260 |
+
timestamp DateTime('UTC'),
|
| 261 |
+
slot UInt64,
|
| 262 |
+
success Boolean,
|
| 263 |
+
error Nullable(String),
|
| 264 |
+
priority_fee Float64,
|
| 265 |
+
|
| 266 |
+
-- === Protocol Info ===
|
| 267 |
+
protocol UInt8,
|
| 268 |
+
|
| 269 |
+
-- === Action Details ===
|
| 270 |
+
action_type UInt8, -- e.g., 0 for Withdraw, 1 for Topup
|
| 271 |
+
contract_address String,
|
| 272 |
+
user String,
|
| 273 |
+
mint_address String,
|
| 274 |
+
amount Float64
|
| 275 |
+
)
|
| 276 |
+
ENGINE = MergeTree()
|
| 277 |
+
ORDER BY (timestamp, mint_address, user);
|
| 278 |
+
|
| 279 |
+
CREATE TABLE burns
|
| 280 |
+
(
|
| 281 |
+
-- Transaction Details
|
| 282 |
+
timestamp DateTime('UTC'),
|
| 283 |
+
signature String,
|
| 284 |
+
slot UInt64,
|
| 285 |
+
success Boolean,
|
| 286 |
+
error Nullable(String),
|
| 287 |
+
priority_fee Float64,
|
| 288 |
+
|
| 289 |
+
-- Burn Details
|
| 290 |
+
mint_address String,
|
| 291 |
+
source String,
|
| 292 |
+
amount UInt64,
|
| 293 |
+
amount_decimal Float64,
|
| 294 |
+
|
| 295 |
+
source_balance Float64
|
| 296 |
+
)
|
| 297 |
+
ENGINE = MergeTree()
|
| 298 |
+
ORDER BY (mint_address, source, timestamp);
|
| 299 |
+
|
| 300 |
+
-------- Wallet schema
|
| 301 |
+
|
| 302 |
+
CREATE TABLE wallet_profiles
|
| 303 |
+
(
|
| 304 |
+
updated_at DateTime('UTC'),
|
| 305 |
+
first_seen_ts DateTime('UTC'),
|
| 306 |
+
last_seen_ts DateTime('UTC'),
|
| 307 |
+
|
| 308 |
+
wallet_address String,
|
| 309 |
+
tags Array(String),
|
| 310 |
+
deployed_tokens Array(String),
|
| 311 |
+
|
| 312 |
+
funded_from String,
|
| 313 |
+
funded_timestamp UInt32,
|
| 314 |
+
funded_signature String,
|
| 315 |
+
funded_amount Float64
|
| 316 |
+
)
|
| 317 |
+
ENGINE = ReplacingMergeTree(updated_at)
|
| 318 |
+
PRIMARY KEY (wallet_address)
|
| 319 |
+
ORDER BY (wallet_address);
|
| 320 |
+
|
| 321 |
+
CREATE TABLE wallet_profile_metrics
|
| 322 |
+
(
|
| 323 |
+
updated_at DateTime('UTC'),
|
| 324 |
+
wallet_address String,
|
| 325 |
+
balance Float64,
|
| 326 |
+
|
| 327 |
+
transfers_in_count UInt32,
|
| 328 |
+
transfers_out_count UInt32,
|
| 329 |
+
spl_transfers_in_count UInt32,
|
| 330 |
+
spl_transfers_out_count UInt32,
|
| 331 |
+
|
| 332 |
+
total_buys_count UInt32,
|
| 333 |
+
total_sells_count UInt32,
|
| 334 |
+
total_winrate Float32,
|
| 335 |
+
|
| 336 |
+
stats_1d_realized_profit_sol Float64,
|
| 337 |
+
stats_1d_realized_profit_usd Float64,
|
| 338 |
+
stats_1d_realized_profit_pnl Float32,
|
| 339 |
+
stats_1d_buy_count UInt32,
|
| 340 |
+
stats_1d_sell_count UInt32,
|
| 341 |
+
stats_1d_transfer_in_count UInt32,
|
| 342 |
+
stats_1d_transfer_out_count UInt32,
|
| 343 |
+
stats_1d_avg_holding_period Float32,
|
| 344 |
+
stats_1d_total_bought_cost_sol Float64,
|
| 345 |
+
stats_1d_total_bought_cost_usd Float64,
|
| 346 |
+
stats_1d_total_sold_income_sol Float64,
|
| 347 |
+
stats_1d_total_sold_income_usd Float64,
|
| 348 |
+
stats_1d_total_fee Float64,
|
| 349 |
+
stats_1d_winrate Float32,
|
| 350 |
+
stats_1d_tokens_traded UInt32,
|
| 351 |
+
|
| 352 |
+
stats_7d_realized_profit_sol Float64,
|
| 353 |
+
stats_7d_realized_profit_usd Float64,
|
| 354 |
+
stats_7d_realized_profit_pnl Float32,
|
| 355 |
+
stats_7d_buy_count UInt32,
|
| 356 |
+
stats_7d_sell_count UInt32,
|
| 357 |
+
stats_7d_transfer_in_count UInt32,
|
| 358 |
+
stats_7d_transfer_out_count UInt32,
|
| 359 |
+
stats_7d_avg_holding_period Float32,
|
| 360 |
+
stats_7d_total_bought_cost_sol Float64,
|
| 361 |
+
stats_7d_total_bought_cost_usd Float64,
|
| 362 |
+
stats_7d_total_sold_income_sol Float64,
|
| 363 |
+
stats_7d_total_sold_income_usd Float64,
|
| 364 |
+
stats_7d_total_fee Float64,
|
| 365 |
+
stats_7d_winrate Float32,
|
| 366 |
+
stats_7d_tokens_traded UInt32,
|
| 367 |
+
|
| 368 |
+
stats_30d_realized_profit_sol Float64,
|
| 369 |
+
stats_30d_realized_profit_usd Float64,
|
| 370 |
+
stats_30d_realized_profit_pnl Float32,
|
| 371 |
+
stats_30d_buy_count UInt32,
|
| 372 |
+
stats_30d_sell_count UInt32,
|
| 373 |
+
stats_30d_transfer_in_count UInt32,
|
| 374 |
+
stats_30d_transfer_out_count UInt32,
|
| 375 |
+
stats_30d_avg_holding_period Float32,
|
| 376 |
+
stats_30d_total_bought_cost_sol Float64,
|
| 377 |
+
stats_30d_total_bought_cost_usd Float64,
|
| 378 |
+
stats_30d_total_sold_income_sol Float64,
|
| 379 |
+
stats_30d_total_sold_income_usd Float64,
|
| 380 |
+
stats_30d_total_fee Float64,
|
| 381 |
+
stats_30d_winrate Float32,
|
| 382 |
+
stats_30d_tokens_traded UInt32
|
| 383 |
+
)
|
| 384 |
+
ENGINE = MergeTree
|
| 385 |
+
ORDER BY (wallet_address, updated_at);
|
| 386 |
+
|
| 387 |
+
CREATE TABLE wallet_holdings
|
| 388 |
+
(
|
| 389 |
+
updated_at DateTime('UTC'),
|
| 390 |
+
start_holding_at DateTime('UTC'),
|
| 391 |
+
|
| 392 |
+
wallet_address String,
|
| 393 |
+
mint_address String,
|
| 394 |
+
current_balance Float64,
|
| 395 |
+
|
| 396 |
+
realized_profit_pnl Float32,
|
| 397 |
+
realized_profit_sol Float64,
|
| 398 |
+
realized_profit_usd Float64,
|
| 399 |
+
|
| 400 |
+
history_transfer_in UInt32,
|
| 401 |
+
history_transfer_out UInt32,
|
| 402 |
+
|
| 403 |
+
history_bought_amount Float64,
|
| 404 |
+
history_bought_cost_sol Float64,
|
| 405 |
+
history_sold_amount Float64,
|
| 406 |
+
history_sold_income_sol Float64
|
| 407 |
+
)
|
| 408 |
+
ENGINE = MergeTree
|
| 409 |
+
ORDER BY (wallet_address, mint_address, updated_at);
|
| 410 |
+
|
| 411 |
+
CREATE TABLE tokens (
|
| 412 |
+
updated_at DateTime('UTC'),
|
| 413 |
+
created_at DateTime('UTC'),
|
| 414 |
+
|
| 415 |
+
-- Core Identifiers
|
| 416 |
+
token_address String,
|
| 417 |
+
name String,
|
| 418 |
+
symbol String,
|
| 419 |
+
token_uri String,
|
| 420 |
+
|
| 421 |
+
-- Token Metadata
|
| 422 |
+
decimals UInt8,
|
| 423 |
+
creator_address String,
|
| 424 |
+
pool_addresses Array(String), -- Map Vec<String> to Array(String)
|
| 425 |
+
|
| 426 |
+
-- Protocol/Launchpad
|
| 427 |
+
launchpad UInt8,
|
| 428 |
+
protocol UInt8,
|
| 429 |
+
total_supply UInt64,
|
| 430 |
+
|
| 431 |
+
-- Authorities/Flags
|
| 432 |
+
is_mutable Boolean, -- Alias for UInt8, but Boolean is clearer/modern
|
| 433 |
+
update_authority Nullable(String), -- Map Option<String> to Nullable(String)
|
| 434 |
+
mint_authority Nullable(String),
|
| 435 |
+
freeze_authority Nullable(String)
|
| 436 |
+
)
|
| 437 |
+
ENGINE = ReplacingMergeTree(updated_at)
|
| 438 |
+
PRIMARY KEY (token_address)
|
| 439 |
+
ORDER BY (token_address, updated_at);
|
| 440 |
+
|
| 441 |
+
-- Latest tokens (one row per token_address)
|
| 442 |
+
CREATE TABLE tokens_latest
|
| 443 |
+
(
|
| 444 |
+
updated_at DateTime('UTC'),
|
| 445 |
+
created_at DateTime('UTC'),
|
| 446 |
+
|
| 447 |
+
token_address String,
|
| 448 |
+
name String,
|
| 449 |
+
symbol String,
|
| 450 |
+
token_uri String,
|
| 451 |
+
|
| 452 |
+
decimals UInt8,
|
| 453 |
+
creator_address String,
|
| 454 |
+
pool_addresses Array(String),
|
| 455 |
+
|
| 456 |
+
launchpad UInt8,
|
| 457 |
+
protocol UInt8,
|
| 458 |
+
total_supply UInt64,
|
| 459 |
+
|
| 460 |
+
is_mutable Boolean,
|
| 461 |
+
update_authority Nullable(String),
|
| 462 |
+
mint_authority Nullable(String),
|
| 463 |
+
freeze_authority Nullable(String)
|
| 464 |
+
)
|
| 465 |
+
ENGINE = ReplacingMergeTree(updated_at)
|
| 466 |
+
ORDER BY (token_address);
|
| 467 |
+
|
| 468 |
+
CREATE TABLE token_metrics (
|
| 469 |
+
updated_at DateTime('UTC'),
|
| 470 |
+
token_address String,
|
| 471 |
+
total_volume_usd Float64,
|
| 472 |
+
total_buys UInt32,
|
| 473 |
+
total_sells UInt32,
|
| 474 |
+
unique_holders UInt32,
|
| 475 |
+
ath_price_usd Float64
|
| 476 |
+
)
|
| 477 |
+
ENGINE = MergeTree
|
| 478 |
+
ORDER BY (token_address, updated_at);
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
-- ========= Latest snapshot helper tables =========
|
| 483 |
+
-- Keep full history in the base tables above, but read fast from these ReplacingMergeTree snapshots.
|
| 484 |
+
|
| 485 |
+
-- Latest wallet profile metrics (one row per wallet_address)
|
| 486 |
+
CREATE TABLE wallet_profile_metrics_latest
|
| 487 |
+
(
|
| 488 |
+
updated_at DateTime('UTC'),
|
| 489 |
+
wallet_address String,
|
| 490 |
+
balance Float64,
|
| 491 |
+
|
| 492 |
+
transfers_in_count UInt32,
|
| 493 |
+
transfers_out_count UInt32,
|
| 494 |
+
spl_transfers_in_count UInt32,
|
| 495 |
+
spl_transfers_out_count UInt32,
|
| 496 |
+
|
| 497 |
+
total_buys_count UInt32,
|
| 498 |
+
total_sells_count UInt32,
|
| 499 |
+
total_winrate Float32,
|
| 500 |
+
|
| 501 |
+
stats_1d_realized_profit_sol Float64,
|
| 502 |
+
stats_1d_realized_profit_usd Float64,
|
| 503 |
+
stats_1d_realized_profit_pnl Float32,
|
| 504 |
+
stats_1d_buy_count UInt32,
|
| 505 |
+
stats_1d_sell_count UInt32,
|
| 506 |
+
stats_1d_transfer_in_count UInt32,
|
| 507 |
+
stats_1d_transfer_out_count UInt32,
|
| 508 |
+
stats_1d_avg_holding_period Float32,
|
| 509 |
+
stats_1d_total_bought_cost_sol Float64,
|
| 510 |
+
stats_1d_total_bought_cost_usd Float64,
|
| 511 |
+
stats_1d_total_sold_income_sol Float64,
|
| 512 |
+
stats_1d_total_sold_income_usd Float64,
|
| 513 |
+
stats_1d_total_fee Float64,
|
| 514 |
+
stats_1d_winrate Float32,
|
| 515 |
+
stats_1d_tokens_traded UInt32,
|
| 516 |
+
|
| 517 |
+
stats_7d_realized_profit_sol Float64,
|
| 518 |
+
stats_7d_realized_profit_usd Float64,
|
| 519 |
+
stats_7d_realized_profit_pnl Float32,
|
| 520 |
+
stats_7d_buy_count UInt32,
|
| 521 |
+
stats_7d_sell_count UInt32,
|
| 522 |
+
stats_7d_transfer_in_count UInt32,
|
| 523 |
+
stats_7d_transfer_out_count UInt32,
|
| 524 |
+
stats_7d_avg_holding_period Float32,
|
| 525 |
+
stats_7d_total_bought_cost_sol Float64,
|
| 526 |
+
stats_7d_total_bought_cost_usd Float64,
|
| 527 |
+
stats_7d_total_sold_income_sol Float64,
|
| 528 |
+
stats_7d_total_sold_income_usd Float64,
|
| 529 |
+
stats_7d_total_fee Float64,
|
| 530 |
+
stats_7d_winrate Float32,
|
| 531 |
+
stats_7d_tokens_traded UInt32,
|
| 532 |
+
|
| 533 |
+
stats_30d_realized_profit_sol Float64,
|
| 534 |
+
stats_30d_realized_profit_usd Float64,
|
| 535 |
+
stats_30d_realized_profit_pnl Float32,
|
| 536 |
+
stats_30d_buy_count UInt32,
|
| 537 |
+
stats_30d_sell_count UInt32,
|
| 538 |
+
stats_30d_transfer_in_count UInt32,
|
| 539 |
+
stats_30d_transfer_out_count UInt32,
|
| 540 |
+
stats_30d_avg_holding_period Float32,
|
| 541 |
+
stats_30d_total_bought_cost_sol Float64,
|
| 542 |
+
stats_30d_total_bought_cost_usd Float64,
|
| 543 |
+
stats_30d_total_sold_income_sol Float64,
|
| 544 |
+
stats_30d_total_sold_income_usd Float64,
|
| 545 |
+
stats_30d_total_fee Float64,
|
| 546 |
+
stats_30d_winrate Float32,
|
| 547 |
+
stats_30d_tokens_traded UInt32
|
| 548 |
+
)
|
| 549 |
+
ENGINE = ReplacingMergeTree(updated_at)
|
| 550 |
+
ORDER BY (wallet_address);
|
| 551 |
+
|
| 552 |
+
-- Latest wallet holdings (one row per wallet_address + mint_address)
|
| 553 |
+
CREATE TABLE wallet_holdings_latest
|
| 554 |
+
(
|
| 555 |
+
updated_at DateTime('UTC'),
|
| 556 |
+
start_holding_at DateTime('UTC'),
|
| 557 |
+
|
| 558 |
+
wallet_address String,
|
| 559 |
+
mint_address String,
|
| 560 |
+
current_balance Float64,
|
| 561 |
+
|
| 562 |
+
realized_profit_pnl Float32,
|
| 563 |
+
realized_profit_sol Float64,
|
| 564 |
+
realized_profit_usd Float64,
|
| 565 |
+
|
| 566 |
+
history_transfer_in UInt32,
|
| 567 |
+
history_transfer_out UInt32,
|
| 568 |
+
|
| 569 |
+
history_bought_amount Float64,
|
| 570 |
+
history_bought_cost_sol Float64,
|
| 571 |
+
history_sold_amount Float64,
|
| 572 |
+
history_sold_income_sol Float64
|
| 573 |
+
)
|
| 574 |
+
ENGINE = ReplacingMergeTree(updated_at)
|
| 575 |
+
ORDER BY (wallet_address, mint_address);
|
| 576 |
+
|
| 577 |
+
-- Latest token metrics (one row per token_address)
|
| 578 |
+
CREATE TABLE token_metrics_latest
|
| 579 |
+
(
|
| 580 |
+
updated_at DateTime('UTC'),
|
| 581 |
+
token_address String,
|
| 582 |
+
total_volume_usd Float64,
|
| 583 |
+
total_buys UInt32,
|
| 584 |
+
total_sells UInt32,
|
| 585 |
+
unique_holders UInt32,
|
| 586 |
+
ath_price_usd Float64
|
| 587 |
+
)
|
| 588 |
+
ENGINE = ReplacingMergeTree(updated_at)
|
| 589 |
+
ORDER BY (token_address);
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
CREATE TABLE known_wallets
|
| 593 |
+
(
|
| 594 |
+
`wallet_address` String,
|
| 595 |
+
`name` String, -- e.g., "Pump.fun Fee Vault", "Raydium CPMM Authority V4", "KOL - Ansem"
|
| 596 |
+
`tag` String, -- e.g., "fee_vault", "dex_authority", "kol", "exchange"
|
| 597 |
+
)
|
| 598 |
+
ENGINE = ReplacingMergeTree()
|
| 599 |
+
ORDER BY (wallet_address);
|
pre_cache.sh
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python scripts/cache_dataset.py \
|
| 2 |
+
--offset-utc 2024-01-01T00:00:00Z \
|
| 3 |
+
--max-samples 100 \
|
| 4 |
+
--out-dir data/cache/epoch_851 \
|
| 5 |
+
--clickhouse-host localhost --clickhouse-port 9000 \
|
| 6 |
+
--neo4j-uri bolt://localhost:7687
|
scripts/cache_dataset.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Script to pre-generate and cache dataset items from the OracleDataset.
|
| 4 |
+
|
| 5 |
+
This script connects to the databases, instantiates the data loader in 'online' mode,
|
| 6 |
+
and iterates through the requested number of samples, saving each processed item
|
| 7 |
+
to a file. This avoids costly data fetching and processing during training.
|
| 8 |
+
|
| 9 |
+
Example usage:
|
| 10 |
+
python scripts/cache_dataset.py --output-dir ./data/cached_dataset --max-samples 1000 --start-date 2024-05-01
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import datetime
|
| 15 |
+
import os
|
| 16 |
+
import sys
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import clickhouse_connect
|
| 21 |
+
from neo4j import GraphDatabase
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
+
|
| 24 |
+
# Add apollo to path to import modules
|
| 25 |
+
sys.path.append(str(Path(__file__).resolve().parents[1]))
|
| 26 |
+
|
| 27 |
+
from data.data_loader import OracleDataset
|
| 28 |
+
from data.data_fetcher import DataFetcher
|
| 29 |
+
|
| 30 |
+
# --- Database Connection Details (can be overridden by env vars) ---
|
| 31 |
+
CLICKHOUSE_HOST = os.getenv("CLICKHOUSE_HOST", "localhost")
|
| 32 |
+
CLICKHOUSE_PORT = int(os.getenv("CLICKHOUSE_PORT", "8123"))
|
| 33 |
+
CLICKHOUSE_USER = os.getenv("CLICKHOUSE_USER", "default")
|
| 34 |
+
CLICKHOUSE_PASSWORD = os.getenv("CLICKHOUSE_PASSWORD", "")
|
| 35 |
+
CLICKHOUSE_DATABASE = os.getenv("CLICKHOUSE_DATABASE", "default")
|
| 36 |
+
|
| 37 |
+
NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687")
|
| 38 |
+
NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
|
| 39 |
+
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "password")
|
| 40 |
+
|
| 41 |
+
def parse_args():
|
| 42 |
+
parser = argparse.ArgumentParser(description="Cache OracleDataset items to disk.")
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
"--output-dir",
|
| 45 |
+
type=str,
|
| 46 |
+
required=True,
|
| 47 |
+
help="Directory to save the cached .pt files."
|
| 48 |
+
)
|
| 49 |
+
parser.add_argument(
|
| 50 |
+
"--max-samples",
|
| 51 |
+
type=int,
|
| 52 |
+
default=None,
|
| 53 |
+
help="Maximum number of samples to generate and cache. Defaults to all available."
|
| 54 |
+
)
|
| 55 |
+
parser.add_argument(
|
| 56 |
+
"--start-date",
|
| 57 |
+
type=str,
|
| 58 |
+
default=None,
|
| 59 |
+
help="Start date for fetching mints in YYYY-MM-DD format. Fetches all mints on or after this UTC date."
|
| 60 |
+
)
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
"--t-cutoff-seconds",
|
| 63 |
+
type=int,
|
| 64 |
+
default=60,
|
| 65 |
+
help="Time in seconds after mint to set the data cutoff (T_cutoff)."
|
| 66 |
+
)
|
| 67 |
+
parser.add_argument(
|
| 68 |
+
"--ohlc-stats-path",
|
| 69 |
+
type=str,
|
| 70 |
+
default="./data/ohlc_stats.npz",
|
| 71 |
+
help="Path to the OHLC stats file for normalization."
|
| 72 |
+
)
|
| 73 |
+
parser.add_argument(
|
| 74 |
+
"--min-trade-usd",
|
| 75 |
+
type=float,
|
| 76 |
+
default=5.0,
|
| 77 |
+
help="Minimum USD value for a trade to be included in the event sequence. Defaults to 5.0."
|
| 78 |
+
)
|
| 79 |
+
return parser.parse_args()
|
| 80 |
+
|
| 81 |
+
def main():
|
| 82 |
+
args = parse_args()
|
| 83 |
+
|
| 84 |
+
output_dir = Path(args.output_dir)
|
| 85 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 86 |
+
print(f"INFO: Caching dataset to {output_dir.resolve()}")
|
| 87 |
+
|
| 88 |
+
start_date_dt = None
|
| 89 |
+
if args.start_date:
|
| 90 |
+
try:
|
| 91 |
+
start_date_dt = datetime.datetime.strptime(args.start_date, "%Y-%m-%d").replace(tzinfo=datetime.timezone.utc)
|
| 92 |
+
print(f"INFO: Filtering mints on or after {start_date_dt}")
|
| 93 |
+
except ValueError:
|
| 94 |
+
print(f"ERROR: Invalid start-date format. Please use YYYY-MM-DD.", file=sys.stderr)
|
| 95 |
+
sys.exit(1)
|
| 96 |
+
|
| 97 |
+
# --- 1. Set up database connections ---
|
| 98 |
+
try:
|
| 99 |
+
print("INFO: Connecting to ClickHouse...")
|
| 100 |
+
clickhouse_client = clickhouse_connect.get_client(host=CLICKHOUSE_HOST, port=CLICKHOUSE_PORT, user=CLICKHOUSE_USER, password=CLICKHOUSE_PASSWORD, database=CLICKHOUSE_DATABASE)
|
| 101 |
+
print("INFO: Connecting to Neo4j...")
|
| 102 |
+
neo4j_driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
|
| 103 |
+
except Exception as e:
|
| 104 |
+
print(f"ERROR: Failed to connect to databases: {e}", file=sys.stderr)
|
| 105 |
+
sys.exit(1)
|
| 106 |
+
|
| 107 |
+
# --- 2. Initialize DataFetcher and OracleDataset ---
|
| 108 |
+
data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
|
| 109 |
+
|
| 110 |
+
dataset = OracleDataset(
|
| 111 |
+
data_fetcher=data_fetcher,
|
| 112 |
+
max_samples=args.max_samples,
|
| 113 |
+
start_date=start_date_dt,
|
| 114 |
+
t_cutoff_seconds=args.t_cutoff_seconds,
|
| 115 |
+
ohlc_stats_path=args.ohlc_stats_path,
|
| 116 |
+
horizons_seconds=[60, 300, 900, 1800, 3600],
|
| 117 |
+
quantiles=[0.5],
|
| 118 |
+
min_trade_usd=args.min_trade_usd
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
if len(dataset) == 0:
|
| 122 |
+
print("WARNING: Dataset initialization resulted in 0 samples. Nothing to cache.")
|
| 123 |
+
return
|
| 124 |
+
|
| 125 |
+
# --- 3. Iterate and cache each item ---
|
| 126 |
+
print(f"INFO: Starting to generate and cache {len(dataset)} samples...")
|
| 127 |
+
skipped_count = 0
|
| 128 |
+
for i in tqdm(range(len(dataset)), desc="Caching samples"):
|
| 129 |
+
try:
|
| 130 |
+
item = dataset.__cacheitem__(i)
|
| 131 |
+
if item is None:
|
| 132 |
+
skipped_count += 1
|
| 133 |
+
continue
|
| 134 |
+
output_path = output_dir / f"sample_{i}.pt"
|
| 135 |
+
torch.save(item, output_path)
|
| 136 |
+
except Exception as e:
|
| 137 |
+
print(f"\nERROR: Failed to generate or save sample {i} for mint '{dataset.sampled_mints[i]['mint_address']}'. Error: {e}", file=sys.stderr)
|
| 138 |
+
skipped_count += 1
|
| 139 |
+
continue
|
| 140 |
+
|
| 141 |
+
print(f"\n--- Caching Complete ---\nSuccessfully cached: {len(dataset) - skipped_count} items.\nSkipped: {skipped_count} items.\nCache location: {output_dir.resolve()}")
|
| 142 |
+
|
| 143 |
+
# --- 4. Close connections ---
|
| 144 |
+
clickhouse_client.close()
|
| 145 |
+
neo4j_driver.close()
|
| 146 |
+
|
| 147 |
+
if __name__ == "__main__":
|
| 148 |
+
main()
|
scripts/download_epoch_artifacts.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Download a specific epoch's parquet/Neo4j artifacts from Hugging Face.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
HF_TOKEN=your_token \
|
| 7 |
+
python scripts/download_epoch_artifacts.py --epoch 851
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import os
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import List
|
| 14 |
+
|
| 15 |
+
from huggingface_hub import snapshot_download
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
REPO_ID = "zirobtc/pump-fun-dataset"
|
| 19 |
+
REPO_TYPE = "model" # dataset is not used here per user note
|
| 20 |
+
DEFAULT_DEST_DIR = "./data/pump_fun"
|
| 21 |
+
|
| 22 |
+
# File stems that are suffixed with `_epoch_{epoch}.parquet`
|
| 23 |
+
PARQUET_STEMS = [
|
| 24 |
+
"wallet_profiles",
|
| 25 |
+
"wallet_holdings",
|
| 26 |
+
"trades",
|
| 27 |
+
"transfers",
|
| 28 |
+
"burns",
|
| 29 |
+
"tokens",
|
| 30 |
+
"mints",
|
| 31 |
+
"liquidity",
|
| 32 |
+
"pool_creations",
|
| 33 |
+
"token_metrics",
|
| 34 |
+
"wallet_profile_metrics",
|
| 35 |
+
"migrations",
|
| 36 |
+
"fee_collections",
|
| 37 |
+
"supply_locks",
|
| 38 |
+
"supply_lock_actions",
|
| 39 |
+
"known_wallets",
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
# Single Neo4j dump name
|
| 43 |
+
NEO4J_FILENAME = "neo4j_epoch_{epoch}.dump"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def build_patterns(epoch: int) -> List[str]:
|
| 47 |
+
epoch_str = str(epoch)
|
| 48 |
+
parquet_patterns = [f"{stem}_epoch_{epoch_str}.parquet" for stem in PARQUET_STEMS]
|
| 49 |
+
neo4j_pattern = NEO4J_FILENAME.format(epoch=epoch_str)
|
| 50 |
+
return parquet_patterns + [neo4j_pattern]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def parse_args() -> argparse.Namespace:
|
| 54 |
+
parser = argparse.ArgumentParser(description="Download epoch artifacts from Hugging Face.")
|
| 55 |
+
parser.add_argument("--epoch", type=int, required=False, help="Epoch number to download (e.g., 851)", default=851)
|
| 56 |
+
parser.add_argument(
|
| 57 |
+
"--token",
|
| 58 |
+
type=str,
|
| 59 |
+
default=None,
|
| 60 |
+
required=False,
|
| 61 |
+
help="Hugging Face token (or set HF_TOKEN env var)",
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
return parser.parse_args()
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def main() -> None:
|
| 68 |
+
args = parse_args()
|
| 69 |
+
token = args.token or os.environ.get("HF_TOKEN")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
patterns = build_patterns(args.epoch)
|
| 73 |
+
dest_root = Path(DEFAULT_DEST_DIR).expanduser()
|
| 74 |
+
dest_dir = dest_root / f"epoch_{args.epoch}"
|
| 75 |
+
dest_dir.mkdir(parents=True, exist_ok=True)
|
| 76 |
+
|
| 77 |
+
print(f"Downloading epoch {args.epoch} files from {REPO_ID} to {dest_dir}")
|
| 78 |
+
print("Files:")
|
| 79 |
+
for p in patterns:
|
| 80 |
+
print(f" - {p}")
|
| 81 |
+
|
| 82 |
+
snapshot_download(
|
| 83 |
+
repo_id=REPO_ID,
|
| 84 |
+
repo_type=REPO_TYPE,
|
| 85 |
+
local_dir=str(dest_dir),
|
| 86 |
+
local_dir_use_symlinks=False,
|
| 87 |
+
allow_patterns=patterns,
|
| 88 |
+
resume_download=True,
|
| 89 |
+
token=token,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
print("Download complete.")
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
if __name__ == "__main__":
|
| 96 |
+
main()
|
scripts/ingest_epoch.py
ADDED
|
@@ -0,0 +1,713 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
ETL Pipeline: Download epoch Parquet files, ingest into ClickHouse, and delete local files.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python scripts/ingest_epoch.py --epoch 851
|
| 7 |
+
|
| 8 |
+
Environment Variables:
|
| 9 |
+
HF_TOKEN: Hugging Face token for downloading private datasets.
|
| 10 |
+
CLICKHOUSE_HOST, CLICKHOUSE_PORT, CLICKHOUSE_USER, CLICKHOUSE_PASSWORD, CLICKHOUSE_DATABASE
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import os
|
| 15 |
+
import sys
|
| 16 |
+
import time
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
import clickhouse_connect
|
| 20 |
+
from huggingface_hub import snapshot_download
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
|
| 23 |
+
# Hugging Face config
|
| 24 |
+
REPO_ID = "zirobtc/pump-fun-dataset"
|
| 25 |
+
REPO_TYPE = "model"
|
| 26 |
+
DEFAULT_DEST_DIR = "./data/pump_fun"
|
| 27 |
+
CLICKHOUSE_DOCKER_CONTAINER = "db-clickhouse"
|
| 28 |
+
CLICKHOUSE_INSERT_SETTINGS = "max_insert_threads=1,max_block_size=65536"
|
| 29 |
+
NEO4J_DOCKER_CONTAINER = "neo4j"
|
| 30 |
+
NEO4J_TARGET_DB = "neo4j"
|
| 31 |
+
NEO4J_TEMP_DB_PREFIX = "epoch"
|
| 32 |
+
NEO4J_MERGE_BATCH_SIZE = 2000
|
| 33 |
+
NEO4J_URI = "bolt://localhost:7687"
|
| 34 |
+
NEO4J_USER = None
|
| 35 |
+
NEO4J_PASSWORD = None
|
| 36 |
+
|
| 37 |
+
# Parquet file stems -> ClickHouse table names
|
| 38 |
+
# Maps the file stem to the target table. Usually they match.
|
| 39 |
+
PARQUET_TABLE_MAP = {
|
| 40 |
+
"wallet_profiles": "wallet_profiles",
|
| 41 |
+
"wallet_holdings": "wallet_holdings",
|
| 42 |
+
"trades": "trades",
|
| 43 |
+
"transfers": "transfers",
|
| 44 |
+
"burns": "burns",
|
| 45 |
+
"tokens": "tokens",
|
| 46 |
+
"mints": "mints",
|
| 47 |
+
"liquidity": "liquidity",
|
| 48 |
+
"pool_creations": "pool_creations",
|
| 49 |
+
"token_metrics": "token_metrics",
|
| 50 |
+
"wallet_profile_metrics": "wallet_profile_metrics",
|
| 51 |
+
"migrations": "migrations",
|
| 52 |
+
"fee_collections": "fee_collections",
|
| 53 |
+
"supply_locks": "supply_locks",
|
| 54 |
+
"supply_lock_actions": "supply_lock_actions",
|
| 55 |
+
"known_wallets": "known_wallets",
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
# Neo4j dump filename pattern
|
| 59 |
+
NEO4J_FILENAME = "neo4j_epoch_{epoch}.dump"
|
| 60 |
+
|
| 61 |
+
# ClickHouse connection defaults (can be overridden by env vars)
|
| 62 |
+
CH_HOST = os.getenv("CLICKHOUSE_HOST", "localhost")
|
| 63 |
+
CH_PORT = int(os.getenv("CLICKHOUSE_PORT", "8123"))
|
| 64 |
+
CH_USER = os.getenv("CLICKHOUSE_USER", "default")
|
| 65 |
+
CH_PASSWORD = os.getenv("CLICKHOUSE_PASSWORD", "")
|
| 66 |
+
CH_DATABASE = os.getenv("CLICKHOUSE_DATABASE", "default")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def build_patterns(epoch: int) -> list[str]:
|
| 70 |
+
"""Build the list of file patterns to download for a given epoch."""
|
| 71 |
+
epoch_str = str(epoch)
|
| 72 |
+
parquet_patterns = [f"{stem}_epoch_{epoch_str}.parquet" for stem in PARQUET_TABLE_MAP.keys()]
|
| 73 |
+
neo4j_pattern = NEO4J_FILENAME.format(epoch=epoch_str)
|
| 74 |
+
return parquet_patterns + [neo4j_pattern]
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def download_epoch(epoch: int, dest_dir: Path, token: str | None) -> None:
|
| 78 |
+
"""Download epoch artifacts from Hugging Face."""
|
| 79 |
+
patterns = build_patterns(epoch)
|
| 80 |
+
dest_dir.mkdir(parents=True, exist_ok=True)
|
| 81 |
+
|
| 82 |
+
print(f"📥 Downloading epoch {epoch} from {REPO_ID}...")
|
| 83 |
+
snapshot_download(
|
| 84 |
+
repo_id=REPO_ID,
|
| 85 |
+
repo_type=REPO_TYPE,
|
| 86 |
+
local_dir=str(dest_dir),
|
| 87 |
+
local_dir_use_symlinks=False,
|
| 88 |
+
allow_patterns=patterns,
|
| 89 |
+
resume_download=True,
|
| 90 |
+
token=token,
|
| 91 |
+
)
|
| 92 |
+
print("✅ Download complete.")
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def ingest_parquet(client, table_name: str, parquet_path: Path, dry_run: bool = False) -> bool:
|
| 96 |
+
"""
|
| 97 |
+
Ingest a Parquet file into a ClickHouse table.
|
| 98 |
+
Returns True on success.
|
| 99 |
+
"""
|
| 100 |
+
if dry_run:
|
| 101 |
+
print(f" [DRY-RUN] insert {parquet_path.name} -> {table_name}")
|
| 102 |
+
return True
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
with parquet_path.open("rb") as fh:
|
| 106 |
+
magic = fh.read(4)
|
| 107 |
+
if magic != b"PAR1":
|
| 108 |
+
print(f" ⚠️ Skipping {parquet_path.name}: not a Parquet file.")
|
| 109 |
+
return False
|
| 110 |
+
|
| 111 |
+
# clickhouse-connect (HTTP) doesn't support FROM INFILE; prefer streaming inserts.
|
| 112 |
+
# Using insert_file can still be memory-heavy for large Parquet on some setups.
|
| 113 |
+
import subprocess
|
| 114 |
+
insert_query = f"INSERT INTO {table_name} FORMAT Parquet SETTINGS {CLICKHOUSE_INSERT_SETTINGS}"
|
| 115 |
+
infile_query = f"INSERT INTO {table_name} FROM INFILE '{parquet_path.resolve()}' FORMAT Parquet"
|
| 116 |
+
try:
|
| 117 |
+
cmd = [
|
| 118 |
+
"clickhouse-client",
|
| 119 |
+
"--host", CH_HOST,
|
| 120 |
+
"--port", str(CH_PORT),
|
| 121 |
+
"--user", CH_USER,
|
| 122 |
+
"--password", CH_PASSWORD,
|
| 123 |
+
"--database", CH_DATABASE,
|
| 124 |
+
"--query", infile_query,
|
| 125 |
+
]
|
| 126 |
+
subprocess.run(cmd, check=True)
|
| 127 |
+
return True
|
| 128 |
+
except FileNotFoundError:
|
| 129 |
+
pass
|
| 130 |
+
|
| 131 |
+
# Docker fallback for ClickHouse container
|
| 132 |
+
ch_container = CLICKHOUSE_DOCKER_CONTAINER
|
| 133 |
+
try:
|
| 134 |
+
tmp_path = f"/tmp/{parquet_path.name}"
|
| 135 |
+
subprocess.run(
|
| 136 |
+
["docker", "cp", str(parquet_path), f"{ch_container}:{tmp_path}"],
|
| 137 |
+
check=True,
|
| 138 |
+
)
|
| 139 |
+
docker_cmd = [
|
| 140 |
+
"docker", "exec", ch_container,
|
| 141 |
+
"clickhouse-client",
|
| 142 |
+
"--query", f"INSERT INTO {table_name} FROM INFILE '{tmp_path}' FORMAT Parquet",
|
| 143 |
+
]
|
| 144 |
+
subprocess.run(docker_cmd, check=True)
|
| 145 |
+
subprocess.run(["docker", "exec", ch_container, "rm", "-f", tmp_path], check=True)
|
| 146 |
+
return True
|
| 147 |
+
except FileNotFoundError:
|
| 148 |
+
raise RuntimeError(
|
| 149 |
+
"clickhouse-client not found and docker is unavailable. Install clickhouse-client or use a ClickHouse container."
|
| 150 |
+
)
|
| 151 |
+
except Exception as e:
|
| 152 |
+
print(f" ❌ Failed to ingest {parquet_path.name}: {e}")
|
| 153 |
+
return False
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def run_etl(epoch: int, dest_dir: Path, client, dry_run: bool = False, token: str | None = None, skip_neo4j: bool = False, skip_clickhouse: bool = False) -> None:
|
| 157 |
+
"""
|
| 158 |
+
Full ETL pipeline:
|
| 159 |
+
1. Use local Parquet files (no download)
|
| 160 |
+
2. Ingest into ClickHouse
|
| 161 |
+
3. Keep local files (no deletion)
|
| 162 |
+
"""
|
| 163 |
+
if not dest_dir.exists():
|
| 164 |
+
raise FileNotFoundError(f"Epoch directory not found: {dest_dir}")
|
| 165 |
+
|
| 166 |
+
if not skip_clickhouse:
|
| 167 |
+
# Step 2: Ingest each Parquet file
|
| 168 |
+
print(f"\n📤 Ingesting Parquet files into ClickHouse...")
|
| 169 |
+
for stem, table_name in tqdm(PARQUET_TABLE_MAP.items(), desc="Ingesting"):
|
| 170 |
+
parquet_path = dest_dir / f"{stem}_epoch_{epoch}.parquet"
|
| 171 |
+
if not parquet_path.exists():
|
| 172 |
+
print(f" ⚠️ Skipping {stem}: file not found.")
|
| 173 |
+
continue
|
| 174 |
+
|
| 175 |
+
ingest_parquet(client, table_name, parquet_path, dry_run=dry_run)
|
| 176 |
+
|
| 177 |
+
print("\n✅ ClickHouse ingestion complete.")
|
| 178 |
+
else:
|
| 179 |
+
print("\nℹ️ ClickHouse ingestion skipped.")
|
| 180 |
+
|
| 181 |
+
# Step 4: Neo4j dump
|
| 182 |
+
neo4j_path = dest_dir / NEO4J_FILENAME.format(epoch=epoch)
|
| 183 |
+
if neo4j_path.exists() and not skip_neo4j:
|
| 184 |
+
merge_neo4j_epoch_dump(epoch, neo4j_path, dry_run=dry_run)
|
| 185 |
+
elif neo4j_path.exists() and skip_neo4j:
|
| 186 |
+
print(f"\nℹ️ Neo4j dump found but skipped: {neo4j_path}")
|
| 187 |
+
|
| 188 |
+
print("\n🎉 Full ETL pipeline complete.")
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def ingest_neo4j_dump(dump_path: Path, database: str = "neo4j", dry_run: bool = False) -> bool:
|
| 192 |
+
"""
|
| 193 |
+
Load a Neo4j dump file into the database.
|
| 194 |
+
Requires neo4j-admin CLI and the Neo4j service to be stopped.
|
| 195 |
+
Returns True on success.
|
| 196 |
+
"""
|
| 197 |
+
import subprocess
|
| 198 |
+
|
| 199 |
+
if not dump_path.exists():
|
| 200 |
+
print(f" ⚠️ Neo4j dump not found: {dump_path}")
|
| 201 |
+
return False
|
| 202 |
+
|
| 203 |
+
import shutil
|
| 204 |
+
|
| 205 |
+
expected_dump_name = f"{database}.dump"
|
| 206 |
+
load_dir = dump_path.parent
|
| 207 |
+
temp_load_dir = None
|
| 208 |
+
if dump_path.name != expected_dump_name:
|
| 209 |
+
temp_load_dir = dump_path.parent / f"_neo4j_load_{database}"
|
| 210 |
+
temp_load_dir.mkdir(parents=True, exist_ok=True)
|
| 211 |
+
load_dump_path = temp_load_dir / expected_dump_name
|
| 212 |
+
shutil.copy2(dump_path, load_dump_path)
|
| 213 |
+
load_dir = temp_load_dir
|
| 214 |
+
|
| 215 |
+
# neo4j-admin database load requires a directory containing <database>.dump
|
| 216 |
+
# For Neo4j 5.x: neo4j-admin database load --from-path=<dir> <database>
|
| 217 |
+
# Note: User must clear the database before loading (no --overwrite flag)
|
| 218 |
+
cmd = [
|
| 219 |
+
"neo4j-admin", "database", "load",
|
| 220 |
+
f"--from-path={load_dir.resolve()}",
|
| 221 |
+
database,
|
| 222 |
+
]
|
| 223 |
+
|
| 224 |
+
if dry_run:
|
| 225 |
+
print(f" [DRY-RUN] {' '.join(cmd)}")
|
| 226 |
+
return True
|
| 227 |
+
|
| 228 |
+
print(f"🔄 Loading Neo4j dump into database '{database}'...")
|
| 229 |
+
print(" ⚠️ Neo4j must be stopped for offline load.")
|
| 230 |
+
|
| 231 |
+
try:
|
| 232 |
+
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
| 233 |
+
print(" ✅ Neo4j dump loaded successfully.")
|
| 234 |
+
return True
|
| 235 |
+
except FileNotFoundError:
|
| 236 |
+
# Fall back to dockerized neo4j-admin if available
|
| 237 |
+
docker_container = NEO4J_DOCKER_CONTAINER
|
| 238 |
+
try:
|
| 239 |
+
docker_ps = subprocess.run(
|
| 240 |
+
["docker", "ps", "-a", "--format", "{{.Names}}\t{{.Image}}"],
|
| 241 |
+
capture_output=True,
|
| 242 |
+
text=True,
|
| 243 |
+
check=True,
|
| 244 |
+
)
|
| 245 |
+
except FileNotFoundError:
|
| 246 |
+
print(" ❌ neo4j-admin not found and docker is unavailable.")
|
| 247 |
+
return False
|
| 248 |
+
except subprocess.CalledProcessError as e:
|
| 249 |
+
print(f" ❌ Failed to list docker containers: {e.stderr}")
|
| 250 |
+
return False
|
| 251 |
+
|
| 252 |
+
containers = [line.strip().split("\t") for line in docker_ps.stdout.splitlines() if line.strip()]
|
| 253 |
+
container_names = {name for name, _ in containers}
|
| 254 |
+
if docker_container not in container_names:
|
| 255 |
+
# Try to auto-detect a neo4j container if the default name isn't found.
|
| 256 |
+
neo4j_candidates = [name for name, image in containers if image.startswith("neo4j")]
|
| 257 |
+
if neo4j_candidates:
|
| 258 |
+
docker_container = neo4j_candidates[0]
|
| 259 |
+
print(f" ℹ️ Using detected Neo4j container '{docker_container}'.")
|
| 260 |
+
else:
|
| 261 |
+
print(f" ❌ neo4j-admin not found and docker container '{docker_container}' does not exist.")
|
| 262 |
+
return False
|
| 263 |
+
|
| 264 |
+
docker_running = subprocess.run(
|
| 265 |
+
["docker", "ps", "--format", "{{.Names}}"],
|
| 266 |
+
capture_output=True,
|
| 267 |
+
text=True,
|
| 268 |
+
check=True,
|
| 269 |
+
)
|
| 270 |
+
running = set(line.strip() for line in docker_running.stdout.splitlines() if line.strip())
|
| 271 |
+
was_running = docker_container in running
|
| 272 |
+
|
| 273 |
+
if was_running:
|
| 274 |
+
print(f" 🛑 Stopping Neo4j container '{docker_container}' for offline load...")
|
| 275 |
+
if dry_run:
|
| 276 |
+
print(f" [DRY-RUN] docker stop {docker_container}")
|
| 277 |
+
else:
|
| 278 |
+
subprocess.run(["docker", "stop", docker_container], check=True)
|
| 279 |
+
|
| 280 |
+
dump_name = dump_path.name
|
| 281 |
+
docker_cmd = [
|
| 282 |
+
"docker", "run", "--rm",
|
| 283 |
+
"--volumes-from", docker_container,
|
| 284 |
+
"-v", f"{load_dir.resolve()}:/dump",
|
| 285 |
+
"neo4j:latest",
|
| 286 |
+
"neo4j-admin", "database", "load",
|
| 287 |
+
f"--from-path=/dump",
|
| 288 |
+
"--overwrite-destination",
|
| 289 |
+
database,
|
| 290 |
+
]
|
| 291 |
+
|
| 292 |
+
if dry_run:
|
| 293 |
+
print(f" [DRY-RUN] {' '.join(docker_cmd)}")
|
| 294 |
+
else:
|
| 295 |
+
print(f" 🔄 Running neo4j-admin in docker for {dump_name}...")
|
| 296 |
+
subprocess.run(docker_cmd, check=True)
|
| 297 |
+
print(" ✅ Neo4j dump loaded successfully (docker).")
|
| 298 |
+
|
| 299 |
+
if was_running:
|
| 300 |
+
print(f" ▶️ Starting Neo4j container '{docker_container}'...")
|
| 301 |
+
if dry_run:
|
| 302 |
+
print(f" [DRY-RUN] docker start {docker_container}")
|
| 303 |
+
else:
|
| 304 |
+
subprocess.run(["docker", "start", docker_container], check=True)
|
| 305 |
+
_wait_for_bolt(NEO4J_URI)
|
| 306 |
+
if temp_load_dir and not dry_run:
|
| 307 |
+
shutil.rmtree(temp_load_dir, ignore_errors=True)
|
| 308 |
+
return True
|
| 309 |
+
except subprocess.CalledProcessError as e:
|
| 310 |
+
print(f" ❌ Failed to load Neo4j dump: {e.stderr}")
|
| 311 |
+
if temp_load_dir and not dry_run:
|
| 312 |
+
shutil.rmtree(temp_load_dir, ignore_errors=True)
|
| 313 |
+
return False
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def _neo4j_driver():
|
| 317 |
+
from neo4j import GraphDatabase
|
| 318 |
+
if NEO4J_USER is None and NEO4J_PASSWORD is None:
|
| 319 |
+
return GraphDatabase.driver(NEO4J_URI, auth=None)
|
| 320 |
+
return GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def _run_merge_batch(tx, query: str, rows: list[dict]) -> None:
|
| 324 |
+
tx.run(query, rows=rows)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def _stream_merge(temp_session, target_session, match_query: str, merge_query: str, label: str) -> None:
|
| 328 |
+
batch = []
|
| 329 |
+
result = temp_session.run(match_query, fetch_size=NEO4J_MERGE_BATCH_SIZE)
|
| 330 |
+
for record in result:
|
| 331 |
+
batch.append(record.data())
|
| 332 |
+
if len(batch) >= NEO4J_MERGE_BATCH_SIZE:
|
| 333 |
+
target_session.execute_write(_run_merge_batch, merge_query, batch)
|
| 334 |
+
batch.clear()
|
| 335 |
+
if batch:
|
| 336 |
+
target_session.execute_write(_run_merge_batch, merge_query, batch)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def _wait_for_bolt(uri: str, timeout_sec: int = 60) -> None:
|
| 340 |
+
from neo4j import GraphDatabase
|
| 341 |
+
start = time.time()
|
| 342 |
+
while True:
|
| 343 |
+
try:
|
| 344 |
+
temp_driver = GraphDatabase.driver(uri, auth=None)
|
| 345 |
+
with temp_driver.session(database="neo4j") as session:
|
| 346 |
+
session.run("RETURN 1").consume()
|
| 347 |
+
temp_driver.close()
|
| 348 |
+
return
|
| 349 |
+
except Exception:
|
| 350 |
+
if time.time() - start > timeout_sec:
|
| 351 |
+
raise RuntimeError(f"Timed out waiting for Neo4j at {uri}")
|
| 352 |
+
time.sleep(1)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def _start_temp_neo4j_from_dump(epoch: int, dump_path: Path) -> tuple[str, str, str, Path]:
|
| 356 |
+
import subprocess
|
| 357 |
+
import shutil
|
| 358 |
+
|
| 359 |
+
expected_dump_name = "neo4j.dump"
|
| 360 |
+
temp_load_dir = dump_path.parent / f"_neo4j_load_{epoch}"
|
| 361 |
+
temp_load_dir.mkdir(parents=True, exist_ok=True)
|
| 362 |
+
load_dump_path = temp_load_dir / expected_dump_name
|
| 363 |
+
shutil.copy2(dump_path, load_dump_path)
|
| 364 |
+
|
| 365 |
+
volume_name = f"neo4j_tmp_{epoch}"
|
| 366 |
+
subprocess.run(["docker", "volume", "create", volume_name], check=True)
|
| 367 |
+
|
| 368 |
+
subprocess.run(
|
| 369 |
+
[
|
| 370 |
+
"docker", "run", "--rm",
|
| 371 |
+
"-v", f"{volume_name}:/data",
|
| 372 |
+
"-v", f"{temp_load_dir.resolve()}:/dump",
|
| 373 |
+
"neo4j:latest",
|
| 374 |
+
"neo4j-admin", "database", "load",
|
| 375 |
+
"--from-path=/dump",
|
| 376 |
+
"--overwrite-destination",
|
| 377 |
+
"neo4j",
|
| 378 |
+
],
|
| 379 |
+
check=True,
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
container_id = subprocess.check_output(
|
| 383 |
+
[
|
| 384 |
+
"docker", "run", "-d", "--rm",
|
| 385 |
+
"-e", "NEO4J_AUTH=none",
|
| 386 |
+
"-v", f"{volume_name}:/data",
|
| 387 |
+
"-p", "0:7687",
|
| 388 |
+
"neo4j:latest",
|
| 389 |
+
],
|
| 390 |
+
text=True,
|
| 391 |
+
).strip()
|
| 392 |
+
|
| 393 |
+
port_out = subprocess.check_output(
|
| 394 |
+
["docker", "port", container_id, "7687/tcp"],
|
| 395 |
+
text=True,
|
| 396 |
+
).strip()
|
| 397 |
+
host_port = port_out.split(":")[-1]
|
| 398 |
+
bolt_uri = f"bolt://localhost:{host_port}"
|
| 399 |
+
return container_id, bolt_uri, volume_name, temp_load_dir
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def merge_neo4j_epoch_dump(epoch: int, dump_path: Path, dry_run: bool = False) -> None:
|
| 403 |
+
print(f"\n🧩 Merging Neo4j dump into '{NEO4J_TARGET_DB}' via temp container...")
|
| 404 |
+
if dry_run:
|
| 405 |
+
_start_temp_neo4j_from_dump(epoch, dump_path)
|
| 406 |
+
print(" [DRY-RUN] merge skipped.")
|
| 407 |
+
return
|
| 408 |
+
|
| 409 |
+
temp_container_id = None
|
| 410 |
+
temp_volume = None
|
| 411 |
+
temp_load_dir = None
|
| 412 |
+
temp_driver = None
|
| 413 |
+
temp_db_name = "neo4j"
|
| 414 |
+
|
| 415 |
+
temp_container_id, temp_bolt_uri, temp_volume, temp_load_dir = _start_temp_neo4j_from_dump(epoch, dump_path)
|
| 416 |
+
_wait_for_bolt(temp_bolt_uri)
|
| 417 |
+
from neo4j import GraphDatabase
|
| 418 |
+
temp_driver = GraphDatabase.driver(temp_bolt_uri, auth=None)
|
| 419 |
+
|
| 420 |
+
_wait_for_bolt(NEO4J_URI)
|
| 421 |
+
driver = _neo4j_driver()
|
| 422 |
+
try:
|
| 423 |
+
with temp_driver.session(database=temp_db_name) as temp_session, driver.session(database=NEO4J_TARGET_DB) as target_session:
|
| 424 |
+
# Wallet nodes
|
| 425 |
+
_stream_merge(
|
| 426 |
+
temp_session,
|
| 427 |
+
target_session,
|
| 428 |
+
"MATCH (w:Wallet) RETURN w.address AS address",
|
| 429 |
+
"UNWIND $rows AS t MERGE (w:Wallet {address: t.address})",
|
| 430 |
+
"wallets",
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
# Token nodes
|
| 434 |
+
_stream_merge(
|
| 435 |
+
temp_session,
|
| 436 |
+
target_session,
|
| 437 |
+
"MATCH (t:Token) RETURN t.address AS address, t.created_ts AS created_ts",
|
| 438 |
+
"UNWIND $rows AS t MERGE (k:Token {address: t.address}) "
|
| 439 |
+
"ON CREATE SET k.created_ts = t.created_ts "
|
| 440 |
+
"ON MATCH SET k.created_ts = CASE WHEN k.created_ts IS NULL OR "
|
| 441 |
+
"t.created_ts < k.created_ts THEN t.created_ts ELSE k.created_ts END",
|
| 442 |
+
"tokens",
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
# BUNDLE_TRADE
|
| 446 |
+
_stream_merge(
|
| 447 |
+
temp_session,
|
| 448 |
+
target_session,
|
| 449 |
+
"MATCH (a:Wallet)-[r:BUNDLE_TRADE]->(b:Wallet) "
|
| 450 |
+
"RETURN a.address AS wa, b.address AS wb, r.mint AS mint, r.slot AS slot, "
|
| 451 |
+
"r.timestamp AS timestamp, r.signatures AS signatures",
|
| 452 |
+
"UNWIND $rows AS t "
|
| 453 |
+
"MERGE (a:Wallet {address: t.wa}) "
|
| 454 |
+
"MERGE (b:Wallet {address: t.wb}) "
|
| 455 |
+
"MERGE (a)-[r:BUNDLE_TRADE {mint: t.mint, slot: t.slot}]->(b) "
|
| 456 |
+
"ON CREATE SET r.timestamp = t.timestamp, r.signatures = t.signatures "
|
| 457 |
+
"ON MATCH SET r.timestamp = CASE WHEN r.timestamp IS NULL OR "
|
| 458 |
+
"t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END",
|
| 459 |
+
"bundle_trade",
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
# TRANSFERRED_TO
|
| 463 |
+
_stream_merge(
|
| 464 |
+
temp_session,
|
| 465 |
+
target_session,
|
| 466 |
+
"MATCH (s:Wallet)-[r:TRANSFERRED_TO]->(d:Wallet) "
|
| 467 |
+
"RETURN s.address AS source, d.address AS destination, r.mint AS mint, "
|
| 468 |
+
"r.signature AS signature, r.timestamp AS timestamp, r.amount AS amount",
|
| 469 |
+
"UNWIND $rows AS t "
|
| 470 |
+
"MERGE (s:Wallet {address: t.source}) "
|
| 471 |
+
"MERGE (d:Wallet {address: t.destination}) "
|
| 472 |
+
"MERGE (s)-[r:TRANSFERRED_TO {mint: t.mint}]->(d) "
|
| 473 |
+
"ON CREATE SET r.signature = t.signature, r.timestamp = t.timestamp, r.amount = t.amount "
|
| 474 |
+
"ON MATCH SET r.timestamp = CASE WHEN r.timestamp IS NULL OR "
|
| 475 |
+
"t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END",
|
| 476 |
+
"transfer",
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
# COORDINATED_ACTIVITY
|
| 480 |
+
_stream_merge(
|
| 481 |
+
temp_session,
|
| 482 |
+
target_session,
|
| 483 |
+
"MATCH (f:Wallet)-[r:COORDINATED_ACTIVITY]->(l:Wallet) "
|
| 484 |
+
"RETURN f.address AS follower, l.address AS leader, r.mint AS mint, r.timestamp AS timestamp, "
|
| 485 |
+
"r.leader_first_sig AS leader_first_sig, r.leader_second_sig AS leader_second_sig, "
|
| 486 |
+
"r.follower_first_sig AS follower_first_sig, r.follower_second_sig AS follower_second_sig, "
|
| 487 |
+
"r.time_gap_on_first_sec AS gap_1, r.time_gap_on_second_sec AS gap_2",
|
| 488 |
+
"UNWIND $rows AS t "
|
| 489 |
+
"MERGE (l:Wallet {address: t.leader}) "
|
| 490 |
+
"MERGE (f:Wallet {address: t.follower}) "
|
| 491 |
+
"MERGE (f)-[r:COORDINATED_ACTIVITY {mint: t.mint}]->(l) "
|
| 492 |
+
"ON CREATE SET r.timestamp = t.timestamp, r.leader_first_sig = t.leader_first_sig, "
|
| 493 |
+
"r.leader_second_sig = t.leader_second_sig, r.follower_first_sig = t.follower_first_sig, "
|
| 494 |
+
"r.follower_second_sig = t.follower_second_sig, r.time_gap_on_first_sec = t.gap_1, "
|
| 495 |
+
"r.time_gap_on_second_sec = t.gap_2 "
|
| 496 |
+
"ON MATCH SET r.timestamp = CASE WHEN r.timestamp IS NULL OR "
|
| 497 |
+
"t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END",
|
| 498 |
+
"coordinated_activity",
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
# COPIED_TRADE
|
| 502 |
+
_stream_merge(
|
| 503 |
+
temp_session,
|
| 504 |
+
target_session,
|
| 505 |
+
"MATCH (f:Wallet)-[r:COPIED_TRADE]->(l:Wallet) "
|
| 506 |
+
"RETURN f.address AS follower, l.address AS leader, r.mint AS mint, r.timestamp AS timestamp, "
|
| 507 |
+
"r.buy_gap AS buy_gap, r.sell_gap AS sell_gap, r.leader_pnl AS leader_pnl, "
|
| 508 |
+
"r.follower_pnl AS follower_pnl, r.l_buy_sig AS l_buy_sig, r.l_sell_sig AS l_sell_sig, "
|
| 509 |
+
"r.f_buy_sig AS f_buy_sig, r.f_sell_sig AS f_sell_sig, r.l_buy_total AS l_buy_total, "
|
| 510 |
+
"r.l_sell_total AS l_sell_total, r.f_buy_total AS f_buy_total, r.f_sell_total AS f_sell_total, "
|
| 511 |
+
"r.f_buy_slip AS f_buy_slip, r.f_sell_slip AS f_sell_slip",
|
| 512 |
+
"UNWIND $rows AS t "
|
| 513 |
+
"MERGE (f:Wallet {address: t.follower}) "
|
| 514 |
+
"MERGE (l:Wallet {address: t.leader}) "
|
| 515 |
+
"MERGE (f)-[r:COPIED_TRADE {mint: t.mint}]->(l) "
|
| 516 |
+
"ON CREATE SET r.timestamp = t.timestamp, r.follower = t.follower, r.leader = t.leader, "
|
| 517 |
+
"r.mint = t.mint, r.buy_gap = t.buy_gap, r.sell_gap = t.sell_gap, r.leader_pnl = t.leader_pnl, "
|
| 518 |
+
"r.follower_pnl = t.follower_pnl, r.l_buy_sig = t.l_buy_sig, r.l_sell_sig = t.l_sell_sig, "
|
| 519 |
+
"r.f_buy_sig = t.f_buy_sig, r.f_sell_sig = t.f_sell_sig, r.l_buy_total = t.l_buy_total, "
|
| 520 |
+
"r.l_sell_total = t.l_sell_total, r.f_buy_total = t.f_buy_total, r.f_sell_total = t.f_sell_total, "
|
| 521 |
+
"r.f_buy_slip = t.f_buy_slip, r.f_sell_slip = t.f_sell_slip "
|
| 522 |
+
"ON MATCH SET r.timestamp = CASE WHEN r.timestamp IS NULL OR "
|
| 523 |
+
"t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END",
|
| 524 |
+
"copied_trade",
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
# MINTED
|
| 528 |
+
_stream_merge(
|
| 529 |
+
temp_session,
|
| 530 |
+
target_session,
|
| 531 |
+
"MATCH (c:Wallet)-[r:MINTED]->(k:Token) "
|
| 532 |
+
"RETURN c.address AS creator, k.address AS token, r.signature AS signature, "
|
| 533 |
+
"r.timestamp AS timestamp, r.buy_amount AS buy_amount",
|
| 534 |
+
"UNWIND $rows AS t "
|
| 535 |
+
"MERGE (c:Wallet {address: t.creator}) "
|
| 536 |
+
"MERGE (k:Token {address: t.token}) "
|
| 537 |
+
"MERGE (c)-[r:MINTED {signature: t.signature}]->(k) "
|
| 538 |
+
"ON CREATE SET r.timestamp = t.timestamp, r.buy_amount = t.buy_amount "
|
| 539 |
+
"ON MATCH SET r.timestamp = CASE WHEN r.timestamp IS NULL OR "
|
| 540 |
+
"t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END",
|
| 541 |
+
"minted",
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
# SNIPED
|
| 545 |
+
_stream_merge(
|
| 546 |
+
temp_session,
|
| 547 |
+
target_session,
|
| 548 |
+
"MATCH (w:Wallet)-[r:SNIPED]->(k:Token) "
|
| 549 |
+
"RETURN w.address AS wallet, k.address AS token, r.signature AS signature, "
|
| 550 |
+
"r.rank AS rank, r.sniped_amount AS sniped_amount, r.timestamp AS timestamp",
|
| 551 |
+
"UNWIND $rows AS t "
|
| 552 |
+
"MERGE (w:Wallet {address: t.wallet}) "
|
| 553 |
+
"MERGE (k:Token {address: t.token}) "
|
| 554 |
+
"MERGE (w)-[r:SNIPED {signature: t.signature}]->(k) "
|
| 555 |
+
"ON CREATE SET r.rank = t.rank, r.sniped_amount = t.sniped_amount, r.timestamp = t.timestamp "
|
| 556 |
+
"ON MATCH SET r.timestamp = CASE WHEN r.timestamp IS NULL OR "
|
| 557 |
+
"t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END",
|
| 558 |
+
"sniped",
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
# LOCKED_SUPPLY
|
| 562 |
+
_stream_merge(
|
| 563 |
+
temp_session,
|
| 564 |
+
target_session,
|
| 565 |
+
"MATCH (s:Wallet)-[r:LOCKED_SUPPLY]->(k:Token) "
|
| 566 |
+
"RETURN s.address AS sender, k.address AS mint, r.signature AS signature, "
|
| 567 |
+
"r.amount AS amount, r.unlock_timestamp AS unlock_ts, r.recipient AS recipient, "
|
| 568 |
+
"r.timestamp AS timestamp",
|
| 569 |
+
"UNWIND $rows AS t "
|
| 570 |
+
"MERGE (s:Wallet {address: t.sender}) "
|
| 571 |
+
"MERGE (k:Token {address: t.mint}) "
|
| 572 |
+
"MERGE (s)-[r:LOCKED_SUPPLY {signature: t.signature}]->(k) "
|
| 573 |
+
"ON CREATE SET r.amount = t.amount, r.unlock_timestamp = t.unlock_ts, "
|
| 574 |
+
"r.recipient = t.recipient, r.timestamp = t.timestamp "
|
| 575 |
+
"ON MATCH SET r.timestamp = CASE WHEN r.timestamp IS NULL OR "
|
| 576 |
+
"t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END",
|
| 577 |
+
"locked_supply",
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
# BURNED
|
| 581 |
+
_stream_merge(
|
| 582 |
+
temp_session,
|
| 583 |
+
target_session,
|
| 584 |
+
"MATCH (w:Wallet)-[r:BURNED]->(k:Token) "
|
| 585 |
+
"RETURN w.address AS wallet, k.address AS token, r.signature AS signature, "
|
| 586 |
+
"r.amount AS amount, r.timestamp AS timestamp",
|
| 587 |
+
"UNWIND $rows AS t "
|
| 588 |
+
"MERGE (w:Wallet {address: t.wallet}) "
|
| 589 |
+
"MERGE (k:Token {address: t.token}) "
|
| 590 |
+
"MERGE (w)-[r:BURNED {signature: t.signature}]->(k) "
|
| 591 |
+
"ON CREATE SET r.amount = t.amount, r.timestamp = t.timestamp "
|
| 592 |
+
"ON MATCH SET r.timestamp = CASE WHEN r.timestamp IS NULL OR "
|
| 593 |
+
"t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END",
|
| 594 |
+
"burned",
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
# PROVIDED_LIQUIDITY
|
| 598 |
+
_stream_merge(
|
| 599 |
+
temp_session,
|
| 600 |
+
target_session,
|
| 601 |
+
"MATCH (w:Wallet)-[r:PROVIDED_LIQUIDITY]->(k:Token) "
|
| 602 |
+
"RETURN w.address AS wallet, k.address AS token, r.signature AS signature, "
|
| 603 |
+
"r.pool_address AS pool_address, r.amount_base AS amount_base, "
|
| 604 |
+
"r.amount_quote AS amount_quote, r.timestamp AS timestamp",
|
| 605 |
+
"UNWIND $rows AS t "
|
| 606 |
+
"MERGE (w:Wallet {address: t.wallet}) "
|
| 607 |
+
"MERGE (k:Token {address: t.token}) "
|
| 608 |
+
"MERGE (w)-[r:PROVIDED_LIQUIDITY {signature: t.signature}]->(k) "
|
| 609 |
+
"ON CREATE SET r.pool_address = t.pool_address, r.amount_base = t.amount_base, "
|
| 610 |
+
"r.amount_quote = t.amount_quote, r.timestamp = t.timestamp "
|
| 611 |
+
"ON MATCH SET r.timestamp = CASE WHEN r.timestamp IS NULL OR "
|
| 612 |
+
"t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END",
|
| 613 |
+
"provided_liquidity",
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
# TOP_TRADER_OF
|
| 617 |
+
_stream_merge(
|
| 618 |
+
temp_session,
|
| 619 |
+
target_session,
|
| 620 |
+
"MATCH (w:Wallet)-[r:TOP_TRADER_OF]->(k:Token) "
|
| 621 |
+
"RETURN w.address AS wallet, k.address AS token, r.pnl_at_creation AS pnl_at_creation, "
|
| 622 |
+
"r.ath_usd_at_creation AS ath_at_creation, r.timestamp AS timestamp",
|
| 623 |
+
"UNWIND $rows AS t "
|
| 624 |
+
"MERGE (w:Wallet {address: t.wallet}) "
|
| 625 |
+
"MERGE (k:Token {address: t.token}) "
|
| 626 |
+
"MERGE (w)-[r:TOP_TRADER_OF]->(k) "
|
| 627 |
+
"ON CREATE SET r.pnl_at_creation = t.pnl_at_creation, r.ath_usd_at_creation = t.ath_at_creation, "
|
| 628 |
+
"r.timestamp = t.timestamp "
|
| 629 |
+
"ON MATCH SET r.timestamp = CASE WHEN r.timestamp IS NULL OR "
|
| 630 |
+
"t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END",
|
| 631 |
+
"top_trader_of",
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
# WHALE_OF
|
| 635 |
+
_stream_merge(
|
| 636 |
+
temp_session,
|
| 637 |
+
target_session,
|
| 638 |
+
"MATCH (w:Wallet)-[r:WHALE_OF]->(k:Token) "
|
| 639 |
+
"RETURN w.address AS wallet, k.address AS token, r.holding_pct_at_creation AS pct_at_creation, "
|
| 640 |
+
"r.ath_usd_at_creation AS ath_at_creation, r.timestamp AS timestamp",
|
| 641 |
+
"UNWIND $rows AS t "
|
| 642 |
+
"MERGE (w:Wallet {address: t.wallet}) "
|
| 643 |
+
"MERGE (k:Token {address: t.token}) "
|
| 644 |
+
"MERGE (w)-[r:WHALE_OF]->(k) "
|
| 645 |
+
"ON CREATE SET r.holding_pct_at_creation = t.pct_at_creation, "
|
| 646 |
+
"r.ath_usd_at_creation = t.ath_at_creation, r.timestamp = t.timestamp "
|
| 647 |
+
"ON MATCH SET r.timestamp = CASE WHEN r.timestamp IS NULL OR "
|
| 648 |
+
"t.timestamp < r.timestamp THEN t.timestamp ELSE r.timestamp END",
|
| 649 |
+
"whale_of",
|
| 650 |
+
)
|
| 651 |
+
finally:
|
| 652 |
+
driver.close()
|
| 653 |
+
|
| 654 |
+
try:
|
| 655 |
+
if temp_driver:
|
| 656 |
+
temp_driver.close()
|
| 657 |
+
if temp_container_id:
|
| 658 |
+
import subprocess
|
| 659 |
+
subprocess.run(["docker", "stop", temp_container_id], check=True)
|
| 660 |
+
if temp_volume:
|
| 661 |
+
import subprocess
|
| 662 |
+
subprocess.run(["docker", "volume", "rm", "-f", temp_volume], check=True)
|
| 663 |
+
if temp_load_dir:
|
| 664 |
+
import shutil
|
| 665 |
+
shutil.rmtree(temp_load_dir, ignore_errors=True)
|
| 666 |
+
print(" 🧹 Dropped temp Neo4j container.")
|
| 667 |
+
except Exception as e:
|
| 668 |
+
print(f" ⚠️ Failed to clean temp Neo4j container: {e}")
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
def parse_args() -> argparse.Namespace:
|
| 672 |
+
parser = argparse.ArgumentParser(description="ETL: Download, Ingest, Delete epoch Parquet files.")
|
| 673 |
+
parser.add_argument("--epoch", type=int, required=True, help="Epoch number to process (e.g., 851)")
|
| 674 |
+
parser.add_argument("-c", "--skip-clickhouse", action="store_true", help="Skip ClickHouse ingestion")
|
| 675 |
+
parser.add_argument("--dry-run", action="store_true", help="Print queries without executing")
|
| 676 |
+
parser.add_argument("--skip-neo4j", action="store_true", help="Skip Neo4j dump loading")
|
| 677 |
+
parser.add_argument("--token", type=str, default=None, help="Hugging Face token (or set HF_TOKEN env var)")
|
| 678 |
+
return parser.parse_args()
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
def main() -> None:
|
| 682 |
+
args = parse_args()
|
| 683 |
+
token = args.token or os.environ.get("HF_TOKEN")
|
| 684 |
+
|
| 685 |
+
dest_dir = Path(DEFAULT_DEST_DIR).expanduser() / f"epoch_{args.epoch}"
|
| 686 |
+
|
| 687 |
+
# Connect to ClickHouse
|
| 688 |
+
print(f"🔌 Connecting to ClickHouse at {CH_HOST}:{CH_PORT}...")
|
| 689 |
+
try:
|
| 690 |
+
client = clickhouse_connect.get_client(
|
| 691 |
+
host=CH_HOST,
|
| 692 |
+
port=CH_PORT,
|
| 693 |
+
username=CH_USER,
|
| 694 |
+
password=CH_PASSWORD,
|
| 695 |
+
database=CH_DATABASE,
|
| 696 |
+
)
|
| 697 |
+
except Exception as e:
|
| 698 |
+
print(f"❌ Failed to connect to ClickHouse: {e}")
|
| 699 |
+
sys.exit(1)
|
| 700 |
+
|
| 701 |
+
run_etl(
|
| 702 |
+
epoch=args.epoch,
|
| 703 |
+
dest_dir=dest_dir,
|
| 704 |
+
client=client,
|
| 705 |
+
dry_run=args.dry_run,
|
| 706 |
+
token=token,
|
| 707 |
+
skip_neo4j=args.skip_neo4j,
|
| 708 |
+
skip_clickhouse=args.skip_clickhouse,
|
| 709 |
+
)
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
if __name__ == "__main__":
|
| 713 |
+
main()
|
train.py
ADDED
|
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import math
|
| 4 |
+
import logging
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
# Ensure torch/dill have a writable tmp dir
|
| 9 |
+
_DEFAULT_TMP = Path(os.getenv("TMPDIR_OVERRIDE", "./.tmp"))
|
| 10 |
+
_DEFAULT_TMP.mkdir(parents=True, exist_ok=True)
|
| 11 |
+
resolved_tmp = str(_DEFAULT_TMP.resolve())
|
| 12 |
+
for key in ("TMPDIR", "TMP", "TEMP"):
|
| 13 |
+
os.environ.setdefault(key, resolved_tmp)
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
from torch.utils.data import DataLoader
|
| 18 |
+
from torch.optim import AdamW
|
| 19 |
+
|
| 20 |
+
# --- Accelerate & Transformers ---
|
| 21 |
+
from accelerate import Accelerator
|
| 22 |
+
from accelerate.logging import get_logger
|
| 23 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
| 24 |
+
from transformers import get_linear_schedule_with_warmup
|
| 25 |
+
|
| 26 |
+
# Logging
|
| 27 |
+
from tqdm.auto import tqdm
|
| 28 |
+
|
| 29 |
+
# DB Clients
|
| 30 |
+
from clickhouse_driver import Client as ClickHouseClient
|
| 31 |
+
from neo4j import GraphDatabase
|
| 32 |
+
|
| 33 |
+
# Local Imports
|
| 34 |
+
from data.data_fetcher import DataFetcher
|
| 35 |
+
from data.data_loader import OracleDataset
|
| 36 |
+
from data.data_collator import MemecoinCollator
|
| 37 |
+
from models.multi_modal_processor import MultiModalEncoder
|
| 38 |
+
from models.helper_encoders import ContextualTimeEncoder
|
| 39 |
+
from models.token_encoder import TokenEncoder
|
| 40 |
+
from models.wallet_encoder import WalletEncoder
|
| 41 |
+
from models.graph_updater import GraphUpdater
|
| 42 |
+
from models.ohlc_embedder import OHLCEmbedder
|
| 43 |
+
from models.model import Oracle
|
| 44 |
+
import models.vocabulary as vocab
|
| 45 |
+
|
| 46 |
+
# Setup Logger
|
| 47 |
+
logger = get_logger(__name__)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def compute_gradient_stats(model: nn.Module) -> Tuple[Optional[Dict[str, float]], Dict[str, float]]:
|
| 51 |
+
"""Return overall and per-module gradient statistics for logging."""
|
| 52 |
+
grad_norms: List[float] = []
|
| 53 |
+
max_abs = 0.0
|
| 54 |
+
module_l2_sums: Dict[str, float] = {}
|
| 55 |
+
|
| 56 |
+
for name, param in model.named_parameters():
|
| 57 |
+
if param.grad is None:
|
| 58 |
+
continue
|
| 59 |
+
grad = param.grad.detach()
|
| 60 |
+
grad_norm = grad.norm().item()
|
| 61 |
+
grad_norms.append(grad_norm)
|
| 62 |
+
max_abs = max(max_abs, grad.abs().max().item())
|
| 63 |
+
|
| 64 |
+
module_name = name.split(".", 1)[0]
|
| 65 |
+
grad_fp32 = grad.float()
|
| 66 |
+
module_l2_sums[module_name] = module_l2_sums.get(module_name, 0.0) + float(grad_fp32.pow(2).sum().item())
|
| 67 |
+
|
| 68 |
+
if not grad_norms:
|
| 69 |
+
return None, {}
|
| 70 |
+
|
| 71 |
+
module_grad_norms = {module: math.sqrt(total) for module, total in module_l2_sums.items()}
|
| 72 |
+
|
| 73 |
+
return {
|
| 74 |
+
"grad_norm_mean": float(sum(grad_norms) / len(grad_norms)),
|
| 75 |
+
"grad_norm_max": float(max(grad_norms)),
|
| 76 |
+
"grad_abs_max": float(max_abs),
|
| 77 |
+
}, module_grad_norms
|
| 78 |
+
|
| 79 |
+
def quantile_pinball_loss(preds: torch.Tensor,
|
| 80 |
+
targets: torch.Tensor,
|
| 81 |
+
mask: torch.Tensor,
|
| 82 |
+
quantiles: List[float]) -> torch.Tensor:
|
| 83 |
+
"""
|
| 84 |
+
Calculates Pinball Loss for quantile regression.
|
| 85 |
+
"""
|
| 86 |
+
if mask.sum() == 0:
|
| 87 |
+
return torch.tensor(0.0, device=preds.device, dtype=preds.dtype)
|
| 88 |
+
|
| 89 |
+
num_quantiles = len(quantiles)
|
| 90 |
+
losses = []
|
| 91 |
+
for idx, q in enumerate(quantiles):
|
| 92 |
+
# Preds shape: [B, Horizons * Quantiles]
|
| 93 |
+
# Logic assumes interleaved outputs or consistent flattening.
|
| 94 |
+
pred_slice = preds[:, idx::num_quantiles]
|
| 95 |
+
target_slice = targets[:, idx::num_quantiles]
|
| 96 |
+
mask_slice = mask[:, idx::num_quantiles]
|
| 97 |
+
|
| 98 |
+
diff = target_slice - pred_slice
|
| 99 |
+
pinball = torch.maximum((q - 1.0) * diff, q * diff)
|
| 100 |
+
losses.append((pinball * mask_slice).sum())
|
| 101 |
+
|
| 102 |
+
return sum(losses) / mask.sum().clamp_min(1.0)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def filtered_collate(collator: MemecoinCollator,
|
| 106 |
+
batch: List[Optional[Dict[str, Any]]]) -> Optional[Dict[str, Any]]:
|
| 107 |
+
"""Filter out None items from the dataset before collating."""
|
| 108 |
+
batch = [item for item in batch if item is not None]
|
| 109 |
+
if not batch:
|
| 110 |
+
return None
|
| 111 |
+
return collator(batch)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def parse_args() -> argparse.Namespace:
|
| 115 |
+
parser = argparse.ArgumentParser(description="Train the Oracle quantile model.")
|
| 116 |
+
parser.add_argument("--epochs", type=int, default=1)
|
| 117 |
+
parser.add_argument("--batch_size", type=int, default=1)
|
| 118 |
+
parser.add_argument("--learning_rate", type=float, default=5e-5)
|
| 119 |
+
parser.add_argument("--warmup_ratio", type=float, default=0.1)
|
| 120 |
+
parser.add_argument("--grad_accum_steps", type=int, default=1)
|
| 121 |
+
parser.add_argument("--max_grad_norm", type=float, default=1.0)
|
| 122 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 123 |
+
parser.add_argument("--log_every", type=int, default=1)
|
| 124 |
+
parser.add_argument("--save_every", type=int, default=1000)
|
| 125 |
+
parser.add_argument("--tensorboard_dir", type=str, default="runs/oracle")
|
| 126 |
+
parser.add_argument("--checkpoint_dir", type=str, default="checkpoints")
|
| 127 |
+
parser.add_argument("--mixed_precision", type=str, default="bf16")
|
| 128 |
+
parser.add_argument("--max_seq_len", type=int, default=16000)
|
| 129 |
+
parser.add_argument("--ohlc_seq_len", type=int, default=60)
|
| 130 |
+
parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[30, 60, 120, 240, 420])
|
| 131 |
+
parser.add_argument("--quantiles", type=float, nargs="+", default=[0.1, 0.5, 0.9])
|
| 132 |
+
parser.add_argument("--max_samples", type=int, default=None)
|
| 133 |
+
parser.add_argument("--ohlc_stats_path", type=str, default="./data/ohlc_stats.npz")
|
| 134 |
+
parser.add_argument("--t_cutoff_seconds", type=int, default=60)
|
| 135 |
+
parser.add_argument("--shuffle", dest="shuffle", action="store_true", default=True)
|
| 136 |
+
parser.add_argument("--no-shuffle", dest="shuffle", action="store_false")
|
| 137 |
+
parser.add_argument("--num_workers", type=int, default=0)
|
| 138 |
+
parser.add_argument("--pin_memory", dest="pin_memory", action="store_true", default=False)
|
| 139 |
+
parser.add_argument("--no-pin_memory", dest="pin_memory", action="store_false")
|
| 140 |
+
parser.add_argument("--clickhouse_host", type=str, default="localhost")
|
| 141 |
+
parser.add_argument("--clickhouse_port", type=int, default=9000)
|
| 142 |
+
parser.add_argument("--neo4j_uri", type=str, default="bolt://localhost:7687")
|
| 143 |
+
parser.add_argument("--neo4j_user", type=str, default=None)
|
| 144 |
+
parser.add_argument("--neo4j_password", type=str, default=None)
|
| 145 |
+
return parser.parse_args()
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def main() -> None:
|
| 149 |
+
args = parse_args()
|
| 150 |
+
epochs = args.epochs
|
| 151 |
+
batch_size = args.batch_size
|
| 152 |
+
learning_rate = args.learning_rate
|
| 153 |
+
warmup_ratio = args.warmup_ratio
|
| 154 |
+
grad_accum_steps = args.grad_accum_steps
|
| 155 |
+
max_grad_norm = args.max_grad_norm
|
| 156 |
+
seed = args.seed
|
| 157 |
+
|
| 158 |
+
log_every = args.log_every
|
| 159 |
+
save_every = args.save_every
|
| 160 |
+
|
| 161 |
+
tensorboard_dir = Path(args.tensorboard_dir).expanduser()
|
| 162 |
+
checkpoint_dir = Path(args.checkpoint_dir).expanduser()
|
| 163 |
+
|
| 164 |
+
# --- 1. Initialize Accelerator ---
|
| 165 |
+
project_config = ProjectConfiguration(project_dir=str(checkpoint_dir), logging_dir=str(tensorboard_dir))
|
| 166 |
+
accelerator = Accelerator(
|
| 167 |
+
gradient_accumulation_steps=grad_accum_steps,
|
| 168 |
+
log_with="tensorboard",
|
| 169 |
+
project_config=project_config,
|
| 170 |
+
mixed_precision=args.mixed_precision # Default to bf16 for stability
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# Make one log on every process with the configuration for debugging.
|
| 174 |
+
logging.basicConfig(
|
| 175 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 176 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 177 |
+
level=logging.INFO,
|
| 178 |
+
)
|
| 179 |
+
logger.info(accelerator.state, main_process_only=False)
|
| 180 |
+
|
| 181 |
+
# Set seed for reproducibility
|
| 182 |
+
set_seed(seed)
|
| 183 |
+
|
| 184 |
+
if accelerator.is_main_process:
|
| 185 |
+
logger.info("Initialized with CLI arguments.")
|
| 186 |
+
tensorboard_dir.mkdir(parents=True, exist_ok=True)
|
| 187 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 188 |
+
accelerator.init_trackers("oracle_training")
|
| 189 |
+
|
| 190 |
+
device = accelerator.device
|
| 191 |
+
|
| 192 |
+
# Determine dtype for model initialization
|
| 193 |
+
init_dtype = torch.float32
|
| 194 |
+
if accelerator.mixed_precision == 'bf16':
|
| 195 |
+
init_dtype = torch.bfloat16
|
| 196 |
+
elif accelerator.mixed_precision == 'fp16':
|
| 197 |
+
init_dtype = torch.float16
|
| 198 |
+
|
| 199 |
+
# --- 2. Data Setup ---
|
| 200 |
+
horizons = args.horizons_seconds
|
| 201 |
+
quantiles = args.quantiles
|
| 202 |
+
max_seq_len = args.max_seq_len
|
| 203 |
+
ohlc_seq_len = args.ohlc_seq_len
|
| 204 |
+
|
| 205 |
+
logger.info(f"Initializing Encoders with dtype={init_dtype}...")
|
| 206 |
+
|
| 207 |
+
# Encoders
|
| 208 |
+
multi_modal_encoder = MultiModalEncoder(dtype=init_dtype)
|
| 209 |
+
time_encoder = ContextualTimeEncoder(dtype=init_dtype)
|
| 210 |
+
token_encoder = TokenEncoder(multi_dim=multi_modal_encoder.embedding_dim, dtype=init_dtype)
|
| 211 |
+
wallet_encoder = WalletEncoder(encoder=multi_modal_encoder, dtype=init_dtype)
|
| 212 |
+
graph_updater = GraphUpdater(time_encoder=time_encoder, dtype=init_dtype)
|
| 213 |
+
ohlc_embedder = OHLCEmbedder(
|
| 214 |
+
num_intervals=vocab.NUM_OHLC_INTERVALS,
|
| 215 |
+
sequence_length=ohlc_seq_len,
|
| 216 |
+
dtype=init_dtype
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
collator = MemecoinCollator(
|
| 220 |
+
event_type_to_id=vocab.EVENT_TO_ID,
|
| 221 |
+
device=device, # Note: Collator will handle basic moves, Accelerate handles the rest
|
| 222 |
+
multi_modal_encoder=multi_modal_encoder,
|
| 223 |
+
dtype=init_dtype,
|
| 224 |
+
ohlc_seq_len=ohlc_seq_len,
|
| 225 |
+
max_seq_len=max_seq_len
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# DB Connections
|
| 229 |
+
clickhouse_client = ClickHouseClient(
|
| 230 |
+
host=args.clickhouse_host,
|
| 231 |
+
port=int(args.clickhouse_port)
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
neo4j_auth = None
|
| 235 |
+
if args.neo4j_user is not None:
|
| 236 |
+
neo4j_auth = (args.neo4j_user, args.neo4j_password or "")
|
| 237 |
+
neo4j_driver = GraphDatabase.driver(args.neo4j_uri, auth=neo4j_auth)
|
| 238 |
+
|
| 239 |
+
data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
|
| 240 |
+
|
| 241 |
+
dataset = OracleDataset(
|
| 242 |
+
data_fetcher=data_fetcher,
|
| 243 |
+
horizons_seconds=horizons,
|
| 244 |
+
quantiles=quantiles,
|
| 245 |
+
max_samples=args.max_samples,
|
| 246 |
+
ohlc_stats_path=args.ohlc_stats_path,
|
| 247 |
+
t_cutoff_seconds=int(args.t_cutoff_seconds)
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
if len(dataset) == 0:
|
| 251 |
+
raise RuntimeError("Dataset is empty.")
|
| 252 |
+
|
| 253 |
+
dataloader = DataLoader(
|
| 254 |
+
dataset,
|
| 255 |
+
batch_size=batch_size,
|
| 256 |
+
shuffle=bool(args.shuffle),
|
| 257 |
+
num_workers=int(args.num_workers),
|
| 258 |
+
pin_memory=bool(args.pin_memory),
|
| 259 |
+
collate_fn=lambda batch: filtered_collate(collator, batch)
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
# --- 3. Model Init ---
|
| 263 |
+
logger.info("Initializing Oracle Model...")
|
| 264 |
+
model = Oracle(
|
| 265 |
+
token_encoder=token_encoder,
|
| 266 |
+
wallet_encoder=wallet_encoder,
|
| 267 |
+
graph_updater=graph_updater,
|
| 268 |
+
ohlc_embedder=ohlc_embedder,
|
| 269 |
+
time_encoder=time_encoder,
|
| 270 |
+
num_event_types=vocab.NUM_EVENT_TYPES,
|
| 271 |
+
multi_modal_dim=multi_modal_encoder.embedding_dim,
|
| 272 |
+
event_pad_id=vocab.EVENT_TO_ID["__PAD__"],
|
| 273 |
+
event_type_to_id=vocab.EVENT_TO_ID,
|
| 274 |
+
model_config_name="Qwen/Qwen3-0.6B",
|
| 275 |
+
quantiles=quantiles,
|
| 276 |
+
horizons_seconds=horizons,
|
| 277 |
+
dtype=init_dtype
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
# Memory Optimization: Delete unused embedding layer from Qwen backbone
|
| 281 |
+
if hasattr(model.model, 'embed_tokens'):
|
| 282 |
+
del model.model.embed_tokens
|
| 283 |
+
logger.info("Freed unused Qwen embedding layer memory.")
|
| 284 |
+
|
| 285 |
+
# --- 4. Optimizer & Scheduler ---
|
| 286 |
+
optimizer = AdamW(model.parameters(), lr=learning_rate)
|
| 287 |
+
|
| 288 |
+
# Calculate training steps
|
| 289 |
+
num_update_steps_per_epoch = math.ceil(len(dataloader) / grad_accum_steps)
|
| 290 |
+
max_train_steps = epochs * num_update_steps_per_epoch
|
| 291 |
+
num_warmup_steps = int(max_train_steps * warmup_ratio)
|
| 292 |
+
|
| 293 |
+
scheduler = get_linear_schedule_with_warmup(
|
| 294 |
+
optimizer,
|
| 295 |
+
num_warmup_steps=num_warmup_steps,
|
| 296 |
+
num_training_steps=max_train_steps
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# --- 5. Accelerate Prepare ---
|
| 300 |
+
model, optimizer, dataloader, scheduler = accelerator.prepare(
|
| 301 |
+
model, optimizer, dataloader, scheduler
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
# --- 6. Resume Training Logic ---
|
| 305 |
+
# Load checkpoint if it exists
|
| 306 |
+
starting_epoch = 0
|
| 307 |
+
resume_step = 0
|
| 308 |
+
|
| 309 |
+
# Check for existing checkpoints
|
| 310 |
+
if checkpoint_dir.exists():
|
| 311 |
+
# Look for subfolders named 'checkpoint-X' or 'epoch_X'
|
| 312 |
+
# Accelerate saves to folders.
|
| 313 |
+
dirs = [d for d in checkpoint_dir.iterdir() if d.is_dir()]
|
| 314 |
+
if dirs:
|
| 315 |
+
# Sort by modification time or name to find latest
|
| 316 |
+
dirs.sort(key=lambda x: x.stat().st_mtime)
|
| 317 |
+
latest_checkpoint = dirs[-1]
|
| 318 |
+
logger.info(f"Found checkpoint: {latest_checkpoint}. Resuming training...")
|
| 319 |
+
accelerator.load_state(str(latest_checkpoint))
|
| 320 |
+
|
| 321 |
+
# Try to infer epoch/step from folder name or saved state if custom tracking
|
| 322 |
+
# Accelerate restores DataLoader state, so we mainly need to know where we are for logging
|
| 323 |
+
# Assuming standard naming or just relying on DataLoader restore.
|
| 324 |
+
# Simple approach: Just trust Accelerate/DataLoader to skip.
|
| 325 |
+
# If you need precise epoch/step recovery for logging display:
|
| 326 |
+
# You could save a metadata.json inside the checkpoint folder.
|
| 327 |
+
|
| 328 |
+
logger.info("Checkpoint loaded. DataLoader state restored.")
|
| 329 |
+
else:
|
| 330 |
+
logger.info("No checkpoint found. Starting fresh.")
|
| 331 |
+
|
| 332 |
+
# --- 7. Training Loop ---
|
| 333 |
+
total_steps = 0
|
| 334 |
+
|
| 335 |
+
logger.info("***** Running training *****")
|
| 336 |
+
logger.info(f" Num examples = {len(dataset)}")
|
| 337 |
+
logger.info(f" Num Epochs = {epochs}")
|
| 338 |
+
logger.info(f" Instantaneous batch size per device = {batch_size}")
|
| 339 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {batch_size * accelerator.num_processes * grad_accum_steps}")
|
| 340 |
+
logger.info(f" Gradient Accumulation steps = {grad_accum_steps}")
|
| 341 |
+
logger.info(f" Total optimization steps = {max_train_steps}")
|
| 342 |
+
|
| 343 |
+
for epoch in range(starting_epoch, epochs):
|
| 344 |
+
model.train()
|
| 345 |
+
epoch_loss = 0.0
|
| 346 |
+
valid_batches = 0
|
| 347 |
+
|
| 348 |
+
# Tqdm only on main process
|
| 349 |
+
progress_bar = tqdm(
|
| 350 |
+
dataloader,
|
| 351 |
+
desc=f"Epoch {epoch+1}/{epochs}",
|
| 352 |
+
disable=not accelerator.is_local_main_process,
|
| 353 |
+
initial=resume_step # If you calculate resume_step from checkpoint
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
for step, batch in enumerate(progress_bar):
|
| 357 |
+
# Skip steps if resuming (Accelerate dataloader might handle this automatically if configured,
|
| 358 |
+
# but 'skip_first_batches' is often manual.
|
| 359 |
+
# For simplicity here, we assume load_state restored the dataloader iterator.)
|
| 360 |
+
|
| 361 |
+
if batch is None:
|
| 362 |
+
continue
|
| 363 |
+
|
| 364 |
+
# Safety Patch for missing social data
|
| 365 |
+
if 'textual_event_indices' not in batch:
|
| 366 |
+
B, L = batch['event_type_ids'].shape
|
| 367 |
+
batch['textual_event_indices'] = torch.zeros((B, L), dtype=torch.long, device=accelerator.device)
|
| 368 |
+
if 'textual_event_data' not in batch:
|
| 369 |
+
batch['textual_event_data'] = []
|
| 370 |
+
|
| 371 |
+
grad_stats: Optional[Dict[str, float]] = None
|
| 372 |
+
module_grad_stats: Dict[str, float] = {}
|
| 373 |
+
with accelerator.accumulate(model):
|
| 374 |
+
outputs = model(batch)
|
| 375 |
+
|
| 376 |
+
preds = outputs["quantile_logits"]
|
| 377 |
+
labels = batch["labels"]
|
| 378 |
+
labels_mask = batch["labels_mask"]
|
| 379 |
+
|
| 380 |
+
if labels_mask.sum() == 0:
|
| 381 |
+
loss = torch.tensor(0.0, requires_grad=True, device=accelerator.device)
|
| 382 |
+
else:
|
| 383 |
+
loss = quantile_pinball_loss(preds, labels, labels_mask, quantiles)
|
| 384 |
+
|
| 385 |
+
accelerator.backward(loss)
|
| 386 |
+
|
| 387 |
+
if accelerator.sync_gradients:
|
| 388 |
+
accelerator.clip_grad_norm_(model.parameters(), max_grad_norm)
|
| 389 |
+
grad_stats, module_grad_stats = compute_gradient_stats(model)
|
| 390 |
+
if grad_stats and accelerator.is_main_process:
|
| 391 |
+
logger.info(
|
| 392 |
+
"Gradients - mean norm: %.4f | max norm: %.4f | max abs: %.4f",
|
| 393 |
+
grad_stats["grad_norm_mean"],
|
| 394 |
+
grad_stats["grad_norm_max"],
|
| 395 |
+
grad_stats["grad_abs_max"],
|
| 396 |
+
)
|
| 397 |
+
if module_grad_stats:
|
| 398 |
+
module_entries = " | ".join(
|
| 399 |
+
f"{name}: {norm:.4f}" for name, norm in sorted(module_grad_stats.items())
|
| 400 |
+
)
|
| 401 |
+
logger.info("Per-module grad norms: %s", module_entries)
|
| 402 |
+
|
| 403 |
+
optimizer.step()
|
| 404 |
+
scheduler.step()
|
| 405 |
+
optimizer.zero_grad()
|
| 406 |
+
|
| 407 |
+
# Logging
|
| 408 |
+
if accelerator.sync_gradients:
|
| 409 |
+
total_steps += 1
|
| 410 |
+
current_loss = loss.item()
|
| 411 |
+
epoch_loss += current_loss
|
| 412 |
+
valid_batches += 1
|
| 413 |
+
|
| 414 |
+
if total_steps % log_every == 0:
|
| 415 |
+
lr = scheduler.get_last_lr()[0]
|
| 416 |
+
log_payload = {
|
| 417 |
+
"train/loss": current_loss,
|
| 418 |
+
"train/learning_rate": lr,
|
| 419 |
+
"train/epoch": epoch + (step / len(dataloader))
|
| 420 |
+
}
|
| 421 |
+
if grad_stats:
|
| 422 |
+
log_payload.update({
|
| 423 |
+
"train/grad_norm_mean": grad_stats["grad_norm_mean"],
|
| 424 |
+
"train/grad_norm_max": grad_stats["grad_norm_max"],
|
| 425 |
+
"train/grad_abs_max": grad_stats["grad_abs_max"],
|
| 426 |
+
})
|
| 427 |
+
accelerator.log(log_payload, step=total_steps)
|
| 428 |
+
|
| 429 |
+
if accelerator.is_main_process:
|
| 430 |
+
progress_bar.set_postfix({"loss": f"{current_loss:.4f}", "lr": f"{lr:.2e}"})
|
| 431 |
+
if grad_stats:
|
| 432 |
+
logger.info(
|
| 433 |
+
"Step %d | loss %.4f | grad_norm %.4f",
|
| 434 |
+
total_steps,
|
| 435 |
+
current_loss,
|
| 436 |
+
grad_stats["grad_norm_mean"],
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
# Save Checkpoint periodically
|
| 440 |
+
if total_steps % save_every == 0:
|
| 441 |
+
if accelerator.is_main_process:
|
| 442 |
+
save_path = checkpoint_dir / f"checkpoint-{total_steps}"
|
| 443 |
+
accelerator.save_state(output_dir=str(save_path))
|
| 444 |
+
logger.info(f"Saved checkpoint to {save_path}")
|
| 445 |
+
|
| 446 |
+
# End of Epoch Handling
|
| 447 |
+
if valid_batches > 0:
|
| 448 |
+
avg_loss = epoch_loss / valid_batches
|
| 449 |
+
if accelerator.is_main_process:
|
| 450 |
+
logger.info(f"Epoch {epoch+1} complete. Avg loss: {avg_loss:.6f}")
|
| 451 |
+
accelerator.log({"train/loss_epoch": avg_loss}, step=global_step)
|
| 452 |
+
|
| 453 |
+
# Save Checkpoint at end of epoch
|
| 454 |
+
save_path = checkpoint_dir / f"epoch_{epoch+1}"
|
| 455 |
+
accelerator.save_state(output_dir=str(save_path))
|
| 456 |
+
logger.info(f"Saved checkpoint to {save_path}")
|
| 457 |
+
else:
|
| 458 |
+
if accelerator.is_main_process:
|
| 459 |
+
logger.warning(f"Epoch {epoch+1}: No valid batches processed.")
|
| 460 |
+
|
| 461 |
+
accelerator.end_training()
|
| 462 |
+
neo4j_driver.close()
|
| 463 |
+
|
| 464 |
+
if __name__ == "__main__":
|
| 465 |
+
main()
|
train.sh
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate launch train.py \
|
| 2 |
+
--epochs 1 \
|
| 3 |
+
--batch_size 1 \
|
| 4 |
+
--learning_rate 1e-4 \
|
| 5 |
+
--warmup_ratio 0.1 \
|
| 6 |
+
--grad_accum_steps 1 \
|
| 7 |
+
--max_grad_norm 1.0 \
|
| 8 |
+
--seed 42 \
|
| 9 |
+
--log_every 1 \
|
| 10 |
+
--save_every 1000 \
|
| 11 |
+
--tensorboard_dir runs/oracle \
|
| 12 |
+
--checkpoint_dir checkpoints \
|
| 13 |
+
--mixed_precision bf16 \
|
| 14 |
+
--max_seq_len 50 \
|
| 15 |
+
--ohlc_seq_len 300 \
|
| 16 |
+
--horizons_seconds 30 60 120 240 420 \
|
| 17 |
+
--quantiles 0.1 0.5 0.9 \
|
| 18 |
+
--ohlc_stats_path ./data/ohlc_stats.npz \
|
| 19 |
+
--t_cutoff_seconds 60 \
|
| 20 |
+
--num_workers 4 \
|
| 21 |
+
--clickhouse_host localhost \
|
| 22 |
+
--clickhouse_port 9000 \
|
| 23 |
+
--neo4j_uri bolt://localhost:7687
|
train.yaml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
training:
|
| 2 |
+
epochs: 1
|
| 3 |
+
batch_size: 1
|
| 4 |
+
learning_rate: 5.0e-05
|
| 5 |
+
use_amp: true
|
| 6 |
+
log_every: 1
|
| 7 |
+
disable_tqdm: false
|
| 8 |
+
tensorboard_logdir: runs/oracle
|
| 9 |
+
checkpoint_path: checkpoints/oracle_checkpoint.pt
|
| 10 |
+
|
| 11 |
+
data:
|
| 12 |
+
max_samples: null
|
| 13 |
+
horizons_seconds: [30, 60, 120, 240, 420]
|
| 14 |
+
quantiles: [0.1, 0.5, 0.9]
|
| 15 |
+
max_seq_len: 50
|
| 16 |
+
ohlc_seq_len: 300
|
| 17 |
+
ohlc_stats_path: ./data/ohlc_stats.npz
|
| 18 |
+
t_cutoff_seconds: 60
|
| 19 |
+
shuffle: true
|
| 20 |
+
num_workers: 0
|
| 21 |
+
pin_memory: false
|
| 22 |
+
|
| 23 |
+
databases:
|
| 24 |
+
clickhouse:
|
| 25 |
+
host: localhost
|
| 26 |
+
port: 9000
|
| 27 |
+
neo4j:
|
| 28 |
+
uri: bolt://localhost:7687
|
| 29 |
+
user: null
|
| 30 |
+
password: null
|
utils.sql
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
OPTIMIZE TABLE wallet_profiles FINAL;
|
| 4 |
+
OPTIMIZE TABLE wallet_profile_metrics_latest FINAL;
|
| 5 |
+
OPTIMIZE TABLE wallet_holdings_latest FINAL;
|
| 6 |
+
OPTIMIZE TABLE tokens_latest FINAL;
|
| 7 |
+
OPTIMIZE TABLE token_metrics_latest FINAL;
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
TRUNCATE TABLE wallet_holdings;
|
| 11 |
+
TRUNCATE TABLE trades;
|
| 12 |
+
TRUNCATE TABLE transfers;
|
| 13 |
+
TRUNCATE TABLE burns;
|
| 14 |
+
TRUNCATE TABLE tokens;
|
| 15 |
+
TRUNCATE TABLE mints;
|
| 16 |
+
TRUNCATE TABLE liquidity;
|
| 17 |
+
TRUNCATE TABLE pool_creations;
|
| 18 |
+
TRUNCATE TABLE token_metrics;
|
| 19 |
+
TRUNCATE TABLE wallet_profile_metrics;
|
| 20 |
+
TRUNCATE TABLE migrations;
|
| 21 |
+
TRUNCATE TABLE fee_collections;
|
| 22 |
+
TRUNCATE TABLE supply_locks;
|
| 23 |
+
TRUNCATE TABLE supply_lock_actions;
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
TRUNCATE TABLE wallet_profile_metrics_latest;
|
| 27 |
+
TRUNCATE TABLE wallet_holdings_latest;
|
| 28 |
+
TRUNCATE TABLE token_metrics_latest;
|
| 29 |
+
TRUNCATE TABLE tokens_latest;
|
| 30 |
+
TRUNCATE TABLE wallet_profiles;
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
DROP TABLE IF EXISTS trades;
|
| 34 |
+
DROP TABLE IF EXISTS mints;
|
| 35 |
+
DROP TABLE IF EXISTS migrations;
|
| 36 |
+
DROP TABLE IF EXISTS fee_collections;
|
| 37 |
+
DROP TABLE IF EXISTS liquidity;
|
| 38 |
+
DROP TABLE IF EXISTS pool_creations;
|
| 39 |
+
DROP TABLE IF EXISTS transfers;
|
| 40 |
+
DROP TABLE IF EXISTS burns;
|
| 41 |
+
DROP TABLE IF EXISTS wallet_profiles;
|
| 42 |
+
DROP TABLE IF EXISTS wallet_holdings;
|
| 43 |
+
DROP TABLE IF EXISTS wallet_profile_metrics;
|
| 44 |
+
DROP TABLE IF EXISTS wallet_profile_metrics_latest;
|
| 45 |
+
DROP TABLE IF EXISTS tokens;
|
| 46 |
+
DROP TABLE IF EXISTS token_metrics;
|
| 47 |
+
DROP TABLE IF EXISTS token_metrics_latest;
|
| 48 |
+
DROP TABLE IF EXISTS supply_locks;
|
| 49 |
+
DROP TABLE IF EXISTS supply_lock_actions;
|
| 50 |
+
DROP TABLE IF EXISTS wallet_holdings_latest;
|
| 51 |
+
DROP TABLE IF EXISTS tokens_latest;
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
-- Backfilling Logic
|
| 55 |
+
|
| 56 |
+
CREATE TABLE IF NOT EXISTS tokens_backfill
|
| 57 |
+
(
|
| 58 |
+
token_address String,
|
| 59 |
+
name String,
|
| 60 |
+
symbol String,
|
| 61 |
+
token_uri String,
|
| 62 |
+
is_mutable UInt8,
|
| 63 |
+
update_authority Nullable(String),
|
| 64 |
+
mint_authority Nullable(String),
|
| 65 |
+
freeze_authority Nullable(String),
|
| 66 |
+
protocol UInt8
|
| 67 |
+
)
|
| 68 |
+
ENGINE = MergeTree
|
| 69 |
+
ORDER BY token_address;
|
validate.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import yaml
|
| 3 |
+
import torch
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any, Dict, List, Optional
|
| 6 |
+
|
| 7 |
+
from clickhouse_driver import Client as ClickHouseClient
|
| 8 |
+
from neo4j import GraphDatabase
|
| 9 |
+
|
| 10 |
+
from data.data_fetcher import DataFetcher
|
| 11 |
+
from data.data_loader import OracleDataset
|
| 12 |
+
from data.data_collator import MemecoinCollator
|
| 13 |
+
from models.multi_modal_processor import MultiModalEncoder
|
| 14 |
+
from models.helper_encoders import ContextualTimeEncoder
|
| 15 |
+
from models.token_encoder import TokenEncoder
|
| 16 |
+
from models.wallet_encoder import WalletEncoder
|
| 17 |
+
from models.graph_updater import GraphUpdater
|
| 18 |
+
from models.ohlc_embedder import OHLCEmbedder
|
| 19 |
+
from models.model import Oracle
|
| 20 |
+
import models.vocabulary as vocab
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def quantile_pinball_loss(preds: torch.Tensor,
|
| 24 |
+
targets: torch.Tensor,
|
| 25 |
+
mask: torch.Tensor,
|
| 26 |
+
quantiles: List[float]) -> torch.Tensor:
|
| 27 |
+
if mask.sum() == 0:
|
| 28 |
+
return torch.tensor(0.0, device=preds.device, dtype=preds.dtype)
|
| 29 |
+
num_q = len(quantiles)
|
| 30 |
+
losses = []
|
| 31 |
+
for idx, q in enumerate(quantiles):
|
| 32 |
+
pred_slice = preds[:, idx::num_q]
|
| 33 |
+
target_slice = targets[:, idx::num_q]
|
| 34 |
+
mask_slice = mask[:, idx::num_q]
|
| 35 |
+
diff = target_slice - pred_slice
|
| 36 |
+
pinball = torch.maximum((q - 1.0) * diff, q * diff)
|
| 37 |
+
losses.append((pinball * mask_slice).sum())
|
| 38 |
+
return sum(losses) / mask.sum().clamp_min(1.0)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def load_config(path: str) -> Dict[str, Any]:
|
| 42 |
+
cfg_path = Path(path)
|
| 43 |
+
if not cfg_path.exists():
|
| 44 |
+
raise FileNotFoundError(f"Config file not found: {cfg_path}")
|
| 45 |
+
with cfg_path.open("r") as handle:
|
| 46 |
+
return yaml.safe_load(handle) or {}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def parse_args() -> argparse.Namespace:
|
| 50 |
+
parser = argparse.ArgumentParser(description="Validate Oracle checkpoint on a single token.")
|
| 51 |
+
parser.add_argument("--config", type=str, default="train.yaml", help="Path to training YAML config.")
|
| 52 |
+
parser.add_argument("--checkpoint", type=str, default=None, help="Checkpoint to load. Defaults to config training.checkpoint_path.")
|
| 53 |
+
parser.add_argument("--sample-idx", type=int, default=0, help="Dataset index to validate.")
|
| 54 |
+
parser.add_argument("--token-address", type=str, default=None, help="Optional mint address to pick instead of index.")
|
| 55 |
+
parser.add_argument("--t-cutoff-seconds", type=int, default=None, help="Override cutoff horizon (seconds after mint).")
|
| 56 |
+
return parser.parse_args()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def resolve_sample_index(dataset: OracleDataset,
|
| 60 |
+
sample_idx: int,
|
| 61 |
+
token_address: Optional[str]) -> int:
|
| 62 |
+
if token_address:
|
| 63 |
+
for idx, mint in enumerate(getattr(dataset, "sampled_mints", [])):
|
| 64 |
+
if mint.get("mint_address") == token_address:
|
| 65 |
+
return idx
|
| 66 |
+
raise ValueError(f"Token {token_address} not found in loaded dataset.")
|
| 67 |
+
if sample_idx < 0 or sample_idx >= len(dataset):
|
| 68 |
+
raise ValueError(f"Sample index {sample_idx} out of range (len={len(dataset)}).")
|
| 69 |
+
return sample_idx
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def move_to_device(batch: Dict[str, Any], device: torch.device) -> Dict[str, Any]:
|
| 73 |
+
for key, value in list(batch.items()):
|
| 74 |
+
if torch.is_tensor(value):
|
| 75 |
+
batch[key] = value.to(device)
|
| 76 |
+
return batch
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def main() -> None:
|
| 80 |
+
args = parse_args()
|
| 81 |
+
config = load_config(args.config)
|
| 82 |
+
|
| 83 |
+
training_cfg = config.get("training", {})
|
| 84 |
+
data_cfg = config.get("data", {})
|
| 85 |
+
db_cfg = config.get("databases", {})
|
| 86 |
+
|
| 87 |
+
checkpoint_path = Path(args.checkpoint or training_cfg.get("checkpoint_path", "checkpoints/oracle_checkpoint.pt")).expanduser()
|
| 88 |
+
if not checkpoint_path.exists():
|
| 89 |
+
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
| 90 |
+
|
| 91 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 92 |
+
dtype = torch.bfloat16 if device.type == "cuda" and torch.cuda.is_bf16_supported() else torch.float16
|
| 93 |
+
if device.type == "cpu":
|
| 94 |
+
dtype = torch.float32
|
| 95 |
+
|
| 96 |
+
quantiles = data_cfg.get("quantiles", [0.1, 0.5, 0.9])
|
| 97 |
+
horizons = data_cfg.get("horizons_seconds", [30, 60, 120, 240, 420])
|
| 98 |
+
max_samples = data_cfg.get("max_samples", None)
|
| 99 |
+
max_seq_len = data_cfg.get("max_seq_len", 50)
|
| 100 |
+
ohlc_seq_len = data_cfg.get("ohlc_seq_len", 60)
|
| 101 |
+
default_t_cutoff = int(data_cfg.get("t_cutoff_seconds", 60))
|
| 102 |
+
t_cutoff_seconds = int(args.t_cutoff_seconds) if args.t_cutoff_seconds is not None else default_t_cutoff
|
| 103 |
+
ohlc_stats_path = data_cfg.get("ohlc_stats_path", "./data/ohlc_stats.npz")
|
| 104 |
+
|
| 105 |
+
multi_modal_encoder = MultiModalEncoder(dtype=dtype)
|
| 106 |
+
time_encoder = ContextualTimeEncoder(dtype=dtype)
|
| 107 |
+
token_encoder = TokenEncoder(multi_dim=multi_modal_encoder.embedding_dim, dtype=dtype)
|
| 108 |
+
wallet_encoder = WalletEncoder(encoder=multi_modal_encoder, dtype=dtype)
|
| 109 |
+
graph_updater = GraphUpdater(time_encoder=time_encoder, dtype=dtype)
|
| 110 |
+
ohlc_embedder = OHLCEmbedder(
|
| 111 |
+
num_intervals=vocab.NUM_OHLC_INTERVALS,
|
| 112 |
+
sequence_length=ohlc_seq_len,
|
| 113 |
+
dtype=dtype
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
collator = MemecoinCollator(
|
| 117 |
+
event_type_to_id=vocab.EVENT_TO_ID,
|
| 118 |
+
device=device,
|
| 119 |
+
multi_modal_encoder=multi_modal_encoder,
|
| 120 |
+
dtype=dtype,
|
| 121 |
+
ohlc_seq_len=ohlc_seq_len,
|
| 122 |
+
max_seq_len=max_seq_len
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
clickhouse_cfg = db_cfg.get("clickhouse", {})
|
| 126 |
+
clickhouse_client = ClickHouseClient(
|
| 127 |
+
host=clickhouse_cfg.get("host", "localhost"),
|
| 128 |
+
port=int(clickhouse_cfg.get("port", 9000))
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
neo4j_cfg = db_cfg.get("neo4j", {})
|
| 132 |
+
neo4j_auth = None
|
| 133 |
+
if neo4j_cfg.get("user") is not None:
|
| 134 |
+
neo4j_auth = (neo4j_cfg.get("user"), neo4j_cfg.get("password") or "")
|
| 135 |
+
neo4j_driver = GraphDatabase.driver(neo4j_cfg.get("uri", "bolt://localhost:7687"), auth=neo4j_auth)
|
| 136 |
+
|
| 137 |
+
data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
|
| 138 |
+
dataset = OracleDataset(
|
| 139 |
+
data_fetcher=data_fetcher,
|
| 140 |
+
horizons_seconds=horizons,
|
| 141 |
+
quantiles=quantiles,
|
| 142 |
+
max_samples=max_samples,
|
| 143 |
+
ohlc_stats_path=ohlc_stats_path,
|
| 144 |
+
token_allowlist=[args.token_address] if args.token_address else None,
|
| 145 |
+
t_cutoff_seconds=t_cutoff_seconds
|
| 146 |
+
)
|
| 147 |
+
if len(dataset) == 0:
|
| 148 |
+
raise RuntimeError("Dataset is empty; cannot validate.")
|
| 149 |
+
|
| 150 |
+
sample_idx = resolve_sample_index(dataset, args.sample_idx, args.token_address)
|
| 151 |
+
sample = dataset[sample_idx]
|
| 152 |
+
if sample is None:
|
| 153 |
+
raise RuntimeError(f"Dataset returned None for sample index {sample_idx}.")
|
| 154 |
+
|
| 155 |
+
token_address = getattr(dataset, "sampled_mints", [{}])[sample_idx].get("mint_address", "Unknown")
|
| 156 |
+
print(f"Validating token {token_address} (dataset idx {sample_idx}) with T_cutoff {t_cutoff_seconds} second(s) after mint")
|
| 157 |
+
|
| 158 |
+
collated = collator([sample])
|
| 159 |
+
collated = move_to_device(collated, device)
|
| 160 |
+
|
| 161 |
+
model = Oracle(
|
| 162 |
+
token_encoder=token_encoder,
|
| 163 |
+
wallet_encoder=wallet_encoder,
|
| 164 |
+
graph_updater=graph_updater,
|
| 165 |
+
ohlc_embedder=ohlc_embedder,
|
| 166 |
+
time_encoder=time_encoder,
|
| 167 |
+
num_event_types=vocab.NUM_EVENT_TYPES,
|
| 168 |
+
multi_modal_dim=multi_modal_encoder.embedding_dim,
|
| 169 |
+
event_pad_id=vocab.EVENT_TO_ID["__PAD__"],
|
| 170 |
+
event_type_to_id=vocab.EVENT_TO_ID,
|
| 171 |
+
quantiles=quantiles,
|
| 172 |
+
horizons_seconds=horizons,
|
| 173 |
+
dtype=dtype
|
| 174 |
+
).to(device)
|
| 175 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 176 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 177 |
+
model.eval()
|
| 178 |
+
|
| 179 |
+
with torch.no_grad():
|
| 180 |
+
outputs = model(collated)
|
| 181 |
+
preds = outputs["quantile_logits"]
|
| 182 |
+
labels = collated["labels"]
|
| 183 |
+
labels_mask = collated["labels_mask"]
|
| 184 |
+
|
| 185 |
+
loss = quantile_pinball_loss(preds, labels, labels_mask, quantiles).item()
|
| 186 |
+
print(f"Pinball loss (masked): {loss:.6f}")
|
| 187 |
+
|
| 188 |
+
B = preds.shape[0]
|
| 189 |
+
grid = preds.view(B, len(horizons), len(quantiles))
|
| 190 |
+
label_grid = labels.view(B, len(horizons), len(quantiles))
|
| 191 |
+
mask_grid = labels_mask.view(B, len(horizons), len(quantiles))
|
| 192 |
+
|
| 193 |
+
for b in range(B):
|
| 194 |
+
print(f"\nSample {b} predictions:")
|
| 195 |
+
for h_idx, horizon in enumerate(horizons):
|
| 196 |
+
pred_row = grid[b, h_idx]
|
| 197 |
+
label_row = label_grid[b, h_idx]
|
| 198 |
+
mask_row = mask_grid[b, h_idx]
|
| 199 |
+
row_str = ", ".join(
|
| 200 |
+
f"q={quantiles[q_idx]:.2f}: pred={pred_row[q_idx].item():.6f}, "
|
| 201 |
+
f"label={label_row[q_idx].item():.6f}, mask={int(mask_row[q_idx].item())}"
|
| 202 |
+
for q_idx in range(len(quantiles))
|
| 203 |
+
)
|
| 204 |
+
print(f" Horizon {horizon:>4}s -> {row_str}")
|
| 205 |
+
|
| 206 |
+
neo4j_driver.close()
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
if __name__ == "__main__":
|
| 210 |
+
main()
|
validate.sh
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python validate.py \
|
| 2 |
+
--config train.yaml \
|
| 3 |
+
--checkpoint checkpoints/oracle_checkpoint.pt \
|
| 4 |
+
--t-cutoff-seconds 240 \
|
| 5 |
+
--token-address 'czaE9hrSWJ6g21bxS6qh9GbbczoRa5F5Lx2eo1apump'
|
| 6 |
+
|