s3y commited on
Commit
5c55bd1
·
verified ·
1 Parent(s): 1be5b40

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. .idea/.gitignore +8 -0
  2. .idea/workspace.xml +12 -0
  3. handler.py +215 -0
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/workspace.xml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectViewState">
4
+ <option name="hideEmptyMiddlePackages" value="true" />
5
+ <option name="showLibraryContents" value="true" />
6
+ </component>
7
+ <component name="PropertiesComponent">{
8
+ &quot;keyToString&quot;: {
9
+ &quot;settings.editor.selected.configurable&quot;: &quot;dev.sweep.assistant.settings.SweepSettingsConfigurable&quot;
10
+ }
11
+ }</component>
12
+ </project>
handler.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import os
4
+ from io import BytesIO
5
+ from typing import Any, Dict, List
6
+
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+ from openpi.policies import policy_config
11
+ from openpi.training import config as train_config
12
+
13
+
14
+ class EndpointHandler:
15
+ def __init__(self, path: str = ""):
16
+ """
17
+ Initialize the handler for pi0 model inference using openpi infrastructure.
18
+
19
+ Args:
20
+ path: Path to the model weights directory
21
+ """
22
+ # Set model path from environment variable or use provided path
23
+ model_path = os.environ.get("MODEL_PATH", path)
24
+ if not model_path:
25
+ model_path = "weights/pi0"
26
+
27
+ # Load the config.json to determine model type
28
+ config_path = os.path.join(model_path, "config.json")
29
+ with open(config_path, "r") as f:
30
+ model_config = json.load(f)
31
+
32
+ model_type = model_config.get("type", "pi0")
33
+
34
+ # Create training config based on model type
35
+ # This uses the openpi config system
36
+ if model_type == "pi0":
37
+ self.train_config = train_config.get_config("pi0")
38
+ else:
39
+ # Default to pi0 if type not recognized
40
+ self.train_config = train_config.get_config("pi0")
41
+
42
+ # Create trained policy using openpi infrastructure
43
+ # This handles all the model loading, preprocessing, etc.
44
+ self.policy = policy_config.create_trained_policy(
45
+ self.train_config,
46
+ model_path,
47
+ pytorch_device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu"
48
+ )
49
+
50
+ # Default number of inference steps
51
+ self.default_num_steps = 50
52
+
53
+ def _decode_base64_image(self, base64_str: str) -> np.ndarray:
54
+ """
55
+ Decode base64 image string to numpy array.
56
+
57
+ Args:
58
+ base64_str: Base64 encoded image string
59
+
60
+ Returns:
61
+ numpy array of shape (H, W, 3) with values in [0, 255]
62
+ """
63
+ # Remove data URL prefix if present
64
+ if base64_str.startswith("data:image"):
65
+ base64_str = base64_str.split(",", 1)[1]
66
+
67
+ # Decode base64
68
+ image_bytes = base64.b64decode(base64_str)
69
+
70
+ # Convert to PIL Image and then to numpy array
71
+ image = Image.open(BytesIO(image_bytes)).convert("RGB")
72
+ image_array = np.array(image)
73
+
74
+ return image_array
75
+
76
+ def _prepare_observation(self, images: Dict[str, str], state: List[float], prompt: str = None) -> Dict[str, Any]:
77
+ """
78
+ Prepare observation dictionary in the format expected by openpi.
79
+
80
+ Args:
81
+ images: Dictionary mapping camera names to base64 encoded images
82
+ state: List of robot state values
83
+ prompt: Optional text prompt
84
+
85
+ Returns:
86
+ Observation dictionary in openpi format
87
+ """
88
+ # Decode and process images
89
+ processed_images = {}
90
+
91
+ # Map input camera names to expected openpi format
92
+ # Based on the config, pi0 expects specific camera names
93
+ camera_mapping = {
94
+ "camera0": "cam_high", # base camera
95
+ "camera1": "cam_left_wrist", # left wrist camera
96
+ "camera2": "cam_right_wrist", # right wrist camera
97
+ # Alternative mappings
98
+ "base_camera": "cam_high",
99
+ "left_wrist": "cam_left_wrist",
100
+ "right_wrist": "cam_right_wrist",
101
+ # Direct mappings
102
+ "cam_high": "cam_high",
103
+ "cam_left_wrist": "cam_left_wrist",
104
+ "cam_right_wrist": "cam_right_wrist"
105
+ }
106
+
107
+ for input_name, image_b64 in images.items():
108
+ # Map to openpi expected name
109
+ openpi_name = camera_mapping.get(input_name, input_name)
110
+
111
+ # Decode image
112
+ image_array = self._decode_base64_image(image_b64)
113
+
114
+ # Resize to expected resolution if needed
115
+ if image_array.shape[:2] != (224, 224):
116
+ image_pil = Image.fromarray(image_array)
117
+ image_resized = image_pil.resize((224, 224))
118
+ image_array = np.array(image_resized)
119
+
120
+ # Convert to format expected by openpi (H, W, C) with uint8
121
+ processed_images[openpi_name] = image_array.astype(np.uint8)
122
+
123
+ # Ensure we have the required cameras, create dummy ones if missing
124
+ required_cameras = ["cam_high", "cam_left_wrist", "cam_right_wrist"]
125
+ for cam_name in required_cameras:
126
+ if cam_name not in processed_images:
127
+ # Create a black dummy image
128
+ processed_images[cam_name] = np.zeros((224, 224, 3), dtype=np.uint8)
129
+
130
+ # Prepare state
131
+ state_array = np.array(state, dtype=np.float32)
132
+
133
+ # Create observation dict in openpi format
134
+ observation = {
135
+ "state": state_array,
136
+ "images": processed_images,
137
+ }
138
+
139
+ # Add prompt if provided
140
+ if prompt:
141
+ observation["prompt"] = prompt
142
+
143
+ return observation
144
+
145
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
146
+ """
147
+ Main inference function called by HuggingFace endpoint.
148
+
149
+ Args:
150
+ data: Input data dictionary containing:
151
+ - inputs: Dictionary with:
152
+ - images: Dict mapping camera names to base64 encoded images
153
+ - state: List of robot state values
154
+ - prompt: Optional text prompt
155
+ - num_actions: Optional, number of actions to predict (default: 50)
156
+ - noise: Optional, noise array for sampling
157
+
158
+ Returns:
159
+ List containing prediction results
160
+ """
161
+ try:
162
+ inputs = data.get("inputs", {})
163
+
164
+ # Extract inputs
165
+ images = inputs.get("images", {})
166
+ state = inputs.get("state", [])
167
+ prompt = inputs.get("prompt", "")
168
+ num_actions = inputs.get("num_actions", self.default_num_steps)
169
+ noise_input = inputs.get("noise", None)
170
+
171
+ # Validate inputs
172
+ if not images:
173
+ raise ValueError("No images provided")
174
+ if not state:
175
+ raise ValueError("No state provided")
176
+
177
+ # Prepare observation using openpi format
178
+ observation = self._prepare_observation(images, state, prompt)
179
+
180
+ # Prepare noise if provided
181
+ noise = None
182
+ if noise_input is not None:
183
+ noise = np.array(noise_input, dtype=np.float32)
184
+
185
+ # Run inference using openpi policy
186
+ # This handles all the preprocessing, model inference, and postprocessing
187
+ result = self.policy.infer(observation, noise=noise)
188
+
189
+ # Extract actions from result
190
+ actions = result["actions"]
191
+
192
+ # Convert to list format for JSON serialization
193
+ if isinstance(actions, np.ndarray):
194
+ actions_list = actions.tolist()
195
+ else:
196
+ actions_list = actions
197
+
198
+ # Return in expected format
199
+ return [{
200
+ "actions": actions_list,
201
+ "num_actions": len(actions_list),
202
+ "action_horizon": len(actions_list),
203
+ "action_dim": len(actions_list[0]) if actions_list else 0,
204
+ "success": True,
205
+ "metadata": {
206
+ "model_type": self.train_config.model.model_type.value,
207
+ "policy_metadata": getattr(self.policy, '_metadata', {})
208
+ }
209
+ }]
210
+
211
+ except Exception as e:
212
+ return [{
213
+ "error": str(e),
214
+ "success": False
215
+ }]