Upload folder using huggingface_hub
Browse files- .ipynb_checkpoints/.gitignore-checkpoint +19 -0
- data/.ipynb_checkpoints/data_fetcher-checkpoint.py +1263 -0
- data/data_collator.py +13 -4
- data/ohlc_stats.npz +1 -1
- log.log +2 -2
- models/graph_updater.py +5 -6
- sample_12LJX4a83B4tCuZ1_3.json +0 -0
- scripts/.ipynb_checkpoints/cache_dataset-checkpoint.py +431 -0
- scripts/analyze_distribution.py +100 -0
- scripts/dump_cache_sample.py +146 -0
- train.py +2 -2
- train.sh +4 -4
.ipynb_checkpoints/.gitignore-checkpoint
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ignore the __pycache__ directory anywhere in the repository
|
| 2 |
+
__pycache__/
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# Ignore the 'runs' directory anywhere in the repository, regardless of nesting
|
| 6 |
+
runs/
|
| 7 |
+
|
| 8 |
+
data/pump_fun
|
| 9 |
+
data/cache
|
| 10 |
+
.env
|
| 11 |
+
|
| 12 |
+
data/cache
|
| 13 |
+
.tmp/
|
| 14 |
+
.cache/
|
| 15 |
+
checkpoints/
|
| 16 |
+
metadata/
|
| 17 |
+
store/
|
| 18 |
+
preprocessed_configs/
|
| 19 |
+
.early.coverage
|
data/.ipynb_checkpoints/data_fetcher-checkpoint.py
ADDED
|
@@ -0,0 +1,1263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# data_fetcher.py
|
| 2 |
+
|
| 3 |
+
from typing import List, Dict, Any, Tuple, Set, Optional
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
import datetime, time
|
| 6 |
+
|
| 7 |
+
# We need the vocabulary for mapping IDs
|
| 8 |
+
import models.vocabulary as vocab
|
| 9 |
+
|
| 10 |
+
class DataFetcher:
|
| 11 |
+
"""
|
| 12 |
+
A dedicated class to handle all database queries for ClickHouse and Neo4j.
|
| 13 |
+
This keeps data fetching logic separate from the dataset and model.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
# --- Explicit column definitions for wallet profile & social fetches ---
|
| 17 |
+
PROFILE_BASE_COLUMNS = [
|
| 18 |
+
'wallet_address',
|
| 19 |
+
'updated_at',
|
| 20 |
+
'first_seen_ts',
|
| 21 |
+
'last_seen_ts',
|
| 22 |
+
'tags',
|
| 23 |
+
'deployed_tokens',
|
| 24 |
+
'funded_from',
|
| 25 |
+
'funded_timestamp',
|
| 26 |
+
'funded_signature',
|
| 27 |
+
'funded_amount'
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
PROFILE_METRIC_COLUMNS = [
|
| 31 |
+
'balance',
|
| 32 |
+
'transfers_in_count',
|
| 33 |
+
'transfers_out_count',
|
| 34 |
+
'spl_transfers_in_count',
|
| 35 |
+
'spl_transfers_out_count',
|
| 36 |
+
'total_buys_count',
|
| 37 |
+
'total_sells_count',
|
| 38 |
+
'total_winrate',
|
| 39 |
+
'stats_1d_realized_profit_sol',
|
| 40 |
+
'stats_1d_realized_profit_usd',
|
| 41 |
+
'stats_1d_realized_profit_pnl',
|
| 42 |
+
'stats_1d_buy_count',
|
| 43 |
+
'stats_1d_sell_count',
|
| 44 |
+
'stats_1d_transfer_in_count',
|
| 45 |
+
'stats_1d_transfer_out_count',
|
| 46 |
+
'stats_1d_avg_holding_period',
|
| 47 |
+
'stats_1d_total_bought_cost_sol',
|
| 48 |
+
'stats_1d_total_bought_cost_usd',
|
| 49 |
+
'stats_1d_total_sold_income_sol',
|
| 50 |
+
'stats_1d_total_sold_income_usd',
|
| 51 |
+
'stats_1d_total_fee',
|
| 52 |
+
'stats_1d_winrate',
|
| 53 |
+
'stats_1d_tokens_traded',
|
| 54 |
+
'stats_7d_realized_profit_sol',
|
| 55 |
+
'stats_7d_realized_profit_usd',
|
| 56 |
+
'stats_7d_realized_profit_pnl',
|
| 57 |
+
'stats_7d_buy_count',
|
| 58 |
+
'stats_7d_sell_count',
|
| 59 |
+
'stats_7d_transfer_in_count',
|
| 60 |
+
'stats_7d_transfer_out_count',
|
| 61 |
+
'stats_7d_avg_holding_period',
|
| 62 |
+
'stats_7d_total_bought_cost_sol',
|
| 63 |
+
'stats_7d_total_bought_cost_usd',
|
| 64 |
+
'stats_7d_total_sold_income_sol',
|
| 65 |
+
'stats_7d_total_sold_income_usd',
|
| 66 |
+
'stats_7d_total_fee',
|
| 67 |
+
'stats_7d_winrate',
|
| 68 |
+
'stats_7d_tokens_traded',
|
| 69 |
+
'stats_30d_realized_profit_sol',
|
| 70 |
+
'stats_30d_realized_profit_usd',
|
| 71 |
+
'stats_30d_realized_profit_pnl',
|
| 72 |
+
'stats_30d_buy_count',
|
| 73 |
+
'stats_30d_sell_count',
|
| 74 |
+
'stats_30d_transfer_in_count',
|
| 75 |
+
'stats_30d_transfer_out_count',
|
| 76 |
+
'stats_30d_avg_holding_period',
|
| 77 |
+
'stats_30d_total_bought_cost_sol',
|
| 78 |
+
'stats_30d_total_bought_cost_usd',
|
| 79 |
+
'stats_30d_total_sold_income_sol',
|
| 80 |
+
'stats_30d_total_sold_income_usd',
|
| 81 |
+
'stats_30d_total_fee',
|
| 82 |
+
'stats_30d_winrate',
|
| 83 |
+
'stats_30d_tokens_traded'
|
| 84 |
+
]
|
| 85 |
+
|
| 86 |
+
PROFILE_COLUMNS_FOR_QUERY = PROFILE_BASE_COLUMNS + PROFILE_METRIC_COLUMNS
|
| 87 |
+
|
| 88 |
+
SOCIAL_COLUMNS_FOR_QUERY = [
|
| 89 |
+
'wallet_address',
|
| 90 |
+
'pumpfun_username',
|
| 91 |
+
'twitter_username',
|
| 92 |
+
'telegram_channel',
|
| 93 |
+
'kolscan_name',
|
| 94 |
+
'cabalspy_name',
|
| 95 |
+
'axiom_kol_name'
|
| 96 |
+
]
|
| 97 |
+
def __init__(self, clickhouse_client: Any, neo4j_driver: Any):
|
| 98 |
+
self.db_client = clickhouse_client
|
| 99 |
+
self.graph_client = neo4j_driver
|
| 100 |
+
print("DataFetcher instantiated.")
|
| 101 |
+
|
| 102 |
+
def get_all_mints(self, start_date: Optional[datetime.datetime] = None) -> List[Dict[str, Any]]:
|
| 103 |
+
"""
|
| 104 |
+
Fetches a list of all mint events to serve as dataset samples.
|
| 105 |
+
Can be filtered to only include mints on or after a given start_date.
|
| 106 |
+
"""
|
| 107 |
+
query = "SELECT mint_address, timestamp, creator_address, protocol, token_name, token_symbol, token_uri, total_supply, token_decimals FROM mints"
|
| 108 |
+
params = {}
|
| 109 |
+
where_clauses = []
|
| 110 |
+
|
| 111 |
+
if start_date:
|
| 112 |
+
where_clauses.append("timestamp >= %(start_date)s")
|
| 113 |
+
params['start_date'] = start_date
|
| 114 |
+
|
| 115 |
+
if where_clauses:
|
| 116 |
+
query += " WHERE " + " AND ".join(where_clauses)
|
| 117 |
+
|
| 118 |
+
print(f"INFO: Executing query to get all mints: `{query}` with params: {params}")
|
| 119 |
+
try:
|
| 120 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 121 |
+
if not rows:
|
| 122 |
+
return []
|
| 123 |
+
columns = [col[0] for col in columns_info]
|
| 124 |
+
result = [dict(zip(columns, row)) for row in rows]
|
| 125 |
+
if not result:
|
| 126 |
+
return []
|
| 127 |
+
return result
|
| 128 |
+
except Exception as e:
|
| 129 |
+
print(f"ERROR: Failed to fetch token addresses from ClickHouse: {e}")
|
| 130 |
+
print("INFO: Falling back to mock token addresses for development.")
|
| 131 |
+
return [{'mint_address': 'tknA_real', 'timestamp': datetime.datetime.now(datetime.timezone.utc), 'creator_address': 'addr_Creator_Real', 'protocol': 0}]
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def fetch_mint_record(self, token_address: str) -> Dict[str, Any]:
|
| 135 |
+
"""
|
| 136 |
+
Fetches the raw mint record for a token from the 'mints' table.
|
| 137 |
+
"""
|
| 138 |
+
query = f"SELECT timestamp, creator_address, mint_address, protocol FROM mints WHERE mint_address = '{token_address}' ORDER BY timestamp ASC LIMIT 1"
|
| 139 |
+
print(f"INFO: Executing query to fetch mint record: `{query}`")
|
| 140 |
+
|
| 141 |
+
# Assumes the client returns a list of dicts or can be converted
|
| 142 |
+
# Using column names from your schema
|
| 143 |
+
columns = ['timestamp', 'creator_address', 'mint_address', 'protocol']
|
| 144 |
+
try:
|
| 145 |
+
result = self.db_client.execute(query)
|
| 146 |
+
|
| 147 |
+
if not result or not result[0]:
|
| 148 |
+
raise ValueError(f"No mint event found for token {token_address}")
|
| 149 |
+
|
| 150 |
+
# Convert the tuple result into a dictionary
|
| 151 |
+
record = dict(zip(columns, result[0]))
|
| 152 |
+
return record
|
| 153 |
+
except Exception as e:
|
| 154 |
+
print(f"ERROR: Failed to fetch mint record for {token_address}: {e}")
|
| 155 |
+
print("INFO: Falling back to mock mint record for development.")
|
| 156 |
+
# Fallback for development if DB connection fails
|
| 157 |
+
return {
|
| 158 |
+
'timestamp': datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(days=1),
|
| 159 |
+
'creator_address': 'addr_Creator_Real',
|
| 160 |
+
'mint_address': token_address,
|
| 161 |
+
'protocol': vocab.PROTOCOL_TO_ID.get("Pump V1", 0)
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
def fetch_wallet_profiles(self, wallet_addresses: List[str], T_cutoff: datetime.datetime) -> Dict[str, Dict[str, Any]]:
|
| 165 |
+
"""
|
| 166 |
+
Convenience wrapper around fetch_wallet_profiles_and_socials for profile-only data.
|
| 167 |
+
"""
|
| 168 |
+
profiles, _ = self.fetch_wallet_profiles_and_socials(wallet_addresses, T_cutoff)
|
| 169 |
+
return profiles
|
| 170 |
+
|
| 171 |
+
def fetch_wallet_socials(self, wallet_addresses: List[str]) -> Dict[str, Dict[str, Any]]:
|
| 172 |
+
"""
|
| 173 |
+
Fetches wallet social records for a list of wallet addresses.
|
| 174 |
+
Batches queries to avoid "Max query size exceeded" errors.
|
| 175 |
+
Returns a dictionary mapping wallet_address to its social data.
|
| 176 |
+
"""
|
| 177 |
+
if not wallet_addresses:
|
| 178 |
+
return {}
|
| 179 |
+
|
| 180 |
+
BATCH_SIZE = 1000
|
| 181 |
+
socials = {}
|
| 182 |
+
total_wallets = len(wallet_addresses)
|
| 183 |
+
print(f"INFO: Executing query to fetch wallet socials for {total_wallets} wallets in batches of {BATCH_SIZE}.")
|
| 184 |
+
|
| 185 |
+
for i in range(0, total_wallets, BATCH_SIZE):
|
| 186 |
+
batch_addresses = wallet_addresses[i : i + BATCH_SIZE]
|
| 187 |
+
|
| 188 |
+
query = "SELECT * FROM wallet_socials WHERE wallet_address IN %(addresses)s"
|
| 189 |
+
params = {'addresses': batch_addresses}
|
| 190 |
+
|
| 191 |
+
try:
|
| 192 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 193 |
+
if not rows:
|
| 194 |
+
continue
|
| 195 |
+
|
| 196 |
+
columns = [col[0] for col in columns_info]
|
| 197 |
+
for row in rows:
|
| 198 |
+
social_dict = dict(zip(columns, row))
|
| 199 |
+
wallet_addr = social_dict.get('wallet_address')
|
| 200 |
+
if wallet_addr:
|
| 201 |
+
socials[wallet_addr] = social_dict
|
| 202 |
+
|
| 203 |
+
except Exception as e:
|
| 204 |
+
print(f"ERROR: Failed to fetch wallet socials for batch {i}: {e}")
|
| 205 |
+
# Continue to next batch
|
| 206 |
+
|
| 207 |
+
return socials
|
| 208 |
+
|
| 209 |
+
def fetch_wallet_profiles_and_socials(self,
|
| 210 |
+
wallet_addresses: List[str],
|
| 211 |
+
T_cutoff: datetime.datetime) -> Tuple[Dict[str, Dict[str, Any]], Dict[str, Dict[str, Any]]]:
|
| 212 |
+
"""
|
| 213 |
+
Fetches wallet profiles (time-aware) and socials for all requested wallets.
|
| 214 |
+
Batches queries to avoid "Max query size exceeded" errors.
|
| 215 |
+
Returns two dictionaries: profiles, socials.
|
| 216 |
+
"""
|
| 217 |
+
if not wallet_addresses:
|
| 218 |
+
return {}, {}
|
| 219 |
+
|
| 220 |
+
social_columns = self.SOCIAL_COLUMNS_FOR_QUERY
|
| 221 |
+
profile_base_cols = self.PROFILE_BASE_COLUMNS
|
| 222 |
+
profile_metric_cols = self.PROFILE_METRIC_COLUMNS
|
| 223 |
+
|
| 224 |
+
profile_base_str = ",\n ".join(profile_base_cols)
|
| 225 |
+
metric_projection_cols = ['wallet_address', 'updated_at'] + profile_metric_cols
|
| 226 |
+
profile_metric_str = ",\n ".join(metric_projection_cols)
|
| 227 |
+
|
| 228 |
+
profile_base_select_cols = [col for col in profile_base_cols if col != 'wallet_address']
|
| 229 |
+
profile_metric_select_cols = [
|
| 230 |
+
col for col in profile_metric_cols if col not in ('wallet_address',)
|
| 231 |
+
]
|
| 232 |
+
social_select_cols = [col for col in social_columns if col != 'wallet_address']
|
| 233 |
+
|
| 234 |
+
select_expressions = []
|
| 235 |
+
for col in profile_base_select_cols:
|
| 236 |
+
select_expressions.append(f"lp.{col} AS profile__{col}")
|
| 237 |
+
for col in profile_metric_select_cols:
|
| 238 |
+
select_expressions.append(f"lm.{col} AS profile__{col}")
|
| 239 |
+
for col in social_select_cols:
|
| 240 |
+
select_expressions.append(f"ws.{col} AS social__{col}")
|
| 241 |
+
select_clause = ""
|
| 242 |
+
if select_expressions:
|
| 243 |
+
select_clause = ",\n " + ",\n ".join(select_expressions)
|
| 244 |
+
|
| 245 |
+
profile_keys = [f"profile__{col}" for col in (profile_base_select_cols + profile_metric_select_cols)]
|
| 246 |
+
social_keys = [f"social__{col}" for col in social_select_cols]
|
| 247 |
+
|
| 248 |
+
BATCH_SIZE = 1000
|
| 249 |
+
all_profiles = {}
|
| 250 |
+
all_socials = {}
|
| 251 |
+
|
| 252 |
+
total_wallets = len(wallet_addresses)
|
| 253 |
+
print(f"INFO: Fetching profiles+socials for {total_wallets} wallets in batches of {BATCH_SIZE}...")
|
| 254 |
+
|
| 255 |
+
for i in range(0, total_wallets, BATCH_SIZE):
|
| 256 |
+
batch_addresses = wallet_addresses[i : i + BATCH_SIZE]
|
| 257 |
+
|
| 258 |
+
query = f"""
|
| 259 |
+
WITH ranked_profiles AS (
|
| 260 |
+
SELECT
|
| 261 |
+
{profile_base_str},
|
| 262 |
+
ROW_NUMBER() OVER (PARTITION BY wallet_address ORDER BY updated_at DESC) AS rn
|
| 263 |
+
FROM wallet_profiles
|
| 264 |
+
WHERE wallet_address IN %(addresses)s
|
| 265 |
+
),
|
| 266 |
+
latest_profiles AS (
|
| 267 |
+
SELECT
|
| 268 |
+
{profile_base_str}
|
| 269 |
+
FROM ranked_profiles
|
| 270 |
+
WHERE rn = 1
|
| 271 |
+
),
|
| 272 |
+
ranked_metrics AS (
|
| 273 |
+
SELECT
|
| 274 |
+
{profile_metric_str},
|
| 275 |
+
ROW_NUMBER() OVER (PARTITION BY wallet_address ORDER BY updated_at DESC) AS rn
|
| 276 |
+
FROM wallet_profile_metrics
|
| 277 |
+
WHERE
|
| 278 |
+
wallet_address IN %(addresses)s
|
| 279 |
+
AND updated_at <= %(T_cutoff)s
|
| 280 |
+
),
|
| 281 |
+
latest_metrics AS (
|
| 282 |
+
SELECT
|
| 283 |
+
{profile_metric_str}
|
| 284 |
+
FROM ranked_metrics
|
| 285 |
+
WHERE rn = 1
|
| 286 |
+
),
|
| 287 |
+
requested_wallets AS (
|
| 288 |
+
SELECT DISTINCT wallet_address
|
| 289 |
+
FROM (SELECT arrayJoin(%(addresses)s) AS wallet_address)
|
| 290 |
+
)
|
| 291 |
+
SELECT
|
| 292 |
+
rw.wallet_address AS wallet_address
|
| 293 |
+
{select_clause}
|
| 294 |
+
FROM requested_wallets AS rw
|
| 295 |
+
LEFT JOIN latest_profiles AS lp ON rw.wallet_address = lp.wallet_address
|
| 296 |
+
LEFT JOIN latest_metrics AS lm ON rw.wallet_address = lm.wallet_address
|
| 297 |
+
LEFT JOIN wallet_socials AS ws ON rw.wallet_address = ws.wallet_address;
|
| 298 |
+
"""
|
| 299 |
+
|
| 300 |
+
params = {'addresses': batch_addresses, 'T_cutoff': T_cutoff}
|
| 301 |
+
|
| 302 |
+
try:
|
| 303 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 304 |
+
if not rows:
|
| 305 |
+
continue
|
| 306 |
+
|
| 307 |
+
columns = [col[0] for col in columns_info]
|
| 308 |
+
|
| 309 |
+
for row in rows:
|
| 310 |
+
row_dict = dict(zip(columns, row))
|
| 311 |
+
wallet_addr = row_dict.get('wallet_address')
|
| 312 |
+
if not wallet_addr:
|
| 313 |
+
continue
|
| 314 |
+
|
| 315 |
+
profile_data = {}
|
| 316 |
+
if profile_keys:
|
| 317 |
+
for pref_key in profile_keys:
|
| 318 |
+
if pref_key in row_dict:
|
| 319 |
+
value = row_dict[pref_key]
|
| 320 |
+
profile_data[pref_key.replace('profile__', '')] = value
|
| 321 |
+
|
| 322 |
+
if profile_data and any(value is not None for value in profile_data.values()):
|
| 323 |
+
profile_data['wallet_address'] = wallet_addr
|
| 324 |
+
all_profiles[wallet_addr] = profile_data
|
| 325 |
+
|
| 326 |
+
social_data = {}
|
| 327 |
+
if social_keys:
|
| 328 |
+
for pref_key in social_keys:
|
| 329 |
+
if pref_key in row_dict:
|
| 330 |
+
value = row_dict[pref_key]
|
| 331 |
+
social_data[pref_key.replace('social__', '')] = value
|
| 332 |
+
|
| 333 |
+
if social_data and any(value is not None for value in social_data.values()):
|
| 334 |
+
social_data['wallet_address'] = wallet_addr
|
| 335 |
+
all_socials[wallet_addr] = social_data
|
| 336 |
+
|
| 337 |
+
except Exception as e:
|
| 338 |
+
print(f"ERROR: Combined profile/social query failed for batch {i}-{i+BATCH_SIZE}: {e}")
|
| 339 |
+
# We continue to the next batch
|
| 340 |
+
|
| 341 |
+
return all_profiles, all_socials
|
| 342 |
+
|
| 343 |
+
def fetch_wallet_holdings(self, wallet_addresses: List[str], T_cutoff: datetime.datetime) -> Dict[str, List[Dict[str, Any]]]:
|
| 344 |
+
"""
|
| 345 |
+
Fetches top 2 wallet holding records for a list of wallet addresses that were active at T_cutoff.
|
| 346 |
+
Batches queries to avoid "Max query size exceeded" errors.
|
| 347 |
+
Returns a dictionary mapping wallet_address to a LIST of its holding data.
|
| 348 |
+
"""
|
| 349 |
+
if not wallet_addresses:
|
| 350 |
+
return {}
|
| 351 |
+
|
| 352 |
+
BATCH_SIZE = 1000
|
| 353 |
+
holdings = defaultdict(list)
|
| 354 |
+
total_wallets = len(wallet_addresses)
|
| 355 |
+
print(f"INFO: Executing query to fetch wallet holdings for {total_wallets} wallets in batches of {BATCH_SIZE}.")
|
| 356 |
+
|
| 357 |
+
for i in range(0, total_wallets, BATCH_SIZE):
|
| 358 |
+
batch_addresses = wallet_addresses[i : i + BATCH_SIZE]
|
| 359 |
+
|
| 360 |
+
# --- Time-aware query ---
|
| 361 |
+
# 1. For each holding, find the latest state at or before T_cutoff.
|
| 362 |
+
# 2. Filter for holdings where the balance was greater than 0.
|
| 363 |
+
# 3. Rank these active holdings by USD volume and take the top 2 per wallet.
|
| 364 |
+
query = """
|
| 365 |
+
WITH point_in_time_holdings AS (
|
| 366 |
+
SELECT
|
| 367 |
+
*,
|
| 368 |
+
COALESCE(history_bought_cost_sol, 0) + COALESCE(history_sold_income_sol, 0) AS total_volume_usd,
|
| 369 |
+
ROW_NUMBER() OVER(PARTITION BY wallet_address, mint_address ORDER BY updated_at DESC) as rn_per_holding
|
| 370 |
+
FROM wallet_holdings
|
| 371 |
+
WHERE
|
| 372 |
+
wallet_address IN %(addresses)s
|
| 373 |
+
AND updated_at <= %(T_cutoff)s
|
| 374 |
+
),
|
| 375 |
+
ranked_active_holdings AS (
|
| 376 |
+
SELECT *,
|
| 377 |
+
ROW_NUMBER() OVER(PARTITION BY wallet_address ORDER BY total_volume_usd DESC) as rn_per_wallet
|
| 378 |
+
FROM point_in_time_holdings
|
| 379 |
+
WHERE rn_per_holding = 1 AND current_balance > 0
|
| 380 |
+
)
|
| 381 |
+
SELECT *
|
| 382 |
+
FROM ranked_active_holdings
|
| 383 |
+
WHERE rn_per_wallet <= 2;
|
| 384 |
+
"""
|
| 385 |
+
params = {'addresses': batch_addresses, 'T_cutoff': T_cutoff}
|
| 386 |
+
|
| 387 |
+
try:
|
| 388 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 389 |
+
if not rows:
|
| 390 |
+
continue
|
| 391 |
+
|
| 392 |
+
columns = [col[0] for col in columns_info]
|
| 393 |
+
for row in rows:
|
| 394 |
+
holding_dict = dict(zip(columns, row))
|
| 395 |
+
wallet_addr = holding_dict.get('wallet_address')
|
| 396 |
+
if wallet_addr:
|
| 397 |
+
holdings[wallet_addr].append(holding_dict)
|
| 398 |
+
|
| 399 |
+
except Exception as e:
|
| 400 |
+
print(f"ERROR: Failed to fetch wallet holdings for batch {i}: {e}")
|
| 401 |
+
# Continue to next batch
|
| 402 |
+
|
| 403 |
+
return dict(holdings)
|
| 404 |
+
|
| 405 |
+
def fetch_graph_links(self,
|
| 406 |
+
initial_addresses: List[str],
|
| 407 |
+
T_cutoff: datetime.datetime,
|
| 408 |
+
max_degrees: int = 1) -> Tuple[Dict[str, str], Dict[str, Dict[str, Any]]]:
|
| 409 |
+
"""
|
| 410 |
+
Fetches graph links from Neo4j, traversing up to a max degree of separation.
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
initial_addresses: A list of starting wallet or token addresses.
|
| 414 |
+
max_degrees: The maximum number of hops to traverse in the graph.
|
| 415 |
+
|
| 416 |
+
Returns:
|
| 417 |
+
A tuple containing:
|
| 418 |
+
- A dictionary mapping entity addresses to their type ('Wallet' or 'Token').
|
| 419 |
+
- A dictionary of aggregated links, structured for the GraphUpdater.
|
| 420 |
+
"""
|
| 421 |
+
if not initial_addresses:
|
| 422 |
+
return {}, {}
|
| 423 |
+
|
| 424 |
+
cutoff_ts = int(T_cutoff.timestamp())
|
| 425 |
+
|
| 426 |
+
print(f"INFO: Fetching graph links up to {max_degrees} degrees for {len(initial_addresses)} initial entities...")
|
| 427 |
+
|
| 428 |
+
max_retries = 3
|
| 429 |
+
backoff_sec = 2
|
| 430 |
+
|
| 431 |
+
for attempt in range(max_retries + 1):
|
| 432 |
+
try:
|
| 433 |
+
with self.graph_client.session() as session:
|
| 434 |
+
all_entities = {addr: 'Token' for addr in initial_addresses} # Assume initial are tokens
|
| 435 |
+
newly_found_entities = set(initial_addresses)
|
| 436 |
+
aggregated_links = defaultdict(lambda: {'links': [], 'edges': []})
|
| 437 |
+
|
| 438 |
+
for i in range(max_degrees):
|
| 439 |
+
if not newly_found_entities:
|
| 440 |
+
break
|
| 441 |
+
|
| 442 |
+
print(f" - Degree {i+1}: Traversing from {len(newly_found_entities)} new entities...")
|
| 443 |
+
|
| 444 |
+
# --- TIMING: Query execution ---
|
| 445 |
+
_t_query_start = time.perf_counter()
|
| 446 |
+
|
| 447 |
+
# Cypher query to find direct neighbors of the current frontier
|
| 448 |
+
# OPTIMIZED: Filter by timestamp IN Neo4j to avoid transferring 97%+ unused records
|
| 449 |
+
query = """
|
| 450 |
+
MATCH (a)-[r]-(b)
|
| 451 |
+
WHERE a.address IN $addresses AND r.timestamp <= $cutoff_ts
|
| 452 |
+
RETURN a.address AS source_address, type(r) AS link_type, properties(r) AS link_props, b.address AS dest_address, labels(b)[0] AS dest_type
|
| 453 |
+
LIMIT 10000
|
| 454 |
+
"""
|
| 455 |
+
params = {'addresses': list(newly_found_entities), 'cutoff_ts': cutoff_ts}
|
| 456 |
+
result = session.run(query, params)
|
| 457 |
+
|
| 458 |
+
_t_query_done = time.perf_counter()
|
| 459 |
+
|
| 460 |
+
# --- TIMING: Result processing ---
|
| 461 |
+
_t_process_start = time.perf_counter()
|
| 462 |
+
records_total = 0
|
| 463 |
+
|
| 464 |
+
current_degree_new_entities = set()
|
| 465 |
+
for record in result:
|
| 466 |
+
records_total += 1
|
| 467 |
+
link_type = record['link_type']
|
| 468 |
+
link_props = dict(record['link_props'])
|
| 469 |
+
source_addr = record['source_address']
|
| 470 |
+
dest_addr = record['dest_address']
|
| 471 |
+
dest_type = record['dest_type']
|
| 472 |
+
|
| 473 |
+
# Add the link and edge data
|
| 474 |
+
aggregated_links[link_type]['links'].append(link_props)
|
| 475 |
+
aggregated_links[link_type]['edges'].append((source_addr, dest_addr))
|
| 476 |
+
|
| 477 |
+
# If we found a new entity, add it to the set for the next iteration
|
| 478 |
+
if dest_addr not in all_entities.keys():
|
| 479 |
+
current_degree_new_entities.add(dest_addr)
|
| 480 |
+
all_entities[dest_addr] = dest_type
|
| 481 |
+
|
| 482 |
+
_t_process_done = time.perf_counter()
|
| 483 |
+
|
| 484 |
+
# --- TIMING: Print detailed stats ---
|
| 485 |
+
print(f" [NEO4J TIMING] query_exec: {(_t_query_done - _t_query_start)*1000:.1f}ms, "
|
| 486 |
+
f"result_process: {(_t_process_done - _t_process_start)*1000:.1f}ms")
|
| 487 |
+
print(f" [NEO4J STATS] records_returned: {records_total}, "
|
| 488 |
+
f"new_entities: {len(current_degree_new_entities)}")
|
| 489 |
+
|
| 490 |
+
newly_found_entities = current_degree_new_entities
|
| 491 |
+
|
| 492 |
+
# --- Post-process: rename, map props, strip, cap ---
|
| 493 |
+
MAX_LINKS_PER_TYPE = 500
|
| 494 |
+
|
| 495 |
+
# Neo4j type -> collator type name
|
| 496 |
+
_NEO4J_TO_COLLATOR_NAME = {
|
| 497 |
+
'TRANSFERRED_TO': 'TransferLink',
|
| 498 |
+
'BUNDLE_TRADE': 'BundleTradeLink',
|
| 499 |
+
'COPIED_TRADE': 'CopiedTradeLink',
|
| 500 |
+
'COORDINATED_ACTIVITY': 'CoordinatedActivityLink',
|
| 501 |
+
'SNIPED': 'SnipedLink',
|
| 502 |
+
'MINTED': 'MintedLink',
|
| 503 |
+
'LOCKED_SUPPLY': 'LockedSupplyLink',
|
| 504 |
+
'BURNED': 'BurnedLink',
|
| 505 |
+
'PROVIDED_LIQUIDITY': 'ProvidedLiquidityLink',
|
| 506 |
+
'WHALE_OF': 'WhaleOfLink',
|
| 507 |
+
'TOP_TRADER_OF': 'TopTraderOfLink',
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
# Neo4j prop name -> encoder prop name (for fields with mismatched names)
|
| 511 |
+
_PROP_REMAP = {
|
| 512 |
+
'CopiedTradeLink': {
|
| 513 |
+
'buy_gap': 'time_gap_on_buy_sec',
|
| 514 |
+
'sell_gap': 'time_gap_on_sell_sec',
|
| 515 |
+
'f_buy_total': 'follower_buy_total',
|
| 516 |
+
'f_sell_total': 'follower_sell_total',
|
| 517 |
+
'leader_pnl': 'leader_pnl',
|
| 518 |
+
'follower_pnl': 'follower_pnl',
|
| 519 |
+
},
|
| 520 |
+
}
|
| 521 |
+
|
| 522 |
+
# Only keep fields each encoder actually reads
|
| 523 |
+
_NEEDED_FIELDS = {
|
| 524 |
+
'TransferLink': ['amount', 'mint'],
|
| 525 |
+
'BundleTradeLink': ['signatures'], # Neo4j has no total_amount; we derive it below
|
| 526 |
+
'CopiedTradeLink': ['time_gap_on_buy_sec', 'time_gap_on_sell_sec', 'leader_pnl', 'follower_pnl', 'follower_buy_total', 'follower_sell_total'],
|
| 527 |
+
'CoordinatedActivityLink': ['time_gap_on_first_sec', 'time_gap_on_second_sec'],
|
| 528 |
+
'SnipedLink': ['rank', 'sniped_amount'],
|
| 529 |
+
'MintedLink': ['buy_amount'],
|
| 530 |
+
'LockedSupplyLink': ['amount'],
|
| 531 |
+
'BurnedLink': ['amount'],
|
| 532 |
+
'ProvidedLiquidityLink': ['amount_quote'],
|
| 533 |
+
'WhaleOfLink': ['holding_pct_at_creation'],
|
| 534 |
+
'TopTraderOfLink': ['pnl_at_creation'],
|
| 535 |
+
}
|
| 536 |
+
|
| 537 |
+
cleaned_links = {}
|
| 538 |
+
for neo4j_type, data in aggregated_links.items():
|
| 539 |
+
collator_name = _NEO4J_TO_COLLATOR_NAME.get(neo4j_type)
|
| 540 |
+
if not collator_name:
|
| 541 |
+
continue # Skip unknown link types
|
| 542 |
+
|
| 543 |
+
links = data['links']
|
| 544 |
+
edges = data['edges']
|
| 545 |
+
|
| 546 |
+
# Cap
|
| 547 |
+
links = links[:MAX_LINKS_PER_TYPE]
|
| 548 |
+
edges = edges[:MAX_LINKS_PER_TYPE]
|
| 549 |
+
|
| 550 |
+
# Remap property names if needed
|
| 551 |
+
remap = _PROP_REMAP.get(collator_name)
|
| 552 |
+
if remap:
|
| 553 |
+
links = [{remap.get(k, k): v for k, v in l.items()} for l in links]
|
| 554 |
+
|
| 555 |
+
# Strip to only needed fields
|
| 556 |
+
needed = _NEEDED_FIELDS.get(collator_name, [])
|
| 557 |
+
links = [{f: l.get(f, 0) for f in needed} for l in links]
|
| 558 |
+
|
| 559 |
+
# BundleTradeLink: Neo4j has no total_amount; derive from signatures count
|
| 560 |
+
if collator_name == 'BundleTradeLink':
|
| 561 |
+
links = [{'total_amount': len(l.get('signatures', []) if isinstance(l.get('signatures'), list) else [])} for l in links]
|
| 562 |
+
|
| 563 |
+
cleaned_links[collator_name] = {'links': links, 'edges': edges}
|
| 564 |
+
|
| 565 |
+
return all_entities, cleaned_links
|
| 566 |
+
|
| 567 |
+
except Exception as e:
|
| 568 |
+
msg = str(e)
|
| 569 |
+
is_rate_limit = "AuthenticationRateLimit" in msg or "RateLimit" in msg
|
| 570 |
+
is_transient = "ServiceUnavailable" in msg or "TransientError" in msg or "SessionExpired" in msg
|
| 571 |
+
|
| 572 |
+
if is_rate_limit or is_transient:
|
| 573 |
+
if attempt < max_retries:
|
| 574 |
+
sleep_time = backoff_sec * (2 ** attempt)
|
| 575 |
+
print(f"WARN: Neo4j error ({type(e).__name__}). Retrying in {sleep_time}s... (Attempt {attempt+1}/{max_retries})")
|
| 576 |
+
time.sleep(sleep_time)
|
| 577 |
+
continue
|
| 578 |
+
|
| 579 |
+
# If we're here, it's either not retryable or we ran out of retries
|
| 580 |
+
# Ensure we use "FATAL" prefix so the caller knows to stop if required
|
| 581 |
+
raise RuntimeError(f"FATAL: Failed to fetch graph links from Neo4j: {e}") from e
|
| 582 |
+
|
| 583 |
+
def fetch_token_data(self, token_addresses: List[str], T_cutoff: datetime.datetime) -> Dict[str, Dict[str, Any]]:
|
| 584 |
+
"""
|
| 585 |
+
Fetches the latest token data for each address at or before T_cutoff.
|
| 586 |
+
Batches queries to avoid "Max query size exceeded" errors.
|
| 587 |
+
Returns a dictionary mapping token_address to its data.
|
| 588 |
+
"""
|
| 589 |
+
if not token_addresses:
|
| 590 |
+
return {}
|
| 591 |
+
|
| 592 |
+
BATCH_SIZE = 1000
|
| 593 |
+
tokens = {}
|
| 594 |
+
total_tokens = len(token_addresses)
|
| 595 |
+
print(f"INFO: Executing query to fetch token data for {total_tokens} tokens in batches of {BATCH_SIZE}.")
|
| 596 |
+
|
| 597 |
+
for i in range(0, total_tokens, BATCH_SIZE):
|
| 598 |
+
batch_addresses = token_addresses[i : i + BATCH_SIZE]
|
| 599 |
+
|
| 600 |
+
# --- NEW: Time-aware query for historical token data ---
|
| 601 |
+
query = """
|
| 602 |
+
WITH ranked_tokens AS (
|
| 603 |
+
SELECT
|
| 604 |
+
*,
|
| 605 |
+
ROW_NUMBER() OVER (PARTITION BY token_address ORDER BY updated_at DESC) as rn
|
| 606 |
+
FROM tokens
|
| 607 |
+
WHERE
|
| 608 |
+
token_address IN %(addresses)s
|
| 609 |
+
AND updated_at <= %(T_cutoff)s
|
| 610 |
+
)
|
| 611 |
+
SELECT token_address, name, symbol, token_uri, protocol, total_supply, decimals
|
| 612 |
+
FROM ranked_tokens
|
| 613 |
+
WHERE rn = 1;
|
| 614 |
+
"""
|
| 615 |
+
params = {'addresses': batch_addresses, 'T_cutoff': T_cutoff}
|
| 616 |
+
|
| 617 |
+
try:
|
| 618 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 619 |
+
if not rows:
|
| 620 |
+
continue
|
| 621 |
+
|
| 622 |
+
# Get column names from the query result description
|
| 623 |
+
columns = [col[0] for col in columns_info]
|
| 624 |
+
|
| 625 |
+
for row in rows:
|
| 626 |
+
token_dict = dict(zip(columns, row))
|
| 627 |
+
token_addr = token_dict.get('token_address')
|
| 628 |
+
if token_addr:
|
| 629 |
+
# The 'tokens' table in the schema has 'token_address' but the
|
| 630 |
+
# collator expects 'address'. We'll add it for compatibility.
|
| 631 |
+
token_dict['address'] = token_addr
|
| 632 |
+
tokens[token_addr] = token_dict
|
| 633 |
+
|
| 634 |
+
except Exception as e:
|
| 635 |
+
print(f"ERROR: Failed to fetch token data for batch {i}: {e}")
|
| 636 |
+
# Continue next batch
|
| 637 |
+
|
| 638 |
+
return tokens
|
| 639 |
+
|
| 640 |
+
def fetch_deployed_token_details(self, token_addresses: List[str], T_cutoff: datetime.datetime) -> Dict[str, Dict[str, Any]]:
|
| 641 |
+
"""
|
| 642 |
+
Fetches historical details for deployed tokens at or before T_cutoff.
|
| 643 |
+
Batches queries to avoid "Max query size exceeded" errors.
|
| 644 |
+
"""
|
| 645 |
+
if not token_addresses:
|
| 646 |
+
return {}
|
| 647 |
+
|
| 648 |
+
BATCH_SIZE = 1000
|
| 649 |
+
token_details = {}
|
| 650 |
+
total_tokens = len(token_addresses)
|
| 651 |
+
print(f"INFO: Executing query to fetch deployed token details for {total_tokens} tokens in batches of {BATCH_SIZE}.")
|
| 652 |
+
|
| 653 |
+
for i in range(0, total_tokens, BATCH_SIZE):
|
| 654 |
+
batch_addresses = token_addresses[i : i + BATCH_SIZE]
|
| 655 |
+
|
| 656 |
+
# --- NEW: Time-aware query for historical deployed token details ---
|
| 657 |
+
query = """
|
| 658 |
+
WITH ranked_tokens AS (
|
| 659 |
+
SELECT
|
| 660 |
+
*,
|
| 661 |
+
ROW_NUMBER() OVER (PARTITION BY token_address ORDER BY updated_at DESC) as rn
|
| 662 |
+
FROM tokens
|
| 663 |
+
WHERE
|
| 664 |
+
token_address IN %(addresses)s
|
| 665 |
+
AND updated_at <= %(T_cutoff)s
|
| 666 |
+
),
|
| 667 |
+
ranked_token_metrics AS (
|
| 668 |
+
SELECT
|
| 669 |
+
token_address,
|
| 670 |
+
ath_price_usd,
|
| 671 |
+
ROW_NUMBER() OVER (PARTITION BY token_address ORDER BY updated_at DESC) as rn
|
| 672 |
+
FROM token_metrics
|
| 673 |
+
WHERE
|
| 674 |
+
token_address IN %(addresses)s
|
| 675 |
+
AND updated_at <= %(T_cutoff)s
|
| 676 |
+
),
|
| 677 |
+
latest_tokens AS (
|
| 678 |
+
SELECT *
|
| 679 |
+
FROM ranked_tokens
|
| 680 |
+
WHERE rn = 1
|
| 681 |
+
),
|
| 682 |
+
latest_token_metrics AS (
|
| 683 |
+
SELECT *
|
| 684 |
+
FROM ranked_token_metrics
|
| 685 |
+
WHERE rn = 1
|
| 686 |
+
)
|
| 687 |
+
SELECT
|
| 688 |
+
lt.token_address,
|
| 689 |
+
lt.created_at,
|
| 690 |
+
lt.updated_at,
|
| 691 |
+
ltm.ath_price_usd,
|
| 692 |
+
lt.total_supply,
|
| 693 |
+
lt.decimals,
|
| 694 |
+
(lt.launchpad != lt.protocol) AS has_migrated
|
| 695 |
+
FROM latest_tokens AS lt
|
| 696 |
+
LEFT JOIN latest_token_metrics AS ltm
|
| 697 |
+
ON lt.token_address = ltm.token_address;
|
| 698 |
+
"""
|
| 699 |
+
params = {'addresses': batch_addresses, 'T_cutoff': T_cutoff}
|
| 700 |
+
|
| 701 |
+
try:
|
| 702 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 703 |
+
if not rows:
|
| 704 |
+
continue
|
| 705 |
+
|
| 706 |
+
columns = [col[0] for col in columns_info]
|
| 707 |
+
for row in rows:
|
| 708 |
+
token_details[row[0]] = dict(zip(columns, row))
|
| 709 |
+
except Exception as e:
|
| 710 |
+
print(f"ERROR: Failed to fetch deployed token details for batch {i}: {e}")
|
| 711 |
+
# Continue next batch
|
| 712 |
+
|
| 713 |
+
return token_details
|
| 714 |
+
|
| 715 |
+
def fetch_trades_for_token(self, token_address: str, T_cutoff: datetime.datetime, count_threshold: int, early_limit: int, recent_limit: int, full_history: bool = False) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
|
| 716 |
+
"""
|
| 717 |
+
Fetches ALL trades for a token up to T_cutoff, ordered by time.
|
| 718 |
+
|
| 719 |
+
Notes:
|
| 720 |
+
- This intentionally does NOT apply the older fetch-time H/B/H (High-Def / Blurry / High-Def)
|
| 721 |
+
sampling logic. Sequence-length control is handled later in data_loader.py via event-level
|
| 722 |
+
head/tail sampling with MIDDLE/RECENT markers.
|
| 723 |
+
- The function signature still includes legacy H/B/H parameters for compatibility.
|
| 724 |
+
Returns: (all_trades, [], [])
|
| 725 |
+
"""
|
| 726 |
+
if not token_address:
|
| 727 |
+
return [], [], []
|
| 728 |
+
|
| 729 |
+
params = {'token_address': token_address, 'T_cutoff': T_cutoff}
|
| 730 |
+
query = "SELECT * FROM trades WHERE base_address = %(token_address)s AND timestamp <= %(T_cutoff)s ORDER BY timestamp ASC"
|
| 731 |
+
try:
|
| 732 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 733 |
+
if not rows:
|
| 734 |
+
return [], [], []
|
| 735 |
+
columns = [col[0] for col in columns_info]
|
| 736 |
+
all_trades = [dict(zip(columns, row)) for row in rows]
|
| 737 |
+
return all_trades, [], []
|
| 738 |
+
except Exception as e:
|
| 739 |
+
print(f"ERROR: Failed to fetch trades for token {token_address}: {e}")
|
| 740 |
+
return [], [], []
|
| 741 |
+
|
| 742 |
+
def fetch_future_trades_for_token(self,
|
| 743 |
+
token_address: str,
|
| 744 |
+
start_ts: datetime.datetime,
|
| 745 |
+
end_ts: datetime.datetime) -> List[Dict[str, Any]]:
|
| 746 |
+
"""
|
| 747 |
+
Fetches successful trades for a token in the window (start_ts, end_ts].
|
| 748 |
+
Used for constructing label targets beyond the cutoff.
|
| 749 |
+
"""
|
| 750 |
+
if not token_address or start_ts is None or end_ts is None or start_ts >= end_ts:
|
| 751 |
+
return []
|
| 752 |
+
|
| 753 |
+
query = """
|
| 754 |
+
SELECT *
|
| 755 |
+
FROM trades
|
| 756 |
+
WHERE base_address = %(token_address)s
|
| 757 |
+
AND success = true
|
| 758 |
+
AND timestamp > %(start_ts)s
|
| 759 |
+
AND timestamp <= %(end_ts)s
|
| 760 |
+
ORDER BY timestamp ASC
|
| 761 |
+
"""
|
| 762 |
+
params = {
|
| 763 |
+
'token_address': token_address,
|
| 764 |
+
'start_ts': start_ts,
|
| 765 |
+
'end_ts': end_ts
|
| 766 |
+
}
|
| 767 |
+
|
| 768 |
+
try:
|
| 769 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 770 |
+
if not rows:
|
| 771 |
+
return []
|
| 772 |
+
columns = [col[0] for col in columns_info]
|
| 773 |
+
return [dict(zip(columns, row)) for row in rows]
|
| 774 |
+
except Exception as e:
|
| 775 |
+
print(f"ERROR: Failed to fetch future trades for token {token_address}: {e}")
|
| 776 |
+
return []
|
| 777 |
+
|
| 778 |
+
def fetch_transfers_for_token(self, token_address: str, T_cutoff: datetime.datetime, min_amount_threshold: float = 10_000_000) -> List[Dict[str, Any]]:
|
| 779 |
+
"""
|
| 780 |
+
Fetches all transfers for a token before T_cutoff, filtering out small amounts.
|
| 781 |
+
"""
|
| 782 |
+
if not token_address:
|
| 783 |
+
return []
|
| 784 |
+
|
| 785 |
+
query = """
|
| 786 |
+
SELECT * FROM transfers
|
| 787 |
+
WHERE mint_address = %(token_address)s
|
| 788 |
+
AND timestamp <= %(T_cutoff)s
|
| 789 |
+
AND amount_decimal >= %(min_amount)s
|
| 790 |
+
ORDER BY timestamp ASC
|
| 791 |
+
"""
|
| 792 |
+
params = {'token_address': token_address, 'T_cutoff': T_cutoff, 'min_amount': min_amount_threshold}
|
| 793 |
+
print(f"INFO: Fetching significant transfers for {token_address} (amount >= {min_amount_threshold}).")
|
| 794 |
+
|
| 795 |
+
try:
|
| 796 |
+
# This query no longer uses H/B/H, it fetches all significant transfers
|
| 797 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 798 |
+
if not rows: return []
|
| 799 |
+
columns = [col[0] for col in columns_info]
|
| 800 |
+
return [dict(zip(columns, row)) for row in rows]
|
| 801 |
+
except Exception as e:
|
| 802 |
+
print(f"ERROR: Failed to fetch transfers for token {token_address}: {e}")
|
| 803 |
+
return []
|
| 804 |
+
|
| 805 |
+
def fetch_pool_creations_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
|
| 806 |
+
"""
|
| 807 |
+
Fetches pool creation records where the token is the base asset.
|
| 808 |
+
"""
|
| 809 |
+
if not token_address:
|
| 810 |
+
return []
|
| 811 |
+
|
| 812 |
+
query = """
|
| 813 |
+
SELECT
|
| 814 |
+
signature,
|
| 815 |
+
timestamp,
|
| 816 |
+
slot,
|
| 817 |
+
success,
|
| 818 |
+
error,
|
| 819 |
+
priority_fee,
|
| 820 |
+
protocol,
|
| 821 |
+
creator_address,
|
| 822 |
+
pool_address,
|
| 823 |
+
base_address,
|
| 824 |
+
quote_address,
|
| 825 |
+
lp_token_address,
|
| 826 |
+
initial_base_liquidity,
|
| 827 |
+
initial_quote_liquidity,
|
| 828 |
+
base_decimals,
|
| 829 |
+
quote_decimals
|
| 830 |
+
FROM pool_creations
|
| 831 |
+
WHERE base_address = %(token_address)s
|
| 832 |
+
AND timestamp <= %(T_cutoff)s
|
| 833 |
+
ORDER BY timestamp ASC
|
| 834 |
+
"""
|
| 835 |
+
params = {'token_address': token_address, 'T_cutoff': T_cutoff}
|
| 836 |
+
# print(f"INFO: Fetching pool creation events for {token_address}.")
|
| 837 |
+
|
| 838 |
+
try:
|
| 839 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 840 |
+
if not rows:
|
| 841 |
+
return []
|
| 842 |
+
columns = [col[0] for col in columns_info]
|
| 843 |
+
return [dict(zip(columns, row)) for row in rows]
|
| 844 |
+
except Exception as e:
|
| 845 |
+
print(f"ERROR: Failed to fetch pool creations for token {token_address}: {e}")
|
| 846 |
+
return []
|
| 847 |
+
|
| 848 |
+
def fetch_liquidity_changes_for_pools(self, pool_addresses: List[str], T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
|
| 849 |
+
"""
|
| 850 |
+
Fetches liquidity change records for the given pools up to T_cutoff.
|
| 851 |
+
"""
|
| 852 |
+
if not pool_addresses:
|
| 853 |
+
return []
|
| 854 |
+
|
| 855 |
+
query = """
|
| 856 |
+
SELECT
|
| 857 |
+
signature,
|
| 858 |
+
timestamp,
|
| 859 |
+
slot,
|
| 860 |
+
success,
|
| 861 |
+
error,
|
| 862 |
+
priority_fee,
|
| 863 |
+
protocol,
|
| 864 |
+
change_type,
|
| 865 |
+
lp_provider,
|
| 866 |
+
pool_address,
|
| 867 |
+
base_amount,
|
| 868 |
+
quote_amount
|
| 869 |
+
FROM liquidity
|
| 870 |
+
WHERE pool_address IN %(pool_addresses)s
|
| 871 |
+
AND timestamp <= %(T_cutoff)s
|
| 872 |
+
ORDER BY timestamp ASC
|
| 873 |
+
"""
|
| 874 |
+
params = {'pool_addresses': pool_addresses, 'T_cutoff': T_cutoff}
|
| 875 |
+
# print(f"INFO: Fetching liquidity change events for {len(pool_addresses)} pool(s).")
|
| 876 |
+
|
| 877 |
+
try:
|
| 878 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 879 |
+
if not rows:
|
| 880 |
+
return []
|
| 881 |
+
columns = [col[0] for col in columns_info]
|
| 882 |
+
return [dict(zip(columns, row)) for row in rows]
|
| 883 |
+
except Exception as e:
|
| 884 |
+
print(f"ERROR: Failed to fetch liquidity changes for pools {pool_addresses}: {e}")
|
| 885 |
+
return []
|
| 886 |
+
|
| 887 |
+
def fetch_fee_collections_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
|
| 888 |
+
"""
|
| 889 |
+
Fetches fee collection events where the token appears as either token_0 or token_1.
|
| 890 |
+
"""
|
| 891 |
+
if not token_address:
|
| 892 |
+
return []
|
| 893 |
+
|
| 894 |
+
query = """
|
| 895 |
+
SELECT
|
| 896 |
+
timestamp,
|
| 897 |
+
signature,
|
| 898 |
+
slot,
|
| 899 |
+
success,
|
| 900 |
+
error,
|
| 901 |
+
priority_fee,
|
| 902 |
+
protocol,
|
| 903 |
+
recipient_address,
|
| 904 |
+
token_0_mint_address,
|
| 905 |
+
token_0_amount,
|
| 906 |
+
token_1_mint_address,
|
| 907 |
+
token_1_amount
|
| 908 |
+
FROM fee_collections
|
| 909 |
+
WHERE (token_0_mint_address = %(token)s OR token_1_mint_address = %(token)s)
|
| 910 |
+
AND timestamp <= %(T_cutoff)s
|
| 911 |
+
ORDER BY timestamp ASC
|
| 912 |
+
"""
|
| 913 |
+
params = {'token': token_address, 'T_cutoff': T_cutoff}
|
| 914 |
+
# print(f"INFO: Fetching fee collection events for {token_address}.")
|
| 915 |
+
|
| 916 |
+
try:
|
| 917 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 918 |
+
if not rows:
|
| 919 |
+
return []
|
| 920 |
+
columns = [col[0] for col in columns_info]
|
| 921 |
+
return [dict(zip(columns, row)) for row in rows]
|
| 922 |
+
except Exception as e:
|
| 923 |
+
print(f"ERROR: Failed to fetch fee collections for token {token_address}: {e}")
|
| 924 |
+
return []
|
| 925 |
+
|
| 926 |
+
def fetch_migrations_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
|
| 927 |
+
"""
|
| 928 |
+
Fetches migration records for a given token up to T_cutoff.
|
| 929 |
+
"""
|
| 930 |
+
if not token_address:
|
| 931 |
+
return []
|
| 932 |
+
query = """
|
| 933 |
+
SELECT
|
| 934 |
+
timestamp,
|
| 935 |
+
signature,
|
| 936 |
+
slot,
|
| 937 |
+
success,
|
| 938 |
+
error,
|
| 939 |
+
priority_fee,
|
| 940 |
+
protocol,
|
| 941 |
+
mint_address,
|
| 942 |
+
virtual_pool_address,
|
| 943 |
+
pool_address,
|
| 944 |
+
migrated_base_liquidity,
|
| 945 |
+
migrated_quote_liquidity
|
| 946 |
+
FROM migrations
|
| 947 |
+
WHERE mint_address = %(token)s
|
| 948 |
+
AND timestamp <= %(T_cutoff)s
|
| 949 |
+
ORDER BY timestamp ASC
|
| 950 |
+
"""
|
| 951 |
+
params = {'token': token_address, 'T_cutoff': T_cutoff}
|
| 952 |
+
# print(f"INFO: Fetching migrations for {token_address}.")
|
| 953 |
+
try:
|
| 954 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 955 |
+
if not rows:
|
| 956 |
+
return []
|
| 957 |
+
columns = [col[0] for col in columns_info]
|
| 958 |
+
return [dict(zip(columns, row)) for row in rows]
|
| 959 |
+
except Exception as e:
|
| 960 |
+
print(f"ERROR: Failed to fetch migrations for token {token_address}: {e}")
|
| 961 |
+
return []
|
| 962 |
+
|
| 963 |
+
def fetch_burns_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
|
| 964 |
+
"""
|
| 965 |
+
Fetches burn events for a given token up to T_cutoff.
|
| 966 |
+
Schema: burns(timestamp, signature, slot, success, error, priority_fee, mint_address, source, amount, amount_decimal, source_balance)
|
| 967 |
+
"""
|
| 968 |
+
if not token_address:
|
| 969 |
+
return []
|
| 970 |
+
|
| 971 |
+
query = """
|
| 972 |
+
SELECT
|
| 973 |
+
timestamp,
|
| 974 |
+
signature,
|
| 975 |
+
slot,
|
| 976 |
+
success,
|
| 977 |
+
error,
|
| 978 |
+
priority_fee,
|
| 979 |
+
mint_address,
|
| 980 |
+
source,
|
| 981 |
+
amount,
|
| 982 |
+
amount_decimal,
|
| 983 |
+
source_balance
|
| 984 |
+
FROM burns
|
| 985 |
+
WHERE mint_address = %(token)s
|
| 986 |
+
AND timestamp <= %(T_cutoff)s
|
| 987 |
+
ORDER BY timestamp ASC
|
| 988 |
+
"""
|
| 989 |
+
params = {'token': token_address, 'T_cutoff': T_cutoff}
|
| 990 |
+
# print(f"INFO: Fetching burn events for {token_address}.")
|
| 991 |
+
|
| 992 |
+
try:
|
| 993 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 994 |
+
if not rows:
|
| 995 |
+
return []
|
| 996 |
+
columns = [col[0] for col in columns_info]
|
| 997 |
+
return [dict(zip(columns, row)) for row in rows]
|
| 998 |
+
except Exception as e:
|
| 999 |
+
print(f"ERROR: Failed to fetch burns for token {token_address}: {e}")
|
| 1000 |
+
return []
|
| 1001 |
+
|
| 1002 |
+
def fetch_supply_locks_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> List[Dict[str, Any]]:
|
| 1003 |
+
"""
|
| 1004 |
+
Fetches supply lock events for a given token up to T_cutoff.
|
| 1005 |
+
Schema: supply_locks(timestamp, signature, slot, success, error, priority_fee, protocol, contract_address, sender, recipient, mint_address, total_locked_amount, final_unlock_timestamp)
|
| 1006 |
+
"""
|
| 1007 |
+
if not token_address:
|
| 1008 |
+
return []
|
| 1009 |
+
|
| 1010 |
+
query = """
|
| 1011 |
+
SELECT
|
| 1012 |
+
timestamp,
|
| 1013 |
+
signature,
|
| 1014 |
+
slot,
|
| 1015 |
+
success,
|
| 1016 |
+
error,
|
| 1017 |
+
priority_fee,
|
| 1018 |
+
protocol,
|
| 1019 |
+
contract_address,
|
| 1020 |
+
sender,
|
| 1021 |
+
recipient,
|
| 1022 |
+
mint_address,
|
| 1023 |
+
total_locked_amount,
|
| 1024 |
+
final_unlock_timestamp
|
| 1025 |
+
FROM supply_locks
|
| 1026 |
+
WHERE mint_address = %(token)s
|
| 1027 |
+
AND timestamp <= %(T_cutoff)s
|
| 1028 |
+
ORDER BY timestamp ASC
|
| 1029 |
+
"""
|
| 1030 |
+
params = {'token': token_address, 'T_cutoff': T_cutoff}
|
| 1031 |
+
# print(f"INFO: Fetching supply lock events for {token_address}.")
|
| 1032 |
+
|
| 1033 |
+
try:
|
| 1034 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 1035 |
+
if not rows:
|
| 1036 |
+
return []
|
| 1037 |
+
columns = [col[0] for col in columns_info]
|
| 1038 |
+
return [dict(zip(columns, row)) for row in rows]
|
| 1039 |
+
except Exception as e:
|
| 1040 |
+
print(f"ERROR: Failed to fetch supply locks for token {token_address}: {e}")
|
| 1041 |
+
return []
|
| 1042 |
+
|
| 1043 |
+
def fetch_token_holders_for_snapshot(self, token_address: str, T_cutoff: datetime.datetime, limit: int = 200) -> List[Dict[str, Any]]:
|
| 1044 |
+
"""
|
| 1045 |
+
Fetch top holders for a token at or before T_cutoff for snapshot purposes.
|
| 1046 |
+
Returns rows with wallet_address and current_balance (>0), ordered by balance desc.
|
| 1047 |
+
"""
|
| 1048 |
+
if not token_address:
|
| 1049 |
+
return []
|
| 1050 |
+
query = """
|
| 1051 |
+
WITH point_in_time_holdings AS (
|
| 1052 |
+
SELECT *,
|
| 1053 |
+
ROW_NUMBER() OVER(PARTITION BY wallet_address, mint_address ORDER BY updated_at DESC) as rn_per_holding
|
| 1054 |
+
FROM wallet_holdings
|
| 1055 |
+
WHERE mint_address = %(token)s AND updated_at <= %(T_cutoff)s
|
| 1056 |
+
)
|
| 1057 |
+
SELECT wallet_address, current_balance
|
| 1058 |
+
FROM point_in_time_holdings
|
| 1059 |
+
WHERE rn_per_holding = 1 AND current_balance > 0
|
| 1060 |
+
ORDER BY current_balance DESC
|
| 1061 |
+
LIMIT %(limit)s;
|
| 1062 |
+
"""
|
| 1063 |
+
params = {'token': token_address, 'T_cutoff': T_cutoff, 'limit': int(limit)}
|
| 1064 |
+
# print(f"INFO: Fetching top holders for snapshot for {token_address} (limit {limit}).")
|
| 1065 |
+
try:
|
| 1066 |
+
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 1067 |
+
if not rows:
|
| 1068 |
+
return []
|
| 1069 |
+
columns = [col[0] for col in columns_info]
|
| 1070 |
+
return [dict(zip(columns, row)) for row in rows]
|
| 1071 |
+
except Exception as e:
|
| 1072 |
+
print(f"ERROR: Failed to fetch token holders for {token_address}: {e}")
|
| 1073 |
+
return []
|
| 1074 |
+
|
| 1075 |
+
def fetch_total_holders_count_for_token(self, token_address: str, T_cutoff: datetime.datetime) -> int:
|
| 1076 |
+
"""
|
| 1077 |
+
Returns the total number of wallets holding the token (current_balance > 0)
|
| 1078 |
+
at or before T_cutoff.
|
| 1079 |
+
"""
|
| 1080 |
+
if not token_address:
|
| 1081 |
+
return 0
|
| 1082 |
+
query = """
|
| 1083 |
+
WITH point_in_time_holdings AS (
|
| 1084 |
+
SELECT *,
|
| 1085 |
+
ROW_NUMBER() OVER(PARTITION BY wallet_address, mint_address ORDER BY updated_at DESC) as rn_per_holding
|
| 1086 |
+
FROM wallet_holdings
|
| 1087 |
+
WHERE mint_address = %(token)s AND updated_at <= %(T_cutoff)s
|
| 1088 |
+
)
|
| 1089 |
+
SELECT count()
|
| 1090 |
+
FROM point_in_time_holdings
|
| 1091 |
+
WHERE rn_per_holding = 1 AND current_balance > 0;
|
| 1092 |
+
"""
|
| 1093 |
+
params = {'token': token_address, 'T_cutoff': T_cutoff}
|
| 1094 |
+
# print(f"INFO: Counting total holders for {token_address} at timestamp {T_cutoff}.")
|
| 1095 |
+
try:
|
| 1096 |
+
rows = self.db_client.execute(query, params)
|
| 1097 |
+
if not rows:
|
| 1098 |
+
return 0
|
| 1099 |
+
return int(rows[0][0])
|
| 1100 |
+
except Exception as e:
|
| 1101 |
+
print(f"ERROR: Failed to count total holders for token {token_address}: {e}")
|
| 1102 |
+
return 0
|
| 1103 |
+
|
| 1104 |
+
def fetch_holder_snapshot_stats_for_token(self, token_address: str, T_cutoff: datetime.datetime, limit: int = 200) -> Tuple[int, List[Dict[str, Any]]]:
|
| 1105 |
+
"""
|
| 1106 |
+
Fetch total holder count at a point in time.
|
| 1107 |
+
Returns (count, top_holders_list).
|
| 1108 |
+
Uses the indexed wallet_holdings table directly - efficient due to mint_address filter.
|
| 1109 |
+
"""
|
| 1110 |
+
if not token_address:
|
| 1111 |
+
return 0, []
|
| 1112 |
+
|
| 1113 |
+
holder_count = self.fetch_total_holders_count_for_token(token_address, T_cutoff)
|
| 1114 |
+
return holder_count, []
|
| 1115 |
+
def fetch_raw_token_data(
|
| 1116 |
+
self,
|
| 1117 |
+
token_address: str,
|
| 1118 |
+
creator_address: str,
|
| 1119 |
+
mint_timestamp: datetime.datetime,
|
| 1120 |
+
max_horizon_seconds: int = 3600,
|
| 1121 |
+
include_wallet_data: bool = True,
|
| 1122 |
+
include_graph: bool = True,
|
| 1123 |
+
min_trades: int = 0,
|
| 1124 |
+
full_history: bool = False,
|
| 1125 |
+
prune_failed: bool = False,
|
| 1126 |
+
prune_transfers: bool = False
|
| 1127 |
+
) -> Optional[Dict[str, Any]]:
|
| 1128 |
+
"""
|
| 1129 |
+
Fetches ALL available data for a token up to the maximum horizon.
|
| 1130 |
+
This data is agnostic of T_cutoff and will be masked/filtered dynamically during training.
|
| 1131 |
+
Wallet/graph data can be skipped to avoid caching T_cutoff-dependent features.
|
| 1132 |
+
|
| 1133 |
+
Args:
|
| 1134 |
+
full_history: If True, fetches ALL trades ignoring H/B/H limits.
|
| 1135 |
+
prune_failed: If True, filters out failed trades from the result.
|
| 1136 |
+
prune_transfers: If True, skips fetching transfers entirely.
|
| 1137 |
+
"""
|
| 1138 |
+
|
| 1139 |
+
# 1. Calculate the absolute maximum timestamp we care about (mint + max_horizon)
|
| 1140 |
+
# We fetch everything up to this point.
|
| 1141 |
+
max_limit_time = mint_timestamp + datetime.timedelta(seconds=max_horizon_seconds)
|
| 1142 |
+
|
| 1143 |
+
# 2. Fetch all trades up to max_limit_time
|
| 1144 |
+
# Note: We pass None as T_cutoff to fetch_trades_for_token if we want *everything*,
|
| 1145 |
+
# but here we likely want to bound it by our max training horizon to avoid fetching months of data.
|
| 1146 |
+
# However, the existing method signature expects T_cutoff.
|
| 1147 |
+
# So we pass max_limit_time as the "cutoff" for the purpose of raw data collection.
|
| 1148 |
+
|
| 1149 |
+
# We use a large enough limit to get all relevant trades for the session
|
| 1150 |
+
# If full_history is True, these limits are ignored inside the method.
|
| 1151 |
+
early_trades, middle_trades, recent_trades = self.fetch_trades_for_token(
|
| 1152 |
+
token_address, max_limit_time, 30000, 10000, 15000, full_history=full_history
|
| 1153 |
+
)
|
| 1154 |
+
|
| 1155 |
+
# Combine and deduplicate trades
|
| 1156 |
+
all_trades = {}
|
| 1157 |
+
for t in early_trades + middle_trades + recent_trades:
|
| 1158 |
+
# key: (slot, tx_idx, instr_idx)
|
| 1159 |
+
key = (t.get('slot'), t.get('transaction_index'), t.get('instruction_index'), t.get('signature'))
|
| 1160 |
+
all_trades[key] = t
|
| 1161 |
+
|
| 1162 |
+
sorted_trades = sorted(list(all_trades.values()), key=lambda x: x['timestamp'])
|
| 1163 |
+
|
| 1164 |
+
# --- PRUNING FAILED TRADES ---
|
| 1165 |
+
if prune_failed:
|
| 1166 |
+
original_count = len(sorted_trades)
|
| 1167 |
+
sorted_trades = [t for t in sorted_trades if t.get('success', False)]
|
| 1168 |
+
if len(sorted_trades) < original_count:
|
| 1169 |
+
# print(f" INFO: Pruned {original_count - len(sorted_trades)} failed trades.")
|
| 1170 |
+
pass
|
| 1171 |
+
|
| 1172 |
+
if len(sorted_trades) < min_trades:
|
| 1173 |
+
print(f" SKIP: Token {token_address} has only {len(sorted_trades)} trades (min required: {min_trades}). skipping fetches.")
|
| 1174 |
+
return None
|
| 1175 |
+
|
| 1176 |
+
# 3. Fetch other events
|
| 1177 |
+
# --- PRUNING TRANSFERS ---
|
| 1178 |
+
if prune_transfers:
|
| 1179 |
+
transfers = []
|
| 1180 |
+
# print(" INFO: Pruning transfers (skipping fetch).")
|
| 1181 |
+
else:
|
| 1182 |
+
transfers = self.fetch_transfers_for_token(token_address, max_limit_time, 0.0) # 0.0 means fetch all
|
| 1183 |
+
|
| 1184 |
+
pool_creations = self.fetch_pool_creations_for_token(token_address, max_limit_time)
|
| 1185 |
+
|
| 1186 |
+
# Collect pool addresses to fetch liquidity changes
|
| 1187 |
+
pool_addresses = [p['pool_address'] for p in pool_creations if p.get('pool_address')]
|
| 1188 |
+
liquidity_changes = []
|
| 1189 |
+
if pool_addresses:
|
| 1190 |
+
liquidity_changes = self.fetch_liquidity_changes_for_pools(pool_addresses, max_limit_time)
|
| 1191 |
+
|
| 1192 |
+
fee_collections = self.fetch_fee_collections_for_token(token_address, max_limit_time)
|
| 1193 |
+
burns = self.fetch_burns_for_token(token_address, max_limit_time)
|
| 1194 |
+
supply_locks = self.fetch_supply_locks_for_token(token_address, max_limit_time)
|
| 1195 |
+
migrations = self.fetch_migrations_for_token(token_address, max_limit_time)
|
| 1196 |
+
|
| 1197 |
+
profile_data = {}
|
| 1198 |
+
social_data = {}
|
| 1199 |
+
holdings_data = {}
|
| 1200 |
+
deployed_token_details = {}
|
| 1201 |
+
fetched_graph_entities = {}
|
| 1202 |
+
graph_links = {}
|
| 1203 |
+
|
| 1204 |
+
unique_wallets = set()
|
| 1205 |
+
if include_wallet_data or include_graph:
|
| 1206 |
+
# Identify wallets that interacted with the token up to max_limit_time.
|
| 1207 |
+
unique_wallets.add(creator_address)
|
| 1208 |
+
for t in sorted_trades:
|
| 1209 |
+
if t.get('maker'):
|
| 1210 |
+
unique_wallets.add(t['maker'])
|
| 1211 |
+
for t in transfers:
|
| 1212 |
+
if t.get('source'):
|
| 1213 |
+
unique_wallets.add(t['source'])
|
| 1214 |
+
if t.get('destination'):
|
| 1215 |
+
unique_wallets.add(t['destination'])
|
| 1216 |
+
for p in pool_creations:
|
| 1217 |
+
if p.get('creator_address'):
|
| 1218 |
+
unique_wallets.add(p['creator_address'])
|
| 1219 |
+
for l in liquidity_changes:
|
| 1220 |
+
if l.get('lp_provider'):
|
| 1221 |
+
unique_wallets.add(l['lp_provider'])
|
| 1222 |
+
|
| 1223 |
+
if include_wallet_data and unique_wallets:
|
| 1224 |
+
# Profiles/holdings are time-dependent; only fetch if explicitly requested.
|
| 1225 |
+
profile_data, social_data = self.fetch_wallet_profiles_and_socials(list(unique_wallets), max_limit_time)
|
| 1226 |
+
holdings_data = self.fetch_wallet_holdings(list(unique_wallets), max_limit_time)
|
| 1227 |
+
|
| 1228 |
+
all_deployed_tokens = set()
|
| 1229 |
+
for profile in profile_data.values():
|
| 1230 |
+
all_deployed_tokens.update(profile.get('deployed_tokens', []))
|
| 1231 |
+
if all_deployed_tokens:
|
| 1232 |
+
deployed_token_details = self.fetch_deployed_token_details(list(all_deployed_tokens), max_limit_time)
|
| 1233 |
+
|
| 1234 |
+
if include_graph and unique_wallets:
|
| 1235 |
+
graph_seed_wallets = list(unique_wallets)
|
| 1236 |
+
if len(graph_seed_wallets) > 100:
|
| 1237 |
+
pass
|
| 1238 |
+
fetched_graph_entities, graph_links = self.fetch_graph_links(
|
| 1239 |
+
graph_seed_wallets,
|
| 1240 |
+
max_limit_time,
|
| 1241 |
+
max_degrees=1
|
| 1242 |
+
)
|
| 1243 |
+
|
| 1244 |
+
return {
|
| 1245 |
+
"token_address": token_address,
|
| 1246 |
+
"creator_address": creator_address,
|
| 1247 |
+
"mint_timestamp": mint_timestamp,
|
| 1248 |
+
"max_limit_time": max_limit_time,
|
| 1249 |
+
"trades": sorted_trades,
|
| 1250 |
+
"transfers": transfers,
|
| 1251 |
+
"pool_creations": pool_creations,
|
| 1252 |
+
"liquidity_changes": liquidity_changes,
|
| 1253 |
+
"fee_collections": fee_collections,
|
| 1254 |
+
"burns": burns,
|
| 1255 |
+
"supply_locks": supply_locks,
|
| 1256 |
+
"migrations": migrations,
|
| 1257 |
+
"profiles": profile_data,
|
| 1258 |
+
"socials": social_data,
|
| 1259 |
+
"holdings": holdings_data,
|
| 1260 |
+
"deployed_token_details": deployed_token_details,
|
| 1261 |
+
"graph_entities": fetched_graph_entities,
|
| 1262 |
+
"graph_links": graph_links
|
| 1263 |
+
}
|
data/data_collator.py
CHANGED
|
@@ -144,23 +144,32 @@ class MemecoinCollator:
|
|
| 144 |
item_wallet_addr_to_global_idx = {addr: wallet_addr_to_batch_idx.get(addr, self.entity_pad_idx) for addr in item_wallets.keys()}
|
| 145 |
item_token_addr_to_global_idx = {addr: token_addr_to_batch_idx.get(addr, self.entity_pad_idx) for addr in item_tokens.keys()}
|
| 146 |
for link_name, data in item.get('graph_links', {}).items():
|
| 147 |
-
aggregated_links[link_name]['links_list'].extend(data.get('links', []))
|
| 148 |
triplet = vocab.LINK_NAME_TO_TRIPLET.get(link_name)
|
| 149 |
if not triplet: continue
|
| 150 |
src_type, _, dst_type = triplet
|
| 151 |
edges = data.get('edges')
|
| 152 |
-
|
|
|
|
|
|
|
| 153 |
src_map = item_wallet_addr_to_global_idx if src_type == 'wallet' else item_token_addr_to_global_idx
|
| 154 |
dst_map = item_wallet_addr_to_global_idx if dst_type == 'wallet' else item_token_addr_to_global_idx
|
|
|
|
| 155 |
remapped_edge_list = []
|
| 156 |
-
|
|
|
|
|
|
|
| 157 |
src_idx_global = src_map.get(src_addr, self.entity_pad_idx)
|
| 158 |
dst_idx_global = dst_map.get(dst_addr, self.entity_pad_idx)
|
|
|
|
| 159 |
if src_idx_global != self.entity_pad_idx and dst_idx_global != self.entity_pad_idx:
|
| 160 |
remapped_edge_list.append([src_idx_global, dst_idx_global])
|
|
|
|
|
|
|
| 161 |
if remapped_edge_list:
|
| 162 |
remapped_edge_tensor = torch.tensor(remapped_edge_list, device=self.device, dtype=torch.long).t()
|
| 163 |
aggregated_links[link_name]['edge_index_list'].append(remapped_edge_tensor)
|
|
|
|
| 164 |
if link_name == "TransferLink":
|
| 165 |
link_props = data.get('links', [])
|
| 166 |
derived_edges = []
|
|
@@ -737,7 +746,7 @@ class MemecoinCollator:
|
|
| 737 |
# Labels
|
| 738 |
'labels': torch.stack([item['labels'] for item in batch]) if batch and 'labels' in batch[0] else None,
|
| 739 |
'labels_mask': torch.stack([item['labels_mask'] for item in batch]) if batch and 'labels_mask' in batch[0] else None,
|
| 740 |
-
'quality_score': torch.stack([item['quality_score'] for item in batch]) if batch and 'quality_score' in batch[0] else None,
|
| 741 |
'class_id': torch.tensor([item.get('class_id', 0) for item in batch], dtype=torch.long),
|
| 742 |
# Debug info
|
| 743 |
'token_addresses': [item.get('token_address', 'unknown') for item in batch],
|
|
|
|
| 144 |
item_wallet_addr_to_global_idx = {addr: wallet_addr_to_batch_idx.get(addr, self.entity_pad_idx) for addr in item_wallets.keys()}
|
| 145 |
item_token_addr_to_global_idx = {addr: token_addr_to_batch_idx.get(addr, self.entity_pad_idx) for addr in item_tokens.keys()}
|
| 146 |
for link_name, data in item.get('graph_links', {}).items():
|
| 147 |
+
# aggregated_links[link_name]['links_list'].extend(data.get('links', [])) - REMOVED: Now handled inside the loop for sync
|
| 148 |
triplet = vocab.LINK_NAME_TO_TRIPLET.get(link_name)
|
| 149 |
if not triplet: continue
|
| 150 |
src_type, _, dst_type = triplet
|
| 151 |
edges = data.get('edges')
|
| 152 |
+
link_props_list = data.get('links', [])
|
| 153 |
+
if not edges or not link_props_list: continue
|
| 154 |
+
|
| 155 |
src_map = item_wallet_addr_to_global_idx if src_type == 'wallet' else item_token_addr_to_global_idx
|
| 156 |
dst_map = item_wallet_addr_to_global_idx if dst_type == 'wallet' else item_token_addr_to_global_idx
|
| 157 |
+
|
| 158 |
remapped_edge_list = []
|
| 159 |
+
valid_link_props = []
|
| 160 |
+
|
| 161 |
+
for (src_addr, dst_addr), props in zip(edges, link_props_list):
|
| 162 |
src_idx_global = src_map.get(src_addr, self.entity_pad_idx)
|
| 163 |
dst_idx_global = dst_map.get(dst_addr, self.entity_pad_idx)
|
| 164 |
+
|
| 165 |
if src_idx_global != self.entity_pad_idx and dst_idx_global != self.entity_pad_idx:
|
| 166 |
remapped_edge_list.append([src_idx_global, dst_idx_global])
|
| 167 |
+
valid_link_props.append(props)
|
| 168 |
+
|
| 169 |
if remapped_edge_list:
|
| 170 |
remapped_edge_tensor = torch.tensor(remapped_edge_list, device=self.device, dtype=torch.long).t()
|
| 171 |
aggregated_links[link_name]['edge_index_list'].append(remapped_edge_tensor)
|
| 172 |
+
aggregated_links[link_name]['links_list'].extend(valid_link_props)
|
| 173 |
if link_name == "TransferLink":
|
| 174 |
link_props = data.get('links', [])
|
| 175 |
derived_edges = []
|
|
|
|
| 746 |
# Labels
|
| 747 |
'labels': torch.stack([item['labels'] for item in batch]) if batch and 'labels' in batch[0] else None,
|
| 748 |
'labels_mask': torch.stack([item['labels_mask'] for item in batch]) if batch and 'labels_mask' in batch[0] else None,
|
| 749 |
+
'quality_score': torch.stack([item['quality_score'] if isinstance(item['quality_score'], torch.Tensor) else torch.tensor(item['quality_score'], dtype=torch.float32) for item in batch]) if batch and 'quality_score' in batch[0] else None,
|
| 750 |
'class_id': torch.tensor([item.get('class_id', 0) for item in batch], dtype=torch.long),
|
| 751 |
# Debug info
|
| 752 |
'token_addresses': [item.get('token_address', 'unknown') for item in batch],
|
data/ohlc_stats.npz
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1660
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3af6751fb5666ccfd4c61d27c549e5fcd71d964090836f9d3646d6f1d63224c0
|
| 3 |
size 1660
|
log.log
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:41a901f956af52a553855651ff68f78a817ad4fa5b108efde1034e22a16724a0
|
| 3 |
+
size 4577
|
models/graph_updater.py
CHANGED
|
@@ -400,10 +400,10 @@ class GraphUpdater(nn.Module):
|
|
| 400 |
|
| 401 |
# Use vocabulary to get the triplet (src, rel, dst)
|
| 402 |
# Make sure ID_TO_LINK_TYPE is correctly populated
|
| 403 |
-
if link_name not in vocabulary.LINK_NAME_TO_TRIPLET:
|
| 404 |
print(f"Warning: Link name '{link_name}' not found in vocabulary.LINK_NAME_TO_TRIPLET. Skipping.")
|
| 405 |
continue
|
| 406 |
-
src_type, rel_type, dst_type = vocabulary.LINK_NAME_TO_TRIPLET[link_name]
|
| 407 |
|
| 408 |
# Check if encoder exists for this link name
|
| 409 |
if link_name not in self.edge_encoders:
|
|
@@ -466,10 +466,9 @@ class GraphUpdater(nn.Module):
|
|
| 466 |
print(f"Warning: Relation '{rel_type}' missing in block {block_key}. Skipping.")
|
| 467 |
continue
|
| 468 |
|
| 469 |
-
#
|
| 470 |
-
#
|
| 471 |
-
|
| 472 |
-
msg_aggregates[dst_type].scatter_add_(0, edge_index[1].unsqueeze(1).expand_as(messages), messages)
|
| 473 |
|
| 474 |
# --- Aggregation & Update (Residual Connection) ---
|
| 475 |
x_next = {}
|
|
|
|
| 400 |
|
| 401 |
# Use vocabulary to get the triplet (src, rel, dst)
|
| 402 |
# Make sure ID_TO_LINK_TYPE is correctly populated
|
| 403 |
+
if link_name not in models.vocabulary.LINK_NAME_TO_TRIPLET:
|
| 404 |
print(f"Warning: Link name '{link_name}' not found in vocabulary.LINK_NAME_TO_TRIPLET. Skipping.")
|
| 405 |
continue
|
| 406 |
+
src_type, rel_type, dst_type = models.vocabulary.LINK_NAME_TO_TRIPLET[link_name]
|
| 407 |
|
| 408 |
# Check if encoder exists for this link name
|
| 409 |
if link_name not in self.edge_encoders:
|
|
|
|
| 466 |
print(f"Warning: Relation '{rel_type}' missing in block {block_key}. Skipping.")
|
| 467 |
continue
|
| 468 |
|
| 469 |
+
# GATv2Conv output is already per-destination-node (shape [num_dst_nodes, node_dim])
|
| 470 |
+
# NOT per-edge. So we directly accumulate, no scatter needed.
|
| 471 |
+
msg_aggregates[dst_type] += messages
|
|
|
|
| 472 |
|
| 473 |
# --- Aggregation & Update (Residual Connection) ---
|
| 474 |
x_next = {}
|
sample_12LJX4a83B4tCuZ1_3.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
scripts/.ipynb_checkpoints/cache_dataset-checkpoint.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import argparse
|
| 5 |
+
import numpy as np
|
| 6 |
+
import datetime
|
| 7 |
+
import torch
|
| 8 |
+
import json
|
| 9 |
+
import math
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
from dotenv import load_dotenv
|
| 13 |
+
import huggingface_hub
|
| 14 |
+
import logging
|
| 15 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 16 |
+
import multiprocessing as mp
|
| 17 |
+
|
| 18 |
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
| 19 |
+
logging.getLogger("transformers").setLevel(logging.ERROR)
|
| 20 |
+
logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
|
| 21 |
+
|
| 22 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 23 |
+
|
| 24 |
+
from scripts.analyze_distribution import get_return_class_map
|
| 25 |
+
from scripts.compute_quality_score import get_token_quality_scores, fetch_token_metrics, _bucket_id, _midrank_percentiles, EPS
|
| 26 |
+
|
| 27 |
+
from clickhouse_driver import Client as ClickHouseClient
|
| 28 |
+
from neo4j import GraphDatabase
|
| 29 |
+
|
| 30 |
+
_worker_dataset = None
|
| 31 |
+
_worker_return_class_map = None
|
| 32 |
+
_worker_quality_scores_map = None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map):
|
| 36 |
+
global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
|
| 37 |
+
from data.data_loader import OracleDataset
|
| 38 |
+
from data.data_fetcher import DataFetcher
|
| 39 |
+
|
| 40 |
+
clickhouse_client = ClickHouseClient(host=db_config['clickhouse_host'], port=db_config['clickhouse_port'])
|
| 41 |
+
neo4j_driver = GraphDatabase.driver(db_config['neo4j_uri'], auth=(db_config['neo4j_user'], db_config['neo4j_password']))
|
| 42 |
+
data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
|
| 43 |
+
|
| 44 |
+
_worker_dataset = OracleDataset(
|
| 45 |
+
data_fetcher=data_fetcher,
|
| 46 |
+
max_samples=dataset_config['max_samples'],
|
| 47 |
+
start_date=dataset_config['start_date'],
|
| 48 |
+
ohlc_stats_path=dataset_config['ohlc_stats_path'],
|
| 49 |
+
horizons_seconds=dataset_config['horizons_seconds'],
|
| 50 |
+
quantiles=dataset_config['quantiles'],
|
| 51 |
+
min_trade_usd=dataset_config['min_trade_usd'],
|
| 52 |
+
max_seq_len=dataset_config['max_seq_len']
|
| 53 |
+
)
|
| 54 |
+
_worker_dataset.sampled_mints = dataset_config['sampled_mints']
|
| 55 |
+
_worker_return_class_map = return_class_map
|
| 56 |
+
_worker_quality_scores_map = quality_scores_map
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _process_single_token_context(args):
|
| 60 |
+
idx, mint_addr, samples_per_token, output_dir = args
|
| 61 |
+
global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
|
| 62 |
+
try:
|
| 63 |
+
class_id = _worker_return_class_map.get(mint_addr)
|
| 64 |
+
if class_id is None:
|
| 65 |
+
return {'status': 'skipped', 'reason': 'not in class map', 'mint': mint_addr}
|
| 66 |
+
contexts = _worker_dataset.__cacheitem_context__(idx, num_samples_per_token=samples_per_token)
|
| 67 |
+
if not contexts:
|
| 68 |
+
return {'status': 'skipped', 'reason': 'no valid contexts', 'mint': mint_addr}
|
| 69 |
+
q_score = _worker_quality_scores_map.get(mint_addr)
|
| 70 |
+
if q_score is None:
|
| 71 |
+
return {'status': 'skipped', 'reason': 'no quality score', 'mint': mint_addr}
|
| 72 |
+
saved_files = []
|
| 73 |
+
for ctx_idx, ctx in enumerate(contexts):
|
| 74 |
+
ctx["quality_score"] = q_score
|
| 75 |
+
ctx["class_id"] = class_id
|
| 76 |
+
ctx["source_token"] = mint_addr
|
| 77 |
+
ctx["cache_mode"] = "context"
|
| 78 |
+
filename = f"sample_{mint_addr[:16]}_{ctx_idx}.pt"
|
| 79 |
+
output_path = Path(output_dir) / filename
|
| 80 |
+
torch.save(ctx, output_path)
|
| 81 |
+
saved_files.append(filename)
|
| 82 |
+
return {'status': 'success', 'mint': mint_addr, 'class_id': class_id, 'q_score': q_score, 'n_contexts': len(contexts), 'n_events': len(contexts[0].get('event_sequence', [])) if contexts else 0, 'files': saved_files}
|
| 83 |
+
except Exception as e:
|
| 84 |
+
import traceback
|
| 85 |
+
return {'status': 'error', 'mint': mint_addr, 'error': str(e), 'traceback': traceback.format_exc()}
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _process_single_token_raw(args):
|
| 89 |
+
idx, mint_addr, output_dir = args
|
| 90 |
+
global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
|
| 91 |
+
try:
|
| 92 |
+
class_id = _worker_return_class_map.get(mint_addr)
|
| 93 |
+
if class_id is None:
|
| 94 |
+
return {'status': 'skipped', 'reason': 'not in class map', 'mint': mint_addr}
|
| 95 |
+
item = _worker_dataset.__cacheitem__(idx)
|
| 96 |
+
if item is None:
|
| 97 |
+
return {'status': 'skipped', 'reason': 'cacheitem returned None', 'mint': mint_addr}
|
| 98 |
+
q_score = _worker_quality_scores_map.get(mint_addr)
|
| 99 |
+
if q_score is None:
|
| 100 |
+
return {'status': 'skipped', 'reason': 'no quality score', 'mint': mint_addr}
|
| 101 |
+
item["quality_score"] = q_score
|
| 102 |
+
item["class_id"] = class_id
|
| 103 |
+
item["cache_mode"] = "raw"
|
| 104 |
+
filename = f"sample_{mint_addr[:16]}.pt"
|
| 105 |
+
output_path = Path(output_dir) / filename
|
| 106 |
+
torch.save(item, output_path)
|
| 107 |
+
return {'status': 'success', 'mint': mint_addr, 'class_id': class_id, 'q_score': q_score, 'n_trades': len(item.get('trades', [])), 'files': [filename]}
|
| 108 |
+
except Exception as e:
|
| 109 |
+
import traceback
|
| 110 |
+
return {'status': 'error', 'mint': mint_addr, 'error': str(e), 'traceback': traceback.format_exc()}
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def compute_save_ohlc_stats(client, output_path):
|
| 114 |
+
print(f"INFO: Computing OHLC stats...")
|
| 115 |
+
query = """SELECT AVG(t.price_usd), stddevPop(t.price_usd), AVG(t.price), stddevPop(t.price), AVG(t.total_usd), stddevPop(t.total_usd) FROM trades AS t WHERE t.price_usd > 0 AND t.total_usd > 0"""
|
| 116 |
+
try:
|
| 117 |
+
result = client.execute(query)
|
| 118 |
+
if result and result[0]:
|
| 119 |
+
row = result[0]
|
| 120 |
+
stats = {"mean_price_usd": float(row[0] or 0), "std_price_usd": float(row[1] or 1), "mean_price_native": float(row[2] or 0), "std_price_native": float(row[3] or 1), "mean_trade_value_usd": float(row[4] or 0), "std_trade_value_usd": float(row[5] or 1)}
|
| 121 |
+
else:
|
| 122 |
+
stats = {"mean_price_usd": 0.0, "std_price_usd": 1.0, "mean_price_native": 0.0, "std_price_native": 1.0, "mean_trade_value_usd": 0.0, "std_trade_value_usd": 1.0}
|
| 123 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 124 |
+
np.savez(output_path, **stats)
|
| 125 |
+
print(f"INFO: Saved OHLC stats to {output_path}")
|
| 126 |
+
except Exception as e:
|
| 127 |
+
print(f"ERROR: Failed to compute OHLC stats: {e}")
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def main():
|
| 131 |
+
load_dotenv()
|
| 132 |
+
mp.set_start_method('spawn', force=True)
|
| 133 |
+
|
| 134 |
+
hf_token = os.getenv("HF_TOKEN")
|
| 135 |
+
if hf_token:
|
| 136 |
+
print(f"INFO: Logging in to Hugging Face...")
|
| 137 |
+
huggingface_hub.login(token=hf_token)
|
| 138 |
+
|
| 139 |
+
parser = argparse.ArgumentParser()
|
| 140 |
+
parser.add_argument("--output_dir", type=str, default="data/cache")
|
| 141 |
+
parser.add_argument("--max_samples", type=int, default=None)
|
| 142 |
+
parser.add_argument("--start_date", type=str, default=None)
|
| 143 |
+
parser.add_argument("--ohlc_stats_path", type=str, default="data/ohlc_stats.npz")
|
| 144 |
+
parser.add_argument("--min_trade_usd", type=float, default=0.0)
|
| 145 |
+
parser.add_argument("--cache_mode", type=str, default="raw", choices=["raw", "context"])
|
| 146 |
+
parser.add_argument("--context_length", type=int, default=8192)
|
| 147 |
+
parser.add_argument("--min_trades", type=int, default=10)
|
| 148 |
+
parser.add_argument("--samples_per_token", type=int, default=1)
|
| 149 |
+
parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[30, 60, 120, 240, 420])
|
| 150 |
+
parser.add_argument("--quantiles", type=float, nargs="+", default=[0.1, 0.5, 0.9])
|
| 151 |
+
parser.add_argument("--num_workers", type=int, default=1)
|
| 152 |
+
parser.add_argument("--clickhouse_host", type=str, default=os.getenv("CLICKHOUSE_HOST", "localhost"))
|
| 153 |
+
parser.add_argument("--clickhouse_port", type=int, default=int(os.getenv("CLICKHOUSE_PORT", 9000)))
|
| 154 |
+
parser.add_argument("--neo4j_uri", type=str, default=os.getenv("NEO4J_URI", "bolt://localhost:7687"))
|
| 155 |
+
parser.add_argument("--neo4j_user", type=str, default=os.getenv("NEO4J_USER", "neo4j"))
|
| 156 |
+
parser.add_argument("--neo4j_password", type=str, default=os.getenv("NEO4J_PASSWORD", "password"))
|
| 157 |
+
args = parser.parse_args()
|
| 158 |
+
|
| 159 |
+
if args.num_workers == 0:
|
| 160 |
+
args.num_workers = max(1, mp.cpu_count() - 4)
|
| 161 |
+
|
| 162 |
+
output_dir = Path(args.output_dir)
|
| 163 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 164 |
+
start_date_dt = datetime.datetime.strptime(args.start_date, "%Y-%m-%d") if args.start_date else None
|
| 165 |
+
|
| 166 |
+
print(f"INFO: Initializing DB Connections...")
|
| 167 |
+
clickhouse_client = ClickHouseClient(host=args.clickhouse_host, port=args.clickhouse_port)
|
| 168 |
+
neo4j_driver = GraphDatabase.driver(args.neo4j_uri, auth=(args.neo4j_user, args.neo4j_password))
|
| 169 |
+
|
| 170 |
+
try:
|
| 171 |
+
compute_save_ohlc_stats(clickhouse_client, args.ohlc_stats_path)
|
| 172 |
+
|
| 173 |
+
from data.data_loader import OracleDataset
|
| 174 |
+
from data.data_fetcher import DataFetcher
|
| 175 |
+
data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
|
| 176 |
+
|
| 177 |
+
print("INFO: Fetching Return Classification Map...")
|
| 178 |
+
return_class_map, _ = get_return_class_map(clickhouse_client)
|
| 179 |
+
print(f"INFO: Loaded {len(return_class_map)} classified tokens.")
|
| 180 |
+
|
| 181 |
+
print("INFO: Fetching Quality Scores...")
|
| 182 |
+
quality_scores_map = get_token_quality_scores(clickhouse_client)
|
| 183 |
+
print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
|
| 184 |
+
|
| 185 |
+
dataset = OracleDataset(data_fetcher=data_fetcher, max_samples=args.max_samples, start_date=start_date_dt, ohlc_stats_path=args.ohlc_stats_path, horizons_seconds=args.horizons_seconds, quantiles=args.quantiles, min_trade_usd=args.min_trade_usd, max_seq_len=args.context_length)
|
| 186 |
+
|
| 187 |
+
if len(dataset) == 0:
|
| 188 |
+
print("WARNING: No samples. Exiting.")
|
| 189 |
+
return
|
| 190 |
+
|
| 191 |
+
# Filter mints by return_class_map
|
| 192 |
+
original_size = len(dataset.sampled_mints)
|
| 193 |
+
filtered_mints = [m for m in dataset.sampled_mints if m['mint_address'] in return_class_map]
|
| 194 |
+
print(f"INFO: Filtered by class map: {original_size} -> {len(filtered_mints)} tokens")
|
| 195 |
+
|
| 196 |
+
# Pre-filter: only keep tokens with >= min_trades trades (fast ClickHouse count query)
|
| 197 |
+
print(f"INFO: Pre-filtering tokens by trade count (>= {args.min_trades} trades)...")
|
| 198 |
+
trade_counts = clickhouse_client.execute("""
|
| 199 |
+
SELECT base_address, count() as cnt
|
| 200 |
+
FROM trades
|
| 201 |
+
GROUP BY base_address
|
| 202 |
+
HAVING cnt >= %(min_trades)s
|
| 203 |
+
""", {'min_trades': args.min_trades})
|
| 204 |
+
valid_tokens = {row[0] for row in trade_counts}
|
| 205 |
+
pre_filter_size = len(filtered_mints)
|
| 206 |
+
filtered_mints = [m for m in filtered_mints if m['mint_address'] in valid_tokens]
|
| 207 |
+
print(f"INFO: Pre-filtered by trade count: {pre_filter_size} -> {len(filtered_mints)} tokens (removed {pre_filter_size - len(filtered_mints)} with < {args.min_trades} trades)")
|
| 208 |
+
|
| 209 |
+
# Also filter by quality score availability
|
| 210 |
+
pre_quality_size = len(filtered_mints)
|
| 211 |
+
filtered_mints = [m for m in filtered_mints if m['mint_address'] in quality_scores_map]
|
| 212 |
+
print(f"INFO: Filtered by quality score: {pre_quality_size} -> {len(filtered_mints)} tokens")
|
| 213 |
+
|
| 214 |
+
if len(filtered_mints) == 0:
|
| 215 |
+
print("WARNING: No tokens after filtering.")
|
| 216 |
+
return
|
| 217 |
+
|
| 218 |
+
print(f"INFO: Cache mode: {args.cache_mode}, Workers: {args.num_workers}")
|
| 219 |
+
|
| 220 |
+
db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password}
|
| 221 |
+
dataset_config = {'max_samples': args.max_samples, 'start_date': start_date_dt, 'ohlc_stats_path': args.ohlc_stats_path, 'horizons_seconds': args.horizons_seconds, 'quantiles': args.quantiles, 'min_trade_usd': args.min_trade_usd, 'max_seq_len': args.context_length, 'sampled_mints': filtered_mints}
|
| 222 |
+
|
| 223 |
+
# Build tasks with class-aware multi-sampling for balanced cache
|
| 224 |
+
import random
|
| 225 |
+
from collections import Counter, defaultdict
|
| 226 |
+
|
| 227 |
+
# Count eligible tokens per class
|
| 228 |
+
eligible_class_counts = Counter()
|
| 229 |
+
mints_by_class = defaultdict(list)
|
| 230 |
+
for i, m in enumerate(filtered_mints):
|
| 231 |
+
cid = return_class_map.get(m['mint_address'])
|
| 232 |
+
if cid is not None:
|
| 233 |
+
eligible_class_counts[cid] += 1
|
| 234 |
+
mints_by_class[cid].append((i, m))
|
| 235 |
+
|
| 236 |
+
print(f"INFO: Eligible tokens per class: {dict(sorted(eligible_class_counts.items()))}")
|
| 237 |
+
|
| 238 |
+
# Compute balanced samples_per_token for each class
|
| 239 |
+
num_classes = len(eligible_class_counts)
|
| 240 |
+
if args.max_samples:
|
| 241 |
+
target_total = args.max_samples
|
| 242 |
+
else:
|
| 243 |
+
target_total = 15000 # Default target: 15k balanced files
|
| 244 |
+
target_per_class = target_total // max(num_classes, 1)
|
| 245 |
+
|
| 246 |
+
class_multipliers = {}
|
| 247 |
+
class_token_caps = {}
|
| 248 |
+
for cid, count in eligible_class_counts.items():
|
| 249 |
+
if count >= target_per_class:
|
| 250 |
+
# Enough tokens — 1 sample each, cap token count
|
| 251 |
+
class_multipliers[cid] = 1
|
| 252 |
+
class_token_caps[cid] = target_per_class
|
| 253 |
+
else:
|
| 254 |
+
# Not enough tokens — multi-sample, use all tokens
|
| 255 |
+
class_multipliers[cid] = min(10, max(1, math.ceil(target_per_class / max(count, 1))))
|
| 256 |
+
class_token_caps[cid] = count
|
| 257 |
+
|
| 258 |
+
print(f"INFO: Target total: {target_total}, Target per class: {target_per_class}")
|
| 259 |
+
print(f"INFO: Class multipliers: {dict(sorted(class_multipliers.items()))}")
|
| 260 |
+
print(f"INFO: Class token caps: {dict(sorted(class_token_caps.items()))}")
|
| 261 |
+
|
| 262 |
+
# Build balanced task list
|
| 263 |
+
tasks = []
|
| 264 |
+
for cid, mint_list in mints_by_class.items():
|
| 265 |
+
random.shuffle(mint_list)
|
| 266 |
+
cap = class_token_caps.get(cid, len(mint_list))
|
| 267 |
+
spt = class_multipliers.get(cid, 1)
|
| 268 |
+
# Override with CLI --samples_per_token if explicitly set > 1
|
| 269 |
+
if args.samples_per_token > 1:
|
| 270 |
+
spt = args.samples_per_token
|
| 271 |
+
for i, m in mint_list[:cap]:
|
| 272 |
+
mint_addr = m['mint_address']
|
| 273 |
+
if args.cache_mode == "context":
|
| 274 |
+
tasks.append((i, mint_addr, spt, str(output_dir)))
|
| 275 |
+
else:
|
| 276 |
+
tasks.append((i, mint_addr, str(output_dir)))
|
| 277 |
+
|
| 278 |
+
random.shuffle(tasks) # Shuffle tasks for even load distribution across workers
|
| 279 |
+
expected_files = sum(
|
| 280 |
+
class_multipliers.get(cid, 1) * min(class_token_caps.get(cid, len(ml)), len(ml))
|
| 281 |
+
for cid, ml in mints_by_class.items()
|
| 282 |
+
)
|
| 283 |
+
print(f"INFO: Total tasks: {len(tasks)} (expected ~{expected_files} output files, target ~{target_total})")
|
| 284 |
+
|
| 285 |
+
success_count, skipped_count, error_count = 0, 0, 0
|
| 286 |
+
class_distribution = {}
|
| 287 |
+
|
| 288 |
+
# --- Resume support: skip tokens that already have cached files ---
|
| 289 |
+
existing_files = set(f.name for f in output_dir.glob("sample_*.pt"))
|
| 290 |
+
if existing_files:
|
| 291 |
+
pre_resume = len(tasks)
|
| 292 |
+
filtered_tasks = []
|
| 293 |
+
already_cached = 0
|
| 294 |
+
for task in tasks:
|
| 295 |
+
mint_addr = task[1] # task = (idx, mint_addr, ...)
|
| 296 |
+
# Check if any file exists for this mint (context mode: sample_MINT_0.pt, raw mode: sample_MINT.pt)
|
| 297 |
+
mint_prefix = f"sample_{mint_addr[:16]}"
|
| 298 |
+
has_cached = any(ef.startswith(mint_prefix) for ef in existing_files)
|
| 299 |
+
if has_cached:
|
| 300 |
+
already_cached += 1
|
| 301 |
+
# Count existing files toward class distribution
|
| 302 |
+
cid = return_class_map.get(mint_addr)
|
| 303 |
+
if cid is not None:
|
| 304 |
+
class_distribution[cid] = class_distribution.get(cid, 0) + 1
|
| 305 |
+
success_count += 1
|
| 306 |
+
else:
|
| 307 |
+
filtered_tasks.append(task)
|
| 308 |
+
tasks = filtered_tasks
|
| 309 |
+
print(f"INFO: Resume: {already_cached} tokens already cached, {len(tasks)} remaining (was {pre_resume})")
|
| 310 |
+
|
| 311 |
+
print(f"INFO: Starting to cache {len(tasks)} tokens...")
|
| 312 |
+
process_fn = _process_single_token_context if args.cache_mode == "context" else _process_single_token_raw
|
| 313 |
+
|
| 314 |
+
import time as _time
|
| 315 |
+
|
| 316 |
+
def _log_progress(task_num, total, start_time, recent_times, success_count, skipped_count, error_count):
|
| 317 |
+
"""Print progress with rolling ETA every 10 tokens."""
|
| 318 |
+
if (task_num + 1) % 10 == 0 and recent_times:
|
| 319 |
+
avg_time = sum(recent_times) / len(recent_times)
|
| 320 |
+
remaining = total - (task_num + 1)
|
| 321 |
+
eta_seconds = avg_time * remaining
|
| 322 |
+
eta_hours = eta_seconds / 3600
|
| 323 |
+
wall_elapsed = _time.perf_counter() - start_time
|
| 324 |
+
speed = (task_num + 1) / wall_elapsed
|
| 325 |
+
tqdm.write(
|
| 326 |
+
f" [PROGRESS] {task_num+1}/{total} | "
|
| 327 |
+
f"Speed: {speed:.1f} tok/s ({speed*60:.0f} tok/min) | "
|
| 328 |
+
f"Avg: {avg_time:.1f}s/tok | "
|
| 329 |
+
f"ETA: {eta_hours:.1f}h | "
|
| 330 |
+
f"OK: {success_count} Skip: {skipped_count} Err: {error_count}"
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
# Error log file for diagnosing failures
|
| 334 |
+
error_log_path = Path(args.output_dir) / "cache_errors.log"
|
| 335 |
+
error_samples = [] # First 20 unique error messages
|
| 336 |
+
|
| 337 |
+
if args.num_workers == 1:
|
| 338 |
+
print("INFO: Single-threaded mode...")
|
| 339 |
+
_init_worker(db_config, dataset_config, return_class_map, quality_scores_map)
|
| 340 |
+
start_time = _time.perf_counter()
|
| 341 |
+
recent_times = []
|
| 342 |
+
for task_num, task in enumerate(tqdm(tasks, desc="Caching", unit="tok")):
|
| 343 |
+
t0 = _time.perf_counter()
|
| 344 |
+
result = process_fn(task)
|
| 345 |
+
elapsed = _time.perf_counter() - t0
|
| 346 |
+
recent_times.append(elapsed)
|
| 347 |
+
if len(recent_times) > 50:
|
| 348 |
+
recent_times.pop(0)
|
| 349 |
+
if result['status'] == 'success':
|
| 350 |
+
success_count += 1
|
| 351 |
+
class_distribution[result['class_id']] = class_distribution.get(result['class_id'], 0) + 1
|
| 352 |
+
elif result['status'] == 'skipped':
|
| 353 |
+
skipped_count += 1
|
| 354 |
+
else:
|
| 355 |
+
error_count += 1
|
| 356 |
+
err_msg = result.get('error', 'unknown')
|
| 357 |
+
tqdm.write(f"ERROR: {result['mint'][:16]} - {err_msg}")
|
| 358 |
+
if len(error_samples) < 20:
|
| 359 |
+
error_samples.append({'mint': result.get('mint'), 'error': err_msg, 'traceback': result.get('traceback', '')})
|
| 360 |
+
_log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
|
| 361 |
+
else:
|
| 362 |
+
print(f"INFO: Running with {args.num_workers} workers...")
|
| 363 |
+
start_time = _time.perf_counter()
|
| 364 |
+
recent_times = []
|
| 365 |
+
with ProcessPoolExecutor(max_workers=args.num_workers, initializer=_init_worker, initargs=(db_config, dataset_config, return_class_map, quality_scores_map)) as executor:
|
| 366 |
+
futures = {executor.submit(process_fn, task): task for task in tasks}
|
| 367 |
+
for task_num, future in enumerate(tqdm(as_completed(futures), total=len(futures), desc="Caching", unit="tok")):
|
| 368 |
+
t0 = _time.perf_counter()
|
| 369 |
+
try:
|
| 370 |
+
result = future.result(timeout=300)
|
| 371 |
+
elapsed = _time.perf_counter() - t0
|
| 372 |
+
recent_times.append(elapsed)
|
| 373 |
+
if len(recent_times) > 50:
|
| 374 |
+
recent_times.pop(0)
|
| 375 |
+
if result['status'] == 'success':
|
| 376 |
+
success_count += 1
|
| 377 |
+
class_distribution[result['class_id']] = class_distribution.get(result['class_id'], 0) + 1
|
| 378 |
+
elif result['status'] == 'skipped':
|
| 379 |
+
skipped_count += 1
|
| 380 |
+
else:
|
| 381 |
+
error_count += 1
|
| 382 |
+
err_msg = result.get('error', 'unknown')
|
| 383 |
+
if len(error_samples) < 20:
|
| 384 |
+
error_samples.append({'mint': result.get('mint'), 'error': err_msg, 'traceback': result.get('traceback', '')})
|
| 385 |
+
if error_count <= 5:
|
| 386 |
+
tqdm.write(f"ERROR: {result.get('mint', '?')[:16]} - {err_msg}")
|
| 387 |
+
except Exception as e:
|
| 388 |
+
error_count += 1
|
| 389 |
+
tqdm.write(f"WORKER ERROR: {e}")
|
| 390 |
+
_log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
|
| 391 |
+
|
| 392 |
+
# Write error log
|
| 393 |
+
if error_samples:
|
| 394 |
+
with open(error_log_path, 'w') as ef:
|
| 395 |
+
for i, es in enumerate(error_samples):
|
| 396 |
+
ef.write(f"=== Error {i+1} === Token: {es['mint']}\n")
|
| 397 |
+
ef.write(f"Error: {es['error']}\n")
|
| 398 |
+
ef.write(f"Traceback:\n{es['traceback']}\n\n")
|
| 399 |
+
print(f"INFO: First {len(error_samples)} error tracebacks saved to {error_log_path}")
|
| 400 |
+
|
| 401 |
+
print("INFO: Building metadata...")
|
| 402 |
+
file_class_map = {}
|
| 403 |
+
for f in sorted(output_dir.glob("sample_*.pt")):
|
| 404 |
+
try:
|
| 405 |
+
file_class_map[f.name] = torch.load(f, map_location="cpu", weights_only=False).get("class_id", 0)
|
| 406 |
+
except:
|
| 407 |
+
pass
|
| 408 |
+
|
| 409 |
+
with open(output_dir / "class_metadata.json", 'w') as f:
|
| 410 |
+
json.dump({
|
| 411 |
+
'file_class_map': file_class_map,
|
| 412 |
+
'class_distribution': {str(k): v for k, v in class_distribution.items()},
|
| 413 |
+
'cache_mode': args.cache_mode,
|
| 414 |
+
'num_workers': args.num_workers,
|
| 415 |
+
'horizons_seconds': args.horizons_seconds,
|
| 416 |
+
'quantiles': args.quantiles,
|
| 417 |
+
'class_multipliers': {str(k): v for k, v in class_multipliers.items()},
|
| 418 |
+
'class_token_caps': {str(k): v for k, v in class_token_caps.items()},
|
| 419 |
+
'target_total': target_total,
|
| 420 |
+
'target_per_class': target_per_class,
|
| 421 |
+
}, f, indent=2)
|
| 422 |
+
|
| 423 |
+
print(f"\n--- Done ---\nSuccess: {success_count}, Skipped: {skipped_count}, Errors: {error_count}\nFiles: {len(file_class_map)}\nLocation: {output_dir.resolve()}")
|
| 424 |
+
|
| 425 |
+
finally:
|
| 426 |
+
clickhouse_client.disconnect()
|
| 427 |
+
neo4j_driver.close()
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
if __name__ == "__main__":
|
| 431 |
+
main()
|
scripts/analyze_distribution.py
CHANGED
|
@@ -313,8 +313,108 @@ def print_stats(name, values):
|
|
| 313 |
|
| 314 |
print(f" {name}: mean={mean:.4f} p50={p50:.4f} p90={p90:.4f} p99={p99:.4f} nonzero_rate={nonzero_rate:.3f} (n={len(vals)})")
|
| 315 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
def analyze():
|
| 317 |
client = get_client()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
data = fetch_all_metrics(client)
|
| 319 |
final_buckets, thresholds, count_manipulated = _classify_tokens(data)
|
| 320 |
|
|
|
|
| 313 |
|
| 314 |
print(f" {name}: mean={mean:.4f} p50={p50:.4f} p90={p90:.4f} p99={p99:.4f} nonzero_rate={nonzero_rate:.3f} (n={len(vals)})")
|
| 315 |
|
| 316 |
+
def fetch_wallet_pnl_stats(client):
|
| 317 |
+
print(" -> Fetching Wallet PnL Quantiles (7d, 30d) - Unique per wallet...")
|
| 318 |
+
# Use argMax to get latest entry per wallet (table is a time-series dump)
|
| 319 |
+
query = """
|
| 320 |
+
WITH unique_wallets AS (
|
| 321 |
+
SELECT
|
| 322 |
+
wallet_address,
|
| 323 |
+
argMax(stats_30d_realized_profit_pnl, updated_at) as pnl_30d,
|
| 324 |
+
argMax(stats_7d_realized_profit_pnl, updated_at) as pnl_7d
|
| 325 |
+
FROM wallet_profile_metrics
|
| 326 |
+
GROUP BY wallet_address
|
| 327 |
+
)
|
| 328 |
+
SELECT
|
| 329 |
+
count() as n,
|
| 330 |
+
countIf(pnl_30d > 0.001) as pos_30d,
|
| 331 |
+
quantiles(0.5, 0.9, 0.95, 0.99, 0.999)(pnl_30d) as q_30d,
|
| 332 |
+
max(pnl_30d) as max_30d,
|
| 333 |
+
|
| 334 |
+
countIf(pnl_7d > 0.001) as pos_7d,
|
| 335 |
+
quantiles(0.5, 0.9, 0.95, 0.99, 0.999)(pnl_7d) as q_7d,
|
| 336 |
+
max(pnl_7d) as max_7d
|
| 337 |
+
FROM unique_wallets
|
| 338 |
+
WHERE pnl_30d > -999 OR pnl_7d > -999
|
| 339 |
+
"""
|
| 340 |
+
rows = client.execute(query)
|
| 341 |
+
if not rows: return None
|
| 342 |
+
return rows[0]
|
| 343 |
+
|
| 344 |
+
def fetch_trade_stats(client):
|
| 345 |
+
print(" -> Fetching Trade Quantiles (USD & Supply %)...")
|
| 346 |
+
query = """
|
| 347 |
+
SELECT
|
| 348 |
+
count() as n,
|
| 349 |
+
quantiles(0.5, 0.9, 0.95, 0.99, 0.999)(t.total_usd) as q_usd,
|
| 350 |
+
quantiles(0.5, 0.9, 0.95, 0.99, 0.999)((t.base_amount / m.total_supply) * 100) as q_sup
|
| 351 |
+
FROM trades t
|
| 352 |
+
JOIN mints m ON t.base_address = m.mint_address
|
| 353 |
+
WHERE m.total_supply > 0
|
| 354 |
+
"""
|
| 355 |
+
rows = client.execute(query)
|
| 356 |
+
if not rows: return None
|
| 357 |
+
return rows[0]
|
| 358 |
+
|
| 359 |
+
def fetch_kol_stats(client):
|
| 360 |
+
print(" -> Fetching KOL stats from wallet_socials...")
|
| 361 |
+
query = """
|
| 362 |
+
SELECT
|
| 363 |
+
uniq(wallet_address) as total_wallets,
|
| 364 |
+
uniqIf(wallet_address, kolscan_name != '' OR cabalspy_name != '' OR axiom_kol_name != '') as kols
|
| 365 |
+
FROM wallet_socials
|
| 366 |
+
"""
|
| 367 |
+
rows = client.execute(query)
|
| 368 |
+
print(f" (DEBUG) KOL query result: {rows}")
|
| 369 |
+
if rows:
|
| 370 |
+
return rows[0]
|
| 371 |
+
return (0, 0)
|
| 372 |
+
|
| 373 |
+
def print_quantiles(name, n, pos_rate, q, max_val=None):
|
| 374 |
+
# q is list [p50, p90, p95, p99, p999]
|
| 375 |
+
print(f"\n[{name}] (n={n})")
|
| 376 |
+
if pos_rate is not None:
|
| 377 |
+
print(f" Positive Rate: {pos_rate*100:.1f}%")
|
| 378 |
+
print(f" p50={q[0]:.4f}")
|
| 379 |
+
print(f" p90={q[1]:.4f}")
|
| 380 |
+
print(f" p95={q[2]:.4f}")
|
| 381 |
+
print(f" p99={q[3]:.4f}")
|
| 382 |
+
print(f" p99.9={q[4]:.4f}")
|
| 383 |
+
if max_val is not None:
|
| 384 |
+
print(f" Max={max_val:.4f}")
|
| 385 |
+
|
| 386 |
+
def analyze_thresholds(client):
|
| 387 |
+
print("\n=== THRESHOLD DISTRIBUTION ANALYSIS (DB-Side) ===")
|
| 388 |
+
|
| 389 |
+
# 1. PnL
|
| 390 |
+
pnl_row = fetch_wallet_pnl_stats(client)
|
| 391 |
+
if pnl_row:
|
| 392 |
+
n, pos_30d, q_30d, max_30d, pos_7d, q_7d, max_7d = pnl_row
|
| 393 |
+
print_quantiles("Wallet PnL (30d)", n, pos_30d/n if n>0 else 0, q_30d, max_30d)
|
| 394 |
+
print_quantiles("Wallet PnL (7d)", n, pos_7d/n if n>0 else 0, q_7d, max_7d)
|
| 395 |
+
|
| 396 |
+
# 2. Trades
|
| 397 |
+
trade_row = fetch_trade_stats(client)
|
| 398 |
+
if trade_row:
|
| 399 |
+
n, q_usd, q_sup = trade_row
|
| 400 |
+
print_quantiles("Trade USD Size", n, None, q_usd)
|
| 401 |
+
print_quantiles("Trade Supply %", n, None, q_sup)
|
| 402 |
+
|
| 403 |
+
# 3. KOLs
|
| 404 |
+
total, kols = fetch_kol_stats(client)
|
| 405 |
+
if total > 0:
|
| 406 |
+
print("\n[KOL Statistics]")
|
| 407 |
+
print(f" Total Wallets with Socials: {total}")
|
| 408 |
+
print(f" Identified KOLs: {kols}")
|
| 409 |
+
print(f" KOL Ratio: {(kols/total)*100:.2f}%")
|
| 410 |
+
|
| 411 |
+
|
| 412 |
def analyze():
|
| 413 |
client = get_client()
|
| 414 |
+
|
| 415 |
+
# Run new analysis first
|
| 416 |
+
analyze_thresholds(client)
|
| 417 |
+
|
| 418 |
data = fetch_all_metrics(client)
|
| 419 |
final_buckets, thresholds, count_manipulated = _classify_tokens(data)
|
| 420 |
|
scripts/dump_cache_sample.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Dump a cached .pt sample to JSON for manual debugging.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python scripts/dump_cache_sample.py # Dump first sample
|
| 7 |
+
python scripts/dump_cache_sample.py --index 5 # Dump sample at index 5
|
| 8 |
+
python scripts/dump_cache_sample.py --file data/cache/sample_ABC123.pt # Dump specific file
|
| 9 |
+
python scripts/dump_cache_sample.py --output debug.json # Custom output path
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import json
|
| 14 |
+
import sys
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
# Add project root to path so torch.load can find project modules when unpickling
|
| 18 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import numpy as np
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from datetime import datetime
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def convert_to_serializable(obj):
|
| 27 |
+
"""Recursively convert non-JSON-serializable objects."""
|
| 28 |
+
if obj is None:
|
| 29 |
+
return None
|
| 30 |
+
if isinstance(obj, (str, int, float, bool)):
|
| 31 |
+
return obj
|
| 32 |
+
if isinstance(obj, (np.integer,)):
|
| 33 |
+
return int(obj)
|
| 34 |
+
if isinstance(obj, (np.floating,)):
|
| 35 |
+
return float(obj)
|
| 36 |
+
if isinstance(obj, np.ndarray):
|
| 37 |
+
return {"__type__": "ndarray", "shape": list(obj.shape), "dtype": str(obj.dtype), "data": obj.tolist()}
|
| 38 |
+
if isinstance(obj, torch.Tensor):
|
| 39 |
+
return {"__type__": "tensor", "shape": list(obj.shape), "dtype": str(obj.dtype), "data": obj.tolist()}
|
| 40 |
+
if isinstance(obj, datetime):
|
| 41 |
+
return {"__type__": "datetime", "value": obj.isoformat()}
|
| 42 |
+
if isinstance(obj, bytes):
|
| 43 |
+
return {"__type__": "bytes", "length": len(obj), "preview": obj[:100].hex() if len(obj) > 0 else ""}
|
| 44 |
+
if isinstance(obj, dict):
|
| 45 |
+
return {str(k): convert_to_serializable(v) for k, v in obj.items()}
|
| 46 |
+
if isinstance(obj, (list, tuple)):
|
| 47 |
+
return [convert_to_serializable(item) for item in obj]
|
| 48 |
+
if isinstance(obj, set):
|
| 49 |
+
return {"__type__": "set", "data": list(obj)}
|
| 50 |
+
# Fallback: try str representation
|
| 51 |
+
try:
|
| 52 |
+
return {"__type__": type(obj).__name__, "repr": str(obj)[:500]}
|
| 53 |
+
except:
|
| 54 |
+
return {"__type__": "unknown", "repr": "<not serializable>"}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def main():
|
| 58 |
+
parser = argparse.ArgumentParser(description="Dump cached .pt sample to JSON")
|
| 59 |
+
parser.add_argument("--index", "-i", type=int, default=0, help="Index of sample to dump (default: 0)")
|
| 60 |
+
parser.add_argument("--file", "-f", type=str, default=None, help="Direct path to .pt file (overrides --index)")
|
| 61 |
+
parser.add_argument("--cache_dir", "-c", type=str, default="data/cache", help="Cache directory (default: data/cache)")
|
| 62 |
+
parser.add_argument("--output", "-o", type=str, default=None, help="Output JSON path (default: auto-generated)")
|
| 63 |
+
parser.add_argument("--compact", action="store_true", help="Compact JSON output (no indentation)")
|
| 64 |
+
args = parser.parse_args()
|
| 65 |
+
|
| 66 |
+
# Determine which file to load
|
| 67 |
+
if args.file:
|
| 68 |
+
filepath = Path(args.file)
|
| 69 |
+
if not filepath.exists():
|
| 70 |
+
print(f"ERROR: File not found: {filepath}")
|
| 71 |
+
return 1
|
| 72 |
+
else:
|
| 73 |
+
cache_dir = Path(args.cache_dir)
|
| 74 |
+
if not cache_dir.is_dir():
|
| 75 |
+
print(f"ERROR: Cache directory not found: {cache_dir}")
|
| 76 |
+
return 1
|
| 77 |
+
|
| 78 |
+
cached_files = sorted(cache_dir.glob("sample_*.pt"))
|
| 79 |
+
if not cached_files:
|
| 80 |
+
print(f"ERROR: No sample_*.pt files found in {cache_dir}")
|
| 81 |
+
return 1
|
| 82 |
+
|
| 83 |
+
if args.index >= len(cached_files):
|
| 84 |
+
print(f"ERROR: Index {args.index} out of range. Found {len(cached_files)} files.")
|
| 85 |
+
return 1
|
| 86 |
+
|
| 87 |
+
filepath = cached_files[args.index]
|
| 88 |
+
|
| 89 |
+
print(f"Loading: {filepath}")
|
| 90 |
+
|
| 91 |
+
# Load the .pt file
|
| 92 |
+
try:
|
| 93 |
+
data = torch.load(filepath, map_location="cpu", weights_only=False)
|
| 94 |
+
except Exception as e:
|
| 95 |
+
print(f"ERROR: Failed to load file: {e}")
|
| 96 |
+
return 1
|
| 97 |
+
|
| 98 |
+
# Convert to JSON-serializable format
|
| 99 |
+
print("Converting to JSON-serializable format...")
|
| 100 |
+
serializable_data = convert_to_serializable(data)
|
| 101 |
+
|
| 102 |
+
# Add metadata
|
| 103 |
+
output_data = {
|
| 104 |
+
"__metadata__": {
|
| 105 |
+
"source_file": str(filepath.absolute()),
|
| 106 |
+
"dumped_at": datetime.now().isoformat(),
|
| 107 |
+
"cache_mode": data.get("cache_mode", "unknown") if isinstance(data, dict) else "unknown"
|
| 108 |
+
},
|
| 109 |
+
"data": serializable_data
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
# Determine output path
|
| 113 |
+
if args.output:
|
| 114 |
+
output_path = Path(args.output)
|
| 115 |
+
else:
|
| 116 |
+
# Default: Save to current directory (root) instead of inside cache dir
|
| 117 |
+
output_path = Path.cwd() / filepath.with_suffix(".json").name
|
| 118 |
+
|
| 119 |
+
# Write JSON
|
| 120 |
+
print(f"Writing to: {output_path}")
|
| 121 |
+
indent = None if args.compact else 2
|
| 122 |
+
with open(output_path, "w") as f:
|
| 123 |
+
json.dump(output_data, f, indent=indent, ensure_ascii=False)
|
| 124 |
+
|
| 125 |
+
# Print summary
|
| 126 |
+
if isinstance(data, dict):
|
| 127 |
+
print(f"\n=== Summary ===")
|
| 128 |
+
print(f"Top-level keys: {list(data.keys())}")
|
| 129 |
+
print(f"Cache mode: {data.get('cache_mode', 'not specified')}")
|
| 130 |
+
if 'event_sequence' in data:
|
| 131 |
+
print(f"Event count: {len(data['event_sequence'])}")
|
| 132 |
+
if 'trades' in data:
|
| 133 |
+
print(f"Trade count: {len(data['trades'])}")
|
| 134 |
+
if 'source_token' in data:
|
| 135 |
+
print(f"Source token: {data['source_token']}")
|
| 136 |
+
if 'class_id' in data:
|
| 137 |
+
print(f"Class ID: {data['class_id']}")
|
| 138 |
+
if 'quality_score' in data:
|
| 139 |
+
print(f"Quality score: {data['quality_score']}")
|
| 140 |
+
|
| 141 |
+
print(f"\nDone! JSON saved to: {output_path}")
|
| 142 |
+
return 0
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
exit(main())
|
train.py
CHANGED
|
@@ -406,7 +406,7 @@ def main() -> None:
|
|
| 406 |
hf_token = os.getenv("HF_TOKEN")
|
| 407 |
if hf_token:
|
| 408 |
print(f"Logging in to Hugging Face with token starting with: {hf_token[:4]}...")
|
| 409 |
-
huggingface_hub.login(token=hf_token)
|
| 410 |
else:
|
| 411 |
print("WARNING: HF_TOKEN not found in environment.")
|
| 412 |
|
|
@@ -437,7 +437,7 @@ def main() -> None:
|
|
| 437 |
collator_encoder = CollatorEncoder(
|
| 438 |
model_id=collator.model_id,
|
| 439 |
dtype=init_dtype,
|
| 440 |
-
device="
|
| 441 |
)
|
| 442 |
_set_worker_encoder(collator_encoder)
|
| 443 |
logger.info("SigLIP encoder pre-loaded successfully.")
|
|
|
|
| 406 |
hf_token = os.getenv("HF_TOKEN")
|
| 407 |
if hf_token:
|
| 408 |
print(f"Logging in to Hugging Face with token starting with: {hf_token[:4]}...")
|
| 409 |
+
pass # huggingface_hub.login(token=hf_token)
|
| 410 |
else:
|
| 411 |
print("WARNING: HF_TOKEN not found in environment.")
|
| 412 |
|
|
|
|
| 437 |
collator_encoder = CollatorEncoder(
|
| 438 |
model_id=collator.model_id,
|
| 439 |
dtype=init_dtype,
|
| 440 |
+
device="cuda" # Use GPU for encoding (requires num_workers=0)
|
| 441 |
)
|
| 442 |
_set_worker_encoder(collator_encoder)
|
| 443 |
logger.info("SigLIP encoder pre-loaded successfully.")
|
train.sh
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
accelerate launch train.py \
|
| 2 |
-
--epochs
|
| 3 |
--batch_size 8 \
|
| 4 |
--learning_rate 1e-4 \
|
| 5 |
--warmup_ratio 0.1 \
|
| 6 |
--grad_accum_steps 2 \
|
| 7 |
--max_grad_norm 1.0 \
|
| 8 |
--seed 42 \
|
| 9 |
-
--log_every
|
| 10 |
--save_every 2000 \
|
| 11 |
--tensorboard_dir runs/oracle \
|
| 12 |
--checkpoint_dir checkpoints \
|
|
@@ -15,8 +15,8 @@ accelerate launch train.py \
|
|
| 15 |
--horizons_seconds 30 60 120 240 420 \
|
| 16 |
--quantiles 0.1 0.5 0.9 \
|
| 17 |
--ohlc_stats_path ./data/ohlc_stats.npz \
|
| 18 |
-
--num_workers
|
| 19 |
--pin_memory \
|
| 20 |
--val_split 0.1 \
|
| 21 |
-
--val_every
|
| 22 |
"$@"
|
|
|
|
| 1 |
accelerate launch train.py \
|
| 2 |
+
--epochs 1 \
|
| 3 |
--batch_size 8 \
|
| 4 |
--learning_rate 1e-4 \
|
| 5 |
--warmup_ratio 0.1 \
|
| 6 |
--grad_accum_steps 2 \
|
| 7 |
--max_grad_norm 1.0 \
|
| 8 |
--seed 42 \
|
| 9 |
+
--log_every 3 \
|
| 10 |
--save_every 2000 \
|
| 11 |
--tensorboard_dir runs/oracle \
|
| 12 |
--checkpoint_dir checkpoints \
|
|
|
|
| 15 |
--horizons_seconds 30 60 120 240 420 \
|
| 16 |
--quantiles 0.1 0.5 0.9 \
|
| 17 |
--ohlc_stats_path ./data/ohlc_stats.npz \
|
| 18 |
+
--num_workers 0 \
|
| 19 |
--pin_memory \
|
| 20 |
--val_split 0.1 \
|
| 21 |
+
--val_every 50 \
|
| 22 |
"$@"
|