Georg commited on
Commit
703d3c2
·
1 Parent(s): e15abf5

Fix API endpoints with FastAPI integration

Browse files

- Replace Gradio-only app with FastAPI + Gradio hybrid
- Add /api/initialize and /api/estimate REST endpoints using FastAPI
- Add FastAPI, uvicorn, pydantic dependencies
- Keep Gradio UI for web interface
- Properly expose REST API for robot-ml training integration

Files changed (4) hide show
  1. .gitignore +1 -0
  2. app.py +60 -338
  3. app_simple.py +290 -0
  4. requirements.txt +5 -0
.gitignore CHANGED
@@ -44,3 +44,4 @@ flagged/
44
  # Test images
45
  test_images/
46
  reference_images/
 
 
44
  # Test images
45
  test_images/
46
  reference_images/
47
+ app_old.py
app.py CHANGED
@@ -1,28 +1,22 @@
1
  """
2
- FoundationPose Inference Server with ZeroGPU Support
3
 
4
- This Gradio app provides an API for 6D object pose estimation using FoundationPose.
5
- It's designed to be called from the robot-ml training pipeline via HTTP requests.
6
-
7
- API Endpoints:
8
- - /api/initialize: Set up tracking for an object with reference images
9
- - /api/estimate: Estimate 6D pose from a query image
10
  """
11
 
12
  import base64
13
- import io
14
  import json
15
  import logging
16
  import os
17
- from pathlib import Path
18
- from typing import Dict, List, Optional
19
 
20
  import cv2
21
  import gradio as gr
22
  import numpy as np
23
  import spaces
24
  import torch
25
- from PIL import Image
 
26
 
27
  logging.basicConfig(
28
  level=logging.INFO,
@@ -31,7 +25,7 @@ logging.basicConfig(
31
  logger = logging.getLogger(__name__)
32
 
33
  # Check if running in real FoundationPose mode or placeholder mode
34
- USE_REAL_MODEL = os.environ.get("USE_REAL_MODEL", "false").lower() == "true"
35
 
36
 
37
  class FoundationPoseInference:
@@ -82,27 +76,16 @@ class FoundationPoseInference:
82
  self,
83
  object_id: str,
84
  reference_images: List[np.ndarray],
85
- camera_intrinsics: Optional[Dict] = None,
86
- mesh_path: Optional[str] = None
87
  ) -> bool:
88
- """Register an object for tracking with reference images.
89
-
90
- Args:
91
- object_id: Unique identifier for the object
92
- reference_images: List of RGB images (numpy arrays) showing the object from different angles
93
- camera_intrinsics: Camera parameters (fx, fy, cx, cy)
94
- mesh_path: Optional path to CAD mesh file
95
-
96
- Returns:
97
- True if registration successful
98
- """
99
  if not self.initialized:
100
  self.initialize_model()
101
 
102
  logger.info(f"Registering object '{object_id}' with {len(reference_images)} reference images")
103
 
104
  if self.use_real_model and self.model is not None:
105
- # Use real FoundationPose model
106
  try:
107
  success = self.model.register_object(
108
  object_id=object_id,
@@ -121,7 +104,6 @@ class FoundationPoseInference:
121
  logger.error(f"Registration failed: {e}", exc_info=True)
122
  return False
123
  else:
124
- # Placeholder mode
125
  self.tracked_objects[object_id] = {
126
  "num_references": len(reference_images),
127
  "camera_intrinsics": camera_intrinsics,
@@ -130,39 +112,16 @@ class FoundationPoseInference:
130
  logger.info(f"✓ Object '{object_id}' registered (placeholder mode)")
131
  return True
132
 
133
- @spaces.GPU(duration=10) # Allocate GPU for 10 seconds per inference
134
  def estimate_pose(
135
  self,
136
  object_id: str,
137
  query_image: np.ndarray,
138
- camera_intrinsics: Optional[Dict] = None,
139
- depth_image: Optional[np.ndarray] = None,
140
- mask: Optional[np.ndarray] = None
141
  ) -> Dict:
142
- """Estimate 6D pose of an object in a query image.
143
-
144
- Args:
145
- object_id: ID of object to detect
146
- query_image: RGB query image as numpy array
147
- camera_intrinsics: Optional camera parameters
148
- depth_image: Optional depth map
149
- mask: Optional object segmentation mask
150
-
151
- Returns:
152
- Dictionary with pose estimation results:
153
- {
154
- "success": bool,
155
- "poses": [
156
- {
157
- "object_id": str,
158
- "position": {"x": float, "y": float, "z": float},
159
- "orientation": {"w": float, "x": float, "y": float, "z": float},
160
- "confidence": float,
161
- "dimensions": [float, float, float]
162
- }
163
- ]
164
- }
165
- """
166
  if not self.initialized:
167
  return {"success": False, "error": "Model not initialized"}
168
 
@@ -172,7 +131,6 @@ class FoundationPoseInference:
172
  logger.info(f"Estimating pose for object '{object_id}'")
173
 
174
  if self.use_real_model and self.model is not None:
175
- # Use real FoundationPose model
176
  try:
177
  pose_result = self.model.estimate_pose(
178
  object_id=object_id,
@@ -198,7 +156,6 @@ class FoundationPoseInference:
198
  logger.error(f"Pose estimation error: {e}", exc_info=True)
199
  return {"success": False, "error": str(e), "poses": []}
200
  else:
201
- # Placeholder mode - return empty poses
202
  logger.info("Placeholder mode: returning empty pose result")
203
  return {
204
  "success": True,
@@ -211,37 +168,33 @@ class FoundationPoseInference:
211
  pose_estimator = FoundationPoseInference()
212
 
213
 
214
- def initialize_api(request: gr.Request) -> Dict:
215
- """API endpoint for initializing object tracking.
 
 
 
 
216
 
217
- Request body:
218
- {
219
- "object_id": str,
220
- "reference_images_b64": [str, ...],
221
- "camera_intrinsics": str (JSON),
222
- "mesh_path": str (optional)
223
- }
224
 
225
- Returns:
226
- {"success": bool, "message": str}
227
- """
228
- try:
229
- data = request.json() if hasattr(request, 'json') else {}
 
230
 
231
- object_id = data.get("object_id")
232
- reference_images_b64 = data.get("reference_images_b64", [])
233
- camera_intrinsics_str = data.get("camera_intrinsics")
234
- mesh_path = data.get("mesh_path")
235
 
236
- if not object_id:
237
- return {"success": False, "error": "Missing object_id"}
238
 
239
- if not reference_images_b64:
240
- return {"success": False, "error": "Missing reference_images_b64"}
241
 
 
 
 
 
242
  # Decode reference images
243
  reference_images = []
244
- for img_b64 in reference_images_b64:
245
  img_bytes = base64.b64decode(img_b64)
246
  img_array = np.frombuffer(img_bytes, dtype=np.uint8)
247
  img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
@@ -249,81 +202,55 @@ def initialize_api(request: gr.Request) -> Dict:
249
  reference_images.append(img)
250
 
251
  # Parse camera intrinsics
252
- intrinsics = json.loads(camera_intrinsics_str) if camera_intrinsics_str else None
253
 
254
  # Register object
255
  success = pose_estimator.register_object(
256
- object_id=object_id,
257
  reference_images=reference_images,
258
  camera_intrinsics=intrinsics,
259
- mesh_path=mesh_path
260
  )
261
 
262
  return {
263
  "success": success,
264
- "message": f"Object '{object_id}' registered with {len(reference_images)} reference images"
265
  }
266
 
267
  except Exception as e:
268
  logger.error(f"Initialization error: {e}", exc_info=True)
269
- return {"success": False, "error": str(e)}
270
-
271
 
272
- def estimate_api(request: gr.Request) -> Dict:
273
- """API endpoint for pose estimation.
274
 
275
- Request body:
276
- {
277
- "object_id": str,
278
- "query_image_b64": str,
279
- "camera_intrinsics": str (JSON),
280
- "depth_image_b64": str (optional),
281
- "mask_b64": str (optional)
282
- }
283
-
284
- Returns:
285
- Pose estimation results
286
- """
287
  try:
288
- data = request.json() if hasattr(request, 'json') else {}
289
-
290
- object_id = data.get("object_id")
291
- query_image_b64 = data.get("query_image_b64")
292
- camera_intrinsics_str = data.get("camera_intrinsics")
293
- depth_image_b64 = data.get("depth_image_b64")
294
- mask_b64 = data.get("mask_b64")
295
-
296
- if not object_id:
297
- return {"success": False, "error": "Missing object_id"}
298
-
299
- if not query_image_b64:
300
- return {"success": False, "error": "Missing query_image_b64"}
301
-
302
  # Decode query image
303
- img_bytes = base64.b64decode(query_image_b64)
304
  img_array = np.frombuffer(img_bytes, dtype=np.uint8)
305
  img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
306
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
307
 
308
  # Decode optional depth image
309
  depth = None
310
- if depth_image_b64:
311
- depth_bytes = base64.b64decode(depth_image_b64)
312
  depth = np.frombuffer(depth_bytes, dtype=np.float32)
313
 
314
  # Decode optional mask
315
  mask = None
316
- if mask_b64:
317
- mask_bytes = base64.b64decode(mask_b64)
318
  mask_array = np.frombuffer(mask_bytes, dtype=np.uint8)
319
  mask = cv2.imdecode(mask_array, cv2.IMREAD_GRAYSCALE)
320
 
321
  # Parse camera intrinsics
322
- intrinsics = json.loads(camera_intrinsics_str) if camera_intrinsics_str else None
323
 
324
  # Estimate pose
325
  result = pose_estimator.estimate_pose(
326
- object_id=object_id,
327
  query_image=img,
328
  camera_intrinsics=intrinsics,
329
  depth_image=depth,
@@ -334,235 +261,30 @@ def estimate_api(request: gr.Request) -> Dict:
334
 
335
  except Exception as e:
336
  logger.error(f"Estimation error: {e}", exc_info=True)
337
- return {"success": False, "error": str(e)}
338
-
339
-
340
- # Gradio UI for testing
341
- def test_initialization(object_id: str, reference_images: List):
342
- """Test UI for initialization."""
343
- if not object_id:
344
- return "❌ Please enter an object ID"
345
-
346
- if not reference_images:
347
- return "❌ Please upload reference images"
348
-
349
- try:
350
- # Convert PIL images to numpy arrays
351
- ref_imgs = []
352
- for img in reference_images:
353
- ref_imgs.append(np.array(img))
354
-
355
- success = pose_estimator.register_object(object_id, ref_imgs, None)
356
-
357
- if success:
358
- return f"✅ Object '{object_id}' registered with {len(ref_imgs)} images"
359
- else:
360
- return "❌ Registration failed"
361
-
362
- except Exception as e:
363
- logger.error(f"Test initialization error: {e}", exc_info=True)
364
- return f"❌ Error: {str(e)}"
365
-
366
-
367
- def test_estimation(object_id: str, query_image):
368
- """Test UI for pose estimation."""
369
- if not object_id:
370
- return "❌ Please enter an object ID", None
371
-
372
- if query_image is None:
373
- return "❌ Please upload a query image", None
374
-
375
- try:
376
- query_img = np.array(query_image)
377
- result = pose_estimator.estimate_pose(object_id, query_img, None)
378
-
379
- if result["success"]:
380
- num_poses = len(result["poses"])
381
- output_text = f"✅ Detection complete: {num_poses} pose(s) detected\n\n"
382
 
383
- if num_poses == 0:
384
- output_text += "Note: " + result.get("note", "No poses detected")
385
- else:
386
- for i, pose in enumerate(result["poses"]):
387
- output_text += f"Pose {i+1}:\n"
388
- output_text += f" Position: ({pose['position']['x']:.3f}, {pose['position']['y']:.3f}, {pose['position']['z']:.3f})\n"
389
- output_text += f" Confidence: {pose['confidence']:.3f}\n\n"
390
 
391
- # TODO: Visualize detected pose on image
392
- output_image = query_image
393
-
394
- return output_text, output_image
395
- else:
396
- return f"❌ Detection failed: {result.get('error', 'Unknown error')}", None
397
-
398
- except Exception as e:
399
- logger.error(f"Test estimation error: {e}", exc_info=True)
400
- return f"❌ Error: {str(e)}", None
401
-
402
-
403
- # Build Gradio interface
404
- with gr.Blocks(title="FoundationPose Inference", theme=gr.themes.Soft()) as demo:
405
  gr.Markdown("# 🎯 FoundationPose 6D Object Pose Estimation")
406
 
407
  mode_indicator = gr.Markdown(
408
- f"**Mode:** {'🟢 Real FoundationPose' if USE_REAL_MODEL else '🟡 Placeholder (set USE_REAL_MODEL=true)'}",
409
  elem_id="mode"
410
  )
411
 
412
  gr.Markdown("""
413
- This service provides 6D object pose estimation using FoundationPose.
414
-
415
- **Usage:**
416
- 1. Register an object with reference images using the Initialize tab
417
- 2. Estimate poses in query images using the Estimate tab
418
 
419
- **API Endpoints:**
420
- - POST `/api/initialize` - Register object with reference images
421
- - POST `/api/estimate` - Estimate 6D pose from query image
422
  """)
423
 
424
- with gr.Tab("🔧 Initialize Object"):
425
- gr.Markdown("### Register an object for tracking")
426
- with gr.Row():
427
- with gr.Column():
428
- init_object_id = gr.Textbox(
429
- label="Object ID",
430
- placeholder="e.g., target_cube",
431
- info="Unique identifier for the object"
432
- )
433
- init_ref_images = gr.File(
434
- label="Reference Images (16-20 recommended)",
435
- file_count="multiple",
436
- file_types=["image"],
437
- type="filepath"
438
- )
439
- init_button = gr.Button("Register Object", variant="primary", size="lg")
440
- with gr.Column():
441
- init_output = gr.Textbox(label="Result", lines=8)
442
-
443
- gr.Markdown("""
444
- **Tips:**
445
- - Capture 16-20 images from different viewpoints
446
- - Include various angles and distances
447
- - Ensure good lighting and sharp focus
448
- """)
449
-
450
- init_button.click(
451
- fn=test_initialization,
452
- inputs=[init_object_id, init_ref_images],
453
- outputs=init_output
454
- )
455
-
456
- with gr.Tab("🔍 Estimate Pose"):
457
- gr.Markdown("### Detect object pose in a query image")
458
- with gr.Row():
459
- with gr.Column():
460
- est_object_id = gr.Textbox(
461
- label="Object ID",
462
- placeholder="e.g., target_cube",
463
- info="Must match an initialized object"
464
- )
465
- est_query_image = gr.Image(
466
- label="Query Image",
467
- type="pil",
468
- sources=["upload", "webcam"]
469
- )
470
- est_button = gr.Button("Estimate Pose", variant="primary", size="lg")
471
- with gr.Column():
472
- est_output_text = gr.Textbox(label="Detection Results", lines=15)
473
- est_output_image = gr.Image(label="Visualization (coming soon)")
474
-
475
- est_button.click(
476
- fn=test_estimation,
477
- inputs=[est_object_id, est_query_image],
478
- outputs=[est_output_text, est_output_image]
479
- )
480
-
481
- with gr.Tab("📖 API Documentation"):
482
- gr.Markdown("""
483
- ### HTTP API
484
-
485
- #### Initialize Object
486
- ```bash
487
- curl -X POST https://gpue-foundationpose.hf.space/api/initialize \\
488
- -H "Content-Type: application/json" \\
489
- -d '{
490
- "object_id": "target_cube",
491
- "reference_images_b64": ["<base64-encoded-jpeg>", ...],
492
- "camera_intrinsics": "{\\"fx\\": 500, \\"fy\\": 500, \\"cx\\": 320, \\"cy\\": 240}"
493
- }'
494
- ```
495
-
496
- #### Estimate Pose
497
- ```bash
498
- curl -X POST https://gpue-foundationpose.hf.space/api/estimate \\
499
- -H "Content-Type: application/json" \\
500
- -d '{
501
- "object_id": "target_cube",
502
- "query_image_b64": "<base64-encoded-jpeg>",
503
- "camera_intrinsics": "{\\"fx\\": 500, \\"fy\\": 500, \\"cx\\": 320, \\"cy\\": 240}"
504
- }'
505
- ```
506
-
507
- **Response Format:**
508
- ```json
509
- {
510
- "success": true,
511
- "poses": [
512
- {
513
- "object_id": "target_cube",
514
- "position": {"x": 0.5, "y": 0.3, "z": 0.1},
515
- "orientation": {"w": 1.0, "x": 0.0, "y": 0.0, "z": 0.0},
516
- "confidence": 0.95,
517
- "dimensions": [0.1, 0.1, 0.1]
518
- }
519
- ]
520
- }
521
- ```
522
-
523
- ### Integration with robot-ml
524
-
525
- ```python
526
- from foundationpose.client import FoundationPoseClient
527
-
528
- client = FoundationPoseClient("https://gpue-foundationpose.hf.space")
529
-
530
- # Load reference images
531
- ref_images = load_reference_images("./perception/reference/target_cube")
532
 
533
- # Initialize object
534
- client.initialize("target_cube", ref_images)
535
 
536
- # Estimate pose
537
- poses = client.estimate_pose("target_cube", query_image)
538
- ```
539
- """)
540
-
541
- gr.Markdown("""
542
- ---
543
- **Citation:**
544
- ```bibtex
545
- @inproceedings{wen2023foundationpose,
546
- title={FoundationPose: Unified 6D Pose Estimation and Tracking of Novel Objects},
547
- author={Wen, Bowen and Yang, Wei and Kautz, Jan and Birchfield, Stan},
548
- booktitle={CVPR},
549
- year={2024}
550
- }
551
- ```
552
-
553
- [GitHub](https://github.com/NVlabs/FoundationPose) | [Paper](https://arxiv.org/abs/2312.08344)
554
- """)
555
-
556
-
557
- # Launch app
558
  if __name__ == "__main__":
559
- logger.info("=" * 60)
560
- logger.info("FoundationPose Inference Server Starting")
561
- logger.info(f"Mode: {'Real Model' if USE_REAL_MODEL else 'Placeholder'}")
562
- logger.info("=" * 60)
563
-
564
- demo.launch(
565
- server_name="0.0.0.0",
566
- server_port=7860,
567
- share=False
568
- )
 
1
  """
2
+ Simple FoundationPose API server using FastAPI + Gradio
3
 
4
+ This version uses FastAPI for clean REST API endpoints alongside Gradio UI.
 
 
 
 
 
5
  """
6
 
7
  import base64
 
8
  import json
9
  import logging
10
  import os
11
+ from typing import Dict, List
 
12
 
13
  import cv2
14
  import gradio as gr
15
  import numpy as np
16
  import spaces
17
  import torch
18
+ from fastapi import FastAPI, HTTPException
19
+ from pydantic import BaseModel
20
 
21
  logging.basicConfig(
22
  level=logging.INFO,
 
25
  logger = logging.getLogger(__name__)
26
 
27
  # Check if running in real FoundationPose mode or placeholder mode
28
+ USE_REAL_MODEL = os.environ.get("USE_REAL_MODEL", "false").lower() == "true"
29
 
30
 
31
  class FoundationPoseInference:
 
76
  self,
77
  object_id: str,
78
  reference_images: List[np.ndarray],
79
+ camera_intrinsics: Dict = None,
80
+ mesh_path: str = None
81
  ) -> bool:
82
+ """Register an object for tracking with reference images."""
 
 
 
 
 
 
 
 
 
 
83
  if not self.initialized:
84
  self.initialize_model()
85
 
86
  logger.info(f"Registering object '{object_id}' with {len(reference_images)} reference images")
87
 
88
  if self.use_real_model and self.model is not None:
 
89
  try:
90
  success = self.model.register_object(
91
  object_id=object_id,
 
104
  logger.error(f"Registration failed: {e}", exc_info=True)
105
  return False
106
  else:
 
107
  self.tracked_objects[object_id] = {
108
  "num_references": len(reference_images),
109
  "camera_intrinsics": camera_intrinsics,
 
112
  logger.info(f"✓ Object '{object_id}' registered (placeholder mode)")
113
  return True
114
 
115
+ @spaces.GPU(duration=10)
116
  def estimate_pose(
117
  self,
118
  object_id: str,
119
  query_image: np.ndarray,
120
+ camera_intrinsics: Dict = None,
121
+ depth_image: np.ndarray = None,
122
+ mask: np.ndarray = None
123
  ) -> Dict:
124
+ """Estimate 6D pose of an object in a query image."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  if not self.initialized:
126
  return {"success": False, "error": "Model not initialized"}
127
 
 
131
  logger.info(f"Estimating pose for object '{object_id}'")
132
 
133
  if self.use_real_model and self.model is not None:
 
134
  try:
135
  pose_result = self.model.estimate_pose(
136
  object_id=object_id,
 
156
  logger.error(f"Pose estimation error: {e}", exc_info=True)
157
  return {"success": False, "error": str(e), "poses": []}
158
  else:
 
159
  logger.info("Placeholder mode: returning empty pose result")
160
  return {
161
  "success": True,
 
168
  pose_estimator = FoundationPoseInference()
169
 
170
 
171
+ # Pydantic models for API
172
+ class InitializeRequest(BaseModel):
173
+ object_id: str
174
+ reference_images_b64: List[str]
175
+ camera_intrinsics: str = None
176
+ mesh_path: str = None
177
 
 
 
 
 
 
 
 
178
 
179
+ class EstimateRequest(BaseModel):
180
+ object_id: str
181
+ query_image_b64: str
182
+ camera_intrinsics: str = None
183
+ depth_image_b64: str = None
184
+ mask_b64: str = None
185
 
 
 
 
 
186
 
187
+ # Create FastAPI app
188
+ app = FastAPI()
189
 
 
 
190
 
191
+ @app.post("/api/initialize")
192
+ async def api_initialize(request: InitializeRequest):
193
+ """Initialize object tracking with reference images."""
194
+ try:
195
  # Decode reference images
196
  reference_images = []
197
+ for img_b64 in request.reference_images_b64:
198
  img_bytes = base64.b64decode(img_b64)
199
  img_array = np.frombuffer(img_bytes, dtype=np.uint8)
200
  img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
 
202
  reference_images.append(img)
203
 
204
  # Parse camera intrinsics
205
+ intrinsics = json.loads(request.camera_intrinsics) if request.camera_intrinsics else None
206
 
207
  # Register object
208
  success = pose_estimator.register_object(
209
+ object_id=request.object_id,
210
  reference_images=reference_images,
211
  camera_intrinsics=intrinsics,
212
+ mesh_path=request.mesh_path
213
  )
214
 
215
  return {
216
  "success": success,
217
+ "message": f"Object '{request.object_id}' registered with {len(reference_images)} reference images"
218
  }
219
 
220
  except Exception as e:
221
  logger.error(f"Initialization error: {e}", exc_info=True)
222
+ raise HTTPException(status_code=500, detail=str(e))
 
223
 
 
 
224
 
225
+ @app.post("/api/estimate")
226
+ async def api_estimate(request: EstimateRequest):
227
+ """Estimate 6D pose from query image."""
 
 
 
 
 
 
 
 
 
228
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  # Decode query image
230
+ img_bytes = base64.b64decode(request.query_image_b64)
231
  img_array = np.frombuffer(img_bytes, dtype=np.uint8)
232
  img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
233
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
234
 
235
  # Decode optional depth image
236
  depth = None
237
+ if request.depth_image_b64:
238
+ depth_bytes = base64.b64decode(request.depth_image_b64)
239
  depth = np.frombuffer(depth_bytes, dtype=np.float32)
240
 
241
  # Decode optional mask
242
  mask = None
243
+ if request.mask_b64:
244
+ mask_bytes = base64.b64decode(request.mask_b64)
245
  mask_array = np.frombuffer(mask_bytes, dtype=np.uint8)
246
  mask = cv2.imdecode(mask_array, cv2.IMREAD_GRAYSCALE)
247
 
248
  # Parse camera intrinsics
249
+ intrinsics = json.loads(request.camera_intrinsics) if request.camera_intrinsics else None
250
 
251
  # Estimate pose
252
  result = pose_estimator.estimate_pose(
253
+ object_id=request.object_id,
254
  query_image=img,
255
  camera_intrinsics=intrinsics,
256
  depth_image=depth,
 
261
 
262
  except Exception as e:
263
  logger.error(f"Estimation error: {e}", exc_info=True)
264
+ raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
 
 
 
 
 
 
 
266
 
267
+ # Gradio UI (simplified)
268
+ with gr.Blocks(title="FoundationPose Inference", theme=gr.themes.Soft()) as gradio_app:
 
 
 
 
 
 
 
 
 
 
 
 
269
  gr.Markdown("# 🎯 FoundationPose 6D Object Pose Estimation")
270
 
271
  mode_indicator = gr.Markdown(
272
+ f"**Mode:** {'🟢 Real FoundationPose' if USE_REAL_MODEL else '🟡 Placeholder'}",
273
  elem_id="mode"
274
  )
275
 
276
  gr.Markdown("""
277
+ API Endpoints:
278
+ - POST `/api/initialize` - Register object
279
+ - POST `/api/estimate` - Estimate pose
 
 
280
 
281
+ See documentation for usage examples.
 
 
282
  """)
283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
 
285
+ # Mount Gradio to FastAPI
286
+ app = gr.mount_gradio_app(app, gradio_app, path="/")
287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  if __name__ == "__main__":
289
+ import uvicorn
290
+ uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
app_simple.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple FoundationPose API server using FastAPI + Gradio
3
+
4
+ This version uses FastAPI for clean REST API endpoints alongside Gradio UI.
5
+ """
6
+
7
+ import base64
8
+ import json
9
+ import logging
10
+ import os
11
+ from typing import Dict, List
12
+
13
+ import cv2
14
+ import gradio as gr
15
+ import numpy as np
16
+ import spaces
17
+ import torch
18
+ from fastapi import FastAPI, HTTPException
19
+ from pydantic import BaseModel
20
+
21
+ logging.basicConfig(
22
+ level=logging.INFO,
23
+ format="[%(asctime)s] %(levelname)s: %(message)s"
24
+ )
25
+ logger = logging.getLogger(__name__)
26
+
27
+ # Check if running in real FoundationPose mode or placeholder mode
28
+ USE_REAL_MODEL = os.environ.get("USE_REAL_MODEL", "false").lower() == "true"
29
+
30
+
31
+ class FoundationPoseInference:
32
+ """Wrapper for FoundationPose model inference."""
33
+
34
+ def __init__(self):
35
+ self.model = None
36
+ self.device = None
37
+ self.initialized = False
38
+ self.tracked_objects = {}
39
+ self.use_real_model = USE_REAL_MODEL
40
+
41
+ @spaces.GPU(duration=120) # Allocate GPU for 120 seconds (includes model loading)
42
+ def initialize_model(self):
43
+ """Initialize the FoundationPose model on GPU."""
44
+ if self.initialized:
45
+ logger.info("Model already initialized")
46
+ return
47
+
48
+ logger.info("Initializing FoundationPose model...")
49
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
+ logger.info(f"Using device: {self.device}")
51
+
52
+ if self.use_real_model:
53
+ try:
54
+ logger.info("Loading real FoundationPose model...")
55
+ from estimator import FoundationPoseEstimator
56
+
57
+ self.model = FoundationPoseEstimator(
58
+ device=str(self.device),
59
+ weights_dir="weights"
60
+ )
61
+ logger.info("✓ Real FoundationPose model initialized successfully")
62
+
63
+ except Exception as e:
64
+ logger.error(f"Failed to initialize real model: {e}", exc_info=True)
65
+ logger.warning("Falling back to placeholder mode")
66
+ self.use_real_model = False
67
+ self.model = None
68
+ else:
69
+ logger.info("Using placeholder mode (set USE_REAL_MODEL=true for real inference)")
70
+ self.model = None
71
+
72
+ self.initialized = True
73
+ logger.info("FoundationPose inference ready")
74
+
75
+ def register_object(
76
+ self,
77
+ object_id: str,
78
+ reference_images: List[np.ndarray],
79
+ camera_intrinsics: Dict = None,
80
+ mesh_path: str = None
81
+ ) -> bool:
82
+ """Register an object for tracking with reference images."""
83
+ if not self.initialized:
84
+ self.initialize_model()
85
+
86
+ logger.info(f"Registering object '{object_id}' with {len(reference_images)} reference images")
87
+
88
+ if self.use_real_model and self.model is not None:
89
+ try:
90
+ success = self.model.register_object(
91
+ object_id=object_id,
92
+ reference_images=reference_images,
93
+ camera_intrinsics=camera_intrinsics,
94
+ mesh_path=mesh_path
95
+ )
96
+ if success:
97
+ self.tracked_objects[object_id] = {
98
+ "num_references": len(reference_images),
99
+ "camera_intrinsics": camera_intrinsics,
100
+ "mesh_path": mesh_path
101
+ }
102
+ return success
103
+ except Exception as e:
104
+ logger.error(f"Registration failed: {e}", exc_info=True)
105
+ return False
106
+ else:
107
+ self.tracked_objects[object_id] = {
108
+ "num_references": len(reference_images),
109
+ "camera_intrinsics": camera_intrinsics,
110
+ "mesh_path": mesh_path
111
+ }
112
+ logger.info(f"✓ Object '{object_id}' registered (placeholder mode)")
113
+ return True
114
+
115
+ @spaces.GPU(duration=10)
116
+ def estimate_pose(
117
+ self,
118
+ object_id: str,
119
+ query_image: np.ndarray,
120
+ camera_intrinsics: Dict = None,
121
+ depth_image: np.ndarray = None,
122
+ mask: np.ndarray = None
123
+ ) -> Dict:
124
+ """Estimate 6D pose of an object in a query image."""
125
+ if not self.initialized:
126
+ return {"success": False, "error": "Model not initialized"}
127
+
128
+ if object_id not in self.tracked_objects:
129
+ return {"success": False, "error": f"Object '{object_id}' not registered"}
130
+
131
+ logger.info(f"Estimating pose for object '{object_id}'")
132
+
133
+ if self.use_real_model and self.model is not None:
134
+ try:
135
+ pose_result = self.model.estimate_pose(
136
+ object_id=object_id,
137
+ rgb_image=query_image,
138
+ depth_image=depth_image,
139
+ mask=mask,
140
+ camera_intrinsics=camera_intrinsics
141
+ )
142
+
143
+ if pose_result is None:
144
+ return {
145
+ "success": False,
146
+ "error": "Pose estimation returned None",
147
+ "poses": []
148
+ }
149
+
150
+ return {
151
+ "success": True,
152
+ "poses": [pose_result]
153
+ }
154
+
155
+ except Exception as e:
156
+ logger.error(f"Pose estimation error: {e}", exc_info=True)
157
+ return {"success": False, "error": str(e), "poses": []}
158
+ else:
159
+ logger.info("Placeholder mode: returning empty pose result")
160
+ return {
161
+ "success": True,
162
+ "poses": [],
163
+ "note": "Placeholder mode - set USE_REAL_MODEL=true for real inference"
164
+ }
165
+
166
+
167
+ # Global model instance
168
+ pose_estimator = FoundationPoseInference()
169
+
170
+
171
+ # Pydantic models for API
172
+ class InitializeRequest(BaseModel):
173
+ object_id: str
174
+ reference_images_b64: List[str]
175
+ camera_intrinsics: str = None
176
+ mesh_path: str = None
177
+
178
+
179
+ class EstimateRequest(BaseModel):
180
+ object_id: str
181
+ query_image_b64: str
182
+ camera_intrinsics: str = None
183
+ depth_image_b64: str = None
184
+ mask_b64: str = None
185
+
186
+
187
+ # Create FastAPI app
188
+ app = FastAPI()
189
+
190
+
191
+ @app.post("/api/initialize")
192
+ async def api_initialize(request: InitializeRequest):
193
+ """Initialize object tracking with reference images."""
194
+ try:
195
+ # Decode reference images
196
+ reference_images = []
197
+ for img_b64 in request.reference_images_b64:
198
+ img_bytes = base64.b64decode(img_b64)
199
+ img_array = np.frombuffer(img_bytes, dtype=np.uint8)
200
+ img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
201
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
202
+ reference_images.append(img)
203
+
204
+ # Parse camera intrinsics
205
+ intrinsics = json.loads(request.camera_intrinsics) if request.camera_intrinsics else None
206
+
207
+ # Register object
208
+ success = pose_estimator.register_object(
209
+ object_id=request.object_id,
210
+ reference_images=reference_images,
211
+ camera_intrinsics=intrinsics,
212
+ mesh_path=request.mesh_path
213
+ )
214
+
215
+ return {
216
+ "success": success,
217
+ "message": f"Object '{request.object_id}' registered with {len(reference_images)} reference images"
218
+ }
219
+
220
+ except Exception as e:
221
+ logger.error(f"Initialization error: {e}", exc_info=True)
222
+ raise HTTPException(status_code=500, detail=str(e))
223
+
224
+
225
+ @app.post("/api/estimate")
226
+ async def api_estimate(request: EstimateRequest):
227
+ """Estimate 6D pose from query image."""
228
+ try:
229
+ # Decode query image
230
+ img_bytes = base64.b64decode(request.query_image_b64)
231
+ img_array = np.frombuffer(img_bytes, dtype=np.uint8)
232
+ img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
233
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
234
+
235
+ # Decode optional depth image
236
+ depth = None
237
+ if request.depth_image_b64:
238
+ depth_bytes = base64.b64decode(request.depth_image_b64)
239
+ depth = np.frombuffer(depth_bytes, dtype=np.float32)
240
+
241
+ # Decode optional mask
242
+ mask = None
243
+ if request.mask_b64:
244
+ mask_bytes = base64.b64decode(request.mask_b64)
245
+ mask_array = np.frombuffer(mask_bytes, dtype=np.uint8)
246
+ mask = cv2.imdecode(mask_array, cv2.IMREAD_GRAYSCALE)
247
+
248
+ # Parse camera intrinsics
249
+ intrinsics = json.loads(request.camera_intrinsics) if request.camera_intrinsics else None
250
+
251
+ # Estimate pose
252
+ result = pose_estimator.estimate_pose(
253
+ object_id=request.object_id,
254
+ query_image=img,
255
+ camera_intrinsics=intrinsics,
256
+ depth_image=depth,
257
+ mask=mask
258
+ )
259
+
260
+ return result
261
+
262
+ except Exception as e:
263
+ logger.error(f"Estimation error: {e}", exc_info=True)
264
+ raise HTTPException(status_code=500, detail=str(e))
265
+
266
+
267
+ # Gradio UI (simplified)
268
+ with gr.Blocks(title="FoundationPose Inference", theme=gr.themes.Soft()) as gradio_app:
269
+ gr.Markdown("# 🎯 FoundationPose 6D Object Pose Estimation")
270
+
271
+ mode_indicator = gr.Markdown(
272
+ f"**Mode:** {'🟢 Real FoundationPose' if USE_REAL_MODEL else '🟡 Placeholder'}",
273
+ elem_id="mode"
274
+ )
275
+
276
+ gr.Markdown("""
277
+ API Endpoints:
278
+ - POST `/api/initialize` - Register object
279
+ - POST `/api/estimate` - Estimate pose
280
+
281
+ See documentation for usage examples.
282
+ """)
283
+
284
+
285
+ # Mount Gradio to FastAPI
286
+ app = gr.mount_gradio_app(app, gradio_app, path="/")
287
+
288
+ if __name__ == "__main__":
289
+ import uvicorn
290
+ uvicorn.run(app, host="0.0.0.0", port=7860)
requirements.txt CHANGED
@@ -5,6 +5,11 @@ numpy>=1.24.0
5
  opencv-python>=4.8.0
6
  Pillow>=10.0.0
7
 
 
 
 
 
 
8
  # Hugging Face
9
  huggingface_hub>=0.20.0
10
 
 
5
  opencv-python>=4.8.0
6
  Pillow>=10.0.0
7
 
8
+ # FastAPI for REST API endpoints
9
+ fastapi>=0.109.0
10
+ uvicorn>=0.27.0
11
+ pydantic>=2.0.0
12
+
13
  # Hugging Face
14
  huggingface_hub>=0.20.0
15