Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Enhanced Android Environment Server Implementation with complete features. | |
| This module wraps DeepMind's android_env with: | |
| - Full gesture support (tap, swipe, scroll, etc.) | |
| - ADB integration for text input and button presses | |
| - Shared memory optimization for parallel training | |
| - Gesture sequencing | |
| """ | |
| import base64 | |
| import io | |
| import logging | |
| import subprocess | |
| import time | |
| from multiprocessing import shared_memory | |
| from typing import Any, Dict, List, Optional | |
| from uuid import uuid4 | |
| import numpy as np | |
| from android_env import loader | |
| from android_env.components import config_classes | |
| from android_env.proto import adb_pb2 | |
| from dm_env import specs | |
| from PIL import Image | |
| from core.env_server.interfaces import Environment | |
| from core.env_server.types import State | |
| from ..models import AndroidAction, AndroidObservation | |
| from .gestures import ADBCommands, GestureBuilder | |
| logger = logging.getLogger(__name__) | |
| class AndroidEnvironment(Environment): | |
| """ | |
| Enhanced Android environment wrapper for OpenEnv. | |
| Features: | |
| - Complete gesture support (swipe, scroll, long press, etc.) | |
| - ADB text input and button press | |
| - Gesture sequencing (multi-step gestures) | |
| - Optional shared memory for high-performance deployments | |
| - Action buffering for gesture composition | |
| """ | |
| def __init__( | |
| self, | |
| task_path: str, | |
| avd_name: str, | |
| adb_path: str = "~/Android/Sdk/platform-tools/adb", | |
| emulator_path: str = "~/Android/Sdk/emulator/emulator", | |
| android_avd_home: str = "~/.android/avd", | |
| android_sdk_root: str = "~/Android/Sdk", | |
| run_headless: bool = True, | |
| image_format: str = "JPEG", | |
| image_quality: int = 85, | |
| use_shared_memory: bool = False, | |
| shared_memory_name: Optional[str] = None, | |
| ): | |
| """Initialize the Android environment. | |
| Args: | |
| task_path: Path to the android_env task textproto file. | |
| avd_name: Name of the Android Virtual Device to use. | |
| adb_path: Path to the ADB executable. | |
| emulator_path: Path to the Android emulator executable. | |
| android_avd_home: Path to the AVD home directory. | |
| android_sdk_root: Path to the Android SDK root. | |
| run_headless: Whether to run the emulator in headless mode. | |
| image_format: Format for encoding screen images ("JPEG" or "PNG"). | |
| image_quality: Quality for JPEG encoding (1-100). | |
| use_shared_memory: Use shared memory for zero-copy observations. | |
| shared_memory_name: Name for shared memory segment. | |
| """ | |
| super().__init__() | |
| self._task_path = task_path | |
| self._avd_name = avd_name | |
| self._adb_path = adb_path | |
| self._image_format = image_format | |
| self._image_quality = image_quality | |
| self._use_shared_memory = use_shared_memory | |
| # Gesture sequencing state | |
| self._gesture_queue: List[dict] = [] | |
| self._executing_gesture = False | |
| # Create android_env configuration | |
| config = config_classes.AndroidEnvConfig( | |
| task=config_classes.FilesystemTaskConfig(path=task_path), | |
| simulator=config_classes.EmulatorConfig( | |
| emulator_launcher=config_classes.EmulatorLauncherConfig( | |
| emulator_path=emulator_path, | |
| android_sdk_root=android_sdk_root, | |
| android_avd_home=android_avd_home, | |
| avd_name=avd_name, | |
| run_headless=run_headless, | |
| ), | |
| adb_controller=config_classes.AdbControllerConfig(adb_path=adb_path), | |
| ), | |
| ) | |
| # Load the android_env environment | |
| logger.info(f"Loading Android environment with AVD: {avd_name}") | |
| self._android_env = loader.load(config) | |
| # Get action and observation specs | |
| self._action_spec = self._android_env.action_spec() | |
| self._observation_spec = self._android_env.observation_spec() | |
| # Get screen dimensions from first observation | |
| initial_obs = self._android_env.reset().observation | |
| pixels = initial_obs.get("pixels") | |
| if pixels is not None: | |
| self._screen_height, self._screen_width, _ = pixels.shape | |
| else: | |
| self._screen_height, self._screen_width = 1920, 1080 # Default | |
| # Set up shared memory if requested | |
| self._shared_mem = None | |
| if use_shared_memory: | |
| mem_size = self._screen_height * self._screen_width * 3 # RGB | |
| self._shared_mem_name = shared_memory_name or f"android_env_{uuid4().hex[:8]}" | |
| try: | |
| self._shared_mem = shared_memory.SharedMemory( | |
| name=self._shared_mem_name, | |
| create=True, | |
| size=mem_size | |
| ) | |
| logger.info(f"Created shared memory: {self._shared_mem_name}") | |
| except Exception as e: | |
| logger.warning(f"Could not create shared memory: {e}. Falling back to encoding.") | |
| self._use_shared_memory = False | |
| # Initialize state | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._latest_timestep = None | |
| logger.info(f"Android environment initialized successfully") | |
| logger.info(f"Screen size: {self._screen_width}x{self._screen_height}") | |
| logger.info(f"Action spec: {list(self._action_spec.keys())}") | |
| def reset(self) -> AndroidObservation: | |
| """Reset the Android environment for a new episode.""" | |
| logger.info("Resetting Android environment...") | |
| # Clear gesture queue | |
| self._gesture_queue = [] | |
| self._executing_gesture = False | |
| # Reset android_env | |
| self._latest_timestep = self._android_env.reset() | |
| # Update state | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| # Convert timestep to observation | |
| observation = self._convert_timestep_to_observation(self._latest_timestep) | |
| logger.info(f"Reset complete. Episode ID: {self._state.episode_id}") | |
| return observation | |
| def step(self, action: AndroidAction) -> AndroidObservation: # type: ignore[override] | |
| """Execute an action in the Android environment.""" | |
| # Convert OpenEnv action to gesture sequence or direct action | |
| gesture_actions = self._convert_action_to_gestures(action) | |
| # Execute all actions in the gesture sequence | |
| for i, gesture_action in enumerate(gesture_actions): | |
| android_action = self._create_android_action(gesture_action) | |
| self._latest_timestep = self._android_env.step(android_action) | |
| # Update state on last action of sequence | |
| if i == len(gesture_actions) - 1: | |
| self._state.step_count += 1 | |
| # Convert final timestep to observation | |
| observation = self._convert_timestep_to_observation(self._latest_timestep) | |
| # Check if episode is done | |
| if self._latest_timestep.last(): | |
| observation.done = True | |
| logger.info(f"Episode ended after {self._state.step_count} steps") | |
| return observation | |
| def state(self) -> State: | |
| """Get the current environment state.""" | |
| return self._state | |
| def close(self) -> None: | |
| """Clean up the Android environment.""" | |
| logger.info("Closing Android environment...") | |
| if hasattr(self, "_android_env"): | |
| self._android_env.close() | |
| if self._shared_mem: | |
| try: | |
| self._shared_mem.close() | |
| self._shared_mem.unlink() | |
| except: | |
| pass | |
| logger.info("Android environment closed") | |
| def _convert_action_to_gestures(self, action: AndroidAction) -> List[dict]: | |
| """Convert high-level action to sequence of primitive gestures.""" | |
| tool_name = action.tool_name | |
| params = action.parameters | |
| # Use GestureBuilder for complex gestures | |
| if tool_name == "tap": | |
| return GestureBuilder.tap(params["x"], params["y"]) | |
| elif tool_name == "swipe": | |
| return GestureBuilder.swipe( | |
| params["x1"], params["y1"], | |
| params["x2"], params["y2"], | |
| params.get("duration_ms", 300) | |
| ) | |
| elif tool_name == "long_press": | |
| return GestureBuilder.long_press( | |
| params["x"], params["y"], | |
| params.get("duration_ms", 1000) | |
| ) | |
| elif tool_name == "double_tap": | |
| return GestureBuilder.double_tap(params["x"], params["y"]) | |
| elif tool_name == "scroll_down": | |
| return GestureBuilder.scroll_down( | |
| params.get("x", 0.5), | |
| params.get("distance", 0.5) | |
| ) | |
| elif tool_name == "scroll_up": | |
| return GestureBuilder.scroll_up( | |
| params.get("x", 0.5), | |
| params.get("distance", 0.5) | |
| ) | |
| elif tool_name == "swipe_left": | |
| return GestureBuilder.swipe_left( | |
| params.get("y", 0.5), | |
| params.get("distance", 0.5) | |
| ) | |
| elif tool_name == "swipe_right": | |
| return GestureBuilder.swipe_right( | |
| params.get("y", 0.5), | |
| params.get("distance", 0.5) | |
| ) | |
| elif tool_name == "type_text": | |
| # Execute ADB text input command | |
| self._execute_adb_text(params["text"]) | |
| # Return a no-op touch action | |
| return [{"action_type": 2, "x": 0.5, "y": 0.5, "duration_ms": 100}] | |
| elif tool_name == "press_button": | |
| # Execute ADB keyevent command | |
| self._execute_adb_button(params["button"]) | |
| # Return a no-op touch action | |
| return [{"action_type": 2, "x": 0.5, "y": 0.5, "duration_ms": 100}] | |
| else: | |
| raise ValueError(f"Unknown action tool_name: {tool_name}") | |
| def _create_android_action(self, gesture_action: dict) -> Dict[str, np.ndarray]: | |
| """Create android_env action from gesture primitive.""" | |
| action = {} | |
| action_type = gesture_action["action_type"] | |
| x = gesture_action["x"] | |
| y = gesture_action["y"] | |
| for key, spec in self._action_spec.items(): | |
| if key == "action_type": | |
| action[key] = np.array(action_type, dtype=spec.dtype) | |
| elif key == "touch_position": | |
| action[key] = np.array([np.clip(x, 0.0, 1.0), np.clip(y, 0.0, 1.0)], dtype=spec.dtype) | |
| else: | |
| # Fill other fields with defaults | |
| if isinstance(spec, specs.DiscreteArray): | |
| action[key] = np.array(0, dtype=spec.dtype) | |
| else: | |
| action[key] = np.zeros(spec.shape, dtype=spec.dtype) | |
| return action | |
| def _execute_adb_text(self, text: str) -> None: | |
| """Execute ADB text input command.""" | |
| try: | |
| cmd = ADBCommands.text_input(text) | |
| adb_request = adb_pb2.AdbRequest() | |
| adb_request.generic.command = cmd | |
| self._android_env.execute_adb_call(adb_request) | |
| logger.info(f"Executed ADB text input: {text[:20]}...") | |
| except Exception as e: | |
| logger.error(f"ADB text input failed: {e}") | |
| def _execute_adb_button(self, button: str) -> None: | |
| """Execute ADB button press command.""" | |
| try: | |
| # Map common button names to keycodes | |
| button_map = { | |
| "HOME": ADBCommands.KEYCODE_HOME, | |
| "BACK": ADBCommands.KEYCODE_BACK, | |
| "MENU": ADBCommands.KEYCODE_MENU, | |
| "ENTER": ADBCommands.KEYCODE_ENTER, | |
| "SEARCH": ADBCommands.KEYCODE_SEARCH, | |
| "DELETE": ADBCommands.KEYCODE_DEL, | |
| "TAB": ADBCommands.KEYCODE_TAB, | |
| "SPACE": ADBCommands.KEYCODE_SPACE, | |
| } | |
| keycode = button_map.get(button.upper(), button) | |
| cmd = ADBCommands.keyevent(keycode) | |
| adb_request = adb_pb2.AdbRequest() | |
| adb_request.generic.command = cmd | |
| self._android_env.execute_adb_call(adb_request) | |
| logger.info(f"Executed ADB button press: {button}") | |
| except Exception as e: | |
| logger.error(f"ADB button press failed: {e}") | |
| def _convert_timestep_to_observation(self, timestep: Any) -> AndroidObservation: | |
| """Convert android_env TimeStep to AndroidObservation.""" | |
| obs_dict = timestep.observation | |
| pixels = obs_dict.get("pixels") | |
| if pixels is None: | |
| raise ValueError("No pixels found in android_env observation") | |
| height, width, channels = pixels.shape | |
| # Handle observation encoding | |
| if self._use_shared_memory and self._shared_mem: | |
| # Write pixels to shared memory | |
| screen_image_b64 = self._write_to_shared_memory(pixels) | |
| else: | |
| # Encode to base64 | |
| screen_image_b64 = self._encode_image(pixels) | |
| # Extract extras | |
| extras = {k: v for k, v in obs_dict.items() if k != "pixels"} | |
| if hasattr(self._android_env, "task_extras"): | |
| task_extras = self._android_env.task_extras(latest_only=True) | |
| extras.update({"task_extras": task_extras}) | |
| observation = AndroidObservation( | |
| screen_image=screen_image_b64, | |
| screen_width=width, | |
| screen_height=height, | |
| timestamp_ms=int(time.time() * 1000), | |
| orientation=0, | |
| pixels_shape=(height, width, channels), | |
| extras=extras, | |
| done=timestep.last(), | |
| reward=float(timestep.reward) if timestep.reward is not None else 0.0, | |
| ) | |
| return observation | |
| def _encode_image(self, pixels: np.ndarray) -> str: | |
| """Encode numpy pixel array to base64 string.""" | |
| image = Image.fromarray(pixels.astype(np.uint8)) | |
| buffer = io.BytesIO() | |
| if self._image_format == "JPEG": | |
| image.save(buffer, format="JPEG", quality=self._image_quality) | |
| elif self._image_format == "PNG": | |
| image.save(buffer, format="PNG") | |
| else: | |
| raise ValueError(f"Unsupported image format: {self._image_format}") | |
| buffer.seek(0) | |
| image_bytes = buffer.read() | |
| return base64.b64encode(image_bytes).decode("utf-8") | |
| def _write_to_shared_memory(self, pixels: np.ndarray) -> str: | |
| """Write pixels to shared memory and return memory name.""" | |
| if not self._shared_mem: | |
| return self._encode_image(pixels) # Fallback | |
| try: | |
| # Write pixels directly to shared memory | |
| np_array = np.ndarray( | |
| pixels.shape, | |
| dtype=pixels.dtype, | |
| buffer=self._shared_mem.buf | |
| ) | |
| np_array[:] = pixels[:] | |
| # Return shared memory name instead of image data | |
| return f"shm://{self._shared_mem_name}" | |
| except Exception as e: | |
| logger.error(f"Shared memory write failed: {e}, falling back to encoding") | |
| return self._encode_image(pixels) | |
| def __del__(self): | |
| """Cleanup on deletion.""" | |
| self.close() | |