choephix commited on
Commit
7c33993
·
1 Parent(s): 61287af

add async job polling API

Browse files
Files changed (3) hide show
  1. README.md +12 -4
  2. job_store.py +15 -0
  3. server.py +249 -122
README.md CHANGED
@@ -13,13 +13,21 @@ short_description: High-fidelity 3D Generation from images
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
15
 
16
- API endpoint: `POST /v1/image-to-glb`
 
 
 
 
17
 
18
  Example:
19
 
20
  ```bash
21
- curl -X POST "http://localhost:7860/v1/image-to-glb" \
22
  -F 'image=@input.png' \
23
- -F 'request={"generation":{"resolution":"512"},"export":{"remesh":true}}' \
24
- --output output.glb
 
 
 
 
25
  ```
 
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
15
 
16
+ API endpoints:
17
+
18
+ - `POST /v1/image-to-glb` - create an async GLB generation job
19
+ - `GET /v1/jobs/{job_id}` - poll job status
20
+ - `GET /v1/jobs/{job_id}/result` - download the GLB when ready
21
 
22
  Example:
23
 
24
  ```bash
25
+ job_json=$(curl -s -X POST "http://localhost:7860/v1/image-to-glb" \
26
  -F 'image=@input.png' \
27
+ -F 'request={"generation":{"resolution":"512"},"export":{"remesh":true}}')
28
+
29
+ job_id=$(python3 -c 'import json,sys; print(json.load(sys.stdin)["job_id"])' <<<"$job_json")
30
+
31
+ curl -s "http://localhost:7860/v1/jobs/$job_id"
32
+ curl -L "http://localhost:7860/v1/jobs/$job_id/result" --output output.glb
33
  ```
job_store.py CHANGED
@@ -35,6 +35,21 @@ class JobStore:
35
  self.write_json(job_dir / "request.json", request_payload)
36
  return job_id, job_dir
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def update_status(self, job_dir: Path, status: str, **extra: Any) -> None:
39
  meta_path = job_dir / "meta.json"
40
  meta = self.read_json(meta_path, default={})
 
35
  self.write_json(job_dir / "request.json", request_payload)
36
  return job_id, job_dir
37
 
38
+ def job_dir(self, job_id: str) -> Path:
39
+ return self.root_dir / job_id
40
+
41
+ def exists(self, job_id: str) -> bool:
42
+ return self.job_dir(job_id).is_dir()
43
+
44
+ def read_meta(self, job_id: str) -> dict[str, Any] | None:
45
+ return self.read_json(self.job_dir(job_id) / "meta.json")
46
+
47
+ def read_request(self, job_id: str) -> dict[str, Any] | None:
48
+ return self.read_json(self.job_dir(job_id) / "request.json")
49
+
50
+ def read_error(self, job_id: str) -> dict[str, Any] | None:
51
+ return self.read_json(self.job_dir(job_id) / "error.json")
52
+
53
  def update_status(self, job_dir: Path, status: str, **extra: Any) -> None:
54
  meta_path = job_dir / "meta.json"
55
  meta = self.read_json(meta_path, default={})
server.py CHANGED
@@ -1,15 +1,16 @@
1
  from __future__ import annotations
2
 
3
  import asyncio
 
4
  import json
5
  import subprocess
6
  import sys
7
  import time
8
  from collections import deque
9
  from pathlib import Path
10
- from typing import Annotated
11
 
12
- from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
13
  from fastapi.responses import FileResponse, JSONResponse
14
  from PIL import Image, UnidentifiedImageError
15
 
@@ -31,29 +32,37 @@ JOB_DIR = TMP_DIR / "jobs"
31
  EXPORT_TIMEOUT_SECONDS = 45
32
 
33
 
34
- class RequestQueue:
35
  def __init__(self) -> None:
36
  self._condition = asyncio.Condition()
37
  self._waiting: deque[str] = deque()
38
  self._active_job_id: str | None = None
39
 
40
- async def acquire(self, job_id: str) -> int:
 
 
 
 
 
 
 
 
 
41
  async with self._condition:
42
  self._waiting.append(job_id)
43
- queued_ahead = len(self._waiting) - 1
44
- try:
45
- while self._waiting[0] != job_id or self._active_job_id is not None:
46
- await self._condition.wait()
47
- self._waiting.popleft()
48
- self._active_job_id = job_id
49
- return queued_ahead
50
- except BaseException:
51
- if job_id in self._waiting:
52
- self._waiting.remove(job_id)
53
- self._condition.notify_all()
54
- raise
55
-
56
- async def release(self, job_id: str) -> None:
57
  async with self._condition:
58
  if self._active_job_id == job_id:
59
  self._active_job_id = None
@@ -61,10 +70,14 @@ class RequestQueue:
61
  self._waiting.remove(job_id)
62
  self._condition.notify_all()
63
 
64
- def snapshot(self) -> dict[str, object]:
 
 
 
65
  return {
66
  "active_job_id": self._active_job_id,
67
  "queued_jobs": len(self._waiting),
 
68
  }
69
 
70
 
@@ -72,11 +85,12 @@ class ServiceState:
72
  def __init__(self) -> None:
73
  self.runtime = TrellisRuntime()
74
  self.job_store = JobStore(JOB_DIR)
75
- self.queue = RequestQueue()
 
76
 
77
 
78
  state = ServiceState()
79
- app = FastAPI(title="TRELLIS.2 API", version="1.0.0")
80
 
81
 
82
  def _parse_request_payload(payload: str) -> ImageToGlbRequest:
@@ -105,6 +119,12 @@ def _open_image(contents: bytes) -> Image.Image:
105
  ) from error
106
 
107
 
 
 
 
 
 
 
108
  def _run_export(job_dir: Path, request: ImageToGlbRequest) -> Path:
109
  output_path = job_dir / "result.glb"
110
  result_json = job_dir / "export_result.json"
@@ -162,111 +182,130 @@ def _run_export(job_dir: Path, request: ImageToGlbRequest) -> Path:
162
  )
163
 
164
 
165
- @app.on_event("startup")
166
- def startup() -> None:
167
- TMP_DIR.mkdir(parents=True, exist_ok=True)
168
- state.job_store.cleanup_expired()
169
- state.runtime.load()
170
-
171
-
172
- @app.get("/healthz")
173
- async def healthz() -> dict[str, object]:
174
- queue_state = state.queue.snapshot()
175
  return {
176
- "ok": state.runtime.is_healthy,
177
- "runtime_healthy": state.runtime.is_healthy,
178
- "busy": queue_state["active_job_id"] is not None,
179
- "queued_jobs": queue_state["queued_jobs"],
180
- "reason": state.runtime.unhealthy_reason,
181
  }
182
 
183
 
184
- @app.get("/")
185
- async def root() -> dict[str, object]:
186
- queue_state = state.queue.snapshot()
187
- return {
188
- "service": "TRELLIS.2 API",
189
- "version": app.version,
190
- "endpoint": "/v1/image-to-glb",
191
- "healthy": state.runtime.is_healthy,
192
- "busy": queue_state["active_job_id"] is not None,
 
 
 
 
 
193
  "queued_jobs": queue_state["queued_jobs"],
 
 
194
  }
195
-
196
-
197
- @app.post("/v1/image-to-glb")
198
- async def image_to_glb(
199
- http_request: Request,
200
- image: Annotated[UploadFile, File(...)],
201
- request_payload: str = Form("{}", alias="request"),
202
- ):
203
- job_id: str | None = None
204
- job_dir: Path | None = None
205
- queue_acquired = False
206
- try:
207
- started = time.perf_counter()
208
- request_model = _parse_request_payload(request_payload)
209
- state.job_store.cleanup_expired()
210
- contents = await image.read()
211
- pil_image = _open_image(contents)
212
- job_id, job_dir = state.job_store.create(
213
- request_model.model_dump(mode="json"), image.filename
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  )
 
215
 
216
- input_suffix = Path(image.filename or "input.png").suffix or ".png"
217
- input_path = job_dir / f"input{input_suffix}"
218
- save_input_image(pil_image, input_path)
219
-
220
- queued_at = time.perf_counter()
221
- state.job_store.update_status(job_dir, "queued")
222
- queued_ahead = await state.queue.acquire(job_id)
223
- queue_acquired = True
224
- queue_wait_ms = round((time.perf_counter() - queued_at) * 1000, 2)
225
- if await http_request.is_disconnected():
226
- raise ServiceError(
227
- stage="admission",
228
- error_code="client_disconnected",
229
- message="Client disconnected while waiting in queue",
230
- retryable=False,
231
- status_code=499,
232
- )
233
 
 
 
 
 
 
 
 
234
  state.job_store.update_status(
235
  job_dir,
236
  "preprocessing",
 
237
  queue_wait_ms=queue_wait_ms,
238
- queued_ahead=queued_ahead,
239
  )
240
- preprocessed = state.runtime.preprocess(pil_image, request_model)
 
 
241
 
242
- state.job_store.update_status(job_dir, "generating")
 
 
 
243
  payload = state.runtime.generate_export_payload(preprocessed, request_model)
244
  save_export_payload(job_dir, payload)
 
245
 
246
- state.job_store.update_status(job_dir, "exporting")
 
 
 
 
 
 
247
  output_path = _run_export(job_dir, request_model)
 
248
  duration_ms = round((time.perf_counter() - started) * 1000, 2)
249
  state.job_store.update_status(
250
  job_dir,
251
  "succeeded",
 
 
 
252
  duration_ms=duration_ms,
 
253
  output_path=str(output_path),
254
  )
255
- return FileResponse(
256
- path=output_path,
257
- media_type="model/gltf-binary",
258
- filename=f"{job_id}.glb",
259
- headers={
260
- "X-Job-Id": job_id,
261
- "X-Request-Id": job_id,
262
- "X-Duration-Ms": str(duration_ms),
263
- "X-Queue-Wait-Ms": str(queue_wait_ms),
264
- },
265
- )
266
- except HTTPException:
267
- raise
268
  except subprocess.TimeoutExpired as error:
269
- assert job_id is not None and job_dir is not None
270
  state.job_store.record_failure(
271
  job_dir,
272
  stage="export",
@@ -275,24 +314,7 @@ async def image_to_glb(
275
  retryable=True,
276
  details={"timeout_seconds": error.timeout},
277
  )
278
- return JSONResponse(
279
- status_code=504,
280
- content={
281
- "job_id": job_id,
282
- "stage": "export",
283
- "error_code": "timeout",
284
- "retryable": True,
285
- "message": "GLB export timed out",
286
- "details": {"timeout_seconds": error.timeout},
287
- },
288
- )
289
  except Exception as error:
290
- if job_id is None or job_dir is None:
291
- if isinstance(error, ServiceError):
292
- raise HTTPException(
293
- status_code=error.status_code, detail=error.message
294
- ) from error
295
- raise
296
  service_error = classify_runtime_error("export", error)
297
  if isinstance(error, ServiceError):
298
  service_error = error
@@ -304,10 +326,115 @@ async def image_to_glb(
304
  retryable=service_error.retryable,
305
  details=service_error.details,
306
  )
307
- return JSONResponse(
308
- status_code=service_error.status_code,
309
- content=service_error.to_dict(job_id),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  )
311
- finally:
312
- if queue_acquired and job_id is not None:
313
- await state.queue.release(job_id)
 
1
  from __future__ import annotations
2
 
3
  import asyncio
4
+ import contextlib
5
  import json
6
  import subprocess
7
  import sys
8
  import time
9
  from collections import deque
10
  from pathlib import Path
11
+ from typing import Annotated, Any
12
 
13
+ from fastapi import FastAPI, File, Form, HTTPException, UploadFile
14
  from fastapi.responses import FileResponse, JSONResponse
15
  from PIL import Image, UnidentifiedImageError
16
 
 
32
  EXPORT_TIMEOUT_SECONDS = 45
33
 
34
 
35
+ class JobQueue:
36
  def __init__(self) -> None:
37
  self._condition = asyncio.Condition()
38
  self._waiting: deque[str] = deque()
39
  self._active_job_id: str | None = None
40
 
41
+ def _queue_position_unlocked(self, job_id: str) -> int | None:
42
+ if self._active_job_id == job_id:
43
+ return 0
44
+ try:
45
+ index = self._waiting.index(job_id)
46
+ except ValueError:
47
+ return None
48
+ return index + (1 if self._active_job_id is not None else 0)
49
+
50
+ async def enqueue(self, job_id: str) -> int:
51
  async with self._condition:
52
  self._waiting.append(job_id)
53
+ self._condition.notify_all()
54
+ queue_position = self._queue_position_unlocked(job_id)
55
+ return 0 if queue_position is None else queue_position
56
+
57
+ async def claim_next(self) -> str:
58
+ async with self._condition:
59
+ while not self._waiting:
60
+ await self._condition.wait()
61
+ job_id = self._waiting.popleft()
62
+ self._active_job_id = job_id
63
+ return job_id
64
+
65
+ async def complete(self, job_id: str) -> None:
 
66
  async with self._condition:
67
  if self._active_job_id == job_id:
68
  self._active_job_id = None
 
70
  self._waiting.remove(job_id)
71
  self._condition.notify_all()
72
 
73
+ def snapshot(self, job_id: str | None = None) -> dict[str, Any]:
74
+ queue_position = None
75
+ if job_id is not None:
76
+ queue_position = self._queue_position_unlocked(job_id)
77
  return {
78
  "active_job_id": self._active_job_id,
79
  "queued_jobs": len(self._waiting),
80
+ "queue_position": queue_position,
81
  }
82
 
83
 
 
85
  def __init__(self) -> None:
86
  self.runtime = TrellisRuntime()
87
  self.job_store = JobStore(JOB_DIR)
88
+ self.queue = JobQueue()
89
+ self.worker_task: asyncio.Task[None] | None = None
90
 
91
 
92
  state = ServiceState()
93
+ app = FastAPI(title="TRELLIS.2 API", version="1.1.0")
94
 
95
 
96
  def _parse_request_payload(payload: str) -> ImageToGlbRequest:
 
119
  ) from error
120
 
121
 
122
+ def _open_image_path(path: Path) -> Image.Image:
123
+ with Image.open(path) as image:
124
+ image.load()
125
+ return image.copy()
126
+
127
+
128
  def _run_export(job_dir: Path, request: ImageToGlbRequest) -> Path:
129
  output_path = job_dir / "result.glb"
130
  result_json = job_dir / "export_result.json"
 
182
  )
183
 
184
 
185
+ def _job_urls(job_id: str) -> dict[str, str]:
 
 
 
 
 
 
 
 
 
186
  return {
187
+ "status_url": f"/v1/jobs/{job_id}",
188
+ "result_url": f"/v1/jobs/{job_id}/result",
 
 
 
189
  }
190
 
191
 
192
+ def _build_job_response(job_id: str) -> dict[str, Any]:
193
+ if not state.job_store.exists(job_id):
194
+ raise HTTPException(status_code=404, detail="Job not found")
195
+
196
+ meta = state.job_store.read_meta(job_id) or {}
197
+ failure = state.job_store.read_error(job_id)
198
+ queue_state = state.queue.snapshot(job_id)
199
+ response = {
200
+ "job_id": job_id,
201
+ "status": meta.get("status", "unknown"),
202
+ "created_at": meta.get("created_at"),
203
+ "updated_at": meta.get("updated_at"),
204
+ "input_filename": meta.get("input_filename"),
205
+ "runtime_healthy": state.runtime.is_healthy,
206
  "queued_jobs": queue_state["queued_jobs"],
207
+ "active_job_id": queue_state["active_job_id"],
208
+ **_job_urls(job_id),
209
  }
210
+ queue_position = queue_state.get("queue_position")
211
+ if queue_position is not None:
212
+ response["queue_position"] = queue_position
213
+
214
+ for field in (
215
+ "queue_wait_ms",
216
+ "preprocess_ms",
217
+ "generate_ms",
218
+ "export_ms",
219
+ "duration_ms",
220
+ "queued_at",
221
+ "started_at",
222
+ "completed_at",
223
+ "output_path",
224
+ ):
225
+ if field in meta:
226
+ response[field] = meta[field]
227
+
228
+ if response["status"] == "succeeded" and meta.get("output_path"):
229
+ response["result_ready"] = Path(meta["output_path"]).exists()
230
+
231
+ if failure is not None:
232
+ response["failure"] = failure
233
+
234
+ return response
235
+
236
+
237
+ def _process_job_sync(job_id: str) -> None:
238
+ job_dir = state.job_store.job_dir(job_id)
239
+ meta = state.job_store.read_meta(job_id) or {}
240
+ request_payload = state.job_store.read_request(job_id)
241
+ if request_payload is None:
242
+ state.job_store.record_failure(
243
+ job_dir,
244
+ stage="admission",
245
+ error_code="missing_request",
246
+ message="Job request payload is missing",
247
+ retryable=False,
248
  )
249
+ return
250
 
251
+ request_model = ImageToGlbRequest.model_validate(request_payload)
252
+ input_path_str = meta.get("input_path")
253
+ if not input_path_str:
254
+ state.job_store.record_failure(
255
+ job_dir,
256
+ stage="admission",
257
+ error_code="missing_input",
258
+ message="Uploaded image is missing",
259
+ retryable=False,
260
+ )
261
+ return
 
 
 
 
 
 
262
 
263
+ input_path = Path(input_path_str)
264
+ started = time.perf_counter()
265
+ queue_wait_ms = round(
266
+ (time.time() - float(meta.get("enqueued_at_ts", time.time()))) * 1000, 2
267
+ )
268
+ preprocess_started = time.perf_counter()
269
+ try:
270
  state.job_store.update_status(
271
  job_dir,
272
  "preprocessing",
273
+ started_at=time.time(),
274
  queue_wait_ms=queue_wait_ms,
 
275
  )
276
+ image = _open_image_path(input_path)
277
+ preprocessed = state.runtime.preprocess(image, request_model)
278
+ preprocess_ms = round((time.perf_counter() - preprocess_started) * 1000, 2)
279
 
280
+ generate_started = time.perf_counter()
281
+ state.job_store.update_status(
282
+ job_dir, "generating", preprocess_ms=preprocess_ms
283
+ )
284
  payload = state.runtime.generate_export_payload(preprocessed, request_model)
285
  save_export_payload(job_dir, payload)
286
+ generate_ms = round((time.perf_counter() - generate_started) * 1000, 2)
287
 
288
+ export_started = time.perf_counter()
289
+ state.job_store.update_status(
290
+ job_dir,
291
+ "exporting",
292
+ preprocess_ms=preprocess_ms,
293
+ generate_ms=generate_ms,
294
+ )
295
  output_path = _run_export(job_dir, request_model)
296
+ export_ms = round((time.perf_counter() - export_started) * 1000, 2)
297
  duration_ms = round((time.perf_counter() - started) * 1000, 2)
298
  state.job_store.update_status(
299
  job_dir,
300
  "succeeded",
301
+ preprocess_ms=preprocess_ms,
302
+ generate_ms=generate_ms,
303
+ export_ms=export_ms,
304
  duration_ms=duration_ms,
305
+ completed_at=time.time(),
306
  output_path=str(output_path),
307
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  except subprocess.TimeoutExpired as error:
 
309
  state.job_store.record_failure(
310
  job_dir,
311
  stage="export",
 
314
  retryable=True,
315
  details={"timeout_seconds": error.timeout},
316
  )
 
 
 
 
 
 
 
 
 
 
 
317
  except Exception as error:
 
 
 
 
 
 
318
  service_error = classify_runtime_error("export", error)
319
  if isinstance(error, ServiceError):
320
  service_error = error
 
326
  retryable=service_error.retryable,
327
  details=service_error.details,
328
  )
329
+
330
+
331
+ async def _worker_loop() -> None:
332
+ while True:
333
+ job_id = await state.queue.claim_next()
334
+ try:
335
+ await asyncio.to_thread(_process_job_sync, job_id)
336
+ finally:
337
+ await state.queue.complete(job_id)
338
+
339
+
340
+ @app.on_event("startup")
341
+ async def startup() -> None:
342
+ TMP_DIR.mkdir(parents=True, exist_ok=True)
343
+ state.job_store.cleanup_expired()
344
+ state.runtime.load()
345
+ if state.worker_task is None:
346
+ state.worker_task = asyncio.create_task(_worker_loop())
347
+
348
+
349
+ @app.on_event("shutdown")
350
+ async def shutdown() -> None:
351
+ if state.worker_task is None:
352
+ return
353
+ state.worker_task.cancel()
354
+ with contextlib.suppress(asyncio.CancelledError):
355
+ await state.worker_task
356
+ state.worker_task = None
357
+
358
+
359
+ @app.get("/healthz")
360
+ async def healthz() -> dict[str, object]:
361
+ queue_state = state.queue.snapshot()
362
+ return {
363
+ "ok": state.runtime.is_healthy,
364
+ "runtime_healthy": state.runtime.is_healthy,
365
+ "busy": queue_state["active_job_id"] is not None,
366
+ "queued_jobs": queue_state["queued_jobs"],
367
+ "active_job_id": queue_state["active_job_id"],
368
+ "reason": state.runtime.unhealthy_reason,
369
+ }
370
+
371
+
372
+ @app.get("/")
373
+ async def root() -> dict[str, object]:
374
+ queue_state = state.queue.snapshot()
375
+ return {
376
+ "service": "TRELLIS.2 API",
377
+ "version": app.version,
378
+ "submit_endpoint": "/v1/image-to-glb",
379
+ "status_endpoint": "/v1/jobs/{job_id}",
380
+ "result_endpoint": "/v1/jobs/{job_id}/result",
381
+ "healthy": state.runtime.is_healthy,
382
+ "busy": queue_state["active_job_id"] is not None,
383
+ "queued_jobs": queue_state["queued_jobs"],
384
+ }
385
+
386
+
387
+ @app.post("/v1/image-to-glb")
388
+ async def image_to_glb(
389
+ image: Annotated[UploadFile, File(...)],
390
+ request_payload: str = Form("{}", alias="request"),
391
+ ):
392
+ request_model = _parse_request_payload(request_payload)
393
+ state.job_store.cleanup_expired()
394
+ contents = await image.read()
395
+ pil_image = _open_image(contents)
396
+
397
+ job_id, job_dir = state.job_store.create(
398
+ request_model.model_dump(mode="json"), image.filename
399
+ )
400
+ input_suffix = Path(image.filename or "input.png").suffix or ".png"
401
+ input_path = job_dir / f"input{input_suffix}"
402
+ save_input_image(pil_image, input_path)
403
+
404
+ queued_at = time.time()
405
+ queue_position = await state.queue.enqueue(job_id)
406
+ state.job_store.update_status(
407
+ job_dir,
408
+ "queued",
409
+ enqueued_at=queued_at,
410
+ enqueued_at_ts=queued_at,
411
+ input_path=str(input_path),
412
+ initial_queue_position=queue_position,
413
+ )
414
+
415
+ response = _build_job_response(job_id)
416
+ return JSONResponse(status_code=202, content=response)
417
+
418
+
419
+ @app.get("/v1/jobs/{job_id}")
420
+ async def get_job(job_id: str) -> dict[str, Any]:
421
+ return _build_job_response(job_id)
422
+
423
+
424
+ @app.get("/v1/jobs/{job_id}/result")
425
+ async def get_job_result(job_id: str):
426
+ response = _build_job_response(job_id)
427
+ status = response["status"]
428
+ if status == "succeeded":
429
+ output_path = response.get("output_path")
430
+ if not output_path or not Path(output_path).exists():
431
+ raise HTTPException(status_code=404, detail="Job result is missing")
432
+ return FileResponse(
433
+ path=output_path,
434
+ media_type="model/gltf-binary",
435
+ filename=f"{job_id}.glb",
436
+ headers={"X-Job-Id": job_id},
437
  )
438
+ if status == "failed":
439
+ return JSONResponse(status_code=409, content=response)
440
+ return JSONResponse(status_code=202, content=response)