zirobtc commited on
Commit
77ebb19
·
1 Parent(s): d2461e7

Update scripts/cache_dataset.py: speed + balance + correctness fixes

Browse files
Files changed (1) hide show
  1. scripts/cache_dataset.py +95 -11
scripts/cache_dataset.py CHANGED
@@ -146,6 +146,8 @@ def main():
146
  parser.add_argument("--context_length", type=int, default=8192)
147
  parser.add_argument("--min_trades", type=int, default=10)
148
  parser.add_argument("--samples_per_token", type=int, default=1)
 
 
149
  parser.add_argument("--num_workers", type=int, default=1)
150
  parser.add_argument("--clickhouse_host", type=str, default=os.getenv("CLICKHOUSE_HOST", "localhost"))
151
  parser.add_argument("--clickhouse_port", type=int, default=int(os.getenv("CLICKHOUSE_PORT", 9000)))
@@ -180,7 +182,7 @@ def main():
180
  quality_scores_map = get_token_quality_scores(clickhouse_client)
181
  print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
182
 
183
- dataset = OracleDataset(data_fetcher=data_fetcher, max_samples=args.max_samples, start_date=start_date_dt, ohlc_stats_path=args.ohlc_stats_path, horizons_seconds=[60, 180, 300, 600, 1800, 3600, 7200], quantiles=[0.5], min_trade_usd=args.min_trade_usd, max_seq_len=args.context_length)
184
 
185
  if len(dataset) == 0:
186
  print("WARNING: No samples. Exiting.")
@@ -189,7 +191,25 @@ def main():
189
  # Filter mints by return_class_map
190
  original_size = len(dataset.sampled_mints)
191
  filtered_mints = [m for m in dataset.sampled_mints if m['mint_address'] in return_class_map]
192
- print(f"INFO: Filtered {original_size} -> {len(filtered_mints)} tokens")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
  if len(filtered_mints) == 0:
195
  print("WARNING: No tokens after filtering.")
@@ -198,16 +218,69 @@ def main():
198
  print(f"INFO: Cache mode: {args.cache_mode}, Workers: {args.num_workers}")
199
 
200
  db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password}
201
- dataset_config = {'max_samples': args.max_samples, 'start_date': start_date_dt, 'ohlc_stats_path': args.ohlc_stats_path, 'horizons_seconds': [60, 180, 300, 600, 1800, 3600, 7200], 'quantiles': [0.5], 'min_trade_usd': args.min_trade_usd, 'max_seq_len': args.context_length, 'sampled_mints': filtered_mints}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
- # Build tasks from filtered_mints directly
204
  tasks = []
205
- for i, mint_record in enumerate(filtered_mints):
206
- mint_addr = mint_record['mint_address']
207
- if args.cache_mode == "context":
208
- tasks.append((i, mint_addr, args.samples_per_token, str(output_dir)))
209
- else:
210
- tasks.append((i, mint_addr, str(output_dir)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
  print(f"INFO: Starting to cache {len(tasks)} tokens...")
213
 
@@ -255,7 +328,18 @@ def main():
255
  pass
256
 
257
  with open(output_dir / "class_metadata.json", 'w') as f:
258
- json.dump({'file_class_map': file_class_map, 'class_distribution': {str(k): v for k, v in class_distribution.items()}, 'cache_mode': args.cache_mode, 'num_workers': args.num_workers}, f, indent=2)
 
 
 
 
 
 
 
 
 
 
 
259
 
260
  print(f"\n--- Done ---\nSuccess: {success_count}, Skipped: {skipped_count}, Errors: {error_count}\nFiles: {len(file_class_map)}\nLocation: {output_dir.resolve()}")
261
 
 
146
  parser.add_argument("--context_length", type=int, default=8192)
147
  parser.add_argument("--min_trades", type=int, default=10)
148
  parser.add_argument("--samples_per_token", type=int, default=1)
149
+ parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[30, 60, 120, 240, 420])
150
+ parser.add_argument("--quantiles", type=float, nargs="+", default=[0.1, 0.5, 0.9])
151
  parser.add_argument("--num_workers", type=int, default=1)
152
  parser.add_argument("--clickhouse_host", type=str, default=os.getenv("CLICKHOUSE_HOST", "localhost"))
153
  parser.add_argument("--clickhouse_port", type=int, default=int(os.getenv("CLICKHOUSE_PORT", 9000)))
 
182
  quality_scores_map = get_token_quality_scores(clickhouse_client)
183
  print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
184
 
185
+ dataset = OracleDataset(data_fetcher=data_fetcher, max_samples=args.max_samples, start_date=start_date_dt, ohlc_stats_path=args.ohlc_stats_path, horizons_seconds=args.horizons_seconds, quantiles=args.quantiles, min_trade_usd=args.min_trade_usd, max_seq_len=args.context_length)
186
 
187
  if len(dataset) == 0:
188
  print("WARNING: No samples. Exiting.")
 
191
  # Filter mints by return_class_map
192
  original_size = len(dataset.sampled_mints)
193
  filtered_mints = [m for m in dataset.sampled_mints if m['mint_address'] in return_class_map]
194
+ print(f"INFO: Filtered by class map: {original_size} -> {len(filtered_mints)} tokens")
195
+
196
+ # Pre-filter: only keep tokens with >= min_trades trades (fast ClickHouse count query)
197
+ print(f"INFO: Pre-filtering tokens by trade count (>= {args.min_trades} trades)...")
198
+ trade_counts = clickhouse_client.execute("""
199
+ SELECT base_address, count() as cnt
200
+ FROM trades
201
+ GROUP BY base_address
202
+ HAVING cnt >= %(min_trades)s
203
+ """, {'min_trades': args.min_trades})
204
+ valid_tokens = {row[0] for row in trade_counts}
205
+ pre_filter_size = len(filtered_mints)
206
+ filtered_mints = [m for m in filtered_mints if m['mint_address'] in valid_tokens]
207
+ print(f"INFO: Pre-filtered by trade count: {pre_filter_size} -> {len(filtered_mints)} tokens (removed {pre_filter_size - len(filtered_mints)} with < {args.min_trades} trades)")
208
+
209
+ # Also filter by quality score availability
210
+ pre_quality_size = len(filtered_mints)
211
+ filtered_mints = [m for m in filtered_mints if m['mint_address'] in quality_scores_map]
212
+ print(f"INFO: Filtered by quality score: {pre_quality_size} -> {len(filtered_mints)} tokens")
213
 
214
  if len(filtered_mints) == 0:
215
  print("WARNING: No tokens after filtering.")
 
218
  print(f"INFO: Cache mode: {args.cache_mode}, Workers: {args.num_workers}")
219
 
220
  db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password}
221
+ dataset_config = {'max_samples': args.max_samples, 'start_date': start_date_dt, 'ohlc_stats_path': args.ohlc_stats_path, 'horizons_seconds': args.horizons_seconds, 'quantiles': args.quantiles, 'min_trade_usd': args.min_trade_usd, 'max_seq_len': args.context_length, 'sampled_mints': filtered_mints}
222
+
223
+ # Build tasks with class-aware multi-sampling for balanced cache
224
+ import random
225
+ from collections import Counter, defaultdict
226
+
227
+ # Count eligible tokens per class
228
+ eligible_class_counts = Counter()
229
+ mints_by_class = defaultdict(list)
230
+ for i, m in enumerate(filtered_mints):
231
+ cid = return_class_map.get(m['mint_address'])
232
+ if cid is not None:
233
+ eligible_class_counts[cid] += 1
234
+ mints_by_class[cid].append((i, m))
235
+
236
+ print(f"INFO: Eligible tokens per class: {dict(sorted(eligible_class_counts.items()))}")
237
+
238
+ # Compute balanced samples_per_token for each class
239
+ num_classes = len(eligible_class_counts)
240
+ if args.max_samples:
241
+ target_total = args.max_samples
242
+ else:
243
+ target_total = 15000 # Default target: 15k balanced files
244
+ target_per_class = target_total // max(num_classes, 1)
245
+
246
+ class_multipliers = {}
247
+ class_token_caps = {}
248
+ for cid, count in eligible_class_counts.items():
249
+ if count >= target_per_class:
250
+ # Enough tokens — 1 sample each, cap token count
251
+ class_multipliers[cid] = 1
252
+ class_token_caps[cid] = target_per_class
253
+ else:
254
+ # Not enough tokens — multi-sample, use all tokens
255
+ class_multipliers[cid] = min(10, max(1, math.ceil(target_per_class / max(count, 1))))
256
+ class_token_caps[cid] = count
257
+
258
+ print(f"INFO: Target total: {target_total}, Target per class: {target_per_class}")
259
+ print(f"INFO: Class multipliers: {dict(sorted(class_multipliers.items()))}")
260
+ print(f"INFO: Class token caps: {dict(sorted(class_token_caps.items()))}")
261
 
262
+ # Build balanced task list
263
  tasks = []
264
+ for cid, mint_list in mints_by_class.items():
265
+ random.shuffle(mint_list)
266
+ cap = class_token_caps.get(cid, len(mint_list))
267
+ spt = class_multipliers.get(cid, 1)
268
+ # Override with CLI --samples_per_token if explicitly set > 1
269
+ if args.samples_per_token > 1:
270
+ spt = args.samples_per_token
271
+ for i, m in mint_list[:cap]:
272
+ mint_addr = m['mint_address']
273
+ if args.cache_mode == "context":
274
+ tasks.append((i, mint_addr, spt, str(output_dir)))
275
+ else:
276
+ tasks.append((i, mint_addr, str(output_dir)))
277
+
278
+ random.shuffle(tasks) # Shuffle tasks for even load distribution across workers
279
+ expected_files = sum(
280
+ class_multipliers.get(cid, 1) * min(class_token_caps.get(cid, len(ml)), len(ml))
281
+ for cid, ml in mints_by_class.items()
282
+ )
283
+ print(f"INFO: Total tasks: {len(tasks)} (expected ~{expected_files} output files, target ~{target_total})")
284
 
285
  print(f"INFO: Starting to cache {len(tasks)} tokens...")
286
 
 
328
  pass
329
 
330
  with open(output_dir / "class_metadata.json", 'w') as f:
331
+ json.dump({
332
+ 'file_class_map': file_class_map,
333
+ 'class_distribution': {str(k): v for k, v in class_distribution.items()},
334
+ 'cache_mode': args.cache_mode,
335
+ 'num_workers': args.num_workers,
336
+ 'horizons_seconds': args.horizons_seconds,
337
+ 'quantiles': args.quantiles,
338
+ 'class_multipliers': {str(k): v for k, v in class_multipliers.items()},
339
+ 'class_token_caps': {str(k): v for k, v in class_token_caps.items()},
340
+ 'target_total': target_total,
341
+ 'target_per_class': target_per_class,
342
+ }, f, indent=2)
343
 
344
  print(f"\n--- Done ---\nSuccess: {success_count}, Skipped: {skipped_count}, Errors: {error_count}\nFiles: {len(file_class_map)}\nLocation: {output_dir.resolve()}")
345