xinjie.wang commited on
Commit
ff87343
·
1 Parent(s): 35d8d13
Files changed (3) hide show
  1. app.py +44 -17
  2. common.py +5 -4
  3. embodied_gen/utils/monkey_patch/gradio.py +41 -0
app.py CHANGED
@@ -39,19 +39,35 @@ from common import (
39
  start_session,
40
  )
41
 
 
42
  app_name = os.getenv("GRADIO_APP")
43
  if app_name == "imageto3d_sam3d":
44
- enable_pre_resize = False
45
  sample_step = 25
46
  bg_rm_model_name = "rembg" # "rembg", "rmbg14"
47
  elif app_name == "imageto3d":
48
- enable_pre_resize = True
49
  sample_step = 12
50
  bg_rm_model_name = "rembg" # "rembg", "rmbg14"
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
53
  gr.HTML(image_css, visible=False)
54
- gr.HTML(lighting_css, visible=False)
55
  gr.Markdown(
56
  """
57
  ## ***EmbodiedGen***: Image-to-3D Asset
@@ -77,7 +93,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
77
  ),
78
  elem_classes=["header"],
79
  )
80
-
81
  with gr.Row():
82
  with gr.Column(scale=3):
83
  with gr.Tabs() as input_tabs:
@@ -262,21 +278,21 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
262
  has quality inspection, open with an editor to view details.
263
  """
264
  )
265
- enable_pre_resize = gr.State(enable_pre_resize)
266
  with gr.Row() as single_image_example:
267
  examples = gr.Examples(
268
  label="Image Gallery",
269
  examples=[
270
  [image_path]
271
  for image_path in sorted(
272
- glob("apps/assets/example_image/*")
273
  )
274
  ],
275
- inputs=[image_prompt, rmbg_tag, enable_pre_resize],
276
- fn=preprocess_image_fn,
277
- outputs=[image_prompt, raw_image_cache],
278
  run_on_click=True,
279
  examples_per_page=10,
 
280
  )
281
 
282
  with gr.Row(visible=False) as single_sam_image_example:
@@ -285,7 +301,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
285
  examples=[
286
  [image_path]
287
  for image_path in sorted(
288
- glob("apps/assets/example_image/*")
289
  )
290
  ],
291
  inputs=[image_prompt_sam],
@@ -294,7 +310,6 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
294
  run_on_click=True,
295
  examples_per_page=10,
296
  )
297
-
298
  with gr.Column(scale=2):
299
  gr.Markdown("<br>")
300
  video_output = gr.Video(
@@ -340,10 +355,22 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
340
  )
341
 
342
  image_prompt.upload(
343
- preprocess_image_fn,
344
- inputs=[image_prompt, rmbg_tag, enable_pre_resize],
345
  outputs=[image_prompt, raw_image_cache],
 
 
 
 
 
346
  )
 
 
 
 
 
 
 
347
  image_prompt.change(
348
  lambda: tuple(
349
  [
@@ -382,10 +409,9 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
382
  est_mu_text,
383
  ],
384
  )
385
- image_prompt.change(
386
- active_btn_by_content,
387
- inputs=image_prompt,
388
- outputs=generate_btn,
389
  )
390
 
391
  image_prompt_sam.upload(
@@ -511,4 +537,5 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
511
 
512
 
513
  if __name__ == "__main__":
 
514
  demo.launch()
 
39
  start_session,
40
  )
41
 
42
+
43
  app_name = os.getenv("GRADIO_APP")
44
  if app_name == "imageto3d_sam3d":
45
+ _enable_pre_resize_default = False
46
  sample_step = 25
47
  bg_rm_model_name = "rembg" # "rembg", "rmbg14"
48
  elif app_name == "imageto3d":
49
+ _enable_pre_resize_default = True
50
  sample_step = 12
51
  bg_rm_model_name = "rembg" # "rembg", "rmbg14"
52
 
53
+ current_rmbg_tag = bg_rm_model_name
54
+ def set_current_rmbg_tag(rmbg: str) -> None:
55
+ global current_rmbg_tag
56
+ current_rmbg_tag = rmbg
57
+
58
+
59
+ def preprocess_example_image(
60
+ img: str,
61
+ ) -> tuple[object, object, gr.Button]:
62
+ image, image_cache = preprocess_image_fn(
63
+ img, current_rmbg_tag, _enable_pre_resize_default
64
+ )
65
+ return image, image_cache, gr.Button(interactive=True)
66
+
67
+
68
  with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
69
  gr.HTML(image_css, visible=False)
70
+ # gr.HTML(lighting_css, visible=False)
71
  gr.Markdown(
72
  """
73
  ## ***EmbodiedGen***: Image-to-3D Asset
 
93
  ),
94
  elem_classes=["header"],
95
  )
96
+ enable_pre_resize = gr.State(_enable_pre_resize_default)
97
  with gr.Row():
98
  with gr.Column(scale=3):
99
  with gr.Tabs() as input_tabs:
 
278
  has quality inspection, open with an editor to view details.
279
  """
280
  )
 
281
  with gr.Row() as single_image_example:
282
  examples = gr.Examples(
283
  label="Image Gallery",
284
  examples=[
285
  [image_path]
286
  for image_path in sorted(
287
+ glob("assets/example_image/*")
288
  )
289
  ],
290
+ inputs=[image_prompt],
291
+ fn=preprocess_example_image,
292
+ outputs=[image_prompt, raw_image_cache, generate_btn],
293
  run_on_click=True,
294
  examples_per_page=10,
295
+ cache_examples=False,
296
  )
297
 
298
  with gr.Row(visible=False) as single_sam_image_example:
 
301
  examples=[
302
  [image_path]
303
  for image_path in sorted(
304
+ glob("assets/example_image/*")
305
  )
306
  ],
307
  inputs=[image_prompt_sam],
 
310
  run_on_click=True,
311
  examples_per_page=10,
312
  )
 
313
  with gr.Column(scale=2):
314
  gr.Markdown("<br>")
315
  video_output = gr.Video(
 
355
  )
356
 
357
  image_prompt.upload(
358
+ lambda img, rmbg: preprocess_image_fn(img, rmbg, _enable_pre_resize_default),
359
+ inputs=[image_prompt, rmbg_tag],
360
  outputs=[image_prompt, raw_image_cache],
361
+ queue=False,
362
+ ).success(
363
+ active_btn_by_content,
364
+ inputs=image_prompt,
365
+ outputs=generate_btn,
366
  )
367
+
368
+ rmbg_tag.change(
369
+ set_current_rmbg_tag,
370
+ inputs=[rmbg_tag],
371
+ outputs=[],
372
+ )
373
+
374
  image_prompt.change(
375
  lambda: tuple(
376
  [
 
409
  est_mu_text,
410
  ],
411
  )
412
+ image_prompt.clear(
413
+ lambda: gr.Button(interactive=False),
414
+ outputs=[generate_btn],
 
415
  )
416
 
417
  image_prompt_sam.upload(
 
537
 
538
 
539
  if __name__ == "__main__":
540
+ # launch_demo()
541
  demo.launch()
common.py CHANGED
@@ -18,6 +18,8 @@ import spaces
18
  from embodied_gen.utils.monkey_patch.trellis import monkey_path_trellis
19
 
20
  monkey_path_trellis()
 
 
21
 
22
  import gc
23
  import logging
@@ -73,7 +75,6 @@ current_file_path = os.path.abspath(__file__)
73
  current_dir = os.path.dirname(current_file_path)
74
  sys.path.append(os.path.join(current_dir, ".."))
75
  from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
76
- from thirdparty.TRELLIS.trellis.utils import postprocessing_utils
77
 
78
  logging.basicConfig(
79
  format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
@@ -275,7 +276,7 @@ def image_to_3d(
275
  sam_image: Image.Image = None,
276
  is_sam_image: bool = False,
277
  req: gr.Request = None,
278
- ) -> tuple[dict, str]:
279
  if is_sam_image:
280
  seg_image = filter_image_small_connected_components(sam_image)
281
  seg_image = Image.fromarray(seg_image, mode="RGBA")
@@ -334,7 +335,7 @@ def image_to_3d(
334
 
335
 
336
  def extract_3d_representations_v2(
337
- state: dict,
338
  enable_delight: bool,
339
  texture_size: int,
340
  req: gr.Request,
@@ -401,7 +402,7 @@ def extract_3d_representations_v2(
401
 
402
 
403
  def extract_3d_representations_v3(
404
- state: dict,
405
  enable_delight: bool,
406
  texture_size: int,
407
  req: gr.Request,
 
18
  from embodied_gen.utils.monkey_patch.trellis import monkey_path_trellis
19
 
20
  monkey_path_trellis()
21
+ from embodied_gen.utils.monkey_patch.gradio import _patch_gradio_schema_bool_bug
22
+ _patch_gradio_schema_bool_bug()
23
 
24
  import gc
25
  import logging
 
75
  current_dir = os.path.dirname(current_file_path)
76
  sys.path.append(os.path.join(current_dir, ".."))
77
  from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
 
78
 
79
  logging.basicConfig(
80
  format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
 
276
  sam_image: Image.Image = None,
277
  is_sam_image: bool = False,
278
  req: gr.Request = None,
279
+ ) -> tuple[object, str]:
280
  if is_sam_image:
281
  seg_image = filter_image_small_connected_components(sam_image)
282
  seg_image = Image.fromarray(seg_image, mode="RGBA")
 
335
 
336
 
337
  def extract_3d_representations_v2(
338
+ state: object,
339
  enable_delight: bool,
340
  texture_size: int,
341
  req: gr.Request,
 
402
 
403
 
404
  def extract_3d_representations_v3(
405
+ state: object,
406
  enable_delight: bool,
407
  texture_size: int,
408
  req: gr.Request,
embodied_gen/utils/monkey_patch/gradio.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ import gradio_client.utils as gradio_client_utils
19
+
20
+
21
+ def _patch_gradio_schema_bool_bug() -> None:
22
+ """Patch gradio_client schema parser for bool-style additionalProperties."""
23
+ original_get_type = gradio_client_utils.get_type
24
+ original_json_schema_to_python_type = (
25
+ gradio_client_utils._json_schema_to_python_type
26
+ )
27
+
28
+ def _safe_get_type(schema):
29
+ if isinstance(schema, bool):
30
+ return {}
31
+ return original_get_type(schema)
32
+
33
+ def _safe_json_schema_to_python_type(schema, defs):
34
+ if isinstance(schema, bool):
35
+ return "Any"
36
+ return original_json_schema_to_python_type(schema, defs)
37
+
38
+ gradio_client_utils.get_type = _safe_get_type
39
+ gradio_client_utils._json_schema_to_python_type = (
40
+ _safe_json_schema_to_python_type
41
+ )