Spaces:
Paused
Paused
Daksh C Jain Claude Sonnet 4.6 commited on
Commit Β·
d7456d6
1
Parent(s): 3fca800
Upgrade to full production MARL masterclass app
Browse files- 3-tab Gradio UI: Mission Control, Training Lab, Algorithm Guide
- Animated GIF replay with HUD overlay (step, reward, throttle bars)
- Side-by-side comparison GIF for multi-episode runs
- 4-panel mission overview: reward bars, 2D trajectory, cumulative reward, engine throttle
- 6-panel episode deep-dive: trajectory, altitude, angle, throttle timelines
- SAC fine-tuning in background thread with live metrics refresh
- Training dashboard: reward history, actor/critic loss, entropy coefficient
- Environment controls: gravity, wind, turbulence sliders
- Modular structure: core/mission.py, core/trainer.py, viz/charts.py, viz/replay.py
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- app.py +432 -300
- core/__init__.py +0 -0
- core/__pycache__/__init__.cpython-311.pyc +0 -0
- core/__pycache__/mission.cpython-311.pyc +0 -0
- core/__pycache__/trainer.cpython-311.pyc +0 -0
- core/mission.py +178 -0
- core/trainer.py +120 -0
- requirements.txt +2 -0
- viz/__init__.py +0 -0
- viz/__pycache__/__init__.cpython-311.pyc +0 -0
- viz/__pycache__/charts.cpython-311.pyc +0 -0
- viz/__pycache__/replay.cpython-311.pyc +0 -0
- viz/charts.py +246 -0
- viz/replay.py +160 -0
app.py
CHANGED
|
@@ -1,69 +1,194 @@
|
|
| 1 |
-
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import numpy as np
|
|
|
|
| 4 |
from stable_baselines3 import SAC
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
)
|
| 57 |
-
return summary
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
@import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;700;900&family=Share+Tech+Mono&family=Exo+2:wght@300;400;600&display=swap');
|
| 65 |
|
| 66 |
-
/* Reset & base */
|
| 67 |
*, *::before, *::after { box-sizing: border-box; }
|
| 68 |
|
| 69 |
body, .gradio-container {
|
|
@@ -72,296 +197,303 @@ body, .gradio-container {
|
|
| 72 |
font-family: 'Exo 2', sans-serif !important;
|
| 73 |
}
|
| 74 |
|
| 75 |
-
.gradio-container {
|
| 76 |
-
max-width: 860px !important;
|
| 77 |
-
margin: 0 auto !important;
|
| 78 |
-
padding: 0 1rem 3rem !important;
|
| 79 |
-
}
|
| 80 |
|
| 81 |
-
|
| 82 |
-
.
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
}
|
| 87 |
|
| 88 |
-
.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
font-family: 'Orbitron', monospace !important;
|
| 90 |
-
font-size: clamp(1.4rem,
|
| 91 |
-
font-weight: 900 !important;
|
| 92 |
-
|
| 93 |
-
color: #e8f4ff !important;
|
| 94 |
-
margin: 0 0 0.35rem !important;
|
| 95 |
-
text-transform: uppercase !important;
|
| 96 |
}
|
| 97 |
-
|
| 98 |
-
.mission-header .sub {
|
| 99 |
font-family: 'Share Tech Mono', monospace;
|
| 100 |
-
font-size: 0.
|
| 101 |
-
|
| 102 |
-
letter-spacing: 0.25em;
|
| 103 |
-
text-transform: uppercase;
|
| 104 |
}
|
| 105 |
|
| 106 |
-
.divider {
|
| 107 |
-
border: none;
|
| 108 |
-
border-top: 1px solid #0d2540;
|
| 109 |
-
margin: 1.5rem 0;
|
| 110 |
-
}
|
| 111 |
-
|
| 112 |
-
/* ββ Status badge strip ββ */
|
| 113 |
.status-strip {
|
| 114 |
-
display: flex;
|
| 115 |
-
|
| 116 |
-
justify-content: center;
|
| 117 |
-
flex-wrap: wrap;
|
| 118 |
-
margin: 1.2rem 0 2rem;
|
| 119 |
}
|
| 120 |
-
|
| 121 |
.badge {
|
| 122 |
font-family: 'Share Tech Mono', monospace;
|
| 123 |
-
font-size: 0.
|
| 124 |
-
|
| 125 |
-
padding: 5px 14px;
|
| 126 |
-
border-radius: 3px;
|
| 127 |
-
text-transform: uppercase;
|
| 128 |
-
}
|
| 129 |
-
|
| 130 |
-
.badge-green {
|
| 131 |
-
background: #041e12;
|
| 132 |
-
color: #2ddb7c;
|
| 133 |
-
border: 1px solid #0a5530;
|
| 134 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
-
.
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
background: #1a1002;
|
| 144 |
-
color: #f5a623;
|
| 145 |
-
border: 1px solid #5c3700;
|
| 146 |
}
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
border: 1px solid #0d2540;
|
| 152 |
-
border-radius: 6px;
|
| 153 |
-
padding: 1.5rem;
|
| 154 |
-
margin-bottom: 1rem;
|
| 155 |
}
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
font-
|
| 160 |
-
letter-spacing: 0.22em;
|
| 161 |
-
text-transform: uppercase;
|
| 162 |
-
color: #2d6a9f;
|
| 163 |
-
margin-bottom: 1rem;
|
| 164 |
}
|
| 165 |
|
| 166 |
-
|
| 167 |
-
.slider-wrap label,
|
| 168 |
-
.gradio-container label span {
|
| 169 |
font-family: 'Share Tech Mono', monospace !important;
|
| 170 |
-
font-size: 0.
|
| 171 |
-
|
| 172 |
-
text-transform: uppercase !important;
|
| 173 |
-
color: #4fb3ff !important;
|
| 174 |
}
|
| 175 |
-
|
| 176 |
input[type=range] {
|
| 177 |
-
-webkit-appearance: none;
|
| 178 |
-
|
| 179 |
-
width: 100%;
|
| 180 |
-
height: 3px;
|
| 181 |
-
background: #0d2540;
|
| 182 |
-
border-radius: 2px;
|
| 183 |
-
outline: none;
|
| 184 |
-
margin: 0.5rem 0;
|
| 185 |
}
|
| 186 |
-
|
| 187 |
input[type=range]::-webkit-slider-thumb {
|
| 188 |
-
-webkit-appearance: none;
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
border-radius: 50%;
|
| 192 |
-
background: #4fb3ff;
|
| 193 |
-
cursor: pointer;
|
| 194 |
-
border: 2px solid #030b1a;
|
| 195 |
-
box-shadow: 0 0 8px rgba(79,179,255,0.5);
|
| 196 |
}
|
| 197 |
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
height:
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
cursor: pointer;
|
| 204 |
-
border: 2px solid #030b1a;
|
| 205 |
}
|
| 206 |
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
text-transform: uppercase !important;
|
| 214 |
-
background: linear-gradient(135deg, #0a2a52 0%, #0d3a72 100%) !important;
|
| 215 |
-
color: #4fb3ff !important;
|
| 216 |
-
border: 1px solid #1a5a9e !important;
|
| 217 |
-
border-radius: 4px !important;
|
| 218 |
-
padding: 0.85rem 2rem !important;
|
| 219 |
-
cursor: pointer !important;
|
| 220 |
-
width: 100% !important;
|
| 221 |
-
transition: all 0.2s ease !important;
|
| 222 |
-
}
|
| 223 |
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
color: #a8d8ff !important;
|
| 228 |
-
transform: translateY(-1px) !important;
|
| 229 |
-
box-shadow: 0 4px 20px rgba(79,179,255,0.25) !important;
|
| 230 |
-
}
|
| 231 |
|
| 232 |
-
#
|
| 233 |
-
transform: translateY(0) !important;
|
| 234 |
-
}
|
| 235 |
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
#output-box textarea,
|
| 239 |
-
.gradio-container textarea {
|
| 240 |
-
font-family: 'Share Tech Mono', monospace !important;
|
| 241 |
-
font-size: 0.82rem !important;
|
| 242 |
-
line-height: 1.7 !important;
|
| 243 |
-
background: #020810 !important;
|
| 244 |
-
color: #7fcfff !important;
|
| 245 |
-
border: 1px solid #0d2540 !important;
|
| 246 |
-
border-radius: 4px !important;
|
| 247 |
-
padding: 1.2rem !important;
|
| 248 |
-
resize: none !important;
|
| 249 |
-
caret-color: #4fb3ff !important;
|
| 250 |
-
}
|
| 251 |
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
background: #0d2540 !important;
|
| 259 |
-
border-radius: 3px !important;
|
| 260 |
-
}
|
| 261 |
|
| 262 |
-
|
| 263 |
-
background: linear-gradient(90deg, #1150a0, #4fb3ff) !important;
|
| 264 |
-
border-radius: 3px !important;
|
| 265 |
-
}
|
| 266 |
|
| 267 |
-
|
| 268 |
-
.mission-footer {
|
| 269 |
-
text-align: center;
|
| 270 |
-
font-family: 'Share Tech Mono', monospace;
|
| 271 |
-
font-size: 0.65rem;
|
| 272 |
-
color: #1e3d5c;
|
| 273 |
-
letter-spacing: 0.2em;
|
| 274 |
-
text-transform: uppercase;
|
| 275 |
-
padding: 2rem 0 0;
|
| 276 |
-
}
|
| 277 |
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
}
|
| 285 |
|
| 286 |
-
|
| 287 |
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
background: transparent !important;
|
| 291 |
-
border: none !important;
|
| 292 |
-
}
|
| 293 |
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
}
|
| 297 |
-
"""
|
| 298 |
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
header_html = """
|
| 302 |
-
<div class="mission-header">
|
| 303 |
-
<div class="sub">Autonomous Flight Intelligence System Β· v2.0</div>
|
| 304 |
-
<h1>⬑ SpaceX Mission Control</h1>
|
| 305 |
-
<div class="sub">SAC Neural Lander Β· LunarLander-v3 Simulation</div>
|
| 306 |
-
</div>
|
| 307 |
-
<hr class="divider"/>
|
| 308 |
-
<div class="status-strip">
|
| 309 |
-
<span class="badge badge-green">β SAC MODEL LOADED</span>
|
| 310 |
-
<span class="badge badge-blue">β SIMULATION READY</span>
|
| 311 |
-
<span class="badge badge-amber">β AWAITING LAUNCH</span>
|
| 312 |
-
</div>
|
| 313 |
-
"""
|
| 314 |
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
</div>
|
| 319 |
-
"""
|
| 320 |
|
| 321 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
|
| 323 |
-
|
| 324 |
|
| 325 |
-
|
| 326 |
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
)
|
| 339 |
-
launch_btn
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
)
|
| 344 |
|
| 345 |
-
|
| 346 |
-
with gr.
|
| 347 |
-
gr.Markdown("**FLIGHT TELEMETRY Β· MISSION REPORT**", elem_classes=["panel-label"])
|
| 348 |
-
output = gr.Textbox(
|
| 349 |
-
label="",
|
| 350 |
-
lines=18,
|
| 351 |
-
max_lines=24,
|
| 352 |
-
placeholder="Awaiting telemetry data...\n\nPress INITIATE LAUNCH SEQUENCE to begin simulation.",
|
| 353 |
-
elem_id="output-box",
|
| 354 |
-
elem_classes=["telemetry"],
|
| 355 |
-
)
|
| 356 |
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
)
|
| 363 |
|
| 364 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
|
| 366 |
if __name__ == "__main__":
|
| 367 |
-
demo.launch()
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SpaceX Mission Control β SAC Rocket Lander
|
| 3 |
+
Production Gradio application: simulate, visualise, analyse, and train
|
| 4 |
+
a Soft Actor-Critic agent on the LunarLander-v3 continuous control task.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
import os
|
| 9 |
import numpy as np
|
| 10 |
+
import gradio as gr
|
| 11 |
from stable_baselines3 import SAC
|
| 12 |
+
|
| 13 |
+
from core.mission import run_mission, MissionResult
|
| 14 |
+
from core.trainer import TrainingState, start_training
|
| 15 |
+
from viz.charts import (
|
| 16 |
+
mission_overview, single_episode_detail,
|
| 17 |
+
training_dashboard, empty_figure,
|
| 18 |
+
)
|
| 19 |
+
from viz.replay import make_episode_gif, make_comparison_gif
|
| 20 |
+
|
| 21 |
+
# ββ Model loading βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 22 |
+
|
| 23 |
+
_MODEL_PATHS = ["sac_finetuned.zip", "sac_rocket_lander.zip"]
|
| 24 |
+
_model: SAC | None = None
|
| 25 |
+
|
| 26 |
+
def _load_model(path: str | None = None) -> tuple[SAC, str]:
|
| 27 |
+
candidates = ([path] if path else []) + _MODEL_PATHS
|
| 28 |
+
for p in candidates:
|
| 29 |
+
if p and os.path.exists(p):
|
| 30 |
+
try:
|
| 31 |
+
return SAC.load(p), p
|
| 32 |
+
except Exception:
|
| 33 |
+
continue
|
| 34 |
+
raise FileNotFoundError("No valid SAC checkpoint found.")
|
| 35 |
+
|
| 36 |
+
def _get_model() -> SAC:
|
| 37 |
+
global _model
|
| 38 |
+
if _model is None:
|
| 39 |
+
_model, _ = _load_model()
|
| 40 |
+
return _model
|
| 41 |
+
|
| 42 |
+
# ββ Global training state βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 43 |
+
_train_state = TrainingState()
|
| 44 |
+
|
| 45 |
+
# ββ Callbacks βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 46 |
+
|
| 47 |
+
def cb_run_mission(
|
| 48 |
+
n_episodes: int,
|
| 49 |
+
gravity: float,
|
| 50 |
+
enable_wind: bool,
|
| 51 |
+
wind_power: float,
|
| 52 |
+
turbulence: float,
|
| 53 |
+
render_gif: bool,
|
| 54 |
+
progress: gr.Progress = gr.Progress(),
|
| 55 |
+
) -> tuple:
|
| 56 |
+
try:
|
| 57 |
+
model = _get_model()
|
| 58 |
+
except FileNotFoundError as e:
|
| 59 |
+
empty = empty_figure(str(e))
|
| 60 |
+
return empty, None, empty, str(e), gr.update(choices=[])
|
| 61 |
+
|
| 62 |
+
mission, all_frames = run_mission(
|
| 63 |
+
model,
|
| 64 |
+
n_episodes=int(n_episodes),
|
| 65 |
+
gravity=float(gravity),
|
| 66 |
+
enable_wind=bool(enable_wind),
|
| 67 |
+
wind_power=float(wind_power),
|
| 68 |
+
turbulence_power=float(turbulence),
|
| 69 |
+
render=bool(render_gif),
|
| 70 |
+
progress_cb=progress,
|
| 71 |
)
|
|
|
|
| 72 |
|
| 73 |
+
overview_fig = mission_overview(mission)
|
| 74 |
+
|
| 75 |
+
gif_path = None
|
| 76 |
+
if render_gif and all_frames:
|
| 77 |
+
if n_episodes >= 2:
|
| 78 |
+
gif_path = make_comparison_gif(all_frames, mission.episodes, fps=15)
|
| 79 |
+
else:
|
| 80 |
+
gif_path = make_episode_gif(all_frames[0], mission.episodes[0], fps=15)
|
| 81 |
+
|
| 82 |
+
best = mission.best
|
| 83 |
+
detail_fig = single_episode_detail(best)
|
| 84 |
|
| 85 |
+
sr = mission.success_rate * 100
|
| 86 |
+
icon = "π" if mission.avg_reward >= 150 else "π₯"
|
| 87 |
+
stats_md = f"""
|
| 88 |
+
### {icon} Mission Complete
|
| 89 |
|
| 90 |
+
| Metric | Value |
|
| 91 |
+
|---|---|
|
| 92 |
+
| **Avg Reward** | `{mission.avg_reward:+.2f}` |
|
| 93 |
+
| **Best** | `{best.total_reward:+.2f}` ({best.status_emoji} Ep {best.episode_idx+1}) |
|
| 94 |
+
| **Worst** | `{mission.worst.total_reward:+.2f}` ({mission.worst.status_emoji} Ep {mission.worst.episode_idx+1}) |
|
| 95 |
+
| **Success Rate** | `{sr:.1f}%` |
|
| 96 |
+
| **Episodes** | `{len(mission.episodes)}` |
|
| 97 |
+
|
| 98 |
+
**Per-Episode Scores:**
|
| 99 |
+
"""
|
| 100 |
+
per_ep = "".join(
|
| 101 |
+
f"- `#{e.episode_idx+1}` {e.status_emoji} **{e.status}** β {e.total_reward:+.1f} ({len(e.steps)} steps)\n"
|
| 102 |
+
for e in mission.episodes
|
| 103 |
+
)
|
| 104 |
+
stats_md += per_ep
|
| 105 |
+
|
| 106 |
+
ep_choices = [
|
| 107 |
+
f"#{e.episode_idx+1} β {e.status_emoji} {e.status} ({e.total_reward:+.1f})"
|
| 108 |
+
for e in mission.episodes
|
| 109 |
+
]
|
| 110 |
+
|
| 111 |
+
_last_mission["data"] = mission
|
| 112 |
+
_last_mission["frames"] = all_frames
|
| 113 |
+
|
| 114 |
+
return overview_fig, gif_path, detail_fig, stats_md, gr.update(choices=ep_choices, value=ep_choices[0])
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
_last_mission: dict = {"data": None, "frames": []}
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def cb_select_episode(selection: str) -> tuple:
|
| 121 |
+
mission: MissionResult | None = _last_mission.get("data")
|
| 122 |
+
all_frames = _last_mission.get("frames", [])
|
| 123 |
+
if not mission or not selection:
|
| 124 |
+
return empty_figure("Run a mission first."), None
|
| 125 |
+
try:
|
| 126 |
+
idx = int(selection.split("#")[1].split(" ")[0]) - 1
|
| 127 |
+
except Exception:
|
| 128 |
+
idx = 0
|
| 129 |
+
ep = mission.episodes[idx]
|
| 130 |
+
fig = single_episode_detail(ep)
|
| 131 |
+
gif = None
|
| 132 |
+
if all_frames and idx < len(all_frames):
|
| 133 |
+
gif = make_episode_gif(all_frames[idx], ep, fps=15)
|
| 134 |
+
return fig, gif
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def cb_start_training(total_steps: int, lr: float, batch_size: int) -> str:
|
| 138 |
+
global _train_state
|
| 139 |
+
if _train_state.running:
|
| 140 |
+
return "Training already in progress."
|
| 141 |
+
_train_state = TrainingState()
|
| 142 |
+
start_training(
|
| 143 |
+
base_model_path="sac_rocket_lander.zip",
|
| 144 |
+
total_timesteps=int(total_steps),
|
| 145 |
+
learning_rate=float(lr),
|
| 146 |
+
batch_size=int(batch_size),
|
| 147 |
+
state=_train_state,
|
| 148 |
+
save_path="sac_finetuned.zip",
|
| 149 |
+
)
|
| 150 |
+
return "Training started. Click **Refresh** to update charts."
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def cb_stop_training() -> str:
|
| 154 |
+
_train_state.running = False
|
| 155 |
+
return "Stop signal sent."
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def cb_refresh_training() -> tuple:
|
| 159 |
+
fig = training_dashboard(_train_state)
|
| 160 |
+
n_ep = len(_train_state.episode_rewards)
|
| 161 |
+
rolling = float(np.mean(_train_state.episode_rewards[-20:])) if n_ep else 0.0
|
| 162 |
+
progress_pct = (_train_state.timestep / max(_train_state.total_timesteps, 1)) * 100
|
| 163 |
+
status_md = f"""
|
| 164 |
+
| | |
|
| 165 |
+
|---|---|
|
| 166 |
+
| **Status** | `{_train_state.status}` |
|
| 167 |
+
| **Progress** | `{progress_pct:.1f}%` |
|
| 168 |
+
| **Episodes** | `{n_ep}` |
|
| 169 |
+
| **Rolling Reward (20)** | `{rolling:+.1f}` |
|
| 170 |
+
| **Best Reward** | `{_train_state.best_reward:+.1f}` |
|
| 171 |
+
"""
|
| 172 |
+
return fig, status_md
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def cb_load_finetuned() -> str:
|
| 176 |
+
global _model
|
| 177 |
+
for path in _MODEL_PATHS:
|
| 178 |
+
if os.path.exists(path):
|
| 179 |
+
try:
|
| 180 |
+
_model = SAC.load(path)
|
| 181 |
+
return f"Model loaded from `{path}`."
|
| 182 |
+
except Exception as e:
|
| 183 |
+
return f"Failed to load `{path}`: {e}"
|
| 184 |
+
return "No checkpoint found."
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# ββ CSS βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 188 |
+
|
| 189 |
+
CSS = """
|
| 190 |
@import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;700;900&family=Share+Tech+Mono&family=Exo+2:wght@300;400;600&display=swap');
|
| 191 |
|
|
|
|
| 192 |
*, *::before, *::after { box-sizing: border-box; }
|
| 193 |
|
| 194 |
body, .gradio-container {
|
|
|
|
| 197 |
font-family: 'Exo 2', sans-serif !important;
|
| 198 |
}
|
| 199 |
|
| 200 |
+
.gradio-container { max-width: 1200px !important; margin: 0 auto !important; }
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
+
.tab-nav { background: #060f1e !important; border-bottom: 1px solid #0d2540 !important; }
|
| 203 |
+
.tab-nav button {
|
| 204 |
+
font-family: 'Share Tech Mono', monospace !important;
|
| 205 |
+
font-size: 0.72rem !important; letter-spacing: 0.18em !important;
|
| 206 |
+
color: #3a6080 !important; background: transparent !important;
|
| 207 |
+
border: none !important; text-transform: uppercase !important;
|
| 208 |
+
padding: 0.7rem 1.4rem !important;
|
| 209 |
+
}
|
| 210 |
+
.tab-nav button.selected {
|
| 211 |
+
color: #4fb3ff !important;
|
| 212 |
+
border-bottom: 2px solid #4fb3ff !important;
|
| 213 |
}
|
| 214 |
|
| 215 |
+
.mc-header {
|
| 216 |
+
text-align: center; padding: 2rem 1rem 1rem;
|
| 217 |
+
border-bottom: 1px solid #0d2540; margin-bottom: 1.5rem;
|
| 218 |
+
}
|
| 219 |
+
.mc-header h1 {
|
| 220 |
font-family: 'Orbitron', monospace !important;
|
| 221 |
+
font-size: clamp(1.4rem, 3.5vw, 2.2rem) !important;
|
| 222 |
+
font-weight: 900 !important; letter-spacing: 0.1em !important;
|
| 223 |
+
color: #e8f4ff !important; margin: 0 !important;
|
|
|
|
|
|
|
|
|
|
| 224 |
}
|
| 225 |
+
.mc-sub {
|
|
|
|
| 226 |
font-family: 'Share Tech Mono', monospace;
|
| 227 |
+
font-size: 0.72rem; color: #2d6a9f;
|
| 228 |
+
letter-spacing: 0.3em; text-transform: uppercase; margin-top: 0.3rem;
|
|
|
|
|
|
|
| 229 |
}
|
| 230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
.status-strip {
|
| 232 |
+
display: flex; gap: 0.6rem; justify-content: center;
|
| 233 |
+
flex-wrap: wrap; margin: 1rem 0;
|
|
|
|
|
|
|
|
|
|
| 234 |
}
|
|
|
|
| 235 |
.badge {
|
| 236 |
font-family: 'Share Tech Mono', monospace;
|
| 237 |
+
font-size: 0.68rem; letter-spacing: 0.15em;
|
| 238 |
+
padding: 4px 12px; border-radius: 3px; text-transform: uppercase;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
}
|
| 240 |
+
.badge-green { background:#041e12; color:#2ddb7c; border:1px solid #0a5530; }
|
| 241 |
+
.badge-blue { background:#020f20; color:#4fb3ff; border:1px solid #0b3362; }
|
| 242 |
+
.badge-amber { background:#1a1002; color:#f5a623; border:1px solid #5c3700; }
|
| 243 |
+
.badge-purple { background:#120920; color:#c77dff; border:1px solid #4a1a7a; }
|
| 244 |
|
| 245 |
+
button.primary {
|
| 246 |
+
font-family: 'Orbitron', monospace !important;
|
| 247 |
+
font-size: 0.82rem !important; font-weight: 700 !important;
|
| 248 |
+
letter-spacing: 0.15em !important; text-transform: uppercase !important;
|
| 249 |
+
background: linear-gradient(135deg,#0a2a52,#0d3a72) !important;
|
| 250 |
+
color: #4fb3ff !important; border: 1px solid #1a5a9e !important;
|
| 251 |
+
border-radius: 4px !important; transition: all 0.2s !important;
|
|
|
|
|
|
|
|
|
|
| 252 |
}
|
| 253 |
+
button.primary:hover {
|
| 254 |
+
background: linear-gradient(135deg,#0d3a72,#1150a0) !important;
|
| 255 |
+
border-color: #4fb3ff !important;
|
| 256 |
+
box-shadow: 0 4px 20px rgba(79,179,255,0.25) !important;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
}
|
| 258 |
+
button.stop {
|
| 259 |
+
background: linear-gradient(135deg,#2a0a0a,#4a1010) !important;
|
| 260 |
+
color: #ff4d6d !important; border: 1px solid #7a1a1a !important;
|
| 261 |
+
font-family: 'Share Tech Mono', monospace !important;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
}
|
| 263 |
|
| 264 |
+
label span, .gradio-container label {
|
|
|
|
|
|
|
| 265 |
font-family: 'Share Tech Mono', monospace !important;
|
| 266 |
+
font-size: 0.72rem !important; letter-spacing: 0.15em !important;
|
| 267 |
+
text-transform: uppercase !important; color: #4fb3ff !important;
|
|
|
|
|
|
|
| 268 |
}
|
|
|
|
| 269 |
input[type=range] {
|
| 270 |
+
-webkit-appearance: none; height: 3px;
|
| 271 |
+
background: #0d2540; border-radius: 2px; outline: none;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
}
|
|
|
|
| 273 |
input[type=range]::-webkit-slider-thumb {
|
| 274 |
+
-webkit-appearance: none; width: 16px; height: 16px;
|
| 275 |
+
border-radius: 50%; background: #4fb3ff; cursor: pointer;
|
| 276 |
+
border: 2px solid #030b1a; box-shadow: 0 0 8px rgba(79,179,255,0.5);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
}
|
| 278 |
|
| 279 |
+
textarea, .gradio-container textarea {
|
| 280 |
+
font-family: 'Share Tech Mono', monospace !important;
|
| 281 |
+
font-size: 0.82rem !important; line-height: 1.7 !important;
|
| 282 |
+
background: #020810 !important; color: #7fcfff !important;
|
| 283 |
+
border: 1px solid #0d2540 !important; border-radius: 4px !important;
|
|
|
|
|
|
|
| 284 |
}
|
| 285 |
|
| 286 |
+
table { width: 100%; border-collapse: collapse; }
|
| 287 |
+
th { background: #060f1e; color: #4fb3ff;
|
| 288 |
+
font-family: 'Share Tech Mono', monospace;
|
| 289 |
+
font-size: 0.7rem; letter-spacing: 0.1em; padding: 6px 10px; }
|
| 290 |
+
td { border-top: 1px solid #0d2540; padding: 6px 10px;
|
| 291 |
+
color: #c8ddf0; font-size: 0.85rem; }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
|
| 293 |
+
footer { display: none !important; }
|
| 294 |
+
.gradio-container .block { background: transparent !important; border: none !important; }
|
| 295 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
|
| 297 |
+
# ββ Theory ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
|
|
|
| 298 |
|
| 299 |
+
THEORY_MD = """
|
| 300 |
+
## Soft Actor-Critic (SAC)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
|
| 302 |
+
SAC is an **off-policy, maximum-entropy** deep RL algorithm for continuous
|
| 303 |
+
action spaces. It simultaneously maximises expected return *and* policy entropy,
|
| 304 |
+
encouraging exploration while converging to a stable policy.
|
| 305 |
|
| 306 |
+
### Objective
|
| 307 |
+
$$J(\\pi) = \\sum_t \\mathbb{E}_{(s_t,a_t)\\sim\\rho_\\pi}\\left[ r(s_t,a_t) + \\alpha\\,\\mathcal{H}(\\pi(\\cdot|s_t)) \\right]$$
|
|
|
|
|
|
|
|
|
|
| 308 |
|
| 309 |
+
The temperature $\\alpha$ is **auto-tuned** to a target entropy level.
|
|
|
|
|
|
|
|
|
|
| 310 |
|
| 311 |
+
### Architecture
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
|
| 313 |
+
| Component | Role |
|
| 314 |
+
|---|---|
|
| 315 |
+
| **Actor** $\\pi_\\phi(a\\|s)$ | Gaussian policy β outputs mean & log-std |
|
| 316 |
+
| **Critic 1** $Q_{\\theta_1}(s,a)$ | Q-value estimator |
|
| 317 |
+
| **Critic 2** $Q_{\\theta_2}(s,a)$ | Clipped double-Q: take min to reduce overestimation |
|
| 318 |
+
| **Target Critics** | Soft-updated copies ($\\tau=0.005$) for stable TD targets |
|
|
|
|
| 319 |
|
| 320 |
+
### Update Rules
|
| 321 |
|
| 322 |
+
**Critic** β minimise Bellman residual:
|
| 323 |
+
$$y = r + \\gamma\\min_i Q_{\\bar\\theta_i}(s',\\tilde a') - \\alpha\\log\\pi(\\tilde a'|s')$$
|
|
|
|
|
|
|
|
|
|
| 324 |
|
| 325 |
+
**Actor** β maximise Q + entropy:
|
| 326 |
+
$$\\mathcal{L}(\\phi) = \\mathbb{E}\\left[\\alpha\\log\\pi_\\phi(a|s) - \\min_i Q_{\\theta_i}(s,a)\\right]$$
|
|
|
|
|
|
|
| 327 |
|
| 328 |
+
**Temperature** β match target entropy $\\bar{\\mathcal{H}}$:
|
| 329 |
+
$$\\mathcal{L}(\\alpha) = \\mathbb{E}\\left[-\\alpha(\\log\\pi(a|s)+\\bar{\\mathcal{H}})\\right]$$
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
|
| 331 |
+
---
|
| 332 |
+
|
| 333 |
+
## LunarLander-v3 (Continuous)
|
|
|
|
|
|
|
| 334 |
|
| 335 |
+
| Property | Value |
|
| 336 |
+
|---|---|
|
| 337 |
+
| **State** | 8-dim: pos (x,y), vel (vx,vy), angle, angular vel, leg contacts |
|
| 338 |
+
| **Action** | 2-dim continuous: main throttle, lateral thrust β [β1,1] |
|
| 339 |
+
| **Reward** | +100 each leg contact, +100 landing, β100 crash |
|
| 340 |
+
| **Solved** | Episode reward β₯ 200 |
|
| 341 |
|
| 342 |
+
---
|
| 343 |
|
| 344 |
+
## Model Hyperparameters
|
| 345 |
|
| 346 |
+
| Parameter | Value |
|
| 347 |
+
|---|---|
|
| 348 |
+
| `learning_rate` | 3Γ10β»β΄ |
|
| 349 |
+
| `buffer_size` | 1,000,000 |
|
| 350 |
+
| `batch_size` | 256 |
|
| 351 |
+
| `tau` | 0.005 |
|
| 352 |
+
| `gamma` | 0.99 |
|
| 353 |
+
| `target_entropy` | β2.0 |
|
| 354 |
+
|
| 355 |
+
---
|
| 356 |
+
|
| 357 |
+
## Reading the Charts
|
| 358 |
+
|
| 359 |
+
- **Reward bars**: green β₯ 150, amber β₯ 0, red < 0
|
| 360 |
+
- **Trajectory plot**: `β
` = successful landing, `Γ` = crash
|
| 361 |
+
- **Engine throttle**: main (blue) fires downward; lateral (amber) steers
|
| 362 |
+
- **Training reward**: smoothed line (solid) trends matter more than raw (faded)
|
| 363 |
+
- **Actor loss**: negative values normal β actor maximises Q, so loss = βQ
|
| 364 |
+
- **Entropy coef**: starts high, decreases as policy converges
|
| 365 |
+
"""
|
| 366 |
+
|
| 367 |
+
# ββ Build UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 368 |
+
|
| 369 |
+
with gr.Blocks(title="SpaceX Mission Control β SAC Rocket Lander") as demo:
|
| 370 |
+
|
| 371 |
+
gr.HTML("""
|
| 372 |
+
<div class="mc-header">
|
| 373 |
+
<div class="mc-sub">Autonomous Flight Intelligence System Β· SAC v2.0</div>
|
| 374 |
+
<h1>⬑ SpaceX Mission Control</h1>
|
| 375 |
+
<div class="mc-sub">Soft Actor-Critic Β· LunarLander-v3 Β· Continuous Control</div>
|
| 376 |
+
</div>
|
| 377 |
+
<div class="status-strip">
|
| 378 |
+
<span class="badge badge-green">β SAC MODEL LOADED</span>
|
| 379 |
+
<span class="badge badge-blue">β PHYSICS ENGINE READY</span>
|
| 380 |
+
<span class="badge badge-amber">β TELEMETRY ONLINE</span>
|
| 381 |
+
<span class="badge badge-purple">β TRAINING MODULE ARMED</span>
|
| 382 |
+
</div>
|
| 383 |
+
""")
|
| 384 |
+
|
| 385 |
+
with gr.Tabs():
|
| 386 |
+
|
| 387 |
+
# ββ Mission Control ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 388 |
+
with gr.Tab("π MISSION CONTROL"):
|
| 389 |
+
|
| 390 |
+
with gr.Row():
|
| 391 |
+
with gr.Column(scale=1, min_width=300):
|
| 392 |
+
gr.HTML('<div class="mc-sub" style="margin-bottom:0.8rem">MISSION PARAMETERS</div>')
|
| 393 |
+
|
| 394 |
+
n_episodes = gr.Slider(1, 10, value=3, step=1,
|
| 395 |
+
label="Landing Attempts")
|
| 396 |
+
gravity = gr.Slider(-20.0, -1.0, value=-10.0, step=0.5,
|
| 397 |
+
label="Gravity (m/sΒ²)")
|
| 398 |
+
enable_wind = gr.Checkbox(label="Enable Wind Disturbance", value=False)
|
| 399 |
+
wind_power = gr.Slider(0.0, 20.0, value=5.0, step=0.5,
|
| 400 |
+
label="Wind Power", visible=False)
|
| 401 |
+
turbulence = gr.Slider(0.0, 2.0, value=0.5, step=0.1,
|
| 402 |
+
label="Turbulence Power", visible=False)
|
| 403 |
+
render_gif = gr.Checkbox(label="Render Animated Replay", value=True)
|
| 404 |
+
|
| 405 |
+
enable_wind.change(
|
| 406 |
+
lambda v: (gr.update(visible=v), gr.update(visible=v)),
|
| 407 |
+
inputs=enable_wind,
|
| 408 |
+
outputs=[wind_power, turbulence],
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
launch_btn = gr.Button("π INITIATE LAUNCH SEQUENCE", variant="primary")
|
| 412 |
+
|
| 413 |
+
gr.HTML('<div class="mc-sub" style="margin-top:1.2rem;margin-bottom:0.4rem">MODEL</div>')
|
| 414 |
+
load_btn = gr.Button("π Reload Checkpoint")
|
| 415 |
+
load_status = gr.Textbox(label="", lines=1, interactive=False,
|
| 416 |
+
placeholder="Model statusβ¦")
|
| 417 |
+
load_btn.click(cb_load_finetuned, outputs=load_status)
|
| 418 |
+
|
| 419 |
+
with gr.Column(scale=2):
|
| 420 |
+
stats_md = gr.Markdown("*Configure mission parameters and click Launch.*")
|
| 421 |
+
episode_selector = gr.Dropdown(
|
| 422 |
+
choices=[], label="Inspect Episode", interactive=True,
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
with gr.Row():
|
| 426 |
+
overview_plot = gr.Plot(label="Mission Overview Dashboard")
|
| 427 |
+
|
| 428 |
+
with gr.Row():
|
| 429 |
+
with gr.Column(scale=1):
|
| 430 |
+
detail_plot = gr.Plot(label="Episode Deep-Dive")
|
| 431 |
+
with gr.Column(scale=1):
|
| 432 |
+
replay_gif = gr.Image(
|
| 433 |
+
label="Episode Replay (GIF with HUD)",
|
| 434 |
+
type="filepath",
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
episode_selector.change(
|
| 438 |
+
cb_select_episode,
|
| 439 |
+
inputs=episode_selector,
|
| 440 |
+
outputs=[detail_plot, replay_gif],
|
| 441 |
)
|
| 442 |
+
launch_btn.click(
|
| 443 |
+
cb_run_mission,
|
| 444 |
+
inputs=[n_episodes, gravity, enable_wind, wind_power, turbulence, render_gif],
|
| 445 |
+
outputs=[overview_plot, replay_gif, detail_plot, stats_md, episode_selector],
|
| 446 |
)
|
| 447 |
|
| 448 |
+
# ββ Training Lab βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 449 |
+
with gr.Tab("π§ͺ TRAINING LAB"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
|
| 451 |
+
gr.Markdown("### Fine-tune the SAC agent in your browser")
|
| 452 |
+
gr.Markdown(
|
| 453 |
+
"Runs in a background thread β click **Refresh Metrics** to pull updates. "
|
| 454 |
+
"The fine-tuned model saves to `sac_finetuned.zip` and is used automatically."
|
| 455 |
+
)
|
|
|
|
| 456 |
|
| 457 |
+
with gr.Row():
|
| 458 |
+
with gr.Column(scale=1):
|
| 459 |
+
gr.HTML('<div class="mc-sub" style="margin-bottom:0.8rem">HYPERPARAMETERS</div>')
|
| 460 |
+
train_steps = gr.Slider(5_000, 200_000, value=20_000, step=5_000,
|
| 461 |
+
label="Total Timesteps")
|
| 462 |
+
train_lr = gr.Slider(1e-5, 1e-3, value=3e-4, step=1e-5,
|
| 463 |
+
label="Learning Rate")
|
| 464 |
+
train_batch = gr.Slider(64, 512, value=256, step=64,
|
| 465 |
+
label="Batch Size")
|
| 466 |
+
|
| 467 |
+
with gr.Row():
|
| 468 |
+
btn_train_start = gr.Button("βΆ Start Training", variant="primary")
|
| 469 |
+
btn_train_stop = gr.Button("βΉ Stop", variant="stop")
|
| 470 |
+
btn_refresh = gr.Button("π Refresh Metrics")
|
| 471 |
+
train_msg = gr.Textbox(label="", lines=2, interactive=False)
|
| 472 |
+
|
| 473 |
+
with gr.Column(scale=2):
|
| 474 |
+
train_status_md = gr.Markdown("*Start training to see live metrics.*")
|
| 475 |
+
train_plot = gr.Plot(label="Live Training Dashboard")
|
| 476 |
+
|
| 477 |
+
btn_train_start.click(
|
| 478 |
+
cb_start_training,
|
| 479 |
+
inputs=[train_steps, train_lr, train_batch],
|
| 480 |
+
outputs=train_msg,
|
| 481 |
+
)
|
| 482 |
+
btn_train_stop.click(cb_stop_training, outputs=train_msg)
|
| 483 |
+
btn_refresh.click(cb_refresh_training, outputs=[train_plot, train_status_md])
|
| 484 |
+
|
| 485 |
+
# ββ Algorithm Guide ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 486 |
+
with gr.Tab("π ALGORITHM GUIDE"):
|
| 487 |
+
gr.Markdown(THEORY_MD)
|
| 488 |
+
|
| 489 |
+
gr.HTML("""
|
| 490 |
+
<div style="text-align:center;font-family:'Share Tech Mono',monospace;
|
| 491 |
+
font-size:0.65rem;color:#1e3d5c;letter-spacing:0.2em;
|
| 492 |
+
text-transform:uppercase;padding:2rem 0 1rem;">
|
| 493 |
+
Powered by Stable-Baselines3 Β· Soft Actor-Critic Β·
|
| 494 |
+
Gymnasium LunarLander-v3 Β· Gradio
|
| 495 |
+
</div>
|
| 496 |
+
""")
|
| 497 |
|
| 498 |
if __name__ == "__main__":
|
| 499 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, css=CSS)
|
core/__init__.py
ADDED
|
File without changes
|
core/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (197 Bytes). View file
|
|
|
core/__pycache__/mission.cpython-311.pyc
ADDED
|
Binary file (10.4 kB). View file
|
|
|
core/__pycache__/trainer.cpython-311.pyc
ADDED
|
Binary file (7.17 kB). View file
|
|
|
core/mission.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Mission runner β executes SAC agent episodes, collects full telemetry.
|
| 3 |
+
Returns structured data for both the UI and the visualization layer.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
import numpy as np
|
| 8 |
+
import gymnasium as gym
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from stable_baselines3 import SAC
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# ββ Telemetry data structures βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class StepData:
|
| 17 |
+
x: float
|
| 18 |
+
y: float
|
| 19 |
+
vx: float
|
| 20 |
+
vy: float
|
| 21 |
+
angle: float
|
| 22 |
+
angular_vel: float
|
| 23 |
+
left_leg: bool
|
| 24 |
+
right_leg: bool
|
| 25 |
+
reward: float
|
| 26 |
+
action_main: float # main engine throttle [-1, 1]
|
| 27 |
+
action_lateral: float # lateral thruster [-1, 1]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class EpisodeResult:
|
| 32 |
+
episode_idx: int
|
| 33 |
+
steps: list[StepData] = field(default_factory=list)
|
| 34 |
+
total_reward: float = 0.0
|
| 35 |
+
landed: bool = False
|
| 36 |
+
crashed: bool = False
|
| 37 |
+
|
| 38 |
+
@property
|
| 39 |
+
def status(self) -> str:
|
| 40 |
+
if self.total_reward >= 200:
|
| 41 |
+
return "PERFECT"
|
| 42 |
+
if self.total_reward >= 150:
|
| 43 |
+
return "LANDED"
|
| 44 |
+
if self.total_reward >= 0:
|
| 45 |
+
return "PARTIAL"
|
| 46 |
+
return "CRASHED"
|
| 47 |
+
|
| 48 |
+
@property
|
| 49 |
+
def status_emoji(self) -> str:
|
| 50 |
+
return {"PERFECT": "π", "LANDED": "β
", "PARTIAL": "β οΈ", "CRASHED": "π₯"}[self.status]
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def xs(self) -> list[float]:
|
| 54 |
+
return [s.x for s in self.steps]
|
| 55 |
+
|
| 56 |
+
@property
|
| 57 |
+
def ys(self) -> list[float]:
|
| 58 |
+
return [s.y for s in self.steps]
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def cumulative_rewards(self) -> list[float]:
|
| 62 |
+
total = 0.0
|
| 63 |
+
out = []
|
| 64 |
+
for s in self.steps:
|
| 65 |
+
total += s.reward
|
| 66 |
+
out.append(total)
|
| 67 |
+
return out
|
| 68 |
+
|
| 69 |
+
@property
|
| 70 |
+
def main_throttle(self) -> list[float]:
|
| 71 |
+
return [s.action_main for s in self.steps]
|
| 72 |
+
|
| 73 |
+
@property
|
| 74 |
+
def lateral_throttle(self) -> list[float]:
|
| 75 |
+
return [s.action_lateral for s in self.steps]
|
| 76 |
+
|
| 77 |
+
@property
|
| 78 |
+
def angles(self) -> list[float]:
|
| 79 |
+
return [np.degrees(s.angle) for s in self.steps]
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@dataclass
|
| 83 |
+
class MissionResult:
|
| 84 |
+
episodes: list[EpisodeResult] = field(default_factory=list)
|
| 85 |
+
|
| 86 |
+
@property
|
| 87 |
+
def rewards(self) -> list[float]:
|
| 88 |
+
return [e.total_reward for e in self.episodes]
|
| 89 |
+
|
| 90 |
+
@property
|
| 91 |
+
def success_rate(self) -> float:
|
| 92 |
+
if not self.episodes:
|
| 93 |
+
return 0.0
|
| 94 |
+
return sum(1 for e in self.episodes if e.total_reward >= 150) / len(self.episodes)
|
| 95 |
+
|
| 96 |
+
@property
|
| 97 |
+
def avg_reward(self) -> float:
|
| 98 |
+
return float(np.mean(self.rewards)) if self.rewards else 0.0
|
| 99 |
+
|
| 100 |
+
@property
|
| 101 |
+
def best(self) -> EpisodeResult:
|
| 102 |
+
return max(self.episodes, key=lambda e: e.total_reward)
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def worst(self) -> EpisodeResult:
|
| 106 |
+
return min(self.episodes, key=lambda e: e.total_reward)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# ββ Runner ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 110 |
+
|
| 111 |
+
def run_mission(
|
| 112 |
+
model: SAC,
|
| 113 |
+
n_episodes: int = 5,
|
| 114 |
+
gravity: float = -10.0,
|
| 115 |
+
enable_wind: bool = False,
|
| 116 |
+
wind_power: float = 5.0,
|
| 117 |
+
turbulence_power: float = 0.5,
|
| 118 |
+
render: bool = True,
|
| 119 |
+
progress_cb=None,
|
| 120 |
+
) -> tuple[MissionResult, list[list[np.ndarray]]]:
|
| 121 |
+
"""
|
| 122 |
+
Run `n_episodes` of the lander.
|
| 123 |
+
Returns (MissionResult, list_of_frame_lists) β one frame list per episode.
|
| 124 |
+
"""
|
| 125 |
+
mission = MissionResult()
|
| 126 |
+
all_frames: list[list[np.ndarray]] = []
|
| 127 |
+
|
| 128 |
+
env_kwargs = dict(
|
| 129 |
+
continuous=True,
|
| 130 |
+
gravity=gravity,
|
| 131 |
+
enable_wind=enable_wind,
|
| 132 |
+
wind_power=wind_power if enable_wind else 0.0,
|
| 133 |
+
turbulence_power=turbulence_power if enable_wind else 0.0,
|
| 134 |
+
render_mode="rgb_array" if render else None,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
for ep_idx in range(n_episodes):
|
| 138 |
+
if progress_cb:
|
| 139 |
+
progress_cb(ep_idx / n_episodes, f"Running mission {ep_idx + 1}/{n_episodes}β¦")
|
| 140 |
+
|
| 141 |
+
env = gym.make("LunarLander-v3", **env_kwargs)
|
| 142 |
+
obs, _ = env.reset()
|
| 143 |
+
|
| 144 |
+
result = EpisodeResult(episode_idx=ep_idx)
|
| 145 |
+
frames: list[np.ndarray] = []
|
| 146 |
+
|
| 147 |
+
done = False
|
| 148 |
+
while not done:
|
| 149 |
+
action, _ = model.predict(obs, deterministic=True)
|
| 150 |
+
next_obs, reward, terminated, truncated, _ = env.step(action)
|
| 151 |
+
|
| 152 |
+
result.steps.append(StepData(
|
| 153 |
+
x=float(obs[0]), y=float(obs[1]),
|
| 154 |
+
vx=float(obs[2]), vy=float(obs[3]),
|
| 155 |
+
angle=float(obs[4]), angular_vel=float(obs[5]),
|
| 156 |
+
left_leg=bool(obs[6]), right_leg=bool(obs[7]),
|
| 157 |
+
reward=float(reward),
|
| 158 |
+
action_main=float(action[0]),
|
| 159 |
+
action_lateral=float(action[1]),
|
| 160 |
+
))
|
| 161 |
+
result.total_reward += float(reward)
|
| 162 |
+
|
| 163 |
+
if render:
|
| 164 |
+
frame = env.render()
|
| 165 |
+
if frame is not None:
|
| 166 |
+
frames.append(frame)
|
| 167 |
+
|
| 168 |
+
obs = next_obs
|
| 169 |
+
done = terminated or truncated
|
| 170 |
+
|
| 171 |
+
env.close()
|
| 172 |
+
mission.episodes.append(result)
|
| 173 |
+
all_frames.append(frames)
|
| 174 |
+
|
| 175 |
+
if progress_cb:
|
| 176 |
+
progress_cb(1.0, "Mission complete.")
|
| 177 |
+
|
| 178 |
+
return mission, all_frames
|
core/trainer.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SAC training pipeline β fine-tune or train from scratch with live callbacks.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
import os
|
| 7 |
+
import threading
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
+
from stable_baselines3 import SAC
|
| 10 |
+
from stable_baselines3.common.callbacks import BaseCallback
|
| 11 |
+
import gymnasium as gym
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class TrainingState:
|
| 17 |
+
running: bool = False
|
| 18 |
+
timestep: int = 0
|
| 19 |
+
total_timesteps: int = 0
|
| 20 |
+
episode_rewards: list[float] = field(default_factory=list)
|
| 21 |
+
actor_losses: list[float] = field(default_factory=list)
|
| 22 |
+
critic_losses: list[float] = field(default_factory=list)
|
| 23 |
+
ent_coefs: list[float] = field(default_factory=list)
|
| 24 |
+
log_steps: list[int] = field(default_factory=list)
|
| 25 |
+
status: str = "idle"
|
| 26 |
+
best_reward: float = float("-inf")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class _LiveCallback(BaseCallback):
|
| 30 |
+
def __init__(self, state: TrainingState, log_interval: int = 500):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self._state = state
|
| 33 |
+
self._log_interval = log_interval
|
| 34 |
+
self._ep_rewards: list[float] = []
|
| 35 |
+
|
| 36 |
+
def _on_step(self) -> bool:
|
| 37 |
+
if not self._state.running:
|
| 38 |
+
return False # abort training
|
| 39 |
+
|
| 40 |
+
self._state.timestep = self.num_timesteps
|
| 41 |
+
|
| 42 |
+
# Collect episode rewards from monitor wrapper
|
| 43 |
+
infos = self.locals.get("infos", [])
|
| 44 |
+
for info in infos:
|
| 45 |
+
if "episode" in info:
|
| 46 |
+
r = float(info["episode"]["r"])
|
| 47 |
+
self._ep_rewards.append(r)
|
| 48 |
+
self._state.episode_rewards.append(r)
|
| 49 |
+
if r > self._state.best_reward:
|
| 50 |
+
self._state.best_reward = r
|
| 51 |
+
|
| 52 |
+
if self.num_timesteps % self._log_interval == 0:
|
| 53 |
+
losses = self.model.logger.name_to_value
|
| 54 |
+
self._state.actor_losses.append(float(losses.get("train/actor_loss", 0)))
|
| 55 |
+
self._state.critic_losses.append(float(losses.get("train/critic_loss", 0)))
|
| 56 |
+
self._state.ent_coefs.append(float(losses.get("train/ent_coef", 0)))
|
| 57 |
+
self._state.log_steps.append(self.num_timesteps)
|
| 58 |
+
|
| 59 |
+
pct = self.num_timesteps / max(self._state.total_timesteps, 1)
|
| 60 |
+
rolling = float(np.mean(self._ep_rewards[-20:])) if self._ep_rewards else 0.0
|
| 61 |
+
self._state.status = (
|
| 62 |
+
f"Step {self.num_timesteps:,}/{self._state.total_timesteps:,} "
|
| 63 |
+
f"({pct*100:.1f}%) | Rolling reward: {rolling:+.1f} | "
|
| 64 |
+
f"Best: {self._state.best_reward:+.1f}"
|
| 65 |
+
)
|
| 66 |
+
return True
|
| 67 |
+
|
| 68 |
+
def _on_training_end(self) -> None:
|
| 69 |
+
self._state.status = (
|
| 70 |
+
f"Training complete β {self.num_timesteps:,} steps. "
|
| 71 |
+
f"Best reward: {self._state.best_reward:+.1f}"
|
| 72 |
+
)
|
| 73 |
+
self._state.running = False
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def start_training(
|
| 77 |
+
base_model_path: str,
|
| 78 |
+
total_timesteps: int,
|
| 79 |
+
learning_rate: float,
|
| 80 |
+
batch_size: int,
|
| 81 |
+
state: TrainingState,
|
| 82 |
+
save_path: str = "sac_finetuned.zip",
|
| 83 |
+
) -> threading.Thread:
|
| 84 |
+
"""Launches training in a daemon thread. Progress written to `state`."""
|
| 85 |
+
|
| 86 |
+
def _train():
|
| 87 |
+
from stable_baselines3.common.monitor import Monitor
|
| 88 |
+
|
| 89 |
+
state.running = True
|
| 90 |
+
state.total_timesteps = total_timesteps
|
| 91 |
+
state.status = "Initialising environmentβ¦"
|
| 92 |
+
|
| 93 |
+
env = Monitor(gym.make("LunarLander-v3", continuous=True))
|
| 94 |
+
|
| 95 |
+
if os.path.exists(base_model_path):
|
| 96 |
+
model = SAC.load(base_model_path, env=env)
|
| 97 |
+
model.learning_rate = learning_rate
|
| 98 |
+
model.batch_size = batch_size
|
| 99 |
+
else:
|
| 100 |
+
model = SAC(
|
| 101 |
+
"MlpPolicy", env,
|
| 102 |
+
learning_rate=learning_rate,
|
| 103 |
+
batch_size=batch_size,
|
| 104 |
+
verbose=0,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
cb = _LiveCallback(state, log_interval=max(total_timesteps // 200, 200))
|
| 108 |
+
model.learn(
|
| 109 |
+
total_timesteps=total_timesteps,
|
| 110 |
+
callback=cb,
|
| 111 |
+
reset_num_timesteps=False,
|
| 112 |
+
progress_bar=False,
|
| 113 |
+
log_interval=1,
|
| 114 |
+
)
|
| 115 |
+
model.save(save_path)
|
| 116 |
+
env.close()
|
| 117 |
+
|
| 118 |
+
thread = threading.Thread(target=_train, daemon=True)
|
| 119 |
+
thread.start()
|
| 120 |
+
return thread
|
requirements.txt
CHANGED
|
@@ -1,3 +1,5 @@
|
|
| 1 |
stable-baselines3[extra]
|
| 2 |
gymnasium[box2d]
|
| 3 |
shimmy
|
|
|
|
|
|
|
|
|
| 1 |
stable-baselines3[extra]
|
| 2 |
gymnasium[box2d]
|
| 3 |
shimmy
|
| 4 |
+
matplotlib
|
| 5 |
+
gradio>=6.0.0
|
viz/__init__.py
ADDED
|
File without changes
|
viz/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (196 Bytes). View file
|
|
|
viz/__pycache__/charts.cpython-311.pyc
ADDED
|
Binary file (18.5 kB). View file
|
|
|
viz/__pycache__/replay.cpython-311.pyc
ADDED
|
Binary file (8.43 kB). View file
|
|
|
viz/charts.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
All matplotlib figure generation for the dashboard.
|
| 3 |
+
Every function returns a plt.Figure β caller closes or passes to Gradio.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
import numpy as np
|
| 8 |
+
import matplotlib
|
| 9 |
+
matplotlib.use("Agg")
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import matplotlib.gridspec as gridspec
|
| 12 |
+
import matplotlib.patches as mpatches
|
| 13 |
+
from matplotlib.collections import LineCollection
|
| 14 |
+
|
| 15 |
+
from core.mission import MissionResult, EpisodeResult
|
| 16 |
+
|
| 17 |
+
# ββ Palette βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 18 |
+
BG = "#030b1a"
|
| 19 |
+
BG2 = "#060f1e"
|
| 20 |
+
GRID = "#0d2540"
|
| 21 |
+
ACCENT = "#4fb3ff"
|
| 22 |
+
GREEN = "#2ddb7c"
|
| 23 |
+
AMBER = "#f5a623"
|
| 24 |
+
RED = "#ff4d6d"
|
| 25 |
+
PURPLE = "#c77dff"
|
| 26 |
+
TEXT = "#c8ddf0"
|
| 27 |
+
DIM = "#3a6080"
|
| 28 |
+
|
| 29 |
+
EP_COLORS = [ACCENT, GREEN, AMBER, PURPLE, "#ff9f1c", "#e9c46a", "#f4a261"]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _style_ax(ax, title: str = "", xlabel: str = "", ylabel: str = ""):
|
| 33 |
+
ax.set_facecolor(BG2)
|
| 34 |
+
ax.tick_params(colors=DIM, labelsize=8)
|
| 35 |
+
for spine in ax.spines.values():
|
| 36 |
+
spine.set_color(GRID)
|
| 37 |
+
ax.grid(color=GRID, linewidth=0.5, linestyle="--", alpha=0.6)
|
| 38 |
+
if title:
|
| 39 |
+
ax.set_title(title, color=TEXT, fontsize=10, pad=8, fontfamily="monospace")
|
| 40 |
+
if xlabel:
|
| 41 |
+
ax.set_xlabel(xlabel, color=DIM, fontsize=8)
|
| 42 |
+
if ylabel:
|
| 43 |
+
ax.set_ylabel(ylabel, color=DIM, fontsize=8)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def mission_overview(mission: MissionResult) -> plt.Figure:
|
| 47 |
+
"""4-panel summary: bar chart, trajectory, reward curves, throttle."""
|
| 48 |
+
n = len(mission.episodes)
|
| 49 |
+
fig = plt.figure(figsize=(14, 9), facecolor=BG)
|
| 50 |
+
gs = gridspec.GridSpec(2, 2, figure=fig, hspace=0.45, wspace=0.32,
|
| 51 |
+
left=0.07, right=0.97, top=0.90, bottom=0.08)
|
| 52 |
+
|
| 53 |
+
# ββ Panel 1: Episode rewards bar βββββββββββββββββββββββββββββββββββββββββ
|
| 54 |
+
ax1 = fig.add_subplot(gs[0, 0])
|
| 55 |
+
_style_ax(ax1, "EPISODE REWARDS", "Episode", "Score")
|
| 56 |
+
labels = [f"#{e.episode_idx+1}" for e in mission.episodes]
|
| 57 |
+
colors = [GREEN if r >= 150 else (AMBER if r >= 0 else RED) for r in mission.rewards]
|
| 58 |
+
bars = ax1.bar(labels, mission.rewards, color=colors, edgecolor=BG, linewidth=0.8)
|
| 59 |
+
ax1.axhline(200, color=GREEN, linestyle="--", linewidth=1, alpha=0.5, label="Perfect (200)")
|
| 60 |
+
ax1.axhline(150, color=ACCENT, linestyle="--", linewidth=1, alpha=0.5, label="Success (150)")
|
| 61 |
+
ax1.axhline(0, color=RED, linestyle="--", linewidth=1, alpha=0.3)
|
| 62 |
+
ax1.legend(fontsize=7, facecolor=BG2, edgecolor=GRID, labelcolor=DIM)
|
| 63 |
+
for bar, val in zip(bars, mission.rewards):
|
| 64 |
+
ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 3,
|
| 65 |
+
f"{val:.0f}", ha="center", va="bottom", color=TEXT, fontsize=8)
|
| 66 |
+
|
| 67 |
+
# ββ Panel 2: 2-D flight trajectory βββββββββββββββββββββββββββββββββββββββ
|
| 68 |
+
ax2 = fig.add_subplot(gs[0, 1])
|
| 69 |
+
_style_ax(ax2, "FLIGHT TRAJECTORIES", "X Position", "Altitude")
|
| 70 |
+
for i, ep in enumerate(mission.episodes):
|
| 71 |
+
col = EP_COLORS[i % len(EP_COLORS)]
|
| 72 |
+
# colour-map by altitude for gradient effect
|
| 73 |
+
points = np.array([ep.xs, ep.ys]).T.reshape(-1, 1, 2)
|
| 74 |
+
segments = np.concatenate([points[:-1], points[1:]], axis=1)
|
| 75 |
+
lc = LineCollection(segments, colors=col, linewidth=1.2, alpha=0.7)
|
| 76 |
+
ax2.add_collection(lc)
|
| 77 |
+
# landing marker
|
| 78 |
+
ax2.scatter(ep.xs[-1], ep.ys[-1],
|
| 79 |
+
marker=("*" if ep.total_reward >= 150 else "x"),
|
| 80 |
+
s=80, color=col, zorder=5)
|
| 81 |
+
ax2.autoscale()
|
| 82 |
+
ax2.axhline(0, color=GRID, linewidth=1)
|
| 83 |
+
# Legend patches
|
| 84 |
+
patches = [mpatches.Patch(color=EP_COLORS[i % len(EP_COLORS)],
|
| 85 |
+
label=f"#{e.episode_idx+1} {e.status_emoji}")
|
| 86 |
+
for i, e in enumerate(mission.episodes)]
|
| 87 |
+
ax2.legend(handles=patches, fontsize=7, facecolor=BG2,
|
| 88 |
+
edgecolor=GRID, labelcolor=DIM, loc="upper right")
|
| 89 |
+
|
| 90 |
+
# ββ Panel 3: Cumulative reward over steps ββββββββββββββββββββββββββββββββ
|
| 91 |
+
ax3 = fig.add_subplot(gs[1, 0])
|
| 92 |
+
_style_ax(ax3, "CUMULATIVE REWARD", "Step", "Reward")
|
| 93 |
+
for i, ep in enumerate(mission.episodes):
|
| 94 |
+
col = EP_COLORS[i % len(EP_COLORS)]
|
| 95 |
+
ax3.plot(ep.cumulative_rewards, color=col, linewidth=1.5,
|
| 96 |
+
label=f"#{ep.episode_idx+1}", alpha=0.85)
|
| 97 |
+
ax3.axhline(0, color=RED, linestyle="--", linewidth=0.8, alpha=0.4)
|
| 98 |
+
ax3.legend(fontsize=7, facecolor=BG2, edgecolor=GRID, labelcolor=DIM)
|
| 99 |
+
|
| 100 |
+
# ββ Panel 4: Engine throttle timeline βββββββββββββββββββββββββββββββββββ
|
| 101 |
+
ax4 = fig.add_subplot(gs[1, 1])
|
| 102 |
+
_style_ax(ax4, "ENGINE THROTTLE β BEST EPISODE", "Step", "Throttle")
|
| 103 |
+
best = mission.best
|
| 104 |
+
steps = range(len(best.steps))
|
| 105 |
+
ax4.fill_between(steps, 0, best.main_throttle,
|
| 106 |
+
color=ACCENT, alpha=0.35, label="Main Engine")
|
| 107 |
+
ax4.plot(steps, best.main_throttle, color=ACCENT, linewidth=1.2)
|
| 108 |
+
ax4.fill_between(steps, 0, best.lateral_throttle,
|
| 109 |
+
color=AMBER, alpha=0.25, label="Lateral Thrusters")
|
| 110 |
+
ax4.plot(steps, best.lateral_throttle, color=AMBER, linewidth=1.0)
|
| 111 |
+
ax4.axhline(0, color=GRID, linewidth=0.8)
|
| 112 |
+
ax4.set_ylim(-1.1, 1.1)
|
| 113 |
+
ax4.legend(fontsize=7, facecolor=BG2, edgecolor=GRID, labelcolor=DIM)
|
| 114 |
+
|
| 115 |
+
# ββ Figure title βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 116 |
+
sr = mission.success_rate * 100
|
| 117 |
+
fig.suptitle(
|
| 118 |
+
f"MISSION REPORT Β· {n} episodes Β· "
|
| 119 |
+
f"Avg {mission.avg_reward:+.1f} Β· Success {sr:.0f}%",
|
| 120 |
+
color=TEXT, fontsize=12, fontfamily="monospace", y=0.96,
|
| 121 |
+
)
|
| 122 |
+
return fig
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def single_episode_detail(ep: EpisodeResult) -> plt.Figure:
|
| 126 |
+
"""6-panel deep-dive for one episode."""
|
| 127 |
+
fig = plt.figure(figsize=(14, 8), facecolor=BG)
|
| 128 |
+
gs = gridspec.GridSpec(2, 3, figure=fig, hspace=0.5, wspace=0.38,
|
| 129 |
+
left=0.07, right=0.97, top=0.88, bottom=0.08)
|
| 130 |
+
|
| 131 |
+
steps = list(range(len(ep.steps)))
|
| 132 |
+
|
| 133 |
+
# Trajectory
|
| 134 |
+
ax = fig.add_subplot(gs[0, 0])
|
| 135 |
+
_style_ax(ax, "TRAJECTORY", "X", "Y")
|
| 136 |
+
ax.plot(ep.xs, ep.ys, color=ACCENT, linewidth=1.5)
|
| 137 |
+
ax.scatter(ep.xs[0], ep.ys[0], s=60, color=GREEN, zorder=5, label="Start")
|
| 138 |
+
ax.scatter(ep.xs[-1], ep.ys[-1], s=80,
|
| 139 |
+
marker="*" if ep.total_reward >= 150 else "x",
|
| 140 |
+
color=GREEN if ep.total_reward >= 150 else RED, zorder=5, label="End")
|
| 141 |
+
ax.axhline(0, color=GRID, linewidth=1)
|
| 142 |
+
ax.legend(fontsize=7, facecolor=BG2, edgecolor=GRID, labelcolor=DIM)
|
| 143 |
+
|
| 144 |
+
# Cumulative reward
|
| 145 |
+
ax = fig.add_subplot(gs[0, 1])
|
| 146 |
+
_style_ax(ax, "CUMULATIVE REWARD", "Step", "Reward")
|
| 147 |
+
cum = ep.cumulative_rewards
|
| 148 |
+
ax.fill_between(steps, 0, cum,
|
| 149 |
+
color=GREEN if ep.total_reward >= 150 else RED, alpha=0.2)
|
| 150 |
+
ax.plot(steps, cum, color=GREEN if ep.total_reward >= 150 else RED, linewidth=1.5)
|
| 151 |
+
ax.axhline(0, color=GRID, linewidth=0.8, linestyle="--")
|
| 152 |
+
|
| 153 |
+
# Altitude over time
|
| 154 |
+
ax = fig.add_subplot(gs[0, 2])
|
| 155 |
+
_style_ax(ax, "ALTITUDE", "Step", "Y")
|
| 156 |
+
ax.fill_between(steps, 0, ep.ys, color=ACCENT, alpha=0.15)
|
| 157 |
+
ax.plot(steps, ep.ys, color=ACCENT, linewidth=1.5)
|
| 158 |
+
ax.axhline(0, color=RED, linewidth=1, linestyle="--", alpha=0.5)
|
| 159 |
+
|
| 160 |
+
# Angle
|
| 161 |
+
ax = fig.add_subplot(gs[1, 0])
|
| 162 |
+
_style_ax(ax, "BODY ANGLE", "Step", "Degrees")
|
| 163 |
+
ax.fill_between(steps, 0, ep.angles, color=AMBER, alpha=0.2)
|
| 164 |
+
ax.plot(steps, ep.angles, color=AMBER, linewidth=1.3)
|
| 165 |
+
ax.axhline(0, color=GRID, linewidth=0.8, linestyle="--")
|
| 166 |
+
|
| 167 |
+
# Main throttle
|
| 168 |
+
ax = fig.add_subplot(gs[1, 1])
|
| 169 |
+
_style_ax(ax, "MAIN ENGINE", "Step", "Throttle")
|
| 170 |
+
ax.fill_between(steps, 0, ep.main_throttle, color=ACCENT, alpha=0.3)
|
| 171 |
+
ax.plot(steps, ep.main_throttle, color=ACCENT, linewidth=1.2)
|
| 172 |
+
ax.set_ylim(-1.1, 1.1)
|
| 173 |
+
ax.axhline(0, color=GRID, linewidth=0.8)
|
| 174 |
+
|
| 175 |
+
# Lateral throttle
|
| 176 |
+
ax = fig.add_subplot(gs[1, 2])
|
| 177 |
+
_style_ax(ax, "LATERAL THRUSTERS", "Step", "Throttle")
|
| 178 |
+
ax.fill_between(steps, 0, ep.lateral_throttle, color=PURPLE, alpha=0.3)
|
| 179 |
+
ax.plot(steps, ep.lateral_throttle, color=PURPLE, linewidth=1.2)
|
| 180 |
+
ax.set_ylim(-1.1, 1.1)
|
| 181 |
+
ax.axhline(0, color=GRID, linewidth=0.8)
|
| 182 |
+
|
| 183 |
+
fig.suptitle(
|
| 184 |
+
f"EPISODE {ep.episode_idx+1} DEEP-DIVE Β· "
|
| 185 |
+
f"{ep.status_emoji} {ep.status} Β· Score: {ep.total_reward:+.1f} Β· "
|
| 186 |
+
f"{len(ep.steps)} steps",
|
| 187 |
+
color=TEXT, fontsize=11, fontfamily="monospace", y=0.95,
|
| 188 |
+
)
|
| 189 |
+
return fig
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def training_dashboard(state) -> plt.Figure:
|
| 193 |
+
"""Live training metrics: reward history + losses + entropy."""
|
| 194 |
+
fig = plt.figure(figsize=(14, 5), facecolor=BG)
|
| 195 |
+
gs = gridspec.GridSpec(1, 3, figure=fig, wspace=0.38,
|
| 196 |
+
left=0.06, right=0.97, top=0.85, bottom=0.12)
|
| 197 |
+
|
| 198 |
+
# Reward curve
|
| 199 |
+
ax = fig.add_subplot(gs[0])
|
| 200 |
+
_style_ax(ax, "EPISODE REWARD", "Episode", "Reward")
|
| 201 |
+
if state.episode_rewards:
|
| 202 |
+
eps = list(range(len(state.episode_rewards)))
|
| 203 |
+
ax.plot(eps, state.episode_rewards, color=ACCENT, linewidth=0.8, alpha=0.4)
|
| 204 |
+
if len(eps) > 20:
|
| 205 |
+
k = max(5, len(eps) // 30)
|
| 206 |
+
smooth = np.convolve(state.episode_rewards, np.ones(k)/k, "valid")
|
| 207 |
+
ax.plot(range(k-1, len(eps)), smooth, color=ACCENT, linewidth=2)
|
| 208 |
+
ax.axhline(200, color=GREEN, linestyle="--", linewidth=1, alpha=0.5)
|
| 209 |
+
ax.axhline(150, color=AMBER, linestyle="--", linewidth=1, alpha=0.5)
|
| 210 |
+
|
| 211 |
+
# Losses
|
| 212 |
+
ax2 = fig.add_subplot(gs[1])
|
| 213 |
+
_style_ax(ax2, "ACTOR / CRITIC LOSS", "Log Step", "Loss")
|
| 214 |
+
if state.log_steps:
|
| 215 |
+
ax2.plot(state.log_steps, state.actor_losses, color=ACCENT,
|
| 216 |
+
linewidth=1.5, label="Actor")
|
| 217 |
+
ax2.plot(state.log_steps, state.critic_losses, color=AMBER,
|
| 218 |
+
linewidth=1.5, label="Critic")
|
| 219 |
+
ax2.legend(fontsize=7, facecolor=BG2, edgecolor=GRID, labelcolor=DIM)
|
| 220 |
+
|
| 221 |
+
# Entropy coef
|
| 222 |
+
ax3 = fig.add_subplot(gs[2])
|
| 223 |
+
_style_ax(ax3, "ENTROPY COEFFICIENT", "Log Step", "Ξ±")
|
| 224 |
+
if state.log_steps:
|
| 225 |
+
ax3.plot(state.log_steps, state.ent_coefs, color=PURPLE, linewidth=1.5)
|
| 226 |
+
ax3.axhline(0, color=GRID, linewidth=0.8, linestyle="--")
|
| 227 |
+
|
| 228 |
+
n_ep = len(state.episode_rewards)
|
| 229 |
+
best = state.best_reward
|
| 230 |
+
fig.suptitle(
|
| 231 |
+
f"SAC TRAINING Β· {state.timestep:,}/{state.total_timesteps:,} steps Β· "
|
| 232 |
+
f"{n_ep} episodes Β· Best: {best:+.1f}",
|
| 233 |
+
color=TEXT, fontsize=10, fontfamily="monospace",
|
| 234 |
+
)
|
| 235 |
+
return fig
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def empty_figure(message: str = "Run a mission to see charts.") -> plt.Figure:
|
| 239 |
+
fig, ax = plt.subplots(figsize=(12, 5), facecolor=BG)
|
| 240 |
+
fig.patch.set_facecolor(BG)
|
| 241 |
+
ax.set_facecolor(BG2)
|
| 242 |
+
ax.text(0.5, 0.5, message, transform=ax.transAxes,
|
| 243 |
+
ha="center", va="center", color=DIM,
|
| 244 |
+
fontsize=13, fontfamily="monospace")
|
| 245 |
+
ax.axis("off")
|
| 246 |
+
return fig
|
viz/replay.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Animated GIF generation from raw RGB frames.
|
| 3 |
+
Adds HUD overlay (step, reward, throttle bars) using PIL drawing β no matplotlib overhead.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
import tempfile
|
| 8 |
+
import numpy as np
|
| 9 |
+
import PIL.Image
|
| 10 |
+
import PIL.ImageDraw
|
| 11 |
+
import PIL.ImageFont
|
| 12 |
+
|
| 13 |
+
from core.mission import EpisodeResult
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# ββ HUD rendering βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 17 |
+
|
| 18 |
+
def _draw_hud(
|
| 19 |
+
img: PIL.Image.Image,
|
| 20 |
+
step: int,
|
| 21 |
+
cumulative_reward: float,
|
| 22 |
+
main_throttle: float,
|
| 23 |
+
lateral_throttle: float,
|
| 24 |
+
status: str,
|
| 25 |
+
) -> PIL.Image.Image:
|
| 26 |
+
draw = PIL.ImageDraw.Draw(img)
|
| 27 |
+
W, H = img.size
|
| 28 |
+
|
| 29 |
+
# Semi-transparent top bar
|
| 30 |
+
draw.rectangle([(0, 0), (W, 22)], fill=(3, 11, 26, 200))
|
| 31 |
+
|
| 32 |
+
# Step & reward text
|
| 33 |
+
draw.text((6, 4), f"STEP {step:03d}", fill=(79, 179, 255), font=None)
|
| 34 |
+
rcolor = (45, 219, 124) if cumulative_reward >= 0 else (255, 77, 109)
|
| 35 |
+
draw.text((W//2 - 40, 4), f"REWARD {cumulative_reward:+.1f}", fill=rcolor, font=None)
|
| 36 |
+
draw.text((W - 80, 4), status, fill=(248, 166, 35), font=None)
|
| 37 |
+
|
| 38 |
+
# Throttle bars at bottom
|
| 39 |
+
BAR_H = 6
|
| 40 |
+
BAR_Y = H - BAR_H - 4
|
| 41 |
+
|
| 42 |
+
# Main engine bar (blue)
|
| 43 |
+
bar_max = W // 2 - 20
|
| 44 |
+
bar_w = int(abs(main_throttle) * bar_max)
|
| 45 |
+
draw.rectangle([(10, BAR_Y), (10 + bar_max, BAR_Y + BAR_H)],
|
| 46 |
+
fill=(13, 37, 64))
|
| 47 |
+
draw.rectangle([(10, BAR_Y), (10 + bar_w, BAR_Y + BAR_H)],
|
| 48 |
+
fill=(79, 179, 255))
|
| 49 |
+
draw.text((10, BAR_Y - 11), "MAIN", fill=(79, 179, 255), font=None)
|
| 50 |
+
|
| 51 |
+
# Lateral bar (amber)
|
| 52 |
+
lx = W // 2 + 10
|
| 53 |
+
lat_w = int(abs(lateral_throttle) * bar_max)
|
| 54 |
+
draw.rectangle([(lx, BAR_Y), (lx + bar_max, BAR_Y + BAR_H)],
|
| 55 |
+
fill=(13, 37, 64))
|
| 56 |
+
col = (245, 166, 35) if lateral_throttle >= 0 else (255, 77, 109)
|
| 57 |
+
draw.rectangle([(lx, BAR_Y), (lx + lat_w, BAR_Y + BAR_H)], fill=col)
|
| 58 |
+
draw.text((lx, BAR_Y - 11), "LATERAL", fill=(245, 166, 35), font=None)
|
| 59 |
+
|
| 60 |
+
return img
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def make_episode_gif(
|
| 64 |
+
frames: list[np.ndarray],
|
| 65 |
+
episode: EpisodeResult,
|
| 66 |
+
fps: int = 15,
|
| 67 |
+
) -> str:
|
| 68 |
+
"""Overlay HUD on every frame, save as animated GIF. Returns temp file path."""
|
| 69 |
+
if not frames:
|
| 70 |
+
return ""
|
| 71 |
+
|
| 72 |
+
cum_rewards = episode.cumulative_rewards
|
| 73 |
+
pil_frames: list[PIL.Image.Image] = []
|
| 74 |
+
|
| 75 |
+
for i, frame in enumerate(frames):
|
| 76 |
+
img = PIL.Image.fromarray(frame).convert("RGBA")
|
| 77 |
+
cum_r = cum_rewards[i] if i < len(cum_rewards) else cum_rewards[-1]
|
| 78 |
+
step_data = episode.steps[i] if i < len(episode.steps) else episode.steps[-1]
|
| 79 |
+
img = _draw_hud(
|
| 80 |
+
img.convert("RGB"),
|
| 81 |
+
step=i + 1,
|
| 82 |
+
cumulative_reward=cum_r,
|
| 83 |
+
main_throttle=step_data.action_main,
|
| 84 |
+
lateral_throttle=step_data.action_lateral,
|
| 85 |
+
status=episode.status,
|
| 86 |
+
)
|
| 87 |
+
pil_frames.append(img)
|
| 88 |
+
|
| 89 |
+
tmp = tempfile.NamedTemporaryFile(suffix=".gif", delete=False)
|
| 90 |
+
pil_frames[0].save(
|
| 91 |
+
tmp.name,
|
| 92 |
+
save_all=True,
|
| 93 |
+
append_images=pil_frames[1:],
|
| 94 |
+
duration=int(1000 / fps),
|
| 95 |
+
loop=0,
|
| 96 |
+
optimize=False,
|
| 97 |
+
)
|
| 98 |
+
return tmp.name
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def make_comparison_gif(
|
| 102 |
+
all_frames: list[list[np.ndarray]],
|
| 103 |
+
episodes: list[EpisodeResult],
|
| 104 |
+
fps: int = 12,
|
| 105 |
+
max_episodes: int = 4,
|
| 106 |
+
) -> str:
|
| 107 |
+
"""
|
| 108 |
+
Side-by-side grid GIF comparing up to `max_episodes` episodes.
|
| 109 |
+
Pads shorter episodes with their last frame.
|
| 110 |
+
"""
|
| 111 |
+
n = min(len(all_frames), max_episodes)
|
| 112 |
+
if n == 0:
|
| 113 |
+
return ""
|
| 114 |
+
|
| 115 |
+
frame_lists = [all_frames[i] for i in range(n)]
|
| 116 |
+
ep_list = [episodes[i] for i in range(n)]
|
| 117 |
+
|
| 118 |
+
max_len = max(len(fl) for fl in frame_lists)
|
| 119 |
+
# Pad each episode to max_len
|
| 120 |
+
padded = [fl + [fl[-1]] * (max_len - len(fl)) if fl else [] for fl in frame_lists]
|
| 121 |
+
|
| 122 |
+
if not padded[0]:
|
| 123 |
+
return ""
|
| 124 |
+
|
| 125 |
+
h, w = padded[0][0].shape[:2]
|
| 126 |
+
cols = 2 if n > 2 else n
|
| 127 |
+
rows = (n + cols - 1) // cols
|
| 128 |
+
grid_w, grid_h = cols * w, rows * h
|
| 129 |
+
|
| 130 |
+
pil_frames: list[PIL.Image.Image] = []
|
| 131 |
+
for step_i in range(max_len):
|
| 132 |
+
canvas = PIL.Image.new("RGB", (grid_w, grid_h), (3, 11, 26))
|
| 133 |
+
for ep_i in range(n):
|
| 134 |
+
if step_i < len(padded[ep_i]):
|
| 135 |
+
cell = PIL.Image.fromarray(padded[ep_i][step_i])
|
| 136 |
+
else:
|
| 137 |
+
continue
|
| 138 |
+
# label
|
| 139 |
+
draw = PIL.ImageDraw.Draw(cell)
|
| 140 |
+
ep = ep_list[ep_i]
|
| 141 |
+
draw.rectangle([(0, 0), (cell.width, 16)], fill=(3, 11, 26))
|
| 142 |
+
col = (45, 219, 124) if ep.total_reward >= 150 else (255, 77, 109)
|
| 143 |
+
draw.text((4, 2),
|
| 144 |
+
f"#{ep.episode_idx+1} {ep.status} {ep.total_reward:+.0f}",
|
| 145 |
+
fill=col, font=None)
|
| 146 |
+
cx = (ep_i % cols) * w
|
| 147 |
+
cy = (ep_i // cols) * h
|
| 148 |
+
canvas.paste(cell, (cx, cy))
|
| 149 |
+
pil_frames.append(canvas)
|
| 150 |
+
|
| 151 |
+
tmp = tempfile.NamedTemporaryFile(suffix=".gif", delete=False)
|
| 152 |
+
pil_frames[0].save(
|
| 153 |
+
tmp.name,
|
| 154 |
+
save_all=True,
|
| 155 |
+
append_images=pil_frames[1:],
|
| 156 |
+
duration=int(1000 / fps),
|
| 157 |
+
loop=0,
|
| 158 |
+
optimize=False,
|
| 159 |
+
)
|
| 160 |
+
return tmp.name
|