tosi-n's picture
Upload folder using huggingface_hub
ce6b50a verified
Raw
History Blame Contribute Delete
3.07 kB
"""Load the replay tape and build the train / OOS / OOS-symbol task splits.
Shallow per-symbol history (free GeckoTerminal cap ~1000 bars) means breadth carries
the GRPO variance: a task = (symbol, window). The symbol-holdout split is the real
generalization test.
"""
from __future__ import annotations
import glob
import json
import os
from pathlib import Path
import numpy as np
import pandas as pd
# Resolve the tape relative to THIS file so it works after the env is built into a
# wheel and installed into site-packages on Prime's infra (where the repo's top-level
# data/ dir does not exist). Packaged copy is authoritative; repo dir is a dev fallback.
_PKG_DATA = Path(__file__).resolve().parent / "data" / "ohlcv"
_REPO_DATA = Path(__file__).resolve().parents[3] / "data" / "ohlcv"
DATA = _PKG_DATA if _PKG_DATA.exists() and any(_PKG_DATA.glob("*.parquet")) else _REPO_DATA
# Quote-side / blue-chip pairs that aren't tradeable "strategy" targets — excluded.
_EXCLUDE = {"USDC", "WETH", "CBBTC", "USDT", "DAI"}
def _load_series() -> dict[str, np.ndarray]:
series: dict[str, np.ndarray] = {}
for f in sorted(glob.glob(str(DATA / "*.parquet"))):
sym = os.path.basename(f).split("__")[0].upper()
if sym in _EXCLUDE:
continue
close = pd.read_parquet(f)["c"].to_numpy(dtype=float)
if len(close) >= 120 and np.all(np.isfinite(close)) and np.all(close > 0):
# keep the most liquid pool per symbol (first sorted = arbitrary but stable)
if sym not in series or len(close) > len(series[sym]):
series[sym] = close
return series
def build_tasks(split: str = "train", oos_frac: float = 0.4, seed: int = 0) -> list[dict]:
"""Return a list of task dicts: {symbol, train_close, oos_close}.
- train: symbols sorted, first 70% of symbols; in-sample = early bars, oos = late bars
- oos: same symbols, evaluated on held-out late window
- oos_symbols: the held-out 30% of symbols (never seen in training)
"""
series = _load_series()
symbols = sorted(series.keys())
rng = np.random.default_rng(seed)
rng.shuffle(symbols)
cut = max(1, int(len(symbols) * 0.7))
train_syms, holdout_syms = symbols[:cut], symbols[cut:]
use = holdout_syms if split == "oos_symbols" else train_syms
tasks = []
for sym in use:
close = series[sym]
split_at = int(len(close) * (1 - oos_frac))
tasks.append({
"symbol": sym,
"train_close": close[:split_at].tolist(),
"oos_close": close[split_at:].tolist(),
"n_bars": len(close),
})
return tasks
def universe_summary() -> dict:
series = _load_series()
return {"n_symbols": len(series), "symbols": sorted(series.keys()),
"bars": {k: len(v) for k, v in series.items()}}
if __name__ == "__main__":
print(json.dumps(universe_summary(), indent=2))
for split in ("train", "oos_symbols"):
t = build_tasks(split)
print(f"{split}: {len(t)} tasks -> {[x['symbol'] for x in t]}")