nova-sim / tests /test_client.py
Georg
Implement blocking homing endpoint and refactor homing logic in mujoco_server.py
e9e2294
"""Lightweight WebSocket test client for nova-sim tests (no external dependencies)."""
import json
import time
import requests
from websockets.sync.client import connect
class NovaSimTestClient:
"""Simple WebSocket client for testing nova-sim."""
def __init__(self, base_url: str = "http://localhost:3004/nova-sim/api/v1"):
self.base_url = base_url
self.ws = None
self.latest_state = {}
def connect(self):
"""Connect to nova-sim via WebSocket."""
# Build WebSocket URL
ws_url = self.base_url.replace("http://", "ws://").replace("https://", "wss://") + "/ws"
# Connect WebSocket
self.ws = connect(ws_url, open_timeout=10)
# Send client identity
self.ws.send(json.dumps({
"type": "client_identity",
"data": {"client_id": "test_client"}
}))
# Wait for first state
time.sleep(0.5)
self._receive_state()
def close(self):
"""Close WebSocket connection."""
if self.ws:
self.ws.close()
self.ws = None
def _receive_state(self):
"""Receive and parse state message."""
try:
msg = self.ws.recv(timeout=0.1)
if msg:
data = json.loads(msg)
if data.get("type") == "state":
self.latest_state = data.get("data", {})
except TimeoutError:
pass
def send_message(self, msg: dict):
"""Send a message and update latest state."""
self.ws.send(json.dumps(msg))
self._receive_state()
def home_blocking(self, timeout_s: float = 30.0, tolerance: float = 0.01, poll_interval_s: float = 0.1):
"""Invoke the blocking homing endpoint."""
resp = requests.get(
f"{self.base_url}/homing",
params={
"timeout_s": timeout_s,
"tolerance": tolerance,
"poll_interval_s": poll_interval_s,
},
timeout=max(5.0, timeout_s + 5.0),
)
resp.raise_for_status()
return resp.json()
def get_joint_positions(self):
"""Get current joint positions from latest state."""
obs = self.latest_state.get("observation", {})
return obs.get("joint_positions", [0, 0, 0, 0, 0, 0])
def get_scene_objects(self):
"""Get scene objects from latest state."""
return self.latest_state.get("scene_objects", [])