trellis3dtest / handler.py
hrmndev's picture
Update handler.py
a778497 verified
import torch
import base64
import os
from PIL import Image
from io import BytesIO
from trellis.pipelines import TrellisImageTo3DPipeline
from trellis.utils import postprocessing_utils
class EndpointHandler:
def __init__(self, model_dir):
# Load the pipeline from the local directory
self.pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
self.pipeline.cuda()
def __call__(self, data):
"""
Args:
data (:obj:`dict`):
- "inputs": The base64 encoded image or URL.
- "params": Dictionary of generation parameters (optional).
"""
inputs = data.pop("inputs", data)
params = data.pop("parameters", {})
# Decode image
image = Image.open(BytesIO(base64.b64decode(inputs)))
# Run Inference
# Note: You can adjust 'steps' or 'cfg' via params
outputs = self.pipeline(
image,
num_samples=1,
return_flags=["mesh"],
**params
)
# Process mesh to GLB
mesh = outputs['mesh'][0]
glb_io = BytesIO()
mesh.export(glb_io, file_type='glb')
glb_io.seek(0)
# Encode GLB to base64 for the response
return {
"mesh_base64": base64.b64encode(glb_io.getvalue()).decode("utf-8"),
"format": "glb"
}