zirobtc commited on
Commit
88bc904
·
1 Parent(s): 85e02a7

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/cache_dataset.py +289 -208
scripts/cache_dataset.py CHANGED
@@ -12,6 +12,8 @@ from tqdm import tqdm
12
  from dotenv import load_dotenv
13
  import huggingface_hub
14
  import logging
 
 
15
 
16
  # Suppress noisy libraries
17
  logging.getLogger("httpx").setLevel(logging.WARNING)
@@ -21,8 +23,6 @@ logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
21
  # Add parent directory to path to import modules
22
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
23
 
24
- from data.data_loader import OracleDataset
25
- from data.data_fetcher import DataFetcher
26
  from scripts.analyze_distribution import get_return_class_map
27
  # Import quality score calculator
28
  from scripts.compute_quality_score import get_token_quality_scores, fetch_token_metrics, _bucket_id, _midrank_percentiles, EPS
@@ -30,18 +30,161 @@ from scripts.compute_quality_score import get_token_quality_scores, fetch_token_
30
  from clickhouse_driver import Client as ClickHouseClient
31
  from neo4j import GraphDatabase
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def compute_save_ohlc_stats(client: ClickHouseClient, output_path: str):
34
  """
35
  Computes global mean/std for price/volume from ClickHouse and saves to .npz
36
  This allows the dataset loader to normalize inputs correctly.
37
  """
38
  print(f"INFO: Computing OHLC stats (mean/std) from ClickHouse...")
39
-
40
- # Query matching preprocess_distribution.py logic
41
- # We use hardcoded min_price/vol filters to avoid skewing stats with dust
42
  min_price = 0.0
43
  min_vol = 0.0
44
-
45
  query = """
46
  SELECT
47
  AVG(t.price_usd) AS mean_price_usd,
@@ -53,9 +196,9 @@ def compute_save_ohlc_stats(client: ClickHouseClient, output_path: str):
53
  FROM trades AS t
54
  WHERE t.price_usd > %(min_price)s AND t.total_usd > %(min_vol)s
55
  """
56
-
57
  params = {"min_price": min_price, "min_vol": min_vol}
58
-
59
  try:
60
  result = client.execute(query, params=params)
61
  if not result or not result[0]:
@@ -67,10 +210,9 @@ def compute_save_ohlc_stats(client: ClickHouseClient, output_path: str):
67
  }
68
  else:
69
  row = result[0]
70
- # Handle potential None values if DB is empty
71
  def safe_float(x, default=0.0):
72
  return float(x) if x is not None else default
73
-
74
  def safe_std(x):
75
  val = safe_float(x, 1.0)
76
  return val if val > 1e-9 else 1.0
@@ -83,29 +225,24 @@ def compute_save_ohlc_stats(client: ClickHouseClient, output_path: str):
83
  "mean_trade_value_usd": safe_float(row[4]),
84
  "std_trade_value_usd": safe_std(row[5]),
85
  }
86
-
87
- # Save to NPZ
88
  out_p = Path(output_path)
89
  out_p.parent.mkdir(parents=True, exist_ok=True)
90
  np.savez(out_p, **stats)
91
-
92
  print(f"INFO: Saved OHLC stats to {out_p}")
93
  for k, v in stats.items():
94
  print(f" {k}: {v:.4f}")
95
-
96
  except Exception as e:
97
  print(f"ERROR: Failed to compute OHLC stats: {e}")
98
- # Don't crash, let it try to proceed (though dataset might complain if file missing)
99
 
100
  def build_quality_missing_reason_map(client: ClickHouseClient, max_ret: float = 1e9):
101
- """
102
- Build a map: token_address -> reason string for why a quality score is missing.
103
- This mirrors compute_quality_scores filtering and feature availability.
104
- """
105
  data = fetch_token_metrics(client)
106
  metrics_by_token = {d.get("token_address"): d for d in data if d.get("token_address")}
107
 
108
- # Build buckets with the same return filtering as compute_quality_scores
109
  buckets = {}
110
  for d in data:
111
  ret_val = d.get("ret")
@@ -117,7 +254,6 @@ def build_quality_missing_reason_map(client: ClickHouseClient, max_ret: float =
117
  d["bucket_id"] = b
118
  buckets.setdefault(b, []).append(d)
119
 
120
- # Same feature definitions as compute_quality_scores
121
  feature_defs = [
122
  ("fees_log", lambda d: math.log1p(d["fees_sol"]) if d.get("fees_sol") is not None else None, True),
123
  ("volume_log", lambda d: math.log1p(d["volume_usd"]) if d.get("volume_usd") is not None else None, True),
@@ -132,7 +268,6 @@ def build_quality_missing_reason_map(client: ClickHouseClient, max_ret: float =
132
  ("dev_hold_pct", lambda d: d.get("dev_hold_pct"), True),
133
  ]
134
 
135
- # Precompute percentiles per bucket + feature
136
  bucket_feature_percentiles = {}
137
  for b, items in buckets.items():
138
  feature_percentiles = {}
@@ -149,10 +284,10 @@ def build_quality_missing_reason_map(client: ClickHouseClient, max_ret: float =
149
  def _reason_for(token_address: str) -> str:
150
  d = metrics_by_token.get(token_address)
151
  if not d:
152
- return "no metrics found (missing from token_metrics/trades/mints joins)"
153
  ret_val = d.get("ret")
154
  if ret_val is None:
155
- return "ret is None (missing ATH/launch metrics)"
156
  if ret_val <= 0:
157
  return f"ret <= 0 ({ret_val})"
158
  if ret_val > max_ret:
@@ -160,27 +295,17 @@ def build_quality_missing_reason_map(client: ClickHouseClient, max_ret: float =
160
  b = _bucket_id(ret_val)
161
  if b == -1:
162
  return f"ret {ret_val} not in RETURN_THRESHOLDS"
163
- items = buckets.get(b, [])
164
- if not items:
165
- return f"bucket {b} empty after filtering"
166
- feature_percentiles = bucket_feature_percentiles.get(b, {})
167
- has_any = False
168
- missing_features = []
169
- for fname, _fget, _pos in feature_defs:
170
- if feature_percentiles.get(fname, {}).get(token_address) is None:
171
- missing_features.append(fname)
172
- else:
173
- has_any = True
174
- if not has_any:
175
- return "no valid feature percentiles for token (all features missing/invalid)"
176
- return f"unexpected: has feature percentiles but no score; missing features={','.join(missing_features)}"
177
 
178
  return _reason_for
179
 
 
180
  def main():
181
  load_dotenv()
182
-
183
- # Explicit Login
 
 
184
  hf_token = os.getenv("HF_TOKEN")
185
  if hf_token:
186
  print(f"INFO: Logging in to Hugging Face with token starting with: {hf_token[:4]}...")
@@ -195,15 +320,19 @@ def main():
195
  parser.add_argument("--ohlc_stats_path", type=str, default="data/ohlc_stats.npz")
196
  parser.add_argument("--min_trade_usd", type=float, default=0.0)
197
 
198
- # NEW: Context caching mode args
199
  parser.add_argument("--cache_mode", type=str, default="raw", choices=["raw", "context"],
200
- help="Cache mode: 'raw' caches raw token data (old behavior), 'context' caches fully processed training contexts (new behavior)")
201
  parser.add_argument("--context_length", type=int, default=8192,
202
- help="Max sequence length for context caching mode. Triggers H/B/H dynamic sampling when events exceed this limit.")
203
  parser.add_argument("--min_trades", type=int, default=10,
204
- help="Minimum number of trades required for T_cutoff sampling. Tokens with fewer trades are skipped.")
205
  parser.add_argument("--samples_per_token", type=int, default=1,
206
- help="Number of different T_cutoff samples to generate per token in context mode.")
 
 
 
 
207
 
208
  # DB Args
209
  parser.add_argument("--clickhouse_host", type=str, default=os.getenv("CLICKHOUSE_HOST", "localhost"))
@@ -213,36 +342,40 @@ def main():
213
  parser.add_argument("--neo4j_password", type=str, default=os.getenv("NEO4J_PASSWORD", "password"))
214
 
215
  args = parser.parse_args()
216
-
 
 
 
 
217
  output_dir = Path(args.output_dir)
218
  output_dir.mkdir(parents=True, exist_ok=True)
219
-
220
  start_date_dt = None
221
  if args.start_date:
222
  start_date_dt = datetime.datetime.strptime(args.start_date, "%Y-%m-%d")
223
-
224
  print(f"INFO: Initializing DB Connections...")
225
  clickhouse_client = ClickHouseClient(host=args.clickhouse_host, port=args.clickhouse_port)
226
  neo4j_driver = GraphDatabase.driver(args.neo4j_uri, auth=(args.neo4j_user, args.neo4j_password))
227
-
228
  try:
229
  # --- 1. Compute OHLC Stats (Global) ---
230
  compute_save_ohlc_stats(clickhouse_client, args.ohlc_stats_path)
231
 
232
- # --- 2. Initialize DataFetcher and OracleDataset ---
 
 
 
233
  data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
234
-
235
- # Pre-fetch the Return Class Map
236
  print("INFO: Fetching Return Classification Map...")
237
  return_class_map, thresholds = get_return_class_map(clickhouse_client)
238
  print(f"INFO: Loaded {len(return_class_map)} valid classified tokens.")
239
 
240
- # Pre-fetch Quality Scores
241
  print("INFO: Fetching Token Quality Scores...")
242
  quality_scores_map = get_token_quality_scores(clickhouse_client)
243
- quality_missing_reason = build_quality_missing_reason_map(clickhouse_client, max_ret=1e9)
244
  print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
245
-
246
  dataset = OracleDataset(
247
  data_fetcher=data_fetcher,
248
  max_samples=args.max_samples,
@@ -251,19 +384,18 @@ def main():
251
  horizons_seconds=[60, 180, 300, 600, 1800, 3600, 7200],
252
  quantiles=[0.5],
253
  min_trade_usd=args.min_trade_usd,
254
- max_seq_len=args.context_length # Pass context_length for H/B/H threshold
255
  )
256
-
257
  if len(dataset) == 0:
258
  print("WARNING: Dataset initialization resulted in 0 samples. Nothing to cache.")
259
  return
260
 
261
- # --- FILTER DATASET BY CLASS MAP ---
262
- # Only keep mints that are classified (valid return, sufficient data)
263
  original_size = len(dataset)
264
  print(f"INFO: Filtering dataset... Original size: {original_size}")
265
  dataset.sampled_mints = [
266
- m for m in dataset.sampled_mints
267
  if m['mint_address'] in return_class_map
268
  ]
269
  filtered_size = len(dataset)
@@ -274,156 +406,101 @@ def main():
274
  print("WARNING: No tokens remain after filtering by return_class_map.")
275
  return
276
 
277
- # --- 3. Iterate and cache based on mode ---
278
  print(f"INFO: Cache mode: {args.cache_mode}")
279
- print(f"INFO: Starting to generate and cache from {len(dataset)} tokens...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
 
 
281
  skipped_count = 0
282
- cached_count = 0
283
- global_sample_idx = 0 # Global counter for unique sample filenames
284
-
285
- # Track class distribution for balanced sampling metadata
286
  class_distribution = {}
287
 
288
- if args.cache_mode == "context":
289
- # =========================================================================
290
- # CONTEXT MODE: Cache fully processed training contexts
291
- # - Samples T_cutoff during caching (non-deterministic moved to cache time)
292
- # - Applies H/B/H dynamic sampling based on context_length
293
- # - Avoids caching tokens that won't be seen (garbage filtered out)
294
- # - Training becomes fully deterministic (just loads cached contexts)
295
- # =========================================================================
296
- print(f"INFO: Context mode settings:")
297
- print(f" - context_length (H/B/H threshold): {args.context_length}")
298
- print(f" - min_trades (T_cutoff threshold): {args.min_trades}")
299
- print(f" - samples_per_token: {args.samples_per_token}")
300
-
301
- for i in tqdm(range(len(dataset)), desc="Caching contexts"):
302
- mint_addr = dataset.sampled_mints[i]['mint_address']
303
- class_id = return_class_map[mint_addr]
304
-
305
- try:
306
- # Generate multiple training contexts per token
307
- contexts = dataset.__cacheitem_context__(i, num_samples_per_token=args.samples_per_token)
308
-
309
- if not contexts:
310
- skipped_count += 1
311
- continue
312
-
313
- # Require quality score
314
- if mint_addr not in quality_scores_map:
315
- reason = quality_missing_reason(mint_addr)
316
- raise RuntimeError(
317
- f"Missing quality score for mint {mint_addr}. Reason: {reason}."
318
- )
319
- q_score = quality_scores_map[mint_addr]
320
-
321
- # Save each context as a separate sample
322
- for ctx in contexts:
323
- ctx["quality_score"] = q_score
324
- ctx["class_id"] = class_id
325
- ctx["source_token"] = mint_addr # Track origin for debugging
326
- ctx["cache_mode"] = "context"
327
-
328
- filename = f"sample_{global_sample_idx}.pt"
329
- output_path = output_dir / filename
330
- torch.save(ctx, output_path)
331
-
332
- # Track class distribution
333
- class_distribution[class_id] = class_distribution.get(class_id, 0) + 1
334
-
335
- global_sample_idx += 1
336
- cached_count += 1
337
-
338
- n_events = len(contexts[0].get("event_sequence", [])) if contexts else 0
339
- tqdm.write(
340
- f" + Cached {len(contexts)} contexts: {mint_addr} | Class: {class_id} | Q: {q_score:.4f} | Events: {n_events}"
341
- )
342
-
343
- except Exception as e:
344
- error_msg = str(e)
345
- if "FATAL" in error_msg or "AuthenticationRateLimit" in error_msg:
346
- print(f"\nCRITICAL: Fatal error processing sample {i}. Stopping.\nError: {e}", file=sys.stderr)
347
- sys.exit(1)
348
-
349
- print(f"\nERROR: Failed to cache contexts for {mint_addr}. Error: {e}", file=sys.stderr)
350
- import traceback
351
- traceback.print_exc()
352
  skipped_count += 1
353
- continue
354
-
 
355
  else:
356
- # =========================================================================
357
- # RAW MODE: Cache raw token data (original behavior)
358
- # - T_cutoff sampling happens at runtime
359
- # - H/B/H applied at runtime
360
- # - Non-deterministic training
361
- # =========================================================================
362
- for i in tqdm(range(len(dataset)), desc="Caching raw samples"):
363
- mint_addr = dataset.sampled_mints[i]['mint_address']
364
- class_id = return_class_map[mint_addr]
365
-
366
- try:
367
- item = dataset.__cacheitem__(i)
368
- if item is None:
369
- skipped_count += 1
370
- continue
371
-
372
- if mint_addr not in quality_scores_map:
373
- reason = quality_missing_reason(mint_addr)
374
- raise RuntimeError(
375
- f"Missing quality score for mint {mint_addr}. Reason: {reason}."
376
- )
377
- q_score = quality_scores_map[mint_addr]
378
-
379
- item["quality_score"] = q_score
380
- item["class_id"] = class_id
381
- item["cache_mode"] = "raw"
382
-
383
- filename = f"sample_{i}.pt"
384
- output_path = output_dir / filename
385
- torch.save(item, output_path)
386
-
387
- # Track class distribution
388
- class_distribution[class_id] = class_distribution.get(class_id, 0) + 1
389
-
390
- cached_count += 1
391
-
392
- n_trades = len(item.get("trades", []))
393
- n_transfers = len(item.get("transfers", []))
394
- n_pool_creations = len(item.get("pool_creations", []))
395
- n_liquidity_changes = len(item.get("liquidity_changes", []))
396
- n_fee_collections = len(item.get("fee_collections", []))
397
- n_burns = len(item.get("burns", []))
398
- n_supply_locks = len(item.get("supply_locks", []))
399
- n_migrations = len(item.get("migrations", []))
400
- n_mints = 1 if item.get("mint_timestamp") else 0
401
- n_ohlc = len(item.get("ohlc_1s", [])) if item.get("ohlc_1s") is not None else 0
402
- n_snapshots_5m = len(item.get("snapshots_5m", []))
403
- n_holders = len(item.get("holder_snapshots_list", []))
404
-
405
- tqdm.write(
406
- f" + Cached: {mint_addr} | Class: {class_id} | Q: {q_score:.4f} | "
407
- f"Events: Mint {n_mints}, Trades {n_trades}, Transfers {n_transfers}, Pool Creations {n_pool_creations}, "
408
- f"Liquidity Changes {n_liquidity_changes}, Fee Collections {n_fee_collections}, "
409
- f"Burns {n_burns}, Supply Locks {n_supply_locks}, Migrations {n_migrations} | "
410
- f"Derived: Ohlc 1s {n_ohlc}, Snapshots 5m {n_snapshots_5m}, Holder Snapshots {n_holders}"
411
- )
412
-
413
- except Exception as e:
414
- error_msg = str(e)
415
- if "FATAL" in error_msg or "AuthenticationRateLimit" in error_msg:
416
- print(f"\nCRITICAL: Fatal error processing sample {i}. Stopping.\nError: {e}", file=sys.stderr)
417
- sys.exit(1)
418
-
419
- print(f"\nERROR: Failed to cache sample {i} for {mint_addr}. Error: {e}", file=sys.stderr)
420
- import traceback
421
- traceback.print_exc()
422
- skipped_count += 1
423
- continue
424
-
425
- # --- Save class metadata for balanced sampling ---
426
- # Build file_class_map for the metadata cache
427
  file_class_map = {}
428
  for sample_file in sorted(output_dir.glob("sample_*.pt")):
429
  try:
@@ -437,11 +514,12 @@ def main():
437
  with open(metadata_path, 'w') as f:
438
  json.dump({
439
  'file_class_map': file_class_map,
440
- 'class_distribution': class_distribution,
441
  'cache_mode': args.cache_mode,
442
  'context_length': args.context_length if args.cache_mode == "context" else None,
443
  'min_trades': args.min_trades if args.cache_mode == "context" else None,
444
  'samples_per_token': args.samples_per_token if args.cache_mode == "context" else None,
 
445
  }, f, indent=2)
446
  print(f"INFO: Saved class metadata to {metadata_path}")
447
  except Exception as e:
@@ -449,16 +527,19 @@ def main():
449
 
450
  print(f"\n--- Caching Complete ---")
451
  print(f"Cache mode: {args.cache_mode}")
452
- print(f"Successfully cached: {cached_count} samples.")
453
- print(f"Filtered (Invalid/High Return): {filtered_count} tokens.")
454
- print(f"Skipped (Errors/Empty): {skipped_count} tokens.")
 
 
 
455
  print(f"Class distribution: {class_distribution}")
456
  print(f"Cache location: {output_dir.resolve()}")
457
 
458
  finally:
459
- # --- 4. Close connections ---
460
  clickhouse_client.disconnect()
461
  neo4j_driver.close()
462
 
 
463
  if __name__ == "__main__":
464
  main()
 
12
  from dotenv import load_dotenv
13
  import huggingface_hub
14
  import logging
15
+ from concurrent.futures import ProcessPoolExecutor, as_completed
16
+ import multiprocessing as mp
17
 
18
  # Suppress noisy libraries
19
  logging.getLogger("httpx").setLevel(logging.WARNING)
 
23
  # Add parent directory to path to import modules
24
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
25
 
 
 
26
  from scripts.analyze_distribution import get_return_class_map
27
  # Import quality score calculator
28
  from scripts.compute_quality_score import get_token_quality_scores, fetch_token_metrics, _bucket_id, _midrank_percentiles, EPS
 
30
  from clickhouse_driver import Client as ClickHouseClient
31
  from neo4j import GraphDatabase
32
 
33
+ # Global variables for worker processes (initialized per-worker)
34
+ _worker_dataset = None
35
+ _worker_return_class_map = None
36
+ _worker_quality_scores_map = None
37
+
38
+
39
+ def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map):
40
+ """Initialize worker process with its own DB connections and dataset."""
41
+ global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
42
+
43
+ from data.data_loader import OracleDataset
44
+ from data.data_fetcher import DataFetcher
45
+
46
+ # Each worker gets its own DB connections
47
+ clickhouse_client = ClickHouseClient(
48
+ host=db_config['clickhouse_host'],
49
+ port=db_config['clickhouse_port']
50
+ )
51
+ neo4j_driver = GraphDatabase.driver(
52
+ db_config['neo4j_uri'],
53
+ auth=(db_config['neo4j_user'], db_config['neo4j_password'])
54
+ )
55
+
56
+ data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
57
+
58
+ _worker_dataset = OracleDataset(
59
+ data_fetcher=data_fetcher,
60
+ max_samples=dataset_config['max_samples'],
61
+ start_date=dataset_config['start_date'],
62
+ ohlc_stats_path=dataset_config['ohlc_stats_path'],
63
+ horizons_seconds=dataset_config['horizons_seconds'],
64
+ quantiles=dataset_config['quantiles'],
65
+ min_trade_usd=dataset_config['min_trade_usd'],
66
+ max_seq_len=dataset_config['max_seq_len']
67
+ )
68
+
69
+ # Set the filtered mints
70
+ _worker_dataset.sampled_mints = dataset_config['sampled_mints']
71
+
72
+ _worker_return_class_map = return_class_map
73
+ _worker_quality_scores_map = quality_scores_map
74
+
75
+
76
+ def _process_single_token_context(args):
77
+ """Worker function to process a single token in context mode."""
78
+ idx, mint_addr, samples_per_token, output_dir = args
79
+
80
+ global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
81
+
82
+ try:
83
+ class_id = _worker_return_class_map.get(mint_addr)
84
+ if class_id is None:
85
+ return {'status': 'skipped', 'reason': 'not in class map', 'mint': mint_addr}
86
+
87
+ # Generate contexts
88
+ contexts = _worker_dataset.__cacheitem_context__(idx, num_samples_per_token=samples_per_token)
89
+
90
+ if not contexts:
91
+ return {'status': 'skipped', 'reason': 'no valid contexts', 'mint': mint_addr}
92
+
93
+ q_score = _worker_quality_scores_map.get(mint_addr)
94
+ if q_score is None:
95
+ return {'status': 'skipped', 'reason': 'no quality score', 'mint': mint_addr}
96
+
97
+ # Save contexts - use mint_addr hash for unique filenames
98
+ saved_files = []
99
+ for ctx_idx, ctx in enumerate(contexts):
100
+ ctx["quality_score"] = q_score
101
+ ctx["class_id"] = class_id
102
+ ctx["source_token"] = mint_addr
103
+ ctx["cache_mode"] = "context"
104
+
105
+ # Use hash-based filename to avoid conflicts
106
+ filename = f"sample_{mint_addr[:16]}_{ctx_idx}.pt"
107
+ output_path = Path(output_dir) / filename
108
+ torch.save(ctx, output_path)
109
+ saved_files.append(filename)
110
+
111
+ return {
112
+ 'status': 'success',
113
+ 'mint': mint_addr,
114
+ 'class_id': class_id,
115
+ 'q_score': q_score,
116
+ 'n_contexts': len(contexts),
117
+ 'n_events': len(contexts[0].get('event_sequence', [])) if contexts else 0,
118
+ 'files': saved_files
119
+ }
120
+
121
+ except Exception as e:
122
+ import traceback
123
+ return {
124
+ 'status': 'error',
125
+ 'mint': mint_addr,
126
+ 'error': str(e),
127
+ 'traceback': traceback.format_exc()
128
+ }
129
+
130
+
131
+ def _process_single_token_raw(args):
132
+ """Worker function to process a single token in raw mode."""
133
+ idx, mint_addr, output_dir = args
134
+
135
+ global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
136
+
137
+ try:
138
+ class_id = _worker_return_class_map.get(mint_addr)
139
+ if class_id is None:
140
+ return {'status': 'skipped', 'reason': 'not in class map', 'mint': mint_addr}
141
+
142
+ item = _worker_dataset.__cacheitem__(idx)
143
+
144
+ if item is None:
145
+ return {'status': 'skipped', 'reason': 'cacheitem returned None', 'mint': mint_addr}
146
+
147
+ q_score = _worker_quality_scores_map.get(mint_addr)
148
+ if q_score is None:
149
+ return {'status': 'skipped', 'reason': 'no quality score', 'mint': mint_addr}
150
+
151
+ item["quality_score"] = q_score
152
+ item["class_id"] = class_id
153
+ item["cache_mode"] = "raw"
154
+
155
+ filename = f"sample_{mint_addr[:16]}.pt"
156
+ output_path = Path(output_dir) / filename
157
+ torch.save(item, output_path)
158
+
159
+ return {
160
+ 'status': 'success',
161
+ 'mint': mint_addr,
162
+ 'class_id': class_id,
163
+ 'q_score': q_score,
164
+ 'n_trades': len(item.get('trades', [])),
165
+ 'files': [filename]
166
+ }
167
+
168
+ except Exception as e:
169
+ import traceback
170
+ return {
171
+ 'status': 'error',
172
+ 'mint': mint_addr,
173
+ 'error': str(e),
174
+ 'traceback': traceback.format_exc()
175
+ }
176
+
177
+
178
  def compute_save_ohlc_stats(client: ClickHouseClient, output_path: str):
179
  """
180
  Computes global mean/std for price/volume from ClickHouse and saves to .npz
181
  This allows the dataset loader to normalize inputs correctly.
182
  """
183
  print(f"INFO: Computing OHLC stats (mean/std) from ClickHouse...")
184
+
 
 
185
  min_price = 0.0
186
  min_vol = 0.0
187
+
188
  query = """
189
  SELECT
190
  AVG(t.price_usd) AS mean_price_usd,
 
196
  FROM trades AS t
197
  WHERE t.price_usd > %(min_price)s AND t.total_usd > %(min_vol)s
198
  """
199
+
200
  params = {"min_price": min_price, "min_vol": min_vol}
201
+
202
  try:
203
  result = client.execute(query, params=params)
204
  if not result or not result[0]:
 
210
  }
211
  else:
212
  row = result[0]
 
213
  def safe_float(x, default=0.0):
214
  return float(x) if x is not None else default
215
+
216
  def safe_std(x):
217
  val = safe_float(x, 1.0)
218
  return val if val > 1e-9 else 1.0
 
225
  "mean_trade_value_usd": safe_float(row[4]),
226
  "std_trade_value_usd": safe_std(row[5]),
227
  }
228
+
 
229
  out_p = Path(output_path)
230
  out_p.parent.mkdir(parents=True, exist_ok=True)
231
  np.savez(out_p, **stats)
232
+
233
  print(f"INFO: Saved OHLC stats to {out_p}")
234
  for k, v in stats.items():
235
  print(f" {k}: {v:.4f}")
236
+
237
  except Exception as e:
238
  print(f"ERROR: Failed to compute OHLC stats: {e}")
239
+
240
 
241
  def build_quality_missing_reason_map(client: ClickHouseClient, max_ret: float = 1e9):
242
+ """Build a map: token_address -> reason string for why a quality score is missing."""
 
 
 
243
  data = fetch_token_metrics(client)
244
  metrics_by_token = {d.get("token_address"): d for d in data if d.get("token_address")}
245
 
 
246
  buckets = {}
247
  for d in data:
248
  ret_val = d.get("ret")
 
254
  d["bucket_id"] = b
255
  buckets.setdefault(b, []).append(d)
256
 
 
257
  feature_defs = [
258
  ("fees_log", lambda d: math.log1p(d["fees_sol"]) if d.get("fees_sol") is not None else None, True),
259
  ("volume_log", lambda d: math.log1p(d["volume_usd"]) if d.get("volume_usd") is not None else None, True),
 
268
  ("dev_hold_pct", lambda d: d.get("dev_hold_pct"), True),
269
  ]
270
 
 
271
  bucket_feature_percentiles = {}
272
  for b, items in buckets.items():
273
  feature_percentiles = {}
 
284
  def _reason_for(token_address: str) -> str:
285
  d = metrics_by_token.get(token_address)
286
  if not d:
287
+ return "no metrics found"
288
  ret_val = d.get("ret")
289
  if ret_val is None:
290
+ return "ret is None"
291
  if ret_val <= 0:
292
  return f"ret <= 0 ({ret_val})"
293
  if ret_val > max_ret:
 
295
  b = _bucket_id(ret_val)
296
  if b == -1:
297
  return f"ret {ret_val} not in RETURN_THRESHOLDS"
298
+ return "unknown"
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
  return _reason_for
301
 
302
+
303
  def main():
304
  load_dotenv()
305
+
306
+ # Use spawn method for multiprocessing (safer with CUDA/DB connections)
307
+ mp.set_start_method('spawn', force=True)
308
+
309
  hf_token = os.getenv("HF_TOKEN")
310
  if hf_token:
311
  print(f"INFO: Logging in to Hugging Face with token starting with: {hf_token[:4]}...")
 
320
  parser.add_argument("--ohlc_stats_path", type=str, default="data/ohlc_stats.npz")
321
  parser.add_argument("--min_trade_usd", type=float, default=0.0)
322
 
323
+ # Context caching mode args
324
  parser.add_argument("--cache_mode", type=str, default="raw", choices=["raw", "context"],
325
+ help="Cache mode: 'raw' or 'context'")
326
  parser.add_argument("--context_length", type=int, default=8192,
327
+ help="Max sequence length for H/B/H threshold")
328
  parser.add_argument("--min_trades", type=int, default=10,
329
+ help="Minimum trades for T_cutoff sampling")
330
  parser.add_argument("--samples_per_token", type=int, default=1,
331
+ help="Number of T_cutoff samples per token")
332
+
333
+ # Parallelization args
334
+ parser.add_argument("--num_workers", type=int, default=1,
335
+ help="Number of parallel workers (default: 1, use 0 for auto-detect)")
336
 
337
  # DB Args
338
  parser.add_argument("--clickhouse_host", type=str, default=os.getenv("CLICKHOUSE_HOST", "localhost"))
 
342
  parser.add_argument("--neo4j_password", type=str, default=os.getenv("NEO4J_PASSWORD", "password"))
343
 
344
  args = parser.parse_args()
345
+
346
+ # Auto-detect workers if set to 0
347
+ if args.num_workers == 0:
348
+ args.num_workers = max(1, mp.cpu_count() - 4)
349
+
350
  output_dir = Path(args.output_dir)
351
  output_dir.mkdir(parents=True, exist_ok=True)
352
+
353
  start_date_dt = None
354
  if args.start_date:
355
  start_date_dt = datetime.datetime.strptime(args.start_date, "%Y-%m-%d")
356
+
357
  print(f"INFO: Initializing DB Connections...")
358
  clickhouse_client = ClickHouseClient(host=args.clickhouse_host, port=args.clickhouse_port)
359
  neo4j_driver = GraphDatabase.driver(args.neo4j_uri, auth=(args.neo4j_user, args.neo4j_password))
360
+
361
  try:
362
  # --- 1. Compute OHLC Stats (Global) ---
363
  compute_save_ohlc_stats(clickhouse_client, args.ohlc_stats_path)
364
 
365
+ # --- 2. Initialize DataFetcher and OracleDataset (main process) ---
366
+ from data.data_loader import OracleDataset
367
+ from data.data_fetcher import DataFetcher
368
+
369
  data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
370
+
 
371
  print("INFO: Fetching Return Classification Map...")
372
  return_class_map, thresholds = get_return_class_map(clickhouse_client)
373
  print(f"INFO: Loaded {len(return_class_map)} valid classified tokens.")
374
 
 
375
  print("INFO: Fetching Token Quality Scores...")
376
  quality_scores_map = get_token_quality_scores(clickhouse_client)
 
377
  print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
378
+
379
  dataset = OracleDataset(
380
  data_fetcher=data_fetcher,
381
  max_samples=args.max_samples,
 
384
  horizons_seconds=[60, 180, 300, 600, 1800, 3600, 7200],
385
  quantiles=[0.5],
386
  min_trade_usd=args.min_trade_usd,
387
+ max_seq_len=args.context_length
388
  )
389
+
390
  if len(dataset) == 0:
391
  print("WARNING: Dataset initialization resulted in 0 samples. Nothing to cache.")
392
  return
393
 
394
+ # Filter dataset by class map
 
395
  original_size = len(dataset)
396
  print(f"INFO: Filtering dataset... Original size: {original_size}")
397
  dataset.sampled_mints = [
398
+ m for m in dataset.sampled_mints
399
  if m['mint_address'] in return_class_map
400
  ]
401
  filtered_size = len(dataset)
 
406
  print("WARNING: No tokens remain after filtering by return_class_map.")
407
  return
408
 
409
+ # --- 3. Parallel caching ---
410
  print(f"INFO: Cache mode: {args.cache_mode}")
411
+ print(f"INFO: Number of workers: {args.num_workers}")
412
+ print(f"INFO: Starting to cache {len(dataset)} tokens...")
413
+
414
+ # Prepare configs for workers
415
+ db_config = {
416
+ 'clickhouse_host': args.clickhouse_host,
417
+ 'clickhouse_port': args.clickhouse_port,
418
+ 'neo4j_uri': args.neo4j_uri,
419
+ 'neo4j_user': args.neo4j_user,
420
+ 'neo4j_password': args.neo4j_password,
421
+ }
422
+
423
+ dataset_config = {
424
+ 'max_samples': args.max_samples,
425
+ 'start_date': start_date_dt,
426
+ 'ohlc_stats_path': args.ohlc_stats_path,
427
+ 'horizons_seconds': [60, 180, 300, 600, 1800, 3600, 7200],
428
+ 'quantiles': [0.5],
429
+ 'min_trade_usd': args.min_trade_usd,
430
+ 'max_seq_len': args.context_length,
431
+ 'sampled_mints': dataset.sampled_mints, # Pass filtered mints
432
+ }
433
+
434
+ # Prepare task list
435
+ tasks = []
436
+ for i in range(len(dataset)):
437
+ mint_addr = dataset.sampled_mints[i]['mint_address']
438
+ if args.cache_mode == "context":
439
+ tasks.append((i, mint_addr, args.samples_per_token, str(output_dir)))
440
+ else:
441
+ tasks.append((i, mint_addr, str(output_dir)))
442
 
443
+ # Track results
444
+ success_count = 0
445
  skipped_count = 0
446
+ error_count = 0
 
 
 
447
  class_distribution = {}
448
 
449
+ if args.num_workers == 1:
450
+ # Single-threaded mode (no multiprocessing overhead)
451
+ print("INFO: Running in single-threaded mode...")
452
+ _init_worker(db_config, dataset_config, return_class_map, quality_scores_map)
453
+
454
+ process_fn = _process_single_token_context if args.cache_mode == "context" else _process_single_token_raw
455
+
456
+ for task in tqdm(tasks, desc="Caching"):
457
+ result = process_fn(task)
458
+
459
+ if result['status'] == 'success':
460
+ success_count += 1
461
+ cid = result['class_id']
462
+ class_distribution[cid] = class_distribution.get(cid, 0) + 1
463
+ if args.cache_mode == "context":
464
+ tqdm.write(f" + {result['mint'][:16]} | Class: {cid} | Q: {result['q_score']:.4f} | Contexts: {result['n_contexts']} | Events: {result['n_events']}")
465
+ else:
466
+ tqdm.write(f" + {result['mint'][:16]} | Class: {cid} | Q: {result['q_score']:.4f} | Trades: {result['n_trades']}")
467
+ elif result['status'] == 'skipped':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
  skipped_count += 1
469
+ else:
470
+ error_count += 1
471
+ tqdm.write(f" ERROR: {result['mint'][:16]} - {result['error']}")
472
  else:
473
+ # Multi-process mode
474
+ print(f"INFO: Running with {args.num_workers} parallel workers...")
475
+
476
+ process_fn = _process_single_token_context if args.cache_mode == "context" else _process_single_token_raw
477
+
478
+ with ProcessPoolExecutor(
479
+ max_workers=args.num_workers,
480
+ initializer=_init_worker,
481
+ initargs=(db_config, dataset_config, return_class_map, quality_scores_map)
482
+ ) as executor:
483
+ futures = {executor.submit(process_fn, task): task for task in tasks}
484
+
485
+ for future in tqdm(as_completed(futures), total=len(futures), desc="Caching"):
486
+ try:
487
+ result = future.result(timeout=300) # 5 min timeout per token
488
+
489
+ if result['status'] == 'success':
490
+ success_count += 1
491
+ cid = result['class_id']
492
+ class_distribution[cid] = class_distribution.get(cid, 0) + 1
493
+ elif result['status'] == 'skipped':
494
+ skipped_count += 1
495
+ else:
496
+ error_count += 1
497
+ tqdm.write(f" ERROR: {result.get('mint', 'unknown')[:16]} - {result.get('error', 'unknown')}")
498
+ except Exception as e:
499
+ error_count += 1
500
+ tqdm.write(f" WORKER ERROR: {e}")
501
+
502
+ # --- 4. Build metadata ---
503
+ print("INFO: Building class metadata...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
504
  file_class_map = {}
505
  for sample_file in sorted(output_dir.glob("sample_*.pt")):
506
  try:
 
514
  with open(metadata_path, 'w') as f:
515
  json.dump({
516
  'file_class_map': file_class_map,
517
+ 'class_distribution': {str(k): v for k, v in class_distribution.items()},
518
  'cache_mode': args.cache_mode,
519
  'context_length': args.context_length if args.cache_mode == "context" else None,
520
  'min_trades': args.min_trades if args.cache_mode == "context" else None,
521
  'samples_per_token': args.samples_per_token if args.cache_mode == "context" else None,
522
+ 'num_workers': args.num_workers,
523
  }, f, indent=2)
524
  print(f"INFO: Saved class metadata to {metadata_path}")
525
  except Exception as e:
 
527
 
528
  print(f"\n--- Caching Complete ---")
529
  print(f"Cache mode: {args.cache_mode}")
530
+ print(f"Workers used: {args.num_workers}")
531
+ print(f"Successfully cached: {success_count} tokens")
532
+ print(f"Total files: {len(file_class_map)}")
533
+ print(f"Filtered: {filtered_count} tokens")
534
+ print(f"Skipped: {skipped_count} tokens")
535
+ print(f"Errors: {error_count} tokens")
536
  print(f"Class distribution: {class_distribution}")
537
  print(f"Cache location: {output_dir.resolve()}")
538
 
539
  finally:
 
540
  clickhouse_client.disconnect()
541
  neo4j_driver.close()
542
 
543
+
544
  if __name__ == "__main__":
545
  main()