Enhance mujoco_server.py and UR5 environment for improved teleoperation and episode control
Browse files- Introduced new mechanisms for handling teleoperation commands, including a snapshot of the last command for better state management.
- Updated the reward calculation in the UR5 environment to prioritize task rewards, with a fallback to distance-based rewards.
- Added gym-style control message handling over the unified WebSocket, allowing for richer interactions and state updates.
- Implemented a robust episode control flag system to manage episode termination and truncation without interfering with the gym client's lifecycle.
- Enhanced documentation in README.md to clarify server startup options and the unified WebSocket API for trainers and clients.
- Added unit tests to verify the functionality of the episode control mechanism, ensuring thread safety and correct behavior across multiple episodes.
- README.md +50 -23
- mujoco_server.py +180 -55
- robots/ur5/ur5_env.py +12 -4
- tests/test_api.py +2 -2
- tests/test_episode_control_fix.py +167 -0
|
@@ -115,12 +115,20 @@ git clone --recurse-submodules https://github.com/iit-DLSLab/Quadruped-PyMPC
|
|
| 115 |
cd Quadruped-PyMPC
|
| 116 |
pip install -e .
|
| 117 |
|
| 118 |
-
# Start the server
|
| 119 |
python mujoco_server.py
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
# Open browser at http://localhost:3004/nova-sim/api/v1
|
| 122 |
```
|
| 123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
## Docker Deployment
|
| 125 |
|
| 126 |
### Getting Started
|
|
@@ -381,12 +389,14 @@ ws.send(JSON.stringify({type: 'command', data: {vx: 0.5, vy: 0, vyaw: 0}}));
|
|
| 381 |
|
| 382 |
// Receive messages
|
| 383 |
ws.onmessage = (event) => {
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
};
|
| 389 |
-
```
|
|
|
|
|
|
|
| 390 |
|
| 391 |
#### Client → Server Messages
|
| 392 |
|
|
@@ -397,10 +407,23 @@ All messages are JSON with `{type, data}` structure:
|
|
| 397 |
| `command` | `{vx, vy, vyaw}` | Set velocity command |
|
| 398 |
| `reset` | `{}` | Reset robot to standing pose |
|
| 399 |
| `switch_robot` | `{robot, scene?}` | Switch active robot and optional scene |
|
|
|
|
| 400 |
| `camera` | `{action, ...}` | Camera control |
|
| 401 |
| `camera_follow` | `{follow}` | Toggle camera follow mode |
|
| 402 |
| `teleop_command` | `{dx, dy, dz}` | Apply incremental cartesian jog command (UI teleop) |
|
| 403 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
**`command`:**
|
| 405 |
```json
|
| 406 |
{"type": "command", "data": {"vx": 0.5, "vy": 0.0, "vyaw": 0.0}}
|
|
@@ -493,9 +516,10 @@ For locomotion robots (G1, Spot):
|
|
| 493 |
"base_height": 0.46,
|
| 494 |
"upright": 0.98,
|
| 495 |
"steps": 1234,
|
| 496 |
-
"vx": 0.5,
|
| 497 |
-
"
|
| 498 |
-
"
|
|
|
|
| 499 |
}
|
| 500 |
}
|
| 501 |
```
|
|
@@ -515,6 +539,10 @@ For robot arm (UR5):
|
|
| 515 |
"control_mode": "ik",
|
| 516 |
"use_orientation": true,
|
| 517 |
"steps": 1234,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
"nova_api": {
|
| 519 |
"connected": true,
|
| 520 |
"state_streaming": true,
|
|
@@ -530,21 +558,21 @@ For robot arm (UR5):
|
|
| 530 |
- `connected`: Whether Nova API client is connected
|
| 531 |
- `state_streaming`: Whether using Nova API for robot state streaming (vs. internal)
|
| 532 |
- `ik`: Whether using Nova API for inverse kinematics (vs. internal)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 533 |
|
| 534 |
-
###
|
| 535 |
-
|
| 536 |
-
The gym-style API is exposed at `ws://localhost:3004/nova-sim/api/v1/gym/ws`.
|
| 537 |
-
It supports `reset`, `step`, `configure`, and `get_spaces`.
|
| 538 |
|
| 539 |
-
|
| 540 |
|
| 541 |
-
```
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
```
|
| 546 |
|
| 547 |
-
The server
|
| 548 |
|
| 549 |
### HTTP Endpoints
|
| 550 |
|
|
@@ -803,8 +831,7 @@ The Nova API integration is implemented in:
|
|
| 803 |
2. Keep it running at `http://localhost:3004` so the HTTP/websocket endpoints stay reachable.
|
| 804 |
3. Run `pytest nova-sim/tests` to exercise:
|
| 805 |
- API endpoints (`/metadata`, `/camera/<name>/video_feed`, `/video_feed`)
|
| 806 |
-
- WebSocket control (`/ws`)
|
| 807 |
-
- Gym-style websocket (`/gym/ws`)
|
| 808 |
- Auxiliary MJPEG overlays after switching to the T-push UR5 scene
|
| 809 |
|
| 810 |
The tests assume the server is accessible via `http://localhost:3004/nova-sim/api/v1` and will skip automatically if the API is unreachable.
|
|
|
|
| 115 |
cd Quadruped-PyMPC
|
| 116 |
pip install -e .
|
| 117 |
|
| 118 |
+
# Start the server (default reward threshold: -0.1)
|
| 119 |
python mujoco_server.py
|
| 120 |
|
| 121 |
+
# Or with custom reward threshold for auto episode termination
|
| 122 |
+
python mujoco_server.py --reward-threshold -0.05 # Stricter (5cm from target)
|
| 123 |
+
python mujoco_server.py --reward-threshold -0.2 # Lenient (20cm from target)
|
| 124 |
+
|
| 125 |
# Open browser at http://localhost:3004/nova-sim/api/v1
|
| 126 |
```
|
| 127 |
|
| 128 |
+
**Reward Threshold**: Episodes automatically terminate when the robot reaches within the specified distance of the target. See [REWARD_THRESHOLD.md](REWARD_THRESHOLD.md) for details.
|
| 129 |
+
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
## Docker Deployment
|
| 133 |
|
| 134 |
### Getting Started
|
|
|
|
| 389 |
|
| 390 |
// Receive messages
|
| 391 |
ws.onmessage = (event) => {
|
| 392 |
+
const msg = JSON.parse(event.data);
|
| 393 |
+
if (msg.type === 'state') {
|
| 394 |
+
console.log(msg.data);
|
| 395 |
+
}
|
| 396 |
+
};
|
| 397 |
+
```
|
| 398 |
+
|
| 399 |
+
Nova-Sim uses `/ws` as the shared control channel for the browser UI, trainers, and any RL clients. Every UI interaction (teleop, camera controls, robot switching) and the trainer handshake/notifications flows through this single socket; the UI `state` messages shown below now also carry the action deltas, integrated reward, and trainer connection status that RL agents need. The legacy `/nova-sim/api/v1/gym/ws` endpoint remains for backwards compatibility, but new clients should talk to `/ws` instead.
|
| 400 |
|
| 401 |
#### Client → Server Messages
|
| 402 |
|
|
|
|
| 407 |
| `command` | `{vx, vy, vyaw}` | Set velocity command |
|
| 408 |
| `reset` | `{}` | Reset robot to standing pose |
|
| 409 |
| `switch_robot` | `{robot, scene?}` | Switch active robot and optional scene |
|
| 410 |
+
| `configure` | `{robot, scene?}` | Reconfigure the active robot/scene (reuses the UI configure flow) |
|
| 411 |
| `camera` | `{action, ...}` | Camera control |
|
| 412 |
| `camera_follow` | `{follow}` | Toggle camera follow mode |
|
| 413 |
| `teleop_command` | `{dx, dy, dz}` | Apply incremental cartesian jog command (UI teleop) |
|
| 414 |
|
| 415 |
+
RL agents, trainers, and CLI tools connect to `/ws` as well by sending gym-style control packets that include a request `id`. The server replies with `gym_*` responses so your existing clients keep working while also benefiting from the richer state stream.
|
| 416 |
+
|
| 417 |
+
**Gym-style RL/RL control messages** (use `id` and expect a `gym_*` response):
|
| 418 |
+
```
|
| 419 |
+
{"type": "reset", "id": 42}
|
| 420 |
+
{"type": "step", "data": {"action": [0,0,0,0,0,0,0], "render": false}, "id": 43}
|
| 421 |
+
{"type": "configure", "data": {"robot": "ur5_t_push"}, "id": 44}
|
| 422 |
+
{"type": "get_spaces", "id": 45}
|
| 423 |
+
{"type": "close", "id": 46}
|
| 424 |
+
```
|
| 425 |
+
The server replies with `gym_reset`, `gym_step`, `gym_configured`, `gym_spaces`, or `gym_closed` (matching the `id` you supplied). Every message still flows over `/ws`, so UI teleop commands, state updates, and gym responses share the same connection.
|
| 426 |
+
|
| 427 |
**`command`:**
|
| 428 |
```json
|
| 429 |
{"type": "command", "data": {"vx": 0.5, "vy": 0.0, "vyaw": 0.0}}
|
|
|
|
| 516 |
"base_height": 0.46,
|
| 517 |
"upright": 0.98,
|
| 518 |
"steps": 1234,
|
| 519 |
+
"command": {"vx": 0.5, "vy": 0.0, "vyaw": 0.0},
|
| 520 |
+
"reward": 0.0,
|
| 521 |
+
"teleop_command": {"dx": 0.05, "dy": 0.0, "dz": 0.0},
|
| 522 |
+
"trainer_connected": true
|
| 523 |
}
|
| 524 |
}
|
| 525 |
```
|
|
|
|
| 539 |
"control_mode": "ik",
|
| 540 |
"use_orientation": true,
|
| 541 |
"steps": 1234,
|
| 542 |
+
"reward": -0.25,
|
| 543 |
+
"command": {"vx": 0.1, "vy": 0.0, "vyaw": 0.0},
|
| 544 |
+
"teleop_command": {"dx": 0.02, "dy": 0.0, "dz": 0.0},
|
| 545 |
+
"trainer_connected": true,
|
| 546 |
"nova_api": {
|
| 547 |
"connected": true,
|
| 548 |
"state_streaming": true,
|
|
|
|
| 558 |
- `connected`: Whether Nova API client is connected
|
| 559 |
- `state_streaming`: Whether using Nova API for robot state streaming (vs. internal)
|
| 560 |
- `ik`: Whether using Nova API for inverse kinematics (vs. internal)
|
| 561 |
+
- `command`: The latest velocity command (`vx`, `vy`, `vyaw`) that drives locomotion or arm movement.
|
| 562 |
+
- `teleop_command`: The most recent UI teleop delta (`dx`, `dy`, `dz`) so trainers know how the UI nudged the robot.
|
| 563 |
+
- `reward`: The integrated task reward from the simulator that remote trainers can consume.
|
| 564 |
+
- `trainer_connected`: Whether a trainer handshake is active on `/ws` (useful for status LEDs).
|
| 565 |
|
| 566 |
+
### State broadcasts and trainer notifications
|
|
|
|
|
|
|
|
|
|
| 567 |
|
| 568 |
+
Every `/ws` client receives a `state` message roughly every 100 ms. The examples above show the locomotion (`spot`) and arm (`ur5`) payloads; the payload also now includes:
|
| 569 |
|
| 570 |
+
- `command`: The last velocity command that drives locomotion or arm motion (`vx`, `vy`, `vyaw`).
|
| 571 |
+
- `teleop_command`: The latest UI teleop delta (`dx`, `dy`, `dz`) so trainers know how the browser nudged the robot.
|
| 572 |
+
- `reward`: The integrated task reward that trainers can consume without sending a separate `step`.
|
| 573 |
+
- `trainer_connected`: Whether a trainer handshake is active on `/ws` (used to update the UI indicator).
|
|
|
|
| 574 |
|
| 575 |
+
Trainers announce themselves by sending a `trainer_identity` payload when the socket opens. The server mirrors that information into the `trainer_status` broadcasts (`trainer_status` messages flow to every UI client) and lets trainers emit `notification` payloads that the UI receives as `trainer_notification` events.
|
| 576 |
|
| 577 |
### HTTP Endpoints
|
| 578 |
|
|
|
|
| 831 |
2. Keep it running at `http://localhost:3004` so the HTTP/websocket endpoints stay reachable.
|
| 832 |
3. Run `pytest nova-sim/tests` to exercise:
|
| 833 |
- API endpoints (`/metadata`, `/camera/<name>/video_feed`, `/video_feed`)
|
| 834 |
+
- Unified WebSocket control (`/ws`)
|
|
|
|
| 835 |
- Auxiliary MJPEG overlays after switching to the T-push UR5 scene
|
| 836 |
|
| 837 |
The tests assume the server is accessible via `http://localhost:3004/nova-sim/api/v1` and will skip automatically if the API is unreachable.
|
|
@@ -9,7 +9,7 @@ import cv2
|
|
| 9 |
import numpy as np
|
| 10 |
import mujoco
|
| 11 |
from pathlib import Path
|
| 12 |
-
from typing import Any
|
| 13 |
from flask import Flask, Response, render_template_string, request, jsonify
|
| 14 |
from flask_sock import Sock
|
| 15 |
|
|
@@ -98,6 +98,10 @@ episode_control_state = {
|
|
| 98 |
}
|
| 99 |
episode_control_lock = threading.Lock()
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
# WebSocket clients
|
| 102 |
ws_clients = set()
|
| 103 |
ws_clients_lock = threading.Lock()
|
|
@@ -105,6 +109,8 @@ trainer_ws_clients = set()
|
|
| 105 |
trainer_ws_clients_lock = threading.Lock()
|
| 106 |
trainer_client_metadata: dict = {}
|
| 107 |
trainer_client_metadata_lock = threading.Lock()
|
|
|
|
|
|
|
| 108 |
|
| 109 |
# Camera state for orbit controls
|
| 110 |
cam = mujoco.MjvCamera()
|
|
@@ -199,7 +205,6 @@ UR5_T_PUSH_OVERLAY_PRESETS = [
|
|
| 199 |
|
| 200 |
OVERLAY_CAMERA_PRESETS = {
|
| 201 |
"ur5_t_push": UR5_T_PUSH_OVERLAY_PRESETS,
|
| 202 |
-
"scene_t_push": UR5_T_PUSH_OVERLAY_PRESETS,
|
| 203 |
}
|
| 204 |
|
| 205 |
CAMERA_FEEDS = [
|
|
@@ -484,6 +489,16 @@ def broadcast_state():
|
|
| 484 |
obs = env._get_obs()
|
| 485 |
cmd = env.get_command()
|
| 486 |
steps = env.steps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
|
| 488 |
with trainer_ws_clients_lock:
|
| 489 |
trainer_connected = len(trainer_ws_clients) > 0
|
|
@@ -533,7 +548,9 @@ def broadcast_state():
|
|
| 533 |
'control_mode': control_mode,
|
| 534 |
'use_orientation': use_orientation,
|
| 535 |
'steps': int(steps),
|
| 536 |
-
'reward':
|
|
|
|
|
|
|
| 537 |
'nova_api': {
|
| 538 |
'connected': nova_connected,
|
| 539 |
'state_streaming': nova_state_streaming,
|
|
@@ -557,9 +574,12 @@ def broadcast_state():
|
|
| 557 |
'base_height': base_height,
|
| 558 |
'upright': upright,
|
| 559 |
'steps': int(steps),
|
|
|
|
|
|
|
| 560 |
'vx': float(cmd[0]),
|
| 561 |
'vy': float(cmd[1]),
|
| 562 |
'vyaw': float(cmd[2]),
|
|
|
|
| 563 |
'trainer_connected': trainer_connected
|
| 564 |
}
|
| 565 |
})
|
|
@@ -576,6 +596,25 @@ def broadcast_state():
|
|
| 576 |
ws_clients.difference_update(dead_clients)
|
| 577 |
|
| 578 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
def _build_trainer_status_payload():
|
| 580 |
"""Build a summary payload describing connected trainer clients."""
|
| 581 |
with trainer_ws_clients_lock:
|
|
@@ -650,6 +689,106 @@ def broadcast_notification_to_ui(payload: dict):
|
|
| 650 |
ws_clients.difference_update(dead_clients)
|
| 651 |
|
| 652 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 653 |
def _signal_episode_control(action: str):
|
| 654 |
"""Set episode control flags and notify trainer/UI clients."""
|
| 655 |
action = (action or "").lower()
|
|
@@ -828,12 +967,19 @@ def generate_overlay_frames(name: str):
|
|
| 828 |
time.sleep(0.04)
|
| 829 |
|
| 830 |
|
| 831 |
-
def handle_ws_message(data):
|
| 832 |
"""Handle incoming WebSocket message."""
|
| 833 |
-
global needs_robot_switch, camera_follow
|
| 834 |
|
| 835 |
msg_type = data.get('type')
|
| 836 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 837 |
if msg_type == 'command':
|
| 838 |
payload = data.get('data', {})
|
| 839 |
vx = payload.get('vx', 0.0)
|
|
@@ -868,6 +1014,14 @@ def handle_ws_message(data):
|
|
| 868 |
print(f"Robot switch requested: {robot} / scene: {scene}")
|
| 869 |
needs_robot_switch = {"robot": robot, "scene": scene}
|
| 870 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 871 |
elif msg_type == 'camera':
|
| 872 |
payload = data.get('data', {})
|
| 873 |
a = cam.azimuth * np.pi / 180.0
|
|
@@ -913,6 +1067,17 @@ def handle_ws_message(data):
|
|
| 913 |
else:
|
| 914 |
updated_target = None
|
| 915 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 916 |
broadcast_to_trainer(
|
| 917 |
"teleop_command",
|
| 918 |
{
|
|
@@ -1057,6 +1222,7 @@ def _create_env(robot, scene):
|
|
| 1057 |
sys.path.insert(0, ur5_dir)
|
| 1058 |
from ur5_env import UR5Env
|
| 1059 |
sys.path.pop(0)
|
|
|
|
| 1060 |
if scene:
|
| 1061 |
return UR5Env(render_mode="rgb_array", width=RENDER_WIDTH, height=RENDER_HEIGHT, scene_name=scene)
|
| 1062 |
return UR5Env(render_mode="rgb_array", width=RENDER_WIDTH, height=RENDER_HEIGHT)
|
|
@@ -1134,7 +1300,7 @@ def websocket_handler(ws):
|
|
| 1134 |
break
|
| 1135 |
try:
|
| 1136 |
data = json.loads(message)
|
| 1137 |
-
handle_ws_message(data)
|
| 1138 |
except json.JSONDecodeError:
|
| 1139 |
print(f"Invalid JSON received: {message}")
|
| 1140 |
except Exception as e:
|
|
@@ -1145,53 +1311,17 @@ def websocket_handler(ws):
|
|
| 1145 |
# Unregister client
|
| 1146 |
with ws_clients_lock:
|
| 1147 |
ws_clients.discard(ws)
|
| 1148 |
-
print('WebSocket client disconnected')
|
| 1149 |
-
|
| 1150 |
-
|
| 1151 |
-
@sock.route(f'{API_PREFIX}/trainer/ws')
|
| 1152 |
-
def trainer_websocket_handler(ws):
|
| 1153 |
-
"""Handle trainer WebSocket clients (teleop commands + notifications)."""
|
| 1154 |
-
print('Trainer client connected')
|
| 1155 |
-
with trainer_ws_clients_lock:
|
| 1156 |
-
trainer_ws_clients.add(ws)
|
| 1157 |
-
_register_trainer_client(ws)
|
| 1158 |
-
broadcast_trainer_connection_status()
|
| 1159 |
-
broadcast_state()
|
| 1160 |
-
|
| 1161 |
-
try:
|
| 1162 |
-
while True:
|
| 1163 |
-
message = ws.receive()
|
| 1164 |
-
if message is None:
|
| 1165 |
-
break
|
| 1166 |
-
try:
|
| 1167 |
-
data = json.loads(message)
|
| 1168 |
-
_handle_trainer_message(ws, data)
|
| 1169 |
-
except json.JSONDecodeError:
|
| 1170 |
-
print(f"Trainer sent invalid JSON: {message}")
|
| 1171 |
-
except Exception as exc:
|
| 1172 |
-
print(f"Error handling trainer message: {exc}")
|
| 1173 |
-
finally:
|
| 1174 |
with trainer_ws_clients_lock:
|
|
|
|
| 1175 |
trainer_ws_clients.discard(ws)
|
| 1176 |
-
|
| 1177 |
-
|
| 1178 |
-
|
| 1179 |
-
|
|
|
|
|
|
|
| 1180 |
|
| 1181 |
|
| 1182 |
-
def _handle_trainer_message(ws, data):
|
| 1183 |
-
"""Process messages sent from the training client."""
|
| 1184 |
-
msg_type = data.get("type")
|
| 1185 |
-
if msg_type == "trainer_identity":
|
| 1186 |
-
payload = data.get("data", {}) or {}
|
| 1187 |
-
identity = payload.get("trainer_id") or payload.get("trainer_name") or payload.get("name") or "trainer"
|
| 1188 |
-
_set_trainer_identity(ws, identity)
|
| 1189 |
-
broadcast_trainer_connection_status()
|
| 1190 |
-
return
|
| 1191 |
-
if msg_type == "notification":
|
| 1192 |
-
payload = data.get("data", {})
|
| 1193 |
-
payload.setdefault("timestamp", time.time())
|
| 1194 |
-
broadcast_notification_to_ui(payload)
|
| 1195 |
|
| 1196 |
|
| 1197 |
@sock.route(f'{API_PREFIX}/gym/ws')
|
|
@@ -1227,6 +1357,7 @@ def gym_websocket_handler(ws):
|
|
| 1227 |
elif msg_type == "step":
|
| 1228 |
action = payload.get("action", [])
|
| 1229 |
obs, reward, terminated, truncated, info = session.step(action)
|
|
|
|
| 1230 |
response = {
|
| 1231 |
"type": "gym_step",
|
| 1232 |
"data": {
|
|
@@ -1801,12 +1932,6 @@ def index():
|
|
| 1801 |
<span>Z nudge</span>
|
| 1802 |
</span>
|
| 1803 |
</li>
|
| 1804 |
-
<li>
|
| 1805 |
-
<span class="hint-key">
|
| 1806 |
-
<kbd>Enter</kbd>
|
| 1807 |
-
<span>End episode</span>
|
| 1808 |
-
</span>
|
| 1809 |
-
</li>
|
| 1810 |
</ul>
|
| 1811 |
</div>
|
| 1812 |
<div class="robot-info" id="robot_info">
|
|
|
|
| 9 |
import numpy as np
|
| 10 |
import mujoco
|
| 11 |
from pathlib import Path
|
| 12 |
+
from typing import Any, Optional
|
| 13 |
from flask import Flask, Response, render_template_string, request, jsonify
|
| 14 |
from flask_sock import Sock
|
| 15 |
|
|
|
|
| 98 |
}
|
| 99 |
episode_control_lock = threading.Lock()
|
| 100 |
|
| 101 |
+
# Latest teleoperation command (for trainer state)
|
| 102 |
+
last_teleop_command: Optional[dict[str, Any]] = None
|
| 103 |
+
teleop_lock = threading.Lock()
|
| 104 |
+
|
| 105 |
# WebSocket clients
|
| 106 |
ws_clients = set()
|
| 107 |
ws_clients_lock = threading.Lock()
|
|
|
|
| 109 |
trainer_ws_clients_lock = threading.Lock()
|
| 110 |
trainer_client_metadata: dict = {}
|
| 111 |
trainer_client_metadata_lock = threading.Lock()
|
| 112 |
+
gym_sessions: dict = {}
|
| 113 |
+
gym_sessions_lock = threading.Lock()
|
| 114 |
|
| 115 |
# Camera state for orbit controls
|
| 116 |
cam = mujoco.MjvCamera()
|
|
|
|
| 205 |
|
| 206 |
OVERLAY_CAMERA_PRESETS = {
|
| 207 |
"ur5_t_push": UR5_T_PUSH_OVERLAY_PRESETS,
|
|
|
|
| 208 |
}
|
| 209 |
|
| 210 |
CAMERA_FEEDS = [
|
|
|
|
| 489 |
obs = env._get_obs()
|
| 490 |
cmd = env.get_command()
|
| 491 |
steps = env.steps
|
| 492 |
+
command = {
|
| 493 |
+
"vx": float(cmd[0]) if len(cmd) > 0 else 0.0,
|
| 494 |
+
"vy": float(cmd[1]) if len(cmd) > 1 else 0.0,
|
| 495 |
+
"vyaw": float(cmd[2]) if len(cmd) > 2 else 0.0,
|
| 496 |
+
}
|
| 497 |
+
with teleop_lock:
|
| 498 |
+
teleop_snapshot = last_teleop_command.copy() if last_teleop_command else None
|
| 499 |
+
reward_value = None
|
| 500 |
+
if hasattr(env, "get_task_reward"):
|
| 501 |
+
reward_value = env.get_task_reward()
|
| 502 |
|
| 503 |
with trainer_ws_clients_lock:
|
| 504 |
trainer_connected = len(trainer_ws_clients) > 0
|
|
|
|
| 548 |
'control_mode': control_mode,
|
| 549 |
'use_orientation': use_orientation,
|
| 550 |
'steps': int(steps),
|
| 551 |
+
'reward': reward_value,
|
| 552 |
+
'command': command,
|
| 553 |
+
'teleop_command': teleop_snapshot,
|
| 554 |
'nova_api': {
|
| 555 |
'connected': nova_connected,
|
| 556 |
'state_streaming': nova_state_streaming,
|
|
|
|
| 574 |
'base_height': base_height,
|
| 575 |
'upright': upright,
|
| 576 |
'steps': int(steps),
|
| 577 |
+
'command': command,
|
| 578 |
+
'teleop_command': teleop_snapshot,
|
| 579 |
'vx': float(cmd[0]),
|
| 580 |
'vy': float(cmd[1]),
|
| 581 |
'vyaw': float(cmd[2]),
|
| 582 |
+
'reward': reward_value,
|
| 583 |
'trainer_connected': trainer_connected
|
| 584 |
}
|
| 585 |
})
|
|
|
|
| 596 |
ws_clients.difference_update(dead_clients)
|
| 597 |
|
| 598 |
|
| 599 |
+
def _handle_trainer_message(ws, data):
|
| 600 |
+
"""Process message payloads originating from trainers."""
|
| 601 |
+
msg_type = data.get("type")
|
| 602 |
+
if msg_type == "trainer_identity":
|
| 603 |
+
payload = data.get("data", {}) or {}
|
| 604 |
+
identity = payload.get("trainer_id") or payload.get("trainer_name") or payload.get("name") or "trainer"
|
| 605 |
+
with trainer_ws_clients_lock:
|
| 606 |
+
trainer_ws_clients.add(ws)
|
| 607 |
+
_register_trainer_client(ws)
|
| 608 |
+
_set_trainer_identity(ws, identity)
|
| 609 |
+
broadcast_state()
|
| 610 |
+
broadcast_trainer_connection_status()
|
| 611 |
+
return
|
| 612 |
+
if msg_type == "notification":
|
| 613 |
+
payload = data.get("data", {})
|
| 614 |
+
payload.setdefault("timestamp", time.time())
|
| 615 |
+
broadcast_notification_to_ui(payload)
|
| 616 |
+
|
| 617 |
+
|
| 618 |
def _build_trainer_status_payload():
|
| 619 |
"""Build a summary payload describing connected trainer clients."""
|
| 620 |
with trainer_ws_clients_lock:
|
|
|
|
| 689 |
ws_clients.difference_update(dead_clients)
|
| 690 |
|
| 691 |
|
| 692 |
+
def _safe_ws_send(ws, message: dict):
|
| 693 |
+
"""Send JSON message over WebSocket without raising."""
|
| 694 |
+
try:
|
| 695 |
+
ws.send(json.dumps(message))
|
| 696 |
+
except Exception:
|
| 697 |
+
pass
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
def _get_or_create_gym_session(ws):
|
| 701 |
+
with gym_sessions_lock:
|
| 702 |
+
session = gym_sessions.get(ws)
|
| 703 |
+
if session is None:
|
| 704 |
+
session = GymSession()
|
| 705 |
+
gym_sessions[ws] = session
|
| 706 |
+
return session
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
def _remove_gym_session(ws):
|
| 710 |
+
with gym_sessions_lock:
|
| 711 |
+
session = gym_sessions.pop(ws, None)
|
| 712 |
+
if session:
|
| 713 |
+
session.close()
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
def _handle_gym_ws_message(ws, data):
|
| 717 |
+
"""Handle gym-style control messages routed over `/ws`."""
|
| 718 |
+
msg_type = data.get("type")
|
| 719 |
+
if msg_type not in {"reset", "step", "configure", "get_spaces", "close"}:
|
| 720 |
+
return False
|
| 721 |
+
msg_id = data.get("id")
|
| 722 |
+
if msg_id is None:
|
| 723 |
+
return False
|
| 724 |
+
|
| 725 |
+
payload = data.get("data", {}) or {}
|
| 726 |
+
session = _get_or_create_gym_session(ws)
|
| 727 |
+
|
| 728 |
+
try:
|
| 729 |
+
if msg_type == "reset":
|
| 730 |
+
seed = payload.get("seed")
|
| 731 |
+
obs, info = session.reset(seed=seed)
|
| 732 |
+
response = {
|
| 733 |
+
"type": "gym_reset",
|
| 734 |
+
"data": {
|
| 735 |
+
"obs": obs.tolist(),
|
| 736 |
+
"info": _serialize_value(info),
|
| 737 |
+
},
|
| 738 |
+
}
|
| 739 |
+
elif msg_type == "step":
|
| 740 |
+
action = payload.get("action", [])
|
| 741 |
+
render = bool(payload.get("render", False))
|
| 742 |
+
obs, reward, terminated, truncated, info = session.step(action)
|
| 743 |
+
response = {
|
| 744 |
+
"type": "gym_step",
|
| 745 |
+
"data": {
|
| 746 |
+
"obs": obs.tolist(),
|
| 747 |
+
"reward": float(reward),
|
| 748 |
+
"terminated": bool(terminated),
|
| 749 |
+
"truncated": bool(truncated),
|
| 750 |
+
"info": _serialize_value(info),
|
| 751 |
+
},
|
| 752 |
+
}
|
| 753 |
+
if render:
|
| 754 |
+
frame_jpeg = session.render_jpeg()
|
| 755 |
+
if frame_jpeg:
|
| 756 |
+
response["data"]["frame_jpeg"] = frame_jpeg
|
| 757 |
+
elif msg_type == "configure":
|
| 758 |
+
robot = payload.get("robot", "ur5")
|
| 759 |
+
scene = payload.get("scene")
|
| 760 |
+
session.configure(robot, scene)
|
| 761 |
+
response = {
|
| 762 |
+
"type": "gym_configured",
|
| 763 |
+
"data": {"robot": session.robot, "scene": session.scene},
|
| 764 |
+
}
|
| 765 |
+
elif msg_type == "get_spaces":
|
| 766 |
+
response = {
|
| 767 |
+
"type": "gym_spaces",
|
| 768 |
+
"data": {
|
| 769 |
+
"action_space": _serialize_space(session.env.action_space),
|
| 770 |
+
"observation_space": _serialize_space(session.env.observation_space),
|
| 771 |
+
},
|
| 772 |
+
}
|
| 773 |
+
elif msg_type == "close":
|
| 774 |
+
response = {"type": "gym_closed"}
|
| 775 |
+
_remove_gym_session(ws)
|
| 776 |
+
else:
|
| 777 |
+
response = {
|
| 778 |
+
"type": "gym_error",
|
| 779 |
+
"message": f"Unknown message type: {msg_type}",
|
| 780 |
+
}
|
| 781 |
+
if msg_id is not None:
|
| 782 |
+
response["id"] = msg_id
|
| 783 |
+
_safe_ws_send(ws, response)
|
| 784 |
+
except Exception as exc:
|
| 785 |
+
error_response = {"type": "gym_error", "message": str(exc)}
|
| 786 |
+
if msg_id is not None:
|
| 787 |
+
error_response["id"] = msg_id
|
| 788 |
+
_safe_ws_send(ws, error_response)
|
| 789 |
+
return True
|
| 790 |
+
|
| 791 |
+
|
| 792 |
def _signal_episode_control(action: str):
|
| 793 |
"""Set episode control flags and notify trainer/UI clients."""
|
| 794 |
action = (action or "").lower()
|
|
|
|
| 967 |
time.sleep(0.04)
|
| 968 |
|
| 969 |
|
| 970 |
+
def handle_ws_message(ws, data):
|
| 971 |
"""Handle incoming WebSocket message."""
|
| 972 |
+
global needs_robot_switch, camera_follow, last_teleop_command
|
| 973 |
|
| 974 |
msg_type = data.get('type')
|
| 975 |
|
| 976 |
+
if _handle_gym_ws_message(ws, data):
|
| 977 |
+
return
|
| 978 |
+
|
| 979 |
+
if msg_type in ("trainer_identity", "notification"):
|
| 980 |
+
_handle_trainer_message(ws, data)
|
| 981 |
+
return
|
| 982 |
+
|
| 983 |
if msg_type == 'command':
|
| 984 |
payload = data.get('data', {})
|
| 985 |
vx = payload.get('vx', 0.0)
|
|
|
|
| 1014 |
print(f"Robot switch requested: {robot} / scene: {scene}")
|
| 1015 |
needs_robot_switch = {"robot": robot, "scene": scene}
|
| 1016 |
|
| 1017 |
+
elif msg_type == 'configure':
|
| 1018 |
+
payload = data.get('data', {})
|
| 1019 |
+
robot = payload.get('robot')
|
| 1020 |
+
scene = payload.get('scene')
|
| 1021 |
+
if robot:
|
| 1022 |
+
print(f"Configure requested: {robot} / scene: {scene}")
|
| 1023 |
+
needs_robot_switch = {"robot": robot, "scene": scene}
|
| 1024 |
+
|
| 1025 |
elif msg_type == 'camera':
|
| 1026 |
payload = data.get('data', {})
|
| 1027 |
a = cam.azimuth * np.pi / 180.0
|
|
|
|
| 1067 |
else:
|
| 1068 |
updated_target = None
|
| 1069 |
|
| 1070 |
+
with teleop_lock:
|
| 1071 |
+
last_teleop_command = {
|
| 1072 |
+
"dx": dx,
|
| 1073 |
+
"dy": dy,
|
| 1074 |
+
"dz": dz,
|
| 1075 |
+
"robot": current_robot,
|
| 1076 |
+
"scene": getattr(env, "scene_name", None) if env is not None else None,
|
| 1077 |
+
"target": updated_target.tolist() if updated_target is not None else None,
|
| 1078 |
+
"timestamp": timestamp,
|
| 1079 |
+
}
|
| 1080 |
+
|
| 1081 |
broadcast_to_trainer(
|
| 1082 |
"teleop_command",
|
| 1083 |
{
|
|
|
|
| 1222 |
sys.path.insert(0, ur5_dir)
|
| 1223 |
from ur5_env import UR5Env
|
| 1224 |
sys.path.pop(0)
|
| 1225 |
+
# scene is already resolved by _resolve_robot_scene (e.g., "scene_t_push")
|
| 1226 |
if scene:
|
| 1227 |
return UR5Env(render_mode="rgb_array", width=RENDER_WIDTH, height=RENDER_HEIGHT, scene_name=scene)
|
| 1228 |
return UR5Env(render_mode="rgb_array", width=RENDER_WIDTH, height=RENDER_HEIGHT)
|
|
|
|
| 1300 |
break
|
| 1301 |
try:
|
| 1302 |
data = json.loads(message)
|
| 1303 |
+
handle_ws_message(ws, data)
|
| 1304 |
except json.JSONDecodeError:
|
| 1305 |
print(f"Invalid JSON received: {message}")
|
| 1306 |
except Exception as e:
|
|
|
|
| 1311 |
# Unregister client
|
| 1312 |
with ws_clients_lock:
|
| 1313 |
ws_clients.discard(ws)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1314 |
with trainer_ws_clients_lock:
|
| 1315 |
+
was_trainer = ws in trainer_ws_clients
|
| 1316 |
trainer_ws_clients.discard(ws)
|
| 1317 |
+
if was_trainer:
|
| 1318 |
+
_unregister_trainer_client(ws)
|
| 1319 |
+
broadcast_state()
|
| 1320 |
+
broadcast_trainer_connection_status()
|
| 1321 |
+
_remove_gym_session(ws)
|
| 1322 |
+
print('WebSocket client disconnected')
|
| 1323 |
|
| 1324 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1325 |
|
| 1326 |
|
| 1327 |
@sock.route(f'{API_PREFIX}/gym/ws')
|
|
|
|
| 1357 |
elif msg_type == "step":
|
| 1358 |
action = payload.get("action", [])
|
| 1359 |
obs, reward, terminated, truncated, info = session.step(action)
|
| 1360 |
+
print(f"[GYM WS] step reward={reward:.4f}, terminated={terminated}, truncated={truncated}", flush=True)
|
| 1361 |
response = {
|
| 1362 |
"type": "gym_step",
|
| 1363 |
"data": {
|
|
|
|
| 1932 |
<span>Z nudge</span>
|
| 1933 |
</span>
|
| 1934 |
</li>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1935 |
</ul>
|
| 1936 |
</div>
|
| 1937 |
<div class="robot-info" id="robot_info">
|
|
@@ -800,14 +800,22 @@ class UR5Env(gym.Env):
|
|
| 800 |
|
| 801 |
observation = self._get_obs()
|
| 802 |
|
| 803 |
-
# Reward: distance to target
|
| 804 |
-
|
| 805 |
-
|
| 806 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 807 |
|
| 808 |
terminated = False
|
| 809 |
truncated = self.steps >= self.max_steps
|
| 810 |
|
|
|
|
|
|
|
|
|
|
| 811 |
info = {
|
| 812 |
"ee_pos": ee_pos,
|
| 813 |
"target_pos": self._target_pos,
|
|
|
|
| 800 |
|
| 801 |
observation = self._get_obs()
|
| 802 |
|
| 803 |
+
# Reward: Use task reward if available, otherwise distance to target
|
| 804 |
+
task_reward = self.get_task_reward()
|
| 805 |
+
if task_reward is not None:
|
| 806 |
+
reward = task_reward
|
| 807 |
+
else:
|
| 808 |
+
# Fallback: distance to target (for non-task scenes)
|
| 809 |
+
ee_pos = self.get_end_effector_pos()
|
| 810 |
+
dist = np.linalg.norm(ee_pos - self._target_pos)
|
| 811 |
+
reward = -dist
|
| 812 |
|
| 813 |
terminated = False
|
| 814 |
truncated = self.steps >= self.max_steps
|
| 815 |
|
| 816 |
+
# Compute info with distance for debugging
|
| 817 |
+
ee_pos = self.get_end_effector_pos()
|
| 818 |
+
dist = np.linalg.norm(ee_pos - self._target_pos)
|
| 819 |
info = {
|
| 820 |
"ee_pos": ee_pos,
|
| 821 |
"target_pos": self._target_pos,
|
|
@@ -28,8 +28,8 @@ def test_overlay_camera_presets(api_base: str):
|
|
| 28 |
resp = requests.get(f"{api_base}/metadata", timeout=5)
|
| 29 |
data = resp.json()
|
| 30 |
presets = data.get('overlay_camera_presets', {})
|
| 31 |
-
assert 'ur5_t_push' in presets
|
| 32 |
-
target_presets = presets.get('ur5_t_push'
|
| 33 |
names = {item.get('name') for item in target_presets}
|
| 34 |
assert {'aux_top', 'aux_side', 'aux_flange'}.issubset(names)
|
| 35 |
|
|
|
|
| 28 |
resp = requests.get(f"{api_base}/metadata", timeout=5)
|
| 29 |
data = resp.json()
|
| 30 |
presets = data.get('overlay_camera_presets', {})
|
| 31 |
+
assert 'ur5_t_push' in presets
|
| 32 |
+
target_presets = presets.get('ur5_t_push', [])
|
| 33 |
names = {item.get('name') for item in target_presets}
|
| 34 |
assert {'aux_top', 'aux_side', 'aux_flange'}.issubset(names)
|
| 35 |
|
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unit test for episode control flag mechanism.
|
| 3 |
+
|
| 4 |
+
This test verifies that episode control flags work correctly across multiple episodes
|
| 5 |
+
without the problematic env.reset() call.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import threading
|
| 9 |
+
import time
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class MockEpisodeControlSystem:
|
| 13 |
+
"""
|
| 14 |
+
Minimal reproduction of the episode control system from mujoco_server.py.
|
| 15 |
+
|
| 16 |
+
This tests the core flag-based signaling mechanism without needing the full
|
| 17 |
+
MuJoCo environment or WebSocket infrastructure.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self):
|
| 21 |
+
self.episode_control_state = {
|
| 22 |
+
"terminate": False,
|
| 23 |
+
"truncate": False,
|
| 24 |
+
}
|
| 25 |
+
self.episode_control_lock = threading.Lock()
|
| 26 |
+
|
| 27 |
+
def signal_episode_control(self, action: str):
|
| 28 |
+
"""
|
| 29 |
+
Simulates _signal_episode_control() from mujoco_server.py.
|
| 30 |
+
|
| 31 |
+
This is the FIXED version without the env.reset() call.
|
| 32 |
+
"""
|
| 33 |
+
action = (action or "").lower()
|
| 34 |
+
if action not in ("terminate", "truncate"):
|
| 35 |
+
return
|
| 36 |
+
|
| 37 |
+
with self.episode_control_lock:
|
| 38 |
+
self.episode_control_state[action] = True
|
| 39 |
+
|
| 40 |
+
# NOTE: The bug was here - the original code called env.reset()
|
| 41 |
+
# The fix is to NOT reset the environment, as it interferes with
|
| 42 |
+
# the gym client's episode lifecycle management.
|
| 43 |
+
|
| 44 |
+
def consume_episode_control_flags(self):
|
| 45 |
+
"""Simulates _consume_episode_control_flags() from mujoco_server.py."""
|
| 46 |
+
with self.episode_control_lock:
|
| 47 |
+
terminate = self.episode_control_state.get("terminate", False)
|
| 48 |
+
truncate = self.episode_control_state.get("truncate", False)
|
| 49 |
+
self.episode_control_state["terminate"] = False
|
| 50 |
+
self.episode_control_state["truncate"] = False
|
| 51 |
+
return terminate, truncate
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def test_episode_control_multiple_episodes():
|
| 55 |
+
"""
|
| 56 |
+
Test that episode control works correctly for multiple consecutive episodes.
|
| 57 |
+
|
| 58 |
+
This simulates the scenario where a user presses Enter in the Nova-Sim UI
|
| 59 |
+
during multiple episodes to terminate them early.
|
| 60 |
+
"""
|
| 61 |
+
system = MockEpisodeControlSystem()
|
| 62 |
+
|
| 63 |
+
num_episodes = 5
|
| 64 |
+
steps_per_episode = 10
|
| 65 |
+
step_to_terminate = 3
|
| 66 |
+
|
| 67 |
+
print(f"Testing {num_episodes} episodes with episode control...")
|
| 68 |
+
|
| 69 |
+
all_episodes_succeeded = True
|
| 70 |
+
|
| 71 |
+
for ep in range(num_episodes):
|
| 72 |
+
print(f"\n=== Episode {ep + 1}/{num_episodes} ===")
|
| 73 |
+
|
| 74 |
+
# Simulate episode reset
|
| 75 |
+
episode_terminated = False
|
| 76 |
+
|
| 77 |
+
for step in range(steps_per_episode):
|
| 78 |
+
# Simulate user pressing Enter at a specific step
|
| 79 |
+
if step == step_to_terminate:
|
| 80 |
+
print(f" Step {step + 1}: User presses Enter (UI episode control)")
|
| 81 |
+
system.signal_episode_control("terminate")
|
| 82 |
+
|
| 83 |
+
# Simulate gym client calling step()
|
| 84 |
+
term, trunc = system.consume_episode_control_flags()
|
| 85 |
+
|
| 86 |
+
if term or trunc:
|
| 87 |
+
print(f" Step {step + 1}: Episode ended (terminated={term}, truncated={trunc})")
|
| 88 |
+
episode_terminated = True
|
| 89 |
+
break
|
| 90 |
+
else:
|
| 91 |
+
print(f" Step {step + 1}: Continuing...")
|
| 92 |
+
|
| 93 |
+
if not episode_terminated:
|
| 94 |
+
print(f" ✗ FAILED: Episode {ep + 1} did not terminate!")
|
| 95 |
+
all_episodes_succeeded = False
|
| 96 |
+
elif not episode_terminated:
|
| 97 |
+
print(f" ✗ FAILED: Episode {ep + 1} ended too early!")
|
| 98 |
+
all_episodes_succeeded = False
|
| 99 |
+
else:
|
| 100 |
+
print(f" ✓ SUCCESS: Episode {ep + 1} terminated correctly")
|
| 101 |
+
|
| 102 |
+
print("\n" + "=" * 50)
|
| 103 |
+
if all_episodes_succeeded:
|
| 104 |
+
print("✓ ALL TESTS PASSED")
|
| 105 |
+
print(f"Episode control worked correctly for all {num_episodes} episodes")
|
| 106 |
+
return True
|
| 107 |
+
else:
|
| 108 |
+
print("✗ SOME TESTS FAILED")
|
| 109 |
+
return False
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def test_episode_control_threading():
|
| 113 |
+
"""
|
| 114 |
+
Test that episode control is thread-safe.
|
| 115 |
+
|
| 116 |
+
This simulates the scenario where the UI thread signals episode control
|
| 117 |
+
while the gym client thread is consuming flags.
|
| 118 |
+
"""
|
| 119 |
+
system = MockEpisodeControlSystem()
|
| 120 |
+
|
| 121 |
+
print("\nTesting thread safety...")
|
| 122 |
+
|
| 123 |
+
# Simulate concurrent UI and gym client activity
|
| 124 |
+
def ui_thread():
|
| 125 |
+
for i in range(10):
|
| 126 |
+
time.sleep(0.01)
|
| 127 |
+
system.signal_episode_control("terminate")
|
| 128 |
+
|
| 129 |
+
def gym_thread():
|
| 130 |
+
terminate_count = 0
|
| 131 |
+
for i in range(100):
|
| 132 |
+
time.sleep(0.001)
|
| 133 |
+
term, _ = system.consume_episode_control_flags()
|
| 134 |
+
if term:
|
| 135 |
+
terminate_count += 1
|
| 136 |
+
return terminate_count
|
| 137 |
+
|
| 138 |
+
ui = threading.Thread(target=ui_thread)
|
| 139 |
+
gym = threading.Thread(target=gym_thread)
|
| 140 |
+
|
| 141 |
+
ui.start()
|
| 142 |
+
gym.start()
|
| 143 |
+
|
| 144 |
+
ui.join()
|
| 145 |
+
gym.join()
|
| 146 |
+
|
| 147 |
+
print("✓ Thread safety test completed without crashes")
|
| 148 |
+
return True
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
if __name__ == "__main__":
|
| 152 |
+
success = True
|
| 153 |
+
|
| 154 |
+
print("=" * 50)
|
| 155 |
+
print("Episode Control Fix - Unit Tests")
|
| 156 |
+
print("=" * 50)
|
| 157 |
+
|
| 158 |
+
success = test_episode_control_multiple_episodes() and success
|
| 159 |
+
success = test_episode_control_threading() and success
|
| 160 |
+
|
| 161 |
+
print("\n" + "=" * 50)
|
| 162 |
+
if success:
|
| 163 |
+
print("✓ ALL UNIT TESTS PASSED")
|
| 164 |
+
exit(0)
|
| 165 |
+
else:
|
| 166 |
+
print("✗ SOME UNIT TESTS FAILED")
|
| 167 |
+
exit(1)
|