import os import sys import argparse import datetime import torch import json from pathlib import Path from tqdm import tqdm from dotenv import load_dotenv import huggingface_hub import logging import multiprocessing as mp from collections import Counter, defaultdict logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("transformers").setLevel(logging.ERROR) logging.getLogger("huggingface_hub").setLevel(logging.WARNING) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from scripts.analyze_distribution import get_return_class_map, compute_p99_clamps from scripts.compute_quality_score import get_token_quality_scores from data.data_loader import summarize_context_window from data.quant_ohlc_feature_schema import FEATURE_VERSION from clickhouse_driver import Client as ClickHouseClient from neo4j import GraphDatabase _worker_dataset = None _worker_return_class_map = None _worker_quality_scores_map = None _worker_encoder = None def _to_int_list(values): if values is None: return [] if isinstance(values, torch.Tensor): return [int(v) for v in values.tolist()] return [int(v) for v in values] def _to_float_list(values): if values is None: return [] if isinstance(values, torch.Tensor): return [float(v) for v in values.tolist()] return [float(v) for v in values] def _representative_context_polarity(context): labels = _to_float_list(context.get("labels")) mask = _to_int_list(context.get("labels_mask")) valid_returns = [label for label, keep in zip(labels, mask) if keep > 0] if not valid_returns: return "negative" return "positive" if max(valid_returns) > 0.0 else "negative" def _class_polarity_targets(class_id, target_contexts_per_class, positive_balance_min_class, positive_ratio): if class_id >= positive_balance_min_class: positive_target = int(round(target_contexts_per_class * positive_ratio)) positive_target = min(max(positive_target, 0), target_contexts_per_class) else: positive_target = 0 return { "positive": positive_target, "negative": max(0, target_contexts_per_class - positive_target), } def _remaining_polarity_targets(class_id, accepted_counts, target_contexts_per_class, positive_balance_min_class, positive_ratio): targets = _class_polarity_targets( class_id=class_id, target_contexts_per_class=target_contexts_per_class, positive_balance_min_class=positive_balance_min_class, positive_ratio=positive_ratio, ) class_counts = accepted_counts[class_id] return { "positive": max(0, targets["positive"] - class_counts["positive"]), "negative": max(0, targets["negative"] - class_counts["negative"]), } def _select_contexts_by_polarity(contexts, max_keep, desired_positive=None, desired_negative=None): if len(contexts) <= max_keep: polarity_counts = {} for context in contexts: polarity = _representative_context_polarity(context) polarity_counts[polarity] = polarity_counts.get(polarity, 0) + 1 context["representative_context_polarity"] = polarity return contexts, polarity_counts positive_bucket = [] negative_bucket = [] for context in contexts: polarity = _representative_context_polarity(context) context["representative_context_polarity"] = polarity if polarity == "positive": positive_bucket.append(context) else: negative_bucket.append(context) selected = [] polarity_counts = {"positive": 0, "negative": 0} desired_positive = max(0, int(desired_positive)) if desired_positive is not None else None desired_negative = max(0, int(desired_negative)) if desired_negative is not None else None if desired_positive is not None or desired_negative is not None: target_positive = min(desired_positive or 0, max_keep, len(positive_bucket)) target_negative = min(desired_negative or 0, max_keep - target_positive, len(negative_bucket)) while polarity_counts["positive"] < target_positive and positive_bucket: selected.append(positive_bucket.pop()) polarity_counts["positive"] += 1 while polarity_counts["negative"] < target_negative and negative_bucket: selected.append(negative_bucket.pop()) polarity_counts["negative"] += 1 prefer_positive = len(positive_bucket) >= len(negative_bucket) while len(selected) < max_keep and (positive_bucket or negative_bucket): if prefer_positive and positive_bucket: selected.append(positive_bucket.pop()) polarity_counts["positive"] += 1 elif not prefer_positive and negative_bucket: selected.append(negative_bucket.pop()) polarity_counts["negative"] += 1 elif positive_bucket: selected.append(positive_bucket.pop()) polarity_counts["positive"] += 1 elif negative_bucket: selected.append(negative_bucket.pop()) polarity_counts["negative"] += 1 prefer_positive = not prefer_positive return selected[:max_keep], polarity_counts def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map): global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map from data.data_loader import OracleDataset from data.data_fetcher import DataFetcher clickhouse_client = ClickHouseClient(host=db_config['clickhouse_host'], port=db_config['clickhouse_port']) neo4j_driver = GraphDatabase.driver(db_config['neo4j_uri'], auth=(db_config['neo4j_user'], db_config['neo4j_password'])) data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver) # --- NEW: Init Encoder on GPU --- from models.multi_modal_processor import MultiModalEncoder # Using float16 for efficiency on GPU global _worker_encoder try: _worker_encoder = MultiModalEncoder( model_id="google/siglip-so400m-patch16-256-i18n", device="cuda", dtype=torch.float16 ) except Exception as e: print(f"WARN: Failed to initialize MultiModalEncoder on worker: {e}") _worker_encoder = None _worker_dataset = OracleDataset( data_fetcher=data_fetcher, start_date=dataset_config['start_date'], horizons_seconds=dataset_config['horizons_seconds'], quantiles=dataset_config['quantiles'], min_trade_usd=dataset_config['min_trade_usd'], max_seq_len=dataset_config['max_seq_len'], p99_clamps=dataset_config.get('p99_clamps') ) _worker_dataset.sampled_mints = dataset_config['sampled_mints'] _worker_return_class_map = return_class_map _worker_quality_scores_map = quality_scores_map def _process_single_token_context(args): idx, mint_addr, samples_per_token, oversample_factor, desired_positive, desired_negative = args global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map try: class_id = _worker_return_class_map.get(mint_addr) if class_id is None: return {'status': 'skipped', 'reason': 'not in class map', 'mint': mint_addr} # Pass the global encoder (if initialized) to pre-compute embeddings global _worker_encoder encoder = _worker_encoder # print(f"DEBUG: Worker encoder status: {type(encoder)}", flush=True) # Commented out to reduce noise if it works if encoder is None: print(f"ERROR: Worker encoder is None for mint {mint_addr}!", flush=True) candidate_contexts = _worker_dataset.__cacheitem_context__( idx, num_samples_per_token=max(samples_per_token, samples_per_token * max(1, oversample_factor)), encoder=encoder, ) contexts, polarity_counts = _select_contexts_by_polarity( candidate_contexts, samples_per_token, desired_positive=desired_positive, desired_negative=desired_negative, ) if not contexts: return {'status': 'skipped', 'reason': 'no valid contexts', 'mint': mint_addr} q_score = _worker_quality_scores_map.get(mint_addr) if q_score is None: return {'status': 'skipped', 'reason': 'no quality score', 'mint': mint_addr} return { 'status': 'success', 'mint': mint_addr, 'class_id': class_id, 'q_score': q_score, 'n_contexts': len(contexts), 'n_events': len(contexts[0].get('event_sequence', [])) if contexts else 0, 'contexts': contexts, 'polarity_counts': polarity_counts, } except Exception as e: import traceback return {'status': 'error', 'mint': mint_addr, 'error': str(e), 'traceback': traceback.format_exc()} def main(): load_dotenv() mp.set_start_method('spawn', force=True) hf_token = os.getenv("HF_TOKEN") if hf_token: print(f"INFO: Logging in to Hugging Face...") huggingface_hub.login(token=hf_token) parser = argparse.ArgumentParser() parser.add_argument("--output_dir", type=str, default="data/cache") parser.add_argument("--start_date", type=str, default=None) parser.add_argument("--min_trade_usd", type=float, default=0.0) parser.add_argument("--context_length", type=int, default=8192) parser.add_argument("--min_trades", type=int, default=10) parser.add_argument("--samples_per_token", type=int, default=1) parser.add_argument("--target_contexts_per_class", type=int, default=2500) parser.add_argument("--context_oversample_factor", type=int, default=4) parser.add_argument("--positive_balance_min_class", type=int, default=2) parser.add_argument("--positive_context_ratio", type=float, default=0.5) parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[30, 60, 120, 240, 420]) parser.add_argument("--quantiles", type=float, nargs="+", default=[0.1, 0.5, 0.9]) parser.add_argument("--num_workers", type=int, default=1) parser.add_argument("--clickhouse_host", type=str, default=os.getenv("CLICKHOUSE_HOST", "localhost")) parser.add_argument("--clickhouse_port", type=int, default=int(os.getenv("CLICKHOUSE_PORT", 9000))) parser.add_argument("--neo4j_uri", type=str, default=os.getenv("NEO4J_URI", "bolt://localhost:7687")) parser.add_argument("--neo4j_user", type=str, default=os.getenv("NEO4J_USER", "neo4j")) parser.add_argument("--neo4j_password", type=str, default=os.getenv("NEO4J_PASSWORD", "password")) args = parser.parse_args() if args.num_workers == 0: args.num_workers = max(1, mp.cpu_count() - 4) if args.num_workers != 1: raise RuntimeError("Quota-based caching requires --num_workers 1 so class counters remain exact.") if args.target_contexts_per_class <= 0: raise RuntimeError("--target_contexts_per_class must be positive.") output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) start_date_dt = datetime.datetime.strptime(args.start_date, "%Y-%m-%d") if args.start_date else None print(f"INFO: Initializing DB Connections...") clickhouse_client = ClickHouseClient(host=args.clickhouse_host, port=args.clickhouse_port) neo4j_driver = GraphDatabase.driver(args.neo4j_uri, auth=(args.neo4j_user, args.neo4j_password)) try: from data.data_loader import OracleDataset from data.data_fetcher import DataFetcher data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver) print("INFO: Fetching Return Classification Map...") return_class_map, _ = get_return_class_map(clickhouse_client) print(f"INFO: Loaded {len(return_class_map)} classified tokens.") print("INFO: Computing P99 clamp values...") p99_clamps = compute_p99_clamps(clickhouse_client) print("INFO: Fetching Quality Scores...") quality_scores_map = get_token_quality_scores(clickhouse_client) print(f"INFO: Loaded {len(quality_scores_map)} quality scores.") dataset = OracleDataset(data_fetcher=data_fetcher, start_date=start_date_dt, horizons_seconds=args.horizons_seconds, quantiles=args.quantiles, min_trade_usd=args.min_trade_usd, max_seq_len=args.context_length, p99_clamps=p99_clamps) if len(dataset) == 0: print("WARNING: No samples. Exiting.") return # Filter mints by return_class_map original_size = len(dataset.sampled_mints) filtered_mints = [m for m in dataset.sampled_mints if m['mint_address'] in return_class_map] print(f"INFO: Filtered by class map: {original_size} -> {len(filtered_mints)} tokens") # Pre-filter: only keep tokens with >= min_trades trades (fast ClickHouse count query) print(f"INFO: Pre-filtering tokens by trade count (>= {args.min_trades} trades)...") trade_counts = clickhouse_client.execute(""" SELECT base_address, count() as cnt FROM trades GROUP BY base_address HAVING cnt >= %(min_trades)s """, {'min_trades': args.min_trades}) valid_tokens = {row[0] for row in trade_counts} pre_filter_size = len(filtered_mints) filtered_mints = [m for m in filtered_mints if m['mint_address'] in valid_tokens] print(f"INFO: Pre-filtered by trade count: {pre_filter_size} -> {len(filtered_mints)} tokens (removed {pre_filter_size - len(filtered_mints)} with < {args.min_trades} trades)") # Also filter by quality score availability pre_quality_size = len(filtered_mints) filtered_mints = [m for m in filtered_mints if m['mint_address'] in quality_scores_map] print(f"INFO: Filtered by quality score: {pre_quality_size} -> {len(filtered_mints)} tokens") if len(filtered_mints) == 0: print("WARNING: No tokens after filtering.") return print(f"INFO: Workers: {args.num_workers}") db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password} dataset_config = {'start_date': start_date_dt, 'horizons_seconds': args.horizons_seconds, 'quantiles': args.quantiles, 'min_trade_usd': args.min_trade_usd, 'max_seq_len': args.context_length, 'sampled_mints': filtered_mints, 'p99_clamps': p99_clamps} import random eligible_class_counts = Counter() for i, m in enumerate(filtered_mints): cid = return_class_map.get(m['mint_address']) if cid is not None: eligible_class_counts[cid] += 1 print(f"INFO: Eligible tokens per class: {dict(sorted(eligible_class_counts.items()))}") class_targets = { int(class_id): int(args.target_contexts_per_class) for class_id in sorted(eligible_class_counts.keys()) } class_polarity_targets = { class_id: _class_polarity_targets( class_id=class_id, target_contexts_per_class=args.target_contexts_per_class, positive_balance_min_class=args.positive_balance_min_class, positive_ratio=args.positive_context_ratio, ) for class_id in class_targets } target_total = args.target_contexts_per_class * len(class_targets) print(f"INFO: Target total: {target_total}, Target per class: {args.target_contexts_per_class}") print(f"INFO: Exact class targets: {dict(sorted(class_targets.items()))}") print(f"INFO: Class polarity targets: {dict(sorted(class_polarity_targets.items()))}") tasks = list(enumerate(filtered_mints)) random.shuffle(tasks) print(f"INFO: Total candidate tokens: {len(tasks)}") success_count, skipped_count, error_count = 0, 0, 0 class_distribution = defaultdict(int) polarity_distribution = defaultdict(int) file_class_map = {} file_context_bucket_map = {} file_context_summary_map = {} accepted_counts = defaultdict(lambda: {"total": 0, "positive": 0, "negative": 0}) # Resume support: count existing files toward quotas. existing_files = set(f.name for f in output_dir.glob("sample_*.pt")) if existing_files: already_cached = 0 for f in sorted(output_dir.glob("sample_*.pt")): try: cached = torch.load(f, map_location="cpu", weights_only=False) except Exception: continue class_id = cached.get("class_id") if class_id is None or int(class_id) not in class_targets: continue class_id = int(class_id) context_summary = summarize_context_window(cached.get("labels"), cached.get("labels_mask")) polarity = _representative_context_polarity(cached) file_class_map[f.name] = class_id file_context_bucket_map[f.name] = context_summary["context_bucket"] file_context_summary_map[f.name] = context_summary class_distribution[class_id] += 1 polarity_distribution[polarity] += 1 accepted_counts[class_id]["total"] += 1 accepted_counts[class_id][polarity] += 1 already_cached += 1 print(f"INFO: Resume: counted {already_cached} cached contexts toward quotas.") print(f"INFO: Starting to cache {len(tasks)} tokens...") process_fn = _process_single_token_context import time as _time def _log_progress(task_num, total, start_time, recent_times, success_count, skipped_count, error_count): """Print progress with rolling ETA every 10 tokens.""" if (task_num + 1) % 10 == 0 and recent_times: avg_time = sum(recent_times) / len(recent_times) remaining = total - (task_num + 1) eta_seconds = avg_time * remaining eta_hours = eta_seconds / 3600 wall_elapsed = _time.perf_counter() - start_time speed = (task_num + 1) / wall_elapsed tqdm.write( f" [PROGRESS] {task_num+1}/{total} | " f"Speed: {speed:.1f} tok/s ({speed*60:.0f} tok/min) | " f"Avg: {avg_time:.1f}s/tok | " f"ETA: {eta_hours:.1f}h | " f"OK: {success_count} Skip: {skipped_count} Err: {error_count}" ) # Error log file for diagnosing failures error_log_path = Path(args.output_dir) / "cache_errors.log" error_samples = [] # First 20 unique error messages print("INFO: Single-threaded mode...") _init_worker(db_config, dataset_config, return_class_map, quality_scores_map) start_time = _time.perf_counter() recent_times = [] completed_classes = set() for task_num, (idx, mint_record) in enumerate(tqdm(tasks, desc="Caching", unit="tok")): mint_addr = mint_record["mint_address"] class_id = return_class_map.get(mint_addr) if class_id is None: skipped_count += 1 _log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count) continue if accepted_counts[class_id]["total"] >= class_targets[class_id]: if class_id not in completed_classes: completed_classes.add(class_id) tqdm.write(f"INFO: Class {class_id} quota filled. Skipping remaining tokens for this class.") skipped_count += 1 _log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count) continue remaining_total = class_targets[class_id] - accepted_counts[class_id]["total"] remaining_polarity = _remaining_polarity_targets( class_id=class_id, accepted_counts=accepted_counts, target_contexts_per_class=args.target_contexts_per_class, positive_balance_min_class=args.positive_balance_min_class, positive_ratio=args.positive_context_ratio, ) desired_positive = min(remaining_polarity["positive"], args.samples_per_token, remaining_total) desired_negative = min( remaining_polarity["negative"], max(0, min(args.samples_per_token, remaining_total) - desired_positive), ) samples_to_keep = min(args.samples_per_token, remaining_total) task = ( idx, mint_addr, samples_to_keep, args.context_oversample_factor, desired_positive, desired_negative, ) t0 = _time.perf_counter() result = process_fn(task) elapsed = _time.perf_counter() - t0 recent_times.append(elapsed) if len(recent_times) > 50: recent_times.pop(0) if result["status"] == "success": saved_contexts = 0 for ctx in result.get("contexts", []): if accepted_counts[class_id]["total"] >= class_targets[class_id]: break polarity = _representative_context_polarity(ctx) remaining_polarity = _remaining_polarity_targets( class_id=class_id, accepted_counts=accepted_counts, target_contexts_per_class=args.target_contexts_per_class, positive_balance_min_class=args.positive_balance_min_class, positive_ratio=args.positive_context_ratio, ) other_polarity = "negative" if polarity == "positive" else "positive" if remaining_polarity[polarity] <= 0 and remaining_polarity[other_polarity] > 0: continue ctx["quality_score"] = result["q_score"] ctx["class_id"] = class_id ctx["source_token"] = mint_addr context_summary = summarize_context_window(ctx.get("labels"), ctx.get("labels_mask")) ctx["context_bucket"] = context_summary["context_bucket"] ctx["context_score"] = context_summary["context_score"] file_idx = accepted_counts[class_id]["total"] filename = f"sample_{mint_addr[:16]}_{file_idx}.pt" output_path = output_dir / filename torch.save(ctx, output_path) file_class_map[filename] = class_id file_context_bucket_map[filename] = context_summary["context_bucket"] file_context_summary_map[filename] = context_summary class_distribution[class_id] += 1 polarity_distribution[polarity] += 1 accepted_counts[class_id]["total"] += 1 accepted_counts[class_id][polarity] += 1 saved_contexts += 1 if saved_contexts > 0: success_count += 1 else: skipped_count += 1 elif result["status"] == "skipped": skipped_count += 1 else: error_count += 1 err_msg = result.get("error", "unknown") tqdm.write(f"ERROR: {result['mint'][:16]} - {err_msg}") if len(error_samples) < 20: error_samples.append({'mint': result.get('mint'), 'error': err_msg, 'traceback': result.get('traceback', '')}) if all(accepted_counts[cid]["total"] >= class_targets[cid] for cid in class_targets): tqdm.write("INFO: All class quotas filled. Stopping early.") _log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count) break _log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count) # Write error log if error_samples: with open(error_log_path, 'w') as ef: for i, es in enumerate(error_samples): ef.write(f"=== Error {i+1} === Token: {es['mint']}\n") ef.write(f"Error: {es['error']}\n") ef.write(f"Traceback:\n{es['traceback']}\n\n") print(f"INFO: First {len(error_samples)} error tracebacks saved to {error_log_path}") print("INFO: Building metadata...") with open(output_dir / "class_metadata.json", 'w') as f: json.dump({ 'file_class_map': file_class_map, 'file_context_bucket_map': file_context_bucket_map, 'file_context_summary_map': file_context_summary_map, 'class_distribution': {str(k): v for k, v in class_distribution.items()}, 'num_workers': args.num_workers, 'horizons_seconds': args.horizons_seconds, 'quantiles': args.quantiles, 'target_total': target_total, 'target_contexts_per_class': args.target_contexts_per_class, 'context_polarity_distribution': polarity_distribution, 'class_targets': {str(k): v for k, v in class_targets.items()}, 'class_polarity_targets': {str(k): v for k, v in class_polarity_targets.items()}, 'accepted_counts': {str(k): v for k, v in accepted_counts.items()}, 'positive_balance_min_class': args.positive_balance_min_class, 'positive_context_ratio': args.positive_context_ratio, 'quant_feature_version': FEATURE_VERSION, }, f, indent=2) print(f"\n--- Done ---\nSuccess: {success_count}, Skipped: {skipped_count}, Errors: {error_count}\nFiles: {len(file_class_map)}\nLocation: {output_dir.resolve()}") finally: clickhouse_client.disconnect() neo4j_driver.close() if __name__ == "__main__": main()