trellis-text-endpoint / handler.py
jsnavarroo's picture
New try, TRELLIS HF inference handler
ee54f9e
# handler.py
import os
import sys
import tempfile
# Asegurar que el repo esté en el PYTHONPATH
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
if ROOT_DIR not in sys.path:
sys.path.insert(0, ROOT_DIR)
import torch
import trellis.pipelines as trellis_pipelines
from trellis.representations.mesh import SparseFeatures2Mesh
class EndpointHandler:
"""
Hugging Face custom inference handler for TRELLIS
"""
def __init__(self, path: str):
"""
HF pasa `path` como la ruta local del repo clonado (e.g. /repository)
"""
# Dispositivo: HF usará GPU automáticamente si está disponible
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# 🔑 CLAVE: usar PATH LOCAL, no repo_id remoto
# Así las rutas relativas ckpts/... funcionan correctamente
self.pipeline = trellis_pipelines.from_pretrained(path)
self.pipeline.to(self.device)
self.pipeline.eval()
def __call__(self, data: dict):
"""
Payload esperado:
{
"inputs": "<prompt>",
"resolution": 64
}
"""
prompt = data["inputs"]
resolution = data.get("resolution", 64)
with torch.no_grad():
# Texto → campo implícito 3D
implicit_field = self.pipeline(prompt)
# Campo implícito → malla
# Forzamos backend estable (sin Kaolin)
mesh_extractor = SparseFeatures2Mesh(
resolution=resolution,
backend="marching_cubes"
)
mesh_result = mesh_extractor(implicit_field)
mesh = mesh_result.mesh
# Guardar malla como OBJ temporal
tmp = tempfile.NamedTemporaryFile(
suffix=".obj",
delete=False
)
mesh.export(tmp.name)
return {
"mesh_path": tmp.name,
"vertices": int(len(mesh.vertices)),
"faces": int(len(mesh.faces))
}