xinjie.wang commited on
Commit
54da04d
·
1 Parent(s): be013ba
Files changed (3) hide show
  1. app.py +1 -61
  2. common.py +6 -622
  3. embodied_gen/utils/monkey_patch/sam3d.py +4 -4
app.py CHANGED
@@ -471,67 +471,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
471
  inputs=image_seg_sam,
472
  outputs=generate_btn,
473
  )
474
-
475
- generate_btn.click(
476
- get_seed,
477
- inputs=[randomize_seed, seed],
478
- outputs=[seed],
479
- ).success(
480
- image_to_3d,
481
- inputs=[
482
- image_prompt,
483
- seed,
484
- ss_sampling_steps,
485
- slat_sampling_steps,
486
- raw_image_cache,
487
- ss_guidance_strength,
488
- slat_guidance_strength,
489
- image_seg_sam,
490
- is_samimage,
491
- ],
492
- outputs=[output_buf, video_output],
493
- ).success(
494
- extract_3d_representations_v3,
495
- inputs=[
496
- output_buf,
497
- project_delight,
498
- texture_size,
499
- ],
500
- outputs=[
501
- model_output_mesh,
502
- model_output_gs,
503
- model_output_obj,
504
- aligned_gs,
505
- ],
506
- ).success(
507
- lambda: gr.Button(interactive=True),
508
- outputs=[extract_urdf_btn],
509
- )
510
-
511
- extract_urdf_btn.click(
512
- extract_urdf,
513
- inputs=[
514
- aligned_gs,
515
- model_output_obj,
516
- asset_cat_text,
517
- height_range_text,
518
- mass_range_text,
519
- asset_version_text,
520
- ],
521
- outputs=[
522
- download_urdf,
523
- est_type_text,
524
- est_height_text,
525
- est_mass_text,
526
- est_mu_text,
527
- ],
528
- queue=True,
529
- show_progress="full",
530
- ).success(
531
- lambda: gr.Button(interactive=True),
532
- outputs=[download_urdf],
533
- )
534
-
535
 
536
  if __name__ == "__main__":
537
  demo.launch()
 
471
  inputs=image_seg_sam,
472
  outputs=generate_btn,
473
  )
474
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
 
476
  if __name__ == "__main__":
477
  demo.launch()
common.py CHANGED
@@ -15,10 +15,6 @@
15
  # permissions and limitations under the License.
16
 
17
  import spaces
18
- from embodied_gen.utils.monkey_patch.trellis import monkey_path_trellis
19
-
20
- monkey_path_trellis()
21
-
22
  import gc
23
  import logging
24
  import os
@@ -32,48 +28,21 @@ import gradio as gr
32
  import numpy as np
33
  import torch
34
  import trimesh
35
- from PIL import Image
36
- from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
37
- from embodied_gen.data.backproject_v3 import entrypoint as backproject_api_v3
38
- from embodied_gen.data.differentiable_render import entrypoint as render_api
39
- from embodied_gen.data.utils import trellis_preprocess, zip_files
40
- from embodied_gen.models.delight_model import DelightingModel
41
- from embodied_gen.models.gs_model import GaussianOperator
42
- from embodied_gen.models.sam3d import Sam3dInference
43
  from embodied_gen.models.segment_model import (
44
  BMGG14Remover,
45
  RembgRemover,
46
  SAMPredictor,
47
- )
48
- from embodied_gen.models.sr_model import ImageRealESRGAN, ImageStableSR
49
- from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
50
- from embodied_gen.scripts.render_mv import build_texture_gen_pipe, infer_pipe
51
- from embodied_gen.scripts.text2image import (
52
- build_text2img_ip_pipeline,
53
- build_text2img_pipeline,
54
- text2img_gen,
55
- )
56
  from embodied_gen.utils.gpt_clients import GPT_CLIENT
57
  from embodied_gen.utils.process_media import (
58
  filter_image_small_connected_components,
59
  keep_largest_connected_component,
60
  merge_images_video,
61
  )
62
- from embodied_gen.utils.tags import VERSION
63
- from embodied_gen.utils.trender import pack_state, render_video, unpack_state
64
- from embodied_gen.validators.quality_checkers import (
65
- BaseChecker,
66
- ImageAestheticChecker,
67
- ImageSegChecker,
68
- MeshGeoChecker,
69
- )
70
- from embodied_gen.validators.urdf_convertor import URDFGenerator
71
-
72
- current_file_path = os.path.abspath(__file__)
73
- current_dir = os.path.dirname(current_file_path)
74
- sys.path.append(os.path.join(current_dir, ".."))
75
- from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
76
- from thirdparty.TRELLIS.trellis.utils import postprocessing_utils
77
 
78
  logging.basicConfig(
79
  format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
@@ -83,67 +52,15 @@ logger = logging.getLogger(__name__)
83
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
84
  os.environ.setdefault("OPENAI_API_KEY", "sk-placeholder")
85
  MAX_SEED = 100000
86
-
87
- # DELIGHT = DelightingModel()
88
- # IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
89
- # IMAGESR_MODEL = ImageStableSR()
90
  if os.getenv("GRADIO_APP").startswith("imageto3d"):
91
  RBG_REMOVER = RembgRemover()
92
  RBG14_REMOVER = BMGG14Remover()
93
  SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
94
- # if "sam3d" in os.getenv("GRADIO_APP"):
95
- # PIPELINE = Sam3dInference()
96
- # else:
97
- # PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
98
- # "microsoft/TRELLIS-image-large"
99
- # )
100
- # # PIPELINE.cuda()
101
- # SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
102
- # GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
103
- # AESTHETIC_CHECKER = ImageAestheticChecker()
104
- # CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
105
  TMP_DIR = os.path.join(
106
  os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
107
  )
108
  os.makedirs(TMP_DIR, exist_ok=True)
109
- elif os.getenv("GRADIO_APP").startswith("textto3d"):
110
- RBG_REMOVER = RembgRemover()
111
- RBG14_REMOVER = BMGG14Remover()
112
- if "sam3d" in os.getenv("GRADIO_APP"):
113
- PIPELINE = Sam3dInference()
114
- else:
115
- PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
116
- "microsoft/TRELLIS-image-large"
117
- )
118
- # PIPELINE.cuda()
119
- text_model_dir = "weights/Kolors"
120
- PIPELINE_IMG_IP = build_text2img_ip_pipeline(text_model_dir, ref_scale=0.3)
121
- PIPELINE_IMG = build_text2img_pipeline(text_model_dir)
122
- SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
123
- GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
124
- AESTHETIC_CHECKER = ImageAestheticChecker()
125
- CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
126
- TMP_DIR = os.path.join(
127
- os.path.dirname(os.path.abspath(__file__)), "sessions/textto3d"
128
- )
129
- os.makedirs(TMP_DIR, exist_ok=True)
130
- elif os.getenv("GRADIO_APP") == "texture_edit":
131
- DELIGHT = DelightingModel()
132
- IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
133
- PIPELINE_IP = build_texture_gen_pipe(
134
- base_ckpt_dir="./weights",
135
- ip_adapt_scale=0.7,
136
- device="cuda",
137
- )
138
- PIPELINE = build_texture_gen_pipe(
139
- base_ckpt_dir="./weights",
140
- ip_adapt_scale=0,
141
- device="cuda",
142
- )
143
- TMP_DIR = os.path.join(
144
- os.path.dirname(os.path.abspath(__file__)), "sessions/texture_edit"
145
- )
146
- os.makedirs(TMP_DIR, exist_ok=True)
147
 
148
 
149
  def start_session(req: gr.Request) -> None:
@@ -262,536 +179,3 @@ def select_point(
262
 
263
  return (image, masks), seg_image
264
 
265
-
266
- @spaces.GPU(duration=300)
267
- def image_to_3d(
268
- image: Image.Image,
269
- seed: int,
270
- ss_sampling_steps: int,
271
- slat_sampling_steps: int,
272
- raw_image_cache: Image.Image,
273
- ss_guidance_strength: float,
274
- slat_guidance_strength: float,
275
- sam_image: Image.Image = None,
276
- is_sam_image: bool = False,
277
- req: gr.Request = None,
278
- ) -> tuple[dict, str]:
279
- if is_sam_image:
280
- seg_image = filter_image_small_connected_components(sam_image)
281
- seg_image = Image.fromarray(seg_image, mode="RGBA")
282
- else:
283
- seg_image = image
284
-
285
- if isinstance(seg_image, np.ndarray):
286
- seg_image = Image.fromarray(seg_image)
287
-
288
- logger.info("Start generating 3D representation from image...")
289
- if isinstance(PIPELINE, Sam3dInference):
290
- outputs = PIPELINE.run(
291
- seg_image,
292
- seed=seed,
293
- stage1_inference_steps=ss_sampling_steps,
294
- stage2_inference_steps=slat_sampling_steps,
295
- )
296
- else:
297
- PIPELINE.cuda()
298
- seg_image = trellis_preprocess(seg_image)
299
- outputs = PIPELINE.run(
300
- seg_image,
301
- seed=seed,
302
- formats=["gaussian", "mesh"],
303
- preprocess_image=False,
304
- sparse_structure_sampler_params={
305
- "steps": ss_sampling_steps,
306
- "cfg_strength": ss_guidance_strength,
307
- },
308
- slat_sampler_params={
309
- "steps": slat_sampling_steps,
310
- "cfg_strength": slat_guidance_strength,
311
- },
312
- )
313
- # Set back to cpu for memory saving.
314
- PIPELINE.cpu()
315
-
316
- gs_model = outputs["gaussian"][0]
317
- mesh_model = outputs["mesh"][0]
318
- color_images = render_video(gs_model, r=1.85)["color"]
319
- normal_images = render_video(mesh_model, r=1.85)["normal"]
320
-
321
- output_root = os.path.join(TMP_DIR, str(req.session_hash))
322
- os.makedirs(output_root, exist_ok=True)
323
- seg_image.save(f"{output_root}/seg_image.png")
324
- raw_image_cache.save(f"{output_root}/raw_image.png")
325
-
326
- video_path = os.path.join(output_root, "gs_mesh.mp4")
327
- merge_images_video(color_images, normal_images, video_path)
328
- state = pack_state(gs_model, mesh_model)
329
-
330
- gc.collect()
331
- torch.cuda.empty_cache()
332
-
333
- return state, video_path
334
-
335
-
336
- def extract_3d_representations_v2(
337
- state: dict,
338
- enable_delight: bool,
339
- texture_size: int,
340
- req: gr.Request,
341
- ):
342
- """Back-Projection Version of Texture Super-Resolution."""
343
- output_root = TMP_DIR
344
- user_dir = os.path.join(output_root, str(req.session_hash))
345
- gs_model, mesh_model = unpack_state(state, device="cpu")
346
-
347
- filename = "sample"
348
- gs_path = os.path.join(user_dir, f"{filename}_gs.ply")
349
- gs_model.save_ply(gs_path)
350
-
351
- # Rotate mesh and GS by 90 degrees around Z-axis.
352
- rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
353
- gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
354
- mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
355
-
356
- # Addtional rotation for GS to align mesh.
357
- gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
358
- pose = GaussianOperator.trans_to_quatpose(gs_rot)
359
- aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
360
- GaussianOperator.resave_ply(
361
- in_ply=gs_path,
362
- out_ply=aligned_gs_path,
363
- instance_pose=pose,
364
- device="cpu",
365
- )
366
- color_path = os.path.join(user_dir, "color.png")
367
- render_gs_api(
368
- input_gs=aligned_gs_path,
369
- output_path=color_path,
370
- elevation=[20, -10, 60, -50],
371
- num_images=12,
372
- )
373
-
374
- mesh = trimesh.Trimesh(
375
- vertices=mesh_model.vertices.cpu().numpy(),
376
- faces=mesh_model.faces.cpu().numpy(),
377
- )
378
- mesh.vertices = mesh.vertices @ np.array(mesh_add_rot)
379
- mesh.vertices = mesh.vertices @ np.array(rot_matrix)
380
-
381
- mesh_obj_path = os.path.join(user_dir, f"{filename}.obj")
382
- mesh.export(mesh_obj_path)
383
-
384
- mesh = backproject_api(
385
- delight_model=DELIGHT,
386
- imagesr_model=IMAGESR_MODEL,
387
- color_path=color_path,
388
- mesh_path=mesh_obj_path,
389
- output_path=mesh_obj_path,
390
- skip_fix_mesh=False,
391
- delight=enable_delight,
392
- texture_wh=[texture_size, texture_size],
393
- elevation=[20, -10, 60, -50],
394
- num_images=12,
395
- )
396
-
397
- mesh_glb_path = os.path.join(user_dir, f"{filename}.glb")
398
- mesh.export(mesh_glb_path)
399
-
400
- return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
401
-
402
-
403
- def extract_3d_representations_v3(
404
- state: dict,
405
- enable_delight: bool,
406
- texture_size: int,
407
- req: gr.Request,
408
- ):
409
- """Back-Projection Version with Optimization-Based."""
410
- output_root = TMP_DIR
411
- user_dir = os.path.join(output_root, str(req.session_hash))
412
- gs_model, mesh_model = unpack_state(state, device="cpu")
413
-
414
- filename = "sample"
415
- gs_path = os.path.join(user_dir, f"{filename}_gs.ply")
416
- gs_model.save_ply(gs_path)
417
-
418
- # Rotate mesh and GS by 90 degrees around Z-axis.
419
- rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
420
- gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
421
- mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
422
-
423
- # Addtional rotation for GS to align mesh.
424
- gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
425
- pose = GaussianOperator.trans_to_quatpose(gs_rot)
426
- aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
427
- GaussianOperator.resave_ply(
428
- in_ply=gs_path,
429
- out_ply=aligned_gs_path,
430
- instance_pose=pose,
431
- device="cpu",
432
- )
433
-
434
- mesh = trimesh.Trimesh(
435
- vertices=mesh_model.vertices.cpu().numpy(),
436
- faces=mesh_model.faces.cpu().numpy(),
437
- )
438
- mesh.vertices = mesh.vertices @ np.array(mesh_add_rot)
439
- mesh.vertices = mesh.vertices @ np.array(rot_matrix)
440
-
441
- mesh_obj_path = os.path.join(user_dir, f"{filename}.obj")
442
- mesh.export(mesh_obj_path)
443
-
444
- mesh = backproject_api_v3(
445
- gs_path=aligned_gs_path,
446
- mesh_path=mesh_obj_path,
447
- output_path=mesh_obj_path,
448
- skip_fix_mesh=False,
449
- texture_size=texture_size,
450
- )
451
-
452
- mesh_glb_path = os.path.join(user_dir, f"{filename}.glb")
453
- mesh.export(mesh_glb_path)
454
-
455
- return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
456
-
457
-
458
- def extract_urdf(
459
- gs_path: str,
460
- mesh_obj_path: str,
461
- asset_cat_text: str,
462
- height_range_text: str,
463
- mass_range_text: str,
464
- asset_version_text: str,
465
- req: gr.Request = None,
466
- ):
467
- output_root = TMP_DIR
468
- if req is not None:
469
- output_root = os.path.join(output_root, str(req.session_hash))
470
-
471
- # Convert to URDF and recover attrs by GPT.
472
- filename = "sample"
473
- urdf_convertor = URDFGenerator(
474
- GPT_CLIENT, render_view_num=4, decompose_convex=True
475
- )
476
- asset_attrs = {
477
- "version": VERSION,
478
- "gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply",
479
- }
480
- if asset_version_text:
481
- asset_attrs["version"] = asset_version_text
482
- if asset_cat_text:
483
- asset_attrs["category"] = asset_cat_text.lower()
484
- if height_range_text:
485
- try:
486
- min_height, max_height = map(float, height_range_text.split("-"))
487
- asset_attrs["min_height"] = min_height
488
- asset_attrs["max_height"] = max_height
489
- except ValueError:
490
- return "Invalid height input format. Use the format: min-max."
491
- if mass_range_text:
492
- try:
493
- min_mass, max_mass = map(float, mass_range_text.split("-"))
494
- asset_attrs["min_mass"] = min_mass
495
- asset_attrs["max_mass"] = max_mass
496
- except ValueError:
497
- return "Invalid mass input format. Use the format: min-max."
498
-
499
- urdf_path = urdf_convertor(
500
- mesh_path=mesh_obj_path,
501
- output_root=f"{output_root}/URDF_{filename}",
502
- **asset_attrs,
503
- )
504
-
505
- # Rescale GS and save to URDF/mesh folder.
506
- real_height = urdf_convertor.get_attr_from_urdf(
507
- urdf_path, attr_name="real_height"
508
- )
509
- out_gs = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}_gs.ply" # noqa
510
- GaussianOperator.resave_ply(
511
- in_ply=gs_path,
512
- out_ply=out_gs,
513
- real_height=real_height,
514
- device="cpu",
515
- )
516
-
517
- # Quality check and update .urdf file.
518
- mesh_out = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}.obj" # noqa
519
- trimesh.load(mesh_out).export(mesh_out.replace(".obj", ".glb"))
520
- # image_paths = render_asset3d(
521
- # mesh_path=mesh_out,
522
- # output_root=f"{output_root}/URDF_{filename}",
523
- # output_subdir="qa_renders",
524
- # num_images=8,
525
- # elevation=(30, -30),
526
- # distance=5.5,
527
- # )
528
-
529
- image_dir = f"{output_root}/URDF_{filename}/{urdf_convertor.output_render_dir}/image_color" # noqa
530
- image_paths = glob(f"{image_dir}/*.png")
531
- images_list = []
532
- for checker in CHECKERS:
533
- images = image_paths
534
- if isinstance(checker, ImageSegChecker):
535
- images = [
536
- f"{TMP_DIR}/{req.session_hash}/raw_image.png",
537
- f"{TMP_DIR}/{req.session_hash}/seg_image.png",
538
- ]
539
- images_list.append(images)
540
-
541
- results = BaseChecker.validate(CHECKERS, images_list)
542
- urdf_convertor.add_quality_tag(urdf_path, results)
543
-
544
- # Zip urdf files
545
- urdf_zip = zip_files(
546
- input_paths=[
547
- f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}",
548
- f"{output_root}/URDF_{filename}/{filename}.urdf",
549
- ],
550
- output_zip=f"{output_root}/urdf_{filename}.zip",
551
- )
552
-
553
- estimated_type = urdf_convertor.estimated_attrs["category"]
554
- estimated_height = urdf_convertor.estimated_attrs["height"]
555
- estimated_mass = urdf_convertor.estimated_attrs["mass"]
556
- estimated_mu = urdf_convertor.estimated_attrs["mu"]
557
-
558
- return (
559
- urdf_zip,
560
- estimated_type,
561
- estimated_height,
562
- estimated_mass,
563
- estimated_mu,
564
- )
565
-
566
-
567
- @spaces.GPU(duration=300)
568
- def text2image_fn(
569
- prompt: str,
570
- guidance_scale: float,
571
- infer_step: int = 50,
572
- ip_image: Image.Image | str = None,
573
- ip_adapt_scale: float = 0.3,
574
- image_wh: int | tuple[int, int] = [1024, 1024],
575
- rmbg_tag: str = "rembg",
576
- seed: int = None,
577
- enable_pre_resize: bool = True,
578
- n_sample: int = 3,
579
- req: gr.Request = None,
580
- ):
581
- if isinstance(image_wh, int):
582
- image_wh = (image_wh, image_wh)
583
- output_root = TMP_DIR
584
- if req is not None:
585
- output_root = os.path.join(output_root, str(req.session_hash))
586
- os.makedirs(output_root, exist_ok=True)
587
-
588
- pipeline = PIPELINE_IMG if ip_image is None else PIPELINE_IMG_IP
589
- if ip_image is not None:
590
- pipeline.set_ip_adapter_scale([ip_adapt_scale])
591
-
592
- images = text2img_gen(
593
- prompt=prompt,
594
- n_sample=n_sample,
595
- guidance_scale=guidance_scale,
596
- pipeline=pipeline,
597
- ip_image=ip_image,
598
- image_wh=image_wh,
599
- infer_step=infer_step,
600
- seed=seed,
601
- )
602
-
603
- for idx in range(len(images)):
604
- image = images[idx]
605
- images[idx], _ = preprocess_image_fn(
606
- image, rmbg_tag, enable_pre_resize
607
- )
608
-
609
- save_paths = []
610
- for idx, image in enumerate(images):
611
- save_path = f"{output_root}/sample_{idx}.png"
612
- image.save(save_path)
613
- save_paths.append(save_path)
614
-
615
- logger.info(f"Images saved to {output_root}")
616
-
617
- gc.collect()
618
- torch.cuda.empty_cache()
619
-
620
- return save_paths + save_paths
621
-
622
-
623
- @spaces.GPU(duration=120)
624
- def generate_condition(mesh_path: str, req: gr.Request, uuid: str = "sample"):
625
- output_root = os.path.join(TMP_DIR, str(req.session_hash))
626
-
627
- _ = render_api(
628
- mesh_path=mesh_path,
629
- output_root=f"{output_root}/condition",
630
- uuid=str(uuid),
631
- )
632
-
633
- gc.collect()
634
- torch.cuda.empty_cache()
635
-
636
- return None, None, None
637
-
638
-
639
- @spaces.GPU(duration=300)
640
- def generate_texture_mvimages(
641
- prompt: str,
642
- controlnet_cond_scale: float = 0.55,
643
- guidance_scale: float = 9,
644
- strength: float = 0.9,
645
- num_inference_steps: int = 50,
646
- seed: int = 0,
647
- ip_adapt_scale: float = 0,
648
- ip_img_path: str = None,
649
- uid: str = "sample",
650
- sub_idxs: tuple[tuple[int]] = ((0, 1, 2), (3, 4, 5)),
651
- req: gr.Request = None,
652
- ) -> list[str]:
653
- output_root = os.path.join(TMP_DIR, str(req.session_hash))
654
- use_ip_adapter = True if ip_img_path and ip_adapt_scale > 0 else False
655
- PIPELINE_IP.set_ip_adapter_scale([ip_adapt_scale])
656
- img_save_paths = infer_pipe(
657
- index_file=f"{output_root}/condition/index.json",
658
- controlnet_cond_scale=controlnet_cond_scale,
659
- guidance_scale=guidance_scale,
660
- strength=strength,
661
- num_inference_steps=num_inference_steps,
662
- ip_adapt_scale=ip_adapt_scale,
663
- ip_img_path=ip_img_path,
664
- uid=uid,
665
- prompt=prompt,
666
- save_dir=f"{output_root}/multi_view",
667
- sub_idxs=sub_idxs,
668
- pipeline=PIPELINE_IP if use_ip_adapter else PIPELINE,
669
- seed=seed,
670
- )
671
-
672
- gc.collect()
673
- torch.cuda.empty_cache()
674
-
675
- return img_save_paths + img_save_paths
676
-
677
-
678
- def backproject_texture(
679
- mesh_path: str,
680
- input_image: str,
681
- texture_size: int,
682
- uuid: str = "sample",
683
- req: gr.Request = None,
684
- ) -> str:
685
- output_root = os.path.join(TMP_DIR, str(req.session_hash))
686
- output_dir = os.path.join(output_root, "texture_mesh")
687
- os.makedirs(output_dir, exist_ok=True)
688
- command = [
689
- "backproject-cli",
690
- "--mesh_path",
691
- mesh_path,
692
- "--input_image",
693
- input_image,
694
- "--output_root",
695
- output_dir,
696
- "--uuid",
697
- f"{uuid}",
698
- "--texture_size",
699
- str(texture_size),
700
- "--skip_fix_mesh",
701
- ]
702
-
703
- _ = subprocess.run(
704
- command, capture_output=True, text=True, encoding="utf-8"
705
- )
706
- output_obj_mesh = os.path.join(output_dir, f"{uuid}.obj")
707
- output_glb_mesh = os.path.join(output_dir, f"{uuid}.glb")
708
- _ = trimesh.load(output_obj_mesh).export(output_glb_mesh)
709
-
710
- zip_file = zip_files(
711
- input_paths=[
712
- output_glb_mesh,
713
- output_obj_mesh,
714
- os.path.join(output_dir, "material.mtl"),
715
- os.path.join(output_dir, "material_0.png"),
716
- ],
717
- output_zip=os.path.join(output_dir, f"{uuid}.zip"),
718
- )
719
-
720
- gc.collect()
721
- torch.cuda.empty_cache()
722
-
723
- return output_glb_mesh, output_obj_mesh, zip_file
724
-
725
-
726
- @spaces.GPU(duration=300)
727
- def backproject_texture_v2(
728
- mesh_path: str,
729
- input_image: str,
730
- texture_size: int,
731
- enable_delight: bool = True,
732
- fix_mesh: bool = False,
733
- no_mesh_post_process: bool = False,
734
- uuid: str = "sample",
735
- req: gr.Request = None,
736
- ) -> str:
737
- output_root = os.path.join(TMP_DIR, str(req.session_hash))
738
- output_dir = os.path.join(output_root, "texture_mesh")
739
- os.makedirs(output_dir, exist_ok=True)
740
-
741
- textured_mesh = backproject_api(
742
- delight_model=DELIGHT,
743
- imagesr_model=IMAGESR_MODEL,
744
- color_path=input_image,
745
- mesh_path=mesh_path,
746
- output_path=f"{output_dir}/{uuid}.obj",
747
- skip_fix_mesh=not fix_mesh,
748
- delight=enable_delight,
749
- texture_wh=[texture_size, texture_size],
750
- no_mesh_post_process=no_mesh_post_process,
751
- )
752
-
753
- output_obj_mesh = os.path.join(output_dir, f"{uuid}.obj")
754
- output_glb_mesh = os.path.join(output_dir, f"{uuid}.glb")
755
- _ = textured_mesh.export(output_glb_mesh)
756
-
757
- zip_file = zip_files(
758
- input_paths=[
759
- output_glb_mesh,
760
- output_obj_mesh,
761
- os.path.join(output_dir, "material.mtl"),
762
- os.path.join(output_dir, "material_0.png"),
763
- ],
764
- output_zip=os.path.join(output_dir, f"{uuid}.zip"),
765
- )
766
-
767
- gc.collect()
768
- torch.cuda.empty_cache()
769
-
770
- return output_glb_mesh, output_obj_mesh, zip_file
771
-
772
-
773
- @spaces.GPU(duration=120)
774
- def render_result_video(
775
- mesh_path: str, video_size: int, req: gr.Request, uuid: str = ""
776
- ) -> str:
777
- output_root = os.path.join(TMP_DIR, str(req.session_hash))
778
- output_dir = os.path.join(output_root, "texture_mesh")
779
-
780
- _ = render_api(
781
- mesh_path=mesh_path,
782
- output_root=output_dir,
783
- num_images=90,
784
- elevation=[20],
785
- with_mtl=True,
786
- pbr_light_factor=1,
787
- uuid=str(uuid),
788
- gen_color_mp4=True,
789
- gen_glonormal_mp4=True,
790
- distance=5.5,
791
- resolution_hw=(video_size, video_size),
792
- )
793
-
794
- gc.collect()
795
- torch.cuda.empty_cache()
796
-
797
- return f"{output_dir}/color.mp4"
 
15
  # permissions and limitations under the License.
16
 
17
  import spaces
 
 
 
 
18
  import gc
19
  import logging
20
  import os
 
28
  import numpy as np
29
  import torch
30
  import trimesh
31
+ from PIL import Image
32
+ from embodied_gen.data.utils import trellis_preprocess, zip_files
 
 
 
 
 
 
33
  from embodied_gen.models.segment_model import (
34
  BMGG14Remover,
35
  RembgRemover,
36
  SAMPredictor,
37
+ )
 
 
 
 
 
 
 
 
38
  from embodied_gen.utils.gpt_clients import GPT_CLIENT
39
  from embodied_gen.utils.process_media import (
40
  filter_image_small_connected_components,
41
  keep_largest_connected_component,
42
  merge_images_video,
43
  )
44
+ from embodied_gen.utils.tags import VERSION
45
+
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  logging.basicConfig(
48
  format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
 
52
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
53
  os.environ.setdefault("OPENAI_API_KEY", "sk-placeholder")
54
  MAX_SEED = 100000
55
+
 
 
 
56
  if os.getenv("GRADIO_APP").startswith("imageto3d"):
57
  RBG_REMOVER = RembgRemover()
58
  RBG14_REMOVER = BMGG14Remover()
59
  SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
 
 
 
 
 
 
 
 
 
 
 
60
  TMP_DIR = os.path.join(
61
  os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
62
  )
63
  os.makedirs(TMP_DIR, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
 
66
  def start_session(req: gr.Request) -> None:
 
179
 
180
  return (image, masks), seg_image
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
embodied_gen/utils/monkey_patch/sam3d.py CHANGED
@@ -40,7 +40,7 @@ def monkey_patch_sam3d():
40
  if sam3d_root not in sys.path:
41
  sys.path.insert(0, sam3d_root)
42
 
43
- def # patch_pointmap_infer_pipeline():
44
  """Patches InferencePipelinePointMap.run to handle pointmap generation and 3D structure sampling."""
45
  try:
46
  from sam3d_objects.pipeline.inference_pipeline_pointmap import (
@@ -202,7 +202,7 @@ def monkey_patch_sam3d():
202
 
203
  InferencePipelinePointMap.run = patch_run
204
 
205
- def # patch_infer_init():
206
  """Patches InferencePipeline.__init__ to allow CPU offloading during model initialization."""
207
  import torch
208
 
@@ -380,7 +380,7 @@ def monkey_patch_sam3d():
380
 
381
  InferencePipeline.__init__ = patch_init
382
 
383
- # # patch_pointmap_infer_pipeline()
384
- # # patch_infer_init()
385
 
386
  return
 
40
  if sam3d_root not in sys.path:
41
  sys.path.insert(0, sam3d_root)
42
 
43
+ def patch_pointmap_infer_pipeline():
44
  """Patches InferencePipelinePointMap.run to handle pointmap generation and 3D structure sampling."""
45
  try:
46
  from sam3d_objects.pipeline.inference_pipeline_pointmap import (
 
202
 
203
  InferencePipelinePointMap.run = patch_run
204
 
205
+ def patch_infer_init():
206
  """Patches InferencePipeline.__init__ to allow CPU offloading during model initialization."""
207
  import torch
208
 
 
380
 
381
  InferencePipeline.__init__ = patch_init
382
 
383
+ # patch_pointmap_infer_pipeline()
384
+ # patch_infer_init()
385
 
386
  return