fbzu's picture
Remove @spaces.GPU decorator, train in background thread directly
6a2c6e2 verified
#!/usr/bin/env python3
"""
Gradio Space that trains an IQL BTC trading agent on zero-a10g (free GPU).
Training runs in a background thread - the Space itself has GPU access.
"""
import os
import sys
import json
import time
import threading
import traceback
from pathlib import Path
HF_TOKEN = os.environ.get("HF_TOKEN")
# ── State ───────────────────────────────────────────────────────────────────
training_status = {
"running": False, "done": False, "success": False,
"error": None, "progress": [], "result": None,
"start_time": None, "end_time": None,
}
status_lock = threading.Lock()
def log_progress(msg, ptype="info", **extra):
entry = {"msg": msg, "type": ptype, **extra}
with status_lock:
training_status["progress"].append(entry)
def run_training():
"""Run training in background thread. GPU is available on zero-a10g Space."""
with status_lock:
if training_status["running"]:
return
training_status["running"] = True
training_status["start_time"] = time.time()
training_status["progress"].clear()
try:
log_progress("Downloading dataset...", "info")
from huggingface_hub import hf_hub_download, snapshot_download
import torch
data_path = hf_hub_download(
repo_id="fbzu/btc_updown_5m_augmented_v1",
filename="btc_updown_5m_augmented_v1.parquet",
repo_type="dataset",
token=HF_TOKEN,
)
log_progress("Dataset downloaded", "info")
log_progress("Downloading code...", "info")
code_dir = snapshot_download(
repo_id="fbzu/rl_btc_v4_iql",
repo_type="model",
token=HF_TOKEN,
allow_patterns=["rl_btc_v4/*"],
)
sys.path.insert(0, code_dir)
log_progress("Importing modules...", "info")
from rl_btc_v4.dataset import build_offline_rl_dataset
from rl_btc_v4.iql_trainer import IQLTrainer, IQLConfig
from rl_btc_v4.constants import N_ACTIONS
gpu_info = f"PyTorch {torch.__version__}, CUDA: {torch.cuda.is_available()}"
if torch.cuda.is_available():
gpu_info += f", GPU: {torch.cuda.get_device_name(0)}"
log_progress(gpu_info, "info")
device = "cuda" if torch.cuda.is_available() else "cpu"
log_progress("Building offline RL dataset...", "info")
train_dataset, test_dataset = build_offline_rl_dataset(
data_path=data_path,
history_length=30, episode_span_days=30,
episode_stride_days=15, risk_lambda=1.0,
soft_dd_penalty=0.50, test_fraction=0.2, seed=42,
)
log_progress(
f"Train: {train_dataset.n_transitions} transitions, "
f"Test: {test_dataset.n_transitions}", "info"
)
log_progress(f"State dim: {train_dataset.states.shape[1]}", "info")
state_dim = train_dataset.states.shape[1]
config = IQLConfig(
hidden_dim=256, num_layers=2, dropout=0.1,
expectile=0.7, temperature=3.0, gamma=0.99, tau=0.005,
learning_rate=3e-4, batch_size=512, num_epochs=100,
weight_decay=1e-4, device=device, seed=42,
)
trainer = IQLTrainer(state_dim=state_dim, action_dim=N_ACTIONS, config=config)
t_start = time.time()
def progress_fn(epoch, metrics):
elapsed = time.time() - t_start
log_progress(
f"Epoch {epoch}: Q={metrics['q_loss']:.4f} V={metrics['v_loss']:.4f} "
f"\u03c0={metrics['policy_loss']:.4f} Adv={metrics['advantage']:.4f} [{elapsed:.0f}s]",
"epoch",
epoch=epoch, q_loss=round(metrics["q_loss"], 6),
v_loss=round(metrics["v_loss"], 6),
policy_loss=round(metrics["policy_loss"], 6),
advantage=round(metrics["advantage"], 6),
elapsed_s=round(elapsed, 1),
)
log_progress("Starting IQL training (100 epochs)...", "info")
result = trainer.train(
states=train_dataset.states, actions=train_dataset.actions,
rewards=train_dataset.rewards, next_states=train_dataset.next_states,
dones=train_dataset.dones, eval_states=test_dataset.states,
eval_rewards=test_dataset.rewards, progress_fn=progress_fn,
)
t_elapsed = time.time() - t_start
log_progress(f"Training complete in {t_elapsed:.1f}s", "success")
# Save artifacts
out_dir = Path("/tmp/rl_btc_v4_artifacts")
out_dir.mkdir(parents=True, exist_ok=True)
trainer.save(out_dir)
import numpy as np
np.savez(
out_dir / "scaler.npz",
mean=train_dataset.mean, std=train_dataset.std,
reward_mean=result["reward_mean"], reward_std=result["reward_std"],
)
report = {
"algorithm": "IQL", "config": config.__dict__,
"dataset": {"path": "fbzu/btc_updown_5m_augmented_v1"},
"results": result, "training_time_seconds": t_elapsed, "device": device,
}
(out_dir / "train_report.json").write_text(json.dumps(report, indent=2))
(out_dir / "training_logs.json").write_text(
json.dumps(training_status["progress"], indent=2)
)
log_progress("Uploading model to HF Hub...", "info")
from huggingface_hub import HfApi
hf_api = HfApi(token=HF_TOKEN)
for f in out_dir.iterdir():
hf_api.upload_file(
path_or_fileobj=str(f),
path_in_repo=f.name,
repo_id="fbzu/rl_btc_v4_iql",
repo_type="model",
)
log_progress("\u2705 Model uploaded to https://huggingface.co/fbzu/rl_btc_v4_iql", "success")
with status_lock:
training_status["success"] = True
training_status["result"] = result
except Exception as e:
err = traceback.format_exc()
log_progress(f"\u274c Error: {str(e)}", "error")
with status_lock:
training_status["error"] = err
with status_lock:
training_status["done"] = True
training_status["running"] = False
training_status["end_time"] = time.time()
# ── Start training immediately ────────────────────────────────────────────
threading.Thread(target=run_training, daemon=True).start()
# ── Gradio UI ──────────────────────────────────────────────────────────────
import gradio as gr
def get_status():
with status_lock:
progress = list(training_status["progress"])
done = training_status["done"]
running = training_status["running"]
success = training_status["success"]
error = training_status["error"]
start = training_status["start_time"]
end = training_status["end_time"]
lines = []
for p in progress:
msg = p.get("msg", "")
ptype = p.get("type", "info")
prefix = {"info": "\u2139\ufe0f", "success": "\u2705", "error": "\u274c", "epoch": "\U0001f4ca"}.get(ptype, " ")
lines.append(f"{prefix} {msg}")
if not done and not running:
lines.append("\u23f3 Initializing...")
elif not done:
lines.append("\u23f3 Training in progress...")
elif success:
t = (end or time.time()) - (start or time.time())
lines.append(f"\n\U0001f389 Training complete in {t:.1f}s")
lines.append(f"\n\U0001f4e6 Model: https://huggingface.co/fbzu/rl_btc_v4_iql")
elif error:
lines.append(f"\n\u274c Training failed:\n{error}")
return "\n".join(lines)
def get_logs():
with status_lock:
progress = list(training_status["progress"])
epoch_logs = [p for p in progress if p.get("type") == "epoch"]
if not epoch_logs:
return "Waiting for training to start..."
lines = ["Epoch | Q Loss | V Loss | Policy Loss | Advantage | Time(s)"]
lines.append("-" * 80)
for log in epoch_logs:
lines.append(
f"{log['epoch']:5d} | {log['q_loss']:.6f} | {log['v_loss']:.6f} | "
f"{log['policy_loss']:.6f} | {log['advantage']:.8f} | {log['elapsed_s']:.0f}"
)
return "\n".join(lines)
with gr.Blocks(title="RL BTC v4 IQL Training") as demo:
gr.Markdown("# \U0001f4c8 RL BTC v4 \u2014 Implicit Q-Learning Trading Agent")
gr.Markdown("Training on zero-a10g (free GPU). Dataset: BTC 5m market data with risk-sensitive rewards.")
with gr.Row():
with gr.Column():
gr.Markdown("## Training Status")
status_box = gr.Textbox(value=get_status(), lines=15, label="Status")
with gr.Column():
gr.Markdown("## Training Logs")
logs_box = gr.Textbox(value=get_logs(), lines=20, label="Logs")
refresh_btn = gr.Button("\U0001f504 Refresh")
refresh_btn.click(fn=get_status, outputs=status_box)
refresh_btn.click(fn=get_logs, outputs=logs_box)
gr.Markdown("""
**Config:** hidden=256, layers=2, dropout=0.1, expectile=0.7, temp=3.0,
gamma=0.99, lr=3e-4, batch=512, epochs=100
**Action space:** 8 actions (HOLD, FLAT, YES/NO at 10/25/50% exposure)
**Reward:** Risk-sensitive PnL with drawdown penalties
""")
if __name__ == "__main__":
demo.launch()