Spaces:
Sleeping
Sleeping
| import math | |
| import time | |
| from typing import Tuple | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as patches | |
| import numpy as np | |
| import pandas as pd | |
| # ============================================================ | |
| # Robot Learning to Walk with Reinforcement Learning | |
| # Gradio visual simulation / teaching dashboard | |
| # ------------------------------------------------------------ | |
| # This is an educational visualization, not a real MuJoCo PPO | |
| # training run. It simulates the visible behavior of RL training: | |
| # 1. Early episodes: robot falls | |
| # 2. Middle episodes: robot struggles and becomes unstable | |
| # 3. Later episodes: robot improves balance | |
| # 4. Final episodes: robot develops a stable walking gait | |
| # ============================================================ | |
| MAX_EPISODE = 25_000 | |
| # ----------------------------- | |
| # Utility functions | |
| # ----------------------------- | |
| def clamp(value, low, high): | |
| return max(low, min(high, value)) | |
| def training_progress(episode: int, max_episode: int = MAX_EPISODE) -> float: | |
| """Return normalized training progress from 0 to 1.""" | |
| return clamp(episode / max_episode, 0, 1) | |
| def stage_from_progress(p: float) -> Tuple[str, str]: | |
| """Map training progress to a learning stage.""" | |
| if p < 0.08: | |
| return "Falling", "The policy has not learned balance yet." | |
| elif p < 0.22: | |
| return "Struggling", "The robot discovers pushing and limb movement." | |
| elif p < 0.42: | |
| return "Unstable", "The robot can stand briefly but tips often." | |
| elif p < 0.65: | |
| return "Improving Balance", "The policy begins coordinating feet and torso." | |
| elif p < 0.85: | |
| return "Mostly Stable", "The robot walks forward with occasional wobble." | |
| else: | |
| return "Natural Walk", "The policy has learned a smooth, efficient gait." | |
| def simulated_reward_curve(max_episode=MAX_EPISODE, points=350, seed=42): | |
| """Create a reward curve that trends upward with realistic noise.""" | |
| rng = np.random.default_rng(seed) | |
| x = np.linspace(0, max_episode, points) | |
| p = x / max_episode | |
| base = -300 + 1250 / (1 + np.exp(-7.2 * (p - 0.47))) | |
| noise = rng.normal(0, 85 * (1 - p) + 30, size=points) | |
| reward = base + noise | |
| return x, reward | |
| def simulated_stability_curve(max_episode=MAX_EPISODE, points=350, seed=7): | |
| """Create a stability curve that improves over time.""" | |
| rng = np.random.default_rng(seed) | |
| x = np.linspace(0, max_episode, points) | |
| p = x / max_episode | |
| base = 0.08 + 0.87 / (1 + np.exp(-7.5 * (p - 0.35))) | |
| noise = rng.normal(0, 0.08 * (1 - p) + 0.015, size=points) | |
| stability = np.clip(base + noise, 0, 1) | |
| return x, stability | |
| def current_value_from_curve(x, y, episode): | |
| idx = np.searchsorted(x, episode) | |
| idx = int(clamp(idx, 0, len(y) - 1)) | |
| return y[idx] | |
| REWARD_X, REWARD_Y = simulated_reward_curve() | |
| STABILITY_X, STABILITY_Y = simulated_stability_curve() | |
| # ----------------------------- | |
| # Drawing functions | |
| # ----------------------------- | |
| def draw_floor_grid(ax, x_min=-3, x_max=3, y_min=-0.1, y_max=2.4): | |
| ax.set_facecolor("#101b2a") | |
| for x in np.linspace(x_min, x_max, 13): | |
| ax.plot([x, x], [y_min, y_max], color="#26384f", linewidth=0.6) | |
| for y in np.linspace(y_min, y_max, 8): | |
| ax.plot([x_min, x_max], [y, y], color="#26384f", linewidth=0.6) | |
| ax.axhline(0, color="#93a4b8", linewidth=2) | |
| def limb(ax, x1, y1, x2, y2, color="#e8eef7", lw=7): | |
| ax.plot([x1, x2], [y1, y2], color="#1f2937", linewidth=lw + 3, solid_capstyle="round", alpha=0.9) | |
| ax.plot([x1, x2], [y1, y2], color=color, linewidth=lw, solid_capstyle="round") | |
| def joint(ax, x, y, radius=0.045, color="#4ea1ff"): | |
| circle = patches.Circle((x, y), radius, facecolor=color, edgecolor="#dbeafe", linewidth=1.2, zorder=5) | |
| ax.add_patch(circle) | |
| def draw_robot(ax, progress, phase=0, x_offset=0): | |
| """ | |
| Draw a simple humanoid robot whose pose changes based on training progress. | |
| progress: 0 to 1 | |
| phase: walking phase in radians | |
| """ | |
| stage, _ = stage_from_progress(progress) | |
| wobble = (1 - progress) * 0.35 * math.sin(phase * 1.8) | |
| gait_amp = 0.15 + progress * 0.28 | |
| torso_lean = wobble * 0.8 | |
| height = 0.75 + progress * 0.55 | |
| if stage == "Falling": | |
| hip = np.array([x_offset - 0.15, 0.18]) | |
| chest = np.array([x_offset + 0.2, 0.33]) | |
| head = np.array([x_offset + 0.42, 0.38]) | |
| left_knee = np.array([x_offset - 0.65, 0.12]) | |
| left_foot = np.array([x_offset - 1.0, 0.03]) | |
| right_knee = np.array([x_offset + 0.0, 0.05]) | |
| right_foot = np.array([x_offset + 0.35, 0.02]) | |
| left_elbow = np.array([x_offset + 0.10, 0.12]) | |
| left_hand = np.array([x_offset - 0.25, 0.02]) | |
| right_elbow = np.array([x_offset + 0.48, 0.18]) | |
| right_hand = np.array([x_offset + 0.72, 0.03]) | |
| elif stage == "Struggling": | |
| hip = np.array([x_offset - 0.1, 0.45]) | |
| chest = np.array([x_offset + 0.10, 0.75]) | |
| head = np.array([x_offset + 0.20, 0.96]) | |
| left_knee = np.array([x_offset - 0.45, 0.22]) | |
| left_foot = np.array([x_offset - 0.78, 0.03]) | |
| right_knee = np.array([x_offset + 0.25, 0.25]) | |
| right_foot = np.array([x_offset + 0.55, 0.03]) | |
| left_elbow = np.array([x_offset - 0.20, 0.42]) | |
| left_hand = np.array([x_offset - 0.55, 0.03]) | |
| right_elbow = np.array([x_offset + 0.42, 0.48]) | |
| right_hand = np.array([x_offset + 0.70, 0.03]) | |
| else: | |
| torso_x = x_offset + torso_lean | |
| hip = np.array([torso_x, height]) | |
| chest = np.array([torso_x + torso_lean * 0.15, height + 0.55]) | |
| head = np.array([torso_x + torso_lean * 0.25, height + 0.83]) | |
| left_leg_phase = math.sin(phase) | |
| right_leg_phase = math.sin(phase + math.pi) | |
| left_arm_phase = math.sin(phase + math.pi) | |
| right_arm_phase = math.sin(phase) | |
| left_foot = np.array([x_offset - 0.25 + gait_amp * left_leg_phase, 0.03]) | |
| right_foot = np.array([x_offset + 0.25 + gait_amp * right_leg_phase, 0.03]) | |
| left_knee = (hip + left_foot) / 2 + np.array([0.05 * math.cos(phase), 0.18 + 0.07 * progress]) | |
| right_knee = (hip + right_foot) / 2 + np.array([0.05 * math.cos(phase + math.pi), 0.18 + 0.07 * progress]) | |
| left_shoulder = chest + np.array([-0.18, -0.03]) | |
| right_shoulder = chest + np.array([0.18, -0.03]) | |
| left_hand = left_shoulder + np.array([0.18 * left_arm_phase, -0.55]) | |
| right_hand = right_shoulder + np.array([0.18 * right_arm_phase, -0.55]) | |
| left_elbow = (left_shoulder + left_hand) / 2 + np.array([-0.05, -0.02]) | |
| right_elbow = (right_shoulder + right_hand) / 2 + np.array([0.05, -0.02]) | |
| if stage in ["Falling", "Struggling"]: | |
| left_shoulder = chest + np.array([-0.15, -0.03]) | |
| right_shoulder = chest + np.array([0.15, -0.03]) | |
| limb(ax, hip[0], hip[1], left_knee[0], left_knee[1]) | |
| limb(ax, left_knee[0], left_knee[1], left_foot[0], left_foot[1]) | |
| limb(ax, hip[0], hip[1], right_knee[0], right_knee[1]) | |
| limb(ax, right_knee[0], right_knee[1], right_foot[0], right_foot[1]) | |
| limb(ax, hip[0], hip[1], chest[0], chest[1], color="#f8fafc", lw=9) | |
| limb(ax, left_shoulder[0], left_shoulder[1], left_elbow[0], left_elbow[1], lw=6) | |
| limb(ax, left_elbow[0], left_elbow[1], left_hand[0], left_hand[1], lw=6) | |
| limb(ax, right_shoulder[0], right_shoulder[1], right_elbow[0], right_elbow[1], lw=6) | |
| limb(ax, right_elbow[0], right_elbow[1], right_hand[0], right_hand[1], lw=6) | |
| head_circle = patches.Circle((head[0], head[1]), 0.14, facecolor="#e5e7eb", edgecolor="#111827", linewidth=2, zorder=6) | |
| ax.add_patch(head_circle) | |
| visor = patches.Ellipse((head[0] + 0.03, head[1] + 0.02), 0.15, 0.06, facecolor="#0f172a", edgecolor="#38bdf8", linewidth=1, zorder=7) | |
| ax.add_patch(visor) | |
| for pt in [hip, chest, left_knee, right_knee, left_foot, right_foot, left_elbow, right_elbow, left_hand, right_hand]: | |
| joint(ax, pt[0], pt[1], radius=0.035) | |
| ax.scatter([hip[0]], [hip[1] + 0.18], s=90, color="#54e26f", edgecolor="white", zorder=9) | |
| ax.text(hip[0] + 0.08, hip[1] + 0.18, "COM", color="#54e26f", fontsize=8, va="center") | |
| ax.plot([x_offset - 0.45, x_offset + 0.45], [0.015, 0.015], color="#54e26f", linewidth=3, alpha=0.8) | |
| def robot_figure(progress, episode, phase): | |
| fig, ax = plt.subplots(figsize=(9, 5)) | |
| fig.patch.set_facecolor("#07111f") | |
| draw_floor_grid(ax) | |
| draw_robot(ax, progress, phase=phase) | |
| stage, description = stage_from_progress(progress) | |
| ax.text(-2.85, 2.25, f"LIVE SIMULATION | Episode {episode:,}", color="#4ea1ff", fontsize=13, fontweight="bold") | |
| ax.text(-2.85, 2.08, f"Stage: {stage}", color="#ffffff", fontsize=12, fontweight="bold") | |
| ax.text(-2.85, 1.93, description, color="#b8c7d9", fontsize=10) | |
| ax.text( | |
| 1.15, | |
| 2.10, | |
| "Observation State\n" | |
| f"Base height: {0.45 + progress * 0.75:.2f} m\n" | |
| f"Velocity: {progress * 1.2:.2f} m/s\n" | |
| f"Stability: {progress:.2f}\n" | |
| f"Wobble: {(1-progress)*100:.1f}%", | |
| color="#dbeafe", | |
| fontsize=9, | |
| bbox=dict(facecolor="#0c1b2d", edgecolor="#24425f", boxstyle="round,pad=0.55", alpha=0.95), | |
| ) | |
| ax.set_xlim(-3, 3) | |
| ax.set_ylim(-0.1, 2.4) | |
| ax.axis("off") | |
| plt.tight_layout() | |
| return fig | |
| def progression_figure(): | |
| checkpoints = [ | |
| (0.02, "1", "FALLING", "Episode ~0"), | |
| (0.14, "2", "STRUGGLING", "Episode ~200"), | |
| (0.32, "3", "UNSTABLE", "Episode ~1,000"), | |
| (0.55, "4", "IMPROVING BALANCE", "Episode ~5,000"), | |
| (0.75, "5", "MOSTLY STABLE", "Episode ~15,000"), | |
| (0.94, "6", "NATURAL WALK", "Episode ~25,000+"), | |
| ] | |
| fig, axes = plt.subplots(1, 6, figsize=(15, 2.7)) | |
| fig.patch.set_facecolor("#07111f") | |
| for ax, (p, number, label, ep_label) in zip(axes, checkpoints): | |
| draw_floor_grid(ax, x_min=-1.4, x_max=1.4, y_min=-0.05, y_max=1.8) | |
| draw_robot(ax, p, phase=1.2) | |
| ax.set_xlim(-1.35, 1.35) | |
| ax.set_ylim(-0.05, 1.85) | |
| ax.axis("off") | |
| ax.set_title(f"{number}. {label}\n{ep_label}", color="#ffffff", fontsize=9, fontweight="bold") | |
| plt.tight_layout() | |
| return fig | |
| def plot_training_curve(x, y, episode, title, ylabel): | |
| fig, ax = plt.subplots(figsize=(5.2, 3.2)) | |
| fig.patch.set_facecolor("#07111f") | |
| ax.set_facecolor("#0c1b2d") | |
| ax.plot(x, y, linewidth=1.5, color="#54e26f" if "Reward" in title else "#4ea1ff") | |
| current = current_value_from_curve(x, y, episode) | |
| ax.scatter([episode], [current], s=80, zorder=4, color="#ffffff", edgecolor="#f59e0b") | |
| ax.set_title(title, color="white", fontsize=12, fontweight="bold") | |
| ax.set_xlabel("Episode", color="#dbeafe") | |
| ax.set_ylabel(ylabel, color="#dbeafe") | |
| ax.grid(True, alpha=0.25) | |
| ax.tick_params(colors="#dbeafe") | |
| for spine in ax.spines.values(): | |
| spine.set_color("#24425f") | |
| plt.tight_layout() | |
| return fig | |
| # ----------------------------- | |
| # Simulation panels | |
| # ----------------------------- | |
| def reward_function_dataframe(progress): | |
| forward_velocity = progress * 1.0 | |
| upright_bonus = progress * 0.5 | |
| smooth_gait = progress * 0.2 | |
| torque_cost = (1 - progress) * 0.18 | |
| orientation_penalty = (1 - progress) * 0.35 | |
| joint_penalty = (1 - progress) * 0.15 | |
| action_penalty = (1 - progress) * 0.08 | |
| total = forward_velocity + upright_bonus + smooth_gait - torque_cost - orientation_penalty - joint_penalty - action_penalty | |
| normalized = clamp((total + 0.7) / 2.4, 0, 1) * 1000 | |
| rows = [ | |
| ["+ Forward velocity", 1.0, "Encourage moving forward", forward_velocity], | |
| ["+ Upright bonus", 0.5, "Encourage staying upright", upright_bonus], | |
| ["+ Smooth gait", 0.2, "Encourage coordinated walking", smooth_gait], | |
| ["− Torque cost", 0.001, "Discourage wasted energy", -torque_cost], | |
| ["− Orientation penalty", 0.5, "Penalize tipping over", -orientation_penalty], | |
| ["− Joint limit penalty", 1.0, "Avoid unnatural joint limits", -joint_penalty], | |
| ["− Action rate penalty", 0.001, "Encourage smooth actions", -action_penalty], | |
| ["TOTAL normalized reward", "", "Scaled to 0–1000", normalized], | |
| ] | |
| return pd.DataFrame(rows, columns=["Reward Component", "Weight", "Purpose", "Current Contribution"]) | |
| def metric_html(episode, progress, reward, stability, algorithm, learning_rate, gamma): | |
| stage, stage_desc = stage_from_progress(progress) | |
| return f""" | |
| <div class="top-grid"> | |
| <div class="metric-card"> | |
| <div class="small-label">Status</div> | |
| <div class="green-text">● TRAINING</div> | |
| <div class="small-label">Environment: Humanoid-v1</div> | |
| <div class="small-label">Simulator: Teaching Demo</div> | |
| </div> | |
| <div class="metric-card"> | |
| <div class="small-label">Episode</div> | |
| <div class="blue-text big-number">{episode:,}</div> | |
| <div class="small-label">of {MAX_EPISODE:,}</div> | |
| </div> | |
| <div class="metric-card"> | |
| <div class="small-label">Learning Stage</div> | |
| <div class="orange-text stage-title">{stage}</div> | |
| <div class="small-label">{stage_desc}</div> | |
| </div> | |
| <div class="metric-card"> | |
| <div class="small-label">Average Reward</div> | |
| <div class="green-text big-number">{reward:,.1f}</div> | |
| <div class="small-label">Higher is better</div> | |
| </div> | |
| <div class="metric-card"> | |
| <div class="small-label">Stability</div> | |
| <div class="blue-text big-number">{stability:.2f}</div> | |
| <div class="small-label">Target: 0.90+</div> | |
| </div> | |
| </div> | |
| <div class="status-box"> | |
| <b>Training Status:</b> | |
| Algorithm = {algorithm} | Policy Network = MLP, 3x256 | Learning Rate = {learning_rate} | | |
| Discount Factor γ = {gamma} | Total Timesteps ≈ {episode * 8192:,} | |
| </div> | |
| """ | |
| def rl_loop_html(): | |
| return """ | |
| <div class="rl-loop"> | |
| <div class="loop-box"><b>STATE</b><br><span>observe robot + world</span><br><code>s_t</code></div> | |
| <div class="arrow">→</div> | |
| <div class="loop-box"><b>ACTION</b><br><span>policy chooses motors</span><br><code>a_t</code></div> | |
| <div class="arrow">→</div> | |
| <div class="loop-box reward"><b>REWARD</b><br><span>score movement</span><br><code>r_t</code></div> | |
| <div class="arrow">→</div> | |
| <div class="loop-box update"><b>POLICY UPDATE</b><br><span>improve network</span><br><code>θ_{t+1}</code></div> | |
| </div> | |
| <div class="status-box"> | |
| The robot is not directly programmed with a walking motion. It learns by trying actions, receiving rewards or penalties, | |
| and updating its policy across many simulated episodes. | |
| </div> | |
| """ | |
| def pseudocode_text(): | |
| return """ | |
| initialize policy_network θ | |
| for episode in range(total_episodes): | |
| state = environment.reset() | |
| done = False | |
| while not done: | |
| action = policy_network.select_action(state) | |
| next_state, reward, done, info = environment.step(action) | |
| store_transition(state, action, reward, next_state, done) | |
| state = next_state | |
| update policy_network θ using PPO/SAC/A2C | |
| # Over many episodes: | |
| # falling actions receive poor reward | |
| # balanced walking actions receive higher reward | |
| # the policy slowly learns a stable gait | |
| """ | |
| def update_dashboard(episode, animation_frame, algorithm, learning_rate, gamma): | |
| episode = int(episode) | |
| progress = training_progress(episode) | |
| phase = animation_frame / 100 * 2 * math.pi | |
| current_reward = current_value_from_curve(REWARD_X, REWARD_Y, episode) | |
| current_stability = current_value_from_curve(STABILITY_X, STABILITY_Y, episode) | |
| metrics = metric_html(episode, progress, current_reward, current_stability, algorithm, learning_rate, gamma) | |
| live = robot_figure(progress, episode, phase) | |
| reward_plot = plot_training_curve(REWARD_X, REWARD_Y, episode, "Reward per Episode", "Average reward") | |
| stability_plot = plot_training_curve(STABILITY_X, STABILITY_Y, episode, "Stability / Error Over Time", "Stability") | |
| reward_df = reward_function_dataframe(progress) | |
| return metrics, live, reward_plot, stability_plot, reward_df | |
| def jump_to_stage(stage_name): | |
| stage_map = { | |
| "Falling": 0, | |
| "Struggling": 2_000, | |
| "Unstable": 6_000, | |
| "Improving Balance": 12_500, | |
| "Mostly Stable": 18_750, | |
| "Natural Walk": 24_000, | |
| } | |
| return stage_map.get(stage_name, 15_680) | |
| def live_training_step(episode, animation_frame, speed, algorithm, learning_rate, gamma): | |
| """Advance the training simulation one live timer step.""" | |
| next_episode = int(episode + speed) | |
| if next_episode > MAX_EPISODE: | |
| next_episode = 0 | |
| next_animation_frame = int((animation_frame + 7) % 101) | |
| metrics, live, reward_plot, stability_plot, reward_df = update_dashboard( | |
| next_episode, | |
| next_animation_frame, | |
| algorithm, | |
| learning_rate, | |
| gamma, | |
| ) | |
| return ( | |
| next_episode, | |
| next_animation_frame, | |
| metrics, | |
| live, | |
| reward_plot, | |
| stability_plot, | |
| reward_df, | |
| ) | |
| def reset_training(algorithm, learning_rate, gamma): | |
| """Reset the live simulation to episode zero.""" | |
| episode = 0 | |
| animation_frame = 0 | |
| metrics, live, reward_plot, stability_plot, reward_df = update_dashboard( | |
| episode, | |
| animation_frame, | |
| algorithm, | |
| learning_rate, | |
| gamma, | |
| ) | |
| return episode, animation_frame, metrics, live, reward_plot, stability_plot, reward_df | |
| # ----------------------------- | |
| # Gradio app | |
| # ----------------------------- | |
| CUSTOM_CSS = """ | |
| body, .gradio-container { | |
| background: #07111f !important; | |
| color: #e5e7eb !important; | |
| } | |
| .metric-card { | |
| background: #0c1b2d; | |
| border: 1px solid #24425f; | |
| border-radius: 14px; | |
| padding: 14px; | |
| box-shadow: 0 0 12px rgba(40,120,255,.15); | |
| min-height: 110px; | |
| } | |
| .top-grid { | |
| display: grid; | |
| grid-template-columns: repeat(5, minmax(140px, 1fr)); | |
| gap: 12px; | |
| } | |
| .small-label { | |
| color: #b8c7d9; | |
| font-size: 0.88rem; | |
| } | |
| .green-text { | |
| color: #54e26f; | |
| font-weight: 800; | |
| } | |
| .blue-text { | |
| color: #4ea1ff; | |
| font-weight: 800; | |
| } | |
| .orange-text { | |
| color: #ff9f1c; | |
| font-weight: 800; | |
| } | |
| .big-number { | |
| font-size: 1.55rem; | |
| } | |
| .stage-title { | |
| font-size: 1.15rem; | |
| } | |
| .status-box { | |
| background: #0c1b2d; | |
| border: 1px solid #24425f; | |
| border-radius: 12px; | |
| padding: 12px; | |
| margin-top: 12px; | |
| color: #dbeafe; | |
| } | |
| .rl-loop { | |
| display: flex; | |
| align-items: center; | |
| justify-content: space-between; | |
| gap: 8px; | |
| flex-wrap: wrap; | |
| background: #0c1b2d; | |
| border: 1px solid #24425f; | |
| border-radius: 14px; | |
| padding: 16px; | |
| } | |
| .loop-box { | |
| border: 1px solid #4ea1ff; | |
| border-radius: 12px; | |
| padding: 12px; | |
| min-width: 150px; | |
| text-align: center; | |
| color: #dbeafe; | |
| } | |
| .loop-box b { | |
| color: #4ea1ff; | |
| } | |
| .loop-box.reward { | |
| border-color: #54e26f; | |
| } | |
| .loop-box.reward b { | |
| color: #54e26f; | |
| } | |
| .loop-box.update { | |
| border-color: #a855f7; | |
| } | |
| .loop-box.update b { | |
| color: #c084fc; | |
| } | |
| .arrow { | |
| font-size: 2rem; | |
| color: #dbeafe; | |
| } | |
| code { | |
| color: #ffffff; | |
| background: #111827; | |
| padding: 2px 5px; | |
| border-radius: 5px; | |
| } | |
| .gr-button { | |
| border-radius: 10px !important; | |
| } | |
| """ | |
| with gr.Blocks(css=CUSTOM_CSS, title="RL Robot Walking Simulation") as demo: | |
| live_timer = gr.Timer(value=0.5, active=False) | |
| gr.Markdown( | |
| """ | |
| # 🤖 Robot Learning to Walk with Reinforcement Learning | |
| Training a humanoid robot in repeated simulation episodes until the policy discovers balance, coordination, and forward walking. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| episode_slider = gr.Slider( | |
| minimum=0, | |
| maximum=MAX_EPISODE, | |
| value=15_680, | |
| step=100, | |
| label="Training Episode", | |
| ) | |
| animation_slider = gr.Slider( | |
| minimum=0, | |
| maximum=100, | |
| value=35, | |
| step=1, | |
| label="Walking Animation Frame", | |
| ) | |
| algorithm_dropdown = gr.Dropdown( | |
| choices=["PPO", "SAC", "A2C", "DQN-style discrete demo"], | |
| value="PPO", | |
| label="Algorithm", | |
| ) | |
| learning_rate_dropdown = gr.Dropdown( | |
| choices=["3e-4", "1e-4", "1e-3"], | |
| value="3e-4", | |
| label="Learning Rate", | |
| ) | |
| gamma_dropdown = gr.Dropdown( | |
| choices=["0.99", "0.98", "0.95"], | |
| value="0.99", | |
| label="Discount Factor γ", | |
| ) | |
| stage_dropdown = gr.Dropdown( | |
| choices=["Falling", "Struggling", "Unstable", "Improving Balance", "Mostly Stable", "Natural Walk"], | |
| value="Mostly Stable", | |
| label="Jump to Learning Stage", | |
| ) | |
| jump_button = gr.Button("Jump to Stage") | |
| gr.Markdown("### Live Reinforcement Mode") | |
| speed_slider = gr.Slider( | |
| minimum=100, | |
| maximum=1500, | |
| value=500, | |
| step=100, | |
| label="Episodes Added per Tick", | |
| ) | |
| with gr.Row(): | |
| start_button = gr.Button("▶ Start Live Training", variant="primary") | |
| stop_button = gr.Button("⏸ Stop") | |
| reset_button = gr.Button("↺ Reset Training") | |
| with gr.Column(scale=4): | |
| metrics_html = gr.HTML() | |
| progression_plot = gr.Plot(value=progression_figure(), label="Learning Progression") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| live_plot = gr.Plot(label="Live Simulation") | |
| gr.HTML(value=rl_loop_html()) | |
| with gr.Column(scale=1): | |
| reward_plot = gr.Plot(label="Reward per Episode") | |
| stability_plot = gr.Plot(label="Stability Over Time") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| reward_table = gr.Dataframe(label="Reward Function: Weighted Sum", interactive=False) | |
| with gr.Column(scale=1): | |
| gr.Code(value=pseudocode_text(), language="python", label="RL Pseudocode") | |
| gr.Markdown( | |
| """ | |
| **Teaching note:** This app visualizes the logic of robot locomotion training. | |
| For a real physics-based version, connect the same dashboard idea to **Gymnasium + MuJoCo + Stable-Baselines3 PPO**. | |
| """ | |
| ) | |
| inputs = [episode_slider, animation_slider, algorithm_dropdown, learning_rate_dropdown, gamma_dropdown] | |
| outputs = [metrics_html, live_plot, reward_plot, stability_plot, reward_table] | |
| live_outputs = [episode_slider, animation_slider, metrics_html, live_plot, reward_plot, stability_plot, reward_table] | |
| for component in inputs: | |
| component.change(fn=update_dashboard, inputs=inputs, outputs=outputs) | |
| demo.load(fn=update_dashboard, inputs=inputs, outputs=outputs) | |
| jump_button.click(fn=jump_to_stage, inputs=stage_dropdown, outputs=episode_slider).then( | |
| fn=update_dashboard, | |
| inputs=inputs, | |
| outputs=outputs, | |
| ) | |
| start_button.click(lambda: gr.Timer(active=True), inputs=None, outputs=live_timer) | |
| stop_button.click(lambda: gr.Timer(active=False), inputs=None, outputs=live_timer) | |
| reset_button.click( | |
| fn=reset_training, | |
| inputs=[algorithm_dropdown, learning_rate_dropdown, gamma_dropdown], | |
| outputs=live_outputs, | |
| ) | |
| live_timer.tick( | |
| fn=live_training_step, | |
| inputs=[episode_slider, animation_slider, speed_slider, algorithm_dropdown, learning_rate_dropdown, gamma_dropdown], | |
| outputs=live_outputs, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |