Georg commited on
Commit
c525614
·
1 Parent(s): fc65a06

Enhance mujoco_server.py and UR5 environment for improved teleoperation and episode control

Browse files

- Introduced new mechanisms for handling teleoperation commands, including a snapshot of the last command for better state management.
- Updated the reward calculation in the UR5 environment to prioritize task rewards, with a fallback to distance-based rewards.
- Added gym-style control message handling over the unified WebSocket, allowing for richer interactions and state updates.
- Implemented a robust episode control flag system to manage episode termination and truncation without interfering with the gym client's lifecycle.
- Enhanced documentation in README.md to clarify server startup options and the unified WebSocket API for trainers and clients.
- Added unit tests to verify the functionality of the episode control mechanism, ensuring thread safety and correct behavior across multiple episodes.

README.md CHANGED
@@ -115,12 +115,20 @@ git clone --recurse-submodules https://github.com/iit-DLSLab/Quadruped-PyMPC
115
  cd Quadruped-PyMPC
116
  pip install -e .
117
 
118
- # Start the server
119
  python mujoco_server.py
120
 
 
 
 
 
121
  # Open browser at http://localhost:3004/nova-sim/api/v1
122
  ```
123
 
 
 
 
 
124
  ## Docker Deployment
125
 
126
  ### Getting Started
@@ -381,12 +389,14 @@ ws.send(JSON.stringify({type: 'command', data: {vx: 0.5, vy: 0, vyaw: 0}}));
381
 
382
  // Receive messages
383
  ws.onmessage = (event) => {
384
- const msg = JSON.parse(event.data);
385
- if (msg.type === 'state') {
386
- console.log(msg.data);
387
- }
388
- };
389
- ```
 
 
390
 
391
  #### Client → Server Messages
392
 
@@ -397,10 +407,23 @@ All messages are JSON with `{type, data}` structure:
397
  | `command` | `{vx, vy, vyaw}` | Set velocity command |
398
  | `reset` | `{}` | Reset robot to standing pose |
399
  | `switch_robot` | `{robot, scene?}` | Switch active robot and optional scene |
 
400
  | `camera` | `{action, ...}` | Camera control |
401
  | `camera_follow` | `{follow}` | Toggle camera follow mode |
402
  | `teleop_command` | `{dx, dy, dz}` | Apply incremental cartesian jog command (UI teleop) |
403
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  **`command`:**
405
  ```json
406
  {"type": "command", "data": {"vx": 0.5, "vy": 0.0, "vyaw": 0.0}}
@@ -493,9 +516,10 @@ For locomotion robots (G1, Spot):
493
  "base_height": 0.46,
494
  "upright": 0.98,
495
  "steps": 1234,
496
- "vx": 0.5,
497
- "vy": 0.0,
498
- "vyaw": 0.0
 
499
  }
500
  }
501
  ```
@@ -515,6 +539,10 @@ For robot arm (UR5):
515
  "control_mode": "ik",
516
  "use_orientation": true,
517
  "steps": 1234,
 
 
 
 
518
  "nova_api": {
519
  "connected": true,
520
  "state_streaming": true,
@@ -530,21 +558,21 @@ For robot arm (UR5):
530
  - `connected`: Whether Nova API client is connected
531
  - `state_streaming`: Whether using Nova API for robot state streaming (vs. internal)
532
  - `ik`: Whether using Nova API for inverse kinematics (vs. internal)
 
 
 
 
533
 
534
- ### Gym WebSocket API (RL/IL)
535
-
536
- The gym-style API is exposed at `ws://localhost:3004/nova-sim/api/v1/gym/ws`.
537
- It supports `reset`, `step`, `configure`, and `get_spaces`.
538
 
539
- Example request payloads:
540
 
541
- ```json
542
- {"type": "configure", "data": {"robot": "ur5_t_push"}}
543
- {"type": "reset"}
544
- {"type": "step", "data": {"action": [0,0,0,0,0,0,0], "render": false}}
545
- ```
546
 
547
- The server responds with `gym_reset`, `gym_step`, `gym_spaces`, or `gym_configured` messages.
548
 
549
  ### HTTP Endpoints
550
 
@@ -803,8 +831,7 @@ The Nova API integration is implemented in:
803
  2. Keep it running at `http://localhost:3004` so the HTTP/websocket endpoints stay reachable.
804
  3. Run `pytest nova-sim/tests` to exercise:
805
  - API endpoints (`/metadata`, `/camera/<name>/video_feed`, `/video_feed`)
806
- - WebSocket control (`/ws`)
807
- - Gym-style websocket (`/gym/ws`)
808
  - Auxiliary MJPEG overlays after switching to the T-push UR5 scene
809
 
810
  The tests assume the server is accessible via `http://localhost:3004/nova-sim/api/v1` and will skip automatically if the API is unreachable.
 
115
  cd Quadruped-PyMPC
116
  pip install -e .
117
 
118
+ # Start the server (default reward threshold: -0.1)
119
  python mujoco_server.py
120
 
121
+ # Or with custom reward threshold for auto episode termination
122
+ python mujoco_server.py --reward-threshold -0.05 # Stricter (5cm from target)
123
+ python mujoco_server.py --reward-threshold -0.2 # Lenient (20cm from target)
124
+
125
  # Open browser at http://localhost:3004/nova-sim/api/v1
126
  ```
127
 
128
+ **Reward Threshold**: Episodes automatically terminate when the robot reaches within the specified distance of the target. See [REWARD_THRESHOLD.md](REWARD_THRESHOLD.md) for details.
129
+
130
+ ```
131
+
132
  ## Docker Deployment
133
 
134
  ### Getting Started
 
389
 
390
  // Receive messages
391
  ws.onmessage = (event) => {
392
+ const msg = JSON.parse(event.data);
393
+ if (msg.type === 'state') {
394
+ console.log(msg.data);
395
+ }
396
+ };
397
+ ```
398
+
399
+ Nova-Sim uses `/ws` as the shared control channel for the browser UI, trainers, and any RL clients. Every UI interaction (teleop, camera controls, robot switching) and the trainer handshake/notifications flows through this single socket; the UI `state` messages shown below now also carry the action deltas, integrated reward, and trainer connection status that RL agents need. The legacy `/nova-sim/api/v1/gym/ws` endpoint remains for backwards compatibility, but new clients should talk to `/ws` instead.
400
 
401
  #### Client → Server Messages
402
 
 
407
  | `command` | `{vx, vy, vyaw}` | Set velocity command |
408
  | `reset` | `{}` | Reset robot to standing pose |
409
  | `switch_robot` | `{robot, scene?}` | Switch active robot and optional scene |
410
+ | `configure` | `{robot, scene?}` | Reconfigure the active robot/scene (reuses the UI configure flow) |
411
  | `camera` | `{action, ...}` | Camera control |
412
  | `camera_follow` | `{follow}` | Toggle camera follow mode |
413
  | `teleop_command` | `{dx, dy, dz}` | Apply incremental cartesian jog command (UI teleop) |
414
 
415
+ RL agents, trainers, and CLI tools connect to `/ws` as well by sending gym-style control packets that include a request `id`. The server replies with `gym_*` responses so your existing clients keep working while also benefiting from the richer state stream.
416
+
417
+ **Gym-style RL/RL control messages** (use `id` and expect a `gym_*` response):
418
+ ```
419
+ {"type": "reset", "id": 42}
420
+ {"type": "step", "data": {"action": [0,0,0,0,0,0,0], "render": false}, "id": 43}
421
+ {"type": "configure", "data": {"robot": "ur5_t_push"}, "id": 44}
422
+ {"type": "get_spaces", "id": 45}
423
+ {"type": "close", "id": 46}
424
+ ```
425
+ The server replies with `gym_reset`, `gym_step`, `gym_configured`, `gym_spaces`, or `gym_closed` (matching the `id` you supplied). Every message still flows over `/ws`, so UI teleop commands, state updates, and gym responses share the same connection.
426
+
427
  **`command`:**
428
  ```json
429
  {"type": "command", "data": {"vx": 0.5, "vy": 0.0, "vyaw": 0.0}}
 
516
  "base_height": 0.46,
517
  "upright": 0.98,
518
  "steps": 1234,
519
+ "command": {"vx": 0.5, "vy": 0.0, "vyaw": 0.0},
520
+ "reward": 0.0,
521
+ "teleop_command": {"dx": 0.05, "dy": 0.0, "dz": 0.0},
522
+ "trainer_connected": true
523
  }
524
  }
525
  ```
 
539
  "control_mode": "ik",
540
  "use_orientation": true,
541
  "steps": 1234,
542
+ "reward": -0.25,
543
+ "command": {"vx": 0.1, "vy": 0.0, "vyaw": 0.0},
544
+ "teleop_command": {"dx": 0.02, "dy": 0.0, "dz": 0.0},
545
+ "trainer_connected": true,
546
  "nova_api": {
547
  "connected": true,
548
  "state_streaming": true,
 
558
  - `connected`: Whether Nova API client is connected
559
  - `state_streaming`: Whether using Nova API for robot state streaming (vs. internal)
560
  - `ik`: Whether using Nova API for inverse kinematics (vs. internal)
561
+ - `command`: The latest velocity command (`vx`, `vy`, `vyaw`) that drives locomotion or arm movement.
562
+ - `teleop_command`: The most recent UI teleop delta (`dx`, `dy`, `dz`) so trainers know how the UI nudged the robot.
563
+ - `reward`: The integrated task reward from the simulator that remote trainers can consume.
564
+ - `trainer_connected`: Whether a trainer handshake is active on `/ws` (useful for status LEDs).
565
 
566
+ ### State broadcasts and trainer notifications
 
 
 
567
 
568
+ Every `/ws` client receives a `state` message roughly every 100 ms. The examples above show the locomotion (`spot`) and arm (`ur5`) payloads; the payload also now includes:
569
 
570
+ - `command`: The last velocity command that drives locomotion or arm motion (`vx`, `vy`, `vyaw`).
571
+ - `teleop_command`: The latest UI teleop delta (`dx`, `dy`, `dz`) so trainers know how the browser nudged the robot.
572
+ - `reward`: The integrated task reward that trainers can consume without sending a separate `step`.
573
+ - `trainer_connected`: Whether a trainer handshake is active on `/ws` (used to update the UI indicator).
 
574
 
575
+ Trainers announce themselves by sending a `trainer_identity` payload when the socket opens. The server mirrors that information into the `trainer_status` broadcasts (`trainer_status` messages flow to every UI client) and lets trainers emit `notification` payloads that the UI receives as `trainer_notification` events.
576
 
577
  ### HTTP Endpoints
578
 
 
831
  2. Keep it running at `http://localhost:3004` so the HTTP/websocket endpoints stay reachable.
832
  3. Run `pytest nova-sim/tests` to exercise:
833
  - API endpoints (`/metadata`, `/camera/<name>/video_feed`, `/video_feed`)
834
+ - Unified WebSocket control (`/ws`)
 
835
  - Auxiliary MJPEG overlays after switching to the T-push UR5 scene
836
 
837
  The tests assume the server is accessible via `http://localhost:3004/nova-sim/api/v1` and will skip automatically if the API is unreachable.
mujoco_server.py CHANGED
@@ -9,7 +9,7 @@ import cv2
9
  import numpy as np
10
  import mujoco
11
  from pathlib import Path
12
- from typing import Any
13
  from flask import Flask, Response, render_template_string, request, jsonify
14
  from flask_sock import Sock
15
 
@@ -98,6 +98,10 @@ episode_control_state = {
98
  }
99
  episode_control_lock = threading.Lock()
100
 
 
 
 
 
101
  # WebSocket clients
102
  ws_clients = set()
103
  ws_clients_lock = threading.Lock()
@@ -105,6 +109,8 @@ trainer_ws_clients = set()
105
  trainer_ws_clients_lock = threading.Lock()
106
  trainer_client_metadata: dict = {}
107
  trainer_client_metadata_lock = threading.Lock()
 
 
108
 
109
  # Camera state for orbit controls
110
  cam = mujoco.MjvCamera()
@@ -199,7 +205,6 @@ UR5_T_PUSH_OVERLAY_PRESETS = [
199
 
200
  OVERLAY_CAMERA_PRESETS = {
201
  "ur5_t_push": UR5_T_PUSH_OVERLAY_PRESETS,
202
- "scene_t_push": UR5_T_PUSH_OVERLAY_PRESETS,
203
  }
204
 
205
  CAMERA_FEEDS = [
@@ -484,6 +489,16 @@ def broadcast_state():
484
  obs = env._get_obs()
485
  cmd = env.get_command()
486
  steps = env.steps
 
 
 
 
 
 
 
 
 
 
487
 
488
  with trainer_ws_clients_lock:
489
  trainer_connected = len(trainer_ws_clients) > 0
@@ -533,7 +548,9 @@ def broadcast_state():
533
  'control_mode': control_mode,
534
  'use_orientation': use_orientation,
535
  'steps': int(steps),
536
- 'reward': env.get_task_reward(),
 
 
537
  'nova_api': {
538
  'connected': nova_connected,
539
  'state_streaming': nova_state_streaming,
@@ -557,9 +574,12 @@ def broadcast_state():
557
  'base_height': base_height,
558
  'upright': upright,
559
  'steps': int(steps),
 
 
560
  'vx': float(cmd[0]),
561
  'vy': float(cmd[1]),
562
  'vyaw': float(cmd[2]),
 
563
  'trainer_connected': trainer_connected
564
  }
565
  })
@@ -576,6 +596,25 @@ def broadcast_state():
576
  ws_clients.difference_update(dead_clients)
577
 
578
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
579
  def _build_trainer_status_payload():
580
  """Build a summary payload describing connected trainer clients."""
581
  with trainer_ws_clients_lock:
@@ -650,6 +689,106 @@ def broadcast_notification_to_ui(payload: dict):
650
  ws_clients.difference_update(dead_clients)
651
 
652
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
653
  def _signal_episode_control(action: str):
654
  """Set episode control flags and notify trainer/UI clients."""
655
  action = (action or "").lower()
@@ -828,12 +967,19 @@ def generate_overlay_frames(name: str):
828
  time.sleep(0.04)
829
 
830
 
831
- def handle_ws_message(data):
832
  """Handle incoming WebSocket message."""
833
- global needs_robot_switch, camera_follow
834
 
835
  msg_type = data.get('type')
836
 
 
 
 
 
 
 
 
837
  if msg_type == 'command':
838
  payload = data.get('data', {})
839
  vx = payload.get('vx', 0.0)
@@ -868,6 +1014,14 @@ def handle_ws_message(data):
868
  print(f"Robot switch requested: {robot} / scene: {scene}")
869
  needs_robot_switch = {"robot": robot, "scene": scene}
870
 
 
 
 
 
 
 
 
 
871
  elif msg_type == 'camera':
872
  payload = data.get('data', {})
873
  a = cam.azimuth * np.pi / 180.0
@@ -913,6 +1067,17 @@ def handle_ws_message(data):
913
  else:
914
  updated_target = None
915
 
 
 
 
 
 
 
 
 
 
 
 
916
  broadcast_to_trainer(
917
  "teleop_command",
918
  {
@@ -1057,6 +1222,7 @@ def _create_env(robot, scene):
1057
  sys.path.insert(0, ur5_dir)
1058
  from ur5_env import UR5Env
1059
  sys.path.pop(0)
 
1060
  if scene:
1061
  return UR5Env(render_mode="rgb_array", width=RENDER_WIDTH, height=RENDER_HEIGHT, scene_name=scene)
1062
  return UR5Env(render_mode="rgb_array", width=RENDER_WIDTH, height=RENDER_HEIGHT)
@@ -1134,7 +1300,7 @@ def websocket_handler(ws):
1134
  break
1135
  try:
1136
  data = json.loads(message)
1137
- handle_ws_message(data)
1138
  except json.JSONDecodeError:
1139
  print(f"Invalid JSON received: {message}")
1140
  except Exception as e:
@@ -1145,53 +1311,17 @@ def websocket_handler(ws):
1145
  # Unregister client
1146
  with ws_clients_lock:
1147
  ws_clients.discard(ws)
1148
- print('WebSocket client disconnected')
1149
-
1150
-
1151
- @sock.route(f'{API_PREFIX}/trainer/ws')
1152
- def trainer_websocket_handler(ws):
1153
- """Handle trainer WebSocket clients (teleop commands + notifications)."""
1154
- print('Trainer client connected')
1155
- with trainer_ws_clients_lock:
1156
- trainer_ws_clients.add(ws)
1157
- _register_trainer_client(ws)
1158
- broadcast_trainer_connection_status()
1159
- broadcast_state()
1160
-
1161
- try:
1162
- while True:
1163
- message = ws.receive()
1164
- if message is None:
1165
- break
1166
- try:
1167
- data = json.loads(message)
1168
- _handle_trainer_message(ws, data)
1169
- except json.JSONDecodeError:
1170
- print(f"Trainer sent invalid JSON: {message}")
1171
- except Exception as exc:
1172
- print(f"Error handling trainer message: {exc}")
1173
- finally:
1174
  with trainer_ws_clients_lock:
 
1175
  trainer_ws_clients.discard(ws)
1176
- _unregister_trainer_client(ws)
1177
- print('Trainer client disconnected')
1178
- broadcast_state()
1179
- broadcast_trainer_connection_status()
 
 
1180
 
1181
 
1182
- def _handle_trainer_message(ws, data):
1183
- """Process messages sent from the training client."""
1184
- msg_type = data.get("type")
1185
- if msg_type == "trainer_identity":
1186
- payload = data.get("data", {}) or {}
1187
- identity = payload.get("trainer_id") or payload.get("trainer_name") or payload.get("name") or "trainer"
1188
- _set_trainer_identity(ws, identity)
1189
- broadcast_trainer_connection_status()
1190
- return
1191
- if msg_type == "notification":
1192
- payload = data.get("data", {})
1193
- payload.setdefault("timestamp", time.time())
1194
- broadcast_notification_to_ui(payload)
1195
 
1196
 
1197
  @sock.route(f'{API_PREFIX}/gym/ws')
@@ -1227,6 +1357,7 @@ def gym_websocket_handler(ws):
1227
  elif msg_type == "step":
1228
  action = payload.get("action", [])
1229
  obs, reward, terminated, truncated, info = session.step(action)
 
1230
  response = {
1231
  "type": "gym_step",
1232
  "data": {
@@ -1801,12 +1932,6 @@ def index():
1801
  <span>Z nudge</span>
1802
  </span>
1803
  </li>
1804
- <li>
1805
- <span class="hint-key">
1806
- <kbd>Enter</kbd>
1807
- <span>End episode</span>
1808
- </span>
1809
- </li>
1810
  </ul>
1811
  </div>
1812
  <div class="robot-info" id="robot_info">
 
9
  import numpy as np
10
  import mujoco
11
  from pathlib import Path
12
+ from typing import Any, Optional
13
  from flask import Flask, Response, render_template_string, request, jsonify
14
  from flask_sock import Sock
15
 
 
98
  }
99
  episode_control_lock = threading.Lock()
100
 
101
+ # Latest teleoperation command (for trainer state)
102
+ last_teleop_command: Optional[dict[str, Any]] = None
103
+ teleop_lock = threading.Lock()
104
+
105
  # WebSocket clients
106
  ws_clients = set()
107
  ws_clients_lock = threading.Lock()
 
109
  trainer_ws_clients_lock = threading.Lock()
110
  trainer_client_metadata: dict = {}
111
  trainer_client_metadata_lock = threading.Lock()
112
+ gym_sessions: dict = {}
113
+ gym_sessions_lock = threading.Lock()
114
 
115
  # Camera state for orbit controls
116
  cam = mujoco.MjvCamera()
 
205
 
206
  OVERLAY_CAMERA_PRESETS = {
207
  "ur5_t_push": UR5_T_PUSH_OVERLAY_PRESETS,
 
208
  }
209
 
210
  CAMERA_FEEDS = [
 
489
  obs = env._get_obs()
490
  cmd = env.get_command()
491
  steps = env.steps
492
+ command = {
493
+ "vx": float(cmd[0]) if len(cmd) > 0 else 0.0,
494
+ "vy": float(cmd[1]) if len(cmd) > 1 else 0.0,
495
+ "vyaw": float(cmd[2]) if len(cmd) > 2 else 0.0,
496
+ }
497
+ with teleop_lock:
498
+ teleop_snapshot = last_teleop_command.copy() if last_teleop_command else None
499
+ reward_value = None
500
+ if hasattr(env, "get_task_reward"):
501
+ reward_value = env.get_task_reward()
502
 
503
  with trainer_ws_clients_lock:
504
  trainer_connected = len(trainer_ws_clients) > 0
 
548
  'control_mode': control_mode,
549
  'use_orientation': use_orientation,
550
  'steps': int(steps),
551
+ 'reward': reward_value,
552
+ 'command': command,
553
+ 'teleop_command': teleop_snapshot,
554
  'nova_api': {
555
  'connected': nova_connected,
556
  'state_streaming': nova_state_streaming,
 
574
  'base_height': base_height,
575
  'upright': upright,
576
  'steps': int(steps),
577
+ 'command': command,
578
+ 'teleop_command': teleop_snapshot,
579
  'vx': float(cmd[0]),
580
  'vy': float(cmd[1]),
581
  'vyaw': float(cmd[2]),
582
+ 'reward': reward_value,
583
  'trainer_connected': trainer_connected
584
  }
585
  })
 
596
  ws_clients.difference_update(dead_clients)
597
 
598
 
599
+ def _handle_trainer_message(ws, data):
600
+ """Process message payloads originating from trainers."""
601
+ msg_type = data.get("type")
602
+ if msg_type == "trainer_identity":
603
+ payload = data.get("data", {}) or {}
604
+ identity = payload.get("trainer_id") or payload.get("trainer_name") or payload.get("name") or "trainer"
605
+ with trainer_ws_clients_lock:
606
+ trainer_ws_clients.add(ws)
607
+ _register_trainer_client(ws)
608
+ _set_trainer_identity(ws, identity)
609
+ broadcast_state()
610
+ broadcast_trainer_connection_status()
611
+ return
612
+ if msg_type == "notification":
613
+ payload = data.get("data", {})
614
+ payload.setdefault("timestamp", time.time())
615
+ broadcast_notification_to_ui(payload)
616
+
617
+
618
  def _build_trainer_status_payload():
619
  """Build a summary payload describing connected trainer clients."""
620
  with trainer_ws_clients_lock:
 
689
  ws_clients.difference_update(dead_clients)
690
 
691
 
692
+ def _safe_ws_send(ws, message: dict):
693
+ """Send JSON message over WebSocket without raising."""
694
+ try:
695
+ ws.send(json.dumps(message))
696
+ except Exception:
697
+ pass
698
+
699
+
700
+ def _get_or_create_gym_session(ws):
701
+ with gym_sessions_lock:
702
+ session = gym_sessions.get(ws)
703
+ if session is None:
704
+ session = GymSession()
705
+ gym_sessions[ws] = session
706
+ return session
707
+
708
+
709
+ def _remove_gym_session(ws):
710
+ with gym_sessions_lock:
711
+ session = gym_sessions.pop(ws, None)
712
+ if session:
713
+ session.close()
714
+
715
+
716
+ def _handle_gym_ws_message(ws, data):
717
+ """Handle gym-style control messages routed over `/ws`."""
718
+ msg_type = data.get("type")
719
+ if msg_type not in {"reset", "step", "configure", "get_spaces", "close"}:
720
+ return False
721
+ msg_id = data.get("id")
722
+ if msg_id is None:
723
+ return False
724
+
725
+ payload = data.get("data", {}) or {}
726
+ session = _get_or_create_gym_session(ws)
727
+
728
+ try:
729
+ if msg_type == "reset":
730
+ seed = payload.get("seed")
731
+ obs, info = session.reset(seed=seed)
732
+ response = {
733
+ "type": "gym_reset",
734
+ "data": {
735
+ "obs": obs.tolist(),
736
+ "info": _serialize_value(info),
737
+ },
738
+ }
739
+ elif msg_type == "step":
740
+ action = payload.get("action", [])
741
+ render = bool(payload.get("render", False))
742
+ obs, reward, terminated, truncated, info = session.step(action)
743
+ response = {
744
+ "type": "gym_step",
745
+ "data": {
746
+ "obs": obs.tolist(),
747
+ "reward": float(reward),
748
+ "terminated": bool(terminated),
749
+ "truncated": bool(truncated),
750
+ "info": _serialize_value(info),
751
+ },
752
+ }
753
+ if render:
754
+ frame_jpeg = session.render_jpeg()
755
+ if frame_jpeg:
756
+ response["data"]["frame_jpeg"] = frame_jpeg
757
+ elif msg_type == "configure":
758
+ robot = payload.get("robot", "ur5")
759
+ scene = payload.get("scene")
760
+ session.configure(robot, scene)
761
+ response = {
762
+ "type": "gym_configured",
763
+ "data": {"robot": session.robot, "scene": session.scene},
764
+ }
765
+ elif msg_type == "get_spaces":
766
+ response = {
767
+ "type": "gym_spaces",
768
+ "data": {
769
+ "action_space": _serialize_space(session.env.action_space),
770
+ "observation_space": _serialize_space(session.env.observation_space),
771
+ },
772
+ }
773
+ elif msg_type == "close":
774
+ response = {"type": "gym_closed"}
775
+ _remove_gym_session(ws)
776
+ else:
777
+ response = {
778
+ "type": "gym_error",
779
+ "message": f"Unknown message type: {msg_type}",
780
+ }
781
+ if msg_id is not None:
782
+ response["id"] = msg_id
783
+ _safe_ws_send(ws, response)
784
+ except Exception as exc:
785
+ error_response = {"type": "gym_error", "message": str(exc)}
786
+ if msg_id is not None:
787
+ error_response["id"] = msg_id
788
+ _safe_ws_send(ws, error_response)
789
+ return True
790
+
791
+
792
  def _signal_episode_control(action: str):
793
  """Set episode control flags and notify trainer/UI clients."""
794
  action = (action or "").lower()
 
967
  time.sleep(0.04)
968
 
969
 
970
+ def handle_ws_message(ws, data):
971
  """Handle incoming WebSocket message."""
972
+ global needs_robot_switch, camera_follow, last_teleop_command
973
 
974
  msg_type = data.get('type')
975
 
976
+ if _handle_gym_ws_message(ws, data):
977
+ return
978
+
979
+ if msg_type in ("trainer_identity", "notification"):
980
+ _handle_trainer_message(ws, data)
981
+ return
982
+
983
  if msg_type == 'command':
984
  payload = data.get('data', {})
985
  vx = payload.get('vx', 0.0)
 
1014
  print(f"Robot switch requested: {robot} / scene: {scene}")
1015
  needs_robot_switch = {"robot": robot, "scene": scene}
1016
 
1017
+ elif msg_type == 'configure':
1018
+ payload = data.get('data', {})
1019
+ robot = payload.get('robot')
1020
+ scene = payload.get('scene')
1021
+ if robot:
1022
+ print(f"Configure requested: {robot} / scene: {scene}")
1023
+ needs_robot_switch = {"robot": robot, "scene": scene}
1024
+
1025
  elif msg_type == 'camera':
1026
  payload = data.get('data', {})
1027
  a = cam.azimuth * np.pi / 180.0
 
1067
  else:
1068
  updated_target = None
1069
 
1070
+ with teleop_lock:
1071
+ last_teleop_command = {
1072
+ "dx": dx,
1073
+ "dy": dy,
1074
+ "dz": dz,
1075
+ "robot": current_robot,
1076
+ "scene": getattr(env, "scene_name", None) if env is not None else None,
1077
+ "target": updated_target.tolist() if updated_target is not None else None,
1078
+ "timestamp": timestamp,
1079
+ }
1080
+
1081
  broadcast_to_trainer(
1082
  "teleop_command",
1083
  {
 
1222
  sys.path.insert(0, ur5_dir)
1223
  from ur5_env import UR5Env
1224
  sys.path.pop(0)
1225
+ # scene is already resolved by _resolve_robot_scene (e.g., "scene_t_push")
1226
  if scene:
1227
  return UR5Env(render_mode="rgb_array", width=RENDER_WIDTH, height=RENDER_HEIGHT, scene_name=scene)
1228
  return UR5Env(render_mode="rgb_array", width=RENDER_WIDTH, height=RENDER_HEIGHT)
 
1300
  break
1301
  try:
1302
  data = json.loads(message)
1303
+ handle_ws_message(ws, data)
1304
  except json.JSONDecodeError:
1305
  print(f"Invalid JSON received: {message}")
1306
  except Exception as e:
 
1311
  # Unregister client
1312
  with ws_clients_lock:
1313
  ws_clients.discard(ws)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1314
  with trainer_ws_clients_lock:
1315
+ was_trainer = ws in trainer_ws_clients
1316
  trainer_ws_clients.discard(ws)
1317
+ if was_trainer:
1318
+ _unregister_trainer_client(ws)
1319
+ broadcast_state()
1320
+ broadcast_trainer_connection_status()
1321
+ _remove_gym_session(ws)
1322
+ print('WebSocket client disconnected')
1323
 
1324
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1325
 
1326
 
1327
  @sock.route(f'{API_PREFIX}/gym/ws')
 
1357
  elif msg_type == "step":
1358
  action = payload.get("action", [])
1359
  obs, reward, terminated, truncated, info = session.step(action)
1360
+ print(f"[GYM WS] step reward={reward:.4f}, terminated={terminated}, truncated={truncated}", flush=True)
1361
  response = {
1362
  "type": "gym_step",
1363
  "data": {
 
1932
  <span>Z nudge</span>
1933
  </span>
1934
  </li>
 
 
 
 
 
 
1935
  </ul>
1936
  </div>
1937
  <div class="robot-info" id="robot_info">
robots/ur5/ur5_env.py CHANGED
@@ -800,14 +800,22 @@ class UR5Env(gym.Env):
800
 
801
  observation = self._get_obs()
802
 
803
- # Reward: distance to target
804
- ee_pos = self.get_end_effector_pos()
805
- dist = np.linalg.norm(ee_pos - self._target_pos)
806
- reward = -dist
 
 
 
 
 
807
 
808
  terminated = False
809
  truncated = self.steps >= self.max_steps
810
 
 
 
 
811
  info = {
812
  "ee_pos": ee_pos,
813
  "target_pos": self._target_pos,
 
800
 
801
  observation = self._get_obs()
802
 
803
+ # Reward: Use task reward if available, otherwise distance to target
804
+ task_reward = self.get_task_reward()
805
+ if task_reward is not None:
806
+ reward = task_reward
807
+ else:
808
+ # Fallback: distance to target (for non-task scenes)
809
+ ee_pos = self.get_end_effector_pos()
810
+ dist = np.linalg.norm(ee_pos - self._target_pos)
811
+ reward = -dist
812
 
813
  terminated = False
814
  truncated = self.steps >= self.max_steps
815
 
816
+ # Compute info with distance for debugging
817
+ ee_pos = self.get_end_effector_pos()
818
+ dist = np.linalg.norm(ee_pos - self._target_pos)
819
  info = {
820
  "ee_pos": ee_pos,
821
  "target_pos": self._target_pos,
tests/test_api.py CHANGED
@@ -28,8 +28,8 @@ def test_overlay_camera_presets(api_base: str):
28
  resp = requests.get(f"{api_base}/metadata", timeout=5)
29
  data = resp.json()
30
  presets = data.get('overlay_camera_presets', {})
31
- assert 'ur5_t_push' in presets or 'scene_t_push' in presets
32
- target_presets = presets.get('ur5_t_push') or presets.get('scene_t_push') or []
33
  names = {item.get('name') for item in target_presets}
34
  assert {'aux_top', 'aux_side', 'aux_flange'}.issubset(names)
35
 
 
28
  resp = requests.get(f"{api_base}/metadata", timeout=5)
29
  data = resp.json()
30
  presets = data.get('overlay_camera_presets', {})
31
+ assert 'ur5_t_push' in presets
32
+ target_presets = presets.get('ur5_t_push', [])
33
  names = {item.get('name') for item in target_presets}
34
  assert {'aux_top', 'aux_side', 'aux_flange'}.issubset(names)
35
 
tests/test_episode_control_fix.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unit test for episode control flag mechanism.
3
+
4
+ This test verifies that episode control flags work correctly across multiple episodes
5
+ without the problematic env.reset() call.
6
+ """
7
+
8
+ import threading
9
+ import time
10
+
11
+
12
+ class MockEpisodeControlSystem:
13
+ """
14
+ Minimal reproduction of the episode control system from mujoco_server.py.
15
+
16
+ This tests the core flag-based signaling mechanism without needing the full
17
+ MuJoCo environment or WebSocket infrastructure.
18
+ """
19
+
20
+ def __init__(self):
21
+ self.episode_control_state = {
22
+ "terminate": False,
23
+ "truncate": False,
24
+ }
25
+ self.episode_control_lock = threading.Lock()
26
+
27
+ def signal_episode_control(self, action: str):
28
+ """
29
+ Simulates _signal_episode_control() from mujoco_server.py.
30
+
31
+ This is the FIXED version without the env.reset() call.
32
+ """
33
+ action = (action or "").lower()
34
+ if action not in ("terminate", "truncate"):
35
+ return
36
+
37
+ with self.episode_control_lock:
38
+ self.episode_control_state[action] = True
39
+
40
+ # NOTE: The bug was here - the original code called env.reset()
41
+ # The fix is to NOT reset the environment, as it interferes with
42
+ # the gym client's episode lifecycle management.
43
+
44
+ def consume_episode_control_flags(self):
45
+ """Simulates _consume_episode_control_flags() from mujoco_server.py."""
46
+ with self.episode_control_lock:
47
+ terminate = self.episode_control_state.get("terminate", False)
48
+ truncate = self.episode_control_state.get("truncate", False)
49
+ self.episode_control_state["terminate"] = False
50
+ self.episode_control_state["truncate"] = False
51
+ return terminate, truncate
52
+
53
+
54
+ def test_episode_control_multiple_episodes():
55
+ """
56
+ Test that episode control works correctly for multiple consecutive episodes.
57
+
58
+ This simulates the scenario where a user presses Enter in the Nova-Sim UI
59
+ during multiple episodes to terminate them early.
60
+ """
61
+ system = MockEpisodeControlSystem()
62
+
63
+ num_episodes = 5
64
+ steps_per_episode = 10
65
+ step_to_terminate = 3
66
+
67
+ print(f"Testing {num_episodes} episodes with episode control...")
68
+
69
+ all_episodes_succeeded = True
70
+
71
+ for ep in range(num_episodes):
72
+ print(f"\n=== Episode {ep + 1}/{num_episodes} ===")
73
+
74
+ # Simulate episode reset
75
+ episode_terminated = False
76
+
77
+ for step in range(steps_per_episode):
78
+ # Simulate user pressing Enter at a specific step
79
+ if step == step_to_terminate:
80
+ print(f" Step {step + 1}: User presses Enter (UI episode control)")
81
+ system.signal_episode_control("terminate")
82
+
83
+ # Simulate gym client calling step()
84
+ term, trunc = system.consume_episode_control_flags()
85
+
86
+ if term or trunc:
87
+ print(f" Step {step + 1}: Episode ended (terminated={term}, truncated={trunc})")
88
+ episode_terminated = True
89
+ break
90
+ else:
91
+ print(f" Step {step + 1}: Continuing...")
92
+
93
+ if not episode_terminated:
94
+ print(f" ✗ FAILED: Episode {ep + 1} did not terminate!")
95
+ all_episodes_succeeded = False
96
+ elif not episode_terminated:
97
+ print(f" ✗ FAILED: Episode {ep + 1} ended too early!")
98
+ all_episodes_succeeded = False
99
+ else:
100
+ print(f" ✓ SUCCESS: Episode {ep + 1} terminated correctly")
101
+
102
+ print("\n" + "=" * 50)
103
+ if all_episodes_succeeded:
104
+ print("✓ ALL TESTS PASSED")
105
+ print(f"Episode control worked correctly for all {num_episodes} episodes")
106
+ return True
107
+ else:
108
+ print("✗ SOME TESTS FAILED")
109
+ return False
110
+
111
+
112
+ def test_episode_control_threading():
113
+ """
114
+ Test that episode control is thread-safe.
115
+
116
+ This simulates the scenario where the UI thread signals episode control
117
+ while the gym client thread is consuming flags.
118
+ """
119
+ system = MockEpisodeControlSystem()
120
+
121
+ print("\nTesting thread safety...")
122
+
123
+ # Simulate concurrent UI and gym client activity
124
+ def ui_thread():
125
+ for i in range(10):
126
+ time.sleep(0.01)
127
+ system.signal_episode_control("terminate")
128
+
129
+ def gym_thread():
130
+ terminate_count = 0
131
+ for i in range(100):
132
+ time.sleep(0.001)
133
+ term, _ = system.consume_episode_control_flags()
134
+ if term:
135
+ terminate_count += 1
136
+ return terminate_count
137
+
138
+ ui = threading.Thread(target=ui_thread)
139
+ gym = threading.Thread(target=gym_thread)
140
+
141
+ ui.start()
142
+ gym.start()
143
+
144
+ ui.join()
145
+ gym.join()
146
+
147
+ print("✓ Thread safety test completed without crashes")
148
+ return True
149
+
150
+
151
+ if __name__ == "__main__":
152
+ success = True
153
+
154
+ print("=" * 50)
155
+ print("Episode Control Fix - Unit Tests")
156
+ print("=" * 50)
157
+
158
+ success = test_episode_control_multiple_episodes() and success
159
+ success = test_episode_control_threading() and success
160
+
161
+ print("\n" + "=" * 50)
162
+ if success:
163
+ print("✓ ALL UNIT TESTS PASSED")
164
+ exit(0)
165
+ else:
166
+ print("✗ SOME UNIT TESTS FAILED")
167
+ exit(1)