File size: 2,387 Bytes
35a4589 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import os
import torch
import base64
import io
from PIL import Image
from trellis.pipelines import TrellisImageTo3DPipeline
from trellis.utils import postprocessing_utils
from typing import Dict, Any
class EndpointHandler:
def __init__(self, model_dir: str):
"""
Initialize the TRELLIS pipeline.
"""
# Set algorithm to 'native' for faster startup on Inference Endpoints
os.environ['SPCONV_ALGO'] = 'native'
# Load the pipeline from the local directory or HF hub
# 'microsoft/TRELLIS-image-large' is the standard model
self.pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
self.pipeline.cuda()
self.pipeline.preprocess_image = self.pipeline.preprocess_image # Ensure visibility
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Args:
data (:obj:`Dict[str, Any]`):
- "inputs": The image as a base64 string or URL.
- "params": Dictionary of optional parameters (seed, steps, etc.)
"""
inputs = data.pop("inputs", data)
params = data.pop("params", {})
# 1. Decode Image
if isinstance(inputs, str):
image = Image.open(io.BytesIO(base64.b64decode(inputs))).convert("RGB")
else:
image = inputs
# 2. Run Pipeline
# You can adjust 'sparse_structure_sampler_params' and 'slat_sampler_params' here
outputs = self.pipeline.run(
image,
seed=params.get("seed", 42),
sparse_structure_sampler_params=params.get("sparse_params", {"steps": 12, "cfg_strength": 7.5}),
slat_sampler_params=params.get("slat_params", {"steps": 12, "cfg_strength": 3.0})
)
# 3. Post-process to GLB
# We extract the mesh and simplify it for export
glb = postprocessing_utils.to_glb(
outputs['gaussian'][0],
outputs['mesh'][0],
simplify=params.get("simplify", 0.95),
texture_size=params.get("texture_size", 1024)
)
# 4. Encode to Base64 for response
buffered = io.BytesIO()
glb.export(buffered)
glb_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return {
"mesh_base64": glb_str,
"format": "glb"
}
|