Spaces:
Running
on
Zero
Running
on
Zero
| # Project EmbodiedGen | |
| # | |
| # Copyright (c) 2025 Horizon Robotics. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | |
| # implied. See the License for the specific language governing | |
| # permissions and limitations under the License. | |
| from embodied_gen.utils.monkey_patch.trellis import monkey_path_trellis | |
| monkey_path_trellis() | |
| import random | |
| import torch | |
| from PIL import Image | |
| from embodied_gen.data.utils import trellis_preprocess | |
| from embodied_gen.models.sam3d import Sam3dInference | |
| from embodied_gen.utils.trender import pack_state, unpack_state | |
| from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline | |
| __all__ = [ | |
| "image3d_model_infer", | |
| ] | |
| def image3d_model_infer( | |
| pipe: TrellisImageTo3DPipeline | Sam3dInference, | |
| seg_image: Image.Image, | |
| seed: int = None, | |
| **kwargs: dict, | |
| ) -> dict[str, any]: | |
| """Execute 3D generation using Trellis or SAM3D pipeline on input image.""" | |
| if isinstance(pipe, TrellisImageTo3DPipeline): | |
| pipe.cuda() | |
| seg_image = trellis_preprocess(seg_image) | |
| outputs = pipe.run( | |
| seg_image, | |
| preprocess_image=False, | |
| seed=(random.randint(0, 100000) if seed is None else seed), | |
| # Optional parameters | |
| # sparse_structure_sampler_params={ | |
| # "steps": 12, | |
| # "cfg_strength": 7.5, | |
| # }, | |
| # slat_sampler_params={ | |
| # "steps": 12, | |
| # "cfg_strength": 3, | |
| # }, | |
| **kwargs, | |
| ) | |
| pipe.cpu() | |
| elif isinstance(pipe, Sam3dInference): | |
| outputs = pipe.run( | |
| seg_image, | |
| seed=(random.randint(0, 100000) if seed is None else seed), | |
| # stage1_inference_steps=25, | |
| # stage2_inference_steps=25, | |
| **kwargs, | |
| ) | |
| state = pack_state(outputs["gaussian"][0], outputs["mesh"][0]) | |
| # Align GS3D from SAM3D with TRELLIS format. | |
| outputs["gaussian"][0], _ = unpack_state(state, device="cuda") | |
| else: | |
| raise ValueError(f"Unsupported pipeline type: {type(pipe)}") | |
| torch.cuda.empty_cache() | |
| return outputs | |