Vo Minh Vu commited on
Commit
50cb462
·
1 Parent(s): 3aa9a74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -151
app.py CHANGED
@@ -1,73 +1,38 @@
1
- import logging
2
  import os
3
  import shlex
4
  import subprocess
5
  import tempfile
6
- import time
 
 
 
 
 
7
 
8
- import gradio as gr
9
  import numpy as np
10
  import rembg
11
- import spaces
12
  import torch
13
  from PIL import Image
14
- from functools import partial
15
-
16
- subprocess.run(shlex.split('pip install wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl'))
17
 
18
  from tsr.system import TSR
19
  from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
20
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- HEADER = """
23
- # TripoSR Demo
24
- <table bgcolor="#1E2432" cellspacing="0" cellpadding="0" width="450">
25
- <tr style="height:50px;">
26
- <td style="text-align: center;">
27
- <a href="https://stability.ai">
28
- <img src="https://images.squarespace-cdn.com/content/v1/6213c340453c3f502425776e/6c9c4c25-5410-4547-bc26-dc621cdacb25/Stability+AI+logo.png" width="200" height="40" />
29
- </a>
30
- </td>
31
- <td style="text-align: center;">
32
- <a href="https://www.tripo3d.ai">
33
- <img src="https://tripo-public.cdn.bcebos.com/logo.png" width="40" height="40" />
34
- </a>
35
- </td>
36
- </tr>
37
- </table>
38
- <table bgcolor="#1E2432" cellspacing="0" cellpadding="0" width="450">
39
- <tr style="height:30px;">
40
- <td style="text-align: center;">
41
- <a href="https://huggingface.co/stabilityai/TripoSR"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Model_Card-Huggingface-orange" height="20"></a>
42
- </td>
43
- <td style="text-align: center;">
44
- <a href="https://github.com/VAST-AI-Research/TripoSR"><img src="https://postimage.me/images/2024/03/04/GitHub_Logo_White.png" width="100" height="20"></a>
45
- </td>
46
- <td style="text-align: center; color: white;">
47
- <a href="https://arxiv.org/abs/2403.02151"><img src="https://img.shields.io/badge/arXiv-2403.02151-b31b1b.svg" height="20"></a>
48
- </td>
49
- </tr>
50
- </table>
51
-
52
- > Try our new model: **SF3D** with several improvements such as faster generation and more game-ready assets.
53
- >
54
- > The model is available [here](https://huggingface.co/stabilityai/stable-fast-3d) and we also have a [demo](https://huggingface.co/spaces/stabilityai/stable-fast-3d).
55
-
56
-
57
- **TripoSR** is a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, developed in collaboration between [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
58
-
59
- **Tips:**
60
- 1. If you find the result is unsatisfied, please try to change the foreground ratio. It might improve the results.
61
- 2. It's better to disable "Remove Background" for the provided examples since they have been already preprocessed.
62
- 3. Otherwise, please disable "Remove Background" option only if your input image is RGBA with transparent background, image contents are centered and occupy more than 70% of image width or height.
63
- """
64
-
65
-
66
- if torch.cuda.is_available():
67
- device = "cuda:0"
68
- else:
69
- device = "cpu"
70
 
 
71
  model = TSR.from_pretrained(
72
  "stabilityai/TripoSR",
73
  config_name="config.yaml",
@@ -76,20 +41,26 @@ model = TSR.from_pretrained(
76
  model.renderer.set_chunk_size(131072)
77
  model.to(device)
78
 
 
79
  rembg_session = rembg.new_session()
80
 
81
 
82
- def check_input_image(input_image):
83
- if input_image is None:
84
- raise gr.Error("No image uploaded!")
85
 
86
 
87
- def preprocess(input_image, do_remove_background, foreground_ratio):
88
- def fill_background(image):
89
- image = np.array(image).astype(np.float32) / 255.0
90
- image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
91
- image = Image.fromarray((image * 255.0).astype(np.uint8))
92
- return image
 
 
 
 
 
93
 
94
  if do_remove_background:
95
  image = input_image.convert("RGB")
@@ -100,97 +71,101 @@ def preprocess(input_image, do_remove_background, foreground_ratio):
100
  image = input_image
101
  if image.mode == "RGBA":
102
  image = fill_background(image)
 
103
  return image
104
 
105
 
106
- @spaces.GPU
107
- def generate(image, mc_resolution, formats=["obj", "glb"]):
 
 
 
 
 
 
108
  scene_codes = model(image, device=device)
109
  mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
110
  mesh = to_gradio_3d_orientation(mesh)
111
 
112
- mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f".glb", delete=False)
113
- mesh.export(mesh_path_glb.name)
114
-
115
- mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False)
116
- mesh.apply_scale([-1, 1, 1]) # Otherwise the visualized .obj will be flipped
117
- mesh.export(mesh_path_obj.name)
118
-
119
- return mesh_path_obj.name, mesh_path_glb.name
120
-
121
- def run_example(image_pil):
122
- preprocessed = preprocess(image_pil, False, 0.9)
123
- mesh_name_obj, mesh_name_glb = generate(preprocessed, 256, ["obj", "glb"])
124
- return preprocessed, mesh_name_obj, mesh_name_glb
125
-
126
- with gr.Blocks() as demo:
127
- gr.Markdown(HEADER)
128
- with gr.Row(variant="panel"):
129
- with gr.Column():
130
- with gr.Row():
131
- input_image = gr.Image(
132
- label="Input Image",
133
- image_mode="RGBA",
134
- sources="upload",
135
- type="pil",
136
- elem_id="content_image",
137
- )
138
- processed_image = gr.Image(label="Processed Image", interactive=False)
139
- with gr.Row():
140
- with gr.Group():
141
- do_remove_background = gr.Checkbox(
142
- label="Remove Background", value=True
143
- )
144
- foreground_ratio = gr.Slider(
145
- label="Foreground Ratio",
146
- minimum=0.5,
147
- maximum=1.0,
148
- value=0.85,
149
- step=0.05,
150
- )
151
- mc_resolution = gr.Slider(
152
- label="Marching Cubes Resolution",
153
- minimum=32,
154
- maximum=320,
155
- value=256,
156
- step=32
157
- )
158
- with gr.Row():
159
- submit = gr.Button("Generate", elem_id="generate", variant="primary")
160
- with gr.Column():
161
- with gr.Tab("OBJ"):
162
- output_model_obj = gr.Model3D(
163
- label="Output Model (OBJ Format)",
164
- interactive=False,
165
- )
166
- gr.Markdown("Note: Downloaded object will be flipped in case of .obj export. Export .glb instead or manually flip it before usage.")
167
- with gr.Tab("GLB"):
168
- output_model_glb = gr.Model3D(
169
- label="Output Model (GLB Format)",
170
- interactive=False,
171
- )
172
- gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
173
- with gr.Row(variant="panel"):
174
- gr.Examples(
175
- examples=[
176
- os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
177
- ],
178
- inputs=[input_image],
179
- outputs=[processed_image, output_model_obj, output_model_glb],
180
- cache_examples=True,
181
- fn=partial(run_example),
182
- label="Examples",
183
- examples_per_page=20
184
- )
185
- submit.click(fn=check_input_image, inputs=[input_image]).success(
186
- fn=preprocess,
187
- inputs=[input_image, do_remove_background, foreground_ratio],
188
- outputs=[processed_image],
189
- ).success(
190
- fn=generate,
191
- inputs=[processed_image, mc_resolution],
192
- outputs=[output_model_obj, output_model_glb],
193
- )
194
-
195
- demo.queue(max_size=10)
196
- demo.launch()
 
1
+ import io
2
  import os
3
  import shlex
4
  import subprocess
5
  import tempfile
6
+ import zipfile
7
+ from functools import partial
8
+
9
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
10
+ from fastapi.responses import StreamingResponse
11
+ from fastapi.middleware.cors import CORSMiddleware
12
 
 
13
  import numpy as np
14
  import rembg
 
15
  import torch
16
  from PIL import Image
 
 
 
17
 
18
  from tsr.system import TSR
19
  from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
20
 
21
+ # ------------------------------------------------------------
22
+ # 1. Model & utils initialization (runs at startup)
23
+ # ------------------------------------------------------------
24
+ # Install any local wheels (if needed)
25
+ subprocess.run(
26
+ shlex.split(
27
+ "pip install wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl"
28
+ ),
29
+ check=False,
30
+ )
31
 
32
+ # device
33
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ # load model
36
  model = TSR.from_pretrained(
37
  "stabilityai/TripoSR",
38
  config_name="config.yaml",
 
41
  model.renderer.set_chunk_size(131072)
42
  model.to(device)
43
 
44
+ # background removal
45
  rembg_session = rembg.new_session()
46
 
47
 
48
+ def check_input_image(image: Image.Image):
49
+ if image is None:
50
+ raise HTTPException(status_code=400, detail="No image uploaded!")
51
 
52
 
53
+ def preprocess(
54
+ input_image: Image.Image, do_remove_background: bool, foreground_ratio: float
55
+ ) -> Image.Image:
56
+ """
57
+ Mimics the Gradio preprocess(...) function.
58
+ """
59
+ def fill_background(image: Image.Image) -> Image.Image:
60
+ arr = np.array(image).astype(np.float32) / 255.0
61
+ arr = arr[:, :, :3] * arr[:, :, 3:4] + (1 - arr[:, :, 3:4]) * 0.5
62
+ out = (arr * 255.0).astype(np.uint8)
63
+ return Image.fromarray(out)
64
 
65
  if do_remove_background:
66
  image = input_image.convert("RGB")
 
71
  image = input_image
72
  if image.mode == "RGBA":
73
  image = fill_background(image)
74
+
75
  return image
76
 
77
 
78
+ def generate(
79
+ image: Image.Image, mc_resolution: int, formats=["obj", "glb"]
80
+ ) -> tuple[str, str]:
81
+ """
82
+ Mimics the Gradio generate(...) function.
83
+ Returns paths to .obj and .glb on disk.
84
+ """
85
+ # 1. inference
86
  scene_codes = model(image, device=device)
87
  mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
88
  mesh = to_gradio_3d_orientation(mesh)
89
 
90
+ # 2. export GLB
91
+ glb_tmp = tempfile.NamedTemporaryFile(suffix=".glb", delete=False)
92
+ mesh.export(glb_tmp.name)
93
+
94
+ # 3. export OBJ (flip x-axis so OBJ is not mirrored)
95
+ obj_tmp = tempfile.NamedTemporaryFile(suffix=".obj", delete=False)
96
+ mesh.apply_scale([-1, 1, 1])
97
+ mesh.export(obj_tmp.name)
98
+
99
+ return obj_tmp.name, glb_tmp.name
100
+
101
+
102
+ # ------------------------------------------------------------
103
+ # 2. FastAPI app
104
+ # ------------------------------------------------------------
105
+ app = FastAPI(title="TripoSR FastAPI Demo")
106
+
107
+ # If you need CORS (e.g. calling from a browser-based front-end)
108
+ app.add_middleware(
109
+ CORSMiddleware,
110
+ allow_origins=["*"],
111
+ allow_methods=["POST", "GET", "OPTIONS"],
112
+ allow_headers=["*"],
113
+ )
114
+
115
+
116
+ @app.post("/generate", response_class=StreamingResponse)
117
+ async def generate_endpoint(
118
+ image_file: UploadFile = File(...),
119
+ do_remove_background: bool = Form(True),
120
+ foreground_ratio: float = Form(0.85),
121
+ mc_resolution: int = Form(256),
122
+ ):
123
+ """
124
+ 1. Read & validate image
125
+ 2. Preprocess
126
+ 3. Generate mesh
127
+ 4. Package processed image + .obj + .glb into a ZIP
128
+ """
129
+ # 1) Read image bytes
130
+ contents = await image_file.read()
131
+ try:
132
+ pil_img = Image.open(io.BytesIO(contents))
133
+ except Exception:
134
+ raise HTTPException(status_code=400, detail="Invalid image file")
135
+
136
+ check_input_image(pil_img)
137
+
138
+ # 2) Preprocess
139
+ processed = preprocess(pil_img, do_remove_background, foreground_ratio)
140
+
141
+ # 3) Generate mesh
142
+ obj_path, glb_path = generate(processed, mc_resolution)
143
+
144
+ # 4) Create in-memory ZIP
145
+ zip_buffer = io.BytesIO()
146
+ with zipfile.ZipFile(zip_buffer, mode="w") as zf:
147
+ # processed image
148
+ buf = io.BytesIO()
149
+ processed.save(buf, format="PNG")
150
+ zf.writestr("processed.png", buf.getvalue())
151
+
152
+ # .obj
153
+ with open(obj_path, "rb") as f:
154
+ zf.writestr(os.path.basename(obj_path), f.read())
155
+
156
+ # .glb
157
+ with open(glb_path, "rb") as f:
158
+ zf.writestr(os.path.basename(glb_path), f.read())
159
+
160
+ zip_buffer.seek(0)
161
+
162
+ # Cleanup temp files
163
+ os.remove(obj_path)
164
+ os.remove(glb_path)
165
+
166
+ headers = {
167
+ "Content-Disposition": 'attachment; filename="tripo_output.zip"'
168
+ }
169
+ return StreamingResponse(
170
+ zip_buffer, media_type="application/zip", headers=headers
171
+ )