Spaces:
Paused
Paused
| #!/usr/bin/env python3 | |
| """ | |
| Gradio Space that trains an IQL BTC trading agent on zero-a10g (free GPU). | |
| Training runs in a background thread - the Space itself has GPU access. | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import time | |
| import threading | |
| import traceback | |
| from pathlib import Path | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| # ββ State βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| training_status = { | |
| "running": False, "done": False, "success": False, | |
| "error": None, "progress": [], "result": None, | |
| "start_time": None, "end_time": None, | |
| } | |
| status_lock = threading.Lock() | |
| def log_progress(msg, ptype="info", **extra): | |
| entry = {"msg": msg, "type": ptype, **extra} | |
| with status_lock: | |
| training_status["progress"].append(entry) | |
| def run_training(): | |
| """Run training in background thread. GPU is available on zero-a10g Space.""" | |
| with status_lock: | |
| if training_status["running"]: | |
| return | |
| training_status["running"] = True | |
| training_status["start_time"] = time.time() | |
| training_status["progress"].clear() | |
| try: | |
| log_progress("Downloading dataset...", "info") | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| import torch | |
| data_path = hf_hub_download( | |
| repo_id="fbzu/btc_updown_5m_augmented_v1", | |
| filename="btc_updown_5m_augmented_v1.parquet", | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| ) | |
| log_progress("Dataset downloaded", "info") | |
| log_progress("Downloading code...", "info") | |
| code_dir = snapshot_download( | |
| repo_id="fbzu/rl_btc_v4_iql", | |
| repo_type="model", | |
| token=HF_TOKEN, | |
| allow_patterns=["rl_btc_v4/*"], | |
| ) | |
| sys.path.insert(0, code_dir) | |
| log_progress("Importing modules...", "info") | |
| from rl_btc_v4.dataset import build_offline_rl_dataset | |
| from rl_btc_v4.iql_trainer import IQLTrainer, IQLConfig | |
| from rl_btc_v4.constants import N_ACTIONS | |
| gpu_info = f"PyTorch {torch.__version__}, CUDA: {torch.cuda.is_available()}" | |
| if torch.cuda.is_available(): | |
| gpu_info += f", GPU: {torch.cuda.get_device_name(0)}" | |
| log_progress(gpu_info, "info") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| log_progress("Building offline RL dataset...", "info") | |
| train_dataset, test_dataset = build_offline_rl_dataset( | |
| data_path=data_path, | |
| history_length=30, episode_span_days=30, | |
| episode_stride_days=15, risk_lambda=1.0, | |
| soft_dd_penalty=0.50, test_fraction=0.2, seed=42, | |
| ) | |
| log_progress( | |
| f"Train: {train_dataset.n_transitions} transitions, " | |
| f"Test: {test_dataset.n_transitions}", "info" | |
| ) | |
| log_progress(f"State dim: {train_dataset.states.shape[1]}", "info") | |
| state_dim = train_dataset.states.shape[1] | |
| config = IQLConfig( | |
| hidden_dim=256, num_layers=2, dropout=0.1, | |
| expectile=0.7, temperature=3.0, gamma=0.99, tau=0.005, | |
| learning_rate=3e-4, batch_size=512, num_epochs=100, | |
| weight_decay=1e-4, device=device, seed=42, | |
| ) | |
| trainer = IQLTrainer(state_dim=state_dim, action_dim=N_ACTIONS, config=config) | |
| t_start = time.time() | |
| def progress_fn(epoch, metrics): | |
| elapsed = time.time() - t_start | |
| log_progress( | |
| f"Epoch {epoch}: Q={metrics['q_loss']:.4f} V={metrics['v_loss']:.4f} " | |
| f"\u03c0={metrics['policy_loss']:.4f} Adv={metrics['advantage']:.4f} [{elapsed:.0f}s]", | |
| "epoch", | |
| epoch=epoch, q_loss=round(metrics["q_loss"], 6), | |
| v_loss=round(metrics["v_loss"], 6), | |
| policy_loss=round(metrics["policy_loss"], 6), | |
| advantage=round(metrics["advantage"], 6), | |
| elapsed_s=round(elapsed, 1), | |
| ) | |
| log_progress("Starting IQL training (100 epochs)...", "info") | |
| result = trainer.train( | |
| states=train_dataset.states, actions=train_dataset.actions, | |
| rewards=train_dataset.rewards, next_states=train_dataset.next_states, | |
| dones=train_dataset.dones, eval_states=test_dataset.states, | |
| eval_rewards=test_dataset.rewards, progress_fn=progress_fn, | |
| ) | |
| t_elapsed = time.time() - t_start | |
| log_progress(f"Training complete in {t_elapsed:.1f}s", "success") | |
| # Save artifacts | |
| out_dir = Path("/tmp/rl_btc_v4_artifacts") | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| trainer.save(out_dir) | |
| import numpy as np | |
| np.savez( | |
| out_dir / "scaler.npz", | |
| mean=train_dataset.mean, std=train_dataset.std, | |
| reward_mean=result["reward_mean"], reward_std=result["reward_std"], | |
| ) | |
| report = { | |
| "algorithm": "IQL", "config": config.__dict__, | |
| "dataset": {"path": "fbzu/btc_updown_5m_augmented_v1"}, | |
| "results": result, "training_time_seconds": t_elapsed, "device": device, | |
| } | |
| (out_dir / "train_report.json").write_text(json.dumps(report, indent=2)) | |
| (out_dir / "training_logs.json").write_text( | |
| json.dumps(training_status["progress"], indent=2) | |
| ) | |
| log_progress("Uploading model to HF Hub...", "info") | |
| from huggingface_hub import HfApi | |
| hf_api = HfApi(token=HF_TOKEN) | |
| for f in out_dir.iterdir(): | |
| hf_api.upload_file( | |
| path_or_fileobj=str(f), | |
| path_in_repo=f.name, | |
| repo_id="fbzu/rl_btc_v4_iql", | |
| repo_type="model", | |
| ) | |
| log_progress("\u2705 Model uploaded to https://huggingface.co/fbzu/rl_btc_v4_iql", "success") | |
| with status_lock: | |
| training_status["success"] = True | |
| training_status["result"] = result | |
| except Exception as e: | |
| err = traceback.format_exc() | |
| log_progress(f"\u274c Error: {str(e)}", "error") | |
| with status_lock: | |
| training_status["error"] = err | |
| with status_lock: | |
| training_status["done"] = True | |
| training_status["running"] = False | |
| training_status["end_time"] = time.time() | |
| # ββ Start training immediately ββββββββββββββββββββββββββββββββββββββββββββ | |
| threading.Thread(target=run_training, daemon=True).start() | |
| # ββ Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| import gradio as gr | |
| def get_status(): | |
| with status_lock: | |
| progress = list(training_status["progress"]) | |
| done = training_status["done"] | |
| running = training_status["running"] | |
| success = training_status["success"] | |
| error = training_status["error"] | |
| start = training_status["start_time"] | |
| end = training_status["end_time"] | |
| lines = [] | |
| for p in progress: | |
| msg = p.get("msg", "") | |
| ptype = p.get("type", "info") | |
| prefix = {"info": "\u2139\ufe0f", "success": "\u2705", "error": "\u274c", "epoch": "\U0001f4ca"}.get(ptype, " ") | |
| lines.append(f"{prefix} {msg}") | |
| if not done and not running: | |
| lines.append("\u23f3 Initializing...") | |
| elif not done: | |
| lines.append("\u23f3 Training in progress...") | |
| elif success: | |
| t = (end or time.time()) - (start or time.time()) | |
| lines.append(f"\n\U0001f389 Training complete in {t:.1f}s") | |
| lines.append(f"\n\U0001f4e6 Model: https://huggingface.co/fbzu/rl_btc_v4_iql") | |
| elif error: | |
| lines.append(f"\n\u274c Training failed:\n{error}") | |
| return "\n".join(lines) | |
| def get_logs(): | |
| with status_lock: | |
| progress = list(training_status["progress"]) | |
| epoch_logs = [p for p in progress if p.get("type") == "epoch"] | |
| if not epoch_logs: | |
| return "Waiting for training to start..." | |
| lines = ["Epoch | Q Loss | V Loss | Policy Loss | Advantage | Time(s)"] | |
| lines.append("-" * 80) | |
| for log in epoch_logs: | |
| lines.append( | |
| f"{log['epoch']:5d} | {log['q_loss']:.6f} | {log['v_loss']:.6f} | " | |
| f"{log['policy_loss']:.6f} | {log['advantage']:.8f} | {log['elapsed_s']:.0f}" | |
| ) | |
| return "\n".join(lines) | |
| with gr.Blocks(title="RL BTC v4 IQL Training") as demo: | |
| gr.Markdown("# \U0001f4c8 RL BTC v4 \u2014 Implicit Q-Learning Trading Agent") | |
| gr.Markdown("Training on zero-a10g (free GPU). Dataset: BTC 5m market data with risk-sensitive rewards.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## Training Status") | |
| status_box = gr.Textbox(value=get_status(), lines=15, label="Status") | |
| with gr.Column(): | |
| gr.Markdown("## Training Logs") | |
| logs_box = gr.Textbox(value=get_logs(), lines=20, label="Logs") | |
| refresh_btn = gr.Button("\U0001f504 Refresh") | |
| refresh_btn.click(fn=get_status, outputs=status_box) | |
| refresh_btn.click(fn=get_logs, outputs=logs_box) | |
| gr.Markdown(""" | |
| **Config:** hidden=256, layers=2, dropout=0.1, expectile=0.7, temp=3.0, | |
| gamma=0.99, lr=3e-4, batch=512, epochs=100 | |
| **Action space:** 8 actions (HOLD, FLAT, YES/NO at 10/25/50% exposure) | |
| **Reward:** Risk-sensitive PnL with drawdown penalties | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() | |