|
|
import os |
|
|
import torch |
|
|
import base64 |
|
|
import io |
|
|
from PIL import Image |
|
|
from trellis.pipelines import TrellisImageTo3DPipeline |
|
|
from trellis.utils import postprocessing_utils |
|
|
from typing import Dict, Any |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, model_dir: str): |
|
|
""" |
|
|
Initialize the TRELLIS pipeline. |
|
|
""" |
|
|
|
|
|
os.environ['SPCONV_ALGO'] = 'native' |
|
|
|
|
|
|
|
|
|
|
|
self.pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large") |
|
|
self.pipeline.cuda() |
|
|
self.pipeline.preprocess_image = self.pipeline.preprocess_image |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Args: |
|
|
data (:obj:`Dict[str, Any]`): |
|
|
- "inputs": The image as a base64 string or URL. |
|
|
- "params": Dictionary of optional parameters (seed, steps, etc.) |
|
|
""" |
|
|
inputs = data.pop("inputs", data) |
|
|
params = data.pop("params", {}) |
|
|
|
|
|
|
|
|
if isinstance(inputs, str): |
|
|
image = Image.open(io.BytesIO(base64.b64decode(inputs))).convert("RGB") |
|
|
else: |
|
|
image = inputs |
|
|
|
|
|
|
|
|
|
|
|
outputs = self.pipeline.run( |
|
|
image, |
|
|
seed=params.get("seed", 42), |
|
|
sparse_structure_sampler_params=params.get("sparse_params", {"steps": 12, "cfg_strength": 7.5}), |
|
|
slat_sampler_params=params.get("slat_params", {"steps": 12, "cfg_strength": 3.0}) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
glb = postprocessing_utils.to_glb( |
|
|
outputs['gaussian'][0], |
|
|
outputs['mesh'][0], |
|
|
simplify=params.get("simplify", 0.95), |
|
|
texture_size=params.get("texture_size", 1024) |
|
|
) |
|
|
|
|
|
|
|
|
buffered = io.BytesIO() |
|
|
glb.export(buffered) |
|
|
glb_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
|
|
|
return { |
|
|
"mesh_base64": glb_str, |
|
|
"format": "glb" |
|
|
} |
|
|
|