sol-arb-trainer / app.py
commanderzee's picture
v1_arb trainer code for sol
9a6e556 verified
"""
Gradio 5 UI for the ARB-MAX trainer Space.
Responsibilities:
- Read env vars (ASSET, MODEL_VARIANT, N_TRIALS, N_FOLDS, AUTO_START, HF_TOKEN).
- Set up DATA_DIR/logs/cache/output.
- On boot: start background worker that waits for the dataset to be ready on
`commanderzee/15m-crypto` (poll every 5 min), then calls train.run_training.
- UI: status, tail of logs (filtered), disk usage, refresh, manual Start button.
"""
from __future__ import annotations
import logging
import os
import shutil
import threading
import time
import traceback
from pathlib import Path
from typing import Optional
import gradio as gr
from huggingface_hub import HfApi
from huggingface_hub.utils import HfHubHTTPError
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
ASSET = os.environ.get("ASSET", "btc").strip().lower()
MODEL_VARIANT = os.environ.get("MODEL_VARIANT", "v1_arb").strip()
N_TRIALS = int(os.environ.get("N_TRIALS", "150"))
N_FOLDS = int(os.environ.get("N_FOLDS", "5"))
AUTO_START = os.environ.get("AUTO_START", "1").strip() == "1"
HF_TOKEN = os.environ.get("HF_TOKEN", "")
DATASET_REPO = "commanderzee/15m-crypto"
DATA_DIR = Path("/data") if Path("/data").exists() else Path("/tmp")
LOG_DIR = DATA_DIR / "logs"
CACHE_DIR = DATA_DIR / "cache"
OUTPUT_DIR = DATA_DIR / "output"
LOG_FILE = LOG_DIR / "train.log"
for d in (LOG_DIR, CACHE_DIR, OUTPUT_DIR):
d.mkdir(parents=True, exist_ok=True)
# ---------------------------------------------------------------------------
# Logger
# ---------------------------------------------------------------------------
def _build_logger() -> logging.Logger:
logger = logging.getLogger("arbmax_trainer")
logger.setLevel(logging.INFO)
# avoid duplicate handlers on reload
for h in list(logger.handlers):
logger.removeHandler(h)
fmt = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
fh = logging.FileHandler(LOG_FILE, mode="a", encoding="utf-8")
fh.setFormatter(fmt)
logger.addHandler(fh)
sh = logging.StreamHandler()
sh.setFormatter(fmt)
logger.addHandler(sh)
logger.propagate = False
return logger
_logger = _build_logger()
def _log(msg: str) -> None:
try:
_logger.info(msg)
except Exception:
# last-ditch fallback
try:
with LOG_FILE.open("a", encoding="utf-8") as f:
f.write(msg + "\n")
except Exception:
pass
# ---------------------------------------------------------------------------
# Worker state
# ---------------------------------------------------------------------------
_worker_lock = threading.Lock()
_worker_state = {
"thread": None, # type: Optional[threading.Thread]
"status": "idle", # idle | waiting_for_data | training | done | error
"detail": "",
"started_at": None,
"finished_at": None,
}
# ---------------------------------------------------------------------------
# Dataset readiness check
# ---------------------------------------------------------------------------
def _dataset_files_ready(hf_token: str, asset: str) -> bool:
"""Both book_snapshot_5/{asset}.parquet and ohlcv_1s/{asset}.parquet present."""
api = HfApi(token=hf_token)
try:
info = api.repo_info(repo_id=DATASET_REPO, repo_type="dataset", token=hf_token)
except HfHubHTTPError:
return False
except Exception:
return False
siblings = getattr(info, "siblings", []) or []
names = {getattr(s, "rfilename", None) for s in siblings}
needed = {f"book_snapshot_5/{asset}.parquet", f"ohlcv_1s/{asset}.parquet"}
return needed.issubset(names)
# ---------------------------------------------------------------------------
# Worker body
# ---------------------------------------------------------------------------
def _worker_body() -> None:
# Deferred import so app boots fast even if heavy libs are slow.
try:
from train import run_training
except Exception as e: # noqa: BLE001
_worker_state["status"] = "error"
_worker_state["detail"] = f"import error: {e!r}"
_log(f"[worker] import error: {e!r}")
_log(traceback.format_exc())
return
_worker_state["status"] = "waiting_for_data"
_worker_state["detail"] = f"waiting for {ASSET} parquet files on {DATASET_REPO}"
_worker_state["started_at"] = time.time()
_log(f"[worker] waiting for {ASSET} parquet files on {DATASET_REPO}")
while True:
try:
ready = _dataset_files_ready(HF_TOKEN, ASSET)
except Exception as e: # noqa: BLE001
_log(f"[worker] readiness check error: {e!r}")
ready = False
if ready:
break
time.sleep(300) # 5 min
_worker_state["status"] = "training"
_worker_state["detail"] = "calling run_training"
_log(f"[worker] dataset ready, starting training for {ASSET} ({MODEL_VARIANT})")
try:
result = run_training(
asset=ASSET,
hf_token=HF_TOKEN,
variant=MODEL_VARIANT,
n_trials=N_TRIALS,
n_folds=N_FOLDS,
log=_log,
cache_dir=str(CACHE_DIR),
output_dir=str(OUTPUT_DIR),
)
_worker_state["status"] = "done"
_worker_state["detail"] = f"training completed: {result}"
_worker_state["finished_at"] = time.time()
_log(f"[worker] done: {result}")
except Exception as e: # noqa: BLE001
_worker_state["status"] = "error"
_worker_state["detail"] = f"training error: {e!r}"
_worker_state["finished_at"] = time.time()
_log(f"[worker] training error: {e!r}")
_log(traceback.format_exc())
def _start_worker_idempotent() -> str:
with _worker_lock:
t: Optional[threading.Thread] = _worker_state.get("thread") # type: ignore[assignment]
if t is not None and t.is_alive():
return f"Worker already running (status={_worker_state['status']})"
new_t = threading.Thread(target=_worker_body, daemon=True, name="arbmax-worker")
_worker_state["thread"] = new_t
new_t.start()
return "Worker started"
# ---------------------------------------------------------------------------
# Log tailing (filtered)
# ---------------------------------------------------------------------------
_LOG_NOISE_FRAGMENTS = (
"HTTP Request:",
"/gradio_api/",
"/_app/",
"/assets/",
"/?logs=",
"/?__theme=",
)
def _tail_log(max_bytes: int = 500 * 1024) -> str:
try:
if not LOG_FILE.exists():
return "(log empty)"
size = LOG_FILE.stat().st_size
read_from = max(0, size - max_bytes)
with LOG_FILE.open("rb") as f:
f.seek(read_from)
raw = f.read().decode("utf-8", errors="replace")
lines = raw.splitlines()
filtered = [l for l in lines if not any(frag in l for frag in _LOG_NOISE_FRAGMENTS)]
return "\n".join(filtered[-2000:]) or "(log empty)"
except Exception as e: # noqa: BLE001
return f"(log read error: {e!r})"
def _disk_usage() -> str:
try:
usage = shutil.disk_usage(str(DATA_DIR))
gb = 1024 ** 3
return (
f"DATA_DIR={DATA_DIR} "
f"total={usage.total / gb:.2f} GB "
f"used={usage.used / gb:.2f} GB "
f"free={usage.free / gb:.2f} GB"
)
except Exception as e: # noqa: BLE001
return f"(disk usage error: {e!r})"
def _status_text() -> str:
s = _worker_state
started = s.get("started_at")
finished = s.get("finished_at")
started_str = time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(started)) if started else "-"
finished_str = time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(finished)) if finished else "-"
return (
f"asset={ASSET} variant={MODEL_VARIANT} trials={N_TRIALS} folds={N_FOLDS}\n"
f"status={s['status']}\n"
f"detail={s['detail']}\n"
f"started_at_utc={started_str}\n"
f"finished_at_utc={finished_str}"
)
def _refresh_all():
return _status_text(), _tail_log(), _disk_usage()
def _on_start_click():
msg = _start_worker_idempotent()
_log(f"[ui] start click: {msg}")
return _status_text(), _tail_log(), _disk_usage()
# ---------------------------------------------------------------------------
# Build Gradio app
# ---------------------------------------------------------------------------
def build_app() -> gr.Blocks:
with gr.Blocks(title=f"ARB-MAX Trainer — {ASSET}/{MODEL_VARIANT}") as demo:
gr.Markdown(
f"# ARB-MAX Trainer\n"
f"**asset** = `{ASSET}` · **variant** = `{MODEL_VARIANT}` · "
f"**trials** = `{N_TRIALS}` · **folds** = `{N_FOLDS}` · "
f"**auto_start** = `{AUTO_START}`"
)
with gr.Row():
with gr.Column(scale=1):
status_box = gr.Textbox(
label="Status",
value=_status_text(),
lines=5,
interactive=False,
)
disk_box = gr.Textbox(
label="Disk",
value=_disk_usage(),
lines=2,
interactive=False,
)
with gr.Row():
refresh_btn = gr.Button("Refresh")
start_btn = gr.Button("Start worker", variant="primary")
with gr.Column(scale=2):
log_box = gr.Code(
label="Training log (tail, filtered)",
value=_tail_log(),
language=None,
lines=28,
)
refresh_btn.click(_refresh_all, outputs=[status_box, log_box, disk_box])
start_btn.click(_on_start_click, outputs=[status_box, log_box, disk_box])
# 5s auto-refresh
timer = gr.Timer(5.0, active=True)
timer.tick(_refresh_all, outputs=[status_box, log_box, disk_box])
return demo
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
if AUTO_START:
try:
_start_worker_idempotent()
except Exception as e: # noqa: BLE001
_log(f"[boot] auto-start error: {e!r}")
demo = build_app()
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", server_port=7860)