zirobtc commited on
Commit
e125fa3
·
1 Parent(s): 4dd4ab4

Upload folder using huggingface_hub

Browse files
data/data_fetcher.py CHANGED
@@ -628,81 +628,29 @@ class DataFetcher:
628
 
629
  def fetch_trades_for_token(self, token_address: str, T_cutoff: datetime.datetime, count_threshold: int, early_limit: int, recent_limit: int, full_history: bool = False) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
630
  """
631
- Fetches trades for a token.
632
- If full_history is True, fetches ALL trades (ignores H/B/H limits).
633
- Otherwise, uses the 3-part H/B/H strategy if the total count exceeds a threshold.
634
- Returns three lists: early_trades, middle_trades, recent_trades.
 
 
 
 
635
  """
636
  if not token_address:
637
  return [], [], []
638
 
639
  params = {'token_address': token_address, 'T_cutoff': T_cutoff}
640
-
641
- # 1. Get the total count if we care about H/B/H logic
642
- if not full_history:
643
- count_query = "SELECT count() FROM trades WHERE base_address = %(token_address)s AND timestamp <= %(T_cutoff)s"
644
- try:
645
- total_trades = self.db_client.execute(count_query, params)[0][0]
646
- print(f"INFO: Found {total_trades} total trades for token {token_address} before {T_cutoff}.")
647
- except Exception as e:
648
- print(f"ERROR: Could not count trades for token {token_address}: {e}")
649
- return [], [], []
650
- else:
651
- total_trades = 0 # Dummy value, ignored
652
-
653
- # 2. Decide which query to use
654
- # If full_history is ON, or count is low, fetch everything.
655
- if full_history or total_trades < count_threshold:
656
- mode = "Full History" if full_history else "Low Count"
657
- # print(f"INFO: Fetching all trades ({mode}).")
658
- query = "SELECT * FROM trades WHERE base_address = %(token_address)s AND timestamp <= %(T_cutoff)s ORDER BY timestamp ASC"
659
- try:
660
- rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
661
- if not rows: return [], [], []
662
- columns = [col[0] for col in columns_info]
663
- all_trades = [dict(zip(columns, row)) for row in rows]
664
- # When not using HBH or fetching full history, all trades are considered "early" (or just one big block)
665
- return all_trades, [], []
666
- except Exception as e:
667
- print(f"ERROR: Failed to fetch all trades for token {token_address}: {e}")
668
- return [], [], []
669
-
670
- # 3. Use the H/B/H strategy if the count is high AND not full_history
671
- print("INFO: Fetching trades using 3-part High-Def/Blurry/High-Def strategy.")
672
  try:
673
- # Fetch Early (High-Def)
674
- early_query = "SELECT * FROM trades WHERE base_address = %(token_address)s AND timestamp <= %(T_cutoff)s ORDER BY timestamp ASC LIMIT %(limit)s"
675
- early_rows, early_cols_info = self.db_client.execute(early_query, {'token_address': token_address, 'T_cutoff': T_cutoff, 'limit': early_limit}, with_column_types=True)
676
- early_trades = [dict(zip([c[0] for c in early_cols_info], r)) for r in early_rows] if early_rows else []
677
-
678
- # Fetch Recent (High-Def)
679
- recent_query = "SELECT * FROM trades WHERE base_address = %(token_address)s AND timestamp <= %(T_cutoff)s ORDER BY timestamp DESC LIMIT %(limit)s"
680
- recent_rows, recent_cols_info = self.db_client.execute(recent_query, {'token_address': token_address, 'T_cutoff': T_cutoff, 'limit': recent_limit}, with_column_types=True)
681
- recent_trades = [dict(zip([c[0] for c in recent_cols_info], r)) for r in recent_rows] if recent_rows else []
682
- recent_trades.reverse() # Order ASC
683
-
684
- # Fetch Middle (Blurry - successful trades only)
685
- middle_trades = []
686
- if early_trades and recent_trades:
687
- start_middle_ts = early_trades[-1]['timestamp']
688
- end_middle_ts = recent_trades[0]['timestamp']
689
- if start_middle_ts < end_middle_ts:
690
- middle_query = """
691
- SELECT * FROM trades
692
- WHERE base_address = %(token_address)s
693
- AND success = true
694
- AND timestamp > %(start_ts)s
695
- AND timestamp < %(end_ts)s
696
- ORDER BY timestamp ASC
697
- """
698
- middle_params = {'token_address': token_address, 'start_ts': start_middle_ts, 'end_ts': end_middle_ts}
699
- middle_rows, middle_cols_info = self.db_client.execute(middle_query, middle_params, with_column_types=True)
700
- middle_trades = [dict(zip([c[0] for c in middle_cols_info], r)) for r in middle_rows] if middle_rows else []
701
-
702
- return early_trades, middle_trades, recent_trades
703
-
704
  except Exception as e:
705
- print(f"ERROR: Failed to fetch H/B/H trades for token {token_address}: {e}")
706
  return [], [], []
707
 
708
  def fetch_future_trades_for_token(self,
 
628
 
629
  def fetch_trades_for_token(self, token_address: str, T_cutoff: datetime.datetime, count_threshold: int, early_limit: int, recent_limit: int, full_history: bool = False) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
630
  """
631
+ Fetches ALL trades for a token up to T_cutoff, ordered by time.
632
+
633
+ Notes:
634
+ - This intentionally does NOT apply the older fetch-time H/B/H (High-Def / Blurry / High-Def)
635
+ sampling logic. Sequence-length control is handled later in data_loader.py via event-level
636
+ head/tail sampling with MIDDLE/RECENT markers.
637
+ - The function signature still includes legacy H/B/H parameters for compatibility.
638
+ Returns: (all_trades, [], [])
639
  """
640
  if not token_address:
641
  return [], [], []
642
 
643
  params = {'token_address': token_address, 'T_cutoff': T_cutoff}
644
+ query = "SELECT * FROM trades WHERE base_address = %(token_address)s AND timestamp <= %(T_cutoff)s ORDER BY timestamp ASC"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
645
  try:
646
+ rows, columns_info = self.db_client.execute(query, params, with_column_types=True)
647
+ if not rows:
648
+ return [], [], []
649
+ columns = [col[0] for col in columns_info]
650
+ all_trades = [dict(zip(columns, row)) for row in rows]
651
+ return all_trades, [], []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
652
  except Exception as e:
653
+ print(f"ERROR: Failed to fetch trades for token {token_address}: {e}")
654
  return [], [], []
655
 
656
  def fetch_future_trades_for_token(self,
data/data_loader.py CHANGED
@@ -142,6 +142,10 @@ class OracleDataset(Dataset):
142
 
143
  self.fetcher = data_fetcher
144
  self.cache_dir = Path(cache_dir) if cache_dir else None
 
 
 
 
145
 
146
  # If a fetcher is provided, we can determine the number of samples.
147
  # Otherwise, we are likely in a test mode where __len__ might not be called
@@ -149,7 +153,13 @@ class OracleDataset(Dataset):
149
  self.t_cutoff_seconds = max(0, int(t_cutoff_seconds or 0))
150
  self.token_allowlist = set(token_allowlist) if token_allowlist else None
151
 
152
- if self.cache_dir and self.cache_dir.is_dir():
 
 
 
 
 
 
153
  print(f"INFO: Initializing dataset in offline (cached) mode from: {self.cache_dir}")
154
  # Scan for cached files to determine length
155
  self.cached_files = sorted(self.cache_dir.glob("sample_*.pt"), key=lambda p: int(p.stem.split('_')[1]))
@@ -1201,7 +1211,8 @@ class OracleDataset(Dataset):
1201
  pooler=pooler,
1202
  sample_idx=idx,
1203
  cached_holders_list=raw_data.get('holder_snapshots_list'),
1204
- cached_ohlc_1s=raw_data.get('ohlc_1s')
 
1205
  )
1206
 
1207
  def __cacheitem__(self, idx: int) -> Optional[Dict[str, Any]]:
@@ -1394,7 +1405,8 @@ class OracleDataset(Dataset):
1394
  pooler: EmbeddingPooler,
1395
  sample_idx: Optional[int] = None,
1396
  cached_holders_list: List[List[str]] = None,
1397
- cached_ohlc_1s: Optional[torch.Tensor] = None
 
1398
  ) -> Optional[Dict[str, Any]]:
1399
  """
1400
  Processes raw token data into a structured dataset item for a specific T_cutoff.
@@ -1683,7 +1695,7 @@ class OracleDataset(Dataset):
1683
  'embedding_pooler': pooler,
1684
  'labels': torch.zeros(len(self.horizons_seconds), dtype=torch.float32),
1685
  'labels_mask': torch.zeros(len(self.horizons_seconds), dtype=torch.float32),
1686
- 'quality_score': torch.tensor(raw_data['quality_score'], dtype=torch.float32)
1687
  }
1688
 
1689
  # Ensure sorted
@@ -1759,5 +1771,5 @@ class OracleDataset(Dataset):
1759
  'embedding_pooler': pooler,
1760
  'labels': torch.tensor(label_values, dtype=torch.float32),
1761
  'labels_mask': torch.tensor(mask_values, dtype=torch.float32),
1762
- 'quality_score': torch.tensor(raw_data['quality_score'], dtype=torch.float32)
1763
  }
 
142
 
143
  self.fetcher = data_fetcher
144
  self.cache_dir = Path(cache_dir) if cache_dir else None
145
+ # Always define these so DataLoader workers don't crash with AttributeError if
146
+ # initialization falls through an unexpected branch.
147
+ self.cached_files = []
148
+ self.weights_list = []
149
 
150
  # If a fetcher is provided, we can determine the number of samples.
151
  # Otherwise, we are likely in a test mode where __len__ might not be called
 
153
  self.t_cutoff_seconds = max(0, int(t_cutoff_seconds or 0))
154
  self.token_allowlist = set(token_allowlist) if token_allowlist else None
155
 
156
+ if self.cache_dir:
157
+ if not self.cache_dir.is_dir():
158
+ raise RuntimeError(
159
+ f"Cache directory '{self.cache_dir}' was provided but is not a directory. "
160
+ "Fix the path or disable cached mode."
161
+ )
162
+ # Cached/offline mode
163
  print(f"INFO: Initializing dataset in offline (cached) mode from: {self.cache_dir}")
164
  # Scan for cached files to determine length
165
  self.cached_files = sorted(self.cache_dir.glob("sample_*.pt"), key=lambda p: int(p.stem.split('_')[1]))
 
1211
  pooler=pooler,
1212
  sample_idx=idx,
1213
  cached_holders_list=raw_data.get('holder_snapshots_list'),
1214
+ cached_ohlc_1s=raw_data.get('ohlc_1s'),
1215
+ quality_score=raw_data.get('quality_score')
1216
  )
1217
 
1218
  def __cacheitem__(self, idx: int) -> Optional[Dict[str, Any]]:
 
1405
  pooler: EmbeddingPooler,
1406
  sample_idx: Optional[int] = None,
1407
  cached_holders_list: List[List[str]] = None,
1408
+ cached_ohlc_1s: Optional[torch.Tensor] = None,
1409
+ quality_score: Optional[float] = None
1410
  ) -> Optional[Dict[str, Any]]:
1411
  """
1412
  Processes raw token data into a structured dataset item for a specific T_cutoff.
 
1695
  'embedding_pooler': pooler,
1696
  'labels': torch.zeros(len(self.horizons_seconds), dtype=torch.float32),
1697
  'labels_mask': torch.zeros(len(self.horizons_seconds), dtype=torch.float32),
1698
+ 'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32)
1699
  }
1700
 
1701
  # Ensure sorted
 
1771
  'embedding_pooler': pooler,
1772
  'labels': torch.tensor(label_values, dtype=torch.float32),
1773
  'labels_mask': torch.tensor(mask_values, dtype=torch.float32),
1774
+ 'quality_score': torch.tensor(quality_score if quality_score is not None else 0.0, dtype=torch.float32)
1775
  }
data/ohlc_stats.npz CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6f2c86bf03e5761e7fb319a54274e032f7aa1d01dd5873f2f44a52c9e0be5244
3
  size 1660
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46809f070aa1dfcb4f53d7390b1b6ff370e6828e198df4c0df5632ac6fa9f607
3
  size 1660
log.log CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:461e55d31752fd72f09aa30c5bcc3a619654ae86ddf1e759c9c57b0dc5db53f6
3
- size 21794
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:41885991264f1522ec8b539dd4f3f738d537102a65103a800578229feef13880
3
+ size 18007
models/model.py CHANGED
@@ -3,7 +3,7 @@
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
- from transformers import AutoConfig, AutoModel
7
  from typing import List, Dict, Any, Optional, Tuple
8
  import os
9
  import json
@@ -32,7 +32,7 @@ class Oracle(nn.Module):
32
  multi_modal_dim: int,
33
  event_pad_id: int,
34
  event_type_to_id: Dict[str, int],
35
- model_config_name: str = "Qwen/Qwen3-0.6B",
36
  quantiles: List[float] = [0.1, 0.5, 0.9],
37
  horizons_seconds: List[int] = [30, 60, 120, 240, 420],
38
  dtype: torch.dtype = torch.bfloat16):
@@ -53,12 +53,43 @@ class Oracle(nn.Module):
53
  self.num_outputs = len(quantiles) * len(horizons_seconds)
54
  self.dtype = dtype
55
 
56
- # --- 2. Load Qwen3 Configuration (architecture only; training from scratch) ---
57
- hf_token = os.getenv("Hf_TOKEN") or os.getenv("HF_TOKEN")
58
- hf_kwargs = {"token": hf_token} if hf_token else {}
59
- model_config = AutoConfig.from_pretrained(model_config_name, trust_remote_code=True, **hf_kwargs)
60
- self.d_model = model_config.hidden_size
61
- self.model = AutoModel.from_config(model_config, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  self.model.to(self.device, dtype=self.dtype)
63
 
64
  # Quantile prediction head (maps pooled hidden state -> flattened horizon/quantile grid)
@@ -225,8 +256,9 @@ class Oracle(nn.Module):
225
  # --- NEW: Embedding for timeframe ID (re-uses protocol_embedding) ---
226
  self.lighthouse_timeframe_embedding = nn.Embedding(vocab.NUM_LIGHTHOUSE_TIMEFRAMES, self.d_model)
227
 
228
- # --- NEW: Embeddings for Special Context Tokens ---
229
- self.special_context_tokens = {'Middle': 0, 'RECENT': 1}
 
230
  self.special_context_embedding = nn.Embedding(len(self.special_context_tokens), self.d_model)
231
 
232
 
@@ -906,19 +938,19 @@ class Oracle(nn.Module):
906
 
907
  def _get_special_context_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
908
  """
909
- Handles special context tokens like 'Middle' and 'RECENT' by adding their unique learnable embeddings.
910
  """
911
  device = self.device
912
  event_type_ids = batch['event_type_ids']
913
  B, L = event_type_ids.shape
914
 
915
- middle_id = self.event_type_to_id.get('Middle', -1)
916
  recent_id = self.event_type_to_id.get('RECENT', -1)
917
 
918
  middle_mask = (event_type_ids == middle_id)
919
  recent_mask = (event_type_ids == recent_id)
920
 
921
- middle_emb = self.special_context_embedding(torch.tensor(self.special_context_tokens['Middle'], device=device))
922
  recent_emb = self.special_context_embedding(torch.tensor(self.special_context_tokens['RECENT'], device=device))
923
 
924
  # Add the embeddings at the correct locations
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
+ from transformers import AutoModel, LlamaConfig
7
  from typing import List, Dict, Any, Optional, Tuple
8
  import os
9
  import json
 
32
  multi_modal_dim: int,
33
  event_pad_id: int,
34
  event_type_to_id: Dict[str, int],
35
+ model_config_name: str = "llama3-12l-768d-gqa4-8k-random",
36
  quantiles: List[float] = [0.1, 0.5, 0.9],
37
  horizons_seconds: List[int] = [30, 60, 120, 240, 420],
38
  dtype: torch.dtype = torch.bfloat16):
 
53
  self.num_outputs = len(quantiles) * len(horizons_seconds)
54
  self.dtype = dtype
55
 
56
+ # --- 2. Backbone: Llama-style decoder, RANDOM INIT (no pretrained weights) ---
57
+ # This gives you RoPE + modern decoder blocks and lets HF use optimized attention
58
+ # implementations (SDPA / FlashAttention) without us implementing a transformer.
59
+ #
60
+ # Size target: ~80-120M params, suitable for 8k-ish seq caps with your data regime.
61
+ attn_impl = os.getenv("HF_ATTN_IMPL", "sdpa") # "sdpa" (safe) or "flash_attention_2" (if installed)
62
+ llama_cfg = LlamaConfig(
63
+ # Model size
64
+ hidden_size=768,
65
+ intermediate_size=3072,
66
+ num_hidden_layers=12,
67
+ num_attention_heads=12,
68
+ # GQA-style KV heads (Llama 3-style efficiency knob)
69
+ num_key_value_heads=4,
70
+ # Long context (must be >= your effective max sequence length)
71
+ max_position_embeddings=8192,
72
+ # Llama 3 uses a large theta; harmless for random init and helps longer contexts.
73
+ rope_theta=500000.0,
74
+ rms_norm_eps=1e-5,
75
+ # Unused when providing inputs_embeds, but required by config
76
+ vocab_size=32000,
77
+ )
78
+ self.d_model = llama_cfg.hidden_size
79
+ # Older transformers versions may not support attn_implementation in from_config.
80
+ # Also, flash_attention_2 requires optional deps; fall back to SDPA if unavailable.
81
+ try:
82
+ self.model = AutoModel.from_config(llama_cfg, attn_implementation=attn_impl)
83
+ except TypeError:
84
+ self.model = AutoModel.from_config(llama_cfg)
85
+ except Exception:
86
+ if attn_impl != "sdpa":
87
+ self.model = AutoModel.from_config(llama_cfg, attn_implementation="sdpa")
88
+ else:
89
+ raise
90
+ # Disable KV cache during training (saves memory; not used for full-seq training).
91
+ if hasattr(self.model, "config"):
92
+ self.model.config.use_cache = False
93
  self.model.to(self.device, dtype=self.dtype)
94
 
95
  # Quantile prediction head (maps pooled hidden state -> flattened horizon/quantile grid)
 
256
  # --- NEW: Embedding for timeframe ID (re-uses protocol_embedding) ---
257
  self.lighthouse_timeframe_embedding = nn.Embedding(vocab.NUM_LIGHTHOUSE_TIMEFRAMES, self.d_model)
258
 
259
+ # --- Embeddings for Special Context Tokens ---
260
+ # Must match vocabulary event names (see models/vocabulary.py).
261
+ self.special_context_tokens = {'MIDDLE': 0, 'RECENT': 1}
262
  self.special_context_embedding = nn.Embedding(len(self.special_context_tokens), self.d_model)
263
 
264
 
 
938
 
939
  def _get_special_context_embeddings(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
940
  """
941
+ Handles special context tokens like 'MIDDLE' and 'RECENT' by adding their unique learnable embeddings.
942
  """
943
  device = self.device
944
  event_type_ids = batch['event_type_ids']
945
  B, L = event_type_ids.shape
946
 
947
+ middle_id = self.event_type_to_id.get('MIDDLE', -1)
948
  recent_id = self.event_type_to_id.get('RECENT', -1)
949
 
950
  middle_mask = (event_type_ids == middle_id)
951
  recent_mask = (event_type_ids == recent_id)
952
 
953
+ middle_emb = self.special_context_embedding(torch.tensor(self.special_context_tokens['MIDDLE'], device=device))
954
  recent_emb = self.special_context_embedding(torch.tensor(self.special_context_tokens['RECENT'], device=device))
955
 
956
  # Add the embeddings at the correct locations
scripts/analyze_hyperparams.py CHANGED
@@ -1,255 +1,301 @@
1
  import os
2
- import sys
3
- import torch
4
- import numpy as np
5
  import argparse
6
- from tqdm import tqdm
7
- from datetime import datetime, timezone
8
- from collections import defaultdict
9
 
 
 
10
 
11
- # Add project root to path
12
- sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
13
- from data.data_loader import OracleDataset, DataFetcher
14
 
15
- import os
16
- import sys
17
- import numpy as np
18
- import argparse
19
- from tqdm import tqdm
20
- from datetime import datetime, timezone
21
- import collections
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- # Add project root to path
24
- sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
25
- from data.data_loader import DataFetcher
26
 
27
- import os
28
- import sys
29
- import numpy as np
30
- import argparse
31
- from tqdm import tqdm
32
- from datetime import datetime, timezone
33
- import collections
34
- from dotenv import load_dotenv
35
- from clickhouse_driver import Client as ClickHouseClient
36
- from neo4j import GraphDatabase
 
37
 
38
- # Add project root to path
39
- sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
40
- from data.data_loader import DataFetcher
41
 
42
- def parse_args():
43
- parser = argparse.ArgumentParser(description="Analyze dataset to tune hyperparameters (Horizons, Seq Len)")
44
- parser.add_argument("--max_samples", type=int, default=5000, help="Max samples to analyze")
45
- parser.add_argument("--token_address", type=str, default=None, help="Specific token address to analyze")
46
- return parser.parse_args()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- def main():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  load_dotenv()
50
  args = parse_args()
51
-
52
- print("--- Hyperparameter Calibration Analysis (SQL) ---")
53
-
54
- # DB Connection
55
- ch_host = os.getenv("CLICKHOUSE_HOST", "localhost")
56
- ch_port = int(os.getenv("CLICKHOUSE_NATIVE_PORT", 9000))
57
- neo_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687")
58
- neo_user = os.getenv("NEO4J_USER", "neo4j")
59
- neo_pass = os.getenv("NEO4J_PASSWORD", "password")
60
-
61
- print(f"Connecting to ClickHouse at {ch_host}:{ch_port}...")
62
- clickhouse_client = ClickHouseClient(host=ch_host, port=ch_port)
63
-
64
- print(f"Connecting to Neo4j at {neo_uri}...")
65
- neo4j_driver = GraphDatabase.driver(neo_uri, auth=(neo_user, neo_pass))
66
-
67
- # 1. Initialize DataFetcher
68
- fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
69
- print("DataFetcher initialized.")
70
-
71
- # 2. Fetch Sample Mints
72
  if args.token_address:
73
- print(f"Analyzing specific token: {args.token_address}")
74
- # Try to find mint timestamp
75
- query = f"SELECT mint_address, timestamp FROM mints WHERE mint_address = '{args.token_address}'"
76
- mints = fetcher.db_client.execute(query)
77
- if not mints:
78
- print("Token not found in mints table. Trying to use first trade timestamp...")
79
- # Fallback if not in mints table
80
- q2 = f"SELECT base_address, min(timestamp) FROM trades WHERE base_address = '{args.token_address}' GROUP BY base_address"
81
- mints = fetcher.db_client.execute(q2)
82
-
83
- if not mints:
84
- print("Token not found in trades either (or no trades). Exiting.")
85
  return
86
-
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  else:
88
- print(f"Fetching {args.max_samples} sample tokens...")
89
- # Fetch random mints
90
- query = f"""
91
- SELECT mint_address, timestamp FROM mints
92
- ORDER BY rand()
93
- LIMIT {args.max_samples}
94
- """
95
- mints = fetcher.db_client.execute(query)
96
- print(f"Fetched {len(mints)} tokens.")
97
-
98
- # Metrics to collect
99
- lifespans = [] # Time from mint to last trade
100
- time_to_ath = [] # Time from mint to highest price
101
-
102
- # Sequence Length estimations
103
- windows_to_test = [5, 10, 30, 60] # Minutes
104
- event_counts = {w: [] for w in windows_to_test}
105
- full_history_counts = []
106
-
107
- print(f"Analyzing trades for {len(mints)} tokens...")
108
-
109
- for mint_addr, mint_ts in tqdm(mints):
110
- try:
111
- if isinstance(mint_ts, datetime) and mint_ts.tzinfo is None:
112
- mint_ts = mint_ts.replace(tzinfo=timezone.utc)
113
- t0 = mint_ts.timestamp()
114
-
115
- # Fetch ALL trades for this token
116
- # We don't need full enrichments, just timestamp and price
117
- # Args: token_addr, T_cutoff, count_threshold, early_lim, recent_lim, full_history
118
- now_ts = datetime.now(timezone.utc)
119
- trades, _, _ = fetcher.fetch_trades_for_token(mint_addr, now_ts, 0, 0, 0, full_history=True)
120
-
121
- if not trades: continue
122
-
123
- # Trades are usually sorted, but ensure
124
- trades.sort(key=lambda x: x['timestamp'])
125
-
126
- # Lifespan
127
- last_ts = trades[-1]['timestamp'].timestamp()
128
- lifespans.append(last_ts - t0)
129
-
130
- # Time to ATH
131
- max_price = -1.0
132
- ath_ts = 0.0
133
-
134
- valid_trades = []
135
- for t in trades:
136
- p = float(t.get('price_usd', 0.0))
137
- # Basic filter for garbage prints
138
- if p > 0:
139
- valid_trades.append(t)
140
- if p > max_price:
141
- max_price = p
142
- ath_ts = t['timestamp'].timestamp()
143
-
144
- if max_price > 0:
145
- time_to_ath.append(ath_ts - t0)
146
-
147
- # --- Sequence Length Metrics ---
148
- full_history_counts.append(len(valid_trades))
149
-
150
- # Windowed counts
151
- counts_in_window = {w: 0 for w in windows_to_test}
152
-
153
- for t in valid_trades:
154
- ts_val = t['timestamp'].timestamp()
155
- elapsed_min = (ts_val - t0) / 60.0
156
-
157
- for w in windows_to_test:
158
- if elapsed_min <= w:
159
- counts_in_window[w] += 1
160
-
161
- for w in windows_to_test:
162
- event_counts[w].append(counts_in_window[w])
163
-
164
- except Exception as e:
165
- print(f"Error processing {mint_addr}: {e}")
166
- import traceback
167
- traceback.print_exc()
168
- pass
169
-
170
- # --- Stats Calculation ---
171
- def print_stats(name, data):
172
- if not data:
173
- print(f"{name}: No Data")
174
  return
175
- # Convert to numpy array for easier filtering if needed, though they are lists
176
- arr = np.array(data)
177
- p25 = np.percentile(arr, 25)
178
- p50 = np.percentile(arr, 50)
179
- p75 = np.percentile(arr, 75)
180
- p90 = np.percentile(arr, 90)
181
- p95 = np.percentile(arr, 95)
182
- p99 = np.percentile(arr, 99)
183
- max_val = np.max(arr)
184
- print(f"[{name}]")
185
- print(f" Mean: {np.mean(arr):.2f} | Median: {p50:.2f} | Max: {max_val:.2f}")
186
- print(f" 25%: {p25:.2f} | 75%: {p75:.2f} | 90%: {p90:.2f} | 95%: {p95:.2f} | 99%: {p99:.2f}")
187
-
188
- print("\n" + "="*40)
189
- print("RESULTS (ALL TOKENS)")
190
- print("="*40)
191
-
192
- # Time Stats
193
- lifespans_min = [x/60.0 for x in lifespans]
194
- time_to_ath_min = [x/60.0 for x in time_to_ath]
195
-
196
- print_stats("Token Lifespan (Minutes)", lifespans_min)
197
- print("\n")
198
- print_stats("Time to ATH (Minutes)", time_to_ath_min)
199
-
200
- print("\n" + "-"*20)
201
- print("SEQUENCE LENGTHS (Trades Only)")
202
- print("-"*20)
203
-
204
- print_stats("Full History Length", full_history_counts)
205
-
206
- for w in windows_to_test:
207
- print("\n")
208
- print_stats(f"Trades in First {w} Minutes", event_counts[w])
209
-
210
- # --- High Activity Subset ---
211
- print("\n" + "="*40)
212
- print("RESULTS (HIGH ACTIVITY SUBSET)")
213
- print("Filter: > 50 trades AND > 5 min lifespan")
214
- print("="*40)
215
-
216
- # Filter indices
217
- valid_indices = []
218
- for i, count in enumerate(full_history_counts):
219
- if count > 50 and lifespans_min[i] > 5.0:
220
- valid_indices.append(i)
221
-
222
- if not valid_indices:
223
- print("No high activity tokens found.")
224
- else:
225
- print(f"Found {len(valid_indices)} high activity tokens out of {len(full_history_counts)}.")
226
-
227
- subset_lifespans = [lifespans_min[i] for i in valid_indices]
228
- subset_ath = [time_to_ath_min[i] for i in valid_indices if i < len(time_to_ath_min)] # careful with length if sizes differ? they shouldn't by logic, but time_to_ath depends on if trade > 0
229
-
230
- # indices are aligned with loop order
231
- # But wait, time_to_ath was appended only if max_price > 0.
232
- # This misalignment is risky.
233
-
234
- # Better: Store dicts or tuples in the main loop instead of parallel lists.
235
- # Quick fix: Just recalc stats on lists is hard if not aligned?
236
- # Actually time_to_ath might be shorter than lifespans.
237
- # Let's just print what we can, assuming simple filtering on `event_counts` which aligns 1:1 with loop (except exceptions).
238
-
239
- # Re-collect logic for subsets is cleaner if we store objects.
240
- # But let's just do Event Counts which are critical for seq_len.
241
-
242
- subset_history = [full_history_counts[i] for i in valid_indices]
243
- print_stats("Subset: Full History Length", subset_history)
244
-
245
- for w in windows_to_test:
246
- subset_w = [event_counts[w][i] for i in valid_indices]
247
- print("\n")
248
- print_stats(f"Subset: Trades in First {w} Min", subset_w)
249
-
250
- print("\nRecommendation Logic:")
251
- print("1. Horizons: Look at 'Time to ATH' p90 (or p90 of Subset).")
252
- print("2. Max Seq Len: Look at 'Trades in First X Minutes' (X ~= Max Horizon).")
253
 
254
  if __name__ == "__main__":
255
  main()
 
1
  import os
 
 
 
2
  import argparse
3
+ from typing import List, Optional, Sequence, Tuple
 
 
4
 
5
+ from dotenv import load_dotenv
6
+ from clickhouse_driver import Client as ClickHouseClient
7
 
 
 
 
8
 
9
+ def parse_args() -> argparse.Namespace:
10
+ parser = argparse.ArgumentParser(
11
+ description="Fast SQL-based hyperparameter analysis (trades-only) for seq_len + horizons."
12
+ )
13
+ parser.add_argument("--token_address", type=str, default=None, help="Analyze a single token address.")
14
+ parser.add_argument(
15
+ "--windows_min",
16
+ type=str,
17
+ default="5,10,30,60",
18
+ help="Comma-separated trade-count windows in minutes (e.g. '5,10,30,60').",
19
+ )
20
+ parser.add_argument(
21
+ "--min_price_usd",
22
+ type=float,
23
+ default=0.0,
24
+ help="Treat trades with price_usd <= min_price_usd as invalid (default: 0.0).",
25
+ )
26
+ return parser.parse_args()
27
 
 
 
 
28
 
29
+ def _parse_windows(windows_min: str) -> List[int]:
30
+ out: List[int] = []
31
+ for part in (windows_min or "").split(","):
32
+ part = part.strip()
33
+ if not part:
34
+ continue
35
+ out.append(int(part))
36
+ out = sorted(set([w for w in out if w > 0]))
37
+ if not out:
38
+ raise ValueError("No valid --windows_min provided.")
39
+ return out
40
 
 
 
 
41
 
42
+ def _connect_clickhouse_from_env() -> ClickHouseClient:
43
+ ch_host = os.getenv("CLICKHOUSE_HOST", "localhost")
44
+ ch_port = int(os.getenv("CLICKHOUSE_NATIVE_PORT", "9000"))
45
+ ch_user = os.getenv("CLICKHOUSE_USER", None)
46
+ ch_pass = os.getenv("CLICKHOUSE_PASSWORD", None)
47
+ ch_db = os.getenv("CLICKHOUSE_DB", None)
48
+
49
+ kwargs = {"host": ch_host, "port": ch_port}
50
+ if ch_user:
51
+ kwargs["user"] = ch_user
52
+ if ch_pass:
53
+ kwargs["password"] = ch_pass
54
+ if ch_db:
55
+ kwargs["database"] = ch_db
56
+ return ClickHouseClient(**kwargs)
57
+
58
+
59
+ def _quantile_levels() -> Sequence[float]:
60
+ # Keep these aligned with the printed labels below.
61
+ return (0.25, 0.5, 0.75, 0.90, 0.95, 0.99)
62
+
63
+
64
+ def _fmt_q_tuple(q: Tuple[float, ...]) -> str:
65
+ # Labels match _quantile_levels()
66
+ labels = ["25%", "50%", "75%", "90%", "95%", "99%"]
67
+ parts = []
68
+ for lbl, v in zip(labels, q):
69
+ parts.append(f"{lbl}: {float(v):.2f}")
70
+ return " | ".join(parts)
71
+
72
+
73
+ def _print_row(prefix: str, mean_v: float, q_tuple: Tuple[float, ...], max_v: float) -> None:
74
+ print(f"[{prefix}]")
75
+ print(f" Mean: {float(mean_v):.2f} | Median: {float(q_tuple[1]):.2f} | Max: {float(max_v):.2f}")
76
+ print(f" {_fmt_q_tuple(q_tuple)}")
77
+
78
+
79
+ def fetch_aggregated_stats_sql(
80
+ ch: ClickHouseClient,
81
+ windows_min: List[int],
82
+ min_price_usd: float,
83
+ token_address: Optional[str] = None,
84
+ ) -> List[tuple]:
85
+ """
86
+ One ClickHouse query that computes distribution statistics directly (no per-token loop in Python).
87
+ Returns two groups:
88
+ - grp='all'
89
+ - grp='subset' where trades_full > 50 and lifespan_sec > 300 (5 minutes)
90
+ """
91
+ q_levels = _quantile_levels()
92
+ q_levels_sql = ", ".join(str(q) for q in q_levels)
93
+
94
+ per_token_window_exprs = []
95
+ agg_window_exprs = []
96
+ for w in windows_min:
97
+ sec = int(w) * 60
98
+ per_token_window_exprs.append(
99
+ f"countIf(is_valid AND (trade_ts - mint_ts) <= {sec}) AS trades_{w}m"
100
+ )
101
+ agg_window_exprs.append(
102
+ f"avg(trades_{w}m) AS trades_{w}m_mean,"
103
+ f" quantilesExact({q_levels_sql})(trades_{w}m) AS trades_{w}m_q,"
104
+ f" max(trades_{w}m) AS trades_{w}m_max"
105
+ )
106
+
107
+ params = {"min_price": float(min_price_usd)}
108
+ token_filter = ""
109
+ if token_address:
110
+ token_filter = "AND m.mint_address = %(token)s"
111
+ params["token"] = token_address
112
 
113
+ # Note: we pre-filter trades to only minted tokens for speed.
114
+ query = f"""
115
+ WITH
116
+ per_token AS (
117
+ SELECT
118
+ m.mint_address AS mint_address,
119
+ toUnixTimestamp(m.timestamp) AS mint_ts,
120
+ countIf(is_valid) AS trades_full,
121
+ (maxIf(trade_ts, is_valid) - mint_ts) AS lifespan_sec,
122
+ (toUnixTimestamp(argMaxIf(t.timestamp, t.price_usd, is_valid)) - mint_ts) AS time_to_ath_sec,
123
+ {", ".join(per_token_window_exprs)}
124
+ FROM mints AS m
125
+ INNER JOIN
126
+ (
127
+ SELECT
128
+ base_address,
129
+ timestamp,
130
+ toUnixTimestamp(timestamp) AS trade_ts,
131
+ price_usd,
132
+ (price_usd > %(min_price)s) AS is_valid
133
+ FROM trades
134
+ WHERE base_address IN (SELECT mint_address FROM mints)
135
+ ) AS t
136
+ ON t.base_address = m.mint_address
137
+ WHERE 1=1
138
+ {token_filter}
139
+ GROUP BY
140
+ mint_address,
141
+ mint_ts
142
+ HAVING
143
+ trades_full > 0
144
+ )
145
+ SELECT
146
+ grp,
147
+ count() AS tokens,
148
+
149
+ avg(trades_full) AS trades_full_mean,
150
+ quantilesExact({q_levels_sql})(trades_full) AS trades_full_q,
151
+ max(trades_full) AS trades_full_max,
152
+
153
+ avg(lifespan_sec / 60.0) AS lifespan_min_mean,
154
+ quantilesExact({q_levels_sql})(lifespan_sec / 60.0) AS lifespan_min_q,
155
+ max(lifespan_sec / 60.0) AS lifespan_min_max,
156
+
157
+ avg(time_to_ath_sec / 60.0) AS tta_min_mean,
158
+ quantilesExact({q_levels_sql})(time_to_ath_sec / 60.0) AS tta_min_q,
159
+ max(time_to_ath_sec / 60.0) AS tta_min_max,
160
+
161
+ {", ".join(agg_window_exprs)}
162
+ FROM per_token
163
+ ARRAY JOIN ['all', 'subset'] AS grp
164
+ WHERE (grp = 'all')
165
+ OR (grp = 'subset' AND trades_full > 50 AND lifespan_sec > 300)
166
+ GROUP BY grp
167
+ ORDER BY grp
168
+ """
169
+
170
+ return ch.execute(query, params)
171
+
172
+
173
+ def fetch_single_token_sql(
174
+ ch: ClickHouseClient,
175
+ windows_min: List[int],
176
+ min_price_usd: float,
177
+ token_address: str,
178
+ ) -> Optional[tuple]:
179
+ per_token_window_exprs = []
180
+ for w in windows_min:
181
+ sec = int(w) * 60
182
+ per_token_window_exprs.append(
183
+ f"countIf(is_valid AND (trade_ts - mint_ts) <= {sec}) AS trades_{w}m"
184
+ )
185
+
186
+ params = {"min_price": float(min_price_usd), "token": token_address}
187
+ query = f"""
188
+ SELECT
189
+ m.mint_address AS mint_address,
190
+ toUnixTimestamp(m.timestamp) AS mint_ts,
191
+ countIf(is_valid) AS trades_full,
192
+ (maxIf(trade_ts, is_valid) - mint_ts) AS lifespan_sec,
193
+ (toUnixTimestamp(argMaxIf(t.timestamp, t.price_usd, is_valid)) - mint_ts) AS time_to_ath_sec,
194
+ {", ".join(per_token_window_exprs)}
195
+ FROM mints AS m
196
+ INNER JOIN
197
+ (
198
+ SELECT
199
+ base_address,
200
+ timestamp,
201
+ toUnixTimestamp(timestamp) AS trade_ts,
202
+ price_usd,
203
+ (price_usd > %(min_price)s) AS is_valid
204
+ FROM trades
205
+ WHERE base_address = %(token)s
206
+ ) AS t
207
+ ON t.base_address = m.mint_address
208
+ WHERE m.mint_address = %(token)s
209
+ GROUP BY
210
+ mint_address,
211
+ mint_ts
212
+ HAVING
213
+ trades_full > 0
214
+ """
215
+ rows = ch.execute(query, params)
216
+ return rows[0] if rows else None
217
+
218
+
219
+ def main() -> None:
220
  load_dotenv()
221
  args = parse_args()
222
+ windows_min = _parse_windows(args.windows_min)
223
+
224
+ print("--- Hyperparameter Calibration Analysis (FAST SQL) ---")
225
+ print(f"Windows (min): {windows_min}")
226
+ print(f"Valid trade filter: price_usd > {float(args.min_price_usd)}")
227
+
228
+ ch = _connect_clickhouse_from_env()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  if args.token_address:
230
+ row = fetch_single_token_sql(
231
+ ch=ch,
232
+ windows_min=windows_min,
233
+ min_price_usd=float(args.min_price_usd),
234
+ token_address=args.token_address,
235
+ )
236
+ if not row:
237
+ print("Token not found (or no valid trades).")
 
 
 
 
238
  return
239
+
240
+ mint_addr = row[0]
241
+ trades_full = int(row[2])
242
+ lifespan_min = float(row[3]) / 60.0
243
+ tta_min = float(row[4]) / 60.0
244
+ print("\n" + "=" * 40)
245
+ print("RESULTS (SINGLE TOKEN)")
246
+ print("=" * 40)
247
+ print(f"Token: {mint_addr}")
248
+ print(f"Valid trades: {trades_full}")
249
+ print(f"Lifespan (min): {lifespan_min:.2f}")
250
+ print(f"Time to ATH (min): {tta_min:.2f}")
251
+ for i, w in enumerate(windows_min):
252
+ print(f"Trades in first {w}m: {int(row[5 + i])}")
253
  else:
254
+ rows = fetch_aggregated_stats_sql(
255
+ ch=ch,
256
+ windows_min=windows_min,
257
+ min_price_usd=float(args.min_price_usd),
258
+ token_address=None,
259
+ )
260
+ if not rows:
261
+ print("No tokens found with valid trades.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  return
263
+
264
+ print("\n" + "=" * 40)
265
+ print("RESULTS (DISTRIBUTION)")
266
+ print("=" * 40)
267
+
268
+ # Row layout:
269
+ # grp, tokens,
270
+ # trades_full_mean, trades_full_q(tuple), trades_full_max,
271
+ # lifespan_min_mean, lifespan_min_q(tuple), lifespan_min_max,
272
+ # tta_min_mean, tta_min_q(tuple), tta_min_max,
273
+ # repeated for each window: mean, q(tuple), max
274
+ for row in rows:
275
+ grp = row[0]
276
+ tokens = int(row[1])
277
+ print(f"\n--- Group: {grp} (tokens={tokens}) ---")
278
+
279
+ _print_row("Trades (Full History, Valid Only)", row[2], row[3], row[4])
280
+ print("")
281
+ _print_row("Token Lifespan (Minutes)", row[5], row[6], row[7])
282
+ print("")
283
+ _print_row("Time to ATH (Minutes)", row[8], row[9], row[10])
284
+
285
+ cursor = 11
286
+ for w in windows_min:
287
+ mean_v = row[cursor]
288
+ q_v = row[cursor + 1]
289
+ max_v = row[cursor + 2]
290
+ cursor += 3
291
+ print("")
292
+ _print_row(f"Trades in First {w} Minutes (Valid Only)", mean_v, q_v, max_v)
293
+
294
+ print("\nRecommendation Logic (Trades-only):")
295
+ print("- Horizons: look at Time-to-ATH p90/p95 (all vs subset).")
296
+ print("- Max seq len: look at Trades-in-first-(max horizon) p95/p99.")
297
+ print(" Then add headroom for non-trade events (transfers/pool/liquidity/etc).")
298
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
  if __name__ == "__main__":
301
  main()
scripts/cache_dataset.py CHANGED
@@ -309,19 +309,17 @@ def main():
309
  n_burns = len(item.get("burns", []))
310
  n_supply_locks = len(item.get("supply_locks", []))
311
  n_migrations = len(item.get("migrations", []))
 
312
  n_ohlc = len(item.get("ohlc_1s", [])) if item.get("ohlc_1s") is not None else 0
313
  n_snapshots_5m = len(item.get("snapshots_5m", []))
314
  n_holders = len(item.get("holder_snapshots_list", []))
315
 
316
- tqdm.write(f" + Cached: {mint_addr} | Class: {class_id} | Q: {q_score:.4f}")
317
  tqdm.write(
318
- " Events | "
319
- f"Trades: {n_trades} | Transfers: {n_transfers} | Pool Creations: {n_pool_creations} | "
320
- f"Liquidity Changes: {n_liquidity_changes} | Fee Collections: {n_fee_collections} | "
321
- f"Burns: {n_burns} | Supply Locks: {n_supply_locks} | Migrations: {n_migrations}"
322
- )
323
- tqdm.write(
324
- f" Derived | Mint: 1 | Ohlc 1s: {n_ohlc} | Snapshots 5m: {n_snapshots_5m} | Holder Snapshots: {n_holders}"
325
  )
326
 
327
  except Exception as e:
 
309
  n_burns = len(item.get("burns", []))
310
  n_supply_locks = len(item.get("supply_locks", []))
311
  n_migrations = len(item.get("migrations", []))
312
+ n_mints = 1 if item.get("mint_timestamp") else 0
313
  n_ohlc = len(item.get("ohlc_1s", [])) if item.get("ohlc_1s") is not None else 0
314
  n_snapshots_5m = len(item.get("snapshots_5m", []))
315
  n_holders = len(item.get("holder_snapshots_list", []))
316
 
 
317
  tqdm.write(
318
+ f" + Cached: {mint_addr} | Class: {class_id} | Q: {q_score:.4f} | "
319
+ f"Events: Mint {n_mints}, Trades {n_trades}, Transfers {n_transfers}, Pool Creations {n_pool_creations}, "
320
+ f"Liquidity Changes {n_liquidity_changes}, Fee Collections {n_fee_collections}, "
321
+ f"Burns {n_burns}, Supply Locks {n_supply_locks}, Migrations {n_migrations} | "
322
+ f"Derived: Ohlc 1s {n_ohlc}, Snapshots 5m {n_snapshots_5m}, Holder Snapshots {n_holders}"
 
 
323
  )
324
 
325
  except Exception as e:
train.py CHANGED
@@ -339,15 +339,21 @@ def main() -> None:
339
  else:
340
  logger.info("INFO: Weights found but shuffle=False. Ignoring weights (sequential mode).")
341
 
342
- dataloader = DataLoader(
343
- dataset,
344
  batch_size=batch_size,
345
  shuffle=shuffle,
346
  sampler=sampler,
347
  num_workers=int(args.num_workers),
348
  pin_memory=bool(args.pin_memory),
349
- collate_fn=functools.partial(filtered_collate, collator)
350
  )
 
 
 
 
 
 
351
 
352
  # --- 3. Model Init ---
353
  logger.info("Initializing Oracle Model...")
@@ -361,16 +367,16 @@ def main() -> None:
361
  multi_modal_dim=multi_modal_encoder.embedding_dim,
362
  event_pad_id=vocab.EVENT_TO_ID["__PAD__"],
363
  event_type_to_id=vocab.EVENT_TO_ID,
364
- model_config_name="Qwen/Qwen3-0.6B",
365
  quantiles=quantiles,
366
  horizons_seconds=horizons,
367
  dtype=init_dtype
368
  )
369
 
370
- # Memory Optimization: Delete unused embedding layer from Qwen backbone
371
  if hasattr(model.model, 'embed_tokens'):
372
  del model.model.embed_tokens
373
- logger.info("Freed unused Qwen embedding layer memory.")
374
 
375
  # --- 4. Optimizer & Scheduler ---
376
  optimizer = AdamW(model.parameters(), lr=learning_rate)
 
339
  else:
340
  logger.info("INFO: Weights found but shuffle=False. Ignoring weights (sequential mode).")
341
 
342
+ dl_kwargs = dict(
343
+ dataset=dataset,
344
  batch_size=batch_size,
345
  shuffle=shuffle,
346
  sampler=sampler,
347
  num_workers=int(args.num_workers),
348
  pin_memory=bool(args.pin_memory),
349
+ collate_fn=functools.partial(filtered_collate, collator),
350
  )
351
+ if int(args.num_workers) > 0:
352
+ # Keeps workers alive across epochs. Otherwise each epoch respawns workers and
353
+ # re-initializes heavy per-worker state (e.g. SigLIP MultiModalEncoder).
354
+ dl_kwargs["persistent_workers"] = True
355
+ dl_kwargs["prefetch_factor"] = 2
356
+ dataloader = DataLoader(**dl_kwargs)
357
 
358
  # --- 3. Model Init ---
359
  logger.info("Initializing Oracle Model...")
 
367
  multi_modal_dim=multi_modal_encoder.embedding_dim,
368
  event_pad_id=vocab.EVENT_TO_ID["__PAD__"],
369
  event_type_to_id=vocab.EVENT_TO_ID,
370
+ model_config_name="llama3-12l-768d-gqa4-8k-random",
371
  quantiles=quantiles,
372
  horizons_seconds=horizons,
373
  dtype=init_dtype
374
  )
375
 
376
+ # Memory optimization: embedding layer isn't used when providing inputs_embeds.
377
  if hasattr(model.model, 'embed_tokens'):
378
  del model.model.embed_tokens
379
+ logger.info("Freed unused backbone embedding layer memory.")
380
 
381
  # --- 4. Optimizer & Scheduler ---
382
  optimizer = AdamW(model.parameters(), lr=learning_rate)