hrmndev commited on
Commit
a778497
·
verified ·
1 Parent(s): 738c310

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +35 -59
handler.py CHANGED
@@ -1,71 +1,47 @@
1
- import os
2
- import subprocess
3
- import sys
4
  import torch
5
  import base64
6
- import io
7
  from PIL import Image
8
- from typing import Dict, Any
 
 
9
 
10
  class EndpointHandler:
11
- def __init__(self, model_dir: str):
12
- # 1. Run custom installation for complex CUDA kernels
13
- self._setup()
14
-
15
- # 2. Now import Trellis (must be after _setup)
16
- from trellis.pipelines import TrellisImageTo3DPipeline
17
- from trellis.utils import postprocessing_utils
18
- self.postprocessing_utils = postprocessing_utils
19
-
20
- # 3. Initialize pipeline
21
- os.environ['SPCONV_ALGO'] = 'native'
22
  self.pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
23
  self.pipeline.cuda()
24
-
25
- def _setup(self):
26
- """Install dependencies that require --no-build-isolation and specific order."""
27
- try:
28
- import trellis
29
- import nvdiffrast
30
- print("Dependencies already satisfied.")
31
- except ImportError:
32
- print("Installing custom CUDA extensions and TRELLIS...")
33
- # Install nvdiffrast and trellis directly from GitHub
34
- # --no-build-isolation is required so they can see the installed PyTorch
35
- packages = [
36
- "git+https://github.com/NVlabs/nvdiffrast.git",
37
- "git+https://github.com/microsoft/TRELLIS.git"
38
- ]
39
- for pkg in packages:
40
- subprocess.check_call([
41
- sys.executable, "-m", "pip", "install", pkg, "--no-build-isolation"
42
- ])
43
-
44
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
45
  inputs = data.pop("inputs", data)
46
- params = data.pop("params", {})
47
-
48
- if isinstance(inputs, str):
49
- image = Image.open(io.BytesIO(base64.b64decode(inputs))).convert("RGB")
50
- else:
51
- image = inputs
52
 
53
- outputs = self.pipeline.run(
 
 
 
 
 
54
  image,
55
- seed=params.get("seed", 42),
56
- sparse_structure_sampler_params=params.get("sparse_params", {"steps": 12, "cfg_strength": 7.5}),
57
- slat_sampler_params=params.get("slat_params", {"steps": 12, "cfg_strength": 3.0})
58
  )
59
 
60
- glb = self.postprocessing_utils.to_glb(
61
- outputs['gaussian'][0],
62
- outputs['mesh'][0],
63
- simplify=params.get("simplify", 0.95),
64
- texture_size=params.get("texture_size", 1024)
65
- )
66
-
67
- buffered = io.BytesIO()
68
- glb.export(buffered)
69
- glb_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
70
-
71
- return {"mesh_base64": glb_str, "format": "glb"}
 
 
 
 
1
  import torch
2
  import base64
3
+ import os
4
  from PIL import Image
5
+ from io import BytesIO
6
+ from trellis.pipelines import TrellisImageTo3DPipeline
7
+ from trellis.utils import postprocessing_utils
8
 
9
  class EndpointHandler:
10
+ def __init__(self, model_dir):
11
+ # Load the pipeline from the local directory
 
 
 
 
 
 
 
 
 
12
  self.pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
13
  self.pipeline.cuda()
14
+
15
+ def __call__(self, data):
16
+ """
17
+ Args:
18
+ data (:obj:`dict`):
19
+ - "inputs": The base64 encoded image or URL.
20
+ - "params": Dictionary of generation parameters (optional).
21
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  inputs = data.pop("inputs", data)
23
+ params = data.pop("parameters", {})
 
 
 
 
 
24
 
25
+ # Decode image
26
+ image = Image.open(BytesIO(base64.b64decode(inputs)))
27
+
28
+ # Run Inference
29
+ # Note: You can adjust 'steps' or 'cfg' via params
30
+ outputs = self.pipeline(
31
  image,
32
+ num_samples=1,
33
+ return_flags=["mesh"],
34
+ **params
35
  )
36
 
37
+ # Process mesh to GLB
38
+ mesh = outputs['mesh'][0]
39
+ glb_io = BytesIO()
40
+ mesh.export(glb_io, file_type='glb')
41
+ glb_io.seek(0)
42
+
43
+ # Encode GLB to base64 for the response
44
+ return {
45
+ "mesh_base64": base64.b64encode(glb_io.getvalue()).decode("utf-8"),
46
+ "format": "glb"
47
+ }