Spaces:
Sleeping
Sleeping
| """ | |
| Gradio 5 UI for the ARB-MAX trainer Space. | |
| Responsibilities: | |
| - Read env vars (ASSET, MODEL_VARIANT, N_TRIALS, N_FOLDS, AUTO_START, HF_TOKEN). | |
| - Set up DATA_DIR/logs/cache/output. | |
| - On boot: start background worker that waits for the dataset to be ready on | |
| `commanderzee/15m-crypto` (poll every 5 min), then calls train.run_training. | |
| - UI: status, tail of logs (filtered), disk usage, refresh, manual Start button. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import os | |
| import shutil | |
| import threading | |
| import time | |
| import traceback | |
| from pathlib import Path | |
| from typing import Optional | |
| import gradio as gr | |
| from huggingface_hub import HfApi | |
| from huggingface_hub.utils import HfHubHTTPError | |
| # --------------------------------------------------------------------------- | |
| # Config | |
| # --------------------------------------------------------------------------- | |
| ASSET = os.environ.get("ASSET", "btc").strip().lower() | |
| MODEL_VARIANT = os.environ.get("MODEL_VARIANT", "v1_arb").strip() | |
| N_TRIALS = int(os.environ.get("N_TRIALS", "150")) | |
| N_FOLDS = int(os.environ.get("N_FOLDS", "5")) | |
| AUTO_START = os.environ.get("AUTO_START", "1").strip() == "1" | |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
| DATASET_REPO = "commanderzee/15m-crypto" | |
| DATA_DIR = Path("/data") if Path("/data").exists() else Path("/tmp") | |
| LOG_DIR = DATA_DIR / "logs" | |
| CACHE_DIR = DATA_DIR / "cache" | |
| OUTPUT_DIR = DATA_DIR / "output" | |
| LOG_FILE = LOG_DIR / "train.log" | |
| for d in (LOG_DIR, CACHE_DIR, OUTPUT_DIR): | |
| d.mkdir(parents=True, exist_ok=True) | |
| # --------------------------------------------------------------------------- | |
| # Logger | |
| # --------------------------------------------------------------------------- | |
| def _build_logger() -> logging.Logger: | |
| logger = logging.getLogger("arbmax_trainer") | |
| logger.setLevel(logging.INFO) | |
| # avoid duplicate handlers on reload | |
| for h in list(logger.handlers): | |
| logger.removeHandler(h) | |
| fmt = logging.Formatter("%(asctime)s %(levelname)s %(message)s") | |
| fh = logging.FileHandler(LOG_FILE, mode="a", encoding="utf-8") | |
| fh.setFormatter(fmt) | |
| logger.addHandler(fh) | |
| sh = logging.StreamHandler() | |
| sh.setFormatter(fmt) | |
| logger.addHandler(sh) | |
| logger.propagate = False | |
| return logger | |
| _logger = _build_logger() | |
| def _log(msg: str) -> None: | |
| try: | |
| _logger.info(msg) | |
| except Exception: | |
| # last-ditch fallback | |
| try: | |
| with LOG_FILE.open("a", encoding="utf-8") as f: | |
| f.write(msg + "\n") | |
| except Exception: | |
| pass | |
| # --------------------------------------------------------------------------- | |
| # Worker state | |
| # --------------------------------------------------------------------------- | |
| _worker_lock = threading.Lock() | |
| _worker_state = { | |
| "thread": None, # type: Optional[threading.Thread] | |
| "status": "idle", # idle | waiting_for_data | training | done | error | |
| "detail": "", | |
| "started_at": None, | |
| "finished_at": None, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Dataset readiness check | |
| # --------------------------------------------------------------------------- | |
| def _dataset_files_ready(hf_token: str, asset: str) -> bool: | |
| """Both book_snapshot_5/{asset}.parquet and ohlcv_1s/{asset}.parquet present.""" | |
| api = HfApi(token=hf_token) | |
| try: | |
| info = api.repo_info(repo_id=DATASET_REPO, repo_type="dataset", token=hf_token) | |
| except HfHubHTTPError: | |
| return False | |
| except Exception: | |
| return False | |
| siblings = getattr(info, "siblings", []) or [] | |
| names = {getattr(s, "rfilename", None) for s in siblings} | |
| needed = {f"book_snapshot_5/{asset}.parquet", f"ohlcv_1s/{asset}.parquet"} | |
| return needed.issubset(names) | |
| # --------------------------------------------------------------------------- | |
| # Worker body | |
| # --------------------------------------------------------------------------- | |
| def _worker_body() -> None: | |
| # Deferred import so app boots fast even if heavy libs are slow. | |
| try: | |
| from train import run_training | |
| except Exception as e: # noqa: BLE001 | |
| _worker_state["status"] = "error" | |
| _worker_state["detail"] = f"import error: {e!r}" | |
| _log(f"[worker] import error: {e!r}") | |
| _log(traceback.format_exc()) | |
| return | |
| _worker_state["status"] = "waiting_for_data" | |
| _worker_state["detail"] = f"waiting for {ASSET} parquet files on {DATASET_REPO}" | |
| _worker_state["started_at"] = time.time() | |
| _log(f"[worker] waiting for {ASSET} parquet files on {DATASET_REPO}") | |
| while True: | |
| try: | |
| ready = _dataset_files_ready(HF_TOKEN, ASSET) | |
| except Exception as e: # noqa: BLE001 | |
| _log(f"[worker] readiness check error: {e!r}") | |
| ready = False | |
| if ready: | |
| break | |
| time.sleep(300) # 5 min | |
| _worker_state["status"] = "training" | |
| _worker_state["detail"] = "calling run_training" | |
| _log(f"[worker] dataset ready, starting training for {ASSET} ({MODEL_VARIANT})") | |
| try: | |
| result = run_training( | |
| asset=ASSET, | |
| hf_token=HF_TOKEN, | |
| variant=MODEL_VARIANT, | |
| n_trials=N_TRIALS, | |
| n_folds=N_FOLDS, | |
| log=_log, | |
| cache_dir=str(CACHE_DIR), | |
| output_dir=str(OUTPUT_DIR), | |
| ) | |
| _worker_state["status"] = "done" | |
| _worker_state["detail"] = f"training completed: {result}" | |
| _worker_state["finished_at"] = time.time() | |
| _log(f"[worker] done: {result}") | |
| except Exception as e: # noqa: BLE001 | |
| _worker_state["status"] = "error" | |
| _worker_state["detail"] = f"training error: {e!r}" | |
| _worker_state["finished_at"] = time.time() | |
| _log(f"[worker] training error: {e!r}") | |
| _log(traceback.format_exc()) | |
| def _start_worker_idempotent() -> str: | |
| with _worker_lock: | |
| t: Optional[threading.Thread] = _worker_state.get("thread") # type: ignore[assignment] | |
| if t is not None and t.is_alive(): | |
| return f"Worker already running (status={_worker_state['status']})" | |
| new_t = threading.Thread(target=_worker_body, daemon=True, name="arbmax-worker") | |
| _worker_state["thread"] = new_t | |
| new_t.start() | |
| return "Worker started" | |
| # --------------------------------------------------------------------------- | |
| # Log tailing (filtered) | |
| # --------------------------------------------------------------------------- | |
| _LOG_NOISE_FRAGMENTS = ( | |
| "HTTP Request:", | |
| "/gradio_api/", | |
| "/_app/", | |
| "/assets/", | |
| "/?logs=", | |
| "/?__theme=", | |
| ) | |
| def _tail_log(max_bytes: int = 500 * 1024) -> str: | |
| try: | |
| if not LOG_FILE.exists(): | |
| return "(log empty)" | |
| size = LOG_FILE.stat().st_size | |
| read_from = max(0, size - max_bytes) | |
| with LOG_FILE.open("rb") as f: | |
| f.seek(read_from) | |
| raw = f.read().decode("utf-8", errors="replace") | |
| lines = raw.splitlines() | |
| filtered = [l for l in lines if not any(frag in l for frag in _LOG_NOISE_FRAGMENTS)] | |
| return "\n".join(filtered[-2000:]) or "(log empty)" | |
| except Exception as e: # noqa: BLE001 | |
| return f"(log read error: {e!r})" | |
| def _disk_usage() -> str: | |
| try: | |
| usage = shutil.disk_usage(str(DATA_DIR)) | |
| gb = 1024 ** 3 | |
| return ( | |
| f"DATA_DIR={DATA_DIR} " | |
| f"total={usage.total / gb:.2f} GB " | |
| f"used={usage.used / gb:.2f} GB " | |
| f"free={usage.free / gb:.2f} GB" | |
| ) | |
| except Exception as e: # noqa: BLE001 | |
| return f"(disk usage error: {e!r})" | |
| def _status_text() -> str: | |
| s = _worker_state | |
| started = s.get("started_at") | |
| finished = s.get("finished_at") | |
| started_str = time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(started)) if started else "-" | |
| finished_str = time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(finished)) if finished else "-" | |
| return ( | |
| f"asset={ASSET} variant={MODEL_VARIANT} trials={N_TRIALS} folds={N_FOLDS}\n" | |
| f"status={s['status']}\n" | |
| f"detail={s['detail']}\n" | |
| f"started_at_utc={started_str}\n" | |
| f"finished_at_utc={finished_str}" | |
| ) | |
| def _refresh_all(): | |
| return _status_text(), _tail_log(), _disk_usage() | |
| def _on_start_click(): | |
| msg = _start_worker_idempotent() | |
| _log(f"[ui] start click: {msg}") | |
| return _status_text(), _tail_log(), _disk_usage() | |
| # --------------------------------------------------------------------------- | |
| # Build Gradio app | |
| # --------------------------------------------------------------------------- | |
| def build_app() -> gr.Blocks: | |
| with gr.Blocks(title=f"ARB-MAX Trainer — {ASSET}/{MODEL_VARIANT}") as demo: | |
| gr.Markdown( | |
| f"# ARB-MAX Trainer\n" | |
| f"**asset** = `{ASSET}` · **variant** = `{MODEL_VARIANT}` · " | |
| f"**trials** = `{N_TRIALS}` · **folds** = `{N_FOLDS}` · " | |
| f"**auto_start** = `{AUTO_START}`" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| status_box = gr.Textbox( | |
| label="Status", | |
| value=_status_text(), | |
| lines=5, | |
| interactive=False, | |
| ) | |
| disk_box = gr.Textbox( | |
| label="Disk", | |
| value=_disk_usage(), | |
| lines=2, | |
| interactive=False, | |
| ) | |
| with gr.Row(): | |
| refresh_btn = gr.Button("Refresh") | |
| start_btn = gr.Button("Start worker", variant="primary") | |
| with gr.Column(scale=2): | |
| log_box = gr.Code( | |
| label="Training log (tail, filtered)", | |
| value=_tail_log(), | |
| language=None, | |
| lines=28, | |
| ) | |
| refresh_btn.click(_refresh_all, outputs=[status_box, log_box, disk_box]) | |
| start_btn.click(_on_start_click, outputs=[status_box, log_box, disk_box]) | |
| # 5s auto-refresh | |
| timer = gr.Timer(5.0, active=True) | |
| timer.tick(_refresh_all, outputs=[status_box, log_box, disk_box]) | |
| return demo | |
| # --------------------------------------------------------------------------- | |
| # Entry point | |
| # --------------------------------------------------------------------------- | |
| if AUTO_START: | |
| try: | |
| _start_worker_idempotent() | |
| except Exception as e: # noqa: BLE001 | |
| _log(f"[boot] auto-start error: {e!r}") | |
| demo = build_app() | |
| if __name__ == "__main__": | |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860) | |