| | |
| | import os |
| | import sys |
| | import tempfile |
| |
|
| | |
| | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| | if ROOT_DIR not in sys.path: |
| | sys.path.insert(0, ROOT_DIR) |
| |
|
| | import torch |
| |
|
| | import trellis.pipelines as trellis_pipelines |
| | from trellis.representations.mesh import SparseFeatures2Mesh |
| |
|
| |
|
| | class EndpointHandler: |
| | """ |
| | Hugging Face custom inference handler for TRELLIS |
| | """ |
| |
|
| | def __init__(self, path: str): |
| | """ |
| | HF pasa `path` como la ruta local del repo clonado (e.g. /repository) |
| | """ |
| | |
| | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | |
| | |
| | self.pipeline = trellis_pipelines.from_pretrained(path) |
| |
|
| | self.pipeline.to(self.device) |
| | self.pipeline.eval() |
| |
|
| | def __call__(self, data: dict): |
| | """ |
| | Payload esperado: |
| | { |
| | "inputs": "<prompt>", |
| | "resolution": 64 |
| | } |
| | """ |
| | prompt = data["inputs"] |
| | resolution = data.get("resolution", 64) |
| |
|
| | with torch.no_grad(): |
| | |
| | implicit_field = self.pipeline(prompt) |
| |
|
| | |
| | |
| | mesh_extractor = SparseFeatures2Mesh( |
| | resolution=resolution, |
| | backend="marching_cubes" |
| | ) |
| |
|
| | mesh_result = mesh_extractor(implicit_field) |
| | mesh = mesh_result.mesh |
| |
|
| | |
| | tmp = tempfile.NamedTemporaryFile( |
| | suffix=".obj", |
| | delete=False |
| | ) |
| | mesh.export(tmp.name) |
| |
|
| | return { |
| | "mesh_path": tmp.name, |
| | "vertices": int(len(mesh.vertices)), |
| | "faces": int(len(mesh.faces)) |
| | } |
| |
|