oracle / scripts /cache_dataset.py
zirobtc's picture
Upload folder using huggingface_hub
d195287 verified
import os
import sys
import argparse
import datetime
import torch
import json
from pathlib import Path
from tqdm import tqdm
from dotenv import load_dotenv
import huggingface_hub
import logging
import multiprocessing as mp
from collections import Counter, defaultdict
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from scripts.analyze_distribution import get_return_class_map, compute_p99_clamps
from scripts.compute_quality_score import get_token_quality_scores
from data.data_loader import summarize_context_window
from data.quant_ohlc_feature_schema import FEATURE_VERSION
from clickhouse_driver import Client as ClickHouseClient
from neo4j import GraphDatabase
_worker_dataset = None
_worker_return_class_map = None
_worker_quality_scores_map = None
_worker_encoder = None
def _to_int_list(values):
if values is None:
return []
if isinstance(values, torch.Tensor):
return [int(v) for v in values.tolist()]
return [int(v) for v in values]
def _to_float_list(values):
if values is None:
return []
if isinstance(values, torch.Tensor):
return [float(v) for v in values.tolist()]
return [float(v) for v in values]
def _representative_context_polarity(context):
labels = _to_float_list(context.get("labels"))
mask = _to_int_list(context.get("labels_mask"))
valid_returns = [label for label, keep in zip(labels, mask) if keep > 0]
if not valid_returns:
return "negative"
return "positive" if max(valid_returns) > 0.0 else "negative"
def _class_polarity_targets(class_id, target_contexts_per_class, positive_balance_min_class, positive_ratio):
if class_id >= positive_balance_min_class:
positive_target = int(round(target_contexts_per_class * positive_ratio))
positive_target = min(max(positive_target, 0), target_contexts_per_class)
else:
positive_target = 0
return {
"positive": positive_target,
"negative": max(0, target_contexts_per_class - positive_target),
}
def _remaining_polarity_targets(class_id, accepted_counts, target_contexts_per_class, positive_balance_min_class, positive_ratio):
targets = _class_polarity_targets(
class_id=class_id,
target_contexts_per_class=target_contexts_per_class,
positive_balance_min_class=positive_balance_min_class,
positive_ratio=positive_ratio,
)
class_counts = accepted_counts[class_id]
return {
"positive": max(0, targets["positive"] - class_counts["positive"]),
"negative": max(0, targets["negative"] - class_counts["negative"]),
}
def _select_contexts_by_polarity(contexts, max_keep, desired_positive=None, desired_negative=None):
if len(contexts) <= max_keep:
polarity_counts = {}
for context in contexts:
polarity = _representative_context_polarity(context)
polarity_counts[polarity] = polarity_counts.get(polarity, 0) + 1
context["representative_context_polarity"] = polarity
return contexts, polarity_counts
positive_bucket = []
negative_bucket = []
for context in contexts:
polarity = _representative_context_polarity(context)
context["representative_context_polarity"] = polarity
if polarity == "positive":
positive_bucket.append(context)
else:
negative_bucket.append(context)
selected = []
polarity_counts = {"positive": 0, "negative": 0}
desired_positive = max(0, int(desired_positive)) if desired_positive is not None else None
desired_negative = max(0, int(desired_negative)) if desired_negative is not None else None
if desired_positive is not None or desired_negative is not None:
target_positive = min(desired_positive or 0, max_keep, len(positive_bucket))
target_negative = min(desired_negative or 0, max_keep - target_positive, len(negative_bucket))
while polarity_counts["positive"] < target_positive and positive_bucket:
selected.append(positive_bucket.pop())
polarity_counts["positive"] += 1
while polarity_counts["negative"] < target_negative and negative_bucket:
selected.append(negative_bucket.pop())
polarity_counts["negative"] += 1
prefer_positive = len(positive_bucket) >= len(negative_bucket)
while len(selected) < max_keep and (positive_bucket or negative_bucket):
if prefer_positive and positive_bucket:
selected.append(positive_bucket.pop())
polarity_counts["positive"] += 1
elif not prefer_positive and negative_bucket:
selected.append(negative_bucket.pop())
polarity_counts["negative"] += 1
elif positive_bucket:
selected.append(positive_bucket.pop())
polarity_counts["positive"] += 1
elif negative_bucket:
selected.append(negative_bucket.pop())
polarity_counts["negative"] += 1
prefer_positive = not prefer_positive
return selected[:max_keep], polarity_counts
def _init_worker(db_config, dataset_config, return_class_map, quality_scores_map):
global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
from data.data_loader import OracleDataset
from data.data_fetcher import DataFetcher
clickhouse_client = ClickHouseClient(host=db_config['clickhouse_host'], port=db_config['clickhouse_port'])
neo4j_driver = GraphDatabase.driver(db_config['neo4j_uri'], auth=(db_config['neo4j_user'], db_config['neo4j_password']))
data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
# --- NEW: Init Encoder on GPU ---
from models.multi_modal_processor import MultiModalEncoder
# Using float16 for efficiency on GPU
global _worker_encoder
try:
_worker_encoder = MultiModalEncoder(
model_id="google/siglip-so400m-patch16-256-i18n",
device="cuda",
dtype=torch.float16
)
except Exception as e:
print(f"WARN: Failed to initialize MultiModalEncoder on worker: {e}")
_worker_encoder = None
_worker_dataset = OracleDataset(
data_fetcher=data_fetcher,
start_date=dataset_config['start_date'],
horizons_seconds=dataset_config['horizons_seconds'],
quantiles=dataset_config['quantiles'],
min_trade_usd=dataset_config['min_trade_usd'],
max_seq_len=dataset_config['max_seq_len'],
p99_clamps=dataset_config.get('p99_clamps')
)
_worker_dataset.sampled_mints = dataset_config['sampled_mints']
_worker_return_class_map = return_class_map
_worker_quality_scores_map = quality_scores_map
def _process_single_token_context(args):
idx, mint_addr, samples_per_token, oversample_factor, desired_positive, desired_negative = args
global _worker_dataset, _worker_return_class_map, _worker_quality_scores_map
try:
class_id = _worker_return_class_map.get(mint_addr)
if class_id is None:
return {'status': 'skipped', 'reason': 'not in class map', 'mint': mint_addr}
# Pass the global encoder (if initialized) to pre-compute embeddings
global _worker_encoder
encoder = _worker_encoder
# print(f"DEBUG: Worker encoder status: {type(encoder)}", flush=True) # Commented out to reduce noise if it works
if encoder is None:
print(f"ERROR: Worker encoder is None for mint {mint_addr}!", flush=True)
candidate_contexts = _worker_dataset.__cacheitem_context__(
idx,
num_samples_per_token=max(samples_per_token, samples_per_token * max(1, oversample_factor)),
encoder=encoder,
)
contexts, polarity_counts = _select_contexts_by_polarity(
candidate_contexts,
samples_per_token,
desired_positive=desired_positive,
desired_negative=desired_negative,
)
if not contexts:
return {'status': 'skipped', 'reason': 'no valid contexts', 'mint': mint_addr}
q_score = _worker_quality_scores_map.get(mint_addr)
if q_score is None:
return {'status': 'skipped', 'reason': 'no quality score', 'mint': mint_addr}
return {
'status': 'success',
'mint': mint_addr,
'class_id': class_id,
'q_score': q_score,
'n_contexts': len(contexts),
'n_events': len(contexts[0].get('event_sequence', [])) if contexts else 0,
'contexts': contexts,
'polarity_counts': polarity_counts,
}
except Exception as e:
import traceback
return {'status': 'error', 'mint': mint_addr, 'error': str(e), 'traceback': traceback.format_exc()}
def main():
load_dotenv()
mp.set_start_method('spawn', force=True)
hf_token = os.getenv("HF_TOKEN")
if hf_token:
print(f"INFO: Logging in to Hugging Face...")
huggingface_hub.login(token=hf_token)
parser = argparse.ArgumentParser()
parser.add_argument("--output_dir", type=str, default="data/cache")
parser.add_argument("--start_date", type=str, default=None)
parser.add_argument("--min_trade_usd", type=float, default=0.0)
parser.add_argument("--context_length", type=int, default=8192)
parser.add_argument("--min_trades", type=int, default=10)
parser.add_argument("--samples_per_token", type=int, default=1)
parser.add_argument("--target_contexts_per_class", type=int, default=2500)
parser.add_argument("--context_oversample_factor", type=int, default=4)
parser.add_argument("--positive_balance_min_class", type=int, default=2)
parser.add_argument("--positive_context_ratio", type=float, default=0.5)
parser.add_argument("--horizons_seconds", type=int, nargs="+", default=[30, 60, 120, 240, 420])
parser.add_argument("--quantiles", type=float, nargs="+", default=[0.1, 0.5, 0.9])
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--clickhouse_host", type=str, default=os.getenv("CLICKHOUSE_HOST", "localhost"))
parser.add_argument("--clickhouse_port", type=int, default=int(os.getenv("CLICKHOUSE_PORT", 9000)))
parser.add_argument("--neo4j_uri", type=str, default=os.getenv("NEO4J_URI", "bolt://localhost:7687"))
parser.add_argument("--neo4j_user", type=str, default=os.getenv("NEO4J_USER", "neo4j"))
parser.add_argument("--neo4j_password", type=str, default=os.getenv("NEO4J_PASSWORD", "password"))
args = parser.parse_args()
if args.num_workers == 0:
args.num_workers = max(1, mp.cpu_count() - 4)
if args.num_workers != 1:
raise RuntimeError("Quota-based caching requires --num_workers 1 so class counters remain exact.")
if args.target_contexts_per_class <= 0:
raise RuntimeError("--target_contexts_per_class must be positive.")
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
start_date_dt = datetime.datetime.strptime(args.start_date, "%Y-%m-%d") if args.start_date else None
print(f"INFO: Initializing DB Connections...")
clickhouse_client = ClickHouseClient(host=args.clickhouse_host, port=args.clickhouse_port)
neo4j_driver = GraphDatabase.driver(args.neo4j_uri, auth=(args.neo4j_user, args.neo4j_password))
try:
from data.data_loader import OracleDataset
from data.data_fetcher import DataFetcher
data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
print("INFO: Fetching Return Classification Map...")
return_class_map, _ = get_return_class_map(clickhouse_client)
print(f"INFO: Loaded {len(return_class_map)} classified tokens.")
print("INFO: Computing P99 clamp values...")
p99_clamps = compute_p99_clamps(clickhouse_client)
print("INFO: Fetching Quality Scores...")
quality_scores_map = get_token_quality_scores(clickhouse_client)
print(f"INFO: Loaded {len(quality_scores_map)} quality scores.")
dataset = OracleDataset(data_fetcher=data_fetcher, start_date=start_date_dt, horizons_seconds=args.horizons_seconds, quantiles=args.quantiles, min_trade_usd=args.min_trade_usd, max_seq_len=args.context_length, p99_clamps=p99_clamps)
if len(dataset) == 0:
print("WARNING: No samples. Exiting.")
return
# Filter mints by return_class_map
original_size = len(dataset.sampled_mints)
filtered_mints = [m for m in dataset.sampled_mints if m['mint_address'] in return_class_map]
print(f"INFO: Filtered by class map: {original_size} -> {len(filtered_mints)} tokens")
# Pre-filter: only keep tokens with >= min_trades trades (fast ClickHouse count query)
print(f"INFO: Pre-filtering tokens by trade count (>= {args.min_trades} trades)...")
trade_counts = clickhouse_client.execute("""
SELECT base_address, count() as cnt
FROM trades
GROUP BY base_address
HAVING cnt >= %(min_trades)s
""", {'min_trades': args.min_trades})
valid_tokens = {row[0] for row in trade_counts}
pre_filter_size = len(filtered_mints)
filtered_mints = [m for m in filtered_mints if m['mint_address'] in valid_tokens]
print(f"INFO: Pre-filtered by trade count: {pre_filter_size} -> {len(filtered_mints)} tokens (removed {pre_filter_size - len(filtered_mints)} with < {args.min_trades} trades)")
# Also filter by quality score availability
pre_quality_size = len(filtered_mints)
filtered_mints = [m for m in filtered_mints if m['mint_address'] in quality_scores_map]
print(f"INFO: Filtered by quality score: {pre_quality_size} -> {len(filtered_mints)} tokens")
if len(filtered_mints) == 0:
print("WARNING: No tokens after filtering.")
return
print(f"INFO: Workers: {args.num_workers}")
db_config = {'clickhouse_host': args.clickhouse_host, 'clickhouse_port': args.clickhouse_port, 'neo4j_uri': args.neo4j_uri, 'neo4j_user': args.neo4j_user, 'neo4j_password': args.neo4j_password}
dataset_config = {'start_date': start_date_dt, 'horizons_seconds': args.horizons_seconds, 'quantiles': args.quantiles, 'min_trade_usd': args.min_trade_usd, 'max_seq_len': args.context_length, 'sampled_mints': filtered_mints, 'p99_clamps': p99_clamps}
import random
eligible_class_counts = Counter()
for i, m in enumerate(filtered_mints):
cid = return_class_map.get(m['mint_address'])
if cid is not None:
eligible_class_counts[cid] += 1
print(f"INFO: Eligible tokens per class: {dict(sorted(eligible_class_counts.items()))}")
class_targets = {
int(class_id): int(args.target_contexts_per_class)
for class_id in sorted(eligible_class_counts.keys())
}
class_polarity_targets = {
class_id: _class_polarity_targets(
class_id=class_id,
target_contexts_per_class=args.target_contexts_per_class,
positive_balance_min_class=args.positive_balance_min_class,
positive_ratio=args.positive_context_ratio,
)
for class_id in class_targets
}
target_total = args.target_contexts_per_class * len(class_targets)
print(f"INFO: Target total: {target_total}, Target per class: {args.target_contexts_per_class}")
print(f"INFO: Exact class targets: {dict(sorted(class_targets.items()))}")
print(f"INFO: Class polarity targets: {dict(sorted(class_polarity_targets.items()))}")
tasks = list(enumerate(filtered_mints))
random.shuffle(tasks)
print(f"INFO: Total candidate tokens: {len(tasks)}")
success_count, skipped_count, error_count = 0, 0, 0
class_distribution = defaultdict(int)
polarity_distribution = defaultdict(int)
file_class_map = {}
file_context_bucket_map = {}
file_context_summary_map = {}
accepted_counts = defaultdict(lambda: {"total": 0, "positive": 0, "negative": 0})
# Resume support: count existing files toward quotas.
existing_files = set(f.name for f in output_dir.glob("sample_*.pt"))
if existing_files:
already_cached = 0
for f in sorted(output_dir.glob("sample_*.pt")):
try:
cached = torch.load(f, map_location="cpu", weights_only=False)
except Exception:
continue
class_id = cached.get("class_id")
if class_id is None or int(class_id) not in class_targets:
continue
class_id = int(class_id)
context_summary = summarize_context_window(cached.get("labels"), cached.get("labels_mask"))
polarity = _representative_context_polarity(cached)
file_class_map[f.name] = class_id
file_context_bucket_map[f.name] = context_summary["context_bucket"]
file_context_summary_map[f.name] = context_summary
class_distribution[class_id] += 1
polarity_distribution[polarity] += 1
accepted_counts[class_id]["total"] += 1
accepted_counts[class_id][polarity] += 1
already_cached += 1
print(f"INFO: Resume: counted {already_cached} cached contexts toward quotas.")
print(f"INFO: Starting to cache {len(tasks)} tokens...")
process_fn = _process_single_token_context
import time as _time
def _log_progress(task_num, total, start_time, recent_times, success_count, skipped_count, error_count):
"""Print progress with rolling ETA every 10 tokens."""
if (task_num + 1) % 10 == 0 and recent_times:
avg_time = sum(recent_times) / len(recent_times)
remaining = total - (task_num + 1)
eta_seconds = avg_time * remaining
eta_hours = eta_seconds / 3600
wall_elapsed = _time.perf_counter() - start_time
speed = (task_num + 1) / wall_elapsed
tqdm.write(
f" [PROGRESS] {task_num+1}/{total} | "
f"Speed: {speed:.1f} tok/s ({speed*60:.0f} tok/min) | "
f"Avg: {avg_time:.1f}s/tok | "
f"ETA: {eta_hours:.1f}h | "
f"OK: {success_count} Skip: {skipped_count} Err: {error_count}"
)
# Error log file for diagnosing failures
error_log_path = Path(args.output_dir) / "cache_errors.log"
error_samples = [] # First 20 unique error messages
print("INFO: Single-threaded mode...")
_init_worker(db_config, dataset_config, return_class_map, quality_scores_map)
start_time = _time.perf_counter()
recent_times = []
completed_classes = set()
for task_num, (idx, mint_record) in enumerate(tqdm(tasks, desc="Caching", unit="tok")):
mint_addr = mint_record["mint_address"]
class_id = return_class_map.get(mint_addr)
if class_id is None:
skipped_count += 1
_log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
continue
if accepted_counts[class_id]["total"] >= class_targets[class_id]:
if class_id not in completed_classes:
completed_classes.add(class_id)
tqdm.write(f"INFO: Class {class_id} quota filled. Skipping remaining tokens for this class.")
skipped_count += 1
_log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
continue
remaining_total = class_targets[class_id] - accepted_counts[class_id]["total"]
remaining_polarity = _remaining_polarity_targets(
class_id=class_id,
accepted_counts=accepted_counts,
target_contexts_per_class=args.target_contexts_per_class,
positive_balance_min_class=args.positive_balance_min_class,
positive_ratio=args.positive_context_ratio,
)
desired_positive = min(remaining_polarity["positive"], args.samples_per_token, remaining_total)
desired_negative = min(
remaining_polarity["negative"],
max(0, min(args.samples_per_token, remaining_total) - desired_positive),
)
samples_to_keep = min(args.samples_per_token, remaining_total)
task = (
idx,
mint_addr,
samples_to_keep,
args.context_oversample_factor,
desired_positive,
desired_negative,
)
t0 = _time.perf_counter()
result = process_fn(task)
elapsed = _time.perf_counter() - t0
recent_times.append(elapsed)
if len(recent_times) > 50:
recent_times.pop(0)
if result["status"] == "success":
saved_contexts = 0
for ctx in result.get("contexts", []):
if accepted_counts[class_id]["total"] >= class_targets[class_id]:
break
polarity = _representative_context_polarity(ctx)
remaining_polarity = _remaining_polarity_targets(
class_id=class_id,
accepted_counts=accepted_counts,
target_contexts_per_class=args.target_contexts_per_class,
positive_balance_min_class=args.positive_balance_min_class,
positive_ratio=args.positive_context_ratio,
)
other_polarity = "negative" if polarity == "positive" else "positive"
if remaining_polarity[polarity] <= 0 and remaining_polarity[other_polarity] > 0:
continue
ctx["quality_score"] = result["q_score"]
ctx["class_id"] = class_id
ctx["source_token"] = mint_addr
context_summary = summarize_context_window(ctx.get("labels"), ctx.get("labels_mask"))
ctx["context_bucket"] = context_summary["context_bucket"]
ctx["context_score"] = context_summary["context_score"]
file_idx = accepted_counts[class_id]["total"]
filename = f"sample_{mint_addr[:16]}_{file_idx}.pt"
output_path = output_dir / filename
torch.save(ctx, output_path)
file_class_map[filename] = class_id
file_context_bucket_map[filename] = context_summary["context_bucket"]
file_context_summary_map[filename] = context_summary
class_distribution[class_id] += 1
polarity_distribution[polarity] += 1
accepted_counts[class_id]["total"] += 1
accepted_counts[class_id][polarity] += 1
saved_contexts += 1
if saved_contexts > 0:
success_count += 1
else:
skipped_count += 1
elif result["status"] == "skipped":
skipped_count += 1
else:
error_count += 1
err_msg = result.get("error", "unknown")
tqdm.write(f"ERROR: {result['mint'][:16]} - {err_msg}")
if len(error_samples) < 20:
error_samples.append({'mint': result.get('mint'), 'error': err_msg, 'traceback': result.get('traceback', '')})
if all(accepted_counts[cid]["total"] >= class_targets[cid] for cid in class_targets):
tqdm.write("INFO: All class quotas filled. Stopping early.")
_log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
break
_log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
# Write error log
if error_samples:
with open(error_log_path, 'w') as ef:
for i, es in enumerate(error_samples):
ef.write(f"=== Error {i+1} === Token: {es['mint']}\n")
ef.write(f"Error: {es['error']}\n")
ef.write(f"Traceback:\n{es['traceback']}\n\n")
print(f"INFO: First {len(error_samples)} error tracebacks saved to {error_log_path}")
print("INFO: Building metadata...")
with open(output_dir / "class_metadata.json", 'w') as f:
json.dump({
'file_class_map': file_class_map,
'file_context_bucket_map': file_context_bucket_map,
'file_context_summary_map': file_context_summary_map,
'class_distribution': {str(k): v for k, v in class_distribution.items()},
'num_workers': args.num_workers,
'horizons_seconds': args.horizons_seconds,
'quantiles': args.quantiles,
'target_total': target_total,
'target_contexts_per_class': args.target_contexts_per_class,
'context_polarity_distribution': polarity_distribution,
'class_targets': {str(k): v for k, v in class_targets.items()},
'class_polarity_targets': {str(k): v for k, v in class_polarity_targets.items()},
'accepted_counts': {str(k): v for k, v in accepted_counts.items()},
'positive_balance_min_class': args.positive_balance_min_class,
'positive_context_ratio': args.positive_context_ratio,
'quant_feature_version': FEATURE_VERSION,
}, f, indent=2)
print(f"\n--- Done ---\nSuccess: {success_count}, Skipped: {skipped_count}, Errors: {error_count}\nFiles: {len(file_class_map)}\nLocation: {output_dir.resolve()}")
finally:
clickhouse_client.disconnect()
neo4j_driver.close()
if __name__ == "__main__":
main()