Daksh C Jain Claude Sonnet 4.6 commited on
Commit
d7456d6
Β·
1 Parent(s): 3fca800

Upgrade to full production MARL masterclass app

Browse files

- 3-tab Gradio UI: Mission Control, Training Lab, Algorithm Guide
- Animated GIF replay with HUD overlay (step, reward, throttle bars)
- Side-by-side comparison GIF for multi-episode runs
- 4-panel mission overview: reward bars, 2D trajectory, cumulative reward, engine throttle
- 6-panel episode deep-dive: trajectory, altitude, angle, throttle timelines
- SAC fine-tuning in background thread with live metrics refresh
- Training dashboard: reward history, actor/critic loss, entropy coefficient
- Environment controls: gravity, wind, turbulence sliders
- Modular structure: core/mission.py, core/trainer.py, viz/charts.py, viz/replay.py

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

app.py CHANGED
@@ -1,69 +1,194 @@
1
- import gradio as gr
2
- import gymnasium as gym
 
 
 
 
 
 
3
  import numpy as np
 
4
  from stable_baselines3 import SAC
5
- import time
6
-
7
- # Load the model
8
- model = SAC.load("sac_rocket_lander.zip")
9
-
10
- # ── Mission logic ──────────────────────────────────────────────────────────────
11
-
12
- def run_mission(episodes, progress=gr.Progress()):
13
- env = gym.make("LunarLander-v3", continuous=True)
14
- total_rewards = []
15
- episode_logs = []
16
-
17
- for i in range(int(episodes)):
18
- progress((i) / int(episodes), desc=f"Running landing attempt {i+1} of {int(episodes)}...")
19
- obs, _ = env.reset()
20
- done = False
21
- ep_reward = 0
22
- steps = 0
23
-
24
- while not done:
25
- action, _ = model.predict(obs, deterministic=True)
26
- obs, reward, terminated, truncated, _ = env.step(action)
27
- ep_reward += reward
28
- steps += 1
29
- done = terminated or truncated
30
-
31
- total_rewards.append(ep_reward)
32
- status = "βœ… LANDED" if ep_reward > 150 else ("⚠️ PARTIAL" if ep_reward > 0 else "πŸ’₯ CRASHED")
33
- episode_logs.append(f"Attempt {i+1:02d} | Score: {ep_reward:+.1f} | Steps: {steps:4d} | {status}")
34
-
35
- env.close()
36
- progress(1.0, desc="Mission complete.")
37
-
38
- avg = np.mean(total_rewards)
39
- best = np.max(total_rewards)
40
- worst = np.min(total_rewards)
41
- success_rate = sum(1 for r in total_rewards if r > 150) / len(total_rewards) * 100
42
-
43
- mission_status = "MISSION SUCCESS" if avg > 150 else "MISSION FAILURE"
44
- status_icon = "πŸš€" if avg > 150 else "πŸ’₯"
45
-
46
- log_output = "\n".join(episode_logs)
47
- summary = (
48
- f"{status_icon} {mission_status}\n\n"
49
- f"{'─'*38}\n"
50
- f" Average Score {avg:>+10.2f}\n"
51
- f" Best Landing {best:>+10.2f}\n"
52
- f" Worst Landing {worst:>+10.2f}\n"
53
- f" Success Rate {success_rate:>9.1f}%\n"
54
- f"{'─'*38}\n\n"
55
- f"FLIGHT LOG\n{log_output}"
 
 
 
 
 
 
 
 
56
  )
57
- return summary
58
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- # ── Custom CSS ──────────────────────���──────────────────────────────────────────
 
 
 
61
 
62
- custom_css = """
63
- /* ── Google Font: Space Grotesk not used β€” using Orbitron + Share Tech Mono ── */
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  @import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;700;900&family=Share+Tech+Mono&family=Exo+2:wght@300;400;600&display=swap');
65
 
66
- /* Reset & base */
67
  *, *::before, *::after { box-sizing: border-box; }
68
 
69
  body, .gradio-container {
@@ -72,296 +197,303 @@ body, .gradio-container {
72
  font-family: 'Exo 2', sans-serif !important;
73
  }
74
 
75
- .gradio-container {
76
- max-width: 860px !important;
77
- margin: 0 auto !important;
78
- padding: 0 1rem 3rem !important;
79
- }
80
 
81
- /* ── Header ── */
82
- .mission-header {
83
- text-align: center;
84
- padding: 2.5rem 1rem 1.5rem;
85
- position: relative;
 
 
 
 
 
 
86
  }
87
 
88
- .mission-header h1 {
 
 
 
 
89
  font-family: 'Orbitron', monospace !important;
90
- font-size: clamp(1.4rem, 4vw, 2.4rem) !important;
91
- font-weight: 900 !important;
92
- letter-spacing: 0.08em !important;
93
- color: #e8f4ff !important;
94
- margin: 0 0 0.35rem !important;
95
- text-transform: uppercase !important;
96
  }
97
-
98
- .mission-header .sub {
99
  font-family: 'Share Tech Mono', monospace;
100
- font-size: 0.78rem;
101
- color: #5b8fb5;
102
- letter-spacing: 0.25em;
103
- text-transform: uppercase;
104
  }
105
 
106
- .divider {
107
- border: none;
108
- border-top: 1px solid #0d2540;
109
- margin: 1.5rem 0;
110
- }
111
-
112
- /* ── Status badge strip ── */
113
  .status-strip {
114
- display: flex;
115
- gap: 0.75rem;
116
- justify-content: center;
117
- flex-wrap: wrap;
118
- margin: 1.2rem 0 2rem;
119
  }
120
-
121
  .badge {
122
  font-family: 'Share Tech Mono', monospace;
123
- font-size: 0.72rem;
124
- letter-spacing: 0.15em;
125
- padding: 5px 14px;
126
- border-radius: 3px;
127
- text-transform: uppercase;
128
- }
129
-
130
- .badge-green {
131
- background: #041e12;
132
- color: #2ddb7c;
133
- border: 1px solid #0a5530;
134
  }
 
 
 
 
135
 
136
- .badge-blue {
137
- background: #020f20;
138
- color: #4fb3ff;
139
- border: 1px solid #0b3362;
140
- }
141
-
142
- .badge-amber {
143
- background: #1a1002;
144
- color: #f5a623;
145
- border: 1px solid #5c3700;
146
  }
147
-
148
- /* ── Panel cards ── */
149
- .panel {
150
- background: #060f1e;
151
- border: 1px solid #0d2540;
152
- border-radius: 6px;
153
- padding: 1.5rem;
154
- margin-bottom: 1rem;
155
  }
156
-
157
- .panel-label {
158
- font-family: 'Share Tech Mono', monospace;
159
- font-size: 0.68rem;
160
- letter-spacing: 0.22em;
161
- text-transform: uppercase;
162
- color: #2d6a9f;
163
- margin-bottom: 1rem;
164
  }
165
 
166
- /* ── Slider ── */
167
- .slider-wrap label,
168
- .gradio-container label span {
169
  font-family: 'Share Tech Mono', monospace !important;
170
- font-size: 0.75rem !important;
171
- letter-spacing: 0.15em !important;
172
- text-transform: uppercase !important;
173
- color: #4fb3ff !important;
174
  }
175
-
176
  input[type=range] {
177
- -webkit-appearance: none;
178
- appearance: none;
179
- width: 100%;
180
- height: 3px;
181
- background: #0d2540;
182
- border-radius: 2px;
183
- outline: none;
184
- margin: 0.5rem 0;
185
  }
186
-
187
  input[type=range]::-webkit-slider-thumb {
188
- -webkit-appearance: none;
189
- width: 18px;
190
- height: 18px;
191
- border-radius: 50%;
192
- background: #4fb3ff;
193
- cursor: pointer;
194
- border: 2px solid #030b1a;
195
- box-shadow: 0 0 8px rgba(79,179,255,0.5);
196
  }
197
 
198
- input[type=range]::-moz-range-thumb {
199
- width: 18px;
200
- height: 18px;
201
- border-radius: 50%;
202
- background: #4fb3ff;
203
- cursor: pointer;
204
- border: 2px solid #030b1a;
205
  }
206
 
207
- /* ── Launch button ── */
208
- #launch-btn {
209
- font-family: 'Orbitron', monospace !important;
210
- font-size: 0.9rem !important;
211
- font-weight: 700 !important;
212
- letter-spacing: 0.18em !important;
213
- text-transform: uppercase !important;
214
- background: linear-gradient(135deg, #0a2a52 0%, #0d3a72 100%) !important;
215
- color: #4fb3ff !important;
216
- border: 1px solid #1a5a9e !important;
217
- border-radius: 4px !important;
218
- padding: 0.85rem 2rem !important;
219
- cursor: pointer !important;
220
- width: 100% !important;
221
- transition: all 0.2s ease !important;
222
- }
223
 
224
- #launch-btn:hover {
225
- background: linear-gradient(135deg, #0d3a72 0%, #1150a0 100%) !important;
226
- border-color: #4fb3ff !important;
227
- color: #a8d8ff !important;
228
- transform: translateY(-1px) !important;
229
- box-shadow: 0 4px 20px rgba(79,179,255,0.25) !important;
230
- }
231
 
232
- #launch-btn:active {
233
- transform: translateY(0) !important;
234
- }
235
 
236
- /* ── Output telemetry box ── */
237
- .telemetry textarea,
238
- #output-box textarea,
239
- .gradio-container textarea {
240
- font-family: 'Share Tech Mono', monospace !important;
241
- font-size: 0.82rem !important;
242
- line-height: 1.7 !important;
243
- background: #020810 !important;
244
- color: #7fcfff !important;
245
- border: 1px solid #0d2540 !important;
246
- border-radius: 4px !important;
247
- padding: 1.2rem !important;
248
- resize: none !important;
249
- caret-color: #4fb3ff !important;
250
- }
251
 
252
- .gradio-container textarea::selection {
253
- background: #0d3a72;
254
- }
255
 
256
- /* Progress bar */
257
- .progress-bar {
258
- background: #0d2540 !important;
259
- border-radius: 3px !important;
260
- }
261
 
262
- .progress-bar > div {
263
- background: linear-gradient(90deg, #1150a0, #4fb3ff) !important;
264
- border-radius: 3px !important;
265
- }
266
 
267
- /* ── Footer ── */
268
- .mission-footer {
269
- text-align: center;
270
- font-family: 'Share Tech Mono', monospace;
271
- font-size: 0.65rem;
272
- color: #1e3d5c;
273
- letter-spacing: 0.2em;
274
- text-transform: uppercase;
275
- padding: 2rem 0 0;
276
- }
277
 
278
- /* ── Misc Gradio overrides ── */
279
- .gradio-container .prose,
280
- .gradio-container p {
281
- color: #5b8fb5 !important;
282
- font-family: 'Exo 2', sans-serif !important;
283
- font-size: 0.85rem !important;
284
- }
285
 
286
- footer { display: none !important; }
287
 
288
- .gradio-container .output-class,
289
- .gradio-container .block {
290
- background: transparent !important;
291
- border: none !important;
292
- }
293
 
294
- .gradio-container .form {
295
- background: transparent !important;
296
- }
297
- """
298
 
299
- # ── Header HTML ───────────────────────────────────────────────────────────────
300
-
301
- header_html = """
302
- <div class="mission-header">
303
- <div class="sub">Autonomous Flight Intelligence System Β· v2.0</div>
304
- <h1>⬑ SpaceX Mission Control</h1>
305
- <div class="sub">SAC Neural Lander Β· LunarLander-v3 Simulation</div>
306
- </div>
307
- <hr class="divider"/>
308
- <div class="status-strip">
309
- <span class="badge badge-green">● SAC MODEL LOADED</span>
310
- <span class="badge badge-blue">● SIMULATION READY</span>
311
- <span class="badge badge-amber">β—ˆ AWAITING LAUNCH</span>
312
- </div>
313
- """
314
 
315
- footer_html = """
316
- <div class="mission-footer">
317
- Powered by Stable-Baselines3 Β· Soft Actor-Critic Β· Gymnasium LunarLander-v3
318
- </div>
319
- """
320
 
321
- # ── Build Gradio UI ───────────────────────────────────────────────────────────
 
 
 
 
 
322
 
323
- with gr.Blocks(title="SpaceX Mission Control") as demo:
324
 
325
- gr.HTML(header_html)
326
 
327
- with gr.Row():
328
- with gr.Column():
329
- gr.Markdown("**MISSION PARAMETERS**", elem_classes=["panel-label"])
330
- attempts = gr.Slider(
331
- minimum=1,
332
- maximum=5,
333
- value=1,
334
- step=1,
335
- label="Landing Attempts",
336
- info="Select the number of autonomous landing simulations to execute",
337
- elem_id="attempts-slider",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  )
339
- launch_btn = gr.Button(
340
- "πŸš€ INITIATE LAUNCH SEQUENCE",
341
- elem_id="launch-btn",
342
- variant="primary",
343
  )
344
 
345
- with gr.Row():
346
- with gr.Column():
347
- gr.Markdown("**FLIGHT TELEMETRY Β· MISSION REPORT**", elem_classes=["panel-label"])
348
- output = gr.Textbox(
349
- label="",
350
- lines=18,
351
- max_lines=24,
352
- placeholder="Awaiting telemetry data...\n\nPress INITIATE LAUNCH SEQUENCE to begin simulation.",
353
- elem_id="output-box",
354
- elem_classes=["telemetry"],
355
- )
356
 
357
- launch_btn.click(
358
- fn=run_mission,
359
- inputs=[attempts],
360
- outputs=[output],
361
- show_progress="full",
362
- )
363
 
364
- gr.HTML(footer_html)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
 
366
  if __name__ == "__main__":
367
- demo.launch()
 
1
+ """
2
+ SpaceX Mission Control β€” SAC Rocket Lander
3
+ Production Gradio application: simulate, visualise, analyse, and train
4
+ a Soft Actor-Critic agent on the LunarLander-v3 continuous control task.
5
+ """
6
+
7
+ from __future__ import annotations
8
+ import os
9
  import numpy as np
10
+ import gradio as gr
11
  from stable_baselines3 import SAC
12
+
13
+ from core.mission import run_mission, MissionResult
14
+ from core.trainer import TrainingState, start_training
15
+ from viz.charts import (
16
+ mission_overview, single_episode_detail,
17
+ training_dashboard, empty_figure,
18
+ )
19
+ from viz.replay import make_episode_gif, make_comparison_gif
20
+
21
+ # ── Model loading ─────────────────────────────────────────────────────────────
22
+
23
+ _MODEL_PATHS = ["sac_finetuned.zip", "sac_rocket_lander.zip"]
24
+ _model: SAC | None = None
25
+
26
+ def _load_model(path: str | None = None) -> tuple[SAC, str]:
27
+ candidates = ([path] if path else []) + _MODEL_PATHS
28
+ for p in candidates:
29
+ if p and os.path.exists(p):
30
+ try:
31
+ return SAC.load(p), p
32
+ except Exception:
33
+ continue
34
+ raise FileNotFoundError("No valid SAC checkpoint found.")
35
+
36
+ def _get_model() -> SAC:
37
+ global _model
38
+ if _model is None:
39
+ _model, _ = _load_model()
40
+ return _model
41
+
42
+ # ── Global training state ─────────────────────────────────────────────────────
43
+ _train_state = TrainingState()
44
+
45
+ # ── Callbacks ─────────────────────────────────────────────────────────────────
46
+
47
+ def cb_run_mission(
48
+ n_episodes: int,
49
+ gravity: float,
50
+ enable_wind: bool,
51
+ wind_power: float,
52
+ turbulence: float,
53
+ render_gif: bool,
54
+ progress: gr.Progress = gr.Progress(),
55
+ ) -> tuple:
56
+ try:
57
+ model = _get_model()
58
+ except FileNotFoundError as e:
59
+ empty = empty_figure(str(e))
60
+ return empty, None, empty, str(e), gr.update(choices=[])
61
+
62
+ mission, all_frames = run_mission(
63
+ model,
64
+ n_episodes=int(n_episodes),
65
+ gravity=float(gravity),
66
+ enable_wind=bool(enable_wind),
67
+ wind_power=float(wind_power),
68
+ turbulence_power=float(turbulence),
69
+ render=bool(render_gif),
70
+ progress_cb=progress,
71
  )
 
72
 
73
+ overview_fig = mission_overview(mission)
74
+
75
+ gif_path = None
76
+ if render_gif and all_frames:
77
+ if n_episodes >= 2:
78
+ gif_path = make_comparison_gif(all_frames, mission.episodes, fps=15)
79
+ else:
80
+ gif_path = make_episode_gif(all_frames[0], mission.episodes[0], fps=15)
81
+
82
+ best = mission.best
83
+ detail_fig = single_episode_detail(best)
84
 
85
+ sr = mission.success_rate * 100
86
+ icon = "πŸš€" if mission.avg_reward >= 150 else "πŸ’₯"
87
+ stats_md = f"""
88
+ ### {icon} Mission Complete
89
 
90
+ | Metric | Value |
91
+ |---|---|
92
+ | **Avg Reward** | `{mission.avg_reward:+.2f}` |
93
+ | **Best** | `{best.total_reward:+.2f}` ({best.status_emoji} Ep {best.episode_idx+1}) |
94
+ | **Worst** | `{mission.worst.total_reward:+.2f}` ({mission.worst.status_emoji} Ep {mission.worst.episode_idx+1}) |
95
+ | **Success Rate** | `{sr:.1f}%` |
96
+ | **Episodes** | `{len(mission.episodes)}` |
97
+
98
+ **Per-Episode Scores:**
99
+ """
100
+ per_ep = "".join(
101
+ f"- `#{e.episode_idx+1}` {e.status_emoji} **{e.status}** β€” {e.total_reward:+.1f} ({len(e.steps)} steps)\n"
102
+ for e in mission.episodes
103
+ )
104
+ stats_md += per_ep
105
+
106
+ ep_choices = [
107
+ f"#{e.episode_idx+1} β€” {e.status_emoji} {e.status} ({e.total_reward:+.1f})"
108
+ for e in mission.episodes
109
+ ]
110
+
111
+ _last_mission["data"] = mission
112
+ _last_mission["frames"] = all_frames
113
+
114
+ return overview_fig, gif_path, detail_fig, stats_md, gr.update(choices=ep_choices, value=ep_choices[0])
115
+
116
+
117
+ _last_mission: dict = {"data": None, "frames": []}
118
+
119
+
120
+ def cb_select_episode(selection: str) -> tuple:
121
+ mission: MissionResult | None = _last_mission.get("data")
122
+ all_frames = _last_mission.get("frames", [])
123
+ if not mission or not selection:
124
+ return empty_figure("Run a mission first."), None
125
+ try:
126
+ idx = int(selection.split("#")[1].split(" ")[0]) - 1
127
+ except Exception:
128
+ idx = 0
129
+ ep = mission.episodes[idx]
130
+ fig = single_episode_detail(ep)
131
+ gif = None
132
+ if all_frames and idx < len(all_frames):
133
+ gif = make_episode_gif(all_frames[idx], ep, fps=15)
134
+ return fig, gif
135
+
136
+
137
+ def cb_start_training(total_steps: int, lr: float, batch_size: int) -> str:
138
+ global _train_state
139
+ if _train_state.running:
140
+ return "Training already in progress."
141
+ _train_state = TrainingState()
142
+ start_training(
143
+ base_model_path="sac_rocket_lander.zip",
144
+ total_timesteps=int(total_steps),
145
+ learning_rate=float(lr),
146
+ batch_size=int(batch_size),
147
+ state=_train_state,
148
+ save_path="sac_finetuned.zip",
149
+ )
150
+ return "Training started. Click **Refresh** to update charts."
151
+
152
+
153
+ def cb_stop_training() -> str:
154
+ _train_state.running = False
155
+ return "Stop signal sent."
156
+
157
+
158
+ def cb_refresh_training() -> tuple:
159
+ fig = training_dashboard(_train_state)
160
+ n_ep = len(_train_state.episode_rewards)
161
+ rolling = float(np.mean(_train_state.episode_rewards[-20:])) if n_ep else 0.0
162
+ progress_pct = (_train_state.timestep / max(_train_state.total_timesteps, 1)) * 100
163
+ status_md = f"""
164
+ | | |
165
+ |---|---|
166
+ | **Status** | `{_train_state.status}` |
167
+ | **Progress** | `{progress_pct:.1f}%` |
168
+ | **Episodes** | `{n_ep}` |
169
+ | **Rolling Reward (20)** | `{rolling:+.1f}` |
170
+ | **Best Reward** | `{_train_state.best_reward:+.1f}` |
171
+ """
172
+ return fig, status_md
173
+
174
+
175
+ def cb_load_finetuned() -> str:
176
+ global _model
177
+ for path in _MODEL_PATHS:
178
+ if os.path.exists(path):
179
+ try:
180
+ _model = SAC.load(path)
181
+ return f"Model loaded from `{path}`."
182
+ except Exception as e:
183
+ return f"Failed to load `{path}`: {e}"
184
+ return "No checkpoint found."
185
+
186
+
187
+ # ── CSS ───────────────────────────────────────────────────────────────────────
188
+
189
+ CSS = """
190
  @import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;700;900&family=Share+Tech+Mono&family=Exo+2:wght@300;400;600&display=swap');
191
 
 
192
  *, *::before, *::after { box-sizing: border-box; }
193
 
194
  body, .gradio-container {
 
197
  font-family: 'Exo 2', sans-serif !important;
198
  }
199
 
200
+ .gradio-container { max-width: 1200px !important; margin: 0 auto !important; }
 
 
 
 
201
 
202
+ .tab-nav { background: #060f1e !important; border-bottom: 1px solid #0d2540 !important; }
203
+ .tab-nav button {
204
+ font-family: 'Share Tech Mono', monospace !important;
205
+ font-size: 0.72rem !important; letter-spacing: 0.18em !important;
206
+ color: #3a6080 !important; background: transparent !important;
207
+ border: none !important; text-transform: uppercase !important;
208
+ padding: 0.7rem 1.4rem !important;
209
+ }
210
+ .tab-nav button.selected {
211
+ color: #4fb3ff !important;
212
+ border-bottom: 2px solid #4fb3ff !important;
213
  }
214
 
215
+ .mc-header {
216
+ text-align: center; padding: 2rem 1rem 1rem;
217
+ border-bottom: 1px solid #0d2540; margin-bottom: 1.5rem;
218
+ }
219
+ .mc-header h1 {
220
  font-family: 'Orbitron', monospace !important;
221
+ font-size: clamp(1.4rem, 3.5vw, 2.2rem) !important;
222
+ font-weight: 900 !important; letter-spacing: 0.1em !important;
223
+ color: #e8f4ff !important; margin: 0 !important;
 
 
 
224
  }
225
+ .mc-sub {
 
226
  font-family: 'Share Tech Mono', monospace;
227
+ font-size: 0.72rem; color: #2d6a9f;
228
+ letter-spacing: 0.3em; text-transform: uppercase; margin-top: 0.3rem;
 
 
229
  }
230
 
 
 
 
 
 
 
 
231
  .status-strip {
232
+ display: flex; gap: 0.6rem; justify-content: center;
233
+ flex-wrap: wrap; margin: 1rem 0;
 
 
 
234
  }
 
235
  .badge {
236
  font-family: 'Share Tech Mono', monospace;
237
+ font-size: 0.68rem; letter-spacing: 0.15em;
238
+ padding: 4px 12px; border-radius: 3px; text-transform: uppercase;
 
 
 
 
 
 
 
 
 
239
  }
240
+ .badge-green { background:#041e12; color:#2ddb7c; border:1px solid #0a5530; }
241
+ .badge-blue { background:#020f20; color:#4fb3ff; border:1px solid #0b3362; }
242
+ .badge-amber { background:#1a1002; color:#f5a623; border:1px solid #5c3700; }
243
+ .badge-purple { background:#120920; color:#c77dff; border:1px solid #4a1a7a; }
244
 
245
+ button.primary {
246
+ font-family: 'Orbitron', monospace !important;
247
+ font-size: 0.82rem !important; font-weight: 700 !important;
248
+ letter-spacing: 0.15em !important; text-transform: uppercase !important;
249
+ background: linear-gradient(135deg,#0a2a52,#0d3a72) !important;
250
+ color: #4fb3ff !important; border: 1px solid #1a5a9e !important;
251
+ border-radius: 4px !important; transition: all 0.2s !important;
 
 
 
252
  }
253
+ button.primary:hover {
254
+ background: linear-gradient(135deg,#0d3a72,#1150a0) !important;
255
+ border-color: #4fb3ff !important;
256
+ box-shadow: 0 4px 20px rgba(79,179,255,0.25) !important;
 
 
 
 
257
  }
258
+ button.stop {
259
+ background: linear-gradient(135deg,#2a0a0a,#4a1010) !important;
260
+ color: #ff4d6d !important; border: 1px solid #7a1a1a !important;
261
+ font-family: 'Share Tech Mono', monospace !important;
 
 
 
 
262
  }
263
 
264
+ label span, .gradio-container label {
 
 
265
  font-family: 'Share Tech Mono', monospace !important;
266
+ font-size: 0.72rem !important; letter-spacing: 0.15em !important;
267
+ text-transform: uppercase !important; color: #4fb3ff !important;
 
 
268
  }
 
269
  input[type=range] {
270
+ -webkit-appearance: none; height: 3px;
271
+ background: #0d2540; border-radius: 2px; outline: none;
 
 
 
 
 
 
272
  }
 
273
  input[type=range]::-webkit-slider-thumb {
274
+ -webkit-appearance: none; width: 16px; height: 16px;
275
+ border-radius: 50%; background: #4fb3ff; cursor: pointer;
276
+ border: 2px solid #030b1a; box-shadow: 0 0 8px rgba(79,179,255,0.5);
 
 
 
 
 
277
  }
278
 
279
+ textarea, .gradio-container textarea {
280
+ font-family: 'Share Tech Mono', monospace !important;
281
+ font-size: 0.82rem !important; line-height: 1.7 !important;
282
+ background: #020810 !important; color: #7fcfff !important;
283
+ border: 1px solid #0d2540 !important; border-radius: 4px !important;
 
 
284
  }
285
 
286
+ table { width: 100%; border-collapse: collapse; }
287
+ th { background: #060f1e; color: #4fb3ff;
288
+ font-family: 'Share Tech Mono', monospace;
289
+ font-size: 0.7rem; letter-spacing: 0.1em; padding: 6px 10px; }
290
+ td { border-top: 1px solid #0d2540; padding: 6px 10px;
291
+ color: #c8ddf0; font-size: 0.85rem; }
 
 
 
 
 
 
 
 
 
 
292
 
293
+ footer { display: none !important; }
294
+ .gradio-container .block { background: transparent !important; border: none !important; }
295
+ """
 
 
 
 
296
 
297
+ # ── Theory ────────────────────────────────────────────────────────────────────
 
 
298
 
299
+ THEORY_MD = """
300
+ ## Soft Actor-Critic (SAC)
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
+ SAC is an **off-policy, maximum-entropy** deep RL algorithm for continuous
303
+ action spaces. It simultaneously maximises expected return *and* policy entropy,
304
+ encouraging exploration while converging to a stable policy.
305
 
306
+ ### Objective
307
+ $$J(\\pi) = \\sum_t \\mathbb{E}_{(s_t,a_t)\\sim\\rho_\\pi}\\left[ r(s_t,a_t) + \\alpha\\,\\mathcal{H}(\\pi(\\cdot|s_t)) \\right]$$
 
 
 
308
 
309
+ The temperature $\\alpha$ is **auto-tuned** to a target entropy level.
 
 
 
310
 
311
+ ### Architecture
 
 
 
 
 
 
 
 
 
312
 
313
+ | Component | Role |
314
+ |---|---|
315
+ | **Actor** $\\pi_\\phi(a\\|s)$ | Gaussian policy β€” outputs mean & log-std |
316
+ | **Critic 1** $Q_{\\theta_1}(s,a)$ | Q-value estimator |
317
+ | **Critic 2** $Q_{\\theta_2}(s,a)$ | Clipped double-Q: take min to reduce overestimation |
318
+ | **Target Critics** | Soft-updated copies ($\\tau=0.005$) for stable TD targets |
 
319
 
320
+ ### Update Rules
321
 
322
+ **Critic** β€” minimise Bellman residual:
323
+ $$y = r + \\gamma\\min_i Q_{\\bar\\theta_i}(s',\\tilde a') - \\alpha\\log\\pi(\\tilde a'|s')$$
 
 
 
324
 
325
+ **Actor** β€” maximise Q + entropy:
326
+ $$\\mathcal{L}(\\phi) = \\mathbb{E}\\left[\\alpha\\log\\pi_\\phi(a|s) - \\min_i Q_{\\theta_i}(s,a)\\right]$$
 
 
327
 
328
+ **Temperature** β€” match target entropy $\\bar{\\mathcal{H}}$:
329
+ $$\\mathcal{L}(\\alpha) = \\mathbb{E}\\left[-\\alpha(\\log\\pi(a|s)+\\bar{\\mathcal{H}})\\right]$$
 
 
 
 
 
 
 
 
 
 
 
 
 
330
 
331
+ ---
332
+
333
+ ## LunarLander-v3 (Continuous)
 
 
334
 
335
+ | Property | Value |
336
+ |---|---|
337
+ | **State** | 8-dim: pos (x,y), vel (vx,vy), angle, angular vel, leg contacts |
338
+ | **Action** | 2-dim continuous: main throttle, lateral thrust ∈ [βˆ’1,1] |
339
+ | **Reward** | +100 each leg contact, +100 landing, βˆ’100 crash |
340
+ | **Solved** | Episode reward β‰₯ 200 |
341
 
342
+ ---
343
 
344
+ ## Model Hyperparameters
345
 
346
+ | Parameter | Value |
347
+ |---|---|
348
+ | `learning_rate` | 3Γ—10⁻⁴ |
349
+ | `buffer_size` | 1,000,000 |
350
+ | `batch_size` | 256 |
351
+ | `tau` | 0.005 |
352
+ | `gamma` | 0.99 |
353
+ | `target_entropy` | βˆ’2.0 |
354
+
355
+ ---
356
+
357
+ ## Reading the Charts
358
+
359
+ - **Reward bars**: green β‰₯ 150, amber β‰₯ 0, red < 0
360
+ - **Trajectory plot**: `β˜…` = successful landing, `Γ—` = crash
361
+ - **Engine throttle**: main (blue) fires downward; lateral (amber) steers
362
+ - **Training reward**: smoothed line (solid) trends matter more than raw (faded)
363
+ - **Actor loss**: negative values normal β€” actor maximises Q, so loss = βˆ’Q
364
+ - **Entropy coef**: starts high, decreases as policy converges
365
+ """
366
+
367
+ # ── Build UI ──────────────────────────────────────────────────────────────────
368
+
369
+ with gr.Blocks(title="SpaceX Mission Control β€” SAC Rocket Lander") as demo:
370
+
371
+ gr.HTML("""
372
+ <div class="mc-header">
373
+ <div class="mc-sub">Autonomous Flight Intelligence System Β· SAC v2.0</div>
374
+ <h1>⬑ SpaceX Mission Control</h1>
375
+ <div class="mc-sub">Soft Actor-Critic Β· LunarLander-v3 Β· Continuous Control</div>
376
+ </div>
377
+ <div class="status-strip">
378
+ <span class="badge badge-green">● SAC MODEL LOADED</span>
379
+ <span class="badge badge-blue">● PHYSICS ENGINE READY</span>
380
+ <span class="badge badge-amber">● TELEMETRY ONLINE</span>
381
+ <span class="badge badge-purple">● TRAINING MODULE ARMED</span>
382
+ </div>
383
+ """)
384
+
385
+ with gr.Tabs():
386
+
387
+ # ── Mission Control ────────────────────────────────────────────────
388
+ with gr.Tab("πŸš€ MISSION CONTROL"):
389
+
390
+ with gr.Row():
391
+ with gr.Column(scale=1, min_width=300):
392
+ gr.HTML('<div class="mc-sub" style="margin-bottom:0.8rem">MISSION PARAMETERS</div>')
393
+
394
+ n_episodes = gr.Slider(1, 10, value=3, step=1,
395
+ label="Landing Attempts")
396
+ gravity = gr.Slider(-20.0, -1.0, value=-10.0, step=0.5,
397
+ label="Gravity (m/sΒ²)")
398
+ enable_wind = gr.Checkbox(label="Enable Wind Disturbance", value=False)
399
+ wind_power = gr.Slider(0.0, 20.0, value=5.0, step=0.5,
400
+ label="Wind Power", visible=False)
401
+ turbulence = gr.Slider(0.0, 2.0, value=0.5, step=0.1,
402
+ label="Turbulence Power", visible=False)
403
+ render_gif = gr.Checkbox(label="Render Animated Replay", value=True)
404
+
405
+ enable_wind.change(
406
+ lambda v: (gr.update(visible=v), gr.update(visible=v)),
407
+ inputs=enable_wind,
408
+ outputs=[wind_power, turbulence],
409
+ )
410
+
411
+ launch_btn = gr.Button("πŸš€ INITIATE LAUNCH SEQUENCE", variant="primary")
412
+
413
+ gr.HTML('<div class="mc-sub" style="margin-top:1.2rem;margin-bottom:0.4rem">MODEL</div>')
414
+ load_btn = gr.Button("πŸ“‚ Reload Checkpoint")
415
+ load_status = gr.Textbox(label="", lines=1, interactive=False,
416
+ placeholder="Model status…")
417
+ load_btn.click(cb_load_finetuned, outputs=load_status)
418
+
419
+ with gr.Column(scale=2):
420
+ stats_md = gr.Markdown("*Configure mission parameters and click Launch.*")
421
+ episode_selector = gr.Dropdown(
422
+ choices=[], label="Inspect Episode", interactive=True,
423
+ )
424
+
425
+ with gr.Row():
426
+ overview_plot = gr.Plot(label="Mission Overview Dashboard")
427
+
428
+ with gr.Row():
429
+ with gr.Column(scale=1):
430
+ detail_plot = gr.Plot(label="Episode Deep-Dive")
431
+ with gr.Column(scale=1):
432
+ replay_gif = gr.Image(
433
+ label="Episode Replay (GIF with HUD)",
434
+ type="filepath",
435
+ )
436
+
437
+ episode_selector.change(
438
+ cb_select_episode,
439
+ inputs=episode_selector,
440
+ outputs=[detail_plot, replay_gif],
441
  )
442
+ launch_btn.click(
443
+ cb_run_mission,
444
+ inputs=[n_episodes, gravity, enable_wind, wind_power, turbulence, render_gif],
445
+ outputs=[overview_plot, replay_gif, detail_plot, stats_md, episode_selector],
446
  )
447
 
448
+ # ── Training Lab ───────────────────────────────────────────────────
449
+ with gr.Tab("πŸ§ͺ TRAINING LAB"):
 
 
 
 
 
 
 
 
 
450
 
451
+ gr.Markdown("### Fine-tune the SAC agent in your browser")
452
+ gr.Markdown(
453
+ "Runs in a background thread β€” click **Refresh Metrics** to pull updates. "
454
+ "The fine-tuned model saves to `sac_finetuned.zip` and is used automatically."
455
+ )
 
456
 
457
+ with gr.Row():
458
+ with gr.Column(scale=1):
459
+ gr.HTML('<div class="mc-sub" style="margin-bottom:0.8rem">HYPERPARAMETERS</div>')
460
+ train_steps = gr.Slider(5_000, 200_000, value=20_000, step=5_000,
461
+ label="Total Timesteps")
462
+ train_lr = gr.Slider(1e-5, 1e-3, value=3e-4, step=1e-5,
463
+ label="Learning Rate")
464
+ train_batch = gr.Slider(64, 512, value=256, step=64,
465
+ label="Batch Size")
466
+
467
+ with gr.Row():
468
+ btn_train_start = gr.Button("β–Ά Start Training", variant="primary")
469
+ btn_train_stop = gr.Button("⏹ Stop", variant="stop")
470
+ btn_refresh = gr.Button("πŸ”„ Refresh Metrics")
471
+ train_msg = gr.Textbox(label="", lines=2, interactive=False)
472
+
473
+ with gr.Column(scale=2):
474
+ train_status_md = gr.Markdown("*Start training to see live metrics.*")
475
+ train_plot = gr.Plot(label="Live Training Dashboard")
476
+
477
+ btn_train_start.click(
478
+ cb_start_training,
479
+ inputs=[train_steps, train_lr, train_batch],
480
+ outputs=train_msg,
481
+ )
482
+ btn_train_stop.click(cb_stop_training, outputs=train_msg)
483
+ btn_refresh.click(cb_refresh_training, outputs=[train_plot, train_status_md])
484
+
485
+ # ── Algorithm Guide ────────────────────────────────────────────────
486
+ with gr.Tab("πŸ“š ALGORITHM GUIDE"):
487
+ gr.Markdown(THEORY_MD)
488
+
489
+ gr.HTML("""
490
+ <div style="text-align:center;font-family:'Share Tech Mono',monospace;
491
+ font-size:0.65rem;color:#1e3d5c;letter-spacing:0.2em;
492
+ text-transform:uppercase;padding:2rem 0 1rem;">
493
+ Powered by Stable-Baselines3 Β· Soft Actor-Critic Β·
494
+ Gymnasium LunarLander-v3 Β· Gradio
495
+ </div>
496
+ """)
497
 
498
  if __name__ == "__main__":
499
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False, css=CSS)
core/__init__.py ADDED
File without changes
core/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (197 Bytes). View file
 
core/__pycache__/mission.cpython-311.pyc ADDED
Binary file (10.4 kB). View file
 
core/__pycache__/trainer.cpython-311.pyc ADDED
Binary file (7.17 kB). View file
 
core/mission.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mission runner β€” executes SAC agent episodes, collects full telemetry.
3
+ Returns structured data for both the UI and the visualization layer.
4
+ """
5
+
6
+ from __future__ import annotations
7
+ import numpy as np
8
+ import gymnasium as gym
9
+ from dataclasses import dataclass, field
10
+ from stable_baselines3 import SAC
11
+
12
+
13
+ # ── Telemetry data structures ─────────────────────────────────────────────────
14
+
15
+ @dataclass
16
+ class StepData:
17
+ x: float
18
+ y: float
19
+ vx: float
20
+ vy: float
21
+ angle: float
22
+ angular_vel: float
23
+ left_leg: bool
24
+ right_leg: bool
25
+ reward: float
26
+ action_main: float # main engine throttle [-1, 1]
27
+ action_lateral: float # lateral thruster [-1, 1]
28
+
29
+
30
+ @dataclass
31
+ class EpisodeResult:
32
+ episode_idx: int
33
+ steps: list[StepData] = field(default_factory=list)
34
+ total_reward: float = 0.0
35
+ landed: bool = False
36
+ crashed: bool = False
37
+
38
+ @property
39
+ def status(self) -> str:
40
+ if self.total_reward >= 200:
41
+ return "PERFECT"
42
+ if self.total_reward >= 150:
43
+ return "LANDED"
44
+ if self.total_reward >= 0:
45
+ return "PARTIAL"
46
+ return "CRASHED"
47
+
48
+ @property
49
+ def status_emoji(self) -> str:
50
+ return {"PERFECT": "πŸ†", "LANDED": "βœ…", "PARTIAL": "⚠️", "CRASHED": "πŸ’₯"}[self.status]
51
+
52
+ @property
53
+ def xs(self) -> list[float]:
54
+ return [s.x for s in self.steps]
55
+
56
+ @property
57
+ def ys(self) -> list[float]:
58
+ return [s.y for s in self.steps]
59
+
60
+ @property
61
+ def cumulative_rewards(self) -> list[float]:
62
+ total = 0.0
63
+ out = []
64
+ for s in self.steps:
65
+ total += s.reward
66
+ out.append(total)
67
+ return out
68
+
69
+ @property
70
+ def main_throttle(self) -> list[float]:
71
+ return [s.action_main for s in self.steps]
72
+
73
+ @property
74
+ def lateral_throttle(self) -> list[float]:
75
+ return [s.action_lateral for s in self.steps]
76
+
77
+ @property
78
+ def angles(self) -> list[float]:
79
+ return [np.degrees(s.angle) for s in self.steps]
80
+
81
+
82
+ @dataclass
83
+ class MissionResult:
84
+ episodes: list[EpisodeResult] = field(default_factory=list)
85
+
86
+ @property
87
+ def rewards(self) -> list[float]:
88
+ return [e.total_reward for e in self.episodes]
89
+
90
+ @property
91
+ def success_rate(self) -> float:
92
+ if not self.episodes:
93
+ return 0.0
94
+ return sum(1 for e in self.episodes if e.total_reward >= 150) / len(self.episodes)
95
+
96
+ @property
97
+ def avg_reward(self) -> float:
98
+ return float(np.mean(self.rewards)) if self.rewards else 0.0
99
+
100
+ @property
101
+ def best(self) -> EpisodeResult:
102
+ return max(self.episodes, key=lambda e: e.total_reward)
103
+
104
+ @property
105
+ def worst(self) -> EpisodeResult:
106
+ return min(self.episodes, key=lambda e: e.total_reward)
107
+
108
+
109
+ # ── Runner ────────────────────────────────────────────────────────────────────
110
+
111
+ def run_mission(
112
+ model: SAC,
113
+ n_episodes: int = 5,
114
+ gravity: float = -10.0,
115
+ enable_wind: bool = False,
116
+ wind_power: float = 5.0,
117
+ turbulence_power: float = 0.5,
118
+ render: bool = True,
119
+ progress_cb=None,
120
+ ) -> tuple[MissionResult, list[list[np.ndarray]]]:
121
+ """
122
+ Run `n_episodes` of the lander.
123
+ Returns (MissionResult, list_of_frame_lists) β€” one frame list per episode.
124
+ """
125
+ mission = MissionResult()
126
+ all_frames: list[list[np.ndarray]] = []
127
+
128
+ env_kwargs = dict(
129
+ continuous=True,
130
+ gravity=gravity,
131
+ enable_wind=enable_wind,
132
+ wind_power=wind_power if enable_wind else 0.0,
133
+ turbulence_power=turbulence_power if enable_wind else 0.0,
134
+ render_mode="rgb_array" if render else None,
135
+ )
136
+
137
+ for ep_idx in range(n_episodes):
138
+ if progress_cb:
139
+ progress_cb(ep_idx / n_episodes, f"Running mission {ep_idx + 1}/{n_episodes}…")
140
+
141
+ env = gym.make("LunarLander-v3", **env_kwargs)
142
+ obs, _ = env.reset()
143
+
144
+ result = EpisodeResult(episode_idx=ep_idx)
145
+ frames: list[np.ndarray] = []
146
+
147
+ done = False
148
+ while not done:
149
+ action, _ = model.predict(obs, deterministic=True)
150
+ next_obs, reward, terminated, truncated, _ = env.step(action)
151
+
152
+ result.steps.append(StepData(
153
+ x=float(obs[0]), y=float(obs[1]),
154
+ vx=float(obs[2]), vy=float(obs[3]),
155
+ angle=float(obs[4]), angular_vel=float(obs[5]),
156
+ left_leg=bool(obs[6]), right_leg=bool(obs[7]),
157
+ reward=float(reward),
158
+ action_main=float(action[0]),
159
+ action_lateral=float(action[1]),
160
+ ))
161
+ result.total_reward += float(reward)
162
+
163
+ if render:
164
+ frame = env.render()
165
+ if frame is not None:
166
+ frames.append(frame)
167
+
168
+ obs = next_obs
169
+ done = terminated or truncated
170
+
171
+ env.close()
172
+ mission.episodes.append(result)
173
+ all_frames.append(frames)
174
+
175
+ if progress_cb:
176
+ progress_cb(1.0, "Mission complete.")
177
+
178
+ return mission, all_frames
core/trainer.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SAC training pipeline β€” fine-tune or train from scratch with live callbacks.
3
+ """
4
+
5
+ from __future__ import annotations
6
+ import os
7
+ import threading
8
+ from dataclasses import dataclass, field
9
+ from stable_baselines3 import SAC
10
+ from stable_baselines3.common.callbacks import BaseCallback
11
+ import gymnasium as gym
12
+ import numpy as np
13
+
14
+
15
+ @dataclass
16
+ class TrainingState:
17
+ running: bool = False
18
+ timestep: int = 0
19
+ total_timesteps: int = 0
20
+ episode_rewards: list[float] = field(default_factory=list)
21
+ actor_losses: list[float] = field(default_factory=list)
22
+ critic_losses: list[float] = field(default_factory=list)
23
+ ent_coefs: list[float] = field(default_factory=list)
24
+ log_steps: list[int] = field(default_factory=list)
25
+ status: str = "idle"
26
+ best_reward: float = float("-inf")
27
+
28
+
29
+ class _LiveCallback(BaseCallback):
30
+ def __init__(self, state: TrainingState, log_interval: int = 500):
31
+ super().__init__()
32
+ self._state = state
33
+ self._log_interval = log_interval
34
+ self._ep_rewards: list[float] = []
35
+
36
+ def _on_step(self) -> bool:
37
+ if not self._state.running:
38
+ return False # abort training
39
+
40
+ self._state.timestep = self.num_timesteps
41
+
42
+ # Collect episode rewards from monitor wrapper
43
+ infos = self.locals.get("infos", [])
44
+ for info in infos:
45
+ if "episode" in info:
46
+ r = float(info["episode"]["r"])
47
+ self._ep_rewards.append(r)
48
+ self._state.episode_rewards.append(r)
49
+ if r > self._state.best_reward:
50
+ self._state.best_reward = r
51
+
52
+ if self.num_timesteps % self._log_interval == 0:
53
+ losses = self.model.logger.name_to_value
54
+ self._state.actor_losses.append(float(losses.get("train/actor_loss", 0)))
55
+ self._state.critic_losses.append(float(losses.get("train/critic_loss", 0)))
56
+ self._state.ent_coefs.append(float(losses.get("train/ent_coef", 0)))
57
+ self._state.log_steps.append(self.num_timesteps)
58
+
59
+ pct = self.num_timesteps / max(self._state.total_timesteps, 1)
60
+ rolling = float(np.mean(self._ep_rewards[-20:])) if self._ep_rewards else 0.0
61
+ self._state.status = (
62
+ f"Step {self.num_timesteps:,}/{self._state.total_timesteps:,} "
63
+ f"({pct*100:.1f}%) | Rolling reward: {rolling:+.1f} | "
64
+ f"Best: {self._state.best_reward:+.1f}"
65
+ )
66
+ return True
67
+
68
+ def _on_training_end(self) -> None:
69
+ self._state.status = (
70
+ f"Training complete β€” {self.num_timesteps:,} steps. "
71
+ f"Best reward: {self._state.best_reward:+.1f}"
72
+ )
73
+ self._state.running = False
74
+
75
+
76
+ def start_training(
77
+ base_model_path: str,
78
+ total_timesteps: int,
79
+ learning_rate: float,
80
+ batch_size: int,
81
+ state: TrainingState,
82
+ save_path: str = "sac_finetuned.zip",
83
+ ) -> threading.Thread:
84
+ """Launches training in a daemon thread. Progress written to `state`."""
85
+
86
+ def _train():
87
+ from stable_baselines3.common.monitor import Monitor
88
+
89
+ state.running = True
90
+ state.total_timesteps = total_timesteps
91
+ state.status = "Initialising environment…"
92
+
93
+ env = Monitor(gym.make("LunarLander-v3", continuous=True))
94
+
95
+ if os.path.exists(base_model_path):
96
+ model = SAC.load(base_model_path, env=env)
97
+ model.learning_rate = learning_rate
98
+ model.batch_size = batch_size
99
+ else:
100
+ model = SAC(
101
+ "MlpPolicy", env,
102
+ learning_rate=learning_rate,
103
+ batch_size=batch_size,
104
+ verbose=0,
105
+ )
106
+
107
+ cb = _LiveCallback(state, log_interval=max(total_timesteps // 200, 200))
108
+ model.learn(
109
+ total_timesteps=total_timesteps,
110
+ callback=cb,
111
+ reset_num_timesteps=False,
112
+ progress_bar=False,
113
+ log_interval=1,
114
+ )
115
+ model.save(save_path)
116
+ env.close()
117
+
118
+ thread = threading.Thread(target=_train, daemon=True)
119
+ thread.start()
120
+ return thread
requirements.txt CHANGED
@@ -1,3 +1,5 @@
1
  stable-baselines3[extra]
2
  gymnasium[box2d]
3
  shimmy
 
 
 
1
  stable-baselines3[extra]
2
  gymnasium[box2d]
3
  shimmy
4
+ matplotlib
5
+ gradio>=6.0.0
viz/__init__.py ADDED
File without changes
viz/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (196 Bytes). View file
 
viz/__pycache__/charts.cpython-311.pyc ADDED
Binary file (18.5 kB). View file
 
viz/__pycache__/replay.cpython-311.pyc ADDED
Binary file (8.43 kB). View file
 
viz/charts.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ All matplotlib figure generation for the dashboard.
3
+ Every function returns a plt.Figure β€” caller closes or passes to Gradio.
4
+ """
5
+
6
+ from __future__ import annotations
7
+ import numpy as np
8
+ import matplotlib
9
+ matplotlib.use("Agg")
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib.gridspec as gridspec
12
+ import matplotlib.patches as mpatches
13
+ from matplotlib.collections import LineCollection
14
+
15
+ from core.mission import MissionResult, EpisodeResult
16
+
17
+ # ── Palette ───────────────────────────────────────────────────────────────────
18
+ BG = "#030b1a"
19
+ BG2 = "#060f1e"
20
+ GRID = "#0d2540"
21
+ ACCENT = "#4fb3ff"
22
+ GREEN = "#2ddb7c"
23
+ AMBER = "#f5a623"
24
+ RED = "#ff4d6d"
25
+ PURPLE = "#c77dff"
26
+ TEXT = "#c8ddf0"
27
+ DIM = "#3a6080"
28
+
29
+ EP_COLORS = [ACCENT, GREEN, AMBER, PURPLE, "#ff9f1c", "#e9c46a", "#f4a261"]
30
+
31
+
32
+ def _style_ax(ax, title: str = "", xlabel: str = "", ylabel: str = ""):
33
+ ax.set_facecolor(BG2)
34
+ ax.tick_params(colors=DIM, labelsize=8)
35
+ for spine in ax.spines.values():
36
+ spine.set_color(GRID)
37
+ ax.grid(color=GRID, linewidth=0.5, linestyle="--", alpha=0.6)
38
+ if title:
39
+ ax.set_title(title, color=TEXT, fontsize=10, pad=8, fontfamily="monospace")
40
+ if xlabel:
41
+ ax.set_xlabel(xlabel, color=DIM, fontsize=8)
42
+ if ylabel:
43
+ ax.set_ylabel(ylabel, color=DIM, fontsize=8)
44
+
45
+
46
+ def mission_overview(mission: MissionResult) -> plt.Figure:
47
+ """4-panel summary: bar chart, trajectory, reward curves, throttle."""
48
+ n = len(mission.episodes)
49
+ fig = plt.figure(figsize=(14, 9), facecolor=BG)
50
+ gs = gridspec.GridSpec(2, 2, figure=fig, hspace=0.45, wspace=0.32,
51
+ left=0.07, right=0.97, top=0.90, bottom=0.08)
52
+
53
+ # ── Panel 1: Episode rewards bar ─────────────────────────────────────────
54
+ ax1 = fig.add_subplot(gs[0, 0])
55
+ _style_ax(ax1, "EPISODE REWARDS", "Episode", "Score")
56
+ labels = [f"#{e.episode_idx+1}" for e in mission.episodes]
57
+ colors = [GREEN if r >= 150 else (AMBER if r >= 0 else RED) for r in mission.rewards]
58
+ bars = ax1.bar(labels, mission.rewards, color=colors, edgecolor=BG, linewidth=0.8)
59
+ ax1.axhline(200, color=GREEN, linestyle="--", linewidth=1, alpha=0.5, label="Perfect (200)")
60
+ ax1.axhline(150, color=ACCENT, linestyle="--", linewidth=1, alpha=0.5, label="Success (150)")
61
+ ax1.axhline(0, color=RED, linestyle="--", linewidth=1, alpha=0.3)
62
+ ax1.legend(fontsize=7, facecolor=BG2, edgecolor=GRID, labelcolor=DIM)
63
+ for bar, val in zip(bars, mission.rewards):
64
+ ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 3,
65
+ f"{val:.0f}", ha="center", va="bottom", color=TEXT, fontsize=8)
66
+
67
+ # ── Panel 2: 2-D flight trajectory ───────────────────────────────────────
68
+ ax2 = fig.add_subplot(gs[0, 1])
69
+ _style_ax(ax2, "FLIGHT TRAJECTORIES", "X Position", "Altitude")
70
+ for i, ep in enumerate(mission.episodes):
71
+ col = EP_COLORS[i % len(EP_COLORS)]
72
+ # colour-map by altitude for gradient effect
73
+ points = np.array([ep.xs, ep.ys]).T.reshape(-1, 1, 2)
74
+ segments = np.concatenate([points[:-1], points[1:]], axis=1)
75
+ lc = LineCollection(segments, colors=col, linewidth=1.2, alpha=0.7)
76
+ ax2.add_collection(lc)
77
+ # landing marker
78
+ ax2.scatter(ep.xs[-1], ep.ys[-1],
79
+ marker=("*" if ep.total_reward >= 150 else "x"),
80
+ s=80, color=col, zorder=5)
81
+ ax2.autoscale()
82
+ ax2.axhline(0, color=GRID, linewidth=1)
83
+ # Legend patches
84
+ patches = [mpatches.Patch(color=EP_COLORS[i % len(EP_COLORS)],
85
+ label=f"#{e.episode_idx+1} {e.status_emoji}")
86
+ for i, e in enumerate(mission.episodes)]
87
+ ax2.legend(handles=patches, fontsize=7, facecolor=BG2,
88
+ edgecolor=GRID, labelcolor=DIM, loc="upper right")
89
+
90
+ # ── Panel 3: Cumulative reward over steps ────────────────────────────────
91
+ ax3 = fig.add_subplot(gs[1, 0])
92
+ _style_ax(ax3, "CUMULATIVE REWARD", "Step", "Reward")
93
+ for i, ep in enumerate(mission.episodes):
94
+ col = EP_COLORS[i % len(EP_COLORS)]
95
+ ax3.plot(ep.cumulative_rewards, color=col, linewidth=1.5,
96
+ label=f"#{ep.episode_idx+1}", alpha=0.85)
97
+ ax3.axhline(0, color=RED, linestyle="--", linewidth=0.8, alpha=0.4)
98
+ ax3.legend(fontsize=7, facecolor=BG2, edgecolor=GRID, labelcolor=DIM)
99
+
100
+ # ── Panel 4: Engine throttle timeline ───────────────────────────────────
101
+ ax4 = fig.add_subplot(gs[1, 1])
102
+ _style_ax(ax4, "ENGINE THROTTLE β€” BEST EPISODE", "Step", "Throttle")
103
+ best = mission.best
104
+ steps = range(len(best.steps))
105
+ ax4.fill_between(steps, 0, best.main_throttle,
106
+ color=ACCENT, alpha=0.35, label="Main Engine")
107
+ ax4.plot(steps, best.main_throttle, color=ACCENT, linewidth=1.2)
108
+ ax4.fill_between(steps, 0, best.lateral_throttle,
109
+ color=AMBER, alpha=0.25, label="Lateral Thrusters")
110
+ ax4.plot(steps, best.lateral_throttle, color=AMBER, linewidth=1.0)
111
+ ax4.axhline(0, color=GRID, linewidth=0.8)
112
+ ax4.set_ylim(-1.1, 1.1)
113
+ ax4.legend(fontsize=7, facecolor=BG2, edgecolor=GRID, labelcolor=DIM)
114
+
115
+ # ── Figure title ─────────────────────────────────────────────────────────
116
+ sr = mission.success_rate * 100
117
+ fig.suptitle(
118
+ f"MISSION REPORT Β· {n} episodes Β· "
119
+ f"Avg {mission.avg_reward:+.1f} Β· Success {sr:.0f}%",
120
+ color=TEXT, fontsize=12, fontfamily="monospace", y=0.96,
121
+ )
122
+ return fig
123
+
124
+
125
+ def single_episode_detail(ep: EpisodeResult) -> plt.Figure:
126
+ """6-panel deep-dive for one episode."""
127
+ fig = plt.figure(figsize=(14, 8), facecolor=BG)
128
+ gs = gridspec.GridSpec(2, 3, figure=fig, hspace=0.5, wspace=0.38,
129
+ left=0.07, right=0.97, top=0.88, bottom=0.08)
130
+
131
+ steps = list(range(len(ep.steps)))
132
+
133
+ # Trajectory
134
+ ax = fig.add_subplot(gs[0, 0])
135
+ _style_ax(ax, "TRAJECTORY", "X", "Y")
136
+ ax.plot(ep.xs, ep.ys, color=ACCENT, linewidth=1.5)
137
+ ax.scatter(ep.xs[0], ep.ys[0], s=60, color=GREEN, zorder=5, label="Start")
138
+ ax.scatter(ep.xs[-1], ep.ys[-1], s=80,
139
+ marker="*" if ep.total_reward >= 150 else "x",
140
+ color=GREEN if ep.total_reward >= 150 else RED, zorder=5, label="End")
141
+ ax.axhline(0, color=GRID, linewidth=1)
142
+ ax.legend(fontsize=7, facecolor=BG2, edgecolor=GRID, labelcolor=DIM)
143
+
144
+ # Cumulative reward
145
+ ax = fig.add_subplot(gs[0, 1])
146
+ _style_ax(ax, "CUMULATIVE REWARD", "Step", "Reward")
147
+ cum = ep.cumulative_rewards
148
+ ax.fill_between(steps, 0, cum,
149
+ color=GREEN if ep.total_reward >= 150 else RED, alpha=0.2)
150
+ ax.plot(steps, cum, color=GREEN if ep.total_reward >= 150 else RED, linewidth=1.5)
151
+ ax.axhline(0, color=GRID, linewidth=0.8, linestyle="--")
152
+
153
+ # Altitude over time
154
+ ax = fig.add_subplot(gs[0, 2])
155
+ _style_ax(ax, "ALTITUDE", "Step", "Y")
156
+ ax.fill_between(steps, 0, ep.ys, color=ACCENT, alpha=0.15)
157
+ ax.plot(steps, ep.ys, color=ACCENT, linewidth=1.5)
158
+ ax.axhline(0, color=RED, linewidth=1, linestyle="--", alpha=0.5)
159
+
160
+ # Angle
161
+ ax = fig.add_subplot(gs[1, 0])
162
+ _style_ax(ax, "BODY ANGLE", "Step", "Degrees")
163
+ ax.fill_between(steps, 0, ep.angles, color=AMBER, alpha=0.2)
164
+ ax.plot(steps, ep.angles, color=AMBER, linewidth=1.3)
165
+ ax.axhline(0, color=GRID, linewidth=0.8, linestyle="--")
166
+
167
+ # Main throttle
168
+ ax = fig.add_subplot(gs[1, 1])
169
+ _style_ax(ax, "MAIN ENGINE", "Step", "Throttle")
170
+ ax.fill_between(steps, 0, ep.main_throttle, color=ACCENT, alpha=0.3)
171
+ ax.plot(steps, ep.main_throttle, color=ACCENT, linewidth=1.2)
172
+ ax.set_ylim(-1.1, 1.1)
173
+ ax.axhline(0, color=GRID, linewidth=0.8)
174
+
175
+ # Lateral throttle
176
+ ax = fig.add_subplot(gs[1, 2])
177
+ _style_ax(ax, "LATERAL THRUSTERS", "Step", "Throttle")
178
+ ax.fill_between(steps, 0, ep.lateral_throttle, color=PURPLE, alpha=0.3)
179
+ ax.plot(steps, ep.lateral_throttle, color=PURPLE, linewidth=1.2)
180
+ ax.set_ylim(-1.1, 1.1)
181
+ ax.axhline(0, color=GRID, linewidth=0.8)
182
+
183
+ fig.suptitle(
184
+ f"EPISODE {ep.episode_idx+1} DEEP-DIVE Β· "
185
+ f"{ep.status_emoji} {ep.status} Β· Score: {ep.total_reward:+.1f} Β· "
186
+ f"{len(ep.steps)} steps",
187
+ color=TEXT, fontsize=11, fontfamily="monospace", y=0.95,
188
+ )
189
+ return fig
190
+
191
+
192
+ def training_dashboard(state) -> plt.Figure:
193
+ """Live training metrics: reward history + losses + entropy."""
194
+ fig = plt.figure(figsize=(14, 5), facecolor=BG)
195
+ gs = gridspec.GridSpec(1, 3, figure=fig, wspace=0.38,
196
+ left=0.06, right=0.97, top=0.85, bottom=0.12)
197
+
198
+ # Reward curve
199
+ ax = fig.add_subplot(gs[0])
200
+ _style_ax(ax, "EPISODE REWARD", "Episode", "Reward")
201
+ if state.episode_rewards:
202
+ eps = list(range(len(state.episode_rewards)))
203
+ ax.plot(eps, state.episode_rewards, color=ACCENT, linewidth=0.8, alpha=0.4)
204
+ if len(eps) > 20:
205
+ k = max(5, len(eps) // 30)
206
+ smooth = np.convolve(state.episode_rewards, np.ones(k)/k, "valid")
207
+ ax.plot(range(k-1, len(eps)), smooth, color=ACCENT, linewidth=2)
208
+ ax.axhline(200, color=GREEN, linestyle="--", linewidth=1, alpha=0.5)
209
+ ax.axhline(150, color=AMBER, linestyle="--", linewidth=1, alpha=0.5)
210
+
211
+ # Losses
212
+ ax2 = fig.add_subplot(gs[1])
213
+ _style_ax(ax2, "ACTOR / CRITIC LOSS", "Log Step", "Loss")
214
+ if state.log_steps:
215
+ ax2.plot(state.log_steps, state.actor_losses, color=ACCENT,
216
+ linewidth=1.5, label="Actor")
217
+ ax2.plot(state.log_steps, state.critic_losses, color=AMBER,
218
+ linewidth=1.5, label="Critic")
219
+ ax2.legend(fontsize=7, facecolor=BG2, edgecolor=GRID, labelcolor=DIM)
220
+
221
+ # Entropy coef
222
+ ax3 = fig.add_subplot(gs[2])
223
+ _style_ax(ax3, "ENTROPY COEFFICIENT", "Log Step", "Ξ±")
224
+ if state.log_steps:
225
+ ax3.plot(state.log_steps, state.ent_coefs, color=PURPLE, linewidth=1.5)
226
+ ax3.axhline(0, color=GRID, linewidth=0.8, linestyle="--")
227
+
228
+ n_ep = len(state.episode_rewards)
229
+ best = state.best_reward
230
+ fig.suptitle(
231
+ f"SAC TRAINING Β· {state.timestep:,}/{state.total_timesteps:,} steps Β· "
232
+ f"{n_ep} episodes Β· Best: {best:+.1f}",
233
+ color=TEXT, fontsize=10, fontfamily="monospace",
234
+ )
235
+ return fig
236
+
237
+
238
+ def empty_figure(message: str = "Run a mission to see charts.") -> plt.Figure:
239
+ fig, ax = plt.subplots(figsize=(12, 5), facecolor=BG)
240
+ fig.patch.set_facecolor(BG)
241
+ ax.set_facecolor(BG2)
242
+ ax.text(0.5, 0.5, message, transform=ax.transAxes,
243
+ ha="center", va="center", color=DIM,
244
+ fontsize=13, fontfamily="monospace")
245
+ ax.axis("off")
246
+ return fig
viz/replay.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Animated GIF generation from raw RGB frames.
3
+ Adds HUD overlay (step, reward, throttle bars) using PIL drawing β€” no matplotlib overhead.
4
+ """
5
+
6
+ from __future__ import annotations
7
+ import tempfile
8
+ import numpy as np
9
+ import PIL.Image
10
+ import PIL.ImageDraw
11
+ import PIL.ImageFont
12
+
13
+ from core.mission import EpisodeResult
14
+
15
+
16
+ # ── HUD rendering ─────────────────────────────────────────────────────────────
17
+
18
+ def _draw_hud(
19
+ img: PIL.Image.Image,
20
+ step: int,
21
+ cumulative_reward: float,
22
+ main_throttle: float,
23
+ lateral_throttle: float,
24
+ status: str,
25
+ ) -> PIL.Image.Image:
26
+ draw = PIL.ImageDraw.Draw(img)
27
+ W, H = img.size
28
+
29
+ # Semi-transparent top bar
30
+ draw.rectangle([(0, 0), (W, 22)], fill=(3, 11, 26, 200))
31
+
32
+ # Step & reward text
33
+ draw.text((6, 4), f"STEP {step:03d}", fill=(79, 179, 255), font=None)
34
+ rcolor = (45, 219, 124) if cumulative_reward >= 0 else (255, 77, 109)
35
+ draw.text((W//2 - 40, 4), f"REWARD {cumulative_reward:+.1f}", fill=rcolor, font=None)
36
+ draw.text((W - 80, 4), status, fill=(248, 166, 35), font=None)
37
+
38
+ # Throttle bars at bottom
39
+ BAR_H = 6
40
+ BAR_Y = H - BAR_H - 4
41
+
42
+ # Main engine bar (blue)
43
+ bar_max = W // 2 - 20
44
+ bar_w = int(abs(main_throttle) * bar_max)
45
+ draw.rectangle([(10, BAR_Y), (10 + bar_max, BAR_Y + BAR_H)],
46
+ fill=(13, 37, 64))
47
+ draw.rectangle([(10, BAR_Y), (10 + bar_w, BAR_Y + BAR_H)],
48
+ fill=(79, 179, 255))
49
+ draw.text((10, BAR_Y - 11), "MAIN", fill=(79, 179, 255), font=None)
50
+
51
+ # Lateral bar (amber)
52
+ lx = W // 2 + 10
53
+ lat_w = int(abs(lateral_throttle) * bar_max)
54
+ draw.rectangle([(lx, BAR_Y), (lx + bar_max, BAR_Y + BAR_H)],
55
+ fill=(13, 37, 64))
56
+ col = (245, 166, 35) if lateral_throttle >= 0 else (255, 77, 109)
57
+ draw.rectangle([(lx, BAR_Y), (lx + lat_w, BAR_Y + BAR_H)], fill=col)
58
+ draw.text((lx, BAR_Y - 11), "LATERAL", fill=(245, 166, 35), font=None)
59
+
60
+ return img
61
+
62
+
63
+ def make_episode_gif(
64
+ frames: list[np.ndarray],
65
+ episode: EpisodeResult,
66
+ fps: int = 15,
67
+ ) -> str:
68
+ """Overlay HUD on every frame, save as animated GIF. Returns temp file path."""
69
+ if not frames:
70
+ return ""
71
+
72
+ cum_rewards = episode.cumulative_rewards
73
+ pil_frames: list[PIL.Image.Image] = []
74
+
75
+ for i, frame in enumerate(frames):
76
+ img = PIL.Image.fromarray(frame).convert("RGBA")
77
+ cum_r = cum_rewards[i] if i < len(cum_rewards) else cum_rewards[-1]
78
+ step_data = episode.steps[i] if i < len(episode.steps) else episode.steps[-1]
79
+ img = _draw_hud(
80
+ img.convert("RGB"),
81
+ step=i + 1,
82
+ cumulative_reward=cum_r,
83
+ main_throttle=step_data.action_main,
84
+ lateral_throttle=step_data.action_lateral,
85
+ status=episode.status,
86
+ )
87
+ pil_frames.append(img)
88
+
89
+ tmp = tempfile.NamedTemporaryFile(suffix=".gif", delete=False)
90
+ pil_frames[0].save(
91
+ tmp.name,
92
+ save_all=True,
93
+ append_images=pil_frames[1:],
94
+ duration=int(1000 / fps),
95
+ loop=0,
96
+ optimize=False,
97
+ )
98
+ return tmp.name
99
+
100
+
101
+ def make_comparison_gif(
102
+ all_frames: list[list[np.ndarray]],
103
+ episodes: list[EpisodeResult],
104
+ fps: int = 12,
105
+ max_episodes: int = 4,
106
+ ) -> str:
107
+ """
108
+ Side-by-side grid GIF comparing up to `max_episodes` episodes.
109
+ Pads shorter episodes with their last frame.
110
+ """
111
+ n = min(len(all_frames), max_episodes)
112
+ if n == 0:
113
+ return ""
114
+
115
+ frame_lists = [all_frames[i] for i in range(n)]
116
+ ep_list = [episodes[i] for i in range(n)]
117
+
118
+ max_len = max(len(fl) for fl in frame_lists)
119
+ # Pad each episode to max_len
120
+ padded = [fl + [fl[-1]] * (max_len - len(fl)) if fl else [] for fl in frame_lists]
121
+
122
+ if not padded[0]:
123
+ return ""
124
+
125
+ h, w = padded[0][0].shape[:2]
126
+ cols = 2 if n > 2 else n
127
+ rows = (n + cols - 1) // cols
128
+ grid_w, grid_h = cols * w, rows * h
129
+
130
+ pil_frames: list[PIL.Image.Image] = []
131
+ for step_i in range(max_len):
132
+ canvas = PIL.Image.new("RGB", (grid_w, grid_h), (3, 11, 26))
133
+ for ep_i in range(n):
134
+ if step_i < len(padded[ep_i]):
135
+ cell = PIL.Image.fromarray(padded[ep_i][step_i])
136
+ else:
137
+ continue
138
+ # label
139
+ draw = PIL.ImageDraw.Draw(cell)
140
+ ep = ep_list[ep_i]
141
+ draw.rectangle([(0, 0), (cell.width, 16)], fill=(3, 11, 26))
142
+ col = (45, 219, 124) if ep.total_reward >= 150 else (255, 77, 109)
143
+ draw.text((4, 2),
144
+ f"#{ep.episode_idx+1} {ep.status} {ep.total_reward:+.0f}",
145
+ fill=col, font=None)
146
+ cx = (ep_i % cols) * w
147
+ cy = (ep_i // cols) * h
148
+ canvas.paste(cell, (cx, cy))
149
+ pil_frames.append(canvas)
150
+
151
+ tmp = tempfile.NamedTemporaryFile(suffix=".gif", delete=False)
152
+ pil_frames[0].save(
153
+ tmp.name,
154
+ save_all=True,
155
+ append_images=pil_frames[1:],
156
+ duration=int(1000 / fps),
157
+ loop=0,
158
+ optimize=False,
159
+ )
160
+ return tmp.name