oracle / data /preprocess_distribution.py
zirobtc's picture
Upload folder using huggingface_hub
858826c
#!/usr/bin/env python3
"""
Preprocess distribution statistics for OHLC normalization and token history coverage.
This script:
1. Computes global mean/std figures for price/volume so downstream code can normalize.
2. Prints descriptive stats about how much price history (in seconds) each token has,
helping decide which horizons are realistic.
All configuration is done via environment variables (see below).
"""
import os
import pathlib
import sys
from typing import List
import numpy as np
import clickhouse_connect
# --- Configuration (override via env vars if needed) ---
CLICKHOUSE_HOST = os.getenv("CLICKHOUSE_HOST", "localhost")
CLICKHOUSE_PORT = int(os.getenv("CLICKHOUSE_PORT", "8123"))
CLICKHOUSE_USERNAME = os.getenv("CLICKHOUSE_USERNAME", "default")
CLICKHOUSE_PASSWORD = os.getenv("CLICKHOUSE_PASSWORD", "")
CLICKHOUSE_DATABASE = os.getenv("CLICKHOUSE_DATABASE", "default")
OUTPUT_PATH = pathlib.Path(os.getenv("OHLC_STATS_PATH", "ohlc_stats.npz"))
MIN_PRICE_USD = float(os.getenv("OHLC_MIN_PRICE_USD", "0.0"))
MIN_VOLUME_USD = float(os.getenv("OHLC_MIN_VOLUME_USD", "0.0"))
TOKEN_ADDRESSES_ENV = os.getenv("OHLC_TOKEN_ADDRESSES", "")
TOKEN_ADDRESSES = tuple(addr.strip() for addr in TOKEN_ADDRESSES_ENV.split(",") if addr.strip()) or None
def build_where_clause() -> List[str]:
clauses = ["t.price_usd > %(min_price)s", "t.total_usd > %(min_vol)s"]
if TOKEN_ADDRESSES:
clauses.append("t.base_address IN %(token_addresses)s")
return clauses
def build_stats_query(where_sql: str) -> str:
return f"""
SELECT
AVG(t.price_usd) AS mean_price_usd,
stddevPop(t.price_usd) AS std_price_usd,
AVG(t.price) AS mean_price_native,
stddevPop(t.price) AS std_price_native,
AVG(t.total_usd) AS mean_trade_value_usd,
stddevPop(t.total_usd) AS std_trade_value_usd
FROM trades AS t
INNER JOIN mints AS m
ON m.mint_address = t.base_address
WHERE {where_sql}
"""
def build_history_query(where_sql: str) -> str:
return f"""
SELECT
t.base_address AS token_address,
toUnixTimestamp(min(t.timestamp)) AS first_ts,
toUnixTimestamp(max(t.timestamp)) AS last_ts,
toUnixTimestamp(max(t.timestamp)) - toUnixTimestamp(min(t.timestamp)) AS history_seconds
FROM trades AS t
INNER JOIN mints AS m
ON m.mint_address = t.base_address
WHERE {where_sql}
GROUP BY token_address
"""
def summarize_histories(histories: np.ndarray) -> None:
if histories.size == 0:
print("No token history stats available (no qualifying trades).")
return
stats = {
"count": histories.size,
"min": histories.min(),
"median": float(np.median(histories)),
"mean": histories.mean(),
"p90": float(np.percentile(histories, 90)),
"max": histories.max(),
}
def format_seconds(sec: float) -> str:
hours = sec / 3600.0
days = hours / 24.0
return f"{sec:.0f}s ({hours:.2f}h / {days:.2f}d)"
print("\nToken history coverage (seconds):")
print(f" Tokens analyzed: {int(stats['count'])}")
print(f" Min history: {format_seconds(stats['min'])}")
print(f" Median history: {format_seconds(stats['median'])}")
print(f" Mean history: {format_seconds(stats['mean'])}")
print(f" 90th percentile: {format_seconds(stats['p90'])}")
print(f" Max history: {format_seconds(stats['max'])}")
def main() -> int:
where_clauses = build_where_clause()
where_sql = " AND ".join(where_clauses) if where_clauses else "1"
params: dict[str, object] = {
"min_price": max(MIN_PRICE_USD, 0.0),
"min_vol": max(MIN_VOLUME_USD, 0.0),
}
if TOKEN_ADDRESSES:
params["token_addresses"] = TOKEN_ADDRESSES
client = clickhouse_connect.get_client(
host=CLICKHOUSE_HOST,
port=CLICKHOUSE_PORT,
username=CLICKHOUSE_USERNAME,
password=CLICKHOUSE_PASSWORD,
database=CLICKHOUSE_DATABASE,
)
# --- Price/volume stats ---
stats_query = build_stats_query(where_sql)
stats_result = client.query(stats_query, parameters=params)
if not stats_result.result_rows:
print("ERROR: Stats query returned no rows. Check filters / connectivity.", file=sys.stderr)
return 1
(
mean_price_usd,
std_price_usd,
mean_price_native,
std_price_native,
mean_trade_value_usd,
std_trade_value_usd,
) = stats_result.result_rows[0]
stats = {
"mean_price_usd": float(mean_price_usd or 0.0),
"std_price_usd": float(std_price_usd or 1.0),
"mean_price_native": float(mean_price_native or 0.0),
"std_price_native": float(std_price_native or 1.0),
"mean_trade_value_usd": float(mean_trade_value_usd or 0.0),
"std_trade_value_usd": float(std_trade_value_usd or 1.0),
}
OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True)
np.savez(OUTPUT_PATH, **stats)
print(f"Saved stats to {OUTPUT_PATH.resolve()}:")
for key, value in stats.items():
print(f" {key}: {value:.6f}")
# --- Token history coverage ---
history_query = build_history_query(where_sql)
history_result = client.query(history_query, parameters=params)
history_seconds = np.array(
[float(row[3]) for row in history_result.result_rows if row[3] is not None],
dtype=np.float64
)
summarize_histories(history_seconds)
return 0
if __name__ == "__main__":
raise SystemExit(main())