File size: 7,762 Bytes
b5784e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import base64
import json
import os
import sys
from io import BytesIO
from typing import Any, Dict, List

import numpy as np
from PIL import Image

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "openpi", "src"))

from openpi.policies import policy_config
from openpi.training import config as train_config


class EndpointHandler:
    def __init__(self, path: str = ""):
        """
        Initialize the handler for pi0 model inference using openpi infrastructure.

        Args:
            path: Path to the model weights directory
        """
        # Set model path from environment variable or use provided path
        model_path = os.environ.get("MODEL_PATH", path)
        if not model_path:
            model_path = "weights/pi0"

        # Load the config.json to determine model type
        config_path = os.path.join(model_path, "config.json")
        with open(config_path, "r") as f:
            model_config = json.load(f)

        model_type = model_config.get("type", "pi0")

        # Create training config based on model type
        # This uses the openpi config system
        if model_type == "pi0":
            self.train_config = train_config.get_config("pi0")
        else:
            # Default to pi0 if type not recognized
            self.train_config = train_config.get_config("pi0")

        # Create trained policy using openpi infrastructure
        # This handles all the model loading, preprocessing, etc.
        self.policy = policy_config.create_trained_policy(
            self.train_config,
            model_path,
            pytorch_device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu"
        )

        # Default number of inference steps
        self.default_num_steps = 50

    def _decode_base64_image(self, base64_str: str) -> np.ndarray:
        """
        Decode base64 image string to numpy array.

        Args:
            base64_str: Base64 encoded image string

        Returns:
            numpy array of shape (H, W, 3) with values in [0, 255]
        """
        # Remove data URL prefix if present
        if base64_str.startswith("data:image"):
            base64_str = base64_str.split(",", 1)[1]

        # Decode base64
        image_bytes = base64.b64decode(base64_str)

        # Convert to PIL Image and then to numpy array
        image = Image.open(BytesIO(image_bytes)).convert("RGB")
        image_array = np.array(image)

        return image_array

    def _prepare_observation(self, images: Dict[str, str], state: List[float], prompt: str = None) -> Dict[str, Any]:
        """
        Prepare observation dictionary in the format expected by openpi.

        Args:
            images: Dictionary mapping camera names to base64 encoded images
            state: List of robot state values
            prompt: Optional text prompt

        Returns:
            Observation dictionary in openpi format
        """
        # Decode and process images
        processed_images = {}

        # Map input camera names to expected openpi format
        # Based on the config, pi0 expects specific camera names
        camera_mapping = {
            "camera0": "cam_high",          # base camera
            "camera1": "cam_left_wrist",    # left wrist camera
            "camera2": "cam_right_wrist",   # right wrist camera
            # Alternative mappings
            "base_camera": "cam_high",
            "left_wrist": "cam_left_wrist",
            "right_wrist": "cam_right_wrist",
            # Direct mappings
            "cam_high": "cam_high",
            "cam_left_wrist": "cam_left_wrist",
            "cam_right_wrist": "cam_right_wrist"
        }

        for input_name, image_b64 in images.items():
            # Map to openpi expected name
            openpi_name = camera_mapping.get(input_name, input_name)

            # Decode image
            image_array = self._decode_base64_image(image_b64)

            # Resize to expected resolution if needed
            if image_array.shape[:2] != (224, 224):
                image_pil = Image.fromarray(image_array)
                image_resized = image_pil.resize((224, 224))
                image_array = np.array(image_resized)

            # Convert to format expected by openpi (H, W, C) with uint8
            processed_images[openpi_name] = image_array.astype(np.uint8)

        # Ensure we have the required cameras, create dummy ones if missing
        required_cameras = ["cam_high", "cam_left_wrist", "cam_right_wrist"]
        for cam_name in required_cameras:
            if cam_name not in processed_images:
                # Create a black dummy image
                processed_images[cam_name] = np.zeros((224, 224, 3), dtype=np.uint8)

        # Prepare state
        state_array = np.array(state, dtype=np.float32)

        # Create observation dict in openpi format
        observation = {
            "state": state_array,
            "images": processed_images,
        }

        # Add prompt if provided
        if prompt:
            observation["prompt"] = prompt

        return observation

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Main inference function called by HuggingFace endpoint.

        Args:
            data: Input data dictionary containing:
                - inputs: Dictionary with:
                    - images: Dict mapping camera names to base64 encoded images
                    - state: List of robot state values
                    - prompt: Optional text prompt
                    - num_actions: Optional, number of actions to predict (default: 50)
                    - noise: Optional, noise array for sampling

        Returns:
            List containing prediction results
        """
        try:
            inputs = data.get("inputs", {})

            # Extract inputs
            images = inputs.get("images", {})
            state = inputs.get("state", [])
            prompt = inputs.get("prompt", "")
            num_actions = inputs.get("num_actions", self.default_num_steps)
            noise_input = inputs.get("noise", None)

            # Validate inputs
            if not images:
                raise ValueError("No images provided")
            if not state:
                raise ValueError("No state provided")

            # Prepare observation using openpi format
            observation = self._prepare_observation(images, state, prompt)

            # Prepare noise if provided
            noise = None
            if noise_input is not None:
                noise = np.array(noise_input, dtype=np.float32)

            # Run inference using openpi policy
            # This handles all the preprocessing, model inference, and postprocessing
            result = self.policy.infer(observation, noise=noise)

            # Extract actions from result
            actions = result["actions"]

            # Convert to list format for JSON serialization
            if isinstance(actions, np.ndarray):
                actions_list = actions.tolist()
            else:
                actions_list = actions

            # Return in expected format
            return [{
                "actions": actions_list,
                "num_actions": len(actions_list),
                "action_horizon": len(actions_list),
                "action_dim": len(actions_list[0]) if actions_list else 0,
                "success": True,
                "metadata": {
                    "model_type": self.train_config.model.model_type.value,
                    "policy_metadata": getattr(self.policy, '_metadata', {})
                }
            }]

        except Exception as e:
            return [{
                "error": str(e),
                "success": False
            }]