choephix commited on
Commit
06eee21
·
1 Parent(s): 3f83257

queue API requests serially

Browse files
Files changed (1) hide show
  1. server.py +77 -27
server.py CHANGED
@@ -1,14 +1,15 @@
1
  from __future__ import annotations
2
 
 
3
  import json
4
  import subprocess
5
  import sys
6
- import threading
7
  import time
 
8
  from pathlib import Path
9
  from typing import Annotated
10
 
11
- from fastapi import FastAPI, File, Form, HTTPException, UploadFile
12
  from fastapi.responses import FileResponse, JSONResponse
13
  from PIL import Image, UnidentifiedImageError
14
 
@@ -27,15 +28,51 @@ from service_runtime import (
27
  ROOT_DIR = Path(__file__).resolve().parent
28
  TMP_DIR = ROOT_DIR / "tmp"
29
  JOB_DIR = TMP_DIR / "jobs"
30
- BUSY_RETRY_AFTER_SECONDS = 15
31
  EXPORT_TIMEOUT_SECONDS = 45
32
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  class ServiceState:
35
  def __init__(self) -> None:
36
  self.runtime = TrellisRuntime()
37
  self.job_store = JobStore(JOB_DIR)
38
- self.lock = threading.Lock()
39
 
40
 
41
  state = ServiceState()
@@ -68,19 +105,6 @@ def _open_image(contents: bytes) -> Image.Image:
68
  ) from error
69
 
70
 
71
- def _busy_response() -> JSONResponse:
72
- return JSONResponse(
73
- status_code=429,
74
- headers={"Retry-After": str(BUSY_RETRY_AFTER_SECONDS)},
75
- content={
76
- "stage": "admission",
77
- "error_code": "busy",
78
- "retryable": True,
79
- "message": "Another request is already running. Retry shortly.",
80
- },
81
- )
82
-
83
-
84
  def _run_export(job_dir: Path, request: ImageToGlbRequest) -> Path:
85
  output_path = job_dir / "result.glb"
86
  result_json = job_dir / "export_result.json"
@@ -146,38 +170,42 @@ def startup() -> None:
146
 
147
 
148
  @app.get("/healthz")
149
- def healthz() -> dict[str, object]:
 
150
  return {
151
  "ok": state.runtime.is_healthy,
152
  "runtime_healthy": state.runtime.is_healthy,
153
- "busy": state.lock.locked(),
 
154
  "reason": state.runtime.unhealthy_reason,
155
  }
156
 
157
 
158
  @app.get("/")
159
- def root() -> dict[str, object]:
 
160
  return {
161
  "service": "TRELLIS.2 API",
162
  "version": app.version,
163
  "endpoint": "/v1/image-to-glb",
164
  "healthy": state.runtime.is_healthy,
 
 
165
  }
166
 
167
 
168
  @app.post("/v1/image-to-glb")
169
  async def image_to_glb(
 
170
  image: Annotated[UploadFile, File(...)],
171
- request: str = Form("{}"),
172
  ):
173
- if not state.lock.acquire(blocking=False):
174
- return _busy_response()
175
-
176
  job_id: str | None = None
177
  job_dir: Path | None = None
 
178
  try:
179
  started = time.perf_counter()
180
- request_model = _parse_request_payload(request)
181
  state.job_store.cleanup_expired()
182
  contents = await image.read()
183
  pil_image = _open_image(contents)
@@ -188,7 +216,27 @@ async def image_to_glb(
188
  input_suffix = Path(image.filename or "input.png").suffix or ".png"
189
  input_path = job_dir / f"input{input_suffix}"
190
  save_input_image(pil_image, input_path)
191
- state.job_store.update_status(job_dir, "preprocessing")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  preprocessed = state.runtime.preprocess(pil_image, request_model)
193
 
194
  state.job_store.update_status(job_dir, "generating")
@@ -212,6 +260,7 @@ async def image_to_glb(
212
  "X-Job-Id": job_id,
213
  "X-Request-Id": job_id,
214
  "X-Duration-Ms": str(duration_ms),
 
215
  },
216
  )
217
  except HTTPException:
@@ -260,4 +309,5 @@ async def image_to_glb(
260
  content=service_error.to_dict(job_id),
261
  )
262
  finally:
263
- state.lock.release()
 
 
1
  from __future__ import annotations
2
 
3
+ import asyncio
4
  import json
5
  import subprocess
6
  import sys
 
7
  import time
8
+ from collections import deque
9
  from pathlib import Path
10
  from typing import Annotated
11
 
12
+ from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
13
  from fastapi.responses import FileResponse, JSONResponse
14
  from PIL import Image, UnidentifiedImageError
15
 
 
28
  ROOT_DIR = Path(__file__).resolve().parent
29
  TMP_DIR = ROOT_DIR / "tmp"
30
  JOB_DIR = TMP_DIR / "jobs"
 
31
  EXPORT_TIMEOUT_SECONDS = 45
32
 
33
 
34
+ class RequestQueue:
35
+ def __init__(self) -> None:
36
+ self._condition = asyncio.Condition()
37
+ self._waiting: deque[str] = deque()
38
+ self._active_job_id: str | None = None
39
+
40
+ async def acquire(self, job_id: str) -> int:
41
+ async with self._condition:
42
+ self._waiting.append(job_id)
43
+ queued_ahead = len(self._waiting) - 1
44
+ try:
45
+ while self._waiting[0] != job_id or self._active_job_id is not None:
46
+ await self._condition.wait()
47
+ self._waiting.popleft()
48
+ self._active_job_id = job_id
49
+ return queued_ahead
50
+ except BaseException:
51
+ if job_id in self._waiting:
52
+ self._waiting.remove(job_id)
53
+ self._condition.notify_all()
54
+ raise
55
+
56
+ async def release(self, job_id: str) -> None:
57
+ async with self._condition:
58
+ if self._active_job_id == job_id:
59
+ self._active_job_id = None
60
+ elif job_id in self._waiting:
61
+ self._waiting.remove(job_id)
62
+ self._condition.notify_all()
63
+
64
+ def snapshot(self) -> dict[str, object]:
65
+ return {
66
+ "active_job_id": self._active_job_id,
67
+ "queued_jobs": len(self._waiting),
68
+ }
69
+
70
+
71
  class ServiceState:
72
  def __init__(self) -> None:
73
  self.runtime = TrellisRuntime()
74
  self.job_store = JobStore(JOB_DIR)
75
+ self.queue = RequestQueue()
76
 
77
 
78
  state = ServiceState()
 
105
  ) from error
106
 
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  def _run_export(job_dir: Path, request: ImageToGlbRequest) -> Path:
109
  output_path = job_dir / "result.glb"
110
  result_json = job_dir / "export_result.json"
 
170
 
171
 
172
  @app.get("/healthz")
173
+ async def healthz() -> dict[str, object]:
174
+ queue_state = state.queue.snapshot()
175
  return {
176
  "ok": state.runtime.is_healthy,
177
  "runtime_healthy": state.runtime.is_healthy,
178
+ "busy": queue_state["active_job_id"] is not None,
179
+ "queued_jobs": queue_state["queued_jobs"],
180
  "reason": state.runtime.unhealthy_reason,
181
  }
182
 
183
 
184
  @app.get("/")
185
+ async def root() -> dict[str, object]:
186
+ queue_state = state.queue.snapshot()
187
  return {
188
  "service": "TRELLIS.2 API",
189
  "version": app.version,
190
  "endpoint": "/v1/image-to-glb",
191
  "healthy": state.runtime.is_healthy,
192
+ "busy": queue_state["active_job_id"] is not None,
193
+ "queued_jobs": queue_state["queued_jobs"],
194
  }
195
 
196
 
197
  @app.post("/v1/image-to-glb")
198
  async def image_to_glb(
199
+ http_request: Request,
200
  image: Annotated[UploadFile, File(...)],
201
+ request_payload: str = Form("{}", alias="request"),
202
  ):
 
 
 
203
  job_id: str | None = None
204
  job_dir: Path | None = None
205
+ queue_acquired = False
206
  try:
207
  started = time.perf_counter()
208
+ request_model = _parse_request_payload(request_payload)
209
  state.job_store.cleanup_expired()
210
  contents = await image.read()
211
  pil_image = _open_image(contents)
 
216
  input_suffix = Path(image.filename or "input.png").suffix or ".png"
217
  input_path = job_dir / f"input{input_suffix}"
218
  save_input_image(pil_image, input_path)
219
+
220
+ queued_at = time.perf_counter()
221
+ state.job_store.update_status(job_dir, "queued")
222
+ queued_ahead = await state.queue.acquire(job_id)
223
+ queue_acquired = True
224
+ queue_wait_ms = round((time.perf_counter() - queued_at) * 1000, 2)
225
+ if await http_request.is_disconnected():
226
+ raise ServiceError(
227
+ stage="admission",
228
+ error_code="client_disconnected",
229
+ message="Client disconnected while waiting in queue",
230
+ retryable=False,
231
+ status_code=499,
232
+ )
233
+
234
+ state.job_store.update_status(
235
+ job_dir,
236
+ "preprocessing",
237
+ queue_wait_ms=queue_wait_ms,
238
+ queued_ahead=queued_ahead,
239
+ )
240
  preprocessed = state.runtime.preprocess(pil_image, request_model)
241
 
242
  state.job_store.update_status(job_dir, "generating")
 
260
  "X-Job-Id": job_id,
261
  "X-Request-Id": job_id,
262
  "X-Duration-Ms": str(duration_ms),
263
+ "X-Queue-Wait-Ms": str(queue_wait_ms),
264
  },
265
  )
266
  except HTTPException:
 
309
  content=service_error.to_dict(job_id),
310
  )
311
  finally:
312
+ if queue_acquired and job_id is not None:
313
+ await state.queue.release(job_id)