from __future__ import annotations import asyncio import contextlib import json import multiprocessing import os import subprocess import sys import time from collections import deque from concurrent.futures import ProcessPoolExecutor from pathlib import Path from typing import Annotated, Any from fastapi import FastAPI, File, Form, HTTPException, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, JSONResponse from PIL import Image, UnidentifiedImageError import runtime_env # noqa: F401 from job_store import JobStore from schemas import ImageToGlbRequest from service_runtime import ( ServiceError, TrellisRuntime, classify_runtime_error, save_export_payload, save_input_image, ) ROOT_DIR = Path(__file__).resolve().parent TMP_DIR = ROOT_DIR / "tmp" JOB_DIR = TMP_DIR / "jobs" EXPORT_TIMEOUT_SECONDS = 45 EXPORT_MAX_ATTEMPTS = 3 EXPORT_RETRY_DELAY_SECONDS = 1.5 DEFAULT_MAX_CONCURRENT_JOBS = 8 DEFAULT_JOB_CLEANUP_INTERVAL_SECONDS = 300 def _read_positive_int_env(name: str, default: int) -> int: raw_value = os.getenv(name) if raw_value is None: return default try: value = int(raw_value) except ValueError: return default return value if value > 0 else default MAX_CONCURRENT_JOBS = _read_positive_int_env( "MAX_CONCURRENT_JOBS", DEFAULT_MAX_CONCURRENT_JOBS ) JOB_CLEANUP_INTERVAL_SECONDS = _read_positive_int_env( "JOB_CLEANUP_INTERVAL_SECONDS", DEFAULT_JOB_CLEANUP_INTERVAL_SECONDS ) _WORKER_RUNTIME: TrellisRuntime | None = None class JobQueue: def __init__(self) -> None: self._condition = asyncio.Condition() self._waiting: deque[str] = deque() self._starting_job_ids: set[str] = set() self._active_job_ids: set[str] = set() def _queue_position_unlocked(self, job_id: str) -> int | None: if job_id in self._starting_job_ids or job_id in self._active_job_ids: return 0 try: index = self._waiting.index(job_id) except ValueError: return None return index + len(self._starting_job_ids) + len(self._active_job_ids) async def enqueue(self, job_id: str) -> int: async with self._condition: self._waiting.append(job_id) self._condition.notify_all() queue_position = self._queue_position_unlocked(job_id) return 0 if queue_position is None else queue_position async def claim_next(self) -> str: async with self._condition: while not self._waiting: await self._condition.wait() return self._waiting.popleft() async def mark_starting(self, job_id: str) -> None: async with self._condition: self._starting_job_ids.add(job_id) async def mark_active(self, job_id: str) -> None: async with self._condition: self._starting_job_ids.discard(job_id) self._active_job_ids.add(job_id) async def complete(self, job_id: str) -> None: async with self._condition: self._starting_job_ids.discard(job_id) self._active_job_ids.discard(job_id) if job_id in self._waiting: self._waiting.remove(job_id) self._condition.notify_all() async def snapshot(self, job_id: str | None = None) -> dict[str, Any]: async with self._condition: queue_position = None if job_id is not None: queue_position = self._queue_position_unlocked(job_id) starting_job_ids = sorted(self._starting_job_ids) running_job_ids = sorted(self._active_job_ids) active_job_ids = sorted(self._starting_job_ids | self._active_job_ids) return { "starting_job_ids": starting_job_ids, "running_job_ids": running_job_ids, "active_job_ids": active_job_ids, "queued_job_ids": list(self._waiting), "starting_jobs": len(self._starting_job_ids), "running_jobs": len(self._active_job_ids), "active_jobs": len(active_job_ids), "queued_jobs": len(self._waiting), "queue_position": queue_position, "is_starting": job_id in self._starting_job_ids if job_id else False, "is_running": job_id in self._active_job_ids if job_id else False, "is_active": ( job_id in self._starting_job_ids or job_id in self._active_job_ids ) if job_id else False, } class ServiceState: def __init__(self) -> None: self.job_store = JobStore(JOB_DIR) self.queue = JobQueue() self.executor: ProcessPoolExecutor | None = None self.worker_tasks: list[asyncio.Task[None]] = [] self.cleanup_task: asyncio.Task[None] | None = None self.worker_pool_ready = False self.worker_pool_error: str | None = None self.warmed_worker_pids: list[int] = [] state = ServiceState() app = FastAPI(title="TRELLIS.2 API", version="1.2.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=False, allow_methods=["*"], allow_headers=["*"], ) def _parse_request_payload(payload: str) -> ImageToGlbRequest: try: raw = json.loads(payload) if payload else {} except json.JSONDecodeError as error: raise HTTPException( status_code=400, detail=f"Invalid request JSON: {error}" ) from error try: return ImageToGlbRequest.model_validate(raw) except Exception as error: raise HTTPException(status_code=422, detail=str(error)) from error def _open_image(contents: bytes) -> Image.Image: from io import BytesIO try: image = Image.open(BytesIO(contents)) image.load() return image except UnidentifiedImageError as error: raise HTTPException( status_code=400, detail="Uploaded file is not a supported image" ) from error def _open_image_path(path: Path) -> Image.Image: with Image.open(path) as image: image.load() return image.copy() def _resolve_input_suffix(filename: str | None) -> str: return Path(filename or "input.png").suffix or ".png" def _prepare_job_sync( request_payload: dict[str, Any], filename: str | None, contents: bytes ) -> tuple[str, Path, str]: pil_image = _open_image(contents) job_id, job_dir = state.job_store.create(request_payload, filename) input_path = job_dir / f"input{_resolve_input_suffix(filename)}" save_input_image(pil_image, input_path) return job_id, job_dir, str(input_path) def _mark_job_queued( job_dir: Path, *, queued_at: float, input_path: str, queue_position: int ) -> None: state.job_store.update_status( job_dir, "queued", enqueued_at=queued_at, enqueued_at_ts=queued_at, input_path=input_path, initial_queue_position=queue_position, ) def _mark_job_starting(job_id: str) -> None: state.job_store.update_status( state.job_store.job_dir(job_id), "starting", dispatched_at=time.time(), ) def _worker_process_init() -> None: global _WORKER_RUNTIME runtime = TrellisRuntime() runtime.load() _WORKER_RUNTIME = runtime def _get_worker_runtime() -> TrellisRuntime: global _WORKER_RUNTIME if _WORKER_RUNTIME is None: _worker_process_init() assert _WORKER_RUNTIME is not None return _WORKER_RUNTIME def _run_export(job_dir: Path, request: ImageToGlbRequest) -> Path: output_path = job_dir / "result.glb" result_json = job_dir / "export_result.json" payload_npz = job_dir / "export_payload.npz" payload_meta = job_dir / "export_payload.json" for path in (output_path, result_json): if path.exists(): path.unlink() command = [ sys.executable, str(ROOT_DIR / "export_worker.py"), "--payload-npz", str(payload_npz), "--payload-meta", str(payload_meta), "--output", str(output_path), "--decimation-target", str(request.export.decimation_target), "--texture-size", str(request.export.texture_size), "--remesh", str(int(request.export.remesh)), "--safe-nonremesh-fallback", str(int(request.export.safe_nonremesh_fallback)), "--result-json", str(result_json), ] completed = subprocess.run( command, cwd=str(ROOT_DIR), capture_output=True, text=True, timeout=EXPORT_TIMEOUT_SECONDS, ) if completed.returncode == 0 and output_path.exists(): return output_path if result_json.exists(): result = json.loads(result_json.read_text(encoding="utf-8")) message = result.get("message", "GLB export failed") details = { "traceback": result.get("traceback"), "stdout": completed.stdout, "stderr": completed.stderr, } else: message = "GLB export worker crashed before returning a result" details = { "stdout": completed.stdout, "stderr": completed.stderr, "returncode": completed.returncode, } raise ServiceError( stage="export", error_code="export_failed", message=message, retryable=True, status_code=500, details=details, ) def _run_export_with_retries( job_store: JobStore, job_dir: Path, request: ImageToGlbRequest ) -> tuple[Path, int]: last_error: ServiceError | None = None for attempt in range(1, EXPORT_MAX_ATTEMPTS + 1): job_store.update_status( job_dir, "exporting", export_attempt=attempt, export_max_attempts=EXPORT_MAX_ATTEMPTS, ) try: output_path = _run_export(job_dir, request) return output_path, attempt except subprocess.TimeoutExpired as error: last_error = ServiceError( stage="export", error_code="timeout", message="GLB export timed out", retryable=True, status_code=504, details={ "timeout_seconds": error.timeout, "attempt": attempt, "max_attempts": EXPORT_MAX_ATTEMPTS, }, ) except ServiceError as error: last_error = error last_error.details = { **last_error.details, "attempt": attempt, "max_attempts": EXPORT_MAX_ATTEMPTS, } job_store.update_status( job_dir, "export_retrying" if attempt < EXPORT_MAX_ATTEMPTS else "exporting", export_attempt=attempt, export_max_attempts=EXPORT_MAX_ATTEMPTS, last_export_error=last_error.message if last_error is not None else None, ) if ( last_error is None or not last_error.retryable or attempt >= EXPORT_MAX_ATTEMPTS ): break time.sleep(EXPORT_RETRY_DELAY_SECONDS * attempt) assert last_error is not None raise last_error def _warm_worker_process(_: int) -> int: _get_worker_runtime() return os.getpid() async def _warm_worker_pool() -> list[int]: if state.executor is None: raise RuntimeError("Worker pool is not initialized") loop = asyncio.get_running_loop() warmed_pids = await asyncio.gather( *[ loop.run_in_executor(state.executor, _warm_worker_process, index) for index in range(MAX_CONCURRENT_JOBS) ] ) return sorted(set(warmed_pids)) async def _cleanup_loop() -> None: while True: await asyncio.sleep(JOB_CLEANUP_INTERVAL_SECONDS) await asyncio.to_thread(state.job_store.cleanup_expired) def _job_urls(job_id: str) -> dict[str, str]: return { "status_url": f"/v1/jobs/{job_id}", "result_url": f"/v1/jobs/{job_id}/result", } def _effective_job_status(meta: dict[str, Any], queue_state: dict[str, Any]) -> str: status = str(meta.get("status", "unknown")) if ( queue_state.get("is_starting") or queue_state.get("is_running") ) and status in {"received", "queued", "unknown"}: return "starting" return status def _build_job_response_sync(job_id: str, queue_state: dict[str, Any]) -> dict[str, Any]: if not state.job_store.exists(job_id): raise HTTPException(status_code=404, detail="Job not found") meta = state.job_store.read_meta(job_id) or {} failure = state.job_store.read_error(job_id) response = { "job_id": job_id, "status": _effective_job_status(meta, queue_state), "created_at": meta.get("created_at"), "updated_at": meta.get("updated_at"), "input_filename": meta.get("input_filename"), "worker_pool_ready": state.worker_pool_ready, "worker_pool_error": state.worker_pool_error, "max_concurrent_jobs": MAX_CONCURRENT_JOBS, "queued_jobs": queue_state["queued_jobs"], "active_jobs": queue_state["active_jobs"], "active_job_ids": queue_state["active_job_ids"], "starting_jobs": queue_state["starting_jobs"], "starting_job_ids": queue_state["starting_job_ids"], "running_jobs": queue_state["running_jobs"], "running_job_ids": queue_state["running_job_ids"], **_job_urls(job_id), } queue_position = queue_state.get("queue_position") if queue_position is not None: response["queue_position"] = queue_position for field in ( "queue_wait_ms", "preprocess_ms", "generate_ms", "export_ms", "export_attempt", "export_max_attempts", "duration_ms", "queued_at", "started_at", "completed_at", "output_path", "last_export_error", ): if field in meta: response[field] = meta[field] if response["status"] == "succeeded" and meta.get("output_path"): response["result_ready"] = Path(meta["output_path"]).exists() if failure is not None: response["failure"] = failure return response async def _build_job_response(job_id: str) -> dict[str, Any]: queue_state = await state.queue.snapshot(job_id) return await asyncio.to_thread(_build_job_response_sync, job_id, queue_state) def _build_job_summary( job_id: str, *, queue_position: int | None = None, is_starting: bool = False, is_running: bool = False, ) -> dict[str, Any]: meta = state.job_store.read_meta(job_id) or {} request = state.job_store.read_request(job_id) or {} queue_state = {"is_starting": is_starting, "is_running": is_running} summary = { "job_id": job_id, "status": _effective_job_status(meta, queue_state), "input_filename": meta.get("input_filename"), "created_at": meta.get("created_at"), "updated_at": meta.get("updated_at"), "started_at": meta.get("started_at"), "completed_at": meta.get("completed_at"), "queue_wait_ms": meta.get("queue_wait_ms"), "preprocess_ms": meta.get("preprocess_ms"), "generate_ms": meta.get("generate_ms"), "export_ms": meta.get("export_ms"), "duration_ms": meta.get("duration_ms"), "export_attempt": meta.get("export_attempt"), "export_max_attempts": meta.get("export_max_attempts"), "is_starting": is_starting, "is_running": is_running, "request": request, "meta": meta, } if queue_position is not None: summary["queue_position"] = queue_position return summary def _build_status_response_sync(queue_state: dict[str, Any]) -> dict[str, Any]: starting_job_ids = set(queue_state["starting_job_ids"]) running_job_ids = set(queue_state["running_job_ids"]) active_jobs_detail = [ _build_job_summary( job_id, is_starting=job_id in starting_job_ids, is_running=job_id in running_job_ids, ) for job_id in queue_state["active_job_ids"] ] queued_jobs_detail = [ _build_job_summary( job_id, queue_position=index + len(queue_state["active_job_ids"]), ) for index, job_id in enumerate(queue_state["queued_job_ids"]) ] return { "service": "TRELLIS.2 API", "version": app.version, "submit_endpoint": "/v1/image-to-glb", "status_endpoint": "/v1/jobs/{job_id}", "result_endpoint": "/v1/jobs/{job_id}/result", "ok": state.worker_pool_ready, "worker_pool_ready": state.worker_pool_ready, "worker_pool_error": state.worker_pool_error, "max_concurrent_jobs": MAX_CONCURRENT_JOBS, "active_job_ids": queue_state["active_job_ids"], "starting_job_ids": queue_state["starting_job_ids"], "running_job_ids": queue_state["running_job_ids"], "queued_job_ids": queue_state["queued_job_ids"], "active_jobs": queue_state["active_jobs"], "starting_jobs": queue_state["starting_jobs"], "running_jobs": queue_state["running_jobs"], "queued_jobs": queue_state["queued_jobs"], "active_jobs_detail": active_jobs_detail, "queued_jobs_detail": queued_jobs_detail, } async def _build_status_response() -> dict[str, Any]: queue_state = await state.queue.snapshot() return await asyncio.to_thread(_build_status_response_sync, queue_state) def _process_job_sync_worker(job_id: str) -> None: job_store = JobStore(JOB_DIR) job_dir = job_store.job_dir(job_id) meta = job_store.read_meta(job_id) or {} request_payload = job_store.read_request(job_id) if request_payload is None: job_store.record_failure( job_dir, stage="admission", error_code="missing_request", message="Job request payload is missing", retryable=False, ) return request_model = ImageToGlbRequest.model_validate(request_payload) input_path_str = meta.get("input_path") if not input_path_str: job_store.record_failure( job_dir, stage="admission", error_code="missing_input", message="Uploaded image is missing", retryable=False, ) return input_path = Path(input_path_str) started = time.perf_counter() queue_wait_ms = round( (time.time() - float(meta.get("enqueued_at_ts", time.time()))) * 1000, 2 ) try: job_store.update_status( job_dir, "starting", started_at=time.time(), queue_wait_ms=queue_wait_ms, ) runtime = _get_worker_runtime() preprocess_started = time.perf_counter() job_store.update_status(job_dir, "preprocessing") image = _open_image_path(input_path) preprocessed = runtime.preprocess(image, request_model) preprocess_ms = round((time.perf_counter() - preprocess_started) * 1000, 2) generate_started = time.perf_counter() job_store.update_status(job_dir, "generating", preprocess_ms=preprocess_ms) payload = runtime.generate_export_payload(preprocessed, request_model) save_export_payload(job_dir, payload) generate_ms = round((time.perf_counter() - generate_started) * 1000, 2) export_started = time.perf_counter() job_store.update_status( job_dir, "exporting", preprocess_ms=preprocess_ms, generate_ms=generate_ms, export_attempt=0, export_max_attempts=EXPORT_MAX_ATTEMPTS, ) output_path, export_attempt = _run_export_with_retries( job_store, job_dir, request_model ) export_ms = round((time.perf_counter() - export_started) * 1000, 2) duration_ms = round((time.perf_counter() - started) * 1000, 2) job_store.update_status( job_dir, "succeeded", preprocess_ms=preprocess_ms, generate_ms=generate_ms, export_ms=export_ms, export_attempt=export_attempt, export_max_attempts=EXPORT_MAX_ATTEMPTS, duration_ms=duration_ms, completed_at=time.time(), output_path=str(output_path), ) except subprocess.TimeoutExpired as error: job_store.record_failure( job_dir, stage="export", error_code="timeout", message="GLB export timed out", retryable=True, details={"timeout_seconds": error.timeout}, ) except Exception as error: service_error = classify_runtime_error("export", error) if isinstance(error, ServiceError): service_error = error job_store.record_failure( job_dir, stage=service_error.stage, error_code=service_error.error_code, message=service_error.message, retryable=service_error.retryable, details=service_error.details, ) async def _worker_loop() -> None: loop = asyncio.get_running_loop() while True: job_id = await state.queue.claim_next() try: if state.executor is None: raise RuntimeError("Worker pool is not initialized") await state.queue.mark_starting(job_id) await asyncio.to_thread(_mark_job_starting, job_id) await state.queue.mark_active(job_id) await loop.run_in_executor(state.executor, _process_job_sync_worker, job_id) finally: await state.queue.complete(job_id) @app.on_event("startup") async def startup() -> None: TMP_DIR.mkdir(parents=True, exist_ok=True) await asyncio.to_thread(state.job_store.cleanup_expired) if state.executor is None: state.executor = ProcessPoolExecutor( max_workers=MAX_CONCURRENT_JOBS, mp_context=multiprocessing.get_context("spawn"), initializer=_worker_process_init, ) state.worker_pool_ready = False state.worker_pool_error = None try: state.warmed_worker_pids = await _warm_worker_pool() state.worker_pool_ready = True except Exception as error: state.worker_pool_error = f"{type(error).__name__}: {error}" if not state.worker_tasks: state.worker_tasks = [ asyncio.create_task(_worker_loop()) for _ in range(MAX_CONCURRENT_JOBS) ] if state.cleanup_task is None: state.cleanup_task = asyncio.create_task(_cleanup_loop()) @app.on_event("shutdown") async def shutdown() -> None: if state.cleanup_task is not None: state.cleanup_task.cancel() with contextlib.suppress(asyncio.CancelledError): await state.cleanup_task state.cleanup_task = None for task in state.worker_tasks: task.cancel() for task in state.worker_tasks: with contextlib.suppress(asyncio.CancelledError): await task state.worker_tasks = [] if state.executor is not None: state.executor.shutdown(wait=False, cancel_futures=True) state.executor = None state.worker_pool_ready = False state.worker_pool_error = None state.warmed_worker_pids = [] @app.get("/healthz") async def healthz() -> dict[str, object]: queue_state = await state.queue.snapshot() return { "ok": state.worker_pool_ready, "version": app.version, "worker_pool_ready": state.worker_pool_ready, "worker_pool_error": state.worker_pool_error, "max_concurrent_jobs": MAX_CONCURRENT_JOBS, "active_jobs": queue_state["active_jobs"], "active_job_ids": queue_state["active_job_ids"], "starting_jobs": queue_state["starting_jobs"], "starting_job_ids": queue_state["starting_job_ids"], "running_jobs": queue_state["running_jobs"], "running_job_ids": queue_state["running_job_ids"], "queued_jobs": queue_state["queued_jobs"], "warmed_worker_pids": state.warmed_worker_pids, } @app.get("/api/status") async def api_status() -> dict[str, Any]: return await _build_status_response() @app.get("/") async def root(): return FileResponse(ROOT_DIR / "frontend" / "index.html") @app.post("/v1/image-to-glb") async def image_to_glb( image: Annotated[UploadFile, File(...)], request_payload: str = Form("{}", alias="request"), ): if not state.worker_pool_ready: raise HTTPException( status_code=503, detail=state.worker_pool_error or "Worker pool is still warming up", ) request_model = _parse_request_payload(request_payload) contents = await image.read() job_id, job_dir, input_path = await asyncio.to_thread( _prepare_job_sync, request_model.model_dump(mode="json"), image.filename, contents, ) queued_at = time.time() queue_position = await state.queue.enqueue(job_id) await asyncio.to_thread( _mark_job_queued, job_dir, queued_at=queued_at, input_path=input_path, queue_position=queue_position, ) response = await _build_job_response(job_id) return JSONResponse(status_code=202, content=response) @app.get("/v1/jobs/{job_id}") async def get_job(job_id: str) -> dict[str, Any]: return await _build_job_response(job_id) @app.get("/v1/jobs/{job_id}/result") async def get_job_result(job_id: str): response = await _build_job_response(job_id) status = response["status"] if status == "succeeded": output_path = response.get("output_path") if not output_path or not Path(output_path).exists(): raise HTTPException(status_code=404, detail="Job result is missing") return FileResponse( path=output_path, media_type="model/gltf-binary", filename=f"{job_id}.glb", headers={"X-Job-Id": job_id}, ) if status == "failed": return JSONResponse(status_code=409, content=response) return JSONResponse(status_code=202, content=response)