Files changed (1) hide show
  1. handler.py +66 -0
handler.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import base64
4
+ import io
5
+ from PIL import Image
6
+ from trellis.pipelines import TrellisImageTo3DPipeline
7
+ from trellis.utils import postprocessing_utils
8
+ from typing import Dict, Any
9
+
10
+ class EndpointHandler:
11
+ def __init__(self, model_dir: str):
12
+ """
13
+ Initialize the TRELLIS pipeline.
14
+ """
15
+ # Set algorithm to 'native' for faster startup on Inference Endpoints
16
+ os.environ['SPCONV_ALGO'] = 'native'
17
+
18
+ # Load the pipeline from the local directory or HF hub
19
+ # 'microsoft/TRELLIS-image-large' is the standard model
20
+ self.pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
21
+ self.pipeline.cuda()
22
+ self.pipeline.preprocess_image = self.pipeline.preprocess_image # Ensure visibility
23
+
24
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
25
+ """
26
+ Args:
27
+ data (:obj:`Dict[str, Any]`):
28
+ - "inputs": The image as a base64 string or URL.
29
+ - "params": Dictionary of optional parameters (seed, steps, etc.)
30
+ """
31
+ inputs = data.pop("inputs", data)
32
+ params = data.pop("params", {})
33
+
34
+ # 1. Decode Image
35
+ if isinstance(inputs, str):
36
+ image = Image.open(io.BytesIO(base64.b64decode(inputs))).convert("RGB")
37
+ else:
38
+ image = inputs
39
+
40
+ # 2. Run Pipeline
41
+ # You can adjust 'sparse_structure_sampler_params' and 'slat_sampler_params' here
42
+ outputs = self.pipeline.run(
43
+ image,
44
+ seed=params.get("seed", 42),
45
+ sparse_structure_sampler_params=params.get("sparse_params", {"steps": 12, "cfg_strength": 7.5}),
46
+ slat_sampler_params=params.get("slat_params", {"steps": 12, "cfg_strength": 3.0})
47
+ )
48
+
49
+ # 3. Post-process to GLB
50
+ # We extract the mesh and simplify it for export
51
+ glb = postprocessing_utils.to_glb(
52
+ outputs['gaussian'][0],
53
+ outputs['mesh'][0],
54
+ simplify=params.get("simplify", 0.95),
55
+ texture_size=params.get("texture_size", 1024)
56
+ )
57
+
58
+ # 4. Encode to Base64 for response
59
+ buffered = io.BytesIO()
60
+ glb.export(buffered)
61
+ glb_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
62
+
63
+ return {
64
+ "mesh_base64": glb_str,
65
+ "format": "glb"
66
+ }