oracle / scripts /cache_parallel.py
zirobtc's picture
Upload folder using huggingface_hub
3780496
#!/usr/bin/env python3
"""Parallel caching script - runs multiple workers to speed up caching. Supports resume."""
import os
import sys
import argparse
import multiprocessing as mp
from pathlib import Path
from dotenv import load_dotenv
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
def cache_chunk(args):
worker_id, start_idx, end_idx, output_dir, db_args = args
import torch
from data.data_loader import OracleDataset
from data.data_fetcher import DataFetcher
from scripts.analyze_distribution import get_return_class_map
from scripts.compute_quality_score import get_token_quality_scores
from clickhouse_driver import Client as ClickHouseClient
from neo4j import GraphDatabase
ch = ClickHouseClient(host=db_args['ch_host'], port=db_args['ch_port'])
neo = GraphDatabase.driver(db_args['neo4j_uri'], auth=(db_args['neo4j_user'], db_args['neo4j_password']))
fetcher = DataFetcher(clickhouse_client=ch, neo4j_driver=neo)
return_map, _ = get_return_class_map(ch)
quality_map = get_token_quality_scores(ch)
ds = OracleDataset(
data_fetcher=fetcher,
horizons_seconds=[30, 60, 120, 240, 420],
quantiles=[0.1, 0.5, 0.9],
)
ds.sampled_mints = [m for m in ds.sampled_mints if m['mint_address'] in return_map]
cached = 0
skipped = 0
errors = 0
for i in range(start_idx, min(end_idx, len(ds))):
# RESUME SUPPORT: Skip if file already exists
output_path = Path(output_dir) / f"sample_{i}.pt"
if output_path.exists():
skipped += 1
continue
mint = ds.sampled_mints[i]['mint_address']
try:
item = ds.__cacheitem__(i)
if item and mint in quality_map:
item["quality_score"] = quality_map[mint]
item["class_id"] = return_map[mint]
torch.save(item, output_path)
cached += 1
if cached % 50 == 0:
print(f"[W{worker_id}] Cached {cached} | Skipped {skipped} | Errors {errors}")
except Exception as e:
errors += 1
if errors < 5:
print(f"[W{worker_id}] Error at {i}: {e}")
ch.disconnect()
neo.close()
return cached, skipped, errors
def main():
load_dotenv()
parser = argparse.ArgumentParser()
parser.add_argument("--workers", type=int, default=16)
parser.add_argument("--output_dir", type=str, default="data/cache")
args = parser.parse_args()
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
# Count existing files
existing = len(list(Path(args.output_dir).glob("sample_*.pt")))
print(f"Found {existing} existing cache files (will be skipped)")
# Get total count
from clickhouse_driver import Client
from neo4j import GraphDatabase
from data.data_loader import OracleDataset
from data.data_fetcher import DataFetcher
from scripts.analyze_distribution import get_return_class_map
ch = Client(host=os.getenv("CLICKHOUSE_HOST", "localhost"), port=int(os.getenv("CLICKHOUSE_PORT", 9000)))
neo = GraphDatabase.driver(os.getenv("NEO4J_URI", "bolt://localhost:7687"),
auth=(os.getenv("NEO4J_USER", "neo4j"), os.getenv("NEO4J_PASSWORD", "neo4j123")))
fetcher = DataFetcher(clickhouse_client=ch, neo4j_driver=neo)
return_map, _ = get_return_class_map(ch)
ds = OracleDataset(data_fetcher=fetcher,
horizons_seconds=[60, 180, 300, 600, 1800, 3600, 7200], quantiles=[0.5])
ds.sampled_mints = [m for m in ds.sampled_mints if m['mint_address'] in return_map]
total = len(ds)
print(f"Total samples: {total}")
print(f"Remaining to cache: ~{total - existing}")
ch.disconnect()
neo.close()
# Split work
chunk_size = (total + args.workers - 1) // args.workers
db_args = {
'ch_host': os.getenv("CLICKHOUSE_HOST", "localhost"),
'ch_port': int(os.getenv("CLICKHOUSE_PORT", 9000)),
'neo4j_uri': os.getenv("NEO4J_URI", "bolt://localhost:7687"),
'neo4j_user': os.getenv("NEO4J_USER", "neo4j"),
'neo4j_password': os.getenv("NEO4J_PASSWORD", "neo4j123"),
}
tasks = [(i, i*chunk_size, (i+1)*chunk_size, args.output_dir, db_args) for i in range(args.workers)]
print(f"Starting {args.workers} workers...")
with mp.Pool(args.workers) as pool:
results = pool.map(cache_chunk, tasks)
total_cached = sum(r[0] for r in results)
total_skipped = sum(r[1] for r in results)
total_errors = sum(r[2] for r in results)
print(f"Done! Cached {total_cached} new samples, skipped {total_skipped} existing, {total_errors} errors.")
if __name__ == "__main__":
main()