import os import sys import argparse import datetime import torch import json import math from pathlib import Path from tqdm import tqdm from dotenv import load_dotenv import huggingface_hub import logging from concurrent.futures import ProcessPoolExecutor, as_completed import multiprocessing as mp logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("transformers").setLevel(logging.ERROR) logging.getLogger("huggingface_hub").setLevel(logging.WARNING) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from scripts.analyze_distribution import get_return_class_map from scripts.compute_quality_score import get_token_quality_scores, fetch_token_metrics, _bucket_id, _midrank_percentiles, EPS from data.data_loader import summarize_context_window from clickhouse_driver import Client as ClickHouseClient from neo4j import GraphDatabase _worker_dataset = None _worker_return_class_map = None _worker_quality_scores_map = None def _build_context_quota_plan( class_ids, target_contexts_per_class, target_contexts_total, good_ratio_nonzero, good_ratio_class0, ): unique_class_ids = sorted(set(int(cid) for cid in class_ids)) if not unique_class_ids: return {} if target_contexts_per_class is not None: per_class_target = int(target_contexts_per_class) elif target_contexts_total is not None: per_class_target = max(1, int(target_contexts_total) // len(unique_class_ids)) else: return {} if per_class_target <= 0: raise RuntimeError("Context quota target must be positive.") plan = {} for class_id in unique_class_ids: ratio = float(good_ratio_class0 if class_id == 0 else good_ratio_nonzero) ratio = max(0.0, min(1.0, ratio)) good_target = int(round(per_class_target * ratio)) bad_target = per_class_target - good_target plan[class_id] = { "total_target": per_class_target, "good_target": good_target, "bad_target": bad_target, } return plan def _should_accept_context(class_id, context_bucket, accepted_counts, quota_plan): if not quota_plan: return True if class_id not in quota_plan: return False class_plan = quota_plan[class_id] class_counts = accepted_counts[class_id] if class_counts["total"] >= class_plan["total_target"]: return False bucket_key = "good" if context_bucket == "good" else "bad" target_key = f"{bucket_key}_target" if class_counts[bucket_key] >= class_plan[target_key]: return False return True def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map): global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map from data.data_loader import OracleDataset from data.data_fetcher import DataFetcher clickhouse_client = ClickHouseClient(host=db_config['clickhouse_host'], port=db_config['clickhouse_port']) neo4j_driver = GraphDatabase.driver(db_config['neo4j_uri'], auth=(db_config['neo4j_user'], db_config['neo4j_password'])) data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver) _worker_dataset = OracleDataset( data_fetcher=data_fetcher, min_trades=dataset_config['min_trades'], start_date=dataset_config['start_date'], horizons_seconds=dataset_config['horizons_seconds'], quantiles=dataset_config['quantiles'], min_trade_usd=dataset_config['min_trade_usd'], max_seq_len=dataset_config['max_seq_len'] ) _worker_dataset.sampled_mints = dataset_config['sampled_mints'] _worker_return_class_map = return_class_map _worker_quality_scores_map = quality_scores_map def _process_single_token_context(args): idx, mint_addr, samples_per_token, output_dir = args global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map try: class_id = _worker_return_class_map.get(mint_addr) if class_id is None: return {'status': 'skipped', 'reason': 'not in class map', 'mint': mint_addr} contexts = _worker_dataset.__cacheitem_context__(idx, num_samples_per_token=samples_per_token) if not contexts: return {'status': 'skipped', 'reason': 'no valid contexts', 'mint': mint_addr} q_score = _worker_quality_scores_map.get(mint_addr) if q_score is None: return {'status': 'skipped', 'reason': 'no quality score', 'mint': mint_addr} return { 'status': 'success', 'mint': mint_addr, 'class_id': class_id, 'q_score': q_score, 'n_contexts': len(contexts), 'n_events': len(contexts[0].get('event_sequence', [])) if contexts else 0, 'contexts': contexts, } except Exception as e: import traceback return {'status': 'error', 'mint': mint_addr, 'error': str(e), 'traceback': traceback.format_exc()} def main(): load_dotenv() mp.set_start_method('spawn', force=True) hf_token = os.getenv("HF_TOKEN") if hf_token: print(f"INFO: Logging in to Hugging Face...") huggingface_hub.login(token=hf_token) parser = argparse.ArgumentParser() parser.add_argument("--output_dir", type=str, default="data/cache") parser.add_argument("--start_date", type=str, default=None) parser.add_argument("--min_trade_usd", type=float, default=0.0) parser.add_argument("--min_trades", type=int, default=10) parser.add_argument("--context_length", type=int, default=8192) parser.add_argument("--samples_per_token", type=int, default=1) parser.add_argument("--target_contexts_per_class", type=int, default=None) parser.add_argument("--target_contexts_total", type=int, default=None) parser.add_argument("--good_ratio_nonzero", type=float, default=0.5) parser.add_argument("--good_ratio_class0", type=float, default=0.0) parser.add_argument("--num_workers", type=int, default=1) parser.add_argument("--clickhouse_host", type=str, default=os.getenv("CLICKHOUSE_HOST", "localhost")) parser.add_argument("--clickhouse_port", type=int, default=int(os.getenv("CLICKHOUSE_PORT", 9000))) parser.add_argument("--neo4j_uri", type=str, default=os.getenv("NEO4J_URI", "bolt://localhost:7687")) parser.add_argument("--neo4j_user", type=str, default=os.getenv("NEO4J_USER", "neo4j")) parser.add_argument("--neo4j_password", type=str, default=os.getenv("NEO4J_PASSWORD", "password")) args = parser.parse_args() if args.target_contexts_per_class is not None and args.target_contexts_total is not None: raise RuntimeError( "Choose exactly one cache budget: either --target_contexts_per_class or --target_contexts_total." ) if args.num_workers == 0: args.num_workers = max(1, mp.cpu_count() - 4) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) start_date_dt = datetime.datetime.strptime(args.start_date, "%Y-%m-%d") if args.start_date else None print(f"INFO: Initializing DB Connections...") clickhouse_client = ClickHouseClient(host=args.clickhouse_host, port=args.clickhouse_port) neo4j_driver = GraphDatabase.driver(args.neo4j_uri, auth=(args.neo4j_user, args.neo4j_password)) try: from data.data_loader import OracleDataset from data.data_fetcher import DataFetcher data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver) print("INFO: Fetching Return Classification Map...") return_class_map, _ = get_return_class_map(clickhouse_client) print(f"INFO: Loaded {len(return_class_map)} classified tokens.") print("INFO: Fetching Quality Scores...") quality_scores_map = get_token_quality_scores(clickhouse_client) print(f"INFO: Loaded {len(quality_scores_map)} quality scores.") dataset = OracleDataset( data_fetcher=data_fetcher, min_trades=args.min_trades, start_date=start_date_dt, horizons_seconds=[60, 180, 300, 600, 1800, 3600, 7200], quantiles=[0.5], min_trade_usd=args.min_trade_usd, max_seq_len=args.context_length, ) if len(dataset) == 0: print("WARNING: No samples. Exiting.") return # Filter mints by return_class_map original_size = len(dataset.sampled_mints) filtered_mints = [m for m in dataset.sampled_mints if m['mint_address'] in return_class_map] print(f"INFO: Filtered {original_size} -> {len(filtered_mints)} tokens") if len(filtered_mints) == 0: print("WARNING: No tokens after filtering.") return print(f"INFO: Building canonical context cache | Workers: {args.num_workers}") if args.num_workers != 1 and ( args.target_contexts_per_class is not None or args.target_contexts_total is not None ): raise RuntimeError( "Quota-driven context caching currently requires --num_workers 1 so accepted contexts " "can be planned and written deterministically in one process." ) db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password} dataset_config = {'start_date': start_date_dt, 'min_trades': args.min_trades, 'horizons_seconds': [60, 180, 300, 600, 1800, 3600, 7200], 'quantiles': [0.5], 'min_trade_usd': args.min_trade_usd, 'max_seq_len': args.context_length, 'sampled_mints': filtered_mints} # Build tasks from filtered_mints directly tasks = [] for i, mint_record in enumerate(filtered_mints): mint_addr = mint_record['mint_address'] tasks.append((i, mint_addr, args.samples_per_token, str(output_dir))) print(f"INFO: Starting to cache {len(tasks)} tokens...") success_count, skipped_count, error_count = 0, 0, 0 class_distribution = {} context_distribution = defaultdict(lambda: defaultdict(int)) file_class_map = {} file_context_bucket_map = {} file_context_summary_map = {} process_fn = _process_single_token_context quota_plan = {} accepted_counts = defaultdict(lambda: {"total": 0, "good": 0, "bad": 0}) accepted_per_token = defaultdict(int) quota_plan = _build_context_quota_plan( class_ids=[return_class_map[m['mint_address']] for m in filtered_mints if m['mint_address'] in return_class_map], target_contexts_per_class=args.target_contexts_per_class, target_contexts_total=args.target_contexts_total, good_ratio_nonzero=args.good_ratio_nonzero, good_ratio_class0=args.good_ratio_class0, ) if quota_plan: print("INFO: Context quota plan:") for class_id, plan in sorted(quota_plan.items()): print( f" Class {class_id}: total={plan['total_target']} " f"(good={plan['good_target']}, bad={plan['bad_target']})" ) if args.num_workers == 1: print("INFO: Single-threaded mode...") _init_worker(db_config, dataset_config, return_class_map, quality_scores_map) for task in tqdm(tasks, desc="Caching"): result = process_fn(task) if result['status'] == 'success': if quota_plan: class_id = result['class_id'] mint_addr = result['mint'] q_score = result['q_score'] saved_any = False for ctx in result.get("contexts", []): context_summary = summarize_context_window(ctx.get("labels"), ctx.get("labels_mask")) context_bucket = context_summary["context_bucket"] if not _should_accept_context(class_id, context_bucket, accepted_counts, quota_plan): continue ctx["quality_score"] = q_score ctx["class_id"] = class_id ctx["source_token"] = mint_addr ctx["context_bucket"] = context_bucket ctx["context_score"] = context_summary["context_score"] file_idx = accepted_per_token[mint_addr] filename = f"sample_{mint_addr[:16]}_{file_idx}.pt" output_path = Path(output_dir) / filename torch.save(ctx, output_path) accepted_per_token[mint_addr] += 1 accepted_counts[class_id]["total"] += 1 accepted_counts[class_id][context_bucket] += 1 class_distribution[class_id] = class_distribution.get(class_id, 0) + 1 context_distribution[class_id][context_bucket] += 1 file_class_map[filename] = class_id file_context_bucket_map[filename] = context_bucket file_context_summary_map[filename] = context_summary saved_any = True if saved_any: success_count += 1 else: class_id = result['class_id'] mint_addr = result['mint'] q_score = result['q_score'] for ctx_idx, ctx in enumerate(result.get("contexts", [])): context_summary = summarize_context_window(ctx.get("labels"), ctx.get("labels_mask")) context_bucket = context_summary["context_bucket"] ctx["quality_score"] = q_score ctx["class_id"] = class_id ctx["source_token"] = mint_addr ctx["context_bucket"] = context_bucket ctx["context_score"] = context_summary["context_score"] filename = f"sample_{mint_addr[:16]}_{ctx_idx}.pt" output_path = Path(output_dir) / filename torch.save(ctx, output_path) file_class_map[filename] = class_id file_context_bucket_map[filename] = context_bucket file_context_summary_map[filename] = context_summary class_distribution[class_id] = class_distribution.get(class_id, 0) + 1 context_distribution[class_id][context_bucket] += 1 success_count += 1 elif result['status'] == 'skipped': skipped_count += 1 else: error_count += 1 tqdm.write(f"ERROR: {result['mint'][:16]} - {result['error']}") else: print(f"INFO: Running with {args.num_workers} workers...") with ProcessPoolExecutor(max_workers=args.num_workers, initializer=_init_worker, initargs=(db_config, dataset_config, return_class_map, quality_scores_map)) as executor: futures = {executor.submit(process_fn, task): task for task in tasks} for future in tqdm(as_completed(futures), total=len(futures), desc="Caching"): try: result = future.result(timeout=300) if result['status'] == 'success': class_id = result['class_id'] mint_addr = result['mint'] q_score = result['q_score'] for ctx_idx, ctx in enumerate(result.get("contexts", [])): context_summary = summarize_context_window(ctx.get("labels"), ctx.get("labels_mask")) context_bucket = context_summary["context_bucket"] ctx["quality_score"] = q_score ctx["class_id"] = class_id ctx["source_token"] = mint_addr ctx["context_bucket"] = context_bucket ctx["context_score"] = context_summary["context_score"] filename = f"sample_{mint_addr[:16]}_{ctx_idx}.pt" output_path = Path(output_dir) / filename torch.save(ctx, output_path) file_class_map[filename] = class_id file_context_bucket_map[filename] = context_bucket file_context_summary_map[filename] = context_summary class_distribution[class_id] = class_distribution.get(class_id, 0) + 1 context_distribution[class_id][context_bucket] += 1 success_count += 1 elif result['status'] == 'skipped': skipped_count += 1 else: error_count += 1 except Exception as e: error_count += 1 tqdm.write(f"WORKER ERROR: {e}") print("INFO: Building metadata...") if not file_class_map: for f in sorted(output_dir.glob("sample_*.pt")): try: cached = torch.load(f, map_location="cpu", weights_only=False) file_class_map[f.name] = cached.get("class_id", 0) if "labels" in cached and "labels_mask" in cached: context_summary = summarize_context_window(cached.get("labels"), cached.get("labels_mask")) file_context_bucket_map[f.name] = context_summary["context_bucket"] file_context_summary_map[f.name] = context_summary except Exception: pass with open(output_dir / "class_metadata.json", 'w') as f: json.dump({ 'file_class_map': file_class_map, 'file_context_bucket_map': file_context_bucket_map, 'file_context_summary_map': file_context_summary_map, 'class_distribution': {str(k): v for k, v in class_distribution.items()}, 'context_distribution': { str(k): {bucket: count for bucket, count in bucket_counts.items()} for k, bucket_counts in context_distribution.items() }, 'quota_plan': {str(k): v for k, v in quota_plan.items()}, 'accepted_counts': {str(k): v for k, v in accepted_counts.items()}, 'num_workers': args.num_workers, }, f, indent=2) if quota_plan: print("INFO: Accepted context counts:") for class_id, counts in sorted(accepted_counts.items()): print( f" Class {class_id}: total={counts['total']} " f"good={counts['good']} bad={counts['bad']}" ) print(f"\n--- Done ---\nSuccess: {success_count}, Skipped: {skipped_count}, Errors: {error_count}\nFiles: {len(file_class_map)}\nLocation: {output_dir.resolve()}") finally: clickhouse_client.disconnect() neo4j_driver.close() if __name__ == "__main__": main()