xinjie.wang commited on
Commit
7e484a7
·
1 Parent(s): 74fb66c
app.bk2.py DELETED
@@ -1,473 +0,0 @@
1
- # Project EmbodiedGen
2
- #
3
- # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
- # implied. See the License for the specific language governing
15
- # permissions and limitations under the License.
16
-
17
-
18
- import os
19
-
20
- # GRADIO_APP == "imageto3d_sam3d", sam3d object model, by default.
21
- # GRADIO_APP == "imageto3d", TRELLIS model.
22
- os.environ["GRADIO_APP"] = "imageto3d_sam3d"
23
- from glob import glob
24
-
25
- import gradio as gr
26
- from app_style import custom_theme, image_css, lighting_css
27
- from common import (
28
- MAX_SEED,
29
- VERSION,
30
- active_btn_by_content,
31
- end_session,
32
- preprocess_image_fn,
33
- preprocess_sam_image_fn,
34
- select_point,
35
- start_session,
36
- )
37
-
38
- app_name = os.getenv("GRADIO_APP")
39
- if app_name == "imageto3d_sam3d":
40
- _enable_pre_resize_default = False
41
- sample_step = 25
42
- bg_rm_model_name = "rembg" # "rembg", "rmbg14"
43
- elif app_name == "imageto3d":
44
- _enable_pre_resize_default = True
45
- sample_step = 12
46
- bg_rm_model_name = "rembg" # "rembg", "rmbg14"
47
-
48
- current_rmbg_tag = bg_rm_model_name
49
- def set_current_rmbg_tag(rmbg: str) -> None:
50
- global current_rmbg_tag
51
- current_rmbg_tag = rmbg
52
-
53
-
54
- def preprocess_example_image(
55
- img: str,
56
- ) -> tuple[object, object, gr.Button]:
57
- image, image_cache = preprocess_image_fn(
58
- img, current_rmbg_tag, _enable_pre_resize_default
59
- )
60
- return image, image_cache, gr.Button(interactive=True)
61
-
62
- with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
63
- gr.HTML(image_css, visible=False)
64
- # gr.HTML(lighting_css, visible=False)
65
- gr.Markdown(
66
- """
67
- ## ***EmbodiedGen***: Image-to-3D Asset
68
- **🔖 Version**: {VERSION}
69
- <p style="display: flex; gap: 10px; flex-wrap: nowrap;">
70
- <a href="https://horizonrobotics.github.io/EmbodiedGen">
71
- <img alt="📖 Documentation" src="https://img.shields.io/badge/📖-Documentation-blue">
72
- </a>
73
- <a href="https://arxiv.org/abs/2506.10600">
74
- <img alt="📄 arXiv" src="https://img.shields.io/badge/📄-arXiv-b31b1b">
75
- </a>
76
- <a href="https://github.com/HorizonRobotics/EmbodiedGen">
77
- <img alt="💻 GitHub" src="https://img.shields.io/badge/GitHub-000000?logo=github">
78
- </a>
79
- <a href="https://www.youtube.com/watch?v=rG4odybuJRk">
80
- <img alt="🎥 Video" src="https://img.shields.io/badge/🎥-Video-red">
81
- </a>
82
- </p>
83
-
84
- 🖼️ Generate physically plausible 3D asset from single input image.
85
- """.format(
86
- VERSION=VERSION
87
- ),
88
- elem_classes=["header"],
89
- )
90
- enable_pre_resize = gr.State(_enable_pre_resize_default)
91
- with gr.Row():
92
- with gr.Column(scale=3):
93
- with gr.Tabs() as input_tabs:
94
- with gr.Tab(
95
- label="Image(auto seg)", id=0
96
- ) as single_image_input_tab:
97
- raw_image_cache = gr.Image(
98
- format="png",
99
- image_mode="RGB",
100
- type="pil",
101
- visible=False,
102
- )
103
- image_prompt = gr.Image(
104
- label="Input Image",
105
- format="png",
106
- image_mode="RGBA",
107
- type="pil",
108
- height=400,
109
- elem_classes=["image_fit"],
110
- )
111
- gr.Markdown(
112
- """
113
- If you are not satisfied with the auto segmentation
114
- result, please switch to the `Image(SAM seg)` tab."""
115
- )
116
- with gr.Tab(
117
- label="Image(SAM seg)", id=1
118
- ) as samimage_input_tab:
119
- with gr.Row():
120
- with gr.Column(scale=1):
121
- image_prompt_sam = gr.Image(
122
- label="Input Image",
123
- type="numpy",
124
- height=400,
125
- elem_classes=["image_fit"],
126
- )
127
- image_seg_sam = gr.Image(
128
- label="SAM Seg Image",
129
- image_mode="RGBA",
130
- type="pil",
131
- height=400,
132
- visible=False,
133
- )
134
- with gr.Column(scale=1):
135
- image_mask_sam = gr.AnnotatedImage(
136
- elem_classes=["image_fit"]
137
- )
138
-
139
- fg_bg_radio = gr.Radio(
140
- ["foreground_point", "background_point"],
141
- label="Select foreground(green) or background(red) points, by default foreground", # noqa
142
- value="foreground_point",
143
- )
144
- gr.Markdown(
145
- """ Click the `Input Image` to select SAM points,
146
- after get the satisified segmentation, click `Generate`
147
- button to generate the 3D asset. \n
148
- Note: If the segmented foreground is too small relative
149
- to the entire image area, the generation will fail.
150
- """
151
- )
152
-
153
- with gr.Accordion(label="Generation Settings", open=False):
154
- with gr.Row():
155
- seed = gr.Slider(
156
- 0, MAX_SEED, label="Seed", value=0, step=1
157
- )
158
- texture_size = gr.Slider(
159
- 1024,
160
- 4096,
161
- label="UV texture size",
162
- value=2048,
163
- step=256,
164
- )
165
- rmbg_tag = gr.Radio(
166
- choices=["rembg", "rmbg14"],
167
- value=bg_rm_model_name,
168
- label="Background Removal Model",
169
- )
170
- with gr.Row():
171
- randomize_seed = gr.Checkbox(
172
- label="Randomize Seed", value=False
173
- )
174
- project_delight = gr.Checkbox(
175
- label="Back-project Delight",
176
- value=True,
177
- )
178
- gr.Markdown("Geo Structure Generation")
179
- with gr.Row():
180
- ss_guidance_strength = gr.Slider(
181
- 0.0,
182
- 10.0,
183
- label="Guidance Strength",
184
- value=7.5,
185
- step=0.1,
186
- )
187
- ss_sampling_steps = gr.Slider(
188
- 1,
189
- 50,
190
- label="Sampling Steps",
191
- value=sample_step,
192
- step=1,
193
- )
194
- gr.Markdown("Visual Appearance Generation")
195
- with gr.Row():
196
- slat_guidance_strength = gr.Slider(
197
- 0.0,
198
- 10.0,
199
- label="Guidance Strength",
200
- value=3.0,
201
- step=0.1,
202
- )
203
- slat_sampling_steps = gr.Slider(
204
- 1,
205
- 50,
206
- label="Sampling Steps",
207
- value=sample_step,
208
- step=1,
209
- )
210
-
211
- generate_btn = gr.Button(
212
- "🚀 1. Generate(~2 mins)",
213
- variant="primary",
214
- interactive=False,
215
- )
216
- model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False)
217
- # with gr.Row():
218
- # extract_rep3d_btn = gr.Button(
219
- # "🔍 2. Extract 3D Representation(~2 mins)",
220
- # variant="primary",
221
- # interactive=False,
222
- # )
223
- with gr.Accordion(
224
- label="Enter Asset Attributes(optional)", open=False
225
- ):
226
- asset_cat_text = gr.Textbox(
227
- label="Enter Asset Category (e.g., chair)"
228
- )
229
- height_range_text = gr.Textbox(
230
- label="Enter **Height Range** in meter (e.g., 0.5-0.6)"
231
- )
232
- mass_range_text = gr.Textbox(
233
- label="Enter **Mass Range** in kg (e.g., 1.1-1.2)"
234
- )
235
- asset_version_text = gr.Textbox(
236
- label=f"Enter version (e.g., {VERSION})"
237
- )
238
- with gr.Row():
239
- extract_urdf_btn = gr.Button(
240
- "🧩 2. Extract URDF with physics(~1 mins)",
241
- variant="primary",
242
- interactive=False,
243
- )
244
- with gr.Row():
245
- gr.Markdown(
246
- "#### Estimated Asset 3D Attributes(No input required)"
247
- )
248
- with gr.Row():
249
- est_type_text = gr.Textbox(
250
- label="Asset category", interactive=False
251
- )
252
- est_height_text = gr.Textbox(
253
- label="Real height(.m)", interactive=False
254
- )
255
- est_mass_text = gr.Textbox(
256
- label="Mass(.kg)", interactive=False
257
- )
258
- est_mu_text = gr.Textbox(
259
- label="Friction coefficient", interactive=False
260
- )
261
- with gr.Row():
262
- download_urdf = gr.DownloadButton(
263
- label="⬇️ 3. Download URDF",
264
- variant="primary",
265
- interactive=False,
266
- )
267
-
268
- gr.Markdown(
269
- """ NOTE: If `Asset Attributes` are provided, it will guide
270
- GPT to perform physical attributes restoration. \n
271
- The `Download URDF` file is restored to the real scale and
272
- has quality inspection, open with an editor to view details.
273
- """
274
- )
275
- with gr.Row() as single_image_example:
276
- examples = gr.Examples(
277
- label="Image Gallery",
278
- examples=[
279
- [image_path]
280
- for image_path in sorted(
281
- glob("assets/example_image/*")
282
- )
283
- ],
284
- inputs=[image_prompt],
285
- fn=preprocess_example_image,
286
- outputs=[image_prompt, raw_image_cache, generate_btn],
287
- run_on_click=True,
288
- examples_per_page=10,
289
- cache_examples=False,
290
- )
291
-
292
- with gr.Row(visible=False) as single_sam_image_example:
293
- examples = gr.Examples(
294
- label="Image Gallery",
295
- examples=[
296
- [image_path]
297
- for image_path in sorted(
298
- glob("assets/example_image/*")
299
- )
300
- ],
301
- inputs=[image_prompt_sam],
302
- fn=preprocess_sam_image_fn,
303
- outputs=[image_prompt_sam, raw_image_cache],
304
- run_on_click=True,
305
- examples_per_page=10,
306
- )
307
- with gr.Column(scale=2):
308
- gr.Markdown("<br>")
309
- video_output = gr.Video(
310
- label="Generated 3D Asset",
311
- autoplay=True,
312
- loop=True,
313
- height=400,
314
- )
315
- model_output_gs = gr.Model3D(
316
- label="Gaussian Representation", height=350, interactive=False
317
- )
318
- aligned_gs = gr.Textbox(visible=False)
319
- gr.Markdown(
320
- """ The rendering of `Gaussian Representation` takes additional 10s. """ # noqa
321
- )
322
- with gr.Row():
323
- model_output_mesh = gr.Model3D(
324
- label="Mesh Representation",
325
- height=350,
326
- interactive=False,
327
- clear_color=[0, 0, 0, 1],
328
- elem_id="lighter_mesh",
329
- )
330
-
331
- is_samimage = gr.State(False)
332
- output_buf = gr.State()
333
- selected_points = gr.State(value=[])
334
-
335
- demo.load(start_session)
336
- demo.unload(end_session)
337
-
338
- single_image_input_tab.select(
339
- lambda: tuple(
340
- [False, gr.Row.update(visible=True), gr.Row.update(visible=False)]
341
- ),
342
- outputs=[is_samimage, single_image_example, single_sam_image_example],
343
- )
344
- samimage_input_tab.select(
345
- lambda: tuple(
346
- [True, gr.Row.update(visible=True), gr.Row.update(visible=False)]
347
- ),
348
- outputs=[is_samimage, single_sam_image_example, single_image_example],
349
- )
350
-
351
- image_prompt.upload(
352
- lambda img, rmbg: preprocess_image_fn(img, rmbg, _enable_pre_resize_default),
353
- inputs=[image_prompt, rmbg_tag],
354
- outputs=[image_prompt, raw_image_cache],
355
- queue=False,
356
- ).success(
357
- active_btn_by_content,
358
- inputs=image_prompt,
359
- outputs=generate_btn,
360
- )
361
- rmbg_tag.change(
362
- set_current_rmbg_tag,
363
- inputs=[rmbg_tag],
364
- outputs=[],
365
- )
366
-
367
- image_prompt.change(
368
- lambda: tuple(
369
- [
370
- # gr.Button(interactive=False),
371
- gr.Button(interactive=False),
372
- gr.Button(interactive=False),
373
- None,
374
- "",
375
- None,
376
- None,
377
- "",
378
- "",
379
- "",
380
- "",
381
- "",
382
- "",
383
- "",
384
- "",
385
- ]
386
- ),
387
- outputs=[
388
- # extract_rep3d_btn,
389
- extract_urdf_btn,
390
- download_urdf,
391
- model_output_gs,
392
- aligned_gs,
393
- model_output_mesh,
394
- video_output,
395
- asset_cat_text,
396
- height_range_text,
397
- mass_range_text,
398
- asset_version_text,
399
- est_type_text,
400
- est_height_text,
401
- est_mass_text,
402
- est_mu_text,
403
- ],
404
- )
405
- image_prompt.clear(
406
- lambda: gr.Button(interactive=False),
407
- outputs=[generate_btn],
408
- )
409
-
410
- image_prompt_sam.upload(
411
- preprocess_sam_image_fn,
412
- inputs=[image_prompt_sam],
413
- outputs=[image_prompt_sam, raw_image_cache],
414
- )
415
- image_prompt_sam.change(
416
- lambda: tuple(
417
- [
418
- # gr.Button(interactive=False),
419
- gr.Button(interactive=False),
420
- gr.Button(interactive=False),
421
- None,
422
- None,
423
- None,
424
- "",
425
- "",
426
- "",
427
- "",
428
- "",
429
- "",
430
- "",
431
- "",
432
- None,
433
- [],
434
- ]
435
- ),
436
- outputs=[
437
- # extract_rep3d_btn,
438
- extract_urdf_btn,
439
- download_urdf,
440
- model_output_gs,
441
- model_output_mesh,
442
- video_output,
443
- asset_cat_text,
444
- height_range_text,
445
- mass_range_text,
446
- asset_version_text,
447
- est_type_text,
448
- est_height_text,
449
- est_mass_text,
450
- est_mu_text,
451
- image_mask_sam,
452
- selected_points,
453
- ],
454
- )
455
-
456
- image_prompt_sam.select(
457
- select_point,
458
- [
459
- image_prompt_sam,
460
- selected_points,
461
- fg_bg_radio,
462
- ],
463
- [image_mask_sam, image_seg_sam],
464
- )
465
- image_seg_sam.change(
466
- active_btn_by_content,
467
- inputs=image_seg_sam,
468
- outputs=generate_btn,
469
- )
470
-
471
-
472
- if __name__ == "__main__":
473
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -19,7 +19,6 @@ import os
19
 
20
  # GRADIO_APP == "imageto3d_sam3d", sam3d object model, by default.
21
  # GRADIO_APP == "imageto3d", TRELLIS model.
22
- # os.environ["GRADIO_APP"] = "imageto3d_sam3d"
23
  os.environ["GRADIO_APP"] = "imageto3d"
24
  from glob import glob
25
 
 
19
 
20
  # GRADIO_APP == "imageto3d_sam3d", sam3d object model, by default.
21
  # GRADIO_APP == "imageto3d", TRELLIS model.
 
22
  os.environ["GRADIO_APP"] = "imageto3d"
23
  from glob import glob
24
 
common.bk2.py DELETED
@@ -1,181 +0,0 @@
1
- # Project EmbodiedGen
2
- #
3
- # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
- # implied. See the License for the specific language governing
15
- # permissions and limitations under the License.
16
-
17
- import spaces
18
- import gc
19
- import logging
20
- import os
21
- import shutil
22
- import subprocess
23
- import sys
24
- from glob import glob
25
-
26
- import cv2
27
- import gradio as gr
28
- import numpy as np
29
- import torch
30
- import trimesh
31
- from PIL import Image
32
- from embodied_gen.data.utils import trellis_preprocess, zip_files
33
- from embodied_gen.models.segment_model import (
34
- BMGG14Remover,
35
- RembgRemover,
36
- SAMPredictor,
37
- )
38
- from embodied_gen.utils.gpt_clients import GPT_CLIENT
39
- from embodied_gen.utils.process_media import (
40
- filter_image_small_connected_components,
41
- keep_largest_connected_component,
42
- merge_images_video,
43
- )
44
- from embodied_gen.utils.tags import VERSION
45
-
46
-
47
- logging.basicConfig(
48
- format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
49
- )
50
- logger = logging.getLogger(__name__)
51
-
52
- os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
53
- os.environ.setdefault("OPENAI_API_KEY", "sk-placeholder")
54
- MAX_SEED = 100000
55
-
56
- if os.getenv("GRADIO_APP").startswith("imageto3d"):
57
- RBG_REMOVER = RembgRemover()
58
- RBG14_REMOVER = BMGG14Remover()
59
- SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cuda")
60
- TMP_DIR = os.path.join(
61
- os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
62
- )
63
- os.makedirs(TMP_DIR, exist_ok=True)
64
-
65
-
66
- def start_session(req: gr.Request) -> None:
67
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
68
- os.makedirs(user_dir, exist_ok=True)
69
-
70
-
71
- def end_session(req: gr.Request) -> None:
72
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
73
- if os.path.exists(user_dir):
74
- shutil.rmtree(user_dir)
75
-
76
- @spaces.GPU()
77
- def preprocess_image_fn(
78
- image: str | np.ndarray | Image.Image,
79
- rmbg_tag: str = "rembg",
80
- preprocess: bool = True,
81
- ) -> tuple[Image.Image, Image.Image]:
82
- if isinstance(image, str):
83
- image = Image.open(image)
84
- elif isinstance(image, np.ndarray):
85
- image = Image.fromarray(image)
86
-
87
- image_cache = image.copy() # resize_pil(image.copy(), 1024)
88
-
89
- bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER
90
- image = bg_remover(image)
91
- image = keep_largest_connected_component(image)
92
-
93
- if preprocess:
94
- image = trellis_preprocess(image)
95
-
96
- return image, image_cache
97
-
98
-
99
- def preprocess_sam_image_fn(
100
- image: Image.Image,
101
- ) -> tuple[Image.Image, Image.Image]:
102
- if isinstance(image, np.ndarray):
103
- image = Image.fromarray(image)
104
-
105
- sam_image = SAM_PREDICTOR.preprocess_image(image)
106
- image_cache = sam_image.copy()
107
- SAM_PREDICTOR.predictor.set_image(sam_image)
108
-
109
- return sam_image, image_cache
110
-
111
-
112
- def active_btn_by_content(content: gr.Image) -> gr.Button:
113
- interactive = True if content is not None else False
114
-
115
- return gr.Button(interactive=interactive)
116
-
117
-
118
- def active_btn_by_text_content(content: gr.Textbox) -> gr.Button:
119
- if content is not None and len(content) > 0:
120
- interactive = True
121
- else:
122
- interactive = False
123
-
124
- return gr.Button(interactive=interactive)
125
-
126
-
127
- def get_selected_image(
128
- choice: str, sample1: str, sample2: str, sample3: str
129
- ) -> str:
130
- if choice == "sample1":
131
- return sample1
132
- elif choice == "sample2":
133
- return sample2
134
- elif choice == "sample3":
135
- return sample3
136
- else:
137
- raise ValueError(f"Invalid choice: {choice}")
138
-
139
-
140
- def get_cached_image(image_path: str) -> Image.Image:
141
- if isinstance(image_path, Image.Image):
142
- return image_path
143
- return Image.open(image_path).resize((512, 512))
144
-
145
-
146
- def get_seed(randomize_seed: bool, seed: int, max_seed: int = MAX_SEED) -> int:
147
- return np.random.randint(0, max_seed) if randomize_seed else seed
148
-
149
-
150
- def select_point(
151
- image: np.ndarray,
152
- sel_pix: list,
153
- point_type: str,
154
- evt: gr.SelectData,
155
- ):
156
- if point_type == "foreground_point":
157
- sel_pix.append((evt.index, 1)) # append the foreground_point
158
- elif point_type == "background_point":
159
- sel_pix.append((evt.index, 0)) # append the background_point
160
- else:
161
- sel_pix.append((evt.index, 1)) # default foreground_point
162
-
163
- masks = SAM_PREDICTOR.generate_masks(image, sel_pix)
164
- seg_image = SAM_PREDICTOR.get_segmented_image(image, masks)
165
-
166
- for point, label in sel_pix:
167
- color = (255, 0, 0) if label == 0 else (0, 255, 0)
168
- marker_type = 1 if label == 0 else 5
169
- cv2.drawMarker(
170
- image,
171
- point,
172
- color,
173
- markerType=marker_type,
174
- markerSize=15,
175
- thickness=10,
176
- )
177
-
178
- torch.cuda.empty_cache()
179
-
180
- return (image, masks), seg_image
181
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
common.py CHANGED
@@ -18,8 +18,9 @@ import spaces
18
  from embodied_gen.utils.monkey_patch.trellis import monkey_path_trellis
19
 
20
  monkey_path_trellis()
21
- from embodied_gen.utils.monkey_patch.gradio import _patch_gradio_schema_bool_bug
22
  _patch_gradio_schema_bool_bug()
 
23
 
24
  import gc
25
  import logging
@@ -41,7 +42,7 @@ from embodied_gen.data.differentiable_render import entrypoint as render_api
41
  from embodied_gen.data.utils import trellis_preprocess, zip_files
42
  from embodied_gen.models.delight_model import DelightingModel
43
  from embodied_gen.models.gs_model import GaussianOperator
44
- # from embodied_gen.models.sam3d import Sam3dInference
45
  from embodied_gen.models.segment_model import (
46
  BMGG14Remover,
47
  RembgRemover,
@@ -92,13 +93,13 @@ if os.getenv("GRADIO_APP").startswith("imageto3d"):
92
  RBG_REMOVER = RembgRemover()
93
  RBG14_REMOVER = BMGG14Remover()
94
  SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
95
- # if "sam3d" in os.getenv("GRADIO_APP"):
96
- # PIPELINE = Sam3dInference()
97
- # else:
98
- PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
99
- "microsoft/TRELLIS-image-large"
100
- )
101
- # PIPELINE.cuda()
102
  SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
103
  GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
104
  AESTHETIC_CHECKER = ImageAestheticChecker()
@@ -107,44 +108,44 @@ if os.getenv("GRADIO_APP").startswith("imageto3d"):
107
  os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
108
  )
109
  os.makedirs(TMP_DIR, exist_ok=True)
110
- # elif os.getenv("GRADIO_APP").startswith("textto3d"):
111
- # RBG_REMOVER = RembgRemover()
112
- # RBG14_REMOVER = BMGG14Remover()
113
- # if "sam3d" in os.getenv("GRADIO_APP"):
114
- # PIPELINE = Sam3dInference()
115
- # else:
116
- # PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
117
- # "microsoft/TRELLIS-image-large"
118
- # )
119
- # # PIPELINE.cuda()
120
- # text_model_dir = "weights/Kolors"
121
- # PIPELINE_IMG_IP = build_text2img_ip_pipeline(text_model_dir, ref_scale=0.3)
122
- # PIPELINE_IMG = build_text2img_pipeline(text_model_dir)
123
- # SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
124
- # GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
125
- # AESTHETIC_CHECKER = ImageAestheticChecker()
126
- # CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
127
- # TMP_DIR = os.path.join(
128
- # os.path.dirname(os.path.abspath(__file__)), "sessions/textto3d"
129
- # )
130
- # os.makedirs(TMP_DIR, exist_ok=True)
131
- # elif os.getenv("GRADIO_APP") == "texture_edit":
132
- # DELIGHT = DelightingModel()
133
- # IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
134
- # PIPELINE_IP = build_texture_gen_pipe(
135
- # base_ckpt_dir="./weights",
136
- # ip_adapt_scale=0.7,
137
- # device="cuda",
138
- # )
139
- # PIPELINE = build_texture_gen_pipe(
140
- # base_ckpt_dir="./weights",
141
- # ip_adapt_scale=0,
142
- # device="cuda",
143
- # )
144
- # TMP_DIR = os.path.join(
145
- # os.path.dirname(os.path.abspath(__file__)), "sessions/texture_edit"
146
- # )
147
- # os.makedirs(TMP_DIR, exist_ok=True)
148
 
149
 
150
  def start_session(req: gr.Request) -> None:
@@ -287,32 +288,32 @@ def image_to_3d(
287
  seg_image = Image.fromarray(seg_image)
288
 
289
  logger.info("Start generating 3D representation from image...")
290
- # if isinstance(PIPELINE, Sam3dInference):
291
- # outputs = PIPELINE.run(
292
- # seg_image,
293
- # seed=seed,
294
- # stage1_inference_steps=ss_sampling_steps,
295
- # stage2_inference_steps=slat_sampling_steps,
296
- # )
297
- # else:
298
- PIPELINE.cuda()
299
- seg_image = trellis_preprocess(seg_image)
300
- outputs = PIPELINE.run(
301
- seg_image,
302
- seed=seed,
303
- formats=["gaussian", "mesh"],
304
- preprocess_image=False,
305
- sparse_structure_sampler_params={
306
- "steps": ss_sampling_steps,
307
- "cfg_strength": ss_guidance_strength,
308
- },
309
- slat_sampler_params={
310
- "steps": slat_sampling_steps,
311
- "cfg_strength": slat_guidance_strength,
312
- },
313
- )
314
- # Set back to cpu for memory saving.
315
- PIPELINE.cpu()
316
 
317
  gs_model = outputs["gaussian"][0]
318
  mesh_model = outputs["mesh"][0]
 
18
  from embodied_gen.utils.monkey_patch.trellis import monkey_path_trellis
19
 
20
  monkey_path_trellis()
21
+ from embodied_gen.utils.monkey_patch.gradio import _patch_gradio_schema_bool_bug, _patch_open3d_cuda_device_count_bug
22
  _patch_gradio_schema_bool_bug()
23
+ _patch_open3d_cuda_device_count_bug()
24
 
25
  import gc
26
  import logging
 
42
  from embodied_gen.data.utils import trellis_preprocess, zip_files
43
  from embodied_gen.models.delight_model import DelightingModel
44
  from embodied_gen.models.gs_model import GaussianOperator
45
+ from embodied_gen.models.sam3d import Sam3dInference
46
  from embodied_gen.models.segment_model import (
47
  BMGG14Remover,
48
  RembgRemover,
 
93
  RBG_REMOVER = RembgRemover()
94
  RBG14_REMOVER = BMGG14Remover()
95
  SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
96
+ if "sam3d" in os.getenv("GRADIO_APP"):
97
+ PIPELINE = Sam3dInference()
98
+ else:
99
+ PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
100
+ "microsoft/TRELLIS-image-large"
101
+ )
102
+ # PIPELINE.cuda()
103
  SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
104
  GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
105
  AESTHETIC_CHECKER = ImageAestheticChecker()
 
108
  os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
109
  )
110
  os.makedirs(TMP_DIR, exist_ok=True)
111
+ elif os.getenv("GRADIO_APP").startswith("textto3d"):
112
+ RBG_REMOVER = RembgRemover()
113
+ RBG14_REMOVER = BMGG14Remover()
114
+ if "sam3d" in os.getenv("GRADIO_APP"):
115
+ PIPELINE = Sam3dInference()
116
+ else:
117
+ PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
118
+ "microsoft/TRELLIS-image-large"
119
+ )
120
+ # PIPELINE.cuda()
121
+ text_model_dir = "weights/Kolors"
122
+ PIPELINE_IMG_IP = build_text2img_ip_pipeline(text_model_dir, ref_scale=0.3)
123
+ PIPELINE_IMG = build_text2img_pipeline(text_model_dir)
124
+ SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
125
+ GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
126
+ AESTHETIC_CHECKER = ImageAestheticChecker()
127
+ CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
128
+ TMP_DIR = os.path.join(
129
+ os.path.dirname(os.path.abspath(__file__)), "sessions/textto3d"
130
+ )
131
+ os.makedirs(TMP_DIR, exist_ok=True)
132
+ elif os.getenv("GRADIO_APP") == "texture_edit":
133
+ DELIGHT = DelightingModel()
134
+ IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
135
+ PIPELINE_IP = build_texture_gen_pipe(
136
+ base_ckpt_dir="./weights",
137
+ ip_adapt_scale=0.7,
138
+ device="cuda",
139
+ )
140
+ PIPELINE = build_texture_gen_pipe(
141
+ base_ckpt_dir="./weights",
142
+ ip_adapt_scale=0,
143
+ device="cuda",
144
+ )
145
+ TMP_DIR = os.path.join(
146
+ os.path.dirname(os.path.abspath(__file__)), "sessions/texture_edit"
147
+ )
148
+ os.makedirs(TMP_DIR, exist_ok=True)
149
 
150
 
151
  def start_session(req: gr.Request) -> None:
 
288
  seg_image = Image.fromarray(seg_image)
289
 
290
  logger.info("Start generating 3D representation from image...")
291
+ if isinstance(PIPELINE, Sam3dInference):
292
+ outputs = PIPELINE.run(
293
+ seg_image,
294
+ seed=seed,
295
+ stage1_inference_steps=ss_sampling_steps,
296
+ stage2_inference_steps=slat_sampling_steps,
297
+ )
298
+ else:
299
+ PIPELINE.cuda()
300
+ seg_image = trellis_preprocess(seg_image)
301
+ outputs = PIPELINE.run(
302
+ seg_image,
303
+ seed=seed,
304
+ formats=["gaussian", "mesh"],
305
+ preprocess_image=False,
306
+ sparse_structure_sampler_params={
307
+ "steps": ss_sampling_steps,
308
+ "cfg_strength": ss_guidance_strength,
309
+ },
310
+ slat_sampler_params={
311
+ "steps": slat_sampling_steps,
312
+ "cfg_strength": slat_guidance_strength,
313
+ },
314
+ )
315
+ # Set back to cpu for memory saving.
316
+ PIPELINE.cpu()
317
 
318
  gs_model = outputs["gaussian"][0]
319
  mesh_model = outputs["mesh"][0]
embodied_gen/utils/monkey_patch/gradio.py CHANGED
@@ -16,7 +16,9 @@
16
 
17
 
18
  import gradio_client.utils as gradio_client_utils
19
-
 
 
20
 
21
  def _patch_gradio_schema_bool_bug() -> None:
22
  """Patch gradio_client schema parser for bool-style additionalProperties."""
@@ -38,4 +40,19 @@ def _patch_gradio_schema_bool_bug() -> None:
38
  gradio_client_utils.get_type = _safe_get_type
39
  gradio_client_utils._json_schema_to_python_type = (
40
  _safe_json_schema_to_python_type
41
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  import gradio_client.utils as gradio_client_utils
19
+ import fileinput
20
+ import site
21
+ from pathlib import Path
22
 
23
  def _patch_gradio_schema_bool_bug() -> None:
24
  """Patch gradio_client schema parser for bool-style additionalProperties."""
 
40
  gradio_client_utils.get_type = _safe_get_type
41
  gradio_client_utils._json_schema_to_python_type = (
42
  _safe_json_schema_to_python_type
43
+ )
44
+
45
+
46
+ def _patch_open3d_cuda_device_count_bug() -> None:
47
+ """Patch open3d to avoid cuda device count bug."""
48
+ with fileinput.FileInput(
49
+ f'{site.getsitepackages()[0]}/open3d/__init__.py', inplace=True
50
+ ) as file:
51
+ for line in file:
52
+ print(
53
+ line.replace(
54
+ '_pybind_cuda.open3d_core_cuda_device_count()',
55
+ '1'
56
+ ),
57
+ end=''
58
+ )
embodied_gen/utils/monkey_patch/trellis.py CHANGED
@@ -37,7 +37,7 @@ def monkey_path_trellis():
37
  os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
38
  "~/.cache/torch_extensions"
39
  )
40
- os.environ["SPCONV_ALGO"] = "auto" # Can be 'native' or 'auto'
41
  os.environ['ATTN_BACKEND'] = (
42
  "xformers" # Can be 'flash-attn' or 'xformers'
43
  )
 
37
  os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
38
  "~/.cache/torch_extensions"
39
  )
40
+ os.environ["SPCONV_ALGO"] = "native" # Can be 'native' or 'auto'
41
  os.environ['ATTN_BACKEND'] = (
42
  "xformers" # Can be 'flash-attn' or 'xformers'
43
  )
requirements.txt CHANGED
@@ -56,12 +56,12 @@ seaborn
56
  hydra-core
57
  modelscope
58
  timm
59
- # open3d
60
  MoGe@git+https://github.com/microsoft/MoGe.git@a8c3734
61
 
62
 
63
  # git+https://github.com/facebookresearch/pytorch3d.git@stable
64
- # https://huggingface.co/xinjjj/RoboAssetGen/resolve/main/wheel_cu121/pytorch3d-0.7.8-cp310-cp310-linux_x86_64.whl
65
  # git+https://github.com/nerfstudio-project/gsplat.git@v1.5.3
66
  https://github.com/nerfstudio-project/gsplat/releases/download/v1.5.0/gsplat-1.5.0+pt24cu121-cp310-cp310-linux_x86_64.whl
67
  # flash-attn==2.7.0.post2
 
56
  hydra-core
57
  modelscope
58
  timm
59
+ open3d
60
  MoGe@git+https://github.com/microsoft/MoGe.git@a8c3734
61
 
62
 
63
  # git+https://github.com/facebookresearch/pytorch3d.git@stable
64
+ https://huggingface.co/xinjjj/RoboAssetGen/resolve/main/wheel_cu121/pytorch3d-0.7.8-cp310-cp310-linux_x86_64.whl
65
  # git+https://github.com/nerfstudio-project/gsplat.git@v1.5.3
66
  https://github.com/nerfstudio-project/gsplat/releases/download/v1.5.0/gsplat-1.5.0+pt24cu121-cp310-cp310-linux_x86_64.whl
67
  # flash-attn==2.7.0.post2