Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """Interactive cartpole control via OpenEnv. | |
| This example demonstrates using the dm_control OpenEnv client with | |
| the cartpole environment. Use arrow keys to control the cart. | |
| Controls: | |
| LEFT/RIGHT arrows: Apply force to move cart | |
| R: Reset environment | |
| ESC or Q: Quit | |
| Requirements: | |
| pip install pygame | |
| Usage: | |
| 1. Start the server: uvicorn server.app:app --host 0.0.0.0 --port 8000 | |
| 2. Run this script: python examples/cartpole_control.py | |
| For visual mode (requires working MuJoCo rendering): | |
| python examples/cartpole_control.py --visual | |
| """ | |
| import argparse | |
| import random | |
| import sys | |
| from pathlib import Path | |
| # Add parent directory to path for imports | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from client import DMControlEnv | |
| from models import DMControlAction | |
| def run_headless(env: DMControlEnv, task: str = "balance", max_steps: int = 500): | |
| """Run cartpole control in headless mode.""" | |
| print("\n=== Headless Mode (OpenEnv Step/Observation Pattern) ===") | |
| print("This mode demonstrates the OpenEnv API with the cartpole.\n") | |
| # Reset environment using OpenEnv pattern | |
| result = env.reset(domain_name="cartpole", task_name=task) | |
| print(f"Initial observations: {list(result.observation.observations.keys())}") | |
| print(f" position: {result.observation.observations.get('position', [])}") | |
| print(f" velocity: {result.observation.observations.get('velocity', [])}") | |
| total_reward = 0.0 | |
| step_count = 0 | |
| print("\nRunning with random actions to demonstrate step/observation pattern...\n") | |
| while not result.done and step_count < max_steps: | |
| # Random action in [-1, 1] | |
| action_value = random.uniform(-1.0, 1.0) | |
| # Step the environment using OpenEnv pattern | |
| action = DMControlAction(values=[action_value]) | |
| result = env.step(action) | |
| # Access observation and reward from result | |
| total_reward += result.reward or 0.0 | |
| step_count += 1 | |
| # Print progress periodically | |
| if step_count % 50 == 0: | |
| pos = result.observation.observations.get("position", []) | |
| vel = result.observation.observations.get("velocity", []) | |
| print( | |
| f"Step {step_count}: reward={result.reward:.3f}, " | |
| f"total={total_reward:.2f}, done={result.done}" | |
| ) | |
| print(f" position={pos}, velocity={vel}") | |
| print(f"\nEpisode finished: {step_count} steps, total reward: {total_reward:.2f}") | |
| def run_interactive(env: DMControlEnv, task: str = "balance"): | |
| """Run interactive control with keyboard input via pygame.""" | |
| import pygame | |
| print("\n=== Interactive Mode (OpenEnv Step/Observation Pattern) ===") | |
| print("Use LEFT/RIGHT arrows to control cart, R to reset, ESC to quit.\n") | |
| # Reset environment using OpenEnv pattern | |
| result = env.reset(domain_name="cartpole", task_name=task) | |
| print(f"Initial observations: {list(result.observation.observations.keys())}") | |
| # Initialize pygame for keyboard input (minimal window) | |
| pygame.init() | |
| screen = pygame.display.set_mode((400, 100)) | |
| pygame.display.set_caption("Cartpole Control - Arrow keys to move, R to reset") | |
| clock = pygame.time.Clock() | |
| # Font for display | |
| font = pygame.font.Font(None, 24) | |
| running = True | |
| total_reward = 0.0 | |
| step_count = 0 | |
| print("\nControls:") | |
| print(" LEFT/RIGHT arrows: Move cart") | |
| print(" R: Reset environment") | |
| print(" ESC or Q: Quit\n") | |
| while running: | |
| # Handle events | |
| for event in pygame.event.get(): | |
| if event.type == pygame.QUIT: | |
| running = False | |
| elif event.type == pygame.KEYDOWN: | |
| if event.key in (pygame.K_ESCAPE, pygame.K_q): | |
| running = False | |
| elif event.key == pygame.K_r: | |
| result = env.reset(domain_name="cartpole", task_name=task) | |
| total_reward = 0.0 | |
| step_count = 0 | |
| print("Environment reset") | |
| # Check for held keys (for continuous control) | |
| keys = pygame.key.get_pressed() | |
| if keys[pygame.K_LEFT]: | |
| action_value = -1.0 | |
| elif keys[pygame.K_RIGHT]: | |
| action_value = 1.0 | |
| else: | |
| action_value = 0.0 | |
| # Step the environment using OpenEnv pattern | |
| action = DMControlAction(values=[action_value]) | |
| result = env.step(action) | |
| # Track reward from result | |
| total_reward += result.reward or 0.0 | |
| step_count += 1 | |
| # Check if episode is done | |
| if result.done: | |
| print( | |
| f"Episode finished! Steps: {step_count}, " | |
| f"Total reward: {total_reward:.2f}" | |
| ) | |
| # Auto-reset on done | |
| result = env.reset(domain_name="cartpole", task_name=task) | |
| total_reward = 0.0 | |
| step_count = 0 | |
| # Update display | |
| direction = ( | |
| "<--" if action_value < 0 else ("-->" if action_value > 0 else "---") | |
| ) | |
| screen.fill((30, 30, 30)) | |
| text = font.render( | |
| f"Step: {step_count} | Reward: {total_reward:.1f} | {direction}", | |
| True, | |
| (255, 255, 255), | |
| ) | |
| screen.blit(text, (10, 40)) | |
| pygame.display.flip() | |
| # Print progress periodically | |
| if step_count % 200 == 0 and step_count > 0: | |
| print(f"Step {step_count}: Total reward: {total_reward:.2f}") | |
| # Cap at 30 FPS | |
| clock.tick(30) | |
| pygame.quit() | |
| print(f"Session ended. Final reward: {total_reward:.2f}") | |
| def run_visual(env: DMControlEnv, task: str = "balance"): | |
| """Run with pygame visualization showing rendered frames.""" | |
| import base64 | |
| import io | |
| import pygame | |
| print("\n=== Visual Mode (OpenEnv Step/Observation Pattern) ===") | |
| # Reset environment with rendering enabled | |
| result = env.reset(domain_name="cartpole", task_name=task, render=True) | |
| print(f"Initial observations: {list(result.observation.observations.keys())}") | |
| # Get first frame to determine window size | |
| if result.observation.pixels is None: | |
| print("Error: Server did not return rendered pixels.") | |
| print("Make sure the server supports render=True") | |
| print("\nTry running in interactive mode (default) instead.") | |
| sys.exit(1) | |
| # Decode base64 PNG to pygame surface | |
| png_data = base64.b64decode(result.observation.pixels) | |
| frame = pygame.image.load(io.BytesIO(png_data)) | |
| frame_size = frame.get_size() | |
| # Initialize pygame | |
| pygame.init() | |
| screen = pygame.display.set_mode(frame_size) | |
| pygame.display.set_caption( | |
| "Cartpole (OpenEnv) - Arrow Keys to Move, R to Reset, ESC to Quit" | |
| ) | |
| clock = pygame.time.Clock() | |
| print("Controls:") | |
| print(" LEFT/RIGHT arrows: Move cart") | |
| print(" R: Reset environment") | |
| print(" ESC or Q: Quit") | |
| running = True | |
| total_reward = 0.0 | |
| step_count = 0 | |
| while running: | |
| # Handle events | |
| for event in pygame.event.get(): | |
| if event.type == pygame.QUIT: | |
| running = False | |
| elif event.type == pygame.KEYDOWN: | |
| if event.key in (pygame.K_ESCAPE, pygame.K_q): | |
| running = False | |
| elif event.key == pygame.K_r: | |
| result = env.reset( | |
| domain_name="cartpole", task_name=task, render=True | |
| ) | |
| total_reward = 0.0 | |
| step_count = 0 | |
| print("Environment reset") | |
| # Check for held keys (for continuous control) | |
| keys = pygame.key.get_pressed() | |
| if keys[pygame.K_LEFT]: | |
| action_value = -1.0 | |
| elif keys[pygame.K_RIGHT]: | |
| action_value = 1.0 | |
| else: | |
| action_value = 0.0 | |
| # Step the environment using OpenEnv pattern | |
| action = DMControlAction(values=[action_value]) | |
| result = env.step(action, render=True) | |
| # Track reward from result | |
| total_reward += result.reward or 0.0 | |
| step_count += 1 | |
| # Check if episode is done | |
| if result.done: | |
| print( | |
| f"Episode finished! Steps: {step_count}, " | |
| f"Total reward: {total_reward:.2f}" | |
| ) | |
| result = env.reset(domain_name="cartpole", task_name=task, render=True) | |
| total_reward = 0.0 | |
| step_count = 0 | |
| # Render the frame from observation pixels | |
| if result.observation.pixels: | |
| png_data = base64.b64decode(result.observation.pixels) | |
| frame = pygame.image.load(io.BytesIO(png_data)) | |
| screen.blit(frame, (0, 0)) | |
| pygame.display.flip() | |
| # Print progress periodically | |
| if step_count % 200 == 0 and step_count > 0: | |
| print(f"Step {step_count}: Total reward: {total_reward:.2f}") | |
| # Cap at 30 FPS | |
| clock.tick(30) | |
| pygame.quit() | |
| print(f"Session ended. Final reward: {total_reward:.2f}") | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Interactive cartpole control via OpenEnv" | |
| ) | |
| parser.add_argument( | |
| "--visual", | |
| action="store_true", | |
| help="Enable pygame visualization with rendered frames", | |
| ) | |
| parser.add_argument( | |
| "--headless", | |
| action="store_true", | |
| help="Run in headless mode (no pygame, automated control)", | |
| ) | |
| parser.add_argument( | |
| "--max-steps", | |
| type=int, | |
| default=500, | |
| help="Maximum steps for headless mode (default: 500)", | |
| ) | |
| parser.add_argument( | |
| "--task", | |
| type=str, | |
| default="balance", | |
| choices=["balance", "balance_sparse", "swingup", "swingup_sparse"], | |
| help="Cartpole task (default: balance)", | |
| ) | |
| args = parser.parse_args() | |
| server_url = "http://localhost:8000" | |
| print(f"Connecting to {server_url}...") | |
| try: | |
| with DMControlEnv(base_url=server_url) as env: | |
| print("Connected!") | |
| # Get environment state | |
| state = env.state() | |
| print(f"Domain: {state.domain_name}, Task: {state.task_name}") | |
| print(f"Action spec: {state.action_spec}") | |
| if args.headless: | |
| run_headless(env, task=args.task, max_steps=args.max_steps) | |
| elif args.visual: | |
| run_visual(env, task=args.task) | |
| else: | |
| run_interactive(env, task=args.task) | |
| except ConnectionError as e: | |
| print(f"Failed to connect: {e}") | |
| print("\nMake sure the server is running:") | |
| print(" cd OpenEnv") | |
| print( | |
| " PYTHONPATH=src:envs uvicorn envs.dm_control_env.server.app:app --port 8000" | |
| ) | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| main() | |