fbzu commited on
Commit
5c94ca4
Β·
verified Β·
1 Parent(s): f6d278e

Add training app

Browse files
Files changed (1) hide show
  1. app.py +274 -0
app.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Gradio Space that trains an IQL BTC trading agent in the background.
4
+ Hardware: zero-a10g (free for Pro users).
5
+ """
6
+ import os
7
+ import sys
8
+ import json
9
+ import time
10
+ import threading
11
+ import traceback
12
+ from pathlib import Path
13
+
14
+ import numpy as np
15
+
16
+ # ── State ───────────────────────────────────────────────────────────────────
17
+ training_status = {"running": False, "done": False, "success": False,
18
+ "error": None, "progress": [], "result": None,
19
+ "start_time": None, "end_time": None}
20
+
21
+ # ── Download dataset and code ─────────────────────────────────────────────
22
+ from huggingface_hub import hf_hub_download, snapshot_download
23
+
24
+ HF_TOKEN = os.environ.get("HF_TOKEN")
25
+
26
+ def run_training():
27
+ """Run training in background thread."""
28
+ training_status["running"] = True
29
+ training_status["start_time"] = time.time()
30
+
31
+ try:
32
+ # Check for saved model
33
+ out_dir = Path("/tmp/rl_btc_v4_artifacts")
34
+ if (out_dir / "iql_model.pt").exists():
35
+ training_status["progress"].append({"msg": "Model already trained, loading...", "type": "info"})
36
+ training_status["done"] = True
37
+ training_status["success"] = True
38
+ training_status["running"] = False
39
+ return
40
+
41
+ training_status["progress"].append({"msg": "Downloading dataset...", "type": "info"})
42
+
43
+ data_path = hf_hub_download(
44
+ repo_id="fbzu/btc_updown_5m_augmented_v1",
45
+ filename="btc_updown_5m_augmented_v1.parquet",
46
+ repo_type="dataset",
47
+ token=HF_TOKEN,
48
+ )
49
+ training_status["progress"].append({"msg": f"Dataset downloaded", "type": "info"})
50
+
51
+ training_status["progress"].append({"msg": "Downloading code...", "type": "info"})
52
+ code_dir = snapshot_download(
53
+ repo_id="fbzu/rl_btc_v4_iql",
54
+ repo_type="model",
55
+ token=HF_TOKEN,
56
+ allow_patterns=["rl_btc_v4/*"],
57
+ )
58
+ sys.path.insert(0, code_dir)
59
+
60
+ training_status["progress"].append({"msg": "Importing modules...", "type": "info"})
61
+ from rl_btc_v4.dataset import build_offline_rl_dataset
62
+ from rl_btc_v4.iql_trainer import IQLTrainer, IQLConfig
63
+ from rl_btc_v4.constants import N_ACTIONS
64
+
65
+ import torch
66
+ gpu_info = f"PyTorch {torch.__version__}, CUDA: {torch.cuda.is_available()}"
67
+ if torch.cuda.is_available():
68
+ gpu_info += f", GPU: {torch.cuda.get_device_name(0)}"
69
+ training_status["progress"].append({"msg": gpu_info, "type": "info"})
70
+
71
+ device = "cuda" if torch.cuda.is_available() else "cpu"
72
+
73
+ # Build dataset
74
+ training_status["progress"].append({"msg": "Building offline RL dataset...", "type": "info"})
75
+ train_dataset, test_dataset = build_offline_rl_dataset(
76
+ data_path=data_path,
77
+ history_length=30,
78
+ episode_span_days=30,
79
+ episode_stride_days=15,
80
+ risk_lambda=1.0,
81
+ soft_dd_penalty=0.50,
82
+ test_fraction=0.2,
83
+ seed=42,
84
+ )
85
+
86
+ train_info = f"Train: {train_dataset.n_transitions} transitions, Test: {test_dataset.n_transitions}"
87
+ training_status["progress"].append({"msg": train_info, "type": "info"})
88
+ training_status["progress"].append({"msg": f"State dim: {train_dataset.states.shape[1]}", "type": "info"})
89
+
90
+ # Train
91
+ state_dim = train_dataset.states.shape[1]
92
+ config = IQLConfig(
93
+ hidden_dim=256,
94
+ num_layers=2,
95
+ dropout=0.1,
96
+ expectile=0.7,
97
+ temperature=3.0,
98
+ gamma=0.99,
99
+ tau=0.005,
100
+ learning_rate=3e-4,
101
+ batch_size=512,
102
+ num_epochs=100,
103
+ weight_decay=1e-4,
104
+ device=device,
105
+ seed=42,
106
+ )
107
+
108
+ trainer = IQLTrainer(state_dim=state_dim, action_dim=N_ACTIONS, config=config)
109
+ t_start = time.time()
110
+
111
+ def progress_fn(epoch, metrics):
112
+ elapsed = time.time() - t_start
113
+ training_status["progress"].append({
114
+ "epoch": epoch,
115
+ "elapsed_s": round(elapsed, 1),
116
+ "q_loss": round(metrics["q_loss"], 6),
117
+ "v_loss": round(metrics["v_loss"], 6),
118
+ "policy_loss": round(metrics["policy_loss"], 6),
119
+ "advantage": round(metrics["advantage"], 6),
120
+ "type": "epoch"
121
+ })
122
+
123
+ training_status["progress"].append({"msg": "Starting IQL training...", "type": "info"})
124
+
125
+ result = trainer.train(
126
+ states=train_dataset.states,
127
+ actions=train_dataset.actions,
128
+ rewards=train_dataset.rewards,
129
+ next_states=train_dataset.next_states,
130
+ dones=train_dataset.dones,
131
+ eval_states=test_dataset.states,
132
+ eval_rewards=test_dataset.rewards,
133
+ progress_fn=progress_fn,
134
+ )
135
+
136
+ t_elapsed = time.time() - t_start
137
+ training_status["progress"].append({
138
+ "msg": f"Training complete in {t_elapsed:.1f}s",
139
+ "type": "success"
140
+ })
141
+
142
+ # Save artifacts
143
+ out_dir.mkdir(parents=True, exist_ok=True)
144
+ trainer.save(out_dir)
145
+
146
+ np.savez(
147
+ out_dir / "scaler.npz",
148
+ mean=train_dataset.mean,
149
+ std=train_dataset.std,
150
+ reward_mean=result["reward_mean"],
151
+ reward_std=result["reward_std"],
152
+ )
153
+
154
+ report = {
155
+ "algorithm": "IQL",
156
+ "config": config.__dict__,
157
+ "dataset": {"path": "fbzu/btc_updown_5m_augmented_v1"},
158
+ "results": result,
159
+ "training_time_seconds": t_elapsed,
160
+ "device": device,
161
+ }
162
+ (out_dir / "train_report.json").write_text(json.dumps(report, indent=2))
163
+ (out_dir / "training_logs.json").write_text(
164
+ json.dumps(training_status["progress"], indent=2)
165
+ )
166
+
167
+ # Upload to HF Hub
168
+ training_status["progress"].append({"msg": "Uploading model to HF Hub...", "type": "info"})
169
+ from huggingface_hub import HfApi
170
+ hf_api = HfApi(token=HF_TOKEN)
171
+ for f in out_dir.iterdir():
172
+ hf_api.upload_file(
173
+ path_or_fileobj=str(f),
174
+ path_in_repo=f.name,
175
+ repo_id="fbzu/rl_btc_v4_iql",
176
+ repo_type="model",
177
+ )
178
+
179
+ training_status["progress"].append({
180
+ "msg": f"βœ… Model uploaded to https://huggingface.co/fbzu/rl_btc_v4_iql",
181
+ "type": "success"
182
+ })
183
+
184
+ training_status["success"] = True
185
+ training_status["result"] = result
186
+
187
+ except Exception as e:
188
+ training_status["error"] = traceback.format_exc()
189
+ training_status["progress"].append({
190
+ "msg": f"❌ Error: {str(e)}",
191
+ "type": "error"
192
+ })
193
+
194
+ training_status["done"] = True
195
+ training_status["running"] = False
196
+ training_status["end_time"] = time.time()
197
+
198
+
199
+ # ── Start training in background ───────────────────────────────────────────
200
+ threading.Thread(target=run_training, daemon=True).start()
201
+
202
+ # ── Gradio UI ──────────────────────────────────────────────────────────────
203
+ import gradio as gr
204
+
205
+ def get_status():
206
+ lines = []
207
+ for p in training_status["progress"]:
208
+ msg = p.get("msg", "")
209
+ ptype = p.get("type", "info")
210
+ prefix = {"info": "ℹ️", "success": "βœ…", "error": "❌", "epoch": "πŸ“Š"}.get(ptype, " ")
211
+ lines.append(f"{prefix} {msg}")
212
+
213
+ if not training_status["done"] and not training_status["running"]:
214
+ lines.append("⏳ Initializing...")
215
+ elif not training_status["done"]:
216
+ lines.append("⏳ Training in progress...")
217
+ elif training_status["success"]:
218
+ t = training_status["end_time"] - training_status["start_time"]
219
+ lines.append(f"\nπŸŽ‰ Training complete in {t:.1f}s")
220
+ lines.append(f"\nπŸ“¦ Model: https://huggingface.co/fbzu/rl_btc_v4_iql")
221
+ elif training_status["error"]:
222
+ lines.append(f"\n❌ Training failed:\n{training_status['error']}")
223
+
224
+ return "\n".join(lines)
225
+
226
+
227
+ def get_logs():
228
+ epoch_logs = [p for p in training_status["progress"] if p.get("type") == "epoch"]
229
+ if not epoch_logs:
230
+ return "Waiting for training to start..."
231
+
232
+ lines = ["Epoch | Q Loss | V Loss | Policy Loss | Advantage | Time(s)"]
233
+ lines.append("-" * 80)
234
+ for log in epoch_logs:
235
+ lines.append(
236
+ f"{log['epoch']:5d} | {log['q_loss']:.6f} | {log['v_loss']:.6f} | "
237
+ f"{log['policy_loss']:.6f} | {log['advantage']:.8f} | {log['elapsed_s']:.0f}"
238
+ )
239
+ return "\n".join(lines)
240
+
241
+
242
+ with gr.Blocks(title="RL BTC v4 IQL Training") as demo:
243
+ gr.Markdown("# πŸ“ˆ RL BTC v4 β€” Implicit Q-Learning Trading Agent")
244
+ gr.Markdown("Training on zero-a10g (free GPU). Dataset: BTC 5m market data with risk-sensitive rewards.")
245
+
246
+ with gr.Row():
247
+ with gr.Column():
248
+ gr.Markdown("## Training Status")
249
+ status_box = gr.Textbox(value=get_status(), lines=15, label="Status")
250
+
251
+ with gr.Column():
252
+ gr.Markdown("## Training Logs")
253
+ logs_box = gr.Textbox(value=get_logs(), lines=20, label="Logs")
254
+
255
+ refresh_btn = gr.Button("πŸ”„ Refresh")
256
+ refresh_btn.click(fn=get_status, outputs=status_box)
257
+ refresh_btn.click(fn=get_logs, outputs=logs_box)
258
+
259
+ # Auto-refresh every 30s
260
+ demo.load(fn=get_status, outputs=status_box, every=30)
261
+ demo.load(fn=get_logs, outputs=logs_box, every=30)
262
+
263
+ gr.Markdown("""
264
+ **Config:** hidden=256, layers=2, dropout=0.1, expectile=0.7, temp=3.0,
265
+ gamma=0.99, lr=3e-4, batch=512, epochs=100
266
+
267
+ **Action space:** 8 actions (HOLD, FLAT, YES/NO at 10/25/50% exposure)
268
+
269
+ **Reward:** Risk-sensitive PnL with drawdown penalties
270
+ """)
271
+
272
+
273
+ if __name__ == "__main__":
274
+ demo.launch(server_name="0.0.0.0", server_port=7860)