burtenshaw's picture
burtenshaw HF Staff
Upload folder using huggingface_hub
42cc6d2 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Enhanced Android Environment Server Implementation with complete features.
This module wraps DeepMind's android_env with:
- Full gesture support (tap, swipe, scroll, etc.)
- ADB integration for text input and button presses
- Shared memory optimization for parallel training
- Gesture sequencing
"""
import base64
import io
import logging
import subprocess
import time
from multiprocessing import shared_memory
from typing import Any, Dict, List, Optional
from uuid import uuid4
import numpy as np
from android_env import loader
from android_env.components import config_classes
from android_env.proto import adb_pb2
from dm_env import specs
from PIL import Image
from core.env_server.interfaces import Environment
from core.env_server.types import State
from ..models import AndroidAction, AndroidObservation
from .gestures import ADBCommands, GestureBuilder
logger = logging.getLogger(__name__)
class AndroidEnvironment(Environment):
"""
Enhanced Android environment wrapper for OpenEnv.
Features:
- Complete gesture support (swipe, scroll, long press, etc.)
- ADB text input and button press
- Gesture sequencing (multi-step gestures)
- Optional shared memory for high-performance deployments
- Action buffering for gesture composition
"""
def __init__(
self,
task_path: str,
avd_name: str,
adb_path: str = "~/Android/Sdk/platform-tools/adb",
emulator_path: str = "~/Android/Sdk/emulator/emulator",
android_avd_home: str = "~/.android/avd",
android_sdk_root: str = "~/Android/Sdk",
run_headless: bool = True,
image_format: str = "JPEG",
image_quality: int = 85,
use_shared_memory: bool = False,
shared_memory_name: Optional[str] = None,
):
"""Initialize the Android environment.
Args:
task_path: Path to the android_env task textproto file.
avd_name: Name of the Android Virtual Device to use.
adb_path: Path to the ADB executable.
emulator_path: Path to the Android emulator executable.
android_avd_home: Path to the AVD home directory.
android_sdk_root: Path to the Android SDK root.
run_headless: Whether to run the emulator in headless mode.
image_format: Format for encoding screen images ("JPEG" or "PNG").
image_quality: Quality for JPEG encoding (1-100).
use_shared_memory: Use shared memory for zero-copy observations.
shared_memory_name: Name for shared memory segment.
"""
super().__init__()
self._task_path = task_path
self._avd_name = avd_name
self._adb_path = adb_path
self._image_format = image_format
self._image_quality = image_quality
self._use_shared_memory = use_shared_memory
# Gesture sequencing state
self._gesture_queue: List[dict] = []
self._executing_gesture = False
# Create android_env configuration
config = config_classes.AndroidEnvConfig(
task=config_classes.FilesystemTaskConfig(path=task_path),
simulator=config_classes.EmulatorConfig(
emulator_launcher=config_classes.EmulatorLauncherConfig(
emulator_path=emulator_path,
android_sdk_root=android_sdk_root,
android_avd_home=android_avd_home,
avd_name=avd_name,
run_headless=run_headless,
),
adb_controller=config_classes.AdbControllerConfig(adb_path=adb_path),
),
)
# Load the android_env environment
logger.info(f"Loading Android environment with AVD: {avd_name}")
self._android_env = loader.load(config)
# Get action and observation specs
self._action_spec = self._android_env.action_spec()
self._observation_spec = self._android_env.observation_spec()
# Get screen dimensions from first observation
initial_obs = self._android_env.reset().observation
pixels = initial_obs.get("pixels")
if pixels is not None:
self._screen_height, self._screen_width, _ = pixels.shape
else:
self._screen_height, self._screen_width = 1920, 1080 # Default
# Set up shared memory if requested
self._shared_mem = None
if use_shared_memory:
mem_size = self._screen_height * self._screen_width * 3 # RGB
self._shared_mem_name = shared_memory_name or f"android_env_{uuid4().hex[:8]}"
try:
self._shared_mem = shared_memory.SharedMemory(
name=self._shared_mem_name,
create=True,
size=mem_size
)
logger.info(f"Created shared memory: {self._shared_mem_name}")
except Exception as e:
logger.warning(f"Could not create shared memory: {e}. Falling back to encoding.")
self._use_shared_memory = False
# Initialize state
self._state = State(episode_id=str(uuid4()), step_count=0)
self._latest_timestep = None
logger.info(f"Android environment initialized successfully")
logger.info(f"Screen size: {self._screen_width}x{self._screen_height}")
logger.info(f"Action spec: {list(self._action_spec.keys())}")
def reset(self) -> AndroidObservation:
"""Reset the Android environment for a new episode."""
logger.info("Resetting Android environment...")
# Clear gesture queue
self._gesture_queue = []
self._executing_gesture = False
# Reset android_env
self._latest_timestep = self._android_env.reset()
# Update state
self._state = State(episode_id=str(uuid4()), step_count=0)
# Convert timestep to observation
observation = self._convert_timestep_to_observation(self._latest_timestep)
logger.info(f"Reset complete. Episode ID: {self._state.episode_id}")
return observation
def step(self, action: AndroidAction) -> AndroidObservation: # type: ignore[override]
"""Execute an action in the Android environment."""
# Convert OpenEnv action to gesture sequence or direct action
gesture_actions = self._convert_action_to_gestures(action)
# Execute all actions in the gesture sequence
for i, gesture_action in enumerate(gesture_actions):
android_action = self._create_android_action(gesture_action)
self._latest_timestep = self._android_env.step(android_action)
# Update state on last action of sequence
if i == len(gesture_actions) - 1:
self._state.step_count += 1
# Convert final timestep to observation
observation = self._convert_timestep_to_observation(self._latest_timestep)
# Check if episode is done
if self._latest_timestep.last():
observation.done = True
logger.info(f"Episode ended after {self._state.step_count} steps")
return observation
@property
def state(self) -> State:
"""Get the current environment state."""
return self._state
def close(self) -> None:
"""Clean up the Android environment."""
logger.info("Closing Android environment...")
if hasattr(self, "_android_env"):
self._android_env.close()
if self._shared_mem:
try:
self._shared_mem.close()
self._shared_mem.unlink()
except:
pass
logger.info("Android environment closed")
def _convert_action_to_gestures(self, action: AndroidAction) -> List[dict]:
"""Convert high-level action to sequence of primitive gestures."""
tool_name = action.tool_name
params = action.parameters
# Use GestureBuilder for complex gestures
if tool_name == "tap":
return GestureBuilder.tap(params["x"], params["y"])
elif tool_name == "swipe":
return GestureBuilder.swipe(
params["x1"], params["y1"],
params["x2"], params["y2"],
params.get("duration_ms", 300)
)
elif tool_name == "long_press":
return GestureBuilder.long_press(
params["x"], params["y"],
params.get("duration_ms", 1000)
)
elif tool_name == "double_tap":
return GestureBuilder.double_tap(params["x"], params["y"])
elif tool_name == "scroll_down":
return GestureBuilder.scroll_down(
params.get("x", 0.5),
params.get("distance", 0.5)
)
elif tool_name == "scroll_up":
return GestureBuilder.scroll_up(
params.get("x", 0.5),
params.get("distance", 0.5)
)
elif tool_name == "swipe_left":
return GestureBuilder.swipe_left(
params.get("y", 0.5),
params.get("distance", 0.5)
)
elif tool_name == "swipe_right":
return GestureBuilder.swipe_right(
params.get("y", 0.5),
params.get("distance", 0.5)
)
elif tool_name == "type_text":
# Execute ADB text input command
self._execute_adb_text(params["text"])
# Return a no-op touch action
return [{"action_type": 2, "x": 0.5, "y": 0.5, "duration_ms": 100}]
elif tool_name == "press_button":
# Execute ADB keyevent command
self._execute_adb_button(params["button"])
# Return a no-op touch action
return [{"action_type": 2, "x": 0.5, "y": 0.5, "duration_ms": 100}]
else:
raise ValueError(f"Unknown action tool_name: {tool_name}")
def _create_android_action(self, gesture_action: dict) -> Dict[str, np.ndarray]:
"""Create android_env action from gesture primitive."""
action = {}
action_type = gesture_action["action_type"]
x = gesture_action["x"]
y = gesture_action["y"]
for key, spec in self._action_spec.items():
if key == "action_type":
action[key] = np.array(action_type, dtype=spec.dtype)
elif key == "touch_position":
action[key] = np.array([np.clip(x, 0.0, 1.0), np.clip(y, 0.0, 1.0)], dtype=spec.dtype)
else:
# Fill other fields with defaults
if isinstance(spec, specs.DiscreteArray):
action[key] = np.array(0, dtype=spec.dtype)
else:
action[key] = np.zeros(spec.shape, dtype=spec.dtype)
return action
def _execute_adb_text(self, text: str) -> None:
"""Execute ADB text input command."""
try:
cmd = ADBCommands.text_input(text)
adb_request = adb_pb2.AdbRequest()
adb_request.generic.command = cmd
self._android_env.execute_adb_call(adb_request)
logger.info(f"Executed ADB text input: {text[:20]}...")
except Exception as e:
logger.error(f"ADB text input failed: {e}")
def _execute_adb_button(self, button: str) -> None:
"""Execute ADB button press command."""
try:
# Map common button names to keycodes
button_map = {
"HOME": ADBCommands.KEYCODE_HOME,
"BACK": ADBCommands.KEYCODE_BACK,
"MENU": ADBCommands.KEYCODE_MENU,
"ENTER": ADBCommands.KEYCODE_ENTER,
"SEARCH": ADBCommands.KEYCODE_SEARCH,
"DELETE": ADBCommands.KEYCODE_DEL,
"TAB": ADBCommands.KEYCODE_TAB,
"SPACE": ADBCommands.KEYCODE_SPACE,
}
keycode = button_map.get(button.upper(), button)
cmd = ADBCommands.keyevent(keycode)
adb_request = adb_pb2.AdbRequest()
adb_request.generic.command = cmd
self._android_env.execute_adb_call(adb_request)
logger.info(f"Executed ADB button press: {button}")
except Exception as e:
logger.error(f"ADB button press failed: {e}")
def _convert_timestep_to_observation(self, timestep: Any) -> AndroidObservation:
"""Convert android_env TimeStep to AndroidObservation."""
obs_dict = timestep.observation
pixels = obs_dict.get("pixels")
if pixels is None:
raise ValueError("No pixels found in android_env observation")
height, width, channels = pixels.shape
# Handle observation encoding
if self._use_shared_memory and self._shared_mem:
# Write pixels to shared memory
screen_image_b64 = self._write_to_shared_memory(pixels)
else:
# Encode to base64
screen_image_b64 = self._encode_image(pixels)
# Extract extras
extras = {k: v for k, v in obs_dict.items() if k != "pixels"}
if hasattr(self._android_env, "task_extras"):
task_extras = self._android_env.task_extras(latest_only=True)
extras.update({"task_extras": task_extras})
observation = AndroidObservation(
screen_image=screen_image_b64,
screen_width=width,
screen_height=height,
timestamp_ms=int(time.time() * 1000),
orientation=0,
pixels_shape=(height, width, channels),
extras=extras,
done=timestep.last(),
reward=float(timestep.reward) if timestep.reward is not None else 0.0,
)
return observation
def _encode_image(self, pixels: np.ndarray) -> str:
"""Encode numpy pixel array to base64 string."""
image = Image.fromarray(pixels.astype(np.uint8))
buffer = io.BytesIO()
if self._image_format == "JPEG":
image.save(buffer, format="JPEG", quality=self._image_quality)
elif self._image_format == "PNG":
image.save(buffer, format="PNG")
else:
raise ValueError(f"Unsupported image format: {self._image_format}")
buffer.seek(0)
image_bytes = buffer.read()
return base64.b64encode(image_bytes).decode("utf-8")
def _write_to_shared_memory(self, pixels: np.ndarray) -> str:
"""Write pixels to shared memory and return memory name."""
if not self._shared_mem:
return self._encode_image(pixels) # Fallback
try:
# Write pixels directly to shared memory
np_array = np.ndarray(
pixels.shape,
dtype=pixels.dtype,
buffer=self._shared_mem.buf
)
np_array[:] = pixels[:]
# Return shared memory name instead of image data
return f"shm://{self._shared_mem_name}"
except Exception as e:
logger.error(f"Shared memory write failed: {e}, falling back to encoding")
return self._encode_image(pixels)
def __del__(self):
"""Cleanup on deletion."""
self.close()