Tarang_v2 / data_loader.py
unknownfriend00007's picture
Update data_loader.py
4936080 verified
import os
import glob
import json
import numpy as np
import pandas as pd
from dataclasses import dataclass
from typing import List, Tuple
from config import config
@dataclass
class AssetSeries:
symbol: str
asset_type: str # "stock" or "crypto"
df: pd.DataFrame # columns: date, open, high, low, close, volume
def _parse_mixed_date(series: pd.Series) -> pd.Series:
s = series.astype(str).str.strip()
iso_mask = s.str.match(r"^\d{4}-\d{2}-\d{2}$", na=False)
out = pd.Series(pd.NaT, index=series.index)
out.loc[iso_mask] = pd.to_datetime(s.loc[iso_mask], errors="coerce", dayfirst=False)
out.loc[~iso_mask] = pd.to_datetime(s.loc[~iso_mask], errors="coerce", dayfirst=True)
return out
def _read_any_ohlcv_csv(path: str) -> pd.DataFrame:
df = pd.read_csv(path, sep=None, engine="python")
cols = [str(c).strip().lower() for c in df.columns]
# Binance format (timestamp_ms)
if "timestamp_ms" in cols:
df.columns = cols
df["date"] = pd.to_datetime(df["timestamp_ms"], unit="ms", errors="coerce", utc=True).dt.tz_convert(None)
out = df[["date", "open", "high", "low", "close", "volume"]].copy()
for c in ["open", "high", "low", "close", "volume"]:
out[c] = pd.to_numeric(out[c], errors="coerce")
out = out.dropna().sort_values("date").reset_index(drop=True)
return out
# Stooq / generic OHLCV
if "date" in cols and "open" in cols and "close" in cols:
df.columns = cols
out = df[["date", "open", "high", "low", "close", "volume"]].copy()
else:
out = df.iloc[:, :6].copy()
out.columns = ["date", "open", "high", "low", "close", "volume"]
out["date"] = _parse_mixed_date(out["date"])
out = out.dropna(subset=["date"]).copy()
for c in ["open", "high", "low", "close", "volume"]:
out[c] = pd.to_numeric(out[c], errors="coerce")
out = out.dropna().sort_values("date").reset_index(drop=True)
return out
def discover_assets(data_dir: str) -> List[Tuple[str, str, str]]:
assets = []
stock_glob = os.path.join(data_dir, "stocks", "stooq", "*", f"{config.INTERVAL}.csv")
for p in glob.glob(stock_glob):
sym = os.path.basename(os.path.dirname(p))
assets.append((sym, "stock", p))
crypto_glob = os.path.join(data_dir, "crypto", "binance", "*", f"{config.INTERVAL}.csv")
for p in glob.glob(crypto_glob):
sym = os.path.basename(os.path.dirname(p))
assets.append((sym, "crypto", p))
assets = sorted(assets, key=lambda x: (x[1], x[0]))
return assets[: config.MAX_ASSETS]
def load_asset_series() -> List[AssetSeries]:
assets = discover_assets(config.DATA_DIR)
series = []
for sym, a_type, path in assets:
df = _read_any_ohlcv_csv(path)
if len(df) >= config.WINDOW + config.HORIZON_DAYS + 5:
series.append(AssetSeries(symbol=sym, asset_type=a_type, df=df))
return series
def make_features(df: pd.DataFrame) -> pd.DataFrame:
out = df.copy()
out["log_return"] = np.log(out["close"] / out["close"].shift(1))
out["hl_range"] = (out["high"] - out["low"]) / (out["close"] + 1e-12)
out["oc_return"] = (out["close"] - out["open"]) / (out["open"] + 1e-12)
out["vol_log"] = np.log10(out["volume"].clip(lower=0) + 1.0)
out = out.dropna().reset_index(drop=True)
return out
def build_windows(feats: pd.DataFrame, window: int, horizon: int):
values = feats[["log_return", "hl_range", "oc_return", "vol_log"]].values.astype(np.float32)
dates = feats["date"].values
X_list, y_list, ts_list = [], [], []
for i in range(window, len(values) - horizon):
X_list.append(values[i - window:i])
y_list.append(values[i + horizon, 0])
ts_list.append(pd.Timestamp(dates[i + horizon]))
return np.stack(X_list), np.array(y_list, dtype=np.float32), ts_list
def save_manifest(series: List[AssetSeries]):
os.makedirs(config.ARTIFACT_DIR, exist_ok=True)
path = os.path.join(config.ARTIFACT_DIR, "manifest.json")
payload = [{"symbol": s.symbol, "asset_type": s.asset_type, "rows": int(len(s.df))} for s in series]
with open(path, "w", encoding="utf-8") as f:
json.dump(payload, f, indent=2)
return path