xinjie.wang commited on
Commit
f3347e3
·
1 Parent(s): f82f044
Files changed (2) hide show
  1. common.py +34 -34
  2. requirements.txt +1 -1
common.py CHANGED
@@ -41,7 +41,7 @@ from embodied_gen.data.differentiable_render import entrypoint as render_api
41
  from embodied_gen.data.utils import trellis_preprocess, zip_files
42
  from embodied_gen.models.delight_model import DelightingModel
43
  from embodied_gen.models.gs_model import GaussianOperator
44
- from embodied_gen.models.sam3d import Sam3dInference
45
  from embodied_gen.models.segment_model import (
46
  BMGG14Remover,
47
  RembgRemover,
@@ -92,13 +92,13 @@ if os.getenv("GRADIO_APP").startswith("imageto3d"):
92
  RBG_REMOVER = RembgRemover()
93
  RBG14_REMOVER = BMGG14Remover()
94
  SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
95
- if "sam3d" in os.getenv("GRADIO_APP"):
96
- PIPELINE = Sam3dInference()
97
- else:
98
- PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
99
- "microsoft/TRELLIS-image-large"
100
- )
101
- # PIPELINE.cuda()
102
  SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
103
  GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
104
  AESTHETIC_CHECKER = ImageAestheticChecker()
@@ -287,32 +287,32 @@ def image_to_3d(
287
  seg_image = Image.fromarray(seg_image)
288
 
289
  logger.info("Start generating 3D representation from image...")
290
- if isinstance(PIPELINE, Sam3dInference):
291
- outputs = PIPELINE.run(
292
- seg_image,
293
- seed=seed,
294
- stage1_inference_steps=ss_sampling_steps,
295
- stage2_inference_steps=slat_sampling_steps,
296
- )
297
- else:
298
- PIPELINE.cuda()
299
- seg_image = trellis_preprocess(seg_image)
300
- outputs = PIPELINE.run(
301
- seg_image,
302
- seed=seed,
303
- formats=["gaussian", "mesh"],
304
- preprocess_image=False,
305
- sparse_structure_sampler_params={
306
- "steps": ss_sampling_steps,
307
- "cfg_strength": ss_guidance_strength,
308
- },
309
- slat_sampler_params={
310
- "steps": slat_sampling_steps,
311
- "cfg_strength": slat_guidance_strength,
312
- },
313
- )
314
- # Set back to cpu for memory saving.
315
- PIPELINE.cpu()
316
 
317
  gs_model = outputs["gaussian"][0]
318
  mesh_model = outputs["mesh"][0]
 
41
  from embodied_gen.data.utils import trellis_preprocess, zip_files
42
  from embodied_gen.models.delight_model import DelightingModel
43
  from embodied_gen.models.gs_model import GaussianOperator
44
+ # from embodied_gen.models.sam3d import Sam3dInference
45
  from embodied_gen.models.segment_model import (
46
  BMGG14Remover,
47
  RembgRemover,
 
92
  RBG_REMOVER = RembgRemover()
93
  RBG14_REMOVER = BMGG14Remover()
94
  SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
95
+ # if "sam3d" in os.getenv("GRADIO_APP"):
96
+ # PIPELINE = Sam3dInference()
97
+ # else:
98
+ PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
99
+ "microsoft/TRELLIS-image-large"
100
+ )
101
+ # PIPELINE.cuda()
102
  SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
103
  GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
104
  AESTHETIC_CHECKER = ImageAestheticChecker()
 
287
  seg_image = Image.fromarray(seg_image)
288
 
289
  logger.info("Start generating 3D representation from image...")
290
+ # if isinstance(PIPELINE, Sam3dInference):
291
+ # outputs = PIPELINE.run(
292
+ # seg_image,
293
+ # seed=seed,
294
+ # stage1_inference_steps=ss_sampling_steps,
295
+ # stage2_inference_steps=slat_sampling_steps,
296
+ # )
297
+ # else:
298
+ PIPELINE.cuda()
299
+ seg_image = trellis_preprocess(seg_image)
300
+ outputs = PIPELINE.run(
301
+ seg_image,
302
+ seed=seed,
303
+ formats=["gaussian", "mesh"],
304
+ preprocess_image=False,
305
+ sparse_structure_sampler_params={
306
+ "steps": ss_sampling_steps,
307
+ "cfg_strength": ss_guidance_strength,
308
+ },
309
+ slat_sampler_params={
310
+ "steps": slat_sampling_steps,
311
+ "cfg_strength": slat_guidance_strength,
312
+ },
313
+ )
314
+ # Set back to cpu for memory saving.
315
+ PIPELINE.cpu()
316
 
317
  gs_model = outputs["gaussian"][0]
318
  mesh_model = outputs["mesh"][0]
requirements.txt CHANGED
@@ -61,7 +61,7 @@ MoGe@git+https://github.com/microsoft/MoGe.git@a8c3734
61
 
62
 
63
  # git+https://github.com/facebookresearch/pytorch3d.git@stable
64
- https://huggingface.co/xinjjj/RoboAssetGen/resolve/main/wheel_cu121/pytorch3d-0.7.8-cp310-cp310-linux_x86_64.whl
65
  # git+https://github.com/nerfstudio-project/gsplat.git@v1.5.3
66
  https://github.com/nerfstudio-project/gsplat/releases/download/v1.5.0/gsplat-1.5.0+pt24cu121-cp310-cp310-linux_x86_64.whl
67
  # flash-attn==2.7.0.post2
 
61
 
62
 
63
  # git+https://github.com/facebookresearch/pytorch3d.git@stable
64
+ # https://huggingface.co/xinjjj/RoboAssetGen/resolve/main/wheel_cu121/pytorch3d-0.7.8-cp310-cp310-linux_x86_64.whl
65
  # git+https://github.com/nerfstudio-project/gsplat.git@v1.5.3
66
  https://github.com/nerfstudio-project/gsplat/releases/download/v1.5.0/gsplat-1.5.0+pt24cu121-cp310-cp310-linux_x86_64.whl
67
  # flash-attn==2.7.0.post2