File size: 1,974 Bytes
ee54f9e 93e4a74 ee54f9e 93e4a74 ee54f9e 93e4a74 ee54f9e 93e4a74 ee54f9e 93e4a74 ee54f9e 93e4a74 ee54f9e 93e4a74 ee54f9e 93e4a74 ee54f9e 93e4a74 ee54f9e 93e4a74 ee54f9e 93e4a74 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | # handler.py
import os
import sys
import tempfile
# Asegurar que el repo esté en el PYTHONPATH
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
if ROOT_DIR not in sys.path:
sys.path.insert(0, ROOT_DIR)
import torch
import trellis.pipelines as trellis_pipelines
from trellis.representations.mesh import SparseFeatures2Mesh
class EndpointHandler:
"""
Hugging Face custom inference handler for TRELLIS
"""
def __init__(self, path: str):
"""
HF pasa `path` como la ruta local del repo clonado (e.g. /repository)
"""
# Dispositivo: HF usará GPU automáticamente si está disponible
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# 🔑 CLAVE: usar PATH LOCAL, no repo_id remoto
# Así las rutas relativas ckpts/... funcionan correctamente
self.pipeline = trellis_pipelines.from_pretrained(path)
self.pipeline.to(self.device)
self.pipeline.eval()
def __call__(self, data: dict):
"""
Payload esperado:
{
"inputs": "<prompt>",
"resolution": 64
}
"""
prompt = data["inputs"]
resolution = data.get("resolution", 64)
with torch.no_grad():
# Texto → campo implícito 3D
implicit_field = self.pipeline(prompt)
# Campo implícito → malla
# Forzamos backend estable (sin Kaolin)
mesh_extractor = SparseFeatures2Mesh(
resolution=resolution,
backend="marching_cubes"
)
mesh_result = mesh_extractor(implicit_field)
mesh = mesh_result.mesh
# Guardar malla como OBJ temporal
tmp = tempfile.NamedTemporaryFile(
suffix=".obj",
delete=False
)
mesh.export(tmp.name)
return {
"mesh_path": tmp.name,
"vertices": int(len(mesh.vertices)),
"faces": int(len(mesh.faces))
}
|