"""Load the replay tape and build the train / OOS / OOS-symbol task splits. Shallow per-symbol history (free GeckoTerminal cap ~1000 bars) means breadth carries the GRPO variance: a task = (symbol, window). The symbol-holdout split is the real generalization test. """ from __future__ import annotations import glob import json import os from pathlib import Path import numpy as np import pandas as pd # Resolve the tape relative to THIS file so it works after the env is built into a # wheel and installed into site-packages on Prime's infra (where the repo's top-level # data/ dir does not exist). Packaged copy is authoritative; repo dir is a dev fallback. _PKG_DATA = Path(__file__).resolve().parent / "data" / "ohlcv" _REPO_DATA = Path(__file__).resolve().parents[3] / "data" / "ohlcv" DATA = _PKG_DATA if _PKG_DATA.exists() and any(_PKG_DATA.glob("*.parquet")) else _REPO_DATA # Quote-side / blue-chip pairs that aren't tradeable "strategy" targets — excluded. _EXCLUDE = {"USDC", "WETH", "CBBTC", "USDT", "DAI"} def _load_series() -> dict[str, np.ndarray]: series: dict[str, np.ndarray] = {} for f in sorted(glob.glob(str(DATA / "*.parquet"))): sym = os.path.basename(f).split("__")[0].upper() if sym in _EXCLUDE: continue close = pd.read_parquet(f)["c"].to_numpy(dtype=float) if len(close) >= 120 and np.all(np.isfinite(close)) and np.all(close > 0): # keep the most liquid pool per symbol (first sorted = arbitrary but stable) if sym not in series or len(close) > len(series[sym]): series[sym] = close return series def build_tasks(split: str = "train", oos_frac: float = 0.4, seed: int = 0) -> list[dict]: """Return a list of task dicts: {symbol, train_close, oos_close}. - train: symbols sorted, first 70% of symbols; in-sample = early bars, oos = late bars - oos: same symbols, evaluated on held-out late window - oos_symbols: the held-out 30% of symbols (never seen in training) """ series = _load_series() symbols = sorted(series.keys()) rng = np.random.default_rng(seed) rng.shuffle(symbols) cut = max(1, int(len(symbols) * 0.7)) train_syms, holdout_syms = symbols[:cut], symbols[cut:] use = holdout_syms if split == "oos_symbols" else train_syms tasks = [] for sym in use: close = series[sym] split_at = int(len(close) * (1 - oos_frac)) tasks.append({ "symbol": sym, "train_close": close[:split_at].tolist(), "oos_close": close[split_at:].tolist(), "n_bars": len(close), }) return tasks def universe_summary() -> dict: series = _load_series() return {"n_symbols": len(series), "symbols": sorted(series.keys()), "bars": {k: len(v) for k, v in series.items()}} if __name__ == "__main__": print(json.dumps(universe_summary(), indent=2)) for split in ("train", "oos_symbols"): t = build_tasks(split) print(f"{split}: {len(t)} tasks -> {[x['symbol'] for x in t]}")