Edoruin commited on
Commit
764c047
·
verified ·
1 Parent(s): f4e3326

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -98
app.py CHANGED
@@ -4,93 +4,40 @@ import shlex
4
  import subprocess
5
  import tempfile
6
  import time
7
-
8
- import gradio as gr
 
9
  import numpy as np
10
  import rembg
11
  import spaces
12
- import torch
13
  from PIL import Image
14
  from functools import partial
15
-
16
- subprocess.run(shlex.split('pip install wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl'))
17
-
18
- from tsr.system import TSR
19
- from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
20
-
21
- from torchmcubes import marching_cubes
22
- import sys
23
- import types
24
- import torch
25
 
26
- # 1. try import mcubes
27
  try:
28
  import mcubes
 
 
 
 
 
 
29
  except ImportError:
30
  print("Error: PyMCubes no está en requirements.txt")
31
 
32
- # 2.parched mock
33
- # false library for run
34
- mock_torchmcubes = types.ModuleType("torchmcubes")
35
-
36
- def marching_cubes_cpu(vertices, threshold):
37
- # turning torch to numpy
38
- v, f = mcubes.marching_cubes(vertices.detach().cpu().numpy(), threshold)
39
- return torch.from_numpy(v.astype("float32")), torch.from_numpy(f.astype("int64"))
40
-
41
- # send to function the false library
42
- mock_torchmcubes.marching_cubes = marching_cubes_cpu
43
- # register for TRIPOSD found the false module
44
- sys.modules["torchmcubes"] = mock_torchmcubes
45
-
46
-
47
-
48
-
49
-
50
- HEADER = """
51
- # TripoSR Demo
52
- <table bgcolor="#1E2432" cellspacing="0" cellpadding="0" width="450">
53
- <tr style="height:50px;">
54
- <td style="text-align: center;">
55
- <a href="https://stability.ai">
56
- <img src="https://images.squarespace-cdn.com/content/v1/6213c340453c3f502425776e/6c9c4c25-5410-4547-bc26-dc621cdacb25/Stability+AI+logo.png" width="200" height="40" />
57
- </a>
58
- </td>
59
- <td style="text-align: center;">
60
- <a href="https://www.tripo3d.ai">
61
- <img src="https://tripo-public.cdn.bcebos.com/logo.png" width="40" height="40" />
62
- </a>
63
- </td>
64
- </tr>
65
- </table>
66
- <table bgcolor="#1E2432" cellspacing="0" cellpadding="0" width="450">
67
- <tr style="height:30px;">
68
- <td style="text-align: center;">
69
- <a href="https://huggingface.co/stabilityai/TripoSR"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Model_Card-Huggingface-orange" height="20"></a>
70
- </td>
71
- <td style="text-align: center;">
72
- <a href="https://github.com/VAST-AI-Research/TripoSR"><img src="https://postimage.me/images/2024/03/04/GitHub_Logo_White.png" width="100" height="20"></a>
73
- </td>
74
- <td style="text-align: center; color: white;">
75
- <a href="https://arxiv.org/abs/2403.02151"><img src="https://img.shields.io/badge/arXiv-2403.02151-b31b1b.svg" height="20"></a>
76
- </td>
77
- </tr>
78
- </table>
79
 
 
 
 
80
  > Try our new model: **SF3D** with several improvements such as faster generation and more game-ready assets.
81
- >
82
- > The model is available [here](https://huggingface.co/stabilityai/stable-fast-3d) and we also have a [demo](https://huggingface.co/spaces/stabilityai/stable-fast-3d).
83
-
84
-
85
- **TripoSR** is a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, developed in collaboration between [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
86
-
87
- **Tips:**
88
- 1. If you find the result is unsatisfied, please try to change the foreground ratio. It might improve the results.
89
- 2. It's better to disable "Remove Background" for the provided examples since they have been already preprocessed.
90
- 3. Otherwise, please disable "Remove Background" option only if your input image is RGBA with transparent background, image contents are centered and occupy more than 70% of image width or height.
91
  """
92
 
93
-
94
  if torch.cuda.is_available():
95
  device = "cuda:0"
96
  else:
@@ -103,22 +50,18 @@ model = TSR.from_pretrained(
103
  )
104
  model.renderer.set_chunk_size(131072)
105
  model.to(device)
106
-
107
  rembg_session = rembg.new_session()
108
 
109
-
110
  def check_input_image(input_image):
111
  if input_image is None:
112
  raise gr.Error("No image uploaded!")
113
 
114
-
115
  def preprocess(input_image, do_remove_background, foreground_ratio):
116
  def fill_background(image):
117
  image = np.array(image).astype(np.float32) / 255.0
118
  image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
119
  image = Image.fromarray((image * 255.0).astype(np.uint8))
120
  return image
121
-
122
  if do_remove_background:
123
  image = input_image.convert("RGB")
124
  image = remove_background(image, rembg_session)
@@ -130,20 +73,16 @@ def preprocess(input_image, do_remove_background, foreground_ratio):
130
  image = fill_background(image)
131
  return image
132
 
133
-
134
  @spaces.GPU
135
  def generate(image, mc_resolution, formats=["obj", "glb"]):
136
  scene_codes = model(image, device=device)
137
  mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
138
  mesh = to_gradio_3d_orientation(mesh)
139
-
140
  mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f".glb", delete=False)
141
  mesh.export(mesh_path_glb.name)
142
-
143
  mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False)
144
- mesh.apply_scale([-1, 1, 1]) # Otherwise the visualized .obj will be flipped
145
  mesh.export(mesh_path_obj.name)
146
-
147
  return mesh_path_obj.name, mesh_path_glb.name
148
 
149
  def run_example(image_pil):
@@ -181,8 +120,8 @@ with gr.Blocks() as demo:
181
  minimum=32,
182
  maximum=320,
183
  value=256,
184
- step=32
185
- )
186
  with gr.Row():
187
  submit = gr.Button("Generate", elem_id="generate", variant="primary")
188
  with gr.Column():
@@ -191,25 +130,26 @@ with gr.Blocks() as demo:
191
  label="Output Model (OBJ Format)",
192
  interactive=False,
193
  )
194
- gr.Markdown("Note: Downloaded object will be flipped in case of .obj export. Export .glb instead or manually flip it before usage.")
195
  with gr.Tab("GLB"):
196
  output_model_glb = gr.Model3D(
197
  label="Output Model (GLB Format)",
198
  interactive=False,
199
  )
200
- gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
201
- with gr.Row(variant="panel"):
202
- gr.Examples(
203
- examples=[
204
- os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
205
- ],
206
- inputs=[input_image],
207
- outputs=[processed_image, output_model_obj, output_model_glb],
208
- cache_examples=True,
209
- fn=partial(run_example),
210
- label="Examples",
211
- examples_per_page=20
212
- )
 
 
213
  submit.click(fn=check_input_image, inputs=[input_image]).success(
214
  fn=preprocess,
215
  inputs=[input_image, do_remove_background, foreground_ratio],
@@ -221,4 +161,4 @@ with gr.Blocks() as demo:
221
  )
222
 
223
  demo.queue(max_size=10)
224
- demo.launch()
 
4
  import subprocess
5
  import tempfile
6
  import time
7
+ import sys
8
+ import types
9
+ import torch
10
  import numpy as np
11
  import rembg
12
  import spaces
13
+ import gradio as gr
14
  from PIL import Image
15
  from functools import partial
 
 
 
 
 
 
 
 
 
 
16
 
17
+ # --- PARCHE DE CPU (DEBE IR ANTES DE IMPORTAR TSR) ---
18
  try:
19
  import mcubes
20
+ mock_torchmcubes = types.ModuleType("torchmcubes")
21
+ def marching_cubes_cpu(vertices, threshold):
22
+ v, f = mcubes.marching_cubes(vertices.detach().cpu().numpy(), threshold)
23
+ return torch.from_numpy(v.astype("float32")), torch.from_numpy(f.astype("int64"))
24
+ mock_torchmcubes.marching_cubes = marching_cubes_cpu
25
+ sys.modules["torchmcubes"] = mock_torchmcubes
26
  except ImportError:
27
  print("Error: PyMCubes no está en requirements.txt")
28
 
29
+ # --- IMPORTS DE TSR ---
30
+ from tsr.system import TSR
31
+ from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ HEADER = """# TripoSR Demo
34
+ <table bgcolor="#1E2432" cellspacing="0" cellpadding="0" width="450"><tr style="height:50px;"><td style="text-align: center;"><a href="https://stability.ai"><img src="https://images.squarespace-cdn.com/content/v1/6213c340453c3f502425776e/6c9c4c25-5410-4547-bc26-dc621cdacb25/Stability+AI+logo.png" width="200" height="40" /></a></td><td style="text-align: center;"><a href="https://www.tripo3d.ai"><img src="https://tripo-public.cdn.bcebos.com/logo.png" width="40" height="40" /></a></td></tr></table>
35
+ <table bgcolor="#1E2432" cellspacing="0" cellpadding="0" width="450"><tr style="height:30px;"><td style="text-align: center;"><a href="https://huggingface.co/stabilityai/TripoSR"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Model_Card-Huggingface-orange" height="20"></a></td><td style="text-align: center;"><a href="https://github.com/VAST-AI-Research/TripoSR"><img src="https://postimage.me/images/2024/03/04/GitHub_Logo_White.png" width="100" height="20"></a></td><td style="text-align: center; color: white;"><a href="https://arxiv.org/abs/2403.02151"><img src="https://img.shields.io/badge/arXiv-2403.02151-b31b1b.svg" height="20"></a></td></tr></table>
36
  > Try our new model: **SF3D** with several improvements such as faster generation and more game-ready assets.
37
+ > The model is available [here](https://huggingface.co/stabilityai/stable-fast-3d) and we also have a [demo](https://huggingface.co/spaces/stabilityai/stable-fast-3d).
38
+ **TripoSR** is a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image.
 
 
 
 
 
 
 
 
39
  """
40
 
 
41
  if torch.cuda.is_available():
42
  device = "cuda:0"
43
  else:
 
50
  )
51
  model.renderer.set_chunk_size(131072)
52
  model.to(device)
 
53
  rembg_session = rembg.new_session()
54
 
 
55
  def check_input_image(input_image):
56
  if input_image is None:
57
  raise gr.Error("No image uploaded!")
58
 
 
59
  def preprocess(input_image, do_remove_background, foreground_ratio):
60
  def fill_background(image):
61
  image = np.array(image).astype(np.float32) / 255.0
62
  image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
63
  image = Image.fromarray((image * 255.0).astype(np.uint8))
64
  return image
 
65
  if do_remove_background:
66
  image = input_image.convert("RGB")
67
  image = remove_background(image, rembg_session)
 
73
  image = fill_background(image)
74
  return image
75
 
 
76
  @spaces.GPU
77
  def generate(image, mc_resolution, formats=["obj", "glb"]):
78
  scene_codes = model(image, device=device)
79
  mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
80
  mesh = to_gradio_3d_orientation(mesh)
 
81
  mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f".glb", delete=False)
82
  mesh.export(mesh_path_glb.name)
 
83
  mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False)
84
+ mesh.apply_scale([-1, 1, 1])
85
  mesh.export(mesh_path_obj.name)
 
86
  return mesh_path_obj.name, mesh_path_glb.name
87
 
88
  def run_example(image_pil):
 
120
  minimum=32,
121
  maximum=320,
122
  value=256,
123
+ step=32
124
+ )
125
  with gr.Row():
126
  submit = gr.Button("Generate", elem_id="generate", variant="primary")
127
  with gr.Column():
 
130
  label="Output Model (OBJ Format)",
131
  interactive=False,
132
  )
 
133
  with gr.Tab("GLB"):
134
  output_model_glb = gr.Model3D(
135
  label="Output Model (GLB Format)",
136
  interactive=False,
137
  )
138
+
139
+ if os.path.exists("examples"):
140
+ with gr.Row(variant="panel"):
141
+ gr.Examples(
142
+ examples=[
143
+ os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
144
+ ] if os.path.exists("examples") else [],
145
+ inputs=[input_image],
146
+ outputs=[processed_image, output_model_obj, output_model_glb],
147
+ cache_examples=False,
148
+ fn=partial(run_example),
149
+ label="Examples",
150
+ examples_per_page=20
151
+ )
152
+
153
  submit.click(fn=check_input_image, inputs=[input_image]).success(
154
  fn=preprocess,
155
  inputs=[input_image, do_remove_background, foreground_ratio],
 
161
  )
162
 
163
  demo.queue(max_size=10)
164
+ demo.launch()