op-test2 / handler.py
s3y's picture
Upload folder using huggingface_hub
b5784e2 verified
import base64
import json
import os
import sys
from io import BytesIO
from typing import Any, Dict, List
import numpy as np
from PIL import Image
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "openpi", "src"))
from openpi.policies import policy_config
from openpi.training import config as train_config
class EndpointHandler:
def __init__(self, path: str = ""):
"""
Initialize the handler for pi0 model inference using openpi infrastructure.
Args:
path: Path to the model weights directory
"""
# Set model path from environment variable or use provided path
model_path = os.environ.get("MODEL_PATH", path)
if not model_path:
model_path = "weights/pi0"
# Load the config.json to determine model type
config_path = os.path.join(model_path, "config.json")
with open(config_path, "r") as f:
model_config = json.load(f)
model_type = model_config.get("type", "pi0")
# Create training config based on model type
# This uses the openpi config system
if model_type == "pi0":
self.train_config = train_config.get_config("pi0")
else:
# Default to pi0 if type not recognized
self.train_config = train_config.get_config("pi0")
# Create trained policy using openpi infrastructure
# This handles all the model loading, preprocessing, etc.
self.policy = policy_config.create_trained_policy(
self.train_config,
model_path,
pytorch_device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu"
)
# Default number of inference steps
self.default_num_steps = 50
def _decode_base64_image(self, base64_str: str) -> np.ndarray:
"""
Decode base64 image string to numpy array.
Args:
base64_str: Base64 encoded image string
Returns:
numpy array of shape (H, W, 3) with values in [0, 255]
"""
# Remove data URL prefix if present
if base64_str.startswith("data:image"):
base64_str = base64_str.split(",", 1)[1]
# Decode base64
image_bytes = base64.b64decode(base64_str)
# Convert to PIL Image and then to numpy array
image = Image.open(BytesIO(image_bytes)).convert("RGB")
image_array = np.array(image)
return image_array
def _prepare_observation(self, images: Dict[str, str], state: List[float], prompt: str = None) -> Dict[str, Any]:
"""
Prepare observation dictionary in the format expected by openpi.
Args:
images: Dictionary mapping camera names to base64 encoded images
state: List of robot state values
prompt: Optional text prompt
Returns:
Observation dictionary in openpi format
"""
# Decode and process images
processed_images = {}
# Map input camera names to expected openpi format
# Based on the config, pi0 expects specific camera names
camera_mapping = {
"camera0": "cam_high", # base camera
"camera1": "cam_left_wrist", # left wrist camera
"camera2": "cam_right_wrist", # right wrist camera
# Alternative mappings
"base_camera": "cam_high",
"left_wrist": "cam_left_wrist",
"right_wrist": "cam_right_wrist",
# Direct mappings
"cam_high": "cam_high",
"cam_left_wrist": "cam_left_wrist",
"cam_right_wrist": "cam_right_wrist"
}
for input_name, image_b64 in images.items():
# Map to openpi expected name
openpi_name = camera_mapping.get(input_name, input_name)
# Decode image
image_array = self._decode_base64_image(image_b64)
# Resize to expected resolution if needed
if image_array.shape[:2] != (224, 224):
image_pil = Image.fromarray(image_array)
image_resized = image_pil.resize((224, 224))
image_array = np.array(image_resized)
# Convert to format expected by openpi (H, W, C) with uint8
processed_images[openpi_name] = image_array.astype(np.uint8)
# Ensure we have the required cameras, create dummy ones if missing
required_cameras = ["cam_high", "cam_left_wrist", "cam_right_wrist"]
for cam_name in required_cameras:
if cam_name not in processed_images:
# Create a black dummy image
processed_images[cam_name] = np.zeros((224, 224, 3), dtype=np.uint8)
# Prepare state
state_array = np.array(state, dtype=np.float32)
# Create observation dict in openpi format
observation = {
"state": state_array,
"images": processed_images,
}
# Add prompt if provided
if prompt:
observation["prompt"] = prompt
return observation
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Main inference function called by HuggingFace endpoint.
Args:
data: Input data dictionary containing:
- inputs: Dictionary with:
- images: Dict mapping camera names to base64 encoded images
- state: List of robot state values
- prompt: Optional text prompt
- num_actions: Optional, number of actions to predict (default: 50)
- noise: Optional, noise array for sampling
Returns:
List containing prediction results
"""
try:
inputs = data.get("inputs", {})
# Extract inputs
images = inputs.get("images", {})
state = inputs.get("state", [])
prompt = inputs.get("prompt", "")
num_actions = inputs.get("num_actions", self.default_num_steps)
noise_input = inputs.get("noise", None)
# Validate inputs
if not images:
raise ValueError("No images provided")
if not state:
raise ValueError("No state provided")
# Prepare observation using openpi format
observation = self._prepare_observation(images, state, prompt)
# Prepare noise if provided
noise = None
if noise_input is not None:
noise = np.array(noise_input, dtype=np.float32)
# Run inference using openpi policy
# This handles all the preprocessing, model inference, and postprocessing
result = self.policy.infer(observation, noise=noise)
# Extract actions from result
actions = result["actions"]
# Convert to list format for JSON serialization
if isinstance(actions, np.ndarray):
actions_list = actions.tolist()
else:
actions_list = actions
# Return in expected format
return [{
"actions": actions_list,
"num_actions": len(actions_list),
"action_horizon": len(actions_list),
"action_dim": len(actions_list[0]) if actions_list else 0,
"success": True,
"metadata": {
"model_type": self.train_config.model.model_type.value,
"policy_metadata": getattr(self.policy, '_metadata', {})
}
}]
except Exception as e:
return [{
"error": str(e),
"success": False
}]