| |
| """Parallel caching script - runs multiple workers to speed up caching. Supports resume.""" |
| import os |
| import sys |
| import argparse |
| import multiprocessing as mp |
| from pathlib import Path |
| from dotenv import load_dotenv |
|
|
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| def cache_chunk(args): |
| worker_id, start_idx, end_idx, output_dir, db_args = args |
|
|
| import torch |
| from data.data_loader import OracleDataset |
| from data.data_fetcher import DataFetcher |
| from scripts.analyze_distribution import get_return_class_map |
| from scripts.compute_quality_score import get_token_quality_scores |
| from clickhouse_driver import Client as ClickHouseClient |
| from neo4j import GraphDatabase |
|
|
| ch = ClickHouseClient(host=db_args['ch_host'], port=db_args['ch_port']) |
| neo = GraphDatabase.driver(db_args['neo4j_uri'], auth=(db_args['neo4j_user'], db_args['neo4j_password'])) |
|
|
| fetcher = DataFetcher(clickhouse_client=ch, neo4j_driver=neo) |
| return_map, _ = get_return_class_map(ch) |
| quality_map = get_token_quality_scores(ch) |
|
|
| ds = OracleDataset( |
| data_fetcher=fetcher, |
| horizons_seconds=[30, 60, 120, 240, 420], |
| quantiles=[0.1, 0.5, 0.9], |
| ) |
| ds.sampled_mints = [m for m in ds.sampled_mints if m['mint_address'] in return_map] |
|
|
| cached = 0 |
| skipped = 0 |
| errors = 0 |
|
|
| for i in range(start_idx, min(end_idx, len(ds))): |
| |
| output_path = Path(output_dir) / f"sample_{i}.pt" |
| if output_path.exists(): |
| skipped += 1 |
| continue |
|
|
| mint = ds.sampled_mints[i]['mint_address'] |
| try: |
| item = ds.__cacheitem__(i) |
| if item and mint in quality_map: |
| item["quality_score"] = quality_map[mint] |
| item["class_id"] = return_map[mint] |
| torch.save(item, output_path) |
| cached += 1 |
| if cached % 50 == 0: |
| print(f"[W{worker_id}] Cached {cached} | Skipped {skipped} | Errors {errors}") |
| except Exception as e: |
| errors += 1 |
| if errors < 5: |
| print(f"[W{worker_id}] Error at {i}: {e}") |
|
|
| ch.disconnect() |
| neo.close() |
| return cached, skipped, errors |
|
|
| def main(): |
| load_dotenv() |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--workers", type=int, default=16) |
| parser.add_argument("--output_dir", type=str, default="data/cache") |
| args = parser.parse_args() |
|
|
| Path(args.output_dir).mkdir(parents=True, exist_ok=True) |
|
|
| |
| existing = len(list(Path(args.output_dir).glob("sample_*.pt"))) |
| print(f"Found {existing} existing cache files (will be skipped)") |
|
|
| |
| from clickhouse_driver import Client |
| from neo4j import GraphDatabase |
| from data.data_loader import OracleDataset |
| from data.data_fetcher import DataFetcher |
| from scripts.analyze_distribution import get_return_class_map |
|
|
| ch = Client(host=os.getenv("CLICKHOUSE_HOST", "localhost"), port=int(os.getenv("CLICKHOUSE_PORT", 9000))) |
| neo = GraphDatabase.driver(os.getenv("NEO4J_URI", "bolt://localhost:7687"), |
| auth=(os.getenv("NEO4J_USER", "neo4j"), os.getenv("NEO4J_PASSWORD", "neo4j123"))) |
|
|
| fetcher = DataFetcher(clickhouse_client=ch, neo4j_driver=neo) |
| return_map, _ = get_return_class_map(ch) |
| ds = OracleDataset(data_fetcher=fetcher, |
| horizons_seconds=[60, 180, 300, 600, 1800, 3600, 7200], quantiles=[0.5]) |
| ds.sampled_mints = [m for m in ds.sampled_mints if m['mint_address'] in return_map] |
| total = len(ds) |
| print(f"Total samples: {total}") |
| print(f"Remaining to cache: ~{total - existing}") |
| ch.disconnect() |
| neo.close() |
|
|
| |
| chunk_size = (total + args.workers - 1) // args.workers |
| db_args = { |
| 'ch_host': os.getenv("CLICKHOUSE_HOST", "localhost"), |
| 'ch_port': int(os.getenv("CLICKHOUSE_PORT", 9000)), |
| 'neo4j_uri': os.getenv("NEO4J_URI", "bolt://localhost:7687"), |
| 'neo4j_user': os.getenv("NEO4J_USER", "neo4j"), |
| 'neo4j_password': os.getenv("NEO4J_PASSWORD", "neo4j123"), |
|
|
| } |
|
|
| tasks = [(i, i*chunk_size, (i+1)*chunk_size, args.output_dir, db_args) for i in range(args.workers)] |
|
|
| print(f"Starting {args.workers} workers...") |
| with mp.Pool(args.workers) as pool: |
| results = pool.map(cache_chunk, tasks) |
|
|
| total_cached = sum(r[0] for r in results) |
| total_skipped = sum(r[1] for r in results) |
| total_errors = sum(r[2] for r in results) |
| print(f"Done! Cached {total_cached} new samples, skipped {total_skipped} existing, {total_errors} errors.") |
|
|
| if __name__ == "__main__": |
| main() |
|
|