File size: 1,974 Bytes
ee54f9e
 
 
93e4a74
 
ee54f9e
 
 
 
 
 
 
 
 
93e4a74
 
 
ee54f9e
 
 
 
 
 
 
 
 
93e4a74
 
ee54f9e
 
 
93e4a74
ee54f9e
 
93e4a74
ee54f9e
 
 
 
 
 
 
 
93e4a74
ee54f9e
93e4a74
 
ee54f9e
 
93e4a74
ee54f9e
 
 
 
 
93e4a74
 
ee54f9e
 
 
 
93e4a74
ee54f9e
 
93e4a74
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# handler.py
import os
import sys
import tempfile

# Asegurar que el repo esté en el PYTHONPATH
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
if ROOT_DIR not in sys.path:
    sys.path.insert(0, ROOT_DIR)

import torch

import trellis.pipelines as trellis_pipelines
from trellis.representations.mesh import SparseFeatures2Mesh


class EndpointHandler:
    """
    Hugging Face custom inference handler for TRELLIS
    """

    def __init__(self, path: str):
        """
        HF pasa `path` como la ruta local del repo clonado (e.g. /repository)
        """
        # Dispositivo: HF usará GPU automáticamente si está disponible
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        # 🔑 CLAVE: usar PATH LOCAL, no repo_id remoto
        # Así las rutas relativas ckpts/... funcionan correctamente
        self.pipeline = trellis_pipelines.from_pretrained(path)

        self.pipeline.to(self.device)
        self.pipeline.eval()

    def __call__(self, data: dict):
        """
        Payload esperado:
        {
          "inputs": "<prompt>",
          "resolution": 64
        }
        """
        prompt = data["inputs"]
        resolution = data.get("resolution", 64)

        with torch.no_grad():
            # Texto → campo implícito 3D
            implicit_field = self.pipeline(prompt)

        # Campo implícito → malla
        # Forzamos backend estable (sin Kaolin)
        mesh_extractor = SparseFeatures2Mesh(
            resolution=resolution,
            backend="marching_cubes"
        )

        mesh_result = mesh_extractor(implicit_field)
        mesh = mesh_result.mesh

        # Guardar malla como OBJ temporal
        tmp = tempfile.NamedTemporaryFile(
            suffix=".obj",
            delete=False
        )
        mesh.export(tmp.name)

        return {
            "mesh_path": tmp.name,
            "vertices": int(len(mesh.vertices)),
            "faces": int(len(mesh.faces))
        }