TRELLIS.2 / server.py
choephix's picture
Update default maximum concurrent jobs to 8 for improved performance
c08e8d5
from __future__ import annotations
import asyncio
import contextlib
import json
import multiprocessing
import os
import subprocess
import sys
import time
from collections import deque
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from typing import Annotated, Any
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse
from PIL import Image, UnidentifiedImageError
import runtime_env # noqa: F401
from job_store import JobStore
from schemas import ImageToGlbRequest
from service_runtime import (
ServiceError,
TrellisRuntime,
classify_runtime_error,
save_export_payload,
save_input_image,
)
ROOT_DIR = Path(__file__).resolve().parent
TMP_DIR = ROOT_DIR / "tmp"
JOB_DIR = TMP_DIR / "jobs"
EXPORT_TIMEOUT_SECONDS = 45
EXPORT_MAX_ATTEMPTS = 3
EXPORT_RETRY_DELAY_SECONDS = 1.5
DEFAULT_MAX_CONCURRENT_JOBS = 8
DEFAULT_JOB_CLEANUP_INTERVAL_SECONDS = 300
def _read_positive_int_env(name: str, default: int) -> int:
raw_value = os.getenv(name)
if raw_value is None:
return default
try:
value = int(raw_value)
except ValueError:
return default
return value if value > 0 else default
MAX_CONCURRENT_JOBS = _read_positive_int_env(
"MAX_CONCURRENT_JOBS", DEFAULT_MAX_CONCURRENT_JOBS
)
JOB_CLEANUP_INTERVAL_SECONDS = _read_positive_int_env(
"JOB_CLEANUP_INTERVAL_SECONDS", DEFAULT_JOB_CLEANUP_INTERVAL_SECONDS
)
_WORKER_RUNTIME: TrellisRuntime | None = None
class JobQueue:
def __init__(self) -> None:
self._condition = asyncio.Condition()
self._waiting: deque[str] = deque()
self._starting_job_ids: set[str] = set()
self._active_job_ids: set[str] = set()
def _queue_position_unlocked(self, job_id: str) -> int | None:
if job_id in self._starting_job_ids or job_id in self._active_job_ids:
return 0
try:
index = self._waiting.index(job_id)
except ValueError:
return None
return index + len(self._starting_job_ids) + len(self._active_job_ids)
async def enqueue(self, job_id: str) -> int:
async with self._condition:
self._waiting.append(job_id)
self._condition.notify_all()
queue_position = self._queue_position_unlocked(job_id)
return 0 if queue_position is None else queue_position
async def claim_next(self) -> str:
async with self._condition:
while not self._waiting:
await self._condition.wait()
return self._waiting.popleft()
async def mark_starting(self, job_id: str) -> None:
async with self._condition:
self._starting_job_ids.add(job_id)
async def mark_active(self, job_id: str) -> None:
async with self._condition:
self._starting_job_ids.discard(job_id)
self._active_job_ids.add(job_id)
async def complete(self, job_id: str) -> None:
async with self._condition:
self._starting_job_ids.discard(job_id)
self._active_job_ids.discard(job_id)
if job_id in self._waiting:
self._waiting.remove(job_id)
self._condition.notify_all()
async def snapshot(self, job_id: str | None = None) -> dict[str, Any]:
async with self._condition:
queue_position = None
if job_id is not None:
queue_position = self._queue_position_unlocked(job_id)
starting_job_ids = sorted(self._starting_job_ids)
running_job_ids = sorted(self._active_job_ids)
active_job_ids = sorted(self._starting_job_ids | self._active_job_ids)
return {
"starting_job_ids": starting_job_ids,
"running_job_ids": running_job_ids,
"active_job_ids": active_job_ids,
"queued_job_ids": list(self._waiting),
"starting_jobs": len(self._starting_job_ids),
"running_jobs": len(self._active_job_ids),
"active_jobs": len(active_job_ids),
"queued_jobs": len(self._waiting),
"queue_position": queue_position,
"is_starting": job_id in self._starting_job_ids if job_id else False,
"is_running": job_id in self._active_job_ids if job_id else False,
"is_active": (
job_id in self._starting_job_ids or job_id in self._active_job_ids
)
if job_id
else False,
}
class ServiceState:
def __init__(self) -> None:
self.job_store = JobStore(JOB_DIR)
self.queue = JobQueue()
self.executor: ProcessPoolExecutor | None = None
self.worker_tasks: list[asyncio.Task[None]] = []
self.cleanup_task: asyncio.Task[None] | None = None
self.worker_pool_ready = False
self.worker_pool_error: str | None = None
self.warmed_worker_pids: list[int] = []
state = ServiceState()
app = FastAPI(title="TRELLIS.2 API", version="1.2.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
def _parse_request_payload(payload: str) -> ImageToGlbRequest:
try:
raw = json.loads(payload) if payload else {}
except json.JSONDecodeError as error:
raise HTTPException(
status_code=400, detail=f"Invalid request JSON: {error}"
) from error
try:
return ImageToGlbRequest.model_validate(raw)
except Exception as error:
raise HTTPException(status_code=422, detail=str(error)) from error
def _open_image(contents: bytes) -> Image.Image:
from io import BytesIO
try:
image = Image.open(BytesIO(contents))
image.load()
return image
except UnidentifiedImageError as error:
raise HTTPException(
status_code=400, detail="Uploaded file is not a supported image"
) from error
def _open_image_path(path: Path) -> Image.Image:
with Image.open(path) as image:
image.load()
return image.copy()
def _resolve_input_suffix(filename: str | None) -> str:
return Path(filename or "input.png").suffix or ".png"
def _prepare_job_sync(
request_payload: dict[str, Any], filename: str | None, contents: bytes
) -> tuple[str, Path, str]:
pil_image = _open_image(contents)
job_id, job_dir = state.job_store.create(request_payload, filename)
input_path = job_dir / f"input{_resolve_input_suffix(filename)}"
save_input_image(pil_image, input_path)
return job_id, job_dir, str(input_path)
def _mark_job_queued(
job_dir: Path, *, queued_at: float, input_path: str, queue_position: int
) -> None:
state.job_store.update_status(
job_dir,
"queued",
enqueued_at=queued_at,
enqueued_at_ts=queued_at,
input_path=input_path,
initial_queue_position=queue_position,
)
def _mark_job_starting(job_id: str) -> None:
state.job_store.update_status(
state.job_store.job_dir(job_id),
"starting",
dispatched_at=time.time(),
)
def _worker_process_init() -> None:
global _WORKER_RUNTIME
runtime = TrellisRuntime()
runtime.load()
_WORKER_RUNTIME = runtime
def _get_worker_runtime() -> TrellisRuntime:
global _WORKER_RUNTIME
if _WORKER_RUNTIME is None:
_worker_process_init()
assert _WORKER_RUNTIME is not None
return _WORKER_RUNTIME
def _run_export(job_dir: Path, request: ImageToGlbRequest) -> Path:
output_path = job_dir / "result.glb"
result_json = job_dir / "export_result.json"
payload_npz = job_dir / "export_payload.npz"
payload_meta = job_dir / "export_payload.json"
for path in (output_path, result_json):
if path.exists():
path.unlink()
command = [
sys.executable,
str(ROOT_DIR / "export_worker.py"),
"--payload-npz",
str(payload_npz),
"--payload-meta",
str(payload_meta),
"--output",
str(output_path),
"--decimation-target",
str(request.export.decimation_target),
"--texture-size",
str(request.export.texture_size),
"--remesh",
str(int(request.export.remesh)),
"--safe-nonremesh-fallback",
str(int(request.export.safe_nonremesh_fallback)),
"--result-json",
str(result_json),
]
completed = subprocess.run(
command,
cwd=str(ROOT_DIR),
capture_output=True,
text=True,
timeout=EXPORT_TIMEOUT_SECONDS,
)
if completed.returncode == 0 and output_path.exists():
return output_path
if result_json.exists():
result = json.loads(result_json.read_text(encoding="utf-8"))
message = result.get("message", "GLB export failed")
details = {
"traceback": result.get("traceback"),
"stdout": completed.stdout,
"stderr": completed.stderr,
}
else:
message = "GLB export worker crashed before returning a result"
details = {
"stdout": completed.stdout,
"stderr": completed.stderr,
"returncode": completed.returncode,
}
raise ServiceError(
stage="export",
error_code="export_failed",
message=message,
retryable=True,
status_code=500,
details=details,
)
def _run_export_with_retries(
job_store: JobStore, job_dir: Path, request: ImageToGlbRequest
) -> tuple[Path, int]:
last_error: ServiceError | None = None
for attempt in range(1, EXPORT_MAX_ATTEMPTS + 1):
job_store.update_status(
job_dir,
"exporting",
export_attempt=attempt,
export_max_attempts=EXPORT_MAX_ATTEMPTS,
)
try:
output_path = _run_export(job_dir, request)
return output_path, attempt
except subprocess.TimeoutExpired as error:
last_error = ServiceError(
stage="export",
error_code="timeout",
message="GLB export timed out",
retryable=True,
status_code=504,
details={
"timeout_seconds": error.timeout,
"attempt": attempt,
"max_attempts": EXPORT_MAX_ATTEMPTS,
},
)
except ServiceError as error:
last_error = error
last_error.details = {
**last_error.details,
"attempt": attempt,
"max_attempts": EXPORT_MAX_ATTEMPTS,
}
job_store.update_status(
job_dir,
"export_retrying" if attempt < EXPORT_MAX_ATTEMPTS else "exporting",
export_attempt=attempt,
export_max_attempts=EXPORT_MAX_ATTEMPTS,
last_export_error=last_error.message if last_error is not None else None,
)
if (
last_error is None
or not last_error.retryable
or attempt >= EXPORT_MAX_ATTEMPTS
):
break
time.sleep(EXPORT_RETRY_DELAY_SECONDS * attempt)
assert last_error is not None
raise last_error
def _warm_worker_process(_: int) -> int:
_get_worker_runtime()
return os.getpid()
async def _warm_worker_pool() -> list[int]:
if state.executor is None:
raise RuntimeError("Worker pool is not initialized")
loop = asyncio.get_running_loop()
warmed_pids = await asyncio.gather(
*[
loop.run_in_executor(state.executor, _warm_worker_process, index)
for index in range(MAX_CONCURRENT_JOBS)
]
)
return sorted(set(warmed_pids))
async def _cleanup_loop() -> None:
while True:
await asyncio.sleep(JOB_CLEANUP_INTERVAL_SECONDS)
await asyncio.to_thread(state.job_store.cleanup_expired)
def _job_urls(job_id: str) -> dict[str, str]:
return {
"status_url": f"/v1/jobs/{job_id}",
"result_url": f"/v1/jobs/{job_id}/result",
}
def _effective_job_status(meta: dict[str, Any], queue_state: dict[str, Any]) -> str:
status = str(meta.get("status", "unknown"))
if (
queue_state.get("is_starting") or queue_state.get("is_running")
) and status in {"received", "queued", "unknown"}:
return "starting"
return status
def _build_job_response_sync(job_id: str, queue_state: dict[str, Any]) -> dict[str, Any]:
if not state.job_store.exists(job_id):
raise HTTPException(status_code=404, detail="Job not found")
meta = state.job_store.read_meta(job_id) or {}
failure = state.job_store.read_error(job_id)
response = {
"job_id": job_id,
"status": _effective_job_status(meta, queue_state),
"created_at": meta.get("created_at"),
"updated_at": meta.get("updated_at"),
"input_filename": meta.get("input_filename"),
"worker_pool_ready": state.worker_pool_ready,
"worker_pool_error": state.worker_pool_error,
"max_concurrent_jobs": MAX_CONCURRENT_JOBS,
"queued_jobs": queue_state["queued_jobs"],
"active_jobs": queue_state["active_jobs"],
"active_job_ids": queue_state["active_job_ids"],
"starting_jobs": queue_state["starting_jobs"],
"starting_job_ids": queue_state["starting_job_ids"],
"running_jobs": queue_state["running_jobs"],
"running_job_ids": queue_state["running_job_ids"],
**_job_urls(job_id),
}
queue_position = queue_state.get("queue_position")
if queue_position is not None:
response["queue_position"] = queue_position
for field in (
"queue_wait_ms",
"preprocess_ms",
"generate_ms",
"export_ms",
"export_attempt",
"export_max_attempts",
"duration_ms",
"queued_at",
"started_at",
"completed_at",
"output_path",
"last_export_error",
):
if field in meta:
response[field] = meta[field]
if response["status"] == "succeeded" and meta.get("output_path"):
response["result_ready"] = Path(meta["output_path"]).exists()
if failure is not None:
response["failure"] = failure
return response
async def _build_job_response(job_id: str) -> dict[str, Any]:
queue_state = await state.queue.snapshot(job_id)
return await asyncio.to_thread(_build_job_response_sync, job_id, queue_state)
def _build_job_summary(
job_id: str,
*,
queue_position: int | None = None,
is_starting: bool = False,
is_running: bool = False,
) -> dict[str, Any]:
meta = state.job_store.read_meta(job_id) or {}
request = state.job_store.read_request(job_id) or {}
queue_state = {"is_starting": is_starting, "is_running": is_running}
summary = {
"job_id": job_id,
"status": _effective_job_status(meta, queue_state),
"input_filename": meta.get("input_filename"),
"created_at": meta.get("created_at"),
"updated_at": meta.get("updated_at"),
"started_at": meta.get("started_at"),
"completed_at": meta.get("completed_at"),
"queue_wait_ms": meta.get("queue_wait_ms"),
"preprocess_ms": meta.get("preprocess_ms"),
"generate_ms": meta.get("generate_ms"),
"export_ms": meta.get("export_ms"),
"duration_ms": meta.get("duration_ms"),
"export_attempt": meta.get("export_attempt"),
"export_max_attempts": meta.get("export_max_attempts"),
"is_starting": is_starting,
"is_running": is_running,
"request": request,
"meta": meta,
}
if queue_position is not None:
summary["queue_position"] = queue_position
return summary
def _build_status_response_sync(queue_state: dict[str, Any]) -> dict[str, Any]:
starting_job_ids = set(queue_state["starting_job_ids"])
running_job_ids = set(queue_state["running_job_ids"])
active_jobs_detail = [
_build_job_summary(
job_id,
is_starting=job_id in starting_job_ids,
is_running=job_id in running_job_ids,
)
for job_id in queue_state["active_job_ids"]
]
queued_jobs_detail = [
_build_job_summary(
job_id,
queue_position=index + len(queue_state["active_job_ids"]),
)
for index, job_id in enumerate(queue_state["queued_job_ids"])
]
return {
"service": "TRELLIS.2 API",
"version": app.version,
"submit_endpoint": "/v1/image-to-glb",
"status_endpoint": "/v1/jobs/{job_id}",
"result_endpoint": "/v1/jobs/{job_id}/result",
"ok": state.worker_pool_ready,
"worker_pool_ready": state.worker_pool_ready,
"worker_pool_error": state.worker_pool_error,
"max_concurrent_jobs": MAX_CONCURRENT_JOBS,
"active_job_ids": queue_state["active_job_ids"],
"starting_job_ids": queue_state["starting_job_ids"],
"running_job_ids": queue_state["running_job_ids"],
"queued_job_ids": queue_state["queued_job_ids"],
"active_jobs": queue_state["active_jobs"],
"starting_jobs": queue_state["starting_jobs"],
"running_jobs": queue_state["running_jobs"],
"queued_jobs": queue_state["queued_jobs"],
"active_jobs_detail": active_jobs_detail,
"queued_jobs_detail": queued_jobs_detail,
}
async def _build_status_response() -> dict[str, Any]:
queue_state = await state.queue.snapshot()
return await asyncio.to_thread(_build_status_response_sync, queue_state)
def _process_job_sync_worker(job_id: str) -> None:
job_store = JobStore(JOB_DIR)
job_dir = job_store.job_dir(job_id)
meta = job_store.read_meta(job_id) or {}
request_payload = job_store.read_request(job_id)
if request_payload is None:
job_store.record_failure(
job_dir,
stage="admission",
error_code="missing_request",
message="Job request payload is missing",
retryable=False,
)
return
request_model = ImageToGlbRequest.model_validate(request_payload)
input_path_str = meta.get("input_path")
if not input_path_str:
job_store.record_failure(
job_dir,
stage="admission",
error_code="missing_input",
message="Uploaded image is missing",
retryable=False,
)
return
input_path = Path(input_path_str)
started = time.perf_counter()
queue_wait_ms = round(
(time.time() - float(meta.get("enqueued_at_ts", time.time()))) * 1000, 2
)
try:
job_store.update_status(
job_dir,
"starting",
started_at=time.time(),
queue_wait_ms=queue_wait_ms,
)
runtime = _get_worker_runtime()
preprocess_started = time.perf_counter()
job_store.update_status(job_dir, "preprocessing")
image = _open_image_path(input_path)
preprocessed = runtime.preprocess(image, request_model)
preprocess_ms = round((time.perf_counter() - preprocess_started) * 1000, 2)
generate_started = time.perf_counter()
job_store.update_status(job_dir, "generating", preprocess_ms=preprocess_ms)
payload = runtime.generate_export_payload(preprocessed, request_model)
save_export_payload(job_dir, payload)
generate_ms = round((time.perf_counter() - generate_started) * 1000, 2)
export_started = time.perf_counter()
job_store.update_status(
job_dir,
"exporting",
preprocess_ms=preprocess_ms,
generate_ms=generate_ms,
export_attempt=0,
export_max_attempts=EXPORT_MAX_ATTEMPTS,
)
output_path, export_attempt = _run_export_with_retries(
job_store, job_dir, request_model
)
export_ms = round((time.perf_counter() - export_started) * 1000, 2)
duration_ms = round((time.perf_counter() - started) * 1000, 2)
job_store.update_status(
job_dir,
"succeeded",
preprocess_ms=preprocess_ms,
generate_ms=generate_ms,
export_ms=export_ms,
export_attempt=export_attempt,
export_max_attempts=EXPORT_MAX_ATTEMPTS,
duration_ms=duration_ms,
completed_at=time.time(),
output_path=str(output_path),
)
except subprocess.TimeoutExpired as error:
job_store.record_failure(
job_dir,
stage="export",
error_code="timeout",
message="GLB export timed out",
retryable=True,
details={"timeout_seconds": error.timeout},
)
except Exception as error:
service_error = classify_runtime_error("export", error)
if isinstance(error, ServiceError):
service_error = error
job_store.record_failure(
job_dir,
stage=service_error.stage,
error_code=service_error.error_code,
message=service_error.message,
retryable=service_error.retryable,
details=service_error.details,
)
async def _worker_loop() -> None:
loop = asyncio.get_running_loop()
while True:
job_id = await state.queue.claim_next()
try:
if state.executor is None:
raise RuntimeError("Worker pool is not initialized")
await state.queue.mark_starting(job_id)
await asyncio.to_thread(_mark_job_starting, job_id)
await state.queue.mark_active(job_id)
await loop.run_in_executor(state.executor, _process_job_sync_worker, job_id)
finally:
await state.queue.complete(job_id)
@app.on_event("startup")
async def startup() -> None:
TMP_DIR.mkdir(parents=True, exist_ok=True)
await asyncio.to_thread(state.job_store.cleanup_expired)
if state.executor is None:
state.executor = ProcessPoolExecutor(
max_workers=MAX_CONCURRENT_JOBS,
mp_context=multiprocessing.get_context("spawn"),
initializer=_worker_process_init,
)
state.worker_pool_ready = False
state.worker_pool_error = None
try:
state.warmed_worker_pids = await _warm_worker_pool()
state.worker_pool_ready = True
except Exception as error:
state.worker_pool_error = f"{type(error).__name__}: {error}"
if not state.worker_tasks:
state.worker_tasks = [
asyncio.create_task(_worker_loop()) for _ in range(MAX_CONCURRENT_JOBS)
]
if state.cleanup_task is None:
state.cleanup_task = asyncio.create_task(_cleanup_loop())
@app.on_event("shutdown")
async def shutdown() -> None:
if state.cleanup_task is not None:
state.cleanup_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await state.cleanup_task
state.cleanup_task = None
for task in state.worker_tasks:
task.cancel()
for task in state.worker_tasks:
with contextlib.suppress(asyncio.CancelledError):
await task
state.worker_tasks = []
if state.executor is not None:
state.executor.shutdown(wait=False, cancel_futures=True)
state.executor = None
state.worker_pool_ready = False
state.worker_pool_error = None
state.warmed_worker_pids = []
@app.get("/healthz")
async def healthz() -> dict[str, object]:
queue_state = await state.queue.snapshot()
return {
"ok": state.worker_pool_ready,
"version": app.version,
"worker_pool_ready": state.worker_pool_ready,
"worker_pool_error": state.worker_pool_error,
"max_concurrent_jobs": MAX_CONCURRENT_JOBS,
"active_jobs": queue_state["active_jobs"],
"active_job_ids": queue_state["active_job_ids"],
"starting_jobs": queue_state["starting_jobs"],
"starting_job_ids": queue_state["starting_job_ids"],
"running_jobs": queue_state["running_jobs"],
"running_job_ids": queue_state["running_job_ids"],
"queued_jobs": queue_state["queued_jobs"],
"warmed_worker_pids": state.warmed_worker_pids,
}
@app.get("/api/status")
async def api_status() -> dict[str, Any]:
return await _build_status_response()
@app.get("/")
async def root():
return FileResponse(ROOT_DIR / "frontend" / "index.html")
@app.post("/v1/image-to-glb")
async def image_to_glb(
image: Annotated[UploadFile, File(...)],
request_payload: str = Form("{}", alias="request"),
):
if not state.worker_pool_ready:
raise HTTPException(
status_code=503,
detail=state.worker_pool_error or "Worker pool is still warming up",
)
request_model = _parse_request_payload(request_payload)
contents = await image.read()
job_id, job_dir, input_path = await asyncio.to_thread(
_prepare_job_sync,
request_model.model_dump(mode="json"),
image.filename,
contents,
)
queued_at = time.time()
queue_position = await state.queue.enqueue(job_id)
await asyncio.to_thread(
_mark_job_queued,
job_dir,
queued_at=queued_at,
input_path=input_path,
queue_position=queue_position,
)
response = await _build_job_response(job_id)
return JSONResponse(status_code=202, content=response)
@app.get("/v1/jobs/{job_id}")
async def get_job(job_id: str) -> dict[str, Any]:
return await _build_job_response(job_id)
@app.get("/v1/jobs/{job_id}/result")
async def get_job_result(job_id: str):
response = await _build_job_response(job_id)
status = response["status"]
if status == "succeeded":
output_path = response.get("output_path")
if not output_path or not Path(output_path).exists():
raise HTTPException(status_code=404, detail="Job result is missing")
return FileResponse(
path=output_path,
media_type="model/gltf-binary",
filename=f"{job_id}.glb",
headers={"X-Job-Id": job_id},
)
if status == "failed":
return JSONResponse(status_code=409, content=response)
return JSONResponse(status_code=202, content=response)