hrmndev commited on
Commit
738c310
·
verified ·
1 Parent(s): adb6036

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +32 -27
handler.py CHANGED
@@ -1,44 +1,55 @@
1
  import os
 
 
2
  import torch
3
  import base64
4
  import io
5
  from PIL import Image
6
- from trellis.pipelines import TrellisImageTo3DPipeline
7
- from trellis.utils import postprocessing_utils
8
  from typing import Dict, Any
9
 
10
  class EndpointHandler:
11
  def __init__(self, model_dir: str):
12
- """
13
- Initialize the TRELLIS pipeline.
14
- """
15
- # Set algorithm to 'native' for faster startup on Inference Endpoints
16
- os.environ['SPCONV_ALGO'] = 'native'
17
 
18
- # Load the pipeline from the local directory or HF hub
19
- # 'microsoft/TRELLIS-image-large' is the standard model
 
 
 
 
 
20
  self.pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
21
  self.pipeline.cuda()
22
- self.pipeline.preprocess_image = self.pipeline.preprocess_image # Ensure visibility
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
25
- """
26
- Args:
27
- data (:obj:`Dict[str, Any]`):
28
- - "inputs": The image as a base64 string or URL.
29
- - "params": Dictionary of optional parameters (seed, steps, etc.)
30
- """
31
  inputs = data.pop("inputs", data)
32
  params = data.pop("params", {})
33
 
34
- # 1. Decode Image
35
  if isinstance(inputs, str):
36
  image = Image.open(io.BytesIO(base64.b64decode(inputs))).convert("RGB")
37
  else:
38
  image = inputs
39
 
40
- # 2. Run Pipeline
41
- # You can adjust 'sparse_structure_sampler_params' and 'slat_sampler_params' here
42
  outputs = self.pipeline.run(
43
  image,
44
  seed=params.get("seed", 42),
@@ -46,21 +57,15 @@ class EndpointHandler:
46
  slat_sampler_params=params.get("slat_params", {"steps": 12, "cfg_strength": 3.0})
47
  )
48
 
49
- # 3. Post-process to GLB
50
- # We extract the mesh and simplify it for export
51
- glb = postprocessing_utils.to_glb(
52
  outputs['gaussian'][0],
53
  outputs['mesh'][0],
54
  simplify=params.get("simplify", 0.95),
55
  texture_size=params.get("texture_size", 1024)
56
  )
57
 
58
- # 4. Encode to Base64 for response
59
  buffered = io.BytesIO()
60
  glb.export(buffered)
61
  glb_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
62
 
63
- return {
64
- "mesh_base64": glb_str,
65
- "format": "glb"
66
- }
 
1
  import os
2
+ import subprocess
3
+ import sys
4
  import torch
5
  import base64
6
  import io
7
  from PIL import Image
 
 
8
  from typing import Dict, Any
9
 
10
  class EndpointHandler:
11
  def __init__(self, model_dir: str):
12
+ # 1. Run custom installation for complex CUDA kernels
13
+ self._setup()
 
 
 
14
 
15
+ # 2. Now import Trellis (must be after _setup)
16
+ from trellis.pipelines import TrellisImageTo3DPipeline
17
+ from trellis.utils import postprocessing_utils
18
+ self.postprocessing_utils = postprocessing_utils
19
+
20
+ # 3. Initialize pipeline
21
+ os.environ['SPCONV_ALGO'] = 'native'
22
  self.pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
23
  self.pipeline.cuda()
24
+
25
+ def _setup(self):
26
+ """Install dependencies that require --no-build-isolation and specific order."""
27
+ try:
28
+ import trellis
29
+ import nvdiffrast
30
+ print("Dependencies already satisfied.")
31
+ except ImportError:
32
+ print("Installing custom CUDA extensions and TRELLIS...")
33
+ # Install nvdiffrast and trellis directly from GitHub
34
+ # --no-build-isolation is required so they can see the installed PyTorch
35
+ packages = [
36
+ "git+https://github.com/NVlabs/nvdiffrast.git",
37
+ "git+https://github.com/microsoft/TRELLIS.git"
38
+ ]
39
+ for pkg in packages:
40
+ subprocess.check_call([
41
+ sys.executable, "-m", "pip", "install", pkg, "--no-build-isolation"
42
+ ])
43
 
44
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
 
 
 
 
 
 
45
  inputs = data.pop("inputs", data)
46
  params = data.pop("params", {})
47
 
 
48
  if isinstance(inputs, str):
49
  image = Image.open(io.BytesIO(base64.b64decode(inputs))).convert("RGB")
50
  else:
51
  image = inputs
52
 
 
 
53
  outputs = self.pipeline.run(
54
  image,
55
  seed=params.get("seed", 42),
 
57
  slat_sampler_params=params.get("slat_params", {"steps": 12, "cfg_strength": 3.0})
58
  )
59
 
60
+ glb = self.postprocessing_utils.to_glb(
 
 
61
  outputs['gaussian'][0],
62
  outputs['mesh'][0],
63
  simplify=params.get("simplify", 0.95),
64
  texture_size=params.get("texture_size", 1024)
65
  )
66
 
 
67
  buffered = io.BytesIO()
68
  glb.export(buffered)
69
  glb_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
70
 
71
+ return {"mesh_base64": glb_str, "format": "glb"}