Spaces:
Sleeping
Sleeping
| import asyncio | |
| import json | |
| import logging | |
| from collections.abc import Callable | |
| from urllib.parse import urlparse | |
| import aiohttp | |
| import websockets | |
| logger = logging.getLogger(__name__) | |
| class RoboticsClientCore: | |
| """Base client for LeRobot Arena robotics API""" | |
| def __init__(self, base_url: str = "http://localhost:8000"): | |
| self.base_url = base_url.rstrip("/") | |
| self.api_base = f"{self.base_url}/robotics" | |
| # WebSocket connection | |
| self.websocket: websockets.WebSocketServerProtocol | None = None | |
| self.workspace_id: str | None = None | |
| self.room_id: str | None = None | |
| self.role: str | None = None | |
| self.participant_id: str | None = None | |
| self.connected = False | |
| # Background task for message handling | |
| self._message_task: asyncio.Task | None = None | |
| # ============= REST API METHODS ============= | |
| async def list_rooms(self, workspace_id: str) -> list[dict]: | |
| """List all available rooms in a workspace""" | |
| async with aiohttp.ClientSession() as session: | |
| async with session.get( | |
| f"{self.api_base}/workspaces/{workspace_id}/rooms" | |
| ) as response: | |
| response.raise_for_status() | |
| result = await response.json() | |
| # Extract the rooms list from the response | |
| return result.get("rooms", []) | |
| async def create_room( | |
| self, workspace_id: str | None = None, room_id: str | None = None | |
| ) -> tuple[str, str]: | |
| """Create a new room and return (workspace_id, room_id)""" | |
| # Generate workspace ID if not provided | |
| final_workspace_id = workspace_id or self._generate_workspace_id() | |
| payload = {} | |
| if room_id: | |
| payload["room_id"] = room_id | |
| async with aiohttp.ClientSession() as session: | |
| async with session.post( | |
| f"{self.api_base}/workspaces/{final_workspace_id}/rooms", json=payload | |
| ) as response: | |
| response.raise_for_status() | |
| result = await response.json() | |
| return result["workspace_id"], result["room_id"] | |
| async def delete_room(self, workspace_id: str, room_id: str) -> bool: | |
| """Delete a room""" | |
| async with aiohttp.ClientSession() as session: | |
| async with session.delete( | |
| f"{self.api_base}/workspaces/{workspace_id}/rooms/{room_id}" | |
| ) as response: | |
| if response.status == 404: | |
| return False | |
| response.raise_for_status() | |
| result = await response.json() | |
| return result["success"] | |
| async def get_room_state(self, workspace_id: str, room_id: str) -> dict: | |
| """Get current room state""" | |
| async with aiohttp.ClientSession() as session: | |
| async with session.get( | |
| f"{self.api_base}/workspaces/{workspace_id}/rooms/{room_id}/state" | |
| ) as response: | |
| response.raise_for_status() | |
| result = await response.json() | |
| # Extract the state from the response | |
| return result.get("state", {}) | |
| async def get_room_info(self, workspace_id: str, room_id: str) -> dict: | |
| """Get basic room information""" | |
| async with aiohttp.ClientSession() as session: | |
| async with session.get( | |
| f"{self.api_base}/workspaces/{workspace_id}/rooms/{room_id}" | |
| ) as response: | |
| response.raise_for_status() | |
| result = await response.json() | |
| # Extract the room data from the response | |
| return result.get("room", {}) | |
| # ============= WEBSOCKET CONNECTION ============= | |
| async def connect_to_room( | |
| self, | |
| workspace_id: str, | |
| room_id: str, | |
| role: str, | |
| participant_id: str | None = None, | |
| ) -> bool: | |
| """Connect to a room as producer or consumer""" | |
| if self.connected: | |
| await self.disconnect() | |
| self.workspace_id = workspace_id | |
| self.room_id = room_id | |
| self.role = role | |
| self.participant_id = participant_id or f"{role}_{id(self)}" | |
| # Convert HTTP URL to WebSocket URL | |
| parsed = urlparse(self.base_url) | |
| ws_scheme = "wss" if parsed.scheme == "https" else "ws" | |
| ws_url = f"{ws_scheme}://{parsed.netloc}/robotics/workspaces/{workspace_id}/rooms/{room_id}/ws" | |
| initial_state_sync = None | |
| try: | |
| self.websocket = await websockets.connect(ws_url) | |
| # Send join message | |
| join_message = {"participant_id": self.participant_id, "role": role} | |
| await self.websocket.send(json.dumps(join_message)) | |
| # Wait for server response to join message | |
| try: | |
| response_text = await asyncio.wait_for( | |
| self.websocket.recv(), timeout=5.0 | |
| ) | |
| response = json.loads(response_text) | |
| if response.get("type") == "error": | |
| logger.error( | |
| f"Server rejected connection: {response.get('message')}" | |
| ) | |
| await self.websocket.close() | |
| return False | |
| if response.get("type") == "state_sync": | |
| # Consumer receives initial state sync, store it and wait for joined message | |
| logger.debug("Received initial state sync") | |
| initial_state_sync = response | |
| # Wait for the joined message | |
| response_text = await asyncio.wait_for( | |
| self.websocket.recv(), timeout=5.0 | |
| ) | |
| response = json.loads(response_text) | |
| if response.get("type") == "joined": | |
| logger.info(f"Successfully joined room {room_id} as {role}") | |
| elif response.get("type") == "error": | |
| logger.error( | |
| f"Server rejected connection: {response.get('message')}" | |
| ) | |
| await self.websocket.close() | |
| return False | |
| else: | |
| logger.warning(f"Unexpected response from server: {response}") | |
| elif response.get("type") == "joined": | |
| logger.info(f"Successfully joined room {room_id} as {role}") | |
| # Connection successful, continue with setup | |
| else: | |
| logger.warning(f"Unexpected response from server: {response}") | |
| except TimeoutError: | |
| logger.error("Timeout waiting for server response") | |
| await self.websocket.close() | |
| return False | |
| except json.JSONDecodeError: | |
| logger.error("Invalid JSON response from server") | |
| await self.websocket.close() | |
| return False | |
| # Start message handling task | |
| self._message_task = asyncio.create_task(self._handle_messages()) | |
| self.connected = True | |
| logger.info(f"Connected to room {room_id} as {role}") | |
| await self._on_connected() | |
| # Process initial state sync if we received one | |
| if initial_state_sync: | |
| await self._process_message(initial_state_sync) | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to connect to room {room_id}: {e}") | |
| return False | |
| async def disconnect(self): | |
| """Disconnect from current room""" | |
| if self._message_task: | |
| self._message_task.cancel() | |
| try: | |
| await self._message_task | |
| except asyncio.CancelledError: | |
| pass | |
| self._message_task = None | |
| if self.websocket: | |
| await self.websocket.close() | |
| self.websocket = None | |
| self.connected = False | |
| self.workspace_id = None | |
| self.room_id = None | |
| self.role = None | |
| self.participant_id = None | |
| await self._on_disconnected() | |
| logger.info("Disconnected from room") | |
| # ============= MESSAGE HANDLING ============= | |
| async def _handle_messages(self): | |
| """Handle incoming WebSocket messages""" | |
| try: | |
| async for message in self.websocket: | |
| try: | |
| data = json.loads(message) | |
| await self._process_message(data) | |
| except json.JSONDecodeError: | |
| logger.error(f"Invalid JSON received: {message}") | |
| except Exception as e: | |
| logger.error(f"Error processing message: {e}") | |
| except websockets.exceptions.ConnectionClosed: | |
| logger.info("WebSocket connection closed") | |
| except Exception as e: | |
| logger.error(f"WebSocket error: {e}") | |
| finally: | |
| self.connected = False | |
| await self._on_disconnected() | |
| async def _process_message(self, data: dict): | |
| """Process incoming message based on type - to be overridden by subclasses""" | |
| msg_type = data.get("type") | |
| if msg_type == "joined": | |
| logger.info( | |
| f"Successfully joined room {data.get('room_id')} as {data.get('role')}" | |
| ) | |
| elif msg_type == "heartbeat_ack": | |
| logger.debug("Heartbeat acknowledged") | |
| else: | |
| # Let subclasses handle specific message types | |
| await self._handle_role_specific_message(data) | |
| async def _handle_role_specific_message(self, data: dict): | |
| """Handle role-specific messages - to be overridden by subclasses""" | |
| # ============= UTILITY METHODS ============= | |
| async def send_heartbeat(self): | |
| """Send heartbeat to server""" | |
| if not self.connected: | |
| return | |
| message = {"type": "heartbeat"} | |
| await self.websocket.send(json.dumps(message)) | |
| def is_connected(self) -> bool: | |
| """Check if client is connected""" | |
| return self.connected | |
| def get_connection_info(self) -> dict: | |
| """Get current connection information""" | |
| return { | |
| "connected": self.connected, | |
| "workspace_id": self.workspace_id, | |
| "room_id": self.room_id, | |
| "role": self.role, | |
| "participant_id": self.participant_id, | |
| "base_url": self.base_url, | |
| } | |
| # ============= HOOKS FOR SUBCLASSES ============= | |
| async def _on_connected(self): | |
| """Called when connection is established - to be overridden by subclasses""" | |
| async def _on_disconnected(self): | |
| """Called when connection is lost - to be overridden by subclasses""" | |
| # ============= CONTEXT MANAGER SUPPORT ============= | |
| async def __aenter__(self): | |
| return self | |
| async def __aexit__(self, exc_type, exc_val, exc_tb): | |
| await self.disconnect() | |
| # ============= WORKSPACE HELPERS ============= | |
| def _generate_workspace_id(self) -> str: | |
| """Generate a UUID-like workspace ID""" | |
| import uuid | |
| return str(uuid.uuid4()) | |
| class RoboticsProducer(RoboticsClientCore): | |
| """Producer client for controlling robots""" | |
| def __init__(self, base_url: str = "http://localhost:8000"): | |
| super().__init__(base_url) | |
| self._on_error_callback: Callable[[str], None] | None = None | |
| self._on_connected_callback: Callable[[], None] | None = None | |
| self._on_disconnected_callback: Callable[[], None] | None = None | |
| async def connect( | |
| self, workspace_id: str, room_id: str, participant_id: str | None = None | |
| ) -> bool: | |
| """Connect as producer to a room""" | |
| return await self.connect_to_room( | |
| workspace_id, room_id, "producer", participant_id | |
| ) | |
| # ============= PRODUCER METHODS ============= | |
| async def send_joint_update(self, joints: list[dict]): | |
| """Send joint updates""" | |
| if not self.connected: | |
| raise ValueError("Must be connected to send joint updates") | |
| message = {"type": "joint_update", "data": joints} | |
| await self.websocket.send(json.dumps(message)) | |
| async def send_state_sync(self, state: dict): | |
| """Send state synchronization (convert dict to list format)""" | |
| joints = [{"name": name, "value": value} for name, value in state.items()] | |
| await self.send_joint_update(joints) | |
| async def send_emergency_stop(self, reason: str = "Emergency stop"): | |
| """Send emergency stop signal""" | |
| if not self.connected: | |
| raise ValueError("Must be connected to send emergency stop") | |
| message = {"type": "emergency_stop", "reason": reason} | |
| await self.websocket.send(json.dumps(message)) | |
| # ============= EVENT CALLBACKS ============= | |
| def on_error(self, callback: Callable[[str], None]): | |
| """Set callback for error events""" | |
| self._on_error_callback = callback | |
| def on_connected(self, callback: Callable[[], None]): | |
| """Set callback for connection events""" | |
| self._on_connected_callback = callback | |
| def on_disconnected(self, callback: Callable[[], None]): | |
| """Set callback for disconnection events""" | |
| self._on_disconnected_callback = callback | |
| # ============= OVERRIDDEN HOOKS ============= | |
| async def _on_connected(self): | |
| if self._on_connected_callback: | |
| self._on_connected_callback() | |
| async def _on_disconnected(self): | |
| if self._on_disconnected_callback: | |
| self._on_disconnected_callback() | |
| async def _handle_role_specific_message(self, data: dict): | |
| """Handle producer-specific messages""" | |
| msg_type = data.get("type") | |
| if msg_type == "emergency_stop": | |
| logger.warning(f"🚨 Emergency stop: {data.get('reason', 'Unknown reason')}") | |
| if self._on_error_callback: | |
| self._on_error_callback( | |
| f"Emergency stop: {data.get('reason', 'Unknown reason')}" | |
| ) | |
| elif msg_type == "error": | |
| error_msg = data.get("message", "Unknown error") | |
| logger.error(f"Server error: {error_msg}") | |
| if self._on_error_callback: | |
| self._on_error_callback(error_msg) | |
| else: | |
| logger.warning(f"Unknown message type for producer: {msg_type}") | |
| class RoboticsConsumer(RoboticsClientCore): | |
| """Consumer client for receiving robot commands""" | |
| def __init__(self, base_url: str = "http://localhost:8000"): | |
| super().__init__(base_url) | |
| self._on_state_sync_callback: Callable[[dict], None] | None = None | |
| self._on_joint_update_callback: Callable[[list], None] | None = None | |
| self._on_error_callback: Callable[[str], None] | None = None | |
| self._on_connected_callback: Callable[[], None] | None = None | |
| self._on_disconnected_callback: Callable[[], None] | None = None | |
| async def connect( | |
| self, workspace_id: str, room_id: str, participant_id: str | None = None | |
| ) -> bool: | |
| """Connect as consumer to a room""" | |
| return await self.connect_to_room( | |
| workspace_id, room_id, "consumer", participant_id | |
| ) | |
| # ============= CONSUMER METHODS ============= | |
| async def get_state_sync(self) -> dict: | |
| """Get current state synchronously""" | |
| if not self.workspace_id or not self.room_id: | |
| raise ValueError("Must be connected to a room") | |
| state = await self.get_room_state(self.workspace_id, self.room_id) | |
| return state.get("joints", {}) | |
| # ============= EVENT CALLBACKS ============= | |
| def on_state_sync(self, callback: Callable[[dict], None]): | |
| """Set callback for state synchronization events""" | |
| self._on_state_sync_callback = callback | |
| def on_joint_update(self, callback: Callable[[list], None]): | |
| """Set callback for joint update events""" | |
| self._on_joint_update_callback = callback | |
| def on_error(self, callback: Callable[[str], None]): | |
| """Set callback for error events""" | |
| self._on_error_callback = callback | |
| def on_connected(self, callback: Callable[[], None]): | |
| """Set callback for connection events""" | |
| self._on_connected_callback = callback | |
| def on_disconnected(self, callback: Callable[[], None]): | |
| """Set callback for disconnection events""" | |
| self._on_disconnected_callback = callback | |
| # ============= OVERRIDDEN HOOKS ============= | |
| async def _on_connected(self): | |
| if self._on_connected_callback: | |
| self._on_connected_callback() | |
| async def _on_disconnected(self): | |
| if self._on_disconnected_callback: | |
| self._on_disconnected_callback() | |
| async def _handle_role_specific_message(self, data: dict): | |
| """Handle consumer-specific messages""" | |
| msg_type = data.get("type") | |
| if msg_type == "state_sync": | |
| if self._on_state_sync_callback: | |
| self._on_state_sync_callback(data.get("data", {})) | |
| elif msg_type == "joint_update": | |
| if self._on_joint_update_callback: | |
| self._on_joint_update_callback(data.get("data", [])) | |
| elif msg_type == "emergency_stop": | |
| logger.warning(f"🚨 Emergency stop: {data.get('reason', 'Unknown reason')}") | |
| if self._on_error_callback: | |
| self._on_error_callback( | |
| f"Emergency stop: {data.get('reason', 'Unknown reason')}" | |
| ) | |
| elif msg_type == "error": | |
| error_msg = data.get("message", "Unknown error") | |
| logger.error(f"Server error: {error_msg}") | |
| if self._on_error_callback: | |
| self._on_error_callback(error_msg) | |
| else: | |
| logger.warning(f"Unknown message type for consumer: {msg_type}") | |
| # ============= FACTORY FUNCTIONS ============= | |
| def create_client(role: str, base_url: str = "http://localhost:8000"): | |
| """Factory function to create the appropriate client based on role""" | |
| if role == "producer": | |
| return RoboticsProducer(base_url) | |
| if role == "consumer": | |
| return RoboticsConsumer(base_url) | |
| raise ValueError(f"Invalid role: {role}. Must be 'producer' or 'consumer'") | |
| async def create_producer_client( | |
| base_url: str = "http://localhost:8000", | |
| workspace_id: str | None = None, | |
| room_id: str | None = None, | |
| ) -> RoboticsProducer: | |
| """Create and connect a producer client""" | |
| client = RoboticsProducer(base_url) | |
| workspace_id, room_id = await client.create_room(workspace_id, room_id) | |
| await client.connect(workspace_id, room_id) | |
| return client | |
| async def create_consumer_client( | |
| workspace_id: str, room_id: str, base_url: str = "http://localhost:8000" | |
| ) -> RoboticsConsumer: | |
| """Create and connect a consumer client""" | |
| client = RoboticsConsumer(base_url) | |
| await client.connect(workspace_id, room_id) | |
| return client | |