import base64 import json import os import sys from io import BytesIO from typing import Any, Dict, List import numpy as np from PIL import Image sys.path.insert(0, os.path.join(os.path.dirname(__file__), "openpi", "src")) from openpi.policies import policy_config from openpi.training import config as train_config class EndpointHandler: def __init__(self, path: str = ""): """ Initialize the handler for pi0 model inference using openpi infrastructure. Args: path: Path to the model weights directory """ # Set model path from environment variable or use provided path model_path = os.environ.get("MODEL_PATH", path) if not model_path: model_path = "weights/pi0" # Load the config.json to determine model type config_path = os.path.join(model_path, "config.json") with open(config_path, "r") as f: model_config = json.load(f) model_type = model_config.get("type", "pi0") # Create training config based on model type # This uses the openpi config system if model_type == "pi0": self.train_config = train_config.get_config("pi0") else: # Default to pi0 if type not recognized self.train_config = train_config.get_config("pi0") # Create trained policy using openpi infrastructure # This handles all the model loading, preprocessing, etc. self.policy = policy_config.create_trained_policy( self.train_config, model_path, pytorch_device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu" ) # Default number of inference steps self.default_num_steps = 50 def _decode_base64_image(self, base64_str: str) -> np.ndarray: """ Decode base64 image string to numpy array. Args: base64_str: Base64 encoded image string Returns: numpy array of shape (H, W, 3) with values in [0, 255] """ # Remove data URL prefix if present if base64_str.startswith("data:image"): base64_str = base64_str.split(",", 1)[1] # Decode base64 image_bytes = base64.b64decode(base64_str) # Convert to PIL Image and then to numpy array image = Image.open(BytesIO(image_bytes)).convert("RGB") image_array = np.array(image) return image_array def _prepare_observation(self, images: Dict[str, str], state: List[float], prompt: str = None) -> Dict[str, Any]: """ Prepare observation dictionary in the format expected by openpi. Args: images: Dictionary mapping camera names to base64 encoded images state: List of robot state values prompt: Optional text prompt Returns: Observation dictionary in openpi format """ # Decode and process images processed_images = {} # Map input camera names to expected openpi format # Based on the config, pi0 expects specific camera names camera_mapping = { "camera0": "cam_high", # base camera "camera1": "cam_left_wrist", # left wrist camera "camera2": "cam_right_wrist", # right wrist camera # Alternative mappings "base_camera": "cam_high", "left_wrist": "cam_left_wrist", "right_wrist": "cam_right_wrist", # Direct mappings "cam_high": "cam_high", "cam_left_wrist": "cam_left_wrist", "cam_right_wrist": "cam_right_wrist" } for input_name, image_b64 in images.items(): # Map to openpi expected name openpi_name = camera_mapping.get(input_name, input_name) # Decode image image_array = self._decode_base64_image(image_b64) # Resize to expected resolution if needed if image_array.shape[:2] != (224, 224): image_pil = Image.fromarray(image_array) image_resized = image_pil.resize((224, 224)) image_array = np.array(image_resized) # Convert to format expected by openpi (H, W, C) with uint8 processed_images[openpi_name] = image_array.astype(np.uint8) # Ensure we have the required cameras, create dummy ones if missing required_cameras = ["cam_high", "cam_left_wrist", "cam_right_wrist"] for cam_name in required_cameras: if cam_name not in processed_images: # Create a black dummy image processed_images[cam_name] = np.zeros((224, 224, 3), dtype=np.uint8) # Prepare state state_array = np.array(state, dtype=np.float32) # Create observation dict in openpi format observation = { "state": state_array, "images": processed_images, } # Add prompt if provided if prompt: observation["prompt"] = prompt return observation def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Main inference function called by HuggingFace endpoint. Args: data: Input data dictionary containing: - inputs: Dictionary with: - images: Dict mapping camera names to base64 encoded images - state: List of robot state values - prompt: Optional text prompt - num_actions: Optional, number of actions to predict (default: 50) - noise: Optional, noise array for sampling Returns: List containing prediction results """ try: inputs = data.get("inputs", {}) # Extract inputs images = inputs.get("images", {}) state = inputs.get("state", []) prompt = inputs.get("prompt", "") num_actions = inputs.get("num_actions", self.default_num_steps) noise_input = inputs.get("noise", None) # Validate inputs if not images: raise ValueError("No images provided") if not state: raise ValueError("No state provided") # Prepare observation using openpi format observation = self._prepare_observation(images, state, prompt) # Prepare noise if provided noise = None if noise_input is not None: noise = np.array(noise_input, dtype=np.float32) # Run inference using openpi policy # This handles all the preprocessing, model inference, and postprocessing result = self.policy.infer(observation, noise=noise) # Extract actions from result actions = result["actions"] # Convert to list format for JSON serialization if isinstance(actions, np.ndarray): actions_list = actions.tolist() else: actions_list = actions # Return in expected format return [{ "actions": actions_list, "num_actions": len(actions_list), "action_horizon": len(actions_list), "action_dim": len(actions_list[0]) if actions_list else 0, "success": True, "metadata": { "model_type": self.train_config.model.model_type.value, "policy_metadata": getattr(self.policy, '_metadata', {}) } }] except Exception as e: return [{ "error": str(e), "success": False }]