File size: 4,762 Bytes
e605733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1274e05
 
e605733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3780496
e605733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3780496
e605733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#!/usr/bin/env python3
"""Parallel caching script - runs multiple workers to speed up caching. Supports resume."""
import os
import sys
import argparse
import multiprocessing as mp
from pathlib import Path
from dotenv import load_dotenv

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

def cache_chunk(args):
    worker_id, start_idx, end_idx, output_dir, db_args = args

    import torch
    from data.data_loader import OracleDataset
    from data.data_fetcher import DataFetcher
    from scripts.analyze_distribution import get_return_class_map
    from scripts.compute_quality_score import get_token_quality_scores
    from clickhouse_driver import Client as ClickHouseClient
    from neo4j import GraphDatabase

    ch = ClickHouseClient(host=db_args['ch_host'], port=db_args['ch_port'])
    neo = GraphDatabase.driver(db_args['neo4j_uri'], auth=(db_args['neo4j_user'], db_args['neo4j_password']))

    fetcher = DataFetcher(clickhouse_client=ch, neo4j_driver=neo)
    return_map, _ = get_return_class_map(ch)
    quality_map = get_token_quality_scores(ch)

    ds = OracleDataset(
        data_fetcher=fetcher,
        horizons_seconds=[30, 60, 120, 240, 420],
        quantiles=[0.1, 0.5, 0.9],
    )
    ds.sampled_mints = [m for m in ds.sampled_mints if m['mint_address'] in return_map]

    cached = 0
    skipped = 0
    errors = 0

    for i in range(start_idx, min(end_idx, len(ds))):
        # RESUME SUPPORT: Skip if file already exists
        output_path = Path(output_dir) / f"sample_{i}.pt"
        if output_path.exists():
            skipped += 1
            continue

        mint = ds.sampled_mints[i]['mint_address']
        try:
            item = ds.__cacheitem__(i)
            if item and mint in quality_map:
                item["quality_score"] = quality_map[mint]
                item["class_id"] = return_map[mint]
                torch.save(item, output_path)
                cached += 1
                if cached % 50 == 0:
                    print(f"[W{worker_id}] Cached {cached} | Skipped {skipped} | Errors {errors}")
        except Exception as e:
            errors += 1
            if errors < 5:
                print(f"[W{worker_id}] Error at {i}: {e}")

    ch.disconnect()
    neo.close()
    return cached, skipped, errors

def main():
    load_dotenv()
    parser = argparse.ArgumentParser()
    parser.add_argument("--workers", type=int, default=16)
    parser.add_argument("--output_dir", type=str, default="data/cache")
    args = parser.parse_args()

    Path(args.output_dir).mkdir(parents=True, exist_ok=True)

    # Count existing files
    existing = len(list(Path(args.output_dir).glob("sample_*.pt")))
    print(f"Found {existing} existing cache files (will be skipped)")

    # Get total count
    from clickhouse_driver import Client
    from neo4j import GraphDatabase
    from data.data_loader import OracleDataset
    from data.data_fetcher import DataFetcher
    from scripts.analyze_distribution import get_return_class_map

    ch = Client(host=os.getenv("CLICKHOUSE_HOST", "localhost"), port=int(os.getenv("CLICKHOUSE_PORT", 9000)))
    neo = GraphDatabase.driver(os.getenv("NEO4J_URI", "bolt://localhost:7687"),
                               auth=(os.getenv("NEO4J_USER", "neo4j"), os.getenv("NEO4J_PASSWORD", "neo4j123")))

    fetcher = DataFetcher(clickhouse_client=ch, neo4j_driver=neo)
    return_map, _ = get_return_class_map(ch)
    ds = OracleDataset(data_fetcher=fetcher,
                       horizons_seconds=[60, 180, 300, 600, 1800, 3600, 7200], quantiles=[0.5])
    ds.sampled_mints = [m for m in ds.sampled_mints if m['mint_address'] in return_map]
    total = len(ds)
    print(f"Total samples: {total}")
    print(f"Remaining to cache: ~{total - existing}")
    ch.disconnect()
    neo.close()

    # Split work
    chunk_size = (total + args.workers - 1) // args.workers
    db_args = {
        'ch_host': os.getenv("CLICKHOUSE_HOST", "localhost"),
        'ch_port': int(os.getenv("CLICKHOUSE_PORT", 9000)),
        'neo4j_uri': os.getenv("NEO4J_URI", "bolt://localhost:7687"),
        'neo4j_user': os.getenv("NEO4J_USER", "neo4j"),
        'neo4j_password': os.getenv("NEO4J_PASSWORD", "neo4j123"),

    }

    tasks = [(i, i*chunk_size, (i+1)*chunk_size, args.output_dir, db_args) for i in range(args.workers)]

    print(f"Starting {args.workers} workers...")
    with mp.Pool(args.workers) as pool:
        results = pool.map(cache_chunk, tasks)

    total_cached = sum(r[0] for r in results)
    total_skipped = sum(r[1] for r in results)
    total_errors = sum(r[2] for r in results)
    print(f"Done! Cached {total_cached} new samples, skipped {total_skipped} existing, {total_errors} errors.")

if __name__ == "__main__":
    main()