from __future__ import annotations import traceback import json from pathlib import Path from typing import Any import runtime_env # noqa: F401 import numpy as np import torch from PIL import Image from schemas import ImageToGlbRequest from trellis2.pipelines import Trellis2ImageTo3DPipeline PIPELINE_ID = "microsoft/TRELLIS.2-4B" class ServiceError(Exception): def __init__( self, *, stage: str, error_code: str, message: str, retryable: bool, status_code: int = 500, details: dict[str, Any] | None = None, ): super().__init__(message) self.stage = stage self.error_code = error_code self.message = message self.retryable = retryable self.status_code = status_code self.details = details or {} def to_dict(self, job_id: str) -> dict[str, Any]: return { "job_id": job_id, "stage": self.stage, "error_code": self.error_code, "retryable": self.retryable, "message": self.message, "details": self.details, } def is_fatal_cuda_error(error: BaseException) -> bool: text = str(error).lower() needles = [ "illegal memory access", "device-side assert", "cuda error", "[cumesh] cuda error", ] return any(needle in text for needle in needles) def classify_runtime_error(stage: str, error: BaseException) -> ServiceError: if isinstance(error, ServiceError): return error retryable = stage == "export" or not is_fatal_cuda_error(error) error_code = f"{stage}_failed" status_code = 500 if is_fatal_cuda_error(error): error_code = f"{stage}_cuda_fatal" return ServiceError( stage=stage, error_code=error_code, message=f"{type(error).__name__}: {error}", retryable=retryable, status_code=status_code, details={"traceback": traceback.format_exc()}, ) class TrellisRuntime: def __init__(self) -> None: self.pipeline: Trellis2ImageTo3DPipeline | None = None self.unhealthy_reason: str | None = None @property def is_healthy(self) -> bool: return self.unhealthy_reason is None def load(self) -> None: if self.pipeline is not None: return pipeline = Trellis2ImageTo3DPipeline.from_pretrained(PIPELINE_ID) pipeline.low_vram = False pipeline.cuda() self.pipeline = pipeline def mark_unhealthy(self, reason: str) -> None: self.unhealthy_reason = reason def ensure_ready(self) -> Trellis2ImageTo3DPipeline: if not self.is_healthy: raise ServiceError( stage="generate", error_code="runtime_unhealthy", message=self.unhealthy_reason or "Runtime unavailable", retryable=False, status_code=503, ) self.load() assert self.pipeline is not None return self.pipeline def preprocess(self, image: Image.Image, request: ImageToGlbRequest) -> Image.Image: pipeline = self.ensure_ready() if request.preprocess.background_mode == "none": if image.mode == "RGBA": image_np = np.array(image).astype(np.float32) / 255.0 rgb = image_np[:, :, :3] * image_np[:, :, 3:4] return Image.fromarray((rgb * 255).astype(np.uint8), mode="RGB") return image.convert("RGB") try: return pipeline.preprocess_image(image) except Exception as error: raise classify_runtime_error("preprocess", error) from error def generate_export_payload( self, image: Image.Image, request: ImageToGlbRequest ) -> dict[str, Any]: pipeline = self.ensure_ready() generation = request.generation pipeline_type = { "512": "512", "1024": "1024_cascade", "1536": "1536_cascade", }[generation.resolution] try: outputs, latents = pipeline.run( image, seed=generation.seed, preprocess_image=False, sparse_structure_sampler_params={ "steps": generation.ss_sampling_steps, "guidance_strength": generation.ss_guidance_strength, "guidance_rescale": generation.ss_guidance_rescale, "rescale_t": generation.ss_rescale_t, }, shape_slat_sampler_params={ "steps": generation.shape_slat_sampling_steps, "guidance_strength": generation.shape_slat_guidance_strength, "guidance_rescale": generation.shape_slat_guidance_rescale, "rescale_t": generation.shape_slat_rescale_t, }, tex_slat_sampler_params={ "steps": generation.tex_slat_sampling_steps, "guidance_strength": generation.tex_slat_guidance_strength, "guidance_rescale": generation.tex_slat_guidance_rescale, "rescale_t": generation.tex_slat_rescale_t, }, pipeline_type=pipeline_type, return_latent=True, ) torch.cuda.synchronize() mesh = outputs[0] _, _, resolution = latents payload = self._mesh_to_payload(mesh, resolution) del outputs del latents del mesh torch.cuda.empty_cache() return payload except Exception as error: if is_fatal_cuda_error(error): self.mark_unhealthy(f"Fatal CUDA error during generation: {error}") raise classify_runtime_error("generate", error) from error @staticmethod def _mesh_to_payload(mesh: Any, resolution: int) -> dict[str, Any]: return { "vertices": mesh.vertices.detach().cpu().numpy().astype(np.float32), "faces": mesh.faces.detach().cpu().numpy().astype(np.int32), "attrs": mesh.attrs.detach().cpu().numpy().astype(np.float32), "coords": mesh.coords.detach().cpu().numpy().astype(np.int32), "resolution": int(resolution), "attr_layout": { key: {"start": value.start, "stop": value.stop} for key, value in mesh.layout.items() }, } def save_input_image(image: Image.Image, path: Path) -> None: image.save(path) def save_export_payload(job_dir: Path, payload: dict[str, Any]) -> tuple[Path, Path]: npz_path = job_dir / "export_payload.npz" meta_path = job_dir / "export_payload.json" np.savez_compressed( npz_path, vertices=payload["vertices"], faces=payload["faces"], attrs=payload["attrs"], coords=payload["coords"], ) meta_path.write_text( json.dumps( { "attr_layout": payload["attr_layout"], "resolution": payload["resolution"], "aabb": [[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], }, indent=2, sort_keys=True, ), encoding="utf-8", ) return npz_path, meta_path