zirobtc commited on
Commit
0e3516b
·
1 Parent(s): 1167296

Upload folder using huggingface_hub

Browse files
.gitignore CHANGED
@@ -6,9 +6,10 @@ __pycache__/
6
  runs/
7
 
8
  data/pump_fun
9
-
10
  .env
11
 
12
  data/cache
13
  .tmp/
14
- .cache/
 
 
6
  runs/
7
 
8
  data/pump_fun
9
+ data/cache
10
  .env
11
 
12
  data/cache
13
  .tmp/
14
+ .cache/
15
+ checkpoints/
data/data_collator.py CHANGED
@@ -6,11 +6,26 @@ from torch.nn.utils.rnn import pad_sequence
6
  from typing import List, Dict, Any, Tuple, Optional, Union
7
  from collections import defaultdict
8
  from PIL import Image
9
- from models.multi_modal_processor import MultiModalEncoder
 
10
 
11
- # Encoders are NO LONGER imported here
12
- import models.vocabulary as vocab # For IDs, config sizes
13
- from data.data_loader import EmbeddingPooler # Import for type hinting and instantiation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  NATIVE_MINT = "So11111111111111111111111111111111111111112"
16
  QUOTE_MINTS = {
@@ -28,19 +43,19 @@ class MemecoinCollator:
28
  def __init__(self,
29
  event_type_to_id: Dict[str, int],
30
  device: torch.device,
31
- multi_modal_encoder: MultiModalEncoder,
32
  dtype: torch.dtype,
33
- ohlc_seq_len: int = 300,
34
- max_seq_len: Optional[int] = None
35
  ):
36
  self.event_type_to_id = event_type_to_id
37
  self.pad_token_id = event_type_to_id.get('__PAD__', 0)
38
- self.multi_modal_encoder = multi_modal_encoder
 
39
  self.entity_pad_idx = 0
40
 
41
  self.device = device
42
  self.dtype = dtype
43
- self.ohlc_seq_len = ohlc_seq_len
44
  self.max_seq_len = max_seq_len
45
 
46
  def _collate_features_for_encoder(self, entities: List[Dict], feature_keys: List[str], device: torch.device, entity_type: str) -> Dict[str, Any]:
@@ -205,12 +220,15 @@ class MemecoinCollator:
205
  all_items_sorted = batch_wide_pooler.get_all_items()
206
  texts_to_encode = [d['item'] for d in all_items_sorted if isinstance(d['item'], str)]
207
  images_to_encode = [d['item'] for d in all_items_sorted if isinstance(d['item'], Image.Image)]
 
 
 
208
 
209
- text_embeds = self.multi_modal_encoder(texts_to_encode) if texts_to_encode else torch.empty(0)
210
- image_embeds = self.multi_modal_encoder(images_to_encode) if images_to_encode else torch.empty(0)
211
 
212
  # Create the final lookup tensor and fill it based on original item type
213
- batch_embedding_pool = torch.zeros(len(all_items_sorted), self.multi_modal_encoder.embedding_dim, device=self.device, dtype=self.dtype)
214
  text_cursor, image_cursor = 0, 0
215
  for i, item_data in enumerate(all_items_sorted):
216
  if isinstance(item_data['item'], str):
 
6
  from typing import List, Dict, Any, Tuple, Optional, Union
7
  from collections import defaultdict
8
  from PIL import Image
9
+ # --- GLOBAL SINGLETON FOR WORKER PROCESSES ---
10
+ _WORKER_ENCODER = None
11
 
12
+ def _get_worker_encoder(model_id: str, dtype: torch.dtype, device: torch.device):
13
+ """
14
+ Lazy-loads the encoder on the worker process.
15
+ FORCED TO CPU to save VRAM when using multiple workers.
16
+ """
17
+ global _WORKER_ENCODER
18
+ if _WORKER_ENCODER is None:
19
+ print(f"[Worker] Initializing MultiModalEncoder (SigLIP) on CPU (VRAM optimization)...")
20
+ # Local import to avoid top-level dependency issues
21
+ from models.multi_modal_processor import MultiModalEncoder
22
+ # Explicitly pass device="cpu"
23
+ _WORKER_ENCODER = MultiModalEncoder(model_id=model_id, dtype=dtype, device="cpu")
24
+
25
+ return _WORKER_ENCODER
26
+
27
+ import models.vocabulary as vocab
28
+ from data.data_loader import EmbeddingPooler
29
 
30
  NATIVE_MINT = "So11111111111111111111111111111111111111112"
31
  QUOTE_MINTS = {
 
43
  def __init__(self,
44
  event_type_to_id: Dict[str, int],
45
  device: torch.device,
 
46
  dtype: torch.dtype,
47
+ max_seq_len: Optional[int] = None,
48
+ model_id: str = "google/siglip-so400m-patch16-256-i18n"
49
  ):
50
  self.event_type_to_id = event_type_to_id
51
  self.pad_token_id = event_type_to_id.get('__PAD__', 0)
52
+ # self.multi_modal_encoder = multi_modal_encoder # DEPRECATED
53
+ self.model_id = model_id
54
  self.entity_pad_idx = 0
55
 
56
  self.device = device
57
  self.dtype = dtype
58
+ self.ohlc_seq_len = 300 # HARDCODED
59
  self.max_seq_len = max_seq_len
60
 
61
  def _collate_features_for_encoder(self, entities: List[Dict], feature_keys: List[str], device: torch.device, entity_type: str) -> Dict[str, Any]:
 
220
  all_items_sorted = batch_wide_pooler.get_all_items()
221
  texts_to_encode = [d['item'] for d in all_items_sorted if isinstance(d['item'], str)]
222
  images_to_encode = [d['item'] for d in all_items_sorted if isinstance(d['item'], Image.Image)]
223
+
224
+ # LAZY LOAD ENCODER
225
+ encoder = _get_worker_encoder(self.model_id, self.dtype, self.device)
226
 
227
+ text_embeds = encoder(texts_to_encode).to(self.device) if texts_to_encode else torch.empty(0)
228
+ image_embeds = encoder(images_to_encode).to(self.device) if images_to_encode else torch.empty(0)
229
 
230
  # Create the final lookup tensor and fill it based on original item type
231
+ batch_embedding_pool = torch.zeros(len(all_items_sorted), encoder.embedding_dim, device=self.device, dtype=self.dtype)
232
  text_cursor, image_cursor = 0, 0
233
  for i, item_data in enumerate(all_items_sorted):
234
  if isinstance(item_data['item'], str):
data/data_fetcher.py CHANGED
@@ -626,9 +626,11 @@ class DataFetcher:
626
 
627
  return token_details
628
 
629
- def fetch_trades_for_token(self, token_address: str, T_cutoff: datetime.datetime, count_threshold: int, early_limit: int, recent_limit: int) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
630
  """
631
- Fetches trades for a token, using a 3-part H/B/H strategy if the total count exceeds a threshold.
 
 
632
  Returns three lists: early_trades, middle_trades, recent_trades.
633
  """
634
  if not token_address:
@@ -636,31 +638,36 @@ class DataFetcher:
636
 
637
  params = {'token_address': token_address, 'T_cutoff': T_cutoff}
638
 
639
- # 1. Get the total count of trades for the token before the cutoff
640
- count_query = "SELECT count() FROM trades WHERE base_address = %(token_address)s AND timestamp <= %(T_cutoff)s"
641
- try:
642
- total_trades = self.db_client.execute(count_query, params)[0][0]
643
- print(f"INFO: Found {total_trades} total trades for token {token_address} before {T_cutoff}.")
644
- except Exception as e:
645
- print(f"ERROR: Could not count trades for token {token_address}: {e}")
646
- return [], [], []
647
-
648
- # 2. Decide which query to use based on the count
649
- if total_trades < count_threshold:
650
- print("INFO: Fetching all trades (count is below H/B/H threshold).")
 
 
 
 
 
651
  query = "SELECT * FROM trades WHERE base_address = %(token_address)s AND timestamp <= %(T_cutoff)s ORDER BY timestamp ASC"
652
  try:
653
  rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
654
  if not rows: return [], [], []
655
  columns = [col[0] for col in columns_info]
656
  all_trades = [dict(zip(columns, row)) for row in rows]
657
- # When not using HBH, all trades are considered "early"
658
  return all_trades, [], []
659
  except Exception as e:
660
  print(f"ERROR: Failed to fetch all trades for token {token_address}: {e}")
661
  return [], [], []
662
 
663
- # 3. Use the H/B/H strategy if the count is high
664
  print("INFO: Fetching trades using 3-part High-Def/Blurry/High-Def strategy.")
665
  try:
666
  # Fetch Early (High-Def)
@@ -792,7 +799,7 @@ class DataFetcher:
792
  ORDER BY timestamp ASC
793
  """
794
  params = {'token_address': token_address, 'T_cutoff': T_cutoff}
795
- print(f"INFO: Fetching pool creation events for {token_address}.")
796
 
797
  try:
798
  rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
@@ -831,7 +838,7 @@ class DataFetcher:
831
  ORDER BY timestamp ASC
832
  """
833
  params = {'pool_addresses': pool_addresses, 'T_cutoff': T_cutoff}
834
- print(f"INFO: Fetching liquidity change events for {len(pool_addresses)} pool(s).")
835
 
836
  try:
837
  rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
@@ -870,7 +877,7 @@ class DataFetcher:
870
  ORDER BY timestamp ASC
871
  """
872
  params = {'token': token_address, 'T_cutoff': T_cutoff}
873
- print(f"INFO: Fetching fee collection events for {token_address}.")
874
 
875
  try:
876
  rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
@@ -908,7 +915,7 @@ class DataFetcher:
908
  ORDER BY timestamp ASC
909
  """
910
  params = {'token': token_address, 'T_cutoff': T_cutoff}
911
- print(f"INFO: Fetching migrations for {token_address}.")
912
  try:
913
  rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
914
  if not rows:
@@ -946,7 +953,7 @@ class DataFetcher:
946
  ORDER BY timestamp ASC
947
  """
948
  params = {'token': token_address, 'T_cutoff': T_cutoff}
949
- print(f"INFO: Fetching burn events for {token_address}.")
950
 
951
  try:
952
  rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
@@ -987,7 +994,7 @@ class DataFetcher:
987
  ORDER BY timestamp ASC
988
  """
989
  params = {'token': token_address, 'T_cutoff': T_cutoff}
990
- print(f"INFO: Fetching supply lock events for {token_address}.")
991
 
992
  try:
993
  rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
@@ -1020,7 +1027,7 @@ class DataFetcher:
1020
  LIMIT %(limit)s;
1021
  """
1022
  params = {'token': token_address, 'T_cutoff': T_cutoff, 'limit': int(limit)}
1023
- print(f"INFO: Fetching top holders for snapshot for {token_address} (limit {limit}).")
1024
  try:
1025
  rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
1026
  if not rows:
@@ -1050,7 +1057,7 @@ class DataFetcher:
1050
  WHERE rn_per_holding = 1 AND current_balance > 0;
1051
  """
1052
  params = {'token': token_address, 'T_cutoff': T_cutoff}
1053
- print(f"INFO: Counting total holders for {token_address} at cutoff.")
1054
  try:
1055
  rows = self.db_client.execute(query, params)
1056
  if not rows:
@@ -1067,12 +1074,20 @@ class DataFetcher:
1067
  max_horizon_seconds: int = 3600,
1068
  include_wallet_data: bool = True,
1069
  include_graph: bool = True,
1070
- min_trades: int = 0
 
 
 
1071
  ) -> Optional[Dict[str, Any]]:
1072
  """
1073
  Fetches ALL available data for a token up to the maximum horizon.
1074
  This data is agnostic of T_cutoff and will be masked/filtered dynamically during training.
1075
  Wallet/graph data can be skipped to avoid caching T_cutoff-dependent features.
 
 
 
 
 
1076
  """
1077
 
1078
  # 1. Calculate the absolute maximum timestamp we care about (mint + max_horizon)
@@ -1086,8 +1101,9 @@ class DataFetcher:
1086
  # So we pass max_limit_time as the "cutoff" for the purpose of raw data collection.
1087
 
1088
  # We use a large enough limit to get all relevant trades for the session
 
1089
  early_trades, middle_trades, recent_trades = self.fetch_trades_for_token(
1090
- token_address, max_limit_time, 30000, 10000, 15000
1091
  )
1092
 
1093
  # Combine and deduplicate trades
@@ -1099,12 +1115,26 @@ class DataFetcher:
1099
 
1100
  sorted_trades = sorted(list(all_trades.values()), key=lambda x: x['timestamp'])
1101
 
 
 
 
 
 
 
 
 
1102
  if len(sorted_trades) < min_trades:
1103
  print(f" SKIP: Token {token_address} has only {len(sorted_trades)} trades (min required: {min_trades}). skipping fetches.")
1104
  return None
1105
 
1106
  # 3. Fetch other events
1107
- transfers = self.fetch_transfers_for_token(token_address, max_limit_time, 0.0) # 0.0 means fetch all
 
 
 
 
 
 
1108
  pool_creations = self.fetch_pool_creations_for_token(token_address, max_limit_time)
1109
 
1110
  # Collect pool addresses to fetch liquidity changes
 
626
 
627
  return token_details
628
 
629
+ def fetch_trades_for_token(self, token_address: str, T_cutoff: datetime.datetime, count_threshold: int, early_limit: int, recent_limit: int, full_history: bool = False) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
630
  """
631
+ Fetches trades for a token.
632
+ If full_history is True, fetches ALL trades (ignores H/B/H limits).
633
+ Otherwise, uses the 3-part H/B/H strategy if the total count exceeds a threshold.
634
  Returns three lists: early_trades, middle_trades, recent_trades.
635
  """
636
  if not token_address:
 
638
 
639
  params = {'token_address': token_address, 'T_cutoff': T_cutoff}
640
 
641
+ # 1. Get the total count if we care about H/B/H logic
642
+ if not full_history:
643
+ count_query = "SELECT count() FROM trades WHERE base_address = %(token_address)s AND timestamp <= %(T_cutoff)s"
644
+ try:
645
+ total_trades = self.db_client.execute(count_query, params)[0][0]
646
+ print(f"INFO: Found {total_trades} total trades for token {token_address} before {T_cutoff}.")
647
+ except Exception as e:
648
+ print(f"ERROR: Could not count trades for token {token_address}: {e}")
649
+ return [], [], []
650
+ else:
651
+ total_trades = 0 # Dummy value, ignored
652
+
653
+ # 2. Decide which query to use
654
+ # If full_history is ON, or count is low, fetch everything.
655
+ if full_history or total_trades < count_threshold:
656
+ mode = "Full History" if full_history else "Low Count"
657
+ # print(f"INFO: Fetching all trades ({mode}).")
658
  query = "SELECT * FROM trades WHERE base_address = %(token_address)s AND timestamp <= %(T_cutoff)s ORDER BY timestamp ASC"
659
  try:
660
  rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
661
  if not rows: return [], [], []
662
  columns = [col[0] for col in columns_info]
663
  all_trades = [dict(zip(columns, row)) for row in rows]
664
+ # When not using HBH or fetching full history, all trades are considered "early" (or just one big block)
665
  return all_trades, [], []
666
  except Exception as e:
667
  print(f"ERROR: Failed to fetch all trades for token {token_address}: {e}")
668
  return [], [], []
669
 
670
+ # 3. Use the H/B/H strategy if the count is high AND not full_history
671
  print("INFO: Fetching trades using 3-part High-Def/Blurry/High-Def strategy.")
672
  try:
673
  # Fetch Early (High-Def)
 
799
  ORDER BY timestamp ASC
800
  """
801
  params = {'token_address': token_address, 'T_cutoff': T_cutoff}
802
+ # print(f"INFO: Fetching pool creation events for {token_address}.")
803
 
804
  try:
805
  rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
 
838
  ORDER BY timestamp ASC
839
  """
840
  params = {'pool_addresses': pool_addresses, 'T_cutoff': T_cutoff}
841
+ # print(f"INFO: Fetching liquidity change events for {len(pool_addresses)} pool(s).")
842
 
843
  try:
844
  rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
 
877
  ORDER BY timestamp ASC
878
  """
879
  params = {'token': token_address, 'T_cutoff': T_cutoff}
880
+ # print(f"INFO: Fetching fee collection events for {token_address}.")
881
 
882
  try:
883
  rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
 
915
  ORDER BY timestamp ASC
916
  """
917
  params = {'token': token_address, 'T_cutoff': T_cutoff}
918
+ # print(f"INFO: Fetching migrations for {token_address}.")
919
  try:
920
  rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
921
  if not rows:
 
953
  ORDER BY timestamp ASC
954
  """
955
  params = {'token': token_address, 'T_cutoff': T_cutoff}
956
+ # print(f"INFO: Fetching burn events for {token_address}.")
957
 
958
  try:
959
  rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
 
994
  ORDER BY timestamp ASC
995
  """
996
  params = {'token': token_address, 'T_cutoff': T_cutoff}
997
+ # print(f"INFO: Fetching supply lock events for {token_address}.")
998
 
999
  try:
1000
  rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
 
1027
  LIMIT %(limit)s;
1028
  """
1029
  params = {'token': token_address, 'T_cutoff': T_cutoff, 'limit': int(limit)}
1030
+ # print(f"INFO: Fetching top holders for snapshot for {token_address} (limit {limit}).")
1031
  try:
1032
  rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
1033
  if not rows:
 
1057
  WHERE rn_per_holding = 1 AND current_balance > 0;
1058
  """
1059
  params = {'token': token_address, 'T_cutoff': T_cutoff}
1060
+ # print(f"INFO: Counting total holders for {token_address} at timestamp {T_cutoff}.")
1061
  try:
1062
  rows = self.db_client.execute(query, params)
1063
  if not rows:
 
1074
  max_horizon_seconds: int = 3600,
1075
  include_wallet_data: bool = True,
1076
  include_graph: bool = True,
1077
+ min_trades: int = 0,
1078
+ full_history: bool = False,
1079
+ prune_failed: bool = False,
1080
+ prune_transfers: bool = False
1081
  ) -> Optional[Dict[str, Any]]:
1082
  """
1083
  Fetches ALL available data for a token up to the maximum horizon.
1084
  This data is agnostic of T_cutoff and will be masked/filtered dynamically during training.
1085
  Wallet/graph data can be skipped to avoid caching T_cutoff-dependent features.
1086
+
1087
+ Args:
1088
+ full_history: If True, fetches ALL trades ignoring H/B/H limits.
1089
+ prune_failed: If True, filters out failed trades from the result.
1090
+ prune_transfers: If True, skips fetching transfers entirely.
1091
  """
1092
 
1093
  # 1. Calculate the absolute maximum timestamp we care about (mint + max_horizon)
 
1101
  # So we pass max_limit_time as the "cutoff" for the purpose of raw data collection.
1102
 
1103
  # We use a large enough limit to get all relevant trades for the session
1104
+ # If full_history is True, these limits are ignored inside the method.
1105
  early_trades, middle_trades, recent_trades = self.fetch_trades_for_token(
1106
+ token_address, max_limit_time, 30000, 10000, 15000, full_history=full_history
1107
  )
1108
 
1109
  # Combine and deduplicate trades
 
1115
 
1116
  sorted_trades = sorted(list(all_trades.values()), key=lambda x: x['timestamp'])
1117
 
1118
+ # --- PRUNING FAILED TRADES ---
1119
+ if prune_failed:
1120
+ original_count = len(sorted_trades)
1121
+ sorted_trades = [t for t in sorted_trades if t.get('success', False)]
1122
+ if len(sorted_trades) < original_count:
1123
+ # print(f" INFO: Pruned {original_count - len(sorted_trades)} failed trades.")
1124
+ pass
1125
+
1126
  if len(sorted_trades) < min_trades:
1127
  print(f" SKIP: Token {token_address} has only {len(sorted_trades)} trades (min required: {min_trades}). skipping fetches.")
1128
  return None
1129
 
1130
  # 3. Fetch other events
1131
+ # --- PRUNING TRANSFERS ---
1132
+ if prune_transfers:
1133
+ transfers = []
1134
+ # print(" INFO: Pruning transfers (skipping fetch).")
1135
+ else:
1136
+ transfers = self.fetch_transfers_for_token(token_address, max_limit_time, 0.0) # 0.0 means fetch all
1137
+
1138
  pool_creations = self.fetch_pool_creations_for_token(token_address, max_limit_time)
1139
 
1140
  # Collect pool addresses to fetch liquidity changes
data/data_loader.py CHANGED
@@ -97,11 +97,11 @@ class OracleDataset(Dataset):
97
  input sequence for the model.
98
  """
99
  def __init__(self,
100
- data_fetcher: DataFetcher, # NEW: Pass the fetcher instance
101
  horizons_seconds: List[int] = [],
102
  quantiles: List[float] = [],
103
  max_samples: Optional[int] = None,
104
- ohlc_stats_path: Union[str, Path] = "./data/ohlc_stats.npz", # NEW: Add stats path parameter
105
  token_allowlist: Optional[List[str]] = None,
106
  t_cutoff_seconds: int = 60,
107
  cache_dir: Optional[Union[str, Path]] = None,
@@ -273,7 +273,8 @@ class OracleDataset(Dataset):
273
  aggregation_trades: List[Dict[str, Any]],
274
  wallet_data: Dict[str, Any],
275
  total_supply_dec: float,
276
- _register_event_fn
 
277
  ) -> None:
278
  # Prepare helper sets and maps (static sniper set based on earliest buyers)
279
  all_buy_trades = sorted([e for e in trade_events if e.get('trade_direction') == 0 and e.get('success', False)], key=lambda x: x['timestamp'])
@@ -304,14 +305,25 @@ class OracleDataset(Dataset):
304
 
305
  buyers_seen_global = set()
306
  prev_holders_count = 0
307
- for ts_value in oc_snapshot_times:
308
  window_start = ts_value - interval_sec
309
  trades_win = [e for e in trade_events if e.get('success', False) and window_start < e['timestamp'] <= ts_value]
310
  xfers_win = [e for e in transfer_events if window_start < e['timestamp'] <= ts_value]
311
 
312
  # Per-snapshot holder distribution at ts_value
313
- cutoff_dt_ts = datetime.datetime.fromtimestamp(ts_value, tz=datetime.timezone.utc)
314
- holder_records_ts = self.fetcher.fetch_token_holders_for_snapshot(token_address, cutoff_dt_ts, limit=HOLDER_SNAPSHOT_TOP_K)
 
 
 
 
 
 
 
 
 
 
 
315
  holder_entries_ts = []
316
  for rec in holder_records_ts:
317
  addr = rec.get('wallet_address')
@@ -363,8 +375,7 @@ class OracleDataset(Dataset):
363
  buyers_seen_global.add(wa)
364
 
365
  # Compute growth against previous snapshot endpoint.
366
- end_dt = datetime.datetime.fromtimestamp(ts_value, tz=datetime.timezone.utc)
367
- holders_end = self.fetcher.fetch_total_holders_count_for_token(token_address, end_dt)
368
  total_holders = float(holders_end)
369
  delta_holders = holders_end - prev_holders_count
370
  holder_growth_rate = float(delta_holders)
@@ -415,7 +426,7 @@ class OracleDataset(Dataset):
415
 
416
  # Fetch all token details in ONE batch query
417
  all_deployed_token_details = {}
418
- if all_deployed_tokens:
419
  all_deployed_token_details = self.fetcher.fetch_deployed_token_details(list(all_deployed_tokens), T_cutoff)
420
 
421
  for addr, profile in profiles.items():
@@ -454,18 +465,24 @@ class OracleDataset(Dataset):
454
  profile['deployed_tokens_avg_peak_mc_usd'] = torch.mean(torch.tensor(peak_mcs)).item() if peak_mcs else 0.0
455
  profile['deployed_tokens_median_peak_mc_usd'] = torch.median(torch.tensor(peak_mcs)).item() if peak_mcs else 0.0
456
 
457
- def _process_wallet_data(self, wallet_addresses: List[str], token_data: Dict[str, Any], pooler: EmbeddingPooler, T_cutoff: datetime.datetime) -> tuple[Dict[str, Dict[str, Any]], Dict[str, Dict[str, Any]]]:
 
458
  """
459
- Fetches and processes profile, social, and holdings data for a list of wallets.
460
- Uses a T_cutoff to ensure data is point-in-time accurate.
461
  """
462
  if not wallet_addresses:
463
  return {}, token_data
464
 
465
- print(f"INFO: Processing wallet data for {len(wallet_addresses)} unique wallets...")
466
- # Bulk fetch all data
467
- profiles, socials = self.fetcher.fetch_wallet_profiles_and_socials(wallet_addresses, T_cutoff)
468
- holdings = self.fetcher.fetch_wallet_holdings(wallet_addresses, T_cutoff)
 
 
 
 
 
 
469
 
470
  valid_wallets = [addr for addr in wallet_addresses if addr in profiles]
471
  dropped_wallets = set(wallet_addresses) - set(valid_wallets)
@@ -618,8 +635,11 @@ class OracleDataset(Dataset):
618
  return {}
619
 
620
  if token_data is None:
621
- print(f"INFO: Processing token data for {len(token_addresses)} unique tokens...")
622
- token_data = self.fetcher.fetch_token_data(token_addresses, T_cutoff)
 
 
 
623
 
624
  # --- NEW: Print the raw fetched token data as requested ---
625
  print("\n--- RAW TOKEN DATA FROM DATABASE ---")
@@ -793,14 +813,13 @@ class OracleDataset(Dataset):
793
  try:
794
  raw_data = torch.load(filepath, map_location='cpu', weights_only=False)
795
  except Exception as e:
796
- print(f"ERROR: Could not load cached item {filepath}: {e}")
797
- return None
798
  else:
799
  # Online mode fallback
800
  raw_data = self.__cacheitem__(idx)
801
 
802
  if not raw_data:
803
- return None
804
 
805
  required_keys = [
806
  "mint_timestamp",
@@ -822,8 +841,8 @@ class OracleDataset(Dataset):
822
  f"Cached sample missing raw fields ({missing_keys}). Rebuild cache with raw caching enabled."
823
  )
824
 
825
- if not self.fetcher:
826
- raise RuntimeError("Data fetcher required for T_cutoff-dependent data.")
827
 
828
  def _timestamp_to_order_value(ts_value: Any) -> float:
829
  if isinstance(ts_value, datetime.datetime):
@@ -904,34 +923,53 @@ class OracleDataset(Dataset):
904
  if _timestamp_to_order_value(liq.get('timestamp')) <= cutoff_ts:
905
  _add_wallet(liq.get('lp_provider'), wallets_to_fetch)
906
 
907
- holder_records = self.fetcher.fetch_token_holders_for_snapshot(
908
- token_address,
909
- T_cutoff,
910
- limit=HOLDER_SNAPSHOT_TOP_K
911
- )
 
 
 
 
 
 
912
  for holder in holder_records:
913
  _add_wallet(holder.get('wallet_address'), wallets_to_fetch)
914
 
915
  pooler = EmbeddingPooler()
916
- main_token_data = self._process_token_data([token_address], pooler, T_cutoff)
 
 
917
  if not main_token_data:
918
  return None
919
 
 
 
 
 
 
 
 
 
 
920
  wallet_data, all_token_data = self._process_wallet_data(
921
  list(wallets_to_fetch),
922
  main_token_data.copy(),
923
  pooler,
924
- T_cutoff
 
 
 
925
  )
926
 
927
  graph_entities = {}
928
  graph_links = {}
929
- if wallets_to_fetch:
930
- graph_entities, graph_links = self.fetcher.fetch_graph_links(
931
- list(wallets_to_fetch),
932
- T_cutoff,
933
- max_degrees=1
934
- )
935
 
936
  # Generate the item
937
  return self._generate_dataset_item(
@@ -960,13 +998,14 @@ class OracleDataset(Dataset):
960
  graph_seed_entities=wallets_to_fetch,
961
  all_graph_entities=graph_entities,
962
  future_trades_for_labels=raw_data['trades'], # We utilize full trade history for labels!
963
- pooler=pooler
 
964
  )
965
 
966
  def __cacheitem__(self, idx: int) -> Optional[Dict[str, Any]]:
967
  """
968
  Fetches cutoff-agnostic raw token data for caching/online sampling.
969
- Random T_cutoff sampling happens later in __getitem__.
970
  """
971
 
972
  if not self.sampled_mints:
@@ -984,6 +1023,7 @@ class OracleDataset(Dataset):
984
  if not self.fetcher:
985
  raise RuntimeError("Dataset has no data fetcher; cannot load raw data.")
986
 
 
987
  raw_data = self.fetcher.fetch_raw_token_data(
988
  token_address=token_address,
989
  creator_address=creator_address,
@@ -991,7 +1031,10 @@ class OracleDataset(Dataset):
991
  max_horizon_seconds=self.max_cache_horizon_seconds,
992
  include_wallet_data=False,
993
  include_graph=False,
994
- min_trades=50
 
 
 
995
  )
996
  if raw_data is None:
997
  return None
@@ -1005,56 +1048,134 @@ class OracleDataset(Dataset):
1005
  return float(ts_value)
1006
  except (TypeError, ValueError):
1007
  return 0.0
1008
-
1009
- trade_ts_values = [
1010
- _timestamp_to_order_value(trade.get('timestamp'))
1011
- for trade in raw_data.get('trades', [])
1012
- if trade.get('timestamp') is not None
1013
- ]
1014
  if not trade_ts_values:
1015
  print(f" SKIP: No valid trades found for {token_address}.")
1016
  return None
 
 
 
1017
 
1018
- horizons = sorted(self.horizons_seconds)
1019
- first_horizon = horizons[0] if horizons else 60
1020
- min_label = max(60, first_horizon)
1021
- min_window = 30
1022
 
1023
- # 2. Strict Duration Check
1024
- # We enforce the exact same logic as __getitem__ to ensure the sample is usable.
1025
- # Logic:
1026
- # lower_bound = max(min_window, first_trade - mint)
1027
- # upper_bound = (last_trade - mint) - required_horizon
1028
- # We need upper_bound >= lower_bound.
1029
 
1030
- last_trade_ts_val = max(trade_ts_values)
1031
- first_trade_ts_val = min(trade_ts_values)
1032
- t0_val = _timestamp_to_order_value(t0)
 
 
 
 
 
 
 
 
 
 
1033
 
1034
- # Calculate offsets relative to mint
1035
- start_offset = max(0.0, first_trade_ts_val - t0_val)
1036
- end_offset = max(0.0, last_trade_ts_val - t0_val)
 
 
 
 
 
 
 
 
 
 
1037
 
1038
- lower_bound = max(min_window, int(start_offset))
 
 
 
 
1039
 
1040
- # We use the FIRST horizon to determine minimum validity,
1041
- # but technically we'd prefer to satisfy at least one horizon.
1042
- # Using min_label (which is max(60, first_horizon)) is safe.
1043
- required_horizon = min_label
1044
- upper_bound = end_offset - required_horizon
1045
 
1046
- if upper_bound < lower_bound:
1047
- # Diagnose the failure reason for the log
1048
- reason = []
1049
- if end_offset < (min_window + required_horizon):
1050
- reason.append(f"total duration {end_offset:.1f}s < {(min_window + required_horizon)}s")
1051
- if (last_trade_ts_val - first_trade_ts_val) < required_horizon:
1052
- reason.append(f"trade span {(last_trade_ts_val - first_trade_ts_val):.1f}s < {required_horizon}s")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1053
 
1054
- reason_str = ", ".join(reason) or "insufficient window overlap"
 
 
 
 
1055
 
1056
- print(f" SKIP: {token_address} does not fit sampling window. ({reason_str}) (Trades: {len(trade_ts_values)})")
1057
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1058
 
1059
  raw_data["protocol_id"] = initial_mint_record.get("protocol")
1060
  return raw_data
@@ -1078,7 +1199,8 @@ class OracleDataset(Dataset):
1078
  graph_seed_entities: set,
1079
  all_graph_entities: Dict[str, str],
1080
  future_trades_for_labels: List[Dict[str, Any]],
1081
- pooler: EmbeddingPooler
 
1082
  ) -> Optional[Dict[str, Any]]:
1083
  """
1084
  Processes raw token data into a structured dataset item for a specific T_cutoff.
@@ -1305,7 +1427,8 @@ class OracleDataset(Dataset):
1305
  aggregation_trades,
1306
  wallet_data,
1307
  total_supply_dec,
1308
- _register_event
 
1309
  )
1310
 
1311
  # 7. Finalize Sequence
 
97
  input sequence for the model.
98
  """
99
  def __init__(self,
100
+ data_fetcher: Optional[DataFetcher] = None, # OPTIONAL: Only needed for caching (Writer)
101
  horizons_seconds: List[int] = [],
102
  quantiles: List[float] = [],
103
  max_samples: Optional[int] = None,
104
+ ohlc_stats_path: Union[str, Path] = "./data/ohlc_stats.npz",
105
  token_allowlist: Optional[List[str]] = None,
106
  t_cutoff_seconds: int = 60,
107
  cache_dir: Optional[Union[str, Path]] = None,
 
273
  aggregation_trades: List[Dict[str, Any]],
274
  wallet_data: Dict[str, Any],
275
  total_supply_dec: float,
276
+ _register_event_fn,
277
+ cached_holders_list: List[List[str]] = None
278
  ) -> None:
279
  # Prepare helper sets and maps (static sniper set based on earliest buyers)
280
  all_buy_trades = sorted([e for e in trade_events if e.get('trade_direction') == 0 and e.get('success', False)], key=lambda x: x['timestamp'])
 
305
 
306
  buyers_seen_global = set()
307
  prev_holders_count = 0
308
+ for i, ts_value in enumerate(oc_snapshot_times):
309
  window_start = ts_value - interval_sec
310
  trades_win = [e for e in trade_events if e.get('success', False) and window_start < e['timestamp'] <= ts_value]
311
  xfers_win = [e for e in transfer_events if window_start < e['timestamp'] <= ts_value]
312
 
313
  # Per-snapshot holder distribution at ts_value
314
+ holder_records_ts = []
315
+ holders_end = 0
316
+ if cached_holders_list is not None and i < len(cached_holders_list):
317
+ # Use cached list of addresses
318
+ holder_records_ts = [{'wallet_address': addr, 'current_balance': 0} for addr in cached_holders_list[i]]
319
+ holders_end = len(cached_holders_list[i])
320
+ elif self.fetcher:
321
+ cutoff_dt_ts = datetime.datetime.fromtimestamp(ts_value, tz=datetime.timezone.utc)
322
+ holder_records_ts = self.fetcher.fetch_token_holders_for_snapshot(token_address, cutoff_dt_ts, limit=HOLDER_SNAPSHOT_TOP_K)
323
+ holders_end = self.fetcher.fetch_total_holders_count_for_token(token_address, cutoff_dt_ts)
324
+ else:
325
+ holder_records_ts = []
326
+ holders_end = 0
327
  holder_entries_ts = []
328
  for rec in holder_records_ts:
329
  addr = rec.get('wallet_address')
 
375
  buyers_seen_global.add(wa)
376
 
377
  # Compute growth against previous snapshot endpoint.
378
+ # total_holders = float(holders_end) # already handled above
 
379
  total_holders = float(holders_end)
380
  delta_holders = holders_end - prev_holders_count
381
  holder_growth_rate = float(delta_holders)
 
426
 
427
  # Fetch all token details in ONE batch query
428
  all_deployed_token_details = {}
429
+ if all_deployed_tokens and self.fetcher:
430
  all_deployed_token_details = self.fetcher.fetch_deployed_token_details(list(all_deployed_tokens), T_cutoff)
431
 
432
  for addr, profile in profiles.items():
 
465
  profile['deployed_tokens_avg_peak_mc_usd'] = torch.mean(torch.tensor(peak_mcs)).item() if peak_mcs else 0.0
466
  profile['deployed_tokens_median_peak_mc_usd'] = torch.median(torch.tensor(peak_mcs)).item() if peak_mcs else 0.0
467
 
468
+ def _process_wallet_data(self, wallet_addresses: List[str], token_data: Dict[str, Any], pooler: EmbeddingPooler, T_cutoff: datetime.datetime,
469
+ profiles_override: Optional[Dict] = None, socials_override: Optional[Dict] = None, holdings_override: Optional[Dict] = None) -> tuple[Dict[str, Dict[str, Any]], Dict[str, Dict[str, Any]]]:
470
  """
471
+ Fetches or uses cached profile, social, and holdings data.
 
472
  """
473
  if not wallet_addresses:
474
  return {}, token_data
475
 
476
+ if profiles_override is not None and socials_override is not None:
477
+ profiles, socials = profiles_override, socials_override
478
+ holdings = holdings_override if holdings_override is not None else {}
479
+ else:
480
+ print(f"INFO: Processing wallet data for {len(wallet_addresses)} unique wallets...")
481
+ if self.fetcher:
482
+ profiles, socials = self.fetcher.fetch_wallet_profiles_and_socials(wallet_addresses, T_cutoff)
483
+ holdings = self.fetcher.fetch_wallet_holdings(wallet_addresses, T_cutoff)
484
+ else:
485
+ profiles, socials, holdings = {}, {}, {}
486
 
487
  valid_wallets = [addr for addr in wallet_addresses if addr in profiles]
488
  dropped_wallets = set(wallet_addresses) - set(valid_wallets)
 
635
  return {}
636
 
637
  if token_data is None:
638
+ if self.fetcher:
639
+ print(f"INFO: Processing token data for {len(token_addresses)} unique tokens...")
640
+ token_data = self.fetcher.fetch_token_data(token_addresses, T_cutoff)
641
+ else:
642
+ token_data = {}
643
 
644
  # --- NEW: Print the raw fetched token data as requested ---
645
  print("\n--- RAW TOKEN DATA FROM DATABASE ---")
 
813
  try:
814
  raw_data = torch.load(filepath, map_location='cpu', weights_only=False)
815
  except Exception as e:
816
+ raise RuntimeError(f"ERROR: Could not load cached item {filepath}: {e}")
 
817
  else:
818
  # Online mode fallback
819
  raw_data = self.__cacheitem__(idx)
820
 
821
  if not raw_data:
822
+ raise RuntimeError(f"No raw data loaded for index {idx}")
823
 
824
  required_keys = [
825
  "mint_timestamp",
 
841
  f"Cached sample missing raw fields ({missing_keys}). Rebuild cache with raw caching enabled."
842
  )
843
 
844
+ # if not self.fetcher:
845
+ # raise RuntimeError("Data fetcher required for T_cutoff-dependent data.")
846
 
847
  def _timestamp_to_order_value(ts_value: Any) -> float:
848
  if isinstance(ts_value, datetime.datetime):
 
923
  if _timestamp_to_order_value(liq.get('timestamp')) <= cutoff_ts:
924
  _add_wallet(liq.get('lp_provider'), wallets_to_fetch)
925
 
926
+ # Offline Holder Lookup using raw_data['holder_snapshots_list']
927
+ # We need the snapshot corresponding to T_cutoff.
928
+ # Intervals are every 300s from mint_ts.
929
+ # idx = (T_cutoff - mint) // 300
930
+ elapsed = (T_cutoff - mint_timestamp).total_seconds()
931
+ snap_idx = int(elapsed // 300)
932
+ holder_records = []
933
+ cached_holders_list = raw_data.get('holder_snapshots_list', [])
934
+ if 0 <= snap_idx < len(cached_holders_list):
935
+ # Format expected by _add_wallet: dict with 'wallet_address'
936
+ holder_records = [{'wallet_address': addr} for addr in cached_holders_list[snap_idx]]
937
  for holder in holder_records:
938
  _add_wallet(holder.get('wallet_address'), wallets_to_fetch)
939
 
940
  pooler = EmbeddingPooler()
941
+ # Prepare offline token data
942
+ offline_token_data = {token_address: raw_data} # Assuming raw_data contains token metadata at root
943
+ main_token_data = self._process_token_data([token_address], pooler, T_cutoff, token_data=offline_token_data)
944
  if not main_token_data:
945
  return None
946
 
947
+ # Prepare offline wallet data
948
+ # raw_data['socials'] structure: {'profiles': {...}, 'socials': {...}} usually.
949
+ # But wait, cached raw_data['socials'] might be just the dict we need?
950
+ # Let's handle graceful empty if not found.
951
+ cached_social_bundle = raw_data.get('socials', {})
952
+ offline_profiles = cached_social_bundle.get('profiles', {})
953
+ offline_socials = cached_social_bundle.get('socials', {})
954
+ offline_holdings = {} # Holdings not cached usually due to size
955
+
956
  wallet_data, all_token_data = self._process_wallet_data(
957
  list(wallets_to_fetch),
958
  main_token_data.copy(),
959
  pooler,
960
+ T_cutoff,
961
+ profiles_override=offline_profiles,
962
+ socials_override=offline_socials,
963
+ holdings_override=offline_holdings
964
  )
965
 
966
  graph_entities = {}
967
  graph_links = {}
968
+ graph_entities = {}
969
+ graph_links = {}
970
+ # if wallets_to_fetch:
971
+ # graph_entities, graph_links = self.fetcher.fetch_graph_links(...)
972
+ # Offline Graph: check if raw_data has graph? Assuming no for now.
 
973
 
974
  # Generate the item
975
  return self._generate_dataset_item(
 
998
  graph_seed_entities=wallets_to_fetch,
999
  all_graph_entities=graph_entities,
1000
  future_trades_for_labels=raw_data['trades'], # We utilize full trade history for labels!
1001
+ pooler=pooler,
1002
+ cached_holders_list=raw_data.get('holder_snapshots_list')
1003
  )
1004
 
1005
  def __cacheitem__(self, idx: int) -> Optional[Dict[str, Any]]:
1006
  """
1007
  Fetches cutoff-agnostic raw token data for caching/online sampling.
1008
+ Generates dense time-series (1s OHLC, Snapshots) and prunes raw logs.
1009
  """
1010
 
1011
  if not self.sampled_mints:
 
1023
  if not self.fetcher:
1024
  raise RuntimeError("Dataset has no data fetcher; cannot load raw data.")
1025
 
1026
+ # --- FETCH FULL HISTORY with PRUNING ---
1027
  raw_data = self.fetcher.fetch_raw_token_data(
1028
  token_address=token_address,
1029
  creator_address=creator_address,
 
1031
  max_horizon_seconds=self.max_cache_horizon_seconds,
1032
  include_wallet_data=False,
1033
  include_graph=False,
1034
+ min_trades=50,
1035
+ full_history=True, # Bypass H/B/H limits
1036
+ prune_failed=True, # Drop failed trades
1037
+ prune_transfers=True # Drop transfers (captured in snapshots)
1038
  )
1039
  if raw_data is None:
1040
  return None
 
1048
  return float(ts_value)
1049
  except (TypeError, ValueError):
1050
  return 0.0
1051
+
1052
+ trades = raw_data.get('trades', [])
1053
+ trade_ts_values = [_timestamp_to_order_value(t.get('timestamp')) for t in trades]
1054
+
 
 
1055
  if not trade_ts_values:
1056
  print(f" SKIP: No valid trades found for {token_address}.")
1057
  return None
1058
+
1059
+ t0_val = _timestamp_to_order_value(t0)
1060
+ last_trade_ts_val = max(trade_ts_values)
1061
 
1062
+ # --- GENERATE DENSE 1s OHLC ---
1063
+ duration_seconds = int(last_trade_ts_val - t0_val) + 120 # Add buffer
1064
+ ohlc_1s = torch.zeros((duration_seconds, 2), dtype=torch.float32)
 
1065
 
1066
+ # Sort trades by time
1067
+ # raw_data trades are already sorted by fetcher, but let's be safe
1068
+ trades.sort(key=lambda x: _timestamp_to_order_value(x['timestamp']))
 
 
 
1069
 
1070
+ # Fill OHLC
1071
+ # A faster way: group by second
1072
+ # We can use a simple loop update or numpy accumulation.
1073
+ # Given standard density, simple loop is fine for caching.
1074
+
1075
+ trades_by_sec = defaultdict(list)
1076
+ for t in trades:
1077
+ ts = _timestamp_to_order_value(t['timestamp'])
1078
+ sec_idx = int(ts - t0_val)
1079
+ if 0 <= sec_idx < duration_seconds:
1080
+ trades_by_sec[sec_idx].append(t['price_usd'])
1081
+
1082
+ last_close = float(trades[0]['price_usd'])
1083
 
1084
+ for i in range(duration_seconds):
1085
+ if i in trades_by_sec:
1086
+ prices = trades_by_sec[i]
1087
+ op = prices[0]
1088
+ cl = prices[-1]
1089
+ last_close = cl
1090
+ else:
1091
+ op = cl = last_close
1092
+
1093
+ ohlc_1s[i, 0] = float(op)
1094
+ ohlc_1s[i, 1] = float(cl)
1095
+
1096
+ raw_data['ohlc_1s'] = ohlc_1s
1097
 
1098
+ # --- GENERATE ON-CHAIN SNAPSHOTS (5m Interval) ---
1099
+ interval = 300 # 5 minutes
1100
+ num_intervals = (duration_seconds // interval) + 1
1101
+ # Feature columns: [volume, tx_count, buy_count, sell_count, total_holders, top_10_holder_pct]
1102
+ # We start with basic trade stats. Holder stats require DB queries.
1103
 
1104
+ snapshot_stats = torch.zeros((num_intervals, 6), dtype=torch.float32)
 
 
 
 
1105
 
1106
+ print(f" INFO: Generating {num_intervals} snapshots (Interval: {interval}s)...")
1107
+
1108
+ cum_volume = 0.0
1109
+ cum_tx = 0
1110
+ cum_buys = 0
1111
+ cum_sells = 0
1112
+
1113
+ # Pre-group trades into 5m buckets for windowed volume
1114
+ buckets = defaultdict(list)
1115
+ for t in trades:
1116
+ ts = _timestamp_to_order_value(t['timestamp'])
1117
+ bucket_idx = int(ts - t0_val) // interval
1118
+ if bucket_idx >= 0:
1119
+ buckets[bucket_idx].append(t)
1120
+
1121
+ # To avoid spamming DB, we might query holders less frequently or batch?
1122
+ # For now, query every step. 288 queries for 24h is fine.
1123
+
1124
+ fetched_holders_cache = {} # Map bucket_idx -> (count, top10_pct)
1125
+ holder_snapshots_list = [] # List of (timestamp, holders_list)
1126
+
1127
+ for i in range(num_intervals):
1128
+ bucket_trades = buckets[i]
1129
 
1130
+ # Windowed Stats
1131
+ vol = sum(t.get('total_usd', 0.0) for t in bucket_trades)
1132
+ tx = len(bucket_trades)
1133
+ buys = sum(1 for t in bucket_trades if t.get('trade_direction') == 0 or t.get('trade_type') == 0) # 0=Buy
1134
+ sells = tx - buys
1135
 
1136
+ # DB Stats: Holders (Point-in-Time)
1137
+ # Time is end of bucket
1138
+ snapshot_ts = t0 + datetime.timedelta(seconds=(i+1)*interval)
1139
+
1140
+ # These queries can be slow.
1141
+ count = self.fetcher.fetch_total_holders_count_for_token(token_address, snapshot_ts)
1142
+ # Fetch Top 200 as per constant
1143
+ top_holders = self.fetcher.fetch_token_holders_for_snapshot(token_address, snapshot_ts, limit=HOLDER_SNAPSHOT_TOP_K)
1144
+
1145
+ total_supply = raw_data.get('total_supply', 0) or 1
1146
+ if raw_data.get('decimals'):
1147
+ total_supply /= (10 ** raw_data['decimals'])
1148
+
1149
+ top10_bal = sum(h.get('current_balance', 0) for h in top_holders[:10])
1150
+ top10_pct = (top10_bal / total_supply) if total_supply > 0 else 0.0
1151
+
1152
+ snapshot_stats[i, 0] = float(vol)
1153
+ snapshot_stats[i, 1] = float(tx)
1154
+ snapshot_stats[i, 2] = float(buys)
1155
+ snapshot_stats[i, 3] = float(sells)
1156
+ snapshot_stats[i, 4] = float(count)
1157
+ snapshot_stats[i, 5] = float(top10_pct)
1158
+
1159
+ # Save the holder identities for the event stream
1160
+ # Make it JSON-serializable-ish (no datetime objects)
1161
+ holder_snapshots_list.append({
1162
+ 'timestamp': int(snapshot_ts.timestamp()),
1163
+ 'holders': top_holders # [{wallet, balance}, ...]
1164
+ })
1165
+
1166
+ raw_data['snapshots_5m'] = snapshot_stats
1167
+ raw_data['holder_snapshots_list'] = holder_snapshots_list # Save the list
1168
+
1169
+ # --- Summary Log ---
1170
+ print(f" [Cache Summary]")
1171
+ print(f" - 1s Candles: {len(ohlc_1s)}")
1172
+ print(f" - 5m Snapshots: {len(snapshot_stats)}")
1173
+ print(f" - Trades (Succ): {len(trades)}")
1174
+ print(f" - Pool Events: {len(raw_data.get('pool_creations', []))}")
1175
+ print(f" - Liquidity Chgs: {len(raw_data.get('liquidity_changes', []))}")
1176
+ print(f" - Burns: {len(raw_data.get('burns', []))}")
1177
+ print(f" - Supply Locks: {len(raw_data.get('supply_locks', []))}")
1178
+ print(f" - Migrations: {len(raw_data.get('migrations', []))}")
1179
 
1180
  raw_data["protocol_id"] = initial_mint_record.get("protocol")
1181
  return raw_data
 
1199
  graph_seed_entities: set,
1200
  all_graph_entities: Dict[str, str],
1201
  future_trades_for_labels: List[Dict[str, Any]],
1202
+ pooler: EmbeddingPooler,
1203
+ cached_holders_list: List[List[str]] = None
1204
  ) -> Optional[Dict[str, Any]]:
1205
  """
1206
  Processes raw token data into a structured dataset item for a specific T_cutoff.
 
1427
  aggregation_trades,
1428
  wallet_data,
1429
  total_supply_dec,
1430
+ _register_event,
1431
+ cached_holders_list=cached_holders_list
1432
  )
1433
 
1434
  # 7. Finalize Sequence
install.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sudo apt update
2
+ sudo apt install -y curl wget gnupg apt-transport-https ca-certificates dirmngr
3
+
4
+ sudo apt update
5
+ sudo apt install -y pkg-config libudev-dev
6
+
7
+ curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
8
+ source $HOME/.cargo/env
9
+
10
+ # ClickHouse (add repo and install)
11
+ sudo apt-key adv --keyserver hkp://keyserver.ubuntu.com:80 --recv 8919F6BD2B48D754
12
+ echo "deb https://packages.clickhouse.com/deb stable main" | sudo tee /etc/apt/sources.list.d/clickhouse.list
13
+ sudo apt update
14
+ sudo apt install -y clickhouse-server clickhouse-client
15
+
16
+ # Neo4j (add repo and install)
17
+ sudo wget -O - https://debian.neo4j.com/neotechnology.gpg.key | sudo gpg --dearmor -o /usr/share/keyrings/neo4j.gpg
18
+ echo "deb [signed-by=/usr/share/keyrings/neo4j.gpg] https://debian.neo4j.com stable latest" | sudo tee -a /etc/apt/sources.list.d/neo4j.list
19
+ sudo apt update
20
+ sudo apt install -y neo4j
21
+
22
+ # Start Neo4j (Runs on bolt://localhost:7687)
23
+ sudo neo4j-admin dbms set-initial-password neo4j123
24
+ neo4j start
25
+
26
+ clickhouse-server
log.log CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:fb6fc43f8ae6467768fb090cfdda9ef48e68d361874317db93e5eee126539989
3
- size 143685
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2bfaace3cf2aadc0acf9e9714d8df00c44bc545db23c87e7497a7844ba3c98a9
3
+ size 6115919
models/model.py CHANGED
@@ -5,6 +5,8 @@ import torch.nn as nn
5
  import torch.nn.functional as F
6
  from transformers import AutoConfig, AutoModel
7
  from typing import List, Dict, Any, Optional, Tuple
 
 
8
 
9
  # --- NOW, we import all the encoders ---
10
  from models.helper_encoders import ContextualTimeEncoder
@@ -43,6 +45,9 @@ class Oracle(nn.Module):
43
  self.multi_modal_dim = multi_modal_dim
44
 
45
 
 
 
 
46
  self.quantiles = quantiles
47
  self.horizons_seconds = horizons_seconds
48
  self.num_outputs = len(quantiles) * len(horizons_seconds)
@@ -225,6 +230,77 @@ class Oracle(nn.Module):
225
  self.to(dtype)
226
  print("Oracle model (full pipeline) initialized.")
227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  def _normalize_and_project(self,
229
  features: torch.Tensor,
230
  norm_layer: nn.LayerNorm,
 
5
  import torch.nn.functional as F
6
  from transformers import AutoConfig, AutoModel
7
  from typing import List, Dict, Any, Optional, Tuple
8
+ import os
9
+ import json
10
 
11
  # --- NOW, we import all the encoders ---
12
  from models.helper_encoders import ContextualTimeEncoder
 
45
  self.multi_modal_dim = multi_modal_dim
46
 
47
 
48
+ self.num_event_types = num_event_types
49
+ self.event_pad_id = event_pad_id
50
+ self.model_config_name = model_config_name
51
  self.quantiles = quantiles
52
  self.horizons_seconds = horizons_seconds
53
  self.num_outputs = len(quantiles) * len(horizons_seconds)
 
230
  self.to(dtype)
231
  print("Oracle model (full pipeline) initialized.")
232
 
233
+ def save_pretrained(self, save_directory: str):
234
+ """
235
+ Saves the model in a Hugging Face-compatible way.
236
+ """
237
+ if not os.path.exists(save_directory):
238
+ os.makedirs(save_directory)
239
+
240
+ # 1. Save the inner transformer model using its own save_pretrained
241
+ # This gives us the standard HF config.json and pytorch_model.bin for the backbone
242
+ self.model.save_pretrained(save_directory)
243
+
244
+ # 2. Save the whole Oracle state dict (includes transformer + all custom encoders)
245
+ # We use 'oracle_model.bin' for the full state.
246
+ torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
247
+
248
+ # 3. Save Oracle specific metadata for reconstruction
249
+ oracle_config = {
250
+ "num_event_types": self.num_event_types,
251
+ "multi_modal_dim": self.multi_modal_dim,
252
+ "event_pad_id": self.event_pad_id,
253
+ "model_config_name": self.model_config_name,
254
+ "quantiles": self.quantiles,
255
+ "horizons_seconds": self.horizons_seconds,
256
+ "dtype": str(self.dtype),
257
+ "event_type_to_id": self.event_type_to_id
258
+ }
259
+ with open(os.path.join(save_directory, "oracle_config.json"), "w") as f:
260
+ json.dump(oracle_config, f, indent=2)
261
+
262
+ print(f"✅ Oracle model saved to {save_directory}")
263
+
264
+ @classmethod
265
+ def from_pretrained(cls, load_directory: str,
266
+ token_encoder, wallet_encoder, graph_updater, ohlc_embedder, time_encoder):
267
+ """
268
+ Loads the Oracle model from a saved directory.
269
+ Note: You must still provide the initialized sub-encoders (or we can refactor to save them too).
270
+ """
271
+ config_path = os.path.join(load_directory, "oracle_config.json")
272
+ with open(config_path, "r") as f:
273
+ config = json.load(f)
274
+
275
+ # Determine dtype from string
276
+ dtype = torch.bfloat16 # Default
277
+ if "float32" in config["dtype"]: dtype = torch.float32
278
+ elif "float16" in config["dtype"]: dtype = torch.float16
279
+
280
+ # Instantiate model
281
+ model = cls(
282
+ token_encoder=token_encoder,
283
+ wallet_encoder=wallet_encoder,
284
+ graph_updater=graph_updater,
285
+ ohlc_embedder=ohlc_embedder,
286
+ time_encoder=time_encoder,
287
+ num_event_types=config["num_event_types"],
288
+ multi_modal_dim=config["multi_modal_dim"],
289
+ event_pad_id=config["event_pad_id"],
290
+ event_type_to_id=config["event_type_to_id"],
291
+ model_config_name=config["model_config_name"],
292
+ quantiles=config["quantiles"],
293
+ horizons_seconds=config["horizons_seconds"],
294
+ dtype=dtype
295
+ )
296
+
297
+ # Load weights
298
+ weight_path = os.path.join(load_directory, "pytorch_model.bin")
299
+ state_dict = torch.load(weight_path, map_location="cpu")
300
+ model.load_state_dict(state_dict)
301
+ print(f"✅ Oracle model loaded from {load_directory}")
302
+ return model
303
+
304
  def _normalize_and_project(self,
305
  features: torch.Tensor,
306
  norm_layer: nn.LayerNorm,
models/multi_modal_processor.py CHANGED
@@ -21,9 +21,12 @@ class MultiModalEncoder:
21
  This class is intended for creating embeddings for vector search.
22
  """
23
 
24
- def __init__(self, model_id="google/siglip-so400m-patch16-256-i18n", dtype: torch.dtype = torch.bfloat16):
25
  self.model_id = model_id
26
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
27
 
28
  self.dtype = dtype
29
 
 
21
  This class is intended for creating embeddings for vector search.
22
  """
23
 
24
+ def __init__(self, model_id="google/siglip-so400m-patch16-256-i18n", dtype: torch.dtype = torch.bfloat16, device: str = None):
25
  self.model_id = model_id
26
+ if device:
27
+ self.device = device
28
+ else:
29
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
30
 
31
  self.dtype = dtype
32
 
models/ohlc_embedder.py CHANGED
@@ -18,7 +18,7 @@ class OHLCEmbedder(nn.Module):
18
  # --- NEW: Interval vocab size ---
19
  num_intervals: int,
20
  input_channels: int = 2, # Open, Close
21
- sequence_length: int = 300,
22
  cnn_channels: List[int] = [16, 32, 64],
23
  kernel_sizes: List[int] = [3, 3, 3],
24
  # --- NEW: Interval embedding dim ---
@@ -30,12 +30,12 @@ class OHLCEmbedder(nn.Module):
30
  assert len(cnn_channels) == len(kernel_sizes), "cnn_channels and kernel_sizes must have the same length"
31
 
32
  self.dtype = dtype
33
- self.sequence_length = sequence_length
34
  self.cnn_layers = nn.ModuleList()
35
  self.output_dim = output_dim
36
 
37
  in_channels = input_channels
38
- current_seq_len = sequence_length
39
 
40
  for i, (out_channels, k_size) in enumerate(zip(cnn_channels, kernel_sizes)):
41
  conv = nn.Conv1d(
 
18
  # --- NEW: Interval vocab size ---
19
  num_intervals: int,
20
  input_channels: int = 2, # Open, Close
21
+ # sequence_length: int = 300, # REMOVED: HARDCODED
22
  cnn_channels: List[int] = [16, 32, 64],
23
  kernel_sizes: List[int] = [3, 3, 3],
24
  # --- NEW: Interval embedding dim ---
 
30
  assert len(cnn_channels) == len(kernel_sizes), "cnn_channels and kernel_sizes must have the same length"
31
 
32
  self.dtype = dtype
33
+ self.sequence_length = 300 # HARDCODED
34
  self.cnn_layers = nn.ModuleList()
35
  self.output_dim = output_dim
36
 
37
  in_channels = input_channels
38
+ current_seq_len = 300
39
 
40
  for i, (out_channels, k_size) in enumerate(zip(cnn_channels, kernel_sizes)):
41
  conv = nn.Conv1d(
train.py CHANGED
@@ -4,6 +4,9 @@ import math
4
  import logging
5
  from pathlib import Path
6
  from typing import Any, Dict, List, Optional, Tuple
 
 
 
7
 
8
  # Ensure torch/dill have a writable tmp dir
9
  _DEFAULT_TMP = Path(os.getenv("TMPDIR_OVERRIDE", "./.tmp"))
@@ -12,6 +15,11 @@ resolved_tmp = str(_DEFAULT_TMP.resolve())
12
  for key in ("TMPDIR", "TMP", "TEMP"):
13
  os.environ.setdefault(key, resolved_tmp)
14
 
 
 
 
 
 
15
  import torch
16
  import torch.nn as nn
17
  from torch.utils.data import DataLoader
@@ -126,7 +134,6 @@ def parse_args() -> argparse.Namespace:
126
  parser.add_argument("--checkpoint_dir", type=str, default="checkpoints")
127
  parser.add_argument("--mixed_precision", type=str, default="bf16")
128
  parser.add_argument("--max_seq_len", type=int, default=16000)
129
- parser.add_argument("--ohlc_seq_len", type=int, default=60)
130
  parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[30, 60, 120, 240, 420])
131
  parser.add_argument("--quantiles", type=float, nargs="+", default=[0.1, 0.5, 0.9])
132
  parser.add_argument("--max_samples", type=int, default=None)
@@ -200,8 +207,8 @@ def main() -> None:
200
  horizons = args.horizons_seconds
201
  quantiles = args.quantiles
202
  max_seq_len = args.max_seq_len
203
- ohlc_seq_len = args.ohlc_seq_len
204
-
205
  logger.info(f"Initializing Encoders with dtype={init_dtype}...")
206
 
207
  # Encoders
@@ -212,39 +219,29 @@ def main() -> None:
212
  graph_updater = GraphUpdater(time_encoder=time_encoder, dtype=init_dtype)
213
  ohlc_embedder = OHLCEmbedder(
214
  num_intervals=vocab.NUM_OHLC_INTERVALS,
215
- sequence_length=ohlc_seq_len,
216
  dtype=init_dtype
217
  )
218
 
219
  collator = MemecoinCollator(
220
  event_type_to_id=vocab.EVENT_TO_ID,
221
  device=device, # Note: Collator will handle basic moves, Accelerate handles the rest
222
- multi_modal_encoder=multi_modal_encoder,
223
  dtype=init_dtype,
224
- ohlc_seq_len=ohlc_seq_len,
225
  max_seq_len=max_seq_len
226
  )
227
 
228
- # DB Connections
229
- clickhouse_client = ClickHouseClient(
230
- host=args.clickhouse_host,
231
- port=int(args.clickhouse_port)
232
- )
233
-
234
- neo4j_auth = ("neo4j", "neo4j123")
235
- if args.neo4j_user is not None:
236
- neo4j_auth = (args.neo4j_user, args.neo4j_password or "")
237
- neo4j_driver = GraphDatabase.driver(args.neo4j_uri, auth=neo4j_auth)
238
-
239
- data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
240
 
241
  dataset = OracleDataset(
242
- data_fetcher=data_fetcher,
243
  horizons_seconds=horizons,
244
  quantiles=quantiles,
245
  max_samples=args.max_samples,
246
  ohlc_stats_path=args.ohlc_stats_path,
247
- t_cutoff_seconds=int(args.t_cutoff_seconds),
248
  cache_dir="/workspace/apollo/data/cache"
249
  )
250
 
@@ -257,7 +254,7 @@ def main() -> None:
257
  shuffle=bool(args.shuffle),
258
  num_workers=int(args.num_workers),
259
  pin_memory=bool(args.pin_memory),
260
- collate_fn=lambda batch: filtered_collate(collator, batch)
261
  )
262
 
263
  # --- 3. Model Init ---
@@ -442,25 +439,36 @@ def main() -> None:
442
  if accelerator.is_main_process:
443
  save_path = checkpoint_dir / f"checkpoint-{total_steps}"
444
  accelerator.save_state(output_dir=str(save_path))
445
- logger.info(f"Saved checkpoint to {save_path}")
 
 
 
 
 
 
446
 
447
  # End of Epoch Handling
448
  if valid_batches > 0:
449
  avg_loss = epoch_loss / valid_batches
450
  if accelerator.is_main_process:
451
  logger.info(f"Epoch {epoch+1} complete. Avg loss: {avg_loss:.6f}")
452
- accelerator.log({"train/loss_epoch": avg_loss}, step=global_step)
453
 
454
- # Save Checkpoint at end of epoch
455
  save_path = checkpoint_dir / f"epoch_{epoch+1}"
456
- accelerator.save_state(output_dir=str(save_path))
457
- logger.info(f"Saved checkpoint to {save_path}")
 
 
 
 
 
458
  else:
459
  if accelerator.is_main_process:
460
  logger.warning(f"Epoch {epoch+1}: No valid batches processed.")
461
 
462
  accelerator.end_training()
463
- neo4j_driver.close()
464
 
465
  if __name__ == "__main__":
466
  main()
 
4
  import logging
5
  from pathlib import Path
6
  from typing import Any, Dict, List, Optional, Tuple
7
+ import functools
8
+
9
+ import torch.multiprocessing as mp
10
 
11
  # Ensure torch/dill have a writable tmp dir
12
  _DEFAULT_TMP = Path(os.getenv("TMPDIR_OVERRIDE", "./.tmp"))
 
15
  for key in ("TMPDIR", "TMP", "TEMP"):
16
  os.environ.setdefault(key, resolved_tmp)
17
 
18
+ try:
19
+ mp.set_start_method('spawn', force=True)
20
+ except RuntimeError:
21
+ pass
22
+
23
  import torch
24
  import torch.nn as nn
25
  from torch.utils.data import DataLoader
 
134
  parser.add_argument("--checkpoint_dir", type=str, default="checkpoints")
135
  parser.add_argument("--mixed_precision", type=str, default="bf16")
136
  parser.add_argument("--max_seq_len", type=int, default=16000)
 
137
  parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[30, 60, 120, 240, 420])
138
  parser.add_argument("--quantiles", type=float, nargs="+", default=[0.1, 0.5, 0.9])
139
  parser.add_argument("--max_samples", type=int, default=None)
 
207
  horizons = args.horizons_seconds
208
  quantiles = args.quantiles
209
  max_seq_len = args.max_seq_len
210
+ max_seq_len = args.max_seq_len
211
+
212
  logger.info(f"Initializing Encoders with dtype={init_dtype}...")
213
 
214
  # Encoders
 
219
  graph_updater = GraphUpdater(time_encoder=time_encoder, dtype=init_dtype)
220
  ohlc_embedder = OHLCEmbedder(
221
  num_intervals=vocab.NUM_OHLC_INTERVALS,
 
222
  dtype=init_dtype
223
  )
224
 
225
  collator = MemecoinCollator(
226
  event_type_to_id=vocab.EVENT_TO_ID,
227
  device=device, # Note: Collator will handle basic moves, Accelerate handles the rest
228
+ # multi_modal_encoder=multi_modal_encoder, # REMOVED: Uses lazy loading internally
229
  dtype=init_dtype,
 
230
  max_seq_len=max_seq_len
231
  )
232
 
233
+ # DB Connections - REMOVED for Training (Using Cache)
234
+ # clickhouse_client = ClickHouseClient(...)
235
+ # neo4j_driver = GraphDatabase.driver(...)
236
+ # data_fetcher = DataFetcher(...)
 
 
 
 
 
 
 
 
237
 
238
  dataset = OracleDataset(
239
+ data_fetcher=None, # Training Mode (Reader Only)
240
  horizons_seconds=horizons,
241
  quantiles=quantiles,
242
  max_samples=args.max_samples,
243
  ohlc_stats_path=args.ohlc_stats_path,
244
+ t_cutoff_seconds=int(args.t_cutoff_seconds) if hasattr(args, 't_cutoff_seconds') else 60,
245
  cache_dir="/workspace/apollo/data/cache"
246
  )
247
 
 
254
  shuffle=bool(args.shuffle),
255
  num_workers=int(args.num_workers),
256
  pin_memory=bool(args.pin_memory),
257
+ collate_fn=functools.partial(filtered_collate, collator)
258
  )
259
 
260
  # --- 3. Model Init ---
 
439
  if accelerator.is_main_process:
440
  save_path = checkpoint_dir / f"checkpoint-{total_steps}"
441
  accelerator.save_state(output_dir=str(save_path))
442
+
443
+ # NEW: Save in standard HF-loadable way
444
+ hf_save_path = save_path / "hf_model"
445
+ unwrapped_model = accelerator.unwrap_model(model)
446
+ unwrapped_model.save_pretrained(str(hf_save_path))
447
+
448
+ logger.info(f"Saved checkpoint and HF-style model to {save_path}")
449
 
450
  # End of Epoch Handling
451
  if valid_batches > 0:
452
  avg_loss = epoch_loss / valid_batches
453
  if accelerator.is_main_process:
454
  logger.info(f"Epoch {epoch+1} complete. Avg loss: {avg_loss:.6f}")
455
+ accelerator.log({"train/loss_epoch": avg_loss}, step=total_steps)
456
 
457
+ # Save Checkpoint at end of epoch (REMOVED: saving every epoch is too much)
458
  save_path = checkpoint_dir / f"epoch_{epoch+1}"
459
+ # accelerator.save_state(output_dir=str(save_path))
460
+ # hf_save_path = save_path / "hf_model"
461
+ # unwrapped_model = accelerator.unwrap_model(model)
462
+ # unwrapped_model.save_pretrained(str(hf_save_path))
463
+
464
+ # logger.info(f"Saved and HF-style model (EOF) to {save_path}")
465
+ pass
466
  else:
467
  if accelerator.is_main_process:
468
  logger.warning(f"Epoch {epoch+1}: No valid batches processed.")
469
 
470
  accelerator.end_training()
471
+ # neo4j_driver.close() # REMOVED
472
 
473
  if __name__ == "__main__":
474
  main()
train.sh CHANGED
@@ -1,4 +1,4 @@
1
- accelerate launch train.py \
2
  --epochs 10 \
3
  --batch_size 1 \
4
  --learning_rate 1e-4 \
@@ -7,16 +7,14 @@ accelerate launch train.py \
7
  --max_grad_norm 1.0 \
8
  --seed 42 \
9
  --log_every 1 \
10
- --save_every 1000 \
11
  --tensorboard_dir runs/oracle \
12
  --checkpoint_dir checkpoints \
13
  --mixed_precision bf16 \
14
- --max_seq_len 50 \
15
- --ohlc_seq_len 300 \
16
  --horizons_seconds 30 60 120 240 420 \
17
  --quantiles 0.1 0.5 0.9 \
18
  --ohlc_stats_path ./data/ohlc_stats.npz \
19
- --t_cutoff_seconds 60 \
20
  --num_workers 4 \
21
  --clickhouse_host localhost \
22
  --clickhouse_port 9000 \
 
1
+ /venv/main/bin/accelerate launch train.py \
2
  --epochs 10 \
3
  --batch_size 1 \
4
  --learning_rate 1e-4 \
 
7
  --max_grad_norm 1.0 \
8
  --seed 42 \
9
  --log_every 1 \
10
+ --save_every 10 \
11
  --tensorboard_dir runs/oracle \
12
  --checkpoint_dir checkpoints \
13
  --mixed_precision bf16 \
14
+ --max_seq_len 4096 \
 
15
  --horizons_seconds 30 60 120 240 420 \
16
  --quantiles 0.1 0.5 0.9 \
17
  --ohlc_stats_path ./data/ohlc_stats.npz \
 
18
  --num_workers 4 \
19
  --clickhouse_host localhost \
20
  --clickhouse_port 9000 \