#!/usr/bin/env python3 """Parallel caching script - runs multiple workers to speed up caching. Supports resume.""" import os import sys import argparse import multiprocessing as mp from pathlib import Path from dotenv import load_dotenv sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) def cache_chunk(args): worker_id, start_idx, end_idx, output_dir, db_args = args import torch from data.data_loader import OracleDataset from data.data_fetcher import DataFetcher from scripts.analyze_distribution import get_return_class_map from scripts.compute_quality_score import get_token_quality_scores from clickhouse_driver import Client as ClickHouseClient from neo4j import GraphDatabase ch = ClickHouseClient(host=db_args['ch_host'], port=db_args['ch_port']) neo = GraphDatabase.driver(db_args['neo4j_uri'], auth=(db_args['neo4j_user'], db_args['neo4j_password'])) fetcher = DataFetcher(clickhouse_client=ch, neo4j_driver=neo) return_map, _ = get_return_class_map(ch) quality_map = get_token_quality_scores(ch) ds = OracleDataset( data_fetcher=fetcher, horizons_seconds=[30, 60, 120, 240, 420], quantiles=[0.1, 0.5, 0.9], ) ds.sampled_mints = [m for m in ds.sampled_mints if m['mint_address'] in return_map] cached = 0 skipped = 0 errors = 0 for i in range(start_idx, min(end_idx, len(ds))): # RESUME SUPPORT: Skip if file already exists output_path = Path(output_dir) / f"sample_{i}.pt" if output_path.exists(): skipped += 1 continue mint = ds.sampled_mints[i]['mint_address'] try: item = ds.__cacheitem__(i) if item and mint in quality_map: item["quality_score"] = quality_map[mint] item["class_id"] = return_map[mint] torch.save(item, output_path) cached += 1 if cached % 50 == 0: print(f"[W{worker_id}] Cached {cached} | Skipped {skipped} | Errors {errors}") except Exception as e: errors += 1 if errors < 5: print(f"[W{worker_id}] Error at {i}: {e}") ch.disconnect() neo.close() return cached, skipped, errors def main(): load_dotenv() parser = argparse.ArgumentParser() parser.add_argument("--workers", type=int, default=16) parser.add_argument("--output_dir", type=str, default="data/cache") args = parser.parse_args() Path(args.output_dir).mkdir(parents=True, exist_ok=True) # Count existing files existing = len(list(Path(args.output_dir).glob("sample_*.pt"))) print(f"Found {existing} existing cache files (will be skipped)") # Get total count from clickhouse_driver import Client from neo4j import GraphDatabase from data.data_loader import OracleDataset from data.data_fetcher import DataFetcher from scripts.analyze_distribution import get_return_class_map ch = Client(host=os.getenv("CLICKHOUSE_HOST", "localhost"), port=int(os.getenv("CLICKHOUSE_PORT", 9000))) neo = GraphDatabase.driver(os.getenv("NEO4J_URI", "bolt://localhost:7687"), auth=(os.getenv("NEO4J_USER", "neo4j"), os.getenv("NEO4J_PASSWORD", "neo4j123"))) fetcher = DataFetcher(clickhouse_client=ch, neo4j_driver=neo) return_map, _ = get_return_class_map(ch) ds = OracleDataset(data_fetcher=fetcher, horizons_seconds=[60, 180, 300, 600, 1800, 3600, 7200], quantiles=[0.5]) ds.sampled_mints = [m for m in ds.sampled_mints if m['mint_address'] in return_map] total = len(ds) print(f"Total samples: {total}") print(f"Remaining to cache: ~{total - existing}") ch.disconnect() neo.close() # Split work chunk_size = (total + args.workers - 1) // args.workers db_args = { 'ch_host': os.getenv("CLICKHOUSE_HOST", "localhost"), 'ch_port': int(os.getenv("CLICKHOUSE_PORT", 9000)), 'neo4j_uri': os.getenv("NEO4J_URI", "bolt://localhost:7687"), 'neo4j_user': os.getenv("NEO4J_USER", "neo4j"), 'neo4j_password': os.getenv("NEO4J_PASSWORD", "neo4j123"), } tasks = [(i, i*chunk_size, (i+1)*chunk_size, args.output_dir, db_args) for i in range(args.workers)] print(f"Starting {args.workers} workers...") with mp.Pool(args.workers) as pool: results = pool.map(cache_chunk, tasks) total_cached = sum(r[0] for r in results) total_skipped = sum(r[1] for r in results) total_errors = sum(r[2] for r in results) print(f"Done! Cached {total_cached} new samples, skipped {total_skipped} existing, {total_errors} errors.") if __name__ == "__main__": main()