TRELLIS.2 / export_worker.py
choephix's picture
Add safe non-remesh fallback option to GLB export
44b4dd3
from __future__ import annotations
import argparse
import json
import traceback
from pathlib import Path
import runtime_env # noqa: F401
import numpy as np
import torch
from glb_export import export_glb as _export_glb
def _deserialize_attr_layout(payload: dict[str, dict[str, int]]) -> dict[str, slice]:
return {key: slice(value["start"], value["stop"]) for key, value in payload.items()}
def export_glb(
*,
payload_npz: Path,
payload_meta: Path,
output_path: Path,
decimation_target: int,
texture_size: int,
remesh: bool = True,
safe_nonremesh_fallback: bool | None = None,
) -> None:
arrays = np.load(payload_npz)
meta = json.loads(payload_meta.read_text(encoding="utf-8"))
attr_layout = _deserialize_attr_layout(meta["attr_layout"])
resolution = int(meta["resolution"])
aabb = meta["aabb"]
vertices = torch.from_numpy(arrays["vertices"]).cuda()
faces = torch.from_numpy(arrays["faces"]).cuda()
attr_volume = torch.from_numpy(arrays["attrs"]).cuda()
coords = torch.from_numpy(arrays["coords"]).cuda()
torch.cuda.synchronize()
glb = _export_glb(
vertices=vertices,
faces=faces,
attr_volume=attr_volume,
coords=coords,
attr_layout=attr_layout,
grid_size=resolution,
aabb=aabb,
decimation_target=decimation_target,
texture_size=texture_size,
remesh=remesh,
safe_nonremesh_fallback=safe_nonremesh_fallback,
use_tqdm=False,
)
torch.cuda.synchronize()
glb.export(str(output_path), extension_webp=True)
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument("--payload-npz", required=True)
parser.add_argument("--payload-meta", required=True)
parser.add_argument("--output", required=True)
parser.add_argument("--decimation-target", type=int, required=True)
parser.add_argument("--texture-size", type=int, required=True)
parser.add_argument(
"--remesh", type=int, default=1, help="1 = remesh (default), 0 = no remesh"
)
parser.add_argument(
"--safe-nonremesh-fallback",
type=int,
default=-1,
help="1 = safe fallback, 0 = upstream raw, -1 = use env var (default)",
)
parser.add_argument("--result-json", required=True)
args = parser.parse_args()
safe_val: bool | None = None
if args.safe_nonremesh_fallback >= 0:
safe_val = bool(args.safe_nonremesh_fallback)
result_path = Path(args.result_json)
try:
export_glb(
payload_npz=Path(args.payload_npz),
payload_meta=Path(args.payload_meta),
output_path=Path(args.output),
decimation_target=args.decimation_target,
texture_size=args.texture_size,
remesh=bool(args.remesh),
safe_nonremesh_fallback=safe_val,
)
result_path.write_text(
json.dumps(
{"ok": True, "output_path": args.output}, indent=2, sort_keys=True
),
encoding="utf-8",
)
return 0
except Exception as error:
result_path.write_text(
json.dumps(
{
"ok": False,
"error_type": type(error).__name__,
"message": str(error),
"traceback": traceback.format_exc(),
},
indent=2,
sort_keys=True,
),
encoding="utf-8",
)
return 1
if __name__ == "__main__":
raise SystemExit(main())