import os import torch import base64 import io from PIL import Image from trellis.pipelines import TrellisImageTo3DPipeline from trellis.utils import postprocessing_utils from typing import Dict, Any class EndpointHandler: def __init__(self, model_dir: str): """ Initialize the TRELLIS pipeline. """ # Set algorithm to 'native' for faster startup on Inference Endpoints os.environ['SPCONV_ALGO'] = 'native' # Load the pipeline from the local directory or HF hub # 'microsoft/TRELLIS-image-large' is the standard model self.pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large") self.pipeline.cuda() self.pipeline.preprocess_image = self.pipeline.preprocess_image # Ensure visibility def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Args: data (:obj:`Dict[str, Any]`): - "inputs": The image as a base64 string or URL. - "params": Dictionary of optional parameters (seed, steps, etc.) """ inputs = data.pop("inputs", data) params = data.pop("params", {}) # 1. Decode Image if isinstance(inputs, str): image = Image.open(io.BytesIO(base64.b64decode(inputs))).convert("RGB") else: image = inputs # 2. Run Pipeline # You can adjust 'sparse_structure_sampler_params' and 'slat_sampler_params' here outputs = self.pipeline.run( image, seed=params.get("seed", 42), sparse_structure_sampler_params=params.get("sparse_params", {"steps": 12, "cfg_strength": 7.5}), slat_sampler_params=params.get("slat_params", {"steps": 12, "cfg_strength": 3.0}) ) # 3. Post-process to GLB # We extract the mesh and simplify it for export glb = postprocessing_utils.to_glb( outputs['gaussian'][0], outputs['mesh'][0], simplify=params.get("simplify", 0.95), texture_size=params.get("texture_size", 1024) ) # 4. Encode to Base64 for response buffered = io.BytesIO() glb.export(buffered) glb_str = base64.b64encode(buffered.getvalue()).decode("utf-8") return { "mesh_base64": glb_str, "format": "glb" }