oracle / scripts /analyze_hyperparams.py
zirobtc's picture
Upload folder using huggingface_hub
e125fa3
import os
import argparse
from typing import List, Optional, Sequence, Tuple
from dotenv import load_dotenv
from clickhouse_driver import Client as ClickHouseClient
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Fast SQL-based hyperparameter analysis (trades-only) for seq_len + horizons."
)
parser.add_argument("--token_address", type=str, default=None, help="Analyze a single token address.")
parser.add_argument(
"--windows_min",
type=str,
default="5,10,30,60",
help="Comma-separated trade-count windows in minutes (e.g. '5,10,30,60').",
)
parser.add_argument(
"--min_price_usd",
type=float,
default=0.0,
help="Treat trades with price_usd <= min_price_usd as invalid (default: 0.0).",
)
return parser.parse_args()
def _parse_windows(windows_min: str) -> List[int]:
out: List[int] = []
for part in (windows_min or "").split(","):
part = part.strip()
if not part:
continue
out.append(int(part))
out = sorted(set([w for w in out if w > 0]))
if not out:
raise ValueError("No valid --windows_min provided.")
return out
def _connect_clickhouse_from_env() -> ClickHouseClient:
ch_host = os.getenv("CLICKHOUSE_HOST", "localhost")
ch_port = int(os.getenv("CLICKHOUSE_NATIVE_PORT", "9000"))
ch_user = os.getenv("CLICKHOUSE_USER", None)
ch_pass = os.getenv("CLICKHOUSE_PASSWORD", None)
ch_db = os.getenv("CLICKHOUSE_DB", None)
kwargs = {"host": ch_host, "port": ch_port}
if ch_user:
kwargs["user"] = ch_user
if ch_pass:
kwargs["password"] = ch_pass
if ch_db:
kwargs["database"] = ch_db
return ClickHouseClient(**kwargs)
def _quantile_levels() -> Sequence[float]:
# Keep these aligned with the printed labels below.
return (0.25, 0.5, 0.75, 0.90, 0.95, 0.99)
def _fmt_q_tuple(q: Tuple[float, ...]) -> str:
# Labels match _quantile_levels()
labels = ["25%", "50%", "75%", "90%", "95%", "99%"]
parts = []
for lbl, v in zip(labels, q):
parts.append(f"{lbl}: {float(v):.2f}")
return " | ".join(parts)
def _print_row(prefix: str, mean_v: float, q_tuple: Tuple[float, ...], max_v: float) -> None:
print(f"[{prefix}]")
print(f" Mean: {float(mean_v):.2f} | Median: {float(q_tuple[1]):.2f} | Max: {float(max_v):.2f}")
print(f" {_fmt_q_tuple(q_tuple)}")
def fetch_aggregated_stats_sql(
ch: ClickHouseClient,
windows_min: List[int],
min_price_usd: float,
token_address: Optional[str] = None,
) -> List[tuple]:
"""
One ClickHouse query that computes distribution statistics directly (no per-token loop in Python).
Returns two groups:
- grp='all'
- grp='subset' where trades_full > 50 and lifespan_sec > 300 (5 minutes)
"""
q_levels = _quantile_levels()
q_levels_sql = ", ".join(str(q) for q in q_levels)
per_token_window_exprs = []
agg_window_exprs = []
for w in windows_min:
sec = int(w) * 60
per_token_window_exprs.append(
f"countIf(is_valid AND (trade_ts - mint_ts) <= {sec}) AS trades_{w}m"
)
agg_window_exprs.append(
f"avg(trades_{w}m) AS trades_{w}m_mean,"
f" quantilesExact({q_levels_sql})(trades_{w}m) AS trades_{w}m_q,"
f" max(trades_{w}m) AS trades_{w}m_max"
)
params = {"min_price": float(min_price_usd)}
token_filter = ""
if token_address:
token_filter = "AND m.mint_address = %(token)s"
params["token"] = token_address
# Note: we pre-filter trades to only minted tokens for speed.
query = f"""
WITH
per_token AS (
SELECT
m.mint_address AS mint_address,
toUnixTimestamp(m.timestamp) AS mint_ts,
countIf(is_valid) AS trades_full,
(maxIf(trade_ts, is_valid) - mint_ts) AS lifespan_sec,
(toUnixTimestamp(argMaxIf(t.timestamp, t.price_usd, is_valid)) - mint_ts) AS time_to_ath_sec,
{", ".join(per_token_window_exprs)}
FROM mints AS m
INNER JOIN
(
SELECT
base_address,
timestamp,
toUnixTimestamp(timestamp) AS trade_ts,
price_usd,
(price_usd > %(min_price)s) AS is_valid
FROM trades
WHERE base_address IN (SELECT mint_address FROM mints)
) AS t
ON t.base_address = m.mint_address
WHERE 1=1
{token_filter}
GROUP BY
mint_address,
mint_ts
HAVING
trades_full > 0
)
SELECT
grp,
count() AS tokens,
avg(trades_full) AS trades_full_mean,
quantilesExact({q_levels_sql})(trades_full) AS trades_full_q,
max(trades_full) AS trades_full_max,
avg(lifespan_sec / 60.0) AS lifespan_min_mean,
quantilesExact({q_levels_sql})(lifespan_sec / 60.0) AS lifespan_min_q,
max(lifespan_sec / 60.0) AS lifespan_min_max,
avg(time_to_ath_sec / 60.0) AS tta_min_mean,
quantilesExact({q_levels_sql})(time_to_ath_sec / 60.0) AS tta_min_q,
max(time_to_ath_sec / 60.0) AS tta_min_max,
{", ".join(agg_window_exprs)}
FROM per_token
ARRAY JOIN ['all', 'subset'] AS grp
WHERE (grp = 'all')
OR (grp = 'subset' AND trades_full > 50 AND lifespan_sec > 300)
GROUP BY grp
ORDER BY grp
"""
return ch.execute(query, params)
def fetch_single_token_sql(
ch: ClickHouseClient,
windows_min: List[int],
min_price_usd: float,
token_address: str,
) -> Optional[tuple]:
per_token_window_exprs = []
for w in windows_min:
sec = int(w) * 60
per_token_window_exprs.append(
f"countIf(is_valid AND (trade_ts - mint_ts) <= {sec}) AS trades_{w}m"
)
params = {"min_price": float(min_price_usd), "token": token_address}
query = f"""
SELECT
m.mint_address AS mint_address,
toUnixTimestamp(m.timestamp) AS mint_ts,
countIf(is_valid) AS trades_full,
(maxIf(trade_ts, is_valid) - mint_ts) AS lifespan_sec,
(toUnixTimestamp(argMaxIf(t.timestamp, t.price_usd, is_valid)) - mint_ts) AS time_to_ath_sec,
{", ".join(per_token_window_exprs)}
FROM mints AS m
INNER JOIN
(
SELECT
base_address,
timestamp,
toUnixTimestamp(timestamp) AS trade_ts,
price_usd,
(price_usd > %(min_price)s) AS is_valid
FROM trades
WHERE base_address = %(token)s
) AS t
ON t.base_address = m.mint_address
WHERE m.mint_address = %(token)s
GROUP BY
mint_address,
mint_ts
HAVING
trades_full > 0
"""
rows = ch.execute(query, params)
return rows[0] if rows else None
def main() -> None:
load_dotenv()
args = parse_args()
windows_min = _parse_windows(args.windows_min)
print("--- Hyperparameter Calibration Analysis (FAST SQL) ---")
print(f"Windows (min): {windows_min}")
print(f"Valid trade filter: price_usd > {float(args.min_price_usd)}")
ch = _connect_clickhouse_from_env()
if args.token_address:
row = fetch_single_token_sql(
ch=ch,
windows_min=windows_min,
min_price_usd=float(args.min_price_usd),
token_address=args.token_address,
)
if not row:
print("Token not found (or no valid trades).")
return
mint_addr = row[0]
trades_full = int(row[2])
lifespan_min = float(row[3]) / 60.0
tta_min = float(row[4]) / 60.0
print("\n" + "=" * 40)
print("RESULTS (SINGLE TOKEN)")
print("=" * 40)
print(f"Token: {mint_addr}")
print(f"Valid trades: {trades_full}")
print(f"Lifespan (min): {lifespan_min:.2f}")
print(f"Time to ATH (min): {tta_min:.2f}")
for i, w in enumerate(windows_min):
print(f"Trades in first {w}m: {int(row[5 + i])}")
else:
rows = fetch_aggregated_stats_sql(
ch=ch,
windows_min=windows_min,
min_price_usd=float(args.min_price_usd),
token_address=None,
)
if not rows:
print("No tokens found with valid trades.")
return
print("\n" + "=" * 40)
print("RESULTS (DISTRIBUTION)")
print("=" * 40)
# Row layout:
# grp, tokens,
# trades_full_mean, trades_full_q(tuple), trades_full_max,
# lifespan_min_mean, lifespan_min_q(tuple), lifespan_min_max,
# tta_min_mean, tta_min_q(tuple), tta_min_max,
# repeated for each window: mean, q(tuple), max
for row in rows:
grp = row[0]
tokens = int(row[1])
print(f"\n--- Group: {grp} (tokens={tokens}) ---")
_print_row("Trades (Full History, Valid Only)", row[2], row[3], row[4])
print("")
_print_row("Token Lifespan (Minutes)", row[5], row[6], row[7])
print("")
_print_row("Time to ATH (Minutes)", row[8], row[9], row[10])
cursor = 11
for w in windows_min:
mean_v = row[cursor]
q_v = row[cursor + 1]
max_v = row[cursor + 2]
cursor += 3
print("")
_print_row(f"Trades in First {w} Minutes (Valid Only)", mean_v, q_v, max_v)
print("\nRecommendation Logic (Trades-only):")
print("- Horizons: look at Time-to-ATH p90/p95 (all vs subset).")
print("- Max seq len: look at Trades-in-first-(max horizon) p95/p99.")
print(" Then add headroom for non-trade events (transfers/pool/liquidity/etc).")
if __name__ == "__main__":
main()