Upload folder using huggingface_hub
Browse files- .gitignore +3 -2
- data/data_collator.py +30 -12
- data/data_fetcher.py +57 -27
- data/data_loader.py +202 -79
- install.sh +26 -0
- log.log +2 -2
- models/model.py +76 -0
- models/multi_modal_processor.py +5 -2
- models/ohlc_embedder.py +3 -3
- train.py +35 -27
- train.sh +3 -5
.gitignore
CHANGED
|
@@ -6,9 +6,10 @@ __pycache__/
|
|
| 6 |
runs/
|
| 7 |
|
| 8 |
data/pump_fun
|
| 9 |
-
|
| 10 |
.env
|
| 11 |
|
| 12 |
data/cache
|
| 13 |
.tmp/
|
| 14 |
-
.cache/
|
|
|
|
|
|
| 6 |
runs/
|
| 7 |
|
| 8 |
data/pump_fun
|
| 9 |
+
data/cache
|
| 10 |
.env
|
| 11 |
|
| 12 |
data/cache
|
| 13 |
.tmp/
|
| 14 |
+
.cache/
|
| 15 |
+
checkpoints/
|
data/data_collator.py
CHANGED
|
@@ -6,11 +6,26 @@ from torch.nn.utils.rnn import pad_sequence
|
|
| 6 |
from typing import List, Dict, Any, Tuple, Optional, Union
|
| 7 |
from collections import defaultdict
|
| 8 |
from PIL import Image
|
| 9 |
-
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
NATIVE_MINT = "So11111111111111111111111111111111111111112"
|
| 16 |
QUOTE_MINTS = {
|
|
@@ -28,19 +43,19 @@ class MemecoinCollator:
|
|
| 28 |
def __init__(self,
|
| 29 |
event_type_to_id: Dict[str, int],
|
| 30 |
device: torch.device,
|
| 31 |
-
multi_modal_encoder: MultiModalEncoder,
|
| 32 |
dtype: torch.dtype,
|
| 33 |
-
|
| 34 |
-
|
| 35 |
):
|
| 36 |
self.event_type_to_id = event_type_to_id
|
| 37 |
self.pad_token_id = event_type_to_id.get('__PAD__', 0)
|
| 38 |
-
self.multi_modal_encoder = multi_modal_encoder
|
|
|
|
| 39 |
self.entity_pad_idx = 0
|
| 40 |
|
| 41 |
self.device = device
|
| 42 |
self.dtype = dtype
|
| 43 |
-
self.ohlc_seq_len =
|
| 44 |
self.max_seq_len = max_seq_len
|
| 45 |
|
| 46 |
def _collate_features_for_encoder(self, entities: List[Dict], feature_keys: List[str], device: torch.device, entity_type: str) -> Dict[str, Any]:
|
|
@@ -205,12 +220,15 @@ class MemecoinCollator:
|
|
| 205 |
all_items_sorted = batch_wide_pooler.get_all_items()
|
| 206 |
texts_to_encode = [d['item'] for d in all_items_sorted if isinstance(d['item'], str)]
|
| 207 |
images_to_encode = [d['item'] for d in all_items_sorted if isinstance(d['item'], Image.Image)]
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
-
text_embeds =
|
| 210 |
-
image_embeds =
|
| 211 |
|
| 212 |
# Create the final lookup tensor and fill it based on original item type
|
| 213 |
-
batch_embedding_pool = torch.zeros(len(all_items_sorted),
|
| 214 |
text_cursor, image_cursor = 0, 0
|
| 215 |
for i, item_data in enumerate(all_items_sorted):
|
| 216 |
if isinstance(item_data['item'], str):
|
|
|
|
| 6 |
from typing import List, Dict, Any, Tuple, Optional, Union
|
| 7 |
from collections import defaultdict
|
| 8 |
from PIL import Image
|
| 9 |
+
# --- GLOBAL SINGLETON FOR WORKER PROCESSES ---
|
| 10 |
+
_WORKER_ENCODER = None
|
| 11 |
|
| 12 |
+
def _get_worker_encoder(model_id: str, dtype: torch.dtype, device: torch.device):
|
| 13 |
+
"""
|
| 14 |
+
Lazy-loads the encoder on the worker process.
|
| 15 |
+
FORCED TO CPU to save VRAM when using multiple workers.
|
| 16 |
+
"""
|
| 17 |
+
global _WORKER_ENCODER
|
| 18 |
+
if _WORKER_ENCODER is None:
|
| 19 |
+
print(f"[Worker] Initializing MultiModalEncoder (SigLIP) on CPU (VRAM optimization)...")
|
| 20 |
+
# Local import to avoid top-level dependency issues
|
| 21 |
+
from models.multi_modal_processor import MultiModalEncoder
|
| 22 |
+
# Explicitly pass device="cpu"
|
| 23 |
+
_WORKER_ENCODER = MultiModalEncoder(model_id=model_id, dtype=dtype, device="cpu")
|
| 24 |
+
|
| 25 |
+
return _WORKER_ENCODER
|
| 26 |
+
|
| 27 |
+
import models.vocabulary as vocab
|
| 28 |
+
from data.data_loader import EmbeddingPooler
|
| 29 |
|
| 30 |
NATIVE_MINT = "So11111111111111111111111111111111111111112"
|
| 31 |
QUOTE_MINTS = {
|
|
|
|
| 43 |
def __init__(self,
|
| 44 |
event_type_to_id: Dict[str, int],
|
| 45 |
device: torch.device,
|
|
|
|
| 46 |
dtype: torch.dtype,
|
| 47 |
+
max_seq_len: Optional[int] = None,
|
| 48 |
+
model_id: str = "google/siglip-so400m-patch16-256-i18n"
|
| 49 |
):
|
| 50 |
self.event_type_to_id = event_type_to_id
|
| 51 |
self.pad_token_id = event_type_to_id.get('__PAD__', 0)
|
| 52 |
+
# self.multi_modal_encoder = multi_modal_encoder # DEPRECATED
|
| 53 |
+
self.model_id = model_id
|
| 54 |
self.entity_pad_idx = 0
|
| 55 |
|
| 56 |
self.device = device
|
| 57 |
self.dtype = dtype
|
| 58 |
+
self.ohlc_seq_len = 300 # HARDCODED
|
| 59 |
self.max_seq_len = max_seq_len
|
| 60 |
|
| 61 |
def _collate_features_for_encoder(self, entities: List[Dict], feature_keys: List[str], device: torch.device, entity_type: str) -> Dict[str, Any]:
|
|
|
|
| 220 |
all_items_sorted = batch_wide_pooler.get_all_items()
|
| 221 |
texts_to_encode = [d['item'] for d in all_items_sorted if isinstance(d['item'], str)]
|
| 222 |
images_to_encode = [d['item'] for d in all_items_sorted if isinstance(d['item'], Image.Image)]
|
| 223 |
+
|
| 224 |
+
# LAZY LOAD ENCODER
|
| 225 |
+
encoder = _get_worker_encoder(self.model_id, self.dtype, self.device)
|
| 226 |
|
| 227 |
+
text_embeds = encoder(texts_to_encode).to(self.device) if texts_to_encode else torch.empty(0)
|
| 228 |
+
image_embeds = encoder(images_to_encode).to(self.device) if images_to_encode else torch.empty(0)
|
| 229 |
|
| 230 |
# Create the final lookup tensor and fill it based on original item type
|
| 231 |
+
batch_embedding_pool = torch.zeros(len(all_items_sorted), encoder.embedding_dim, device=self.device, dtype=self.dtype)
|
| 232 |
text_cursor, image_cursor = 0, 0
|
| 233 |
for i, item_data in enumerate(all_items_sorted):
|
| 234 |
if isinstance(item_data['item'], str):
|
data/data_fetcher.py
CHANGED
|
@@ -626,9 +626,11 @@ class DataFetcher:
|
|
| 626 |
|
| 627 |
return token_details
|
| 628 |
|
| 629 |
-
def fetch_trades_for_token(self, token_address: str, T_cutoff: datetime.datetime, count_threshold: int, early_limit: int, recent_limit: int) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
|
| 630 |
"""
|
| 631 |
-
Fetches trades for a token
|
|
|
|
|
|
|
| 632 |
Returns three lists: early_trades, middle_trades, recent_trades.
|
| 633 |
"""
|
| 634 |
if not token_address:
|
|
@@ -636,31 +638,36 @@ class DataFetcher:
|
|
| 636 |
|
| 637 |
params = {'token_address': token_address, 'T_cutoff': T_cutoff}
|
| 638 |
|
| 639 |
-
# 1. Get the total count
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 651 |
query = "SELECT * FROM trades WHERE base_address = %(token_address)s AND timestamp <= %(T_cutoff)s ORDER BY timestamp ASC"
|
| 652 |
try:
|
| 653 |
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 654 |
if not rows: return [], [], []
|
| 655 |
columns = [col[0] for col in columns_info]
|
| 656 |
all_trades = [dict(zip(columns, row)) for row in rows]
|
| 657 |
-
# When not using HBH, all trades are considered "early"
|
| 658 |
return all_trades, [], []
|
| 659 |
except Exception as e:
|
| 660 |
print(f"ERROR: Failed to fetch all trades for token {token_address}: {e}")
|
| 661 |
return [], [], []
|
| 662 |
|
| 663 |
-
# 3. Use the H/B/H strategy if the count is high
|
| 664 |
print("INFO: Fetching trades using 3-part High-Def/Blurry/High-Def strategy.")
|
| 665 |
try:
|
| 666 |
# Fetch Early (High-Def)
|
|
@@ -792,7 +799,7 @@ class DataFetcher:
|
|
| 792 |
ORDER BY timestamp ASC
|
| 793 |
"""
|
| 794 |
params = {'token_address': token_address, 'T_cutoff': T_cutoff}
|
| 795 |
-
print(f"INFO: Fetching pool creation events for {token_address}.")
|
| 796 |
|
| 797 |
try:
|
| 798 |
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
|
@@ -831,7 +838,7 @@ class DataFetcher:
|
|
| 831 |
ORDER BY timestamp ASC
|
| 832 |
"""
|
| 833 |
params = {'pool_addresses': pool_addresses, 'T_cutoff': T_cutoff}
|
| 834 |
-
print(f"INFO: Fetching liquidity change events for {len(pool_addresses)} pool(s).")
|
| 835 |
|
| 836 |
try:
|
| 837 |
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
|
@@ -870,7 +877,7 @@ class DataFetcher:
|
|
| 870 |
ORDER BY timestamp ASC
|
| 871 |
"""
|
| 872 |
params = {'token': token_address, 'T_cutoff': T_cutoff}
|
| 873 |
-
print(f"INFO: Fetching fee collection events for {token_address}.")
|
| 874 |
|
| 875 |
try:
|
| 876 |
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
|
@@ -908,7 +915,7 @@ class DataFetcher:
|
|
| 908 |
ORDER BY timestamp ASC
|
| 909 |
"""
|
| 910 |
params = {'token': token_address, 'T_cutoff': T_cutoff}
|
| 911 |
-
print(f"INFO: Fetching migrations for {token_address}.")
|
| 912 |
try:
|
| 913 |
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 914 |
if not rows:
|
|
@@ -946,7 +953,7 @@ class DataFetcher:
|
|
| 946 |
ORDER BY timestamp ASC
|
| 947 |
"""
|
| 948 |
params = {'token': token_address, 'T_cutoff': T_cutoff}
|
| 949 |
-
print(f"INFO: Fetching burn events for {token_address}.")
|
| 950 |
|
| 951 |
try:
|
| 952 |
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
|
@@ -987,7 +994,7 @@ class DataFetcher:
|
|
| 987 |
ORDER BY timestamp ASC
|
| 988 |
"""
|
| 989 |
params = {'token': token_address, 'T_cutoff': T_cutoff}
|
| 990 |
-
print(f"INFO: Fetching supply lock events for {token_address}.")
|
| 991 |
|
| 992 |
try:
|
| 993 |
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
|
@@ -1020,7 +1027,7 @@ class DataFetcher:
|
|
| 1020 |
LIMIT %(limit)s;
|
| 1021 |
"""
|
| 1022 |
params = {'token': token_address, 'T_cutoff': T_cutoff, 'limit': int(limit)}
|
| 1023 |
-
print(f"INFO: Fetching top holders for snapshot for {token_address} (limit {limit}).")
|
| 1024 |
try:
|
| 1025 |
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 1026 |
if not rows:
|
|
@@ -1050,7 +1057,7 @@ class DataFetcher:
|
|
| 1050 |
WHERE rn_per_holding = 1 AND current_balance > 0;
|
| 1051 |
"""
|
| 1052 |
params = {'token': token_address, 'T_cutoff': T_cutoff}
|
| 1053 |
-
print(f"INFO: Counting total holders for {token_address} at
|
| 1054 |
try:
|
| 1055 |
rows = self.db_client.execute(query, params)
|
| 1056 |
if not rows:
|
|
@@ -1067,12 +1074,20 @@ class DataFetcher:
|
|
| 1067 |
max_horizon_seconds: int = 3600,
|
| 1068 |
include_wallet_data: bool = True,
|
| 1069 |
include_graph: bool = True,
|
| 1070 |
-
min_trades: int = 0
|
|
|
|
|
|
|
|
|
|
| 1071 |
) -> Optional[Dict[str, Any]]:
|
| 1072 |
"""
|
| 1073 |
Fetches ALL available data for a token up to the maximum horizon.
|
| 1074 |
This data is agnostic of T_cutoff and will be masked/filtered dynamically during training.
|
| 1075 |
Wallet/graph data can be skipped to avoid caching T_cutoff-dependent features.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1076 |
"""
|
| 1077 |
|
| 1078 |
# 1. Calculate the absolute maximum timestamp we care about (mint + max_horizon)
|
|
@@ -1086,8 +1101,9 @@ class DataFetcher:
|
|
| 1086 |
# So we pass max_limit_time as the "cutoff" for the purpose of raw data collection.
|
| 1087 |
|
| 1088 |
# We use a large enough limit to get all relevant trades for the session
|
|
|
|
| 1089 |
early_trades, middle_trades, recent_trades = self.fetch_trades_for_token(
|
| 1090 |
-
token_address, max_limit_time, 30000, 10000, 15000
|
| 1091 |
)
|
| 1092 |
|
| 1093 |
# Combine and deduplicate trades
|
|
@@ -1099,12 +1115,26 @@ class DataFetcher:
|
|
| 1099 |
|
| 1100 |
sorted_trades = sorted(list(all_trades.values()), key=lambda x: x['timestamp'])
|
| 1101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1102 |
if len(sorted_trades) < min_trades:
|
| 1103 |
print(f" SKIP: Token {token_address} has only {len(sorted_trades)} trades (min required: {min_trades}). skipping fetches.")
|
| 1104 |
return None
|
| 1105 |
|
| 1106 |
# 3. Fetch other events
|
| 1107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1108 |
pool_creations = self.fetch_pool_creations_for_token(token_address, max_limit_time)
|
| 1109 |
|
| 1110 |
# Collect pool addresses to fetch liquidity changes
|
|
|
|
| 626 |
|
| 627 |
return token_details
|
| 628 |
|
| 629 |
+
def fetch_trades_for_token(self, token_address: str, T_cutoff: datetime.datetime, count_threshold: int, early_limit: int, recent_limit: int, full_history: bool = False) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
|
| 630 |
"""
|
| 631 |
+
Fetches trades for a token.
|
| 632 |
+
If full_history is True, fetches ALL trades (ignores H/B/H limits).
|
| 633 |
+
Otherwise, uses the 3-part H/B/H strategy if the total count exceeds a threshold.
|
| 634 |
Returns three lists: early_trades, middle_trades, recent_trades.
|
| 635 |
"""
|
| 636 |
if not token_address:
|
|
|
|
| 638 |
|
| 639 |
params = {'token_address': token_address, 'T_cutoff': T_cutoff}
|
| 640 |
|
| 641 |
+
# 1. Get the total count if we care about H/B/H logic
|
| 642 |
+
if not full_history:
|
| 643 |
+
count_query = "SELECT count() FROM trades WHERE base_address = %(token_address)s AND timestamp <= %(T_cutoff)s"
|
| 644 |
+
try:
|
| 645 |
+
total_trades = self.db_client.execute(count_query, params)[0][0]
|
| 646 |
+
print(f"INFO: Found {total_trades} total trades for token {token_address} before {T_cutoff}.")
|
| 647 |
+
except Exception as e:
|
| 648 |
+
print(f"ERROR: Could not count trades for token {token_address}: {e}")
|
| 649 |
+
return [], [], []
|
| 650 |
+
else:
|
| 651 |
+
total_trades = 0 # Dummy value, ignored
|
| 652 |
+
|
| 653 |
+
# 2. Decide which query to use
|
| 654 |
+
# If full_history is ON, or count is low, fetch everything.
|
| 655 |
+
if full_history or total_trades < count_threshold:
|
| 656 |
+
mode = "Full History" if full_history else "Low Count"
|
| 657 |
+
# print(f"INFO: Fetching all trades ({mode}).")
|
| 658 |
query = "SELECT * FROM trades WHERE base_address = %(token_address)s AND timestamp <= %(T_cutoff)s ORDER BY timestamp ASC"
|
| 659 |
try:
|
| 660 |
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 661 |
if not rows: return [], [], []
|
| 662 |
columns = [col[0] for col in columns_info]
|
| 663 |
all_trades = [dict(zip(columns, row)) for row in rows]
|
| 664 |
+
# When not using HBH or fetching full history, all trades are considered "early" (or just one big block)
|
| 665 |
return all_trades, [], []
|
| 666 |
except Exception as e:
|
| 667 |
print(f"ERROR: Failed to fetch all trades for token {token_address}: {e}")
|
| 668 |
return [], [], []
|
| 669 |
|
| 670 |
+
# 3. Use the H/B/H strategy if the count is high AND not full_history
|
| 671 |
print("INFO: Fetching trades using 3-part High-Def/Blurry/High-Def strategy.")
|
| 672 |
try:
|
| 673 |
# Fetch Early (High-Def)
|
|
|
|
| 799 |
ORDER BY timestamp ASC
|
| 800 |
"""
|
| 801 |
params = {'token_address': token_address, 'T_cutoff': T_cutoff}
|
| 802 |
+
# print(f"INFO: Fetching pool creation events for {token_address}.")
|
| 803 |
|
| 804 |
try:
|
| 805 |
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
|
|
|
| 838 |
ORDER BY timestamp ASC
|
| 839 |
"""
|
| 840 |
params = {'pool_addresses': pool_addresses, 'T_cutoff': T_cutoff}
|
| 841 |
+
# print(f"INFO: Fetching liquidity change events for {len(pool_addresses)} pool(s).")
|
| 842 |
|
| 843 |
try:
|
| 844 |
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
|
|
|
| 877 |
ORDER BY timestamp ASC
|
| 878 |
"""
|
| 879 |
params = {'token': token_address, 'T_cutoff': T_cutoff}
|
| 880 |
+
# print(f"INFO: Fetching fee collection events for {token_address}.")
|
| 881 |
|
| 882 |
try:
|
| 883 |
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
|
|
|
| 915 |
ORDER BY timestamp ASC
|
| 916 |
"""
|
| 917 |
params = {'token': token_address, 'T_cutoff': T_cutoff}
|
| 918 |
+
# print(f"INFO: Fetching migrations for {token_address}.")
|
| 919 |
try:
|
| 920 |
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 921 |
if not rows:
|
|
|
|
| 953 |
ORDER BY timestamp ASC
|
| 954 |
"""
|
| 955 |
params = {'token': token_address, 'T_cutoff': T_cutoff}
|
| 956 |
+
# print(f"INFO: Fetching burn events for {token_address}.")
|
| 957 |
|
| 958 |
try:
|
| 959 |
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
|
|
|
| 994 |
ORDER BY timestamp ASC
|
| 995 |
"""
|
| 996 |
params = {'token': token_address, 'T_cutoff': T_cutoff}
|
| 997 |
+
# print(f"INFO: Fetching supply lock events for {token_address}.")
|
| 998 |
|
| 999 |
try:
|
| 1000 |
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
|
|
|
| 1027 |
LIMIT %(limit)s;
|
| 1028 |
"""
|
| 1029 |
params = {'token': token_address, 'T_cutoff': T_cutoff, 'limit': int(limit)}
|
| 1030 |
+
# print(f"INFO: Fetching top holders for snapshot for {token_address} (limit {limit}).")
|
| 1031 |
try:
|
| 1032 |
rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
|
| 1033 |
if not rows:
|
|
|
|
| 1057 |
WHERE rn_per_holding = 1 AND current_balance > 0;
|
| 1058 |
"""
|
| 1059 |
params = {'token': token_address, 'T_cutoff': T_cutoff}
|
| 1060 |
+
# print(f"INFO: Counting total holders for {token_address} at timestamp {T_cutoff}.")
|
| 1061 |
try:
|
| 1062 |
rows = self.db_client.execute(query, params)
|
| 1063 |
if not rows:
|
|
|
|
| 1074 |
max_horizon_seconds: int = 3600,
|
| 1075 |
include_wallet_data: bool = True,
|
| 1076 |
include_graph: bool = True,
|
| 1077 |
+
min_trades: int = 0,
|
| 1078 |
+
full_history: bool = False,
|
| 1079 |
+
prune_failed: bool = False,
|
| 1080 |
+
prune_transfers: bool = False
|
| 1081 |
) -> Optional[Dict[str, Any]]:
|
| 1082 |
"""
|
| 1083 |
Fetches ALL available data for a token up to the maximum horizon.
|
| 1084 |
This data is agnostic of T_cutoff and will be masked/filtered dynamically during training.
|
| 1085 |
Wallet/graph data can be skipped to avoid caching T_cutoff-dependent features.
|
| 1086 |
+
|
| 1087 |
+
Args:
|
| 1088 |
+
full_history: If True, fetches ALL trades ignoring H/B/H limits.
|
| 1089 |
+
prune_failed: If True, filters out failed trades from the result.
|
| 1090 |
+
prune_transfers: If True, skips fetching transfers entirely.
|
| 1091 |
"""
|
| 1092 |
|
| 1093 |
# 1. Calculate the absolute maximum timestamp we care about (mint + max_horizon)
|
|
|
|
| 1101 |
# So we pass max_limit_time as the "cutoff" for the purpose of raw data collection.
|
| 1102 |
|
| 1103 |
# We use a large enough limit to get all relevant trades for the session
|
| 1104 |
+
# If full_history is True, these limits are ignored inside the method.
|
| 1105 |
early_trades, middle_trades, recent_trades = self.fetch_trades_for_token(
|
| 1106 |
+
token_address, max_limit_time, 30000, 10000, 15000, full_history=full_history
|
| 1107 |
)
|
| 1108 |
|
| 1109 |
# Combine and deduplicate trades
|
|
|
|
| 1115 |
|
| 1116 |
sorted_trades = sorted(list(all_trades.values()), key=lambda x: x['timestamp'])
|
| 1117 |
|
| 1118 |
+
# --- PRUNING FAILED TRADES ---
|
| 1119 |
+
if prune_failed:
|
| 1120 |
+
original_count = len(sorted_trades)
|
| 1121 |
+
sorted_trades = [t for t in sorted_trades if t.get('success', False)]
|
| 1122 |
+
if len(sorted_trades) < original_count:
|
| 1123 |
+
# print(f" INFO: Pruned {original_count - len(sorted_trades)} failed trades.")
|
| 1124 |
+
pass
|
| 1125 |
+
|
| 1126 |
if len(sorted_trades) < min_trades:
|
| 1127 |
print(f" SKIP: Token {token_address} has only {len(sorted_trades)} trades (min required: {min_trades}). skipping fetches.")
|
| 1128 |
return None
|
| 1129 |
|
| 1130 |
# 3. Fetch other events
|
| 1131 |
+
# --- PRUNING TRANSFERS ---
|
| 1132 |
+
if prune_transfers:
|
| 1133 |
+
transfers = []
|
| 1134 |
+
# print(" INFO: Pruning transfers (skipping fetch).")
|
| 1135 |
+
else:
|
| 1136 |
+
transfers = self.fetch_transfers_for_token(token_address, max_limit_time, 0.0) # 0.0 means fetch all
|
| 1137 |
+
|
| 1138 |
pool_creations = self.fetch_pool_creations_for_token(token_address, max_limit_time)
|
| 1139 |
|
| 1140 |
# Collect pool addresses to fetch liquidity changes
|
data/data_loader.py
CHANGED
|
@@ -97,11 +97,11 @@ class OracleDataset(Dataset):
|
|
| 97 |
input sequence for the model.
|
| 98 |
"""
|
| 99 |
def __init__(self,
|
| 100 |
-
data_fetcher: DataFetcher, #
|
| 101 |
horizons_seconds: List[int] = [],
|
| 102 |
quantiles: List[float] = [],
|
| 103 |
max_samples: Optional[int] = None,
|
| 104 |
-
ohlc_stats_path: Union[str, Path] = "./data/ohlc_stats.npz",
|
| 105 |
token_allowlist: Optional[List[str]] = None,
|
| 106 |
t_cutoff_seconds: int = 60,
|
| 107 |
cache_dir: Optional[Union[str, Path]] = None,
|
|
@@ -273,7 +273,8 @@ class OracleDataset(Dataset):
|
|
| 273 |
aggregation_trades: List[Dict[str, Any]],
|
| 274 |
wallet_data: Dict[str, Any],
|
| 275 |
total_supply_dec: float,
|
| 276 |
-
_register_event_fn
|
|
|
|
| 277 |
) -> None:
|
| 278 |
# Prepare helper sets and maps (static sniper set based on earliest buyers)
|
| 279 |
all_buy_trades = sorted([e for e in trade_events if e.get('trade_direction') == 0 and e.get('success', False)], key=lambda x: x['timestamp'])
|
|
@@ -304,14 +305,25 @@ class OracleDataset(Dataset):
|
|
| 304 |
|
| 305 |
buyers_seen_global = set()
|
| 306 |
prev_holders_count = 0
|
| 307 |
-
for ts_value in oc_snapshot_times:
|
| 308 |
window_start = ts_value - interval_sec
|
| 309 |
trades_win = [e for e in trade_events if e.get('success', False) and window_start < e['timestamp'] <= ts_value]
|
| 310 |
xfers_win = [e for e in transfer_events if window_start < e['timestamp'] <= ts_value]
|
| 311 |
|
| 312 |
# Per-snapshot holder distribution at ts_value
|
| 313 |
-
|
| 314 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
holder_entries_ts = []
|
| 316 |
for rec in holder_records_ts:
|
| 317 |
addr = rec.get('wallet_address')
|
|
@@ -363,8 +375,7 @@ class OracleDataset(Dataset):
|
|
| 363 |
buyers_seen_global.add(wa)
|
| 364 |
|
| 365 |
# Compute growth against previous snapshot endpoint.
|
| 366 |
-
|
| 367 |
-
holders_end = self.fetcher.fetch_total_holders_count_for_token(token_address, end_dt)
|
| 368 |
total_holders = float(holders_end)
|
| 369 |
delta_holders = holders_end - prev_holders_count
|
| 370 |
holder_growth_rate = float(delta_holders)
|
|
@@ -415,7 +426,7 @@ class OracleDataset(Dataset):
|
|
| 415 |
|
| 416 |
# Fetch all token details in ONE batch query
|
| 417 |
all_deployed_token_details = {}
|
| 418 |
-
if all_deployed_tokens:
|
| 419 |
all_deployed_token_details = self.fetcher.fetch_deployed_token_details(list(all_deployed_tokens), T_cutoff)
|
| 420 |
|
| 421 |
for addr, profile in profiles.items():
|
|
@@ -454,18 +465,24 @@ class OracleDataset(Dataset):
|
|
| 454 |
profile['deployed_tokens_avg_peak_mc_usd'] = torch.mean(torch.tensor(peak_mcs)).item() if peak_mcs else 0.0
|
| 455 |
profile['deployed_tokens_median_peak_mc_usd'] = torch.median(torch.tensor(peak_mcs)).item() if peak_mcs else 0.0
|
| 456 |
|
| 457 |
-
def _process_wallet_data(self, wallet_addresses: List[str], token_data: Dict[str, Any], pooler: EmbeddingPooler, T_cutoff: datetime.datetime
|
|
|
|
| 458 |
"""
|
| 459 |
-
Fetches
|
| 460 |
-
Uses a T_cutoff to ensure data is point-in-time accurate.
|
| 461 |
"""
|
| 462 |
if not wallet_addresses:
|
| 463 |
return {}, token_data
|
| 464 |
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
|
| 470 |
valid_wallets = [addr for addr in wallet_addresses if addr in profiles]
|
| 471 |
dropped_wallets = set(wallet_addresses) - set(valid_wallets)
|
|
@@ -618,8 +635,11 @@ class OracleDataset(Dataset):
|
|
| 618 |
return {}
|
| 619 |
|
| 620 |
if token_data is None:
|
| 621 |
-
|
| 622 |
-
|
|
|
|
|
|
|
|
|
|
| 623 |
|
| 624 |
# --- NEW: Print the raw fetched token data as requested ---
|
| 625 |
print("\n--- RAW TOKEN DATA FROM DATABASE ---")
|
|
@@ -793,14 +813,13 @@ class OracleDataset(Dataset):
|
|
| 793 |
try:
|
| 794 |
raw_data = torch.load(filepath, map_location='cpu', weights_only=False)
|
| 795 |
except Exception as e:
|
| 796 |
-
|
| 797 |
-
return None
|
| 798 |
else:
|
| 799 |
# Online mode fallback
|
| 800 |
raw_data = self.__cacheitem__(idx)
|
| 801 |
|
| 802 |
if not raw_data:
|
| 803 |
-
|
| 804 |
|
| 805 |
required_keys = [
|
| 806 |
"mint_timestamp",
|
|
@@ -822,8 +841,8 @@ class OracleDataset(Dataset):
|
|
| 822 |
f"Cached sample missing raw fields ({missing_keys}). Rebuild cache with raw caching enabled."
|
| 823 |
)
|
| 824 |
|
| 825 |
-
if not self.fetcher:
|
| 826 |
-
|
| 827 |
|
| 828 |
def _timestamp_to_order_value(ts_value: Any) -> float:
|
| 829 |
if isinstance(ts_value, datetime.datetime):
|
|
@@ -904,34 +923,53 @@ class OracleDataset(Dataset):
|
|
| 904 |
if _timestamp_to_order_value(liq.get('timestamp')) <= cutoff_ts:
|
| 905 |
_add_wallet(liq.get('lp_provider'), wallets_to_fetch)
|
| 906 |
|
| 907 |
-
|
| 908 |
-
|
| 909 |
-
|
| 910 |
-
|
| 911 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 912 |
for holder in holder_records:
|
| 913 |
_add_wallet(holder.get('wallet_address'), wallets_to_fetch)
|
| 914 |
|
| 915 |
pooler = EmbeddingPooler()
|
| 916 |
-
|
|
|
|
|
|
|
| 917 |
if not main_token_data:
|
| 918 |
return None
|
| 919 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 920 |
wallet_data, all_token_data = self._process_wallet_data(
|
| 921 |
list(wallets_to_fetch),
|
| 922 |
main_token_data.copy(),
|
| 923 |
pooler,
|
| 924 |
-
T_cutoff
|
|
|
|
|
|
|
|
|
|
| 925 |
)
|
| 926 |
|
| 927 |
graph_entities = {}
|
| 928 |
graph_links = {}
|
| 929 |
-
|
| 930 |
-
|
| 931 |
-
|
| 932 |
-
|
| 933 |
-
|
| 934 |
-
)
|
| 935 |
|
| 936 |
# Generate the item
|
| 937 |
return self._generate_dataset_item(
|
|
@@ -960,13 +998,14 @@ class OracleDataset(Dataset):
|
|
| 960 |
graph_seed_entities=wallets_to_fetch,
|
| 961 |
all_graph_entities=graph_entities,
|
| 962 |
future_trades_for_labels=raw_data['trades'], # We utilize full trade history for labels!
|
| 963 |
-
pooler=pooler
|
|
|
|
| 964 |
)
|
| 965 |
|
| 966 |
def __cacheitem__(self, idx: int) -> Optional[Dict[str, Any]]:
|
| 967 |
"""
|
| 968 |
Fetches cutoff-agnostic raw token data for caching/online sampling.
|
| 969 |
-
|
| 970 |
"""
|
| 971 |
|
| 972 |
if not self.sampled_mints:
|
|
@@ -984,6 +1023,7 @@ class OracleDataset(Dataset):
|
|
| 984 |
if not self.fetcher:
|
| 985 |
raise RuntimeError("Dataset has no data fetcher; cannot load raw data.")
|
| 986 |
|
|
|
|
| 987 |
raw_data = self.fetcher.fetch_raw_token_data(
|
| 988 |
token_address=token_address,
|
| 989 |
creator_address=creator_address,
|
|
@@ -991,7 +1031,10 @@ class OracleDataset(Dataset):
|
|
| 991 |
max_horizon_seconds=self.max_cache_horizon_seconds,
|
| 992 |
include_wallet_data=False,
|
| 993 |
include_graph=False,
|
| 994 |
-
min_trades=50
|
|
|
|
|
|
|
|
|
|
| 995 |
)
|
| 996 |
if raw_data is None:
|
| 997 |
return None
|
|
@@ -1005,56 +1048,134 @@ class OracleDataset(Dataset):
|
|
| 1005 |
return float(ts_value)
|
| 1006 |
except (TypeError, ValueError):
|
| 1007 |
return 0.0
|
| 1008 |
-
|
| 1009 |
-
|
| 1010 |
-
|
| 1011 |
-
|
| 1012 |
-
if trade.get('timestamp') is not None
|
| 1013 |
-
]
|
| 1014 |
if not trade_ts_values:
|
| 1015 |
print(f" SKIP: No valid trades found for {token_address}.")
|
| 1016 |
return None
|
|
|
|
|
|
|
|
|
|
| 1017 |
|
| 1018 |
-
|
| 1019 |
-
|
| 1020 |
-
|
| 1021 |
-
min_window = 30
|
| 1022 |
|
| 1023 |
-
#
|
| 1024 |
-
#
|
| 1025 |
-
|
| 1026 |
-
# lower_bound = max(min_window, first_trade - mint)
|
| 1027 |
-
# upper_bound = (last_trade - mint) - required_horizon
|
| 1028 |
-
# We need upper_bound >= lower_bound.
|
| 1029 |
|
| 1030 |
-
|
| 1031 |
-
|
| 1032 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1033 |
|
| 1034 |
-
|
| 1035 |
-
|
| 1036 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1037 |
|
| 1038 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1039 |
|
| 1040 |
-
|
| 1041 |
-
# but technically we'd prefer to satisfy at least one horizon.
|
| 1042 |
-
# Using min_label (which is max(60, first_horizon)) is safe.
|
| 1043 |
-
required_horizon = min_label
|
| 1044 |
-
upper_bound = end_offset - required_horizon
|
| 1045 |
|
| 1046 |
-
|
| 1047 |
-
|
| 1048 |
-
|
| 1049 |
-
|
| 1050 |
-
|
| 1051 |
-
|
| 1052 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1053 |
|
| 1054 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1055 |
|
| 1056 |
-
|
| 1057 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1058 |
|
| 1059 |
raw_data["protocol_id"] = initial_mint_record.get("protocol")
|
| 1060 |
return raw_data
|
|
@@ -1078,7 +1199,8 @@ class OracleDataset(Dataset):
|
|
| 1078 |
graph_seed_entities: set,
|
| 1079 |
all_graph_entities: Dict[str, str],
|
| 1080 |
future_trades_for_labels: List[Dict[str, Any]],
|
| 1081 |
-
pooler: EmbeddingPooler
|
|
|
|
| 1082 |
) -> Optional[Dict[str, Any]]:
|
| 1083 |
"""
|
| 1084 |
Processes raw token data into a structured dataset item for a specific T_cutoff.
|
|
@@ -1305,7 +1427,8 @@ class OracleDataset(Dataset):
|
|
| 1305 |
aggregation_trades,
|
| 1306 |
wallet_data,
|
| 1307 |
total_supply_dec,
|
| 1308 |
-
_register_event
|
|
|
|
| 1309 |
)
|
| 1310 |
|
| 1311 |
# 7. Finalize Sequence
|
|
|
|
| 97 |
input sequence for the model.
|
| 98 |
"""
|
| 99 |
def __init__(self,
|
| 100 |
+
data_fetcher: Optional[DataFetcher] = None, # OPTIONAL: Only needed for caching (Writer)
|
| 101 |
horizons_seconds: List[int] = [],
|
| 102 |
quantiles: List[float] = [],
|
| 103 |
max_samples: Optional[int] = None,
|
| 104 |
+
ohlc_stats_path: Union[str, Path] = "./data/ohlc_stats.npz",
|
| 105 |
token_allowlist: Optional[List[str]] = None,
|
| 106 |
t_cutoff_seconds: int = 60,
|
| 107 |
cache_dir: Optional[Union[str, Path]] = None,
|
|
|
|
| 273 |
aggregation_trades: List[Dict[str, Any]],
|
| 274 |
wallet_data: Dict[str, Any],
|
| 275 |
total_supply_dec: float,
|
| 276 |
+
_register_event_fn,
|
| 277 |
+
cached_holders_list: List[List[str]] = None
|
| 278 |
) -> None:
|
| 279 |
# Prepare helper sets and maps (static sniper set based on earliest buyers)
|
| 280 |
all_buy_trades = sorted([e for e in trade_events if e.get('trade_direction') == 0 and e.get('success', False)], key=lambda x: x['timestamp'])
|
|
|
|
| 305 |
|
| 306 |
buyers_seen_global = set()
|
| 307 |
prev_holders_count = 0
|
| 308 |
+
for i, ts_value in enumerate(oc_snapshot_times):
|
| 309 |
window_start = ts_value - interval_sec
|
| 310 |
trades_win = [e for e in trade_events if e.get('success', False) and window_start < e['timestamp'] <= ts_value]
|
| 311 |
xfers_win = [e for e in transfer_events if window_start < e['timestamp'] <= ts_value]
|
| 312 |
|
| 313 |
# Per-snapshot holder distribution at ts_value
|
| 314 |
+
holder_records_ts = []
|
| 315 |
+
holders_end = 0
|
| 316 |
+
if cached_holders_list is not None and i < len(cached_holders_list):
|
| 317 |
+
# Use cached list of addresses
|
| 318 |
+
holder_records_ts = [{'wallet_address': addr, 'current_balance': 0} for addr in cached_holders_list[i]]
|
| 319 |
+
holders_end = len(cached_holders_list[i])
|
| 320 |
+
elif self.fetcher:
|
| 321 |
+
cutoff_dt_ts = datetime.datetime.fromtimestamp(ts_value, tz=datetime.timezone.utc)
|
| 322 |
+
holder_records_ts = self.fetcher.fetch_token_holders_for_snapshot(token_address, cutoff_dt_ts, limit=HOLDER_SNAPSHOT_TOP_K)
|
| 323 |
+
holders_end = self.fetcher.fetch_total_holders_count_for_token(token_address, cutoff_dt_ts)
|
| 324 |
+
else:
|
| 325 |
+
holder_records_ts = []
|
| 326 |
+
holders_end = 0
|
| 327 |
holder_entries_ts = []
|
| 328 |
for rec in holder_records_ts:
|
| 329 |
addr = rec.get('wallet_address')
|
|
|
|
| 375 |
buyers_seen_global.add(wa)
|
| 376 |
|
| 377 |
# Compute growth against previous snapshot endpoint.
|
| 378 |
+
# total_holders = float(holders_end) # already handled above
|
|
|
|
| 379 |
total_holders = float(holders_end)
|
| 380 |
delta_holders = holders_end - prev_holders_count
|
| 381 |
holder_growth_rate = float(delta_holders)
|
|
|
|
| 426 |
|
| 427 |
# Fetch all token details in ONE batch query
|
| 428 |
all_deployed_token_details = {}
|
| 429 |
+
if all_deployed_tokens and self.fetcher:
|
| 430 |
all_deployed_token_details = self.fetcher.fetch_deployed_token_details(list(all_deployed_tokens), T_cutoff)
|
| 431 |
|
| 432 |
for addr, profile in profiles.items():
|
|
|
|
| 465 |
profile['deployed_tokens_avg_peak_mc_usd'] = torch.mean(torch.tensor(peak_mcs)).item() if peak_mcs else 0.0
|
| 466 |
profile['deployed_tokens_median_peak_mc_usd'] = torch.median(torch.tensor(peak_mcs)).item() if peak_mcs else 0.0
|
| 467 |
|
| 468 |
+
def _process_wallet_data(self, wallet_addresses: List[str], token_data: Dict[str, Any], pooler: EmbeddingPooler, T_cutoff: datetime.datetime,
|
| 469 |
+
profiles_override: Optional[Dict] = None, socials_override: Optional[Dict] = None, holdings_override: Optional[Dict] = None) -> tuple[Dict[str, Dict[str, Any]], Dict[str, Dict[str, Any]]]:
|
| 470 |
"""
|
| 471 |
+
Fetches or uses cached profile, social, and holdings data.
|
|
|
|
| 472 |
"""
|
| 473 |
if not wallet_addresses:
|
| 474 |
return {}, token_data
|
| 475 |
|
| 476 |
+
if profiles_override is not None and socials_override is not None:
|
| 477 |
+
profiles, socials = profiles_override, socials_override
|
| 478 |
+
holdings = holdings_override if holdings_override is not None else {}
|
| 479 |
+
else:
|
| 480 |
+
print(f"INFO: Processing wallet data for {len(wallet_addresses)} unique wallets...")
|
| 481 |
+
if self.fetcher:
|
| 482 |
+
profiles, socials = self.fetcher.fetch_wallet_profiles_and_socials(wallet_addresses, T_cutoff)
|
| 483 |
+
holdings = self.fetcher.fetch_wallet_holdings(wallet_addresses, T_cutoff)
|
| 484 |
+
else:
|
| 485 |
+
profiles, socials, holdings = {}, {}, {}
|
| 486 |
|
| 487 |
valid_wallets = [addr for addr in wallet_addresses if addr in profiles]
|
| 488 |
dropped_wallets = set(wallet_addresses) - set(valid_wallets)
|
|
|
|
| 635 |
return {}
|
| 636 |
|
| 637 |
if token_data is None:
|
| 638 |
+
if self.fetcher:
|
| 639 |
+
print(f"INFO: Processing token data for {len(token_addresses)} unique tokens...")
|
| 640 |
+
token_data = self.fetcher.fetch_token_data(token_addresses, T_cutoff)
|
| 641 |
+
else:
|
| 642 |
+
token_data = {}
|
| 643 |
|
| 644 |
# --- NEW: Print the raw fetched token data as requested ---
|
| 645 |
print("\n--- RAW TOKEN DATA FROM DATABASE ---")
|
|
|
|
| 813 |
try:
|
| 814 |
raw_data = torch.load(filepath, map_location='cpu', weights_only=False)
|
| 815 |
except Exception as e:
|
| 816 |
+
raise RuntimeError(f"ERROR: Could not load cached item {filepath}: {e}")
|
|
|
|
| 817 |
else:
|
| 818 |
# Online mode fallback
|
| 819 |
raw_data = self.__cacheitem__(idx)
|
| 820 |
|
| 821 |
if not raw_data:
|
| 822 |
+
raise RuntimeError(f"No raw data loaded for index {idx}")
|
| 823 |
|
| 824 |
required_keys = [
|
| 825 |
"mint_timestamp",
|
|
|
|
| 841 |
f"Cached sample missing raw fields ({missing_keys}). Rebuild cache with raw caching enabled."
|
| 842 |
)
|
| 843 |
|
| 844 |
+
# if not self.fetcher:
|
| 845 |
+
# raise RuntimeError("Data fetcher required for T_cutoff-dependent data.")
|
| 846 |
|
| 847 |
def _timestamp_to_order_value(ts_value: Any) -> float:
|
| 848 |
if isinstance(ts_value, datetime.datetime):
|
|
|
|
| 923 |
if _timestamp_to_order_value(liq.get('timestamp')) <= cutoff_ts:
|
| 924 |
_add_wallet(liq.get('lp_provider'), wallets_to_fetch)
|
| 925 |
|
| 926 |
+
# Offline Holder Lookup using raw_data['holder_snapshots_list']
|
| 927 |
+
# We need the snapshot corresponding to T_cutoff.
|
| 928 |
+
# Intervals are every 300s from mint_ts.
|
| 929 |
+
# idx = (T_cutoff - mint) // 300
|
| 930 |
+
elapsed = (T_cutoff - mint_timestamp).total_seconds()
|
| 931 |
+
snap_idx = int(elapsed // 300)
|
| 932 |
+
holder_records = []
|
| 933 |
+
cached_holders_list = raw_data.get('holder_snapshots_list', [])
|
| 934 |
+
if 0 <= snap_idx < len(cached_holders_list):
|
| 935 |
+
# Format expected by _add_wallet: dict with 'wallet_address'
|
| 936 |
+
holder_records = [{'wallet_address': addr} for addr in cached_holders_list[snap_idx]]
|
| 937 |
for holder in holder_records:
|
| 938 |
_add_wallet(holder.get('wallet_address'), wallets_to_fetch)
|
| 939 |
|
| 940 |
pooler = EmbeddingPooler()
|
| 941 |
+
# Prepare offline token data
|
| 942 |
+
offline_token_data = {token_address: raw_data} # Assuming raw_data contains token metadata at root
|
| 943 |
+
main_token_data = self._process_token_data([token_address], pooler, T_cutoff, token_data=offline_token_data)
|
| 944 |
if not main_token_data:
|
| 945 |
return None
|
| 946 |
|
| 947 |
+
# Prepare offline wallet data
|
| 948 |
+
# raw_data['socials'] structure: {'profiles': {...}, 'socials': {...}} usually.
|
| 949 |
+
# But wait, cached raw_data['socials'] might be just the dict we need?
|
| 950 |
+
# Let's handle graceful empty if not found.
|
| 951 |
+
cached_social_bundle = raw_data.get('socials', {})
|
| 952 |
+
offline_profiles = cached_social_bundle.get('profiles', {})
|
| 953 |
+
offline_socials = cached_social_bundle.get('socials', {})
|
| 954 |
+
offline_holdings = {} # Holdings not cached usually due to size
|
| 955 |
+
|
| 956 |
wallet_data, all_token_data = self._process_wallet_data(
|
| 957 |
list(wallets_to_fetch),
|
| 958 |
main_token_data.copy(),
|
| 959 |
pooler,
|
| 960 |
+
T_cutoff,
|
| 961 |
+
profiles_override=offline_profiles,
|
| 962 |
+
socials_override=offline_socials,
|
| 963 |
+
holdings_override=offline_holdings
|
| 964 |
)
|
| 965 |
|
| 966 |
graph_entities = {}
|
| 967 |
graph_links = {}
|
| 968 |
+
graph_entities = {}
|
| 969 |
+
graph_links = {}
|
| 970 |
+
# if wallets_to_fetch:
|
| 971 |
+
# graph_entities, graph_links = self.fetcher.fetch_graph_links(...)
|
| 972 |
+
# Offline Graph: check if raw_data has graph? Assuming no for now.
|
|
|
|
| 973 |
|
| 974 |
# Generate the item
|
| 975 |
return self._generate_dataset_item(
|
|
|
|
| 998 |
graph_seed_entities=wallets_to_fetch,
|
| 999 |
all_graph_entities=graph_entities,
|
| 1000 |
future_trades_for_labels=raw_data['trades'], # We utilize full trade history for labels!
|
| 1001 |
+
pooler=pooler,
|
| 1002 |
+
cached_holders_list=raw_data.get('holder_snapshots_list')
|
| 1003 |
)
|
| 1004 |
|
| 1005 |
def __cacheitem__(self, idx: int) -> Optional[Dict[str, Any]]:
|
| 1006 |
"""
|
| 1007 |
Fetches cutoff-agnostic raw token data for caching/online sampling.
|
| 1008 |
+
Generates dense time-series (1s OHLC, Snapshots) and prunes raw logs.
|
| 1009 |
"""
|
| 1010 |
|
| 1011 |
if not self.sampled_mints:
|
|
|
|
| 1023 |
if not self.fetcher:
|
| 1024 |
raise RuntimeError("Dataset has no data fetcher; cannot load raw data.")
|
| 1025 |
|
| 1026 |
+
# --- FETCH FULL HISTORY with PRUNING ---
|
| 1027 |
raw_data = self.fetcher.fetch_raw_token_data(
|
| 1028 |
token_address=token_address,
|
| 1029 |
creator_address=creator_address,
|
|
|
|
| 1031 |
max_horizon_seconds=self.max_cache_horizon_seconds,
|
| 1032 |
include_wallet_data=False,
|
| 1033 |
include_graph=False,
|
| 1034 |
+
min_trades=50,
|
| 1035 |
+
full_history=True, # Bypass H/B/H limits
|
| 1036 |
+
prune_failed=True, # Drop failed trades
|
| 1037 |
+
prune_transfers=True # Drop transfers (captured in snapshots)
|
| 1038 |
)
|
| 1039 |
if raw_data is None:
|
| 1040 |
return None
|
|
|
|
| 1048 |
return float(ts_value)
|
| 1049 |
except (TypeError, ValueError):
|
| 1050 |
return 0.0
|
| 1051 |
+
|
| 1052 |
+
trades = raw_data.get('trades', [])
|
| 1053 |
+
trade_ts_values = [_timestamp_to_order_value(t.get('timestamp')) for t in trades]
|
| 1054 |
+
|
|
|
|
|
|
|
| 1055 |
if not trade_ts_values:
|
| 1056 |
print(f" SKIP: No valid trades found for {token_address}.")
|
| 1057 |
return None
|
| 1058 |
+
|
| 1059 |
+
t0_val = _timestamp_to_order_value(t0)
|
| 1060 |
+
last_trade_ts_val = max(trade_ts_values)
|
| 1061 |
|
| 1062 |
+
# --- GENERATE DENSE 1s OHLC ---
|
| 1063 |
+
duration_seconds = int(last_trade_ts_val - t0_val) + 120 # Add buffer
|
| 1064 |
+
ohlc_1s = torch.zeros((duration_seconds, 2), dtype=torch.float32)
|
|
|
|
| 1065 |
|
| 1066 |
+
# Sort trades by time
|
| 1067 |
+
# raw_data trades are already sorted by fetcher, but let's be safe
|
| 1068 |
+
trades.sort(key=lambda x: _timestamp_to_order_value(x['timestamp']))
|
|
|
|
|
|
|
|
|
|
| 1069 |
|
| 1070 |
+
# Fill OHLC
|
| 1071 |
+
# A faster way: group by second
|
| 1072 |
+
# We can use a simple loop update or numpy accumulation.
|
| 1073 |
+
# Given standard density, simple loop is fine for caching.
|
| 1074 |
+
|
| 1075 |
+
trades_by_sec = defaultdict(list)
|
| 1076 |
+
for t in trades:
|
| 1077 |
+
ts = _timestamp_to_order_value(t['timestamp'])
|
| 1078 |
+
sec_idx = int(ts - t0_val)
|
| 1079 |
+
if 0 <= sec_idx < duration_seconds:
|
| 1080 |
+
trades_by_sec[sec_idx].append(t['price_usd'])
|
| 1081 |
+
|
| 1082 |
+
last_close = float(trades[0]['price_usd'])
|
| 1083 |
|
| 1084 |
+
for i in range(duration_seconds):
|
| 1085 |
+
if i in trades_by_sec:
|
| 1086 |
+
prices = trades_by_sec[i]
|
| 1087 |
+
op = prices[0]
|
| 1088 |
+
cl = prices[-1]
|
| 1089 |
+
last_close = cl
|
| 1090 |
+
else:
|
| 1091 |
+
op = cl = last_close
|
| 1092 |
+
|
| 1093 |
+
ohlc_1s[i, 0] = float(op)
|
| 1094 |
+
ohlc_1s[i, 1] = float(cl)
|
| 1095 |
+
|
| 1096 |
+
raw_data['ohlc_1s'] = ohlc_1s
|
| 1097 |
|
| 1098 |
+
# --- GENERATE ON-CHAIN SNAPSHOTS (5m Interval) ---
|
| 1099 |
+
interval = 300 # 5 minutes
|
| 1100 |
+
num_intervals = (duration_seconds // interval) + 1
|
| 1101 |
+
# Feature columns: [volume, tx_count, buy_count, sell_count, total_holders, top_10_holder_pct]
|
| 1102 |
+
# We start with basic trade stats. Holder stats require DB queries.
|
| 1103 |
|
| 1104 |
+
snapshot_stats = torch.zeros((num_intervals, 6), dtype=torch.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1105 |
|
| 1106 |
+
print(f" INFO: Generating {num_intervals} snapshots (Interval: {interval}s)...")
|
| 1107 |
+
|
| 1108 |
+
cum_volume = 0.0
|
| 1109 |
+
cum_tx = 0
|
| 1110 |
+
cum_buys = 0
|
| 1111 |
+
cum_sells = 0
|
| 1112 |
+
|
| 1113 |
+
# Pre-group trades into 5m buckets for windowed volume
|
| 1114 |
+
buckets = defaultdict(list)
|
| 1115 |
+
for t in trades:
|
| 1116 |
+
ts = _timestamp_to_order_value(t['timestamp'])
|
| 1117 |
+
bucket_idx = int(ts - t0_val) // interval
|
| 1118 |
+
if bucket_idx >= 0:
|
| 1119 |
+
buckets[bucket_idx].append(t)
|
| 1120 |
+
|
| 1121 |
+
# To avoid spamming DB, we might query holders less frequently or batch?
|
| 1122 |
+
# For now, query every step. 288 queries for 24h is fine.
|
| 1123 |
+
|
| 1124 |
+
fetched_holders_cache = {} # Map bucket_idx -> (count, top10_pct)
|
| 1125 |
+
holder_snapshots_list = [] # List of (timestamp, holders_list)
|
| 1126 |
+
|
| 1127 |
+
for i in range(num_intervals):
|
| 1128 |
+
bucket_trades = buckets[i]
|
| 1129 |
|
| 1130 |
+
# Windowed Stats
|
| 1131 |
+
vol = sum(t.get('total_usd', 0.0) for t in bucket_trades)
|
| 1132 |
+
tx = len(bucket_trades)
|
| 1133 |
+
buys = sum(1 for t in bucket_trades if t.get('trade_direction') == 0 or t.get('trade_type') == 0) # 0=Buy
|
| 1134 |
+
sells = tx - buys
|
| 1135 |
|
| 1136 |
+
# DB Stats: Holders (Point-in-Time)
|
| 1137 |
+
# Time is end of bucket
|
| 1138 |
+
snapshot_ts = t0 + datetime.timedelta(seconds=(i+1)*interval)
|
| 1139 |
+
|
| 1140 |
+
# These queries can be slow.
|
| 1141 |
+
count = self.fetcher.fetch_total_holders_count_for_token(token_address, snapshot_ts)
|
| 1142 |
+
# Fetch Top 200 as per constant
|
| 1143 |
+
top_holders = self.fetcher.fetch_token_holders_for_snapshot(token_address, snapshot_ts, limit=HOLDER_SNAPSHOT_TOP_K)
|
| 1144 |
+
|
| 1145 |
+
total_supply = raw_data.get('total_supply', 0) or 1
|
| 1146 |
+
if raw_data.get('decimals'):
|
| 1147 |
+
total_supply /= (10 ** raw_data['decimals'])
|
| 1148 |
+
|
| 1149 |
+
top10_bal = sum(h.get('current_balance', 0) for h in top_holders[:10])
|
| 1150 |
+
top10_pct = (top10_bal / total_supply) if total_supply > 0 else 0.0
|
| 1151 |
+
|
| 1152 |
+
snapshot_stats[i, 0] = float(vol)
|
| 1153 |
+
snapshot_stats[i, 1] = float(tx)
|
| 1154 |
+
snapshot_stats[i, 2] = float(buys)
|
| 1155 |
+
snapshot_stats[i, 3] = float(sells)
|
| 1156 |
+
snapshot_stats[i, 4] = float(count)
|
| 1157 |
+
snapshot_stats[i, 5] = float(top10_pct)
|
| 1158 |
+
|
| 1159 |
+
# Save the holder identities for the event stream
|
| 1160 |
+
# Make it JSON-serializable-ish (no datetime objects)
|
| 1161 |
+
holder_snapshots_list.append({
|
| 1162 |
+
'timestamp': int(snapshot_ts.timestamp()),
|
| 1163 |
+
'holders': top_holders # [{wallet, balance}, ...]
|
| 1164 |
+
})
|
| 1165 |
+
|
| 1166 |
+
raw_data['snapshots_5m'] = snapshot_stats
|
| 1167 |
+
raw_data['holder_snapshots_list'] = holder_snapshots_list # Save the list
|
| 1168 |
+
|
| 1169 |
+
# --- Summary Log ---
|
| 1170 |
+
print(f" [Cache Summary]")
|
| 1171 |
+
print(f" - 1s Candles: {len(ohlc_1s)}")
|
| 1172 |
+
print(f" - 5m Snapshots: {len(snapshot_stats)}")
|
| 1173 |
+
print(f" - Trades (Succ): {len(trades)}")
|
| 1174 |
+
print(f" - Pool Events: {len(raw_data.get('pool_creations', []))}")
|
| 1175 |
+
print(f" - Liquidity Chgs: {len(raw_data.get('liquidity_changes', []))}")
|
| 1176 |
+
print(f" - Burns: {len(raw_data.get('burns', []))}")
|
| 1177 |
+
print(f" - Supply Locks: {len(raw_data.get('supply_locks', []))}")
|
| 1178 |
+
print(f" - Migrations: {len(raw_data.get('migrations', []))}")
|
| 1179 |
|
| 1180 |
raw_data["protocol_id"] = initial_mint_record.get("protocol")
|
| 1181 |
return raw_data
|
|
|
|
| 1199 |
graph_seed_entities: set,
|
| 1200 |
all_graph_entities: Dict[str, str],
|
| 1201 |
future_trades_for_labels: List[Dict[str, Any]],
|
| 1202 |
+
pooler: EmbeddingPooler,
|
| 1203 |
+
cached_holders_list: List[List[str]] = None
|
| 1204 |
) -> Optional[Dict[str, Any]]:
|
| 1205 |
"""
|
| 1206 |
Processes raw token data into a structured dataset item for a specific T_cutoff.
|
|
|
|
| 1427 |
aggregation_trades,
|
| 1428 |
wallet_data,
|
| 1429 |
total_supply_dec,
|
| 1430 |
+
_register_event,
|
| 1431 |
+
cached_holders_list=cached_holders_list
|
| 1432 |
)
|
| 1433 |
|
| 1434 |
# 7. Finalize Sequence
|
install.sh
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
sudo apt update
|
| 2 |
+
sudo apt install -y curl wget gnupg apt-transport-https ca-certificates dirmngr
|
| 3 |
+
|
| 4 |
+
sudo apt update
|
| 5 |
+
sudo apt install -y pkg-config libudev-dev
|
| 6 |
+
|
| 7 |
+
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
| 8 |
+
source $HOME/.cargo/env
|
| 9 |
+
|
| 10 |
+
# ClickHouse (add repo and install)
|
| 11 |
+
sudo apt-key adv --keyserver hkp://keyserver.ubuntu.com:80 --recv 8919F6BD2B48D754
|
| 12 |
+
echo "deb https://packages.clickhouse.com/deb stable main" | sudo tee /etc/apt/sources.list.d/clickhouse.list
|
| 13 |
+
sudo apt update
|
| 14 |
+
sudo apt install -y clickhouse-server clickhouse-client
|
| 15 |
+
|
| 16 |
+
# Neo4j (add repo and install)
|
| 17 |
+
sudo wget -O - https://debian.neo4j.com/neotechnology.gpg.key | sudo gpg --dearmor -o /usr/share/keyrings/neo4j.gpg
|
| 18 |
+
echo "deb [signed-by=/usr/share/keyrings/neo4j.gpg] https://debian.neo4j.com stable latest" | sudo tee -a /etc/apt/sources.list.d/neo4j.list
|
| 19 |
+
sudo apt update
|
| 20 |
+
sudo apt install -y neo4j
|
| 21 |
+
|
| 22 |
+
# Start Neo4j (Runs on bolt://localhost:7687)
|
| 23 |
+
sudo neo4j-admin dbms set-initial-password neo4j123
|
| 24 |
+
neo4j start
|
| 25 |
+
|
| 26 |
+
clickhouse-server
|
log.log
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2bfaace3cf2aadc0acf9e9714d8df00c44bc545db23c87e7497a7844ba3c98a9
|
| 3 |
+
size 6115919
|
models/model.py
CHANGED
|
@@ -5,6 +5,8 @@ import torch.nn as nn
|
|
| 5 |
import torch.nn.functional as F
|
| 6 |
from transformers import AutoConfig, AutoModel
|
| 7 |
from typing import List, Dict, Any, Optional, Tuple
|
|
|
|
|
|
|
| 8 |
|
| 9 |
# --- NOW, we import all the encoders ---
|
| 10 |
from models.helper_encoders import ContextualTimeEncoder
|
|
@@ -43,6 +45,9 @@ class Oracle(nn.Module):
|
|
| 43 |
self.multi_modal_dim = multi_modal_dim
|
| 44 |
|
| 45 |
|
|
|
|
|
|
|
|
|
|
| 46 |
self.quantiles = quantiles
|
| 47 |
self.horizons_seconds = horizons_seconds
|
| 48 |
self.num_outputs = len(quantiles) * len(horizons_seconds)
|
|
@@ -225,6 +230,77 @@ class Oracle(nn.Module):
|
|
| 225 |
self.to(dtype)
|
| 226 |
print("Oracle model (full pipeline) initialized.")
|
| 227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
def _normalize_and_project(self,
|
| 229 |
features: torch.Tensor,
|
| 230 |
norm_layer: nn.LayerNorm,
|
|
|
|
| 5 |
import torch.nn.functional as F
|
| 6 |
from transformers import AutoConfig, AutoModel
|
| 7 |
from typing import List, Dict, Any, Optional, Tuple
|
| 8 |
+
import os
|
| 9 |
+
import json
|
| 10 |
|
| 11 |
# --- NOW, we import all the encoders ---
|
| 12 |
from models.helper_encoders import ContextualTimeEncoder
|
|
|
|
| 45 |
self.multi_modal_dim = multi_modal_dim
|
| 46 |
|
| 47 |
|
| 48 |
+
self.num_event_types = num_event_types
|
| 49 |
+
self.event_pad_id = event_pad_id
|
| 50 |
+
self.model_config_name = model_config_name
|
| 51 |
self.quantiles = quantiles
|
| 52 |
self.horizons_seconds = horizons_seconds
|
| 53 |
self.num_outputs = len(quantiles) * len(horizons_seconds)
|
|
|
|
| 230 |
self.to(dtype)
|
| 231 |
print("Oracle model (full pipeline) initialized.")
|
| 232 |
|
| 233 |
+
def save_pretrained(self, save_directory: str):
|
| 234 |
+
"""
|
| 235 |
+
Saves the model in a Hugging Face-compatible way.
|
| 236 |
+
"""
|
| 237 |
+
if not os.path.exists(save_directory):
|
| 238 |
+
os.makedirs(save_directory)
|
| 239 |
+
|
| 240 |
+
# 1. Save the inner transformer model using its own save_pretrained
|
| 241 |
+
# This gives us the standard HF config.json and pytorch_model.bin for the backbone
|
| 242 |
+
self.model.save_pretrained(save_directory)
|
| 243 |
+
|
| 244 |
+
# 2. Save the whole Oracle state dict (includes transformer + all custom encoders)
|
| 245 |
+
# We use 'oracle_model.bin' for the full state.
|
| 246 |
+
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
|
| 247 |
+
|
| 248 |
+
# 3. Save Oracle specific metadata for reconstruction
|
| 249 |
+
oracle_config = {
|
| 250 |
+
"num_event_types": self.num_event_types,
|
| 251 |
+
"multi_modal_dim": self.multi_modal_dim,
|
| 252 |
+
"event_pad_id": self.event_pad_id,
|
| 253 |
+
"model_config_name": self.model_config_name,
|
| 254 |
+
"quantiles": self.quantiles,
|
| 255 |
+
"horizons_seconds": self.horizons_seconds,
|
| 256 |
+
"dtype": str(self.dtype),
|
| 257 |
+
"event_type_to_id": self.event_type_to_id
|
| 258 |
+
}
|
| 259 |
+
with open(os.path.join(save_directory, "oracle_config.json"), "w") as f:
|
| 260 |
+
json.dump(oracle_config, f, indent=2)
|
| 261 |
+
|
| 262 |
+
print(f"✅ Oracle model saved to {save_directory}")
|
| 263 |
+
|
| 264 |
+
@classmethod
|
| 265 |
+
def from_pretrained(cls, load_directory: str,
|
| 266 |
+
token_encoder, wallet_encoder, graph_updater, ohlc_embedder, time_encoder):
|
| 267 |
+
"""
|
| 268 |
+
Loads the Oracle model from a saved directory.
|
| 269 |
+
Note: You must still provide the initialized sub-encoders (or we can refactor to save them too).
|
| 270 |
+
"""
|
| 271 |
+
config_path = os.path.join(load_directory, "oracle_config.json")
|
| 272 |
+
with open(config_path, "r") as f:
|
| 273 |
+
config = json.load(f)
|
| 274 |
+
|
| 275 |
+
# Determine dtype from string
|
| 276 |
+
dtype = torch.bfloat16 # Default
|
| 277 |
+
if "float32" in config["dtype"]: dtype = torch.float32
|
| 278 |
+
elif "float16" in config["dtype"]: dtype = torch.float16
|
| 279 |
+
|
| 280 |
+
# Instantiate model
|
| 281 |
+
model = cls(
|
| 282 |
+
token_encoder=token_encoder,
|
| 283 |
+
wallet_encoder=wallet_encoder,
|
| 284 |
+
graph_updater=graph_updater,
|
| 285 |
+
ohlc_embedder=ohlc_embedder,
|
| 286 |
+
time_encoder=time_encoder,
|
| 287 |
+
num_event_types=config["num_event_types"],
|
| 288 |
+
multi_modal_dim=config["multi_modal_dim"],
|
| 289 |
+
event_pad_id=config["event_pad_id"],
|
| 290 |
+
event_type_to_id=config["event_type_to_id"],
|
| 291 |
+
model_config_name=config["model_config_name"],
|
| 292 |
+
quantiles=config["quantiles"],
|
| 293 |
+
horizons_seconds=config["horizons_seconds"],
|
| 294 |
+
dtype=dtype
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
# Load weights
|
| 298 |
+
weight_path = os.path.join(load_directory, "pytorch_model.bin")
|
| 299 |
+
state_dict = torch.load(weight_path, map_location="cpu")
|
| 300 |
+
model.load_state_dict(state_dict)
|
| 301 |
+
print(f"✅ Oracle model loaded from {load_directory}")
|
| 302 |
+
return model
|
| 303 |
+
|
| 304 |
def _normalize_and_project(self,
|
| 305 |
features: torch.Tensor,
|
| 306 |
norm_layer: nn.LayerNorm,
|
models/multi_modal_processor.py
CHANGED
|
@@ -21,9 +21,12 @@ class MultiModalEncoder:
|
|
| 21 |
This class is intended for creating embeddings for vector search.
|
| 22 |
"""
|
| 23 |
|
| 24 |
-
def __init__(self, model_id="google/siglip-so400m-patch16-256-i18n", dtype: torch.dtype = torch.bfloat16):
|
| 25 |
self.model_id = model_id
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
self.dtype = dtype
|
| 29 |
|
|
|
|
| 21 |
This class is intended for creating embeddings for vector search.
|
| 22 |
"""
|
| 23 |
|
| 24 |
+
def __init__(self, model_id="google/siglip-so400m-patch16-256-i18n", dtype: torch.dtype = torch.bfloat16, device: str = None):
|
| 25 |
self.model_id = model_id
|
| 26 |
+
if device:
|
| 27 |
+
self.device = device
|
| 28 |
+
else:
|
| 29 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 30 |
|
| 31 |
self.dtype = dtype
|
| 32 |
|
models/ohlc_embedder.py
CHANGED
|
@@ -18,7 +18,7 @@ class OHLCEmbedder(nn.Module):
|
|
| 18 |
# --- NEW: Interval vocab size ---
|
| 19 |
num_intervals: int,
|
| 20 |
input_channels: int = 2, # Open, Close
|
| 21 |
-
sequence_length: int = 300,
|
| 22 |
cnn_channels: List[int] = [16, 32, 64],
|
| 23 |
kernel_sizes: List[int] = [3, 3, 3],
|
| 24 |
# --- NEW: Interval embedding dim ---
|
|
@@ -30,12 +30,12 @@ class OHLCEmbedder(nn.Module):
|
|
| 30 |
assert len(cnn_channels) == len(kernel_sizes), "cnn_channels and kernel_sizes must have the same length"
|
| 31 |
|
| 32 |
self.dtype = dtype
|
| 33 |
-
self.sequence_length =
|
| 34 |
self.cnn_layers = nn.ModuleList()
|
| 35 |
self.output_dim = output_dim
|
| 36 |
|
| 37 |
in_channels = input_channels
|
| 38 |
-
current_seq_len =
|
| 39 |
|
| 40 |
for i, (out_channels, k_size) in enumerate(zip(cnn_channels, kernel_sizes)):
|
| 41 |
conv = nn.Conv1d(
|
|
|
|
| 18 |
# --- NEW: Interval vocab size ---
|
| 19 |
num_intervals: int,
|
| 20 |
input_channels: int = 2, # Open, Close
|
| 21 |
+
# sequence_length: int = 300, # REMOVED: HARDCODED
|
| 22 |
cnn_channels: List[int] = [16, 32, 64],
|
| 23 |
kernel_sizes: List[int] = [3, 3, 3],
|
| 24 |
# --- NEW: Interval embedding dim ---
|
|
|
|
| 30 |
assert len(cnn_channels) == len(kernel_sizes), "cnn_channels and kernel_sizes must have the same length"
|
| 31 |
|
| 32 |
self.dtype = dtype
|
| 33 |
+
self.sequence_length = 300 # HARDCODED
|
| 34 |
self.cnn_layers = nn.ModuleList()
|
| 35 |
self.output_dim = output_dim
|
| 36 |
|
| 37 |
in_channels = input_channels
|
| 38 |
+
current_seq_len = 300
|
| 39 |
|
| 40 |
for i, (out_channels, k_size) in enumerate(zip(cnn_channels, kernel_sizes)):
|
| 41 |
conv = nn.Conv1d(
|
train.py
CHANGED
|
@@ -4,6 +4,9 @@ import math
|
|
| 4 |
import logging
|
| 5 |
from pathlib import Path
|
| 6 |
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
# Ensure torch/dill have a writable tmp dir
|
| 9 |
_DEFAULT_TMP = Path(os.getenv("TMPDIR_OVERRIDE", "./.tmp"))
|
|
@@ -12,6 +15,11 @@ resolved_tmp = str(_DEFAULT_TMP.resolve())
|
|
| 12 |
for key in ("TMPDIR", "TMP", "TEMP"):
|
| 13 |
os.environ.setdefault(key, resolved_tmp)
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
import torch
|
| 16 |
import torch.nn as nn
|
| 17 |
from torch.utils.data import DataLoader
|
|
@@ -126,7 +134,6 @@ def parse_args() -> argparse.Namespace:
|
|
| 126 |
parser.add_argument("--checkpoint_dir", type=str, default="checkpoints")
|
| 127 |
parser.add_argument("--mixed_precision", type=str, default="bf16")
|
| 128 |
parser.add_argument("--max_seq_len", type=int, default=16000)
|
| 129 |
-
parser.add_argument("--ohlc_seq_len", type=int, default=60)
|
| 130 |
parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[30, 60, 120, 240, 420])
|
| 131 |
parser.add_argument("--quantiles", type=float, nargs="+", default=[0.1, 0.5, 0.9])
|
| 132 |
parser.add_argument("--max_samples", type=int, default=None)
|
|
@@ -200,8 +207,8 @@ def main() -> None:
|
|
| 200 |
horizons = args.horizons_seconds
|
| 201 |
quantiles = args.quantiles
|
| 202 |
max_seq_len = args.max_seq_len
|
| 203 |
-
|
| 204 |
-
|
| 205 |
logger.info(f"Initializing Encoders with dtype={init_dtype}...")
|
| 206 |
|
| 207 |
# Encoders
|
|
@@ -212,39 +219,29 @@ def main() -> None:
|
|
| 212 |
graph_updater = GraphUpdater(time_encoder=time_encoder, dtype=init_dtype)
|
| 213 |
ohlc_embedder = OHLCEmbedder(
|
| 214 |
num_intervals=vocab.NUM_OHLC_INTERVALS,
|
| 215 |
-
sequence_length=ohlc_seq_len,
|
| 216 |
dtype=init_dtype
|
| 217 |
)
|
| 218 |
|
| 219 |
collator = MemecoinCollator(
|
| 220 |
event_type_to_id=vocab.EVENT_TO_ID,
|
| 221 |
device=device, # Note: Collator will handle basic moves, Accelerate handles the rest
|
| 222 |
-
multi_modal_encoder=multi_modal_encoder,
|
| 223 |
dtype=init_dtype,
|
| 224 |
-
ohlc_seq_len=ohlc_seq_len,
|
| 225 |
max_seq_len=max_seq_len
|
| 226 |
)
|
| 227 |
|
| 228 |
-
# DB Connections
|
| 229 |
-
clickhouse_client = ClickHouseClient(
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
)
|
| 233 |
-
|
| 234 |
-
neo4j_auth = ("neo4j", "neo4j123")
|
| 235 |
-
if args.neo4j_user is not None:
|
| 236 |
-
neo4j_auth = (args.neo4j_user, args.neo4j_password or "")
|
| 237 |
-
neo4j_driver = GraphDatabase.driver(args.neo4j_uri, auth=neo4j_auth)
|
| 238 |
-
|
| 239 |
-
data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
|
| 240 |
|
| 241 |
dataset = OracleDataset(
|
| 242 |
-
data_fetcher=
|
| 243 |
horizons_seconds=horizons,
|
| 244 |
quantiles=quantiles,
|
| 245 |
max_samples=args.max_samples,
|
| 246 |
ohlc_stats_path=args.ohlc_stats_path,
|
| 247 |
-
t_cutoff_seconds=int(args.t_cutoff_seconds),
|
| 248 |
cache_dir="/workspace/apollo/data/cache"
|
| 249 |
)
|
| 250 |
|
|
@@ -257,7 +254,7 @@ def main() -> None:
|
|
| 257 |
shuffle=bool(args.shuffle),
|
| 258 |
num_workers=int(args.num_workers),
|
| 259 |
pin_memory=bool(args.pin_memory),
|
| 260 |
-
collate_fn=
|
| 261 |
)
|
| 262 |
|
| 263 |
# --- 3. Model Init ---
|
|
@@ -442,25 +439,36 @@ def main() -> None:
|
|
| 442 |
if accelerator.is_main_process:
|
| 443 |
save_path = checkpoint_dir / f"checkpoint-{total_steps}"
|
| 444 |
accelerator.save_state(output_dir=str(save_path))
|
| 445 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
|
| 447 |
# End of Epoch Handling
|
| 448 |
if valid_batches > 0:
|
| 449 |
avg_loss = epoch_loss / valid_batches
|
| 450 |
if accelerator.is_main_process:
|
| 451 |
logger.info(f"Epoch {epoch+1} complete. Avg loss: {avg_loss:.6f}")
|
| 452 |
-
accelerator.log({"train/loss_epoch": avg_loss}, step=
|
| 453 |
|
| 454 |
-
# Save Checkpoint at end of epoch
|
| 455 |
save_path = checkpoint_dir / f"epoch_{epoch+1}"
|
| 456 |
-
accelerator.save_state(output_dir=str(save_path))
|
| 457 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
else:
|
| 459 |
if accelerator.is_main_process:
|
| 460 |
logger.warning(f"Epoch {epoch+1}: No valid batches processed.")
|
| 461 |
|
| 462 |
accelerator.end_training()
|
| 463 |
-
neo4j_driver.close()
|
| 464 |
|
| 465 |
if __name__ == "__main__":
|
| 466 |
main()
|
|
|
|
| 4 |
import logging
|
| 5 |
from pathlib import Path
|
| 6 |
from typing import Any, Dict, List, Optional, Tuple
|
| 7 |
+
import functools
|
| 8 |
+
|
| 9 |
+
import torch.multiprocessing as mp
|
| 10 |
|
| 11 |
# Ensure torch/dill have a writable tmp dir
|
| 12 |
_DEFAULT_TMP = Path(os.getenv("TMPDIR_OVERRIDE", "./.tmp"))
|
|
|
|
| 15 |
for key in ("TMPDIR", "TMP", "TEMP"):
|
| 16 |
os.environ.setdefault(key, resolved_tmp)
|
| 17 |
|
| 18 |
+
try:
|
| 19 |
+
mp.set_start_method('spawn', force=True)
|
| 20 |
+
except RuntimeError:
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
import torch
|
| 24 |
import torch.nn as nn
|
| 25 |
from torch.utils.data import DataLoader
|
|
|
|
| 134 |
parser.add_argument("--checkpoint_dir", type=str, default="checkpoints")
|
| 135 |
parser.add_argument("--mixed_precision", type=str, default="bf16")
|
| 136 |
parser.add_argument("--max_seq_len", type=int, default=16000)
|
|
|
|
| 137 |
parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[30, 60, 120, 240, 420])
|
| 138 |
parser.add_argument("--quantiles", type=float, nargs="+", default=[0.1, 0.5, 0.9])
|
| 139 |
parser.add_argument("--max_samples", type=int, default=None)
|
|
|
|
| 207 |
horizons = args.horizons_seconds
|
| 208 |
quantiles = args.quantiles
|
| 209 |
max_seq_len = args.max_seq_len
|
| 210 |
+
max_seq_len = args.max_seq_len
|
| 211 |
+
|
| 212 |
logger.info(f"Initializing Encoders with dtype={init_dtype}...")
|
| 213 |
|
| 214 |
# Encoders
|
|
|
|
| 219 |
graph_updater = GraphUpdater(time_encoder=time_encoder, dtype=init_dtype)
|
| 220 |
ohlc_embedder = OHLCEmbedder(
|
| 221 |
num_intervals=vocab.NUM_OHLC_INTERVALS,
|
|
|
|
| 222 |
dtype=init_dtype
|
| 223 |
)
|
| 224 |
|
| 225 |
collator = MemecoinCollator(
|
| 226 |
event_type_to_id=vocab.EVENT_TO_ID,
|
| 227 |
device=device, # Note: Collator will handle basic moves, Accelerate handles the rest
|
| 228 |
+
# multi_modal_encoder=multi_modal_encoder, # REMOVED: Uses lazy loading internally
|
| 229 |
dtype=init_dtype,
|
|
|
|
| 230 |
max_seq_len=max_seq_len
|
| 231 |
)
|
| 232 |
|
| 233 |
+
# DB Connections - REMOVED for Training (Using Cache)
|
| 234 |
+
# clickhouse_client = ClickHouseClient(...)
|
| 235 |
+
# neo4j_driver = GraphDatabase.driver(...)
|
| 236 |
+
# data_fetcher = DataFetcher(...)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
dataset = OracleDataset(
|
| 239 |
+
data_fetcher=None, # Training Mode (Reader Only)
|
| 240 |
horizons_seconds=horizons,
|
| 241 |
quantiles=quantiles,
|
| 242 |
max_samples=args.max_samples,
|
| 243 |
ohlc_stats_path=args.ohlc_stats_path,
|
| 244 |
+
t_cutoff_seconds=int(args.t_cutoff_seconds) if hasattr(args, 't_cutoff_seconds') else 60,
|
| 245 |
cache_dir="/workspace/apollo/data/cache"
|
| 246 |
)
|
| 247 |
|
|
|
|
| 254 |
shuffle=bool(args.shuffle),
|
| 255 |
num_workers=int(args.num_workers),
|
| 256 |
pin_memory=bool(args.pin_memory),
|
| 257 |
+
collate_fn=functools.partial(filtered_collate, collator)
|
| 258 |
)
|
| 259 |
|
| 260 |
# --- 3. Model Init ---
|
|
|
|
| 439 |
if accelerator.is_main_process:
|
| 440 |
save_path = checkpoint_dir / f"checkpoint-{total_steps}"
|
| 441 |
accelerator.save_state(output_dir=str(save_path))
|
| 442 |
+
|
| 443 |
+
# NEW: Save in standard HF-loadable way
|
| 444 |
+
hf_save_path = save_path / "hf_model"
|
| 445 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
| 446 |
+
unwrapped_model.save_pretrained(str(hf_save_path))
|
| 447 |
+
|
| 448 |
+
logger.info(f"Saved checkpoint and HF-style model to {save_path}")
|
| 449 |
|
| 450 |
# End of Epoch Handling
|
| 451 |
if valid_batches > 0:
|
| 452 |
avg_loss = epoch_loss / valid_batches
|
| 453 |
if accelerator.is_main_process:
|
| 454 |
logger.info(f"Epoch {epoch+1} complete. Avg loss: {avg_loss:.6f}")
|
| 455 |
+
accelerator.log({"train/loss_epoch": avg_loss}, step=total_steps)
|
| 456 |
|
| 457 |
+
# Save Checkpoint at end of epoch (REMOVED: saving every epoch is too much)
|
| 458 |
save_path = checkpoint_dir / f"epoch_{epoch+1}"
|
| 459 |
+
# accelerator.save_state(output_dir=str(save_path))
|
| 460 |
+
# hf_save_path = save_path / "hf_model"
|
| 461 |
+
# unwrapped_model = accelerator.unwrap_model(model)
|
| 462 |
+
# unwrapped_model.save_pretrained(str(hf_save_path))
|
| 463 |
+
|
| 464 |
+
# logger.info(f"Saved and HF-style model (EOF) to {save_path}")
|
| 465 |
+
pass
|
| 466 |
else:
|
| 467 |
if accelerator.is_main_process:
|
| 468 |
logger.warning(f"Epoch {epoch+1}: No valid batches processed.")
|
| 469 |
|
| 470 |
accelerator.end_training()
|
| 471 |
+
# neo4j_driver.close() # REMOVED
|
| 472 |
|
| 473 |
if __name__ == "__main__":
|
| 474 |
main()
|
train.sh
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
accelerate launch train.py \
|
| 2 |
--epochs 10 \
|
| 3 |
--batch_size 1 \
|
| 4 |
--learning_rate 1e-4 \
|
|
@@ -7,16 +7,14 @@ accelerate launch train.py \
|
|
| 7 |
--max_grad_norm 1.0 \
|
| 8 |
--seed 42 \
|
| 9 |
--log_every 1 \
|
| 10 |
-
--save_every
|
| 11 |
--tensorboard_dir runs/oracle \
|
| 12 |
--checkpoint_dir checkpoints \
|
| 13 |
--mixed_precision bf16 \
|
| 14 |
-
--max_seq_len
|
| 15 |
-
--ohlc_seq_len 300 \
|
| 16 |
--horizons_seconds 30 60 120 240 420 \
|
| 17 |
--quantiles 0.1 0.5 0.9 \
|
| 18 |
--ohlc_stats_path ./data/ohlc_stats.npz \
|
| 19 |
-
--t_cutoff_seconds 60 \
|
| 20 |
--num_workers 4 \
|
| 21 |
--clickhouse_host localhost \
|
| 22 |
--clickhouse_port 9000 \
|
|
|
|
| 1 |
+
/venv/main/bin/accelerate launch train.py \
|
| 2 |
--epochs 10 \
|
| 3 |
--batch_size 1 \
|
| 4 |
--learning_rate 1e-4 \
|
|
|
|
| 7 |
--max_grad_norm 1.0 \
|
| 8 |
--seed 42 \
|
| 9 |
--log_every 1 \
|
| 10 |
+
--save_every 10 \
|
| 11 |
--tensorboard_dir runs/oracle \
|
| 12 |
--checkpoint_dir checkpoints \
|
| 13 |
--mixed_precision bf16 \
|
| 14 |
+
--max_seq_len 4096 \
|
|
|
|
| 15 |
--horizons_seconds 30 60 120 240 420 \
|
| 16 |
--quantiles 0.1 0.5 0.9 \
|
| 17 |
--ohlc_stats_path ./data/ohlc_stats.npz \
|
|
|
|
| 18 |
--num_workers 4 \
|
| 19 |
--clickhouse_host localhost \
|
| 20 |
--clickhouse_port 9000 \
|