Vo Minh Vu commited on
Commit
f716caf
·
1 Parent(s): e6fd99c

refactor: move run.py code to app.py

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +24 -200
README.md CHANGED
@@ -75,7 +75,7 @@ with open("model.obj", "wb") as f:
75
  To run the application locally:
76
  ```bash
77
  pip install -r requirements.txt
78
- python run.py
79
  ```
80
 
81
  The API will be available at `http://localhost:7860` with automatic documentation at `http://localhost:7860/docs`.
 
75
  To run the application locally:
76
  ```bash
77
  pip install -r requirements.txt
78
+ python app.py
79
  ```
80
 
81
  The API will be available at `http://localhost:7860` with automatic documentation at `http://localhost:7860/docs`.
app.py CHANGED
@@ -1,205 +1,29 @@
1
- import logging
2
- import os
3
- import shlex
4
- import subprocess
5
- import tempfile
6
- import time
7
-
8
- import numpy as np
9
- import rembg
10
- import torch
11
- from PIL import Image
12
- from functools import partial
13
-
14
- # Install torchmcubes
15
- subprocess.run(shlex.split('pip install git+https://github.com/tatsy/torchmcubes.git'))
16
-
17
- from tsr.system import TSR
18
- from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
19
-
20
- # Import gradio after installing dependencies
21
- import gradio as gr
22
- import spaces
23
-
24
- # Import gradio components directly to avoid audio dependencies
25
- from gradio.blocks import Blocks
26
- from gradio.components import Image as GradioImage
27
- from gradio.components import Model3D, Markdown, Slider, Checkbox, Button
28
- from gradio.layouts import Row, Column, Group
29
- from gradio.themes import Base
30
-
31
- subprocess.run(shlex.split('pip install wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl'))
32
-
33
- from tsr.system import TSR
34
- from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
35
-
36
-
37
- HEADER = """
38
- # TripoSR Demo
39
- <table bgcolor="#1E2432" cellspacing="0" cellpadding="0" width="450">
40
- <tr style="height:50px;">
41
- <td style="text-align: center;">
42
- <a href="https://stability.ai">
43
- <img src="https://images.squarespace-cdn.com/content/v1/6213c340453c3f502425776e/6c9c4c25-5410-4547-bc26-dc621cdacb25/Stability+AI+logo.png" width="200" height="40" />
44
- </a>
45
- </td>
46
- <td style="text-align: center;">
47
- <a href="https://www.tripo3d.ai">
48
- <img src="https://tripo-public.cdn.bcebos.com/logo.png" width="40" height="40" />
49
- </a>
50
- </td>
51
- </tr>
52
- </table>
53
- <table bgcolor="#1E2432" cellspacing="0" cellpadding="0" width="450">
54
- <tr style="height:30px;">
55
- <td style="text-align: center;">
56
- <a href="https://huggingface.co/stabilityai/TripoSR"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Model_Card-Huggingface-orange" height="20"></a>
57
- </td>
58
- <td style="text-align: center;">
59
- <a href="https://github.com/VAST-AI-Research/TripoSR"><img src="https://postimage.me/images/2024/03/04/GitHub_Logo_White.png" width="100" height="20"></a>
60
- </td>
61
- <td style="text-align: center; color: white;">
62
- <a href="https://arxiv.org/abs/2403.02151"><img src="https://img.shields.io/badge/arXiv-2403.02151-b31b1b.svg" height="20"></a>
63
- </td>
64
- </tr>
65
- </table>
66
-
67
- > Try our new model: **SF3D** with several improvements such as faster generation and more game-ready assets.
68
- >
69
- > The model is available [here](https://huggingface.co/stabilityai/stable-fast-3d) and we also have a [demo](https://huggingface.co/spaces/stabilityai/stable-fast-3d).
70
-
71
-
72
- **TripoSR** is a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, developed in collaboration between [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
73
-
74
- **Tips:**
75
- 1. If you find the result is unsatisfied, please try to change the foreground ratio. It might improve the results.
76
- 2. It's better to disable "Remove Background" for the provided examples since they have been already preprocessed.
77
- 3. Otherwise, please disable "Remove Background" option only if your input image is RGBA with transparent background, image contents are centered and occupy more than 70% of image width or height.
78
- """
79
-
80
-
81
- if torch.cuda.is_available():
82
- device = "cuda:0"
83
- else:
84
- device = "cpu"
85
-
86
- model = TSR.from_pretrained(
87
- "stabilityai/TripoSR",
88
- config_name="config.yaml",
89
- weight_name="model.ckpt",
90
  )
91
- model.renderer.set_chunk_size(131072)
92
- model.to(device)
93
-
94
- rembg_session = rembg.new_session()
95
-
96
-
97
- def check_input_image(input_image):
98
- if input_image is None:
99
- raise gr.Error("No image uploaded!")
100
-
101
-
102
- def preprocess(input_image, do_remove_background, foreground_ratio):
103
- def fill_background(image):
104
- image = np.array(image).astype(np.float32) / 255.0
105
- image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
106
- image = Image.fromarray((image * 255.0).astype(np.uint8))
107
- return image
108
 
109
- if do_remove_background:
110
- image = input_image.convert("RGB")
111
- image = remove_background(image, rembg_session)
112
- image = resize_foreground(image, foreground_ratio)
113
- image = fill_background(image)
114
- else:
115
- image = input_image
116
- if image.mode == "RGBA":
117
- image = fill_background(image)
118
- return image
119
-
120
-
121
- @spaces.GPU
122
- def generate(image, mc_resolution, formats=["obj", "glb"]):
123
- scene_codes = model(image, device=device)
124
- mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
125
- mesh = to_gradio_3d_orientation(mesh)
126
-
127
- mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f".glb", delete=False)
128
- mesh.export(mesh_path_glb.name)
129
-
130
- # Export GLTF
131
- mesh_path_gltf = tempfile.NamedTemporaryFile(suffix=f".gltf", delete=False)
132
- mesh.export(mesh_path_gltf.name)
133
-
134
- mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False)
135
- mesh.apply_scale([-1, 1, 1]) # Otherwise the visualized .obj will be flipped
136
- mesh.export(mesh_path_obj.name)
137
-
138
- return mesh_path_obj.name, mesh_path_glb.name, mesh_path_gltf.name
139
 
140
- with gr.Blocks() as demo:
141
- gr.Markdown(HEADER)
142
- with gr.Row(variant="panel"):
143
- with gr.Column():
144
- with gr.Row():
145
- input_image = gr.Image(
146
- label="Input Image",
147
- image_mode="RGBA",
148
- sources="upload",
149
- type="pil",
150
- elem_id="content_image",
151
- )
152
- processed_image = gr.Image(label="Processed Image", interactive=False)
153
- with gr.Row():
154
- with gr.Group():
155
- do_remove_background = gr.Checkbox(
156
- label="Remove Background", value=True
157
- )
158
- foreground_ratio = gr.Slider(
159
- label="Foreground Ratio",
160
- minimum=0.5,
161
- maximum=1.0,
162
- value=0.85,
163
- step=0.05,
164
- )
165
- mc_resolution = gr.Slider(
166
- label="Marching Cubes Resolution",
167
- minimum=32,
168
- maximum=320,
169
- value=256,
170
- step=32
171
- )
172
- with gr.Row():
173
- submit = gr.Button("Generate", elem_id="generate", variant="primary")
174
- with gr.Column():
175
- with gr.Tab("OBJ"):
176
- output_model_obj = gr.Model3D(
177
- label="Output Model (OBJ Format)",
178
- interactive=False,
179
- )
180
- gr.Markdown("Note: Downloaded object will be flipped in case of .obj export. Export .glb instead or manually flip it before usage.")
181
- with gr.Tab("GLB"):
182
- output_model_glb = gr.Model3D(
183
- label="Output Model (GLB Format)",
184
- interactive=False,
185
- )
186
- gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
187
- with gr.Tab("GLTF"):
188
- output_model_gltf = gr.Model3D(
189
- label="Output Model (GLTF Format)",
190
- interactive=False,
191
- )
192
- gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
193
 
194
- submit.click(fn=check_input_image, inputs=[input_image]).success(
195
- fn=preprocess,
196
- inputs=[input_image, do_remove_background, foreground_ratio],
197
- outputs=[processed_image],
198
- ).success(
199
- fn=generate,
200
- inputs=[processed_image, mc_resolution],
201
- outputs=[output_model_obj, output_model_glb, output_model_gltf],
202
- )
203
 
204
- demo.queue(max_size=10)
205
- demo.launch()
 
1
+ import uvicorn
2
+ from fastapi import FastAPI
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+
5
+ app = FastAPI(
6
+ title="TripoSR API",
7
+ description="API for TripoSR 3D reconstruction from single images",
8
+ version="1.0.0"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # Configure CORS
12
+ app.add_middleware(
13
+ CORSMiddleware,
14
+ allow_origins=["*"], # Allows all origins
15
+ allow_credentials=True,
16
+ allow_methods=["*"], # Allows all methods
17
+ allow_headers=["*"], # Allows all headers
18
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ # Import and include routers
21
+ from app.api.endpoints import router as api_router
22
+ app.include_router(api_router, prefix="/api/v1")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ @app.get("/")
25
+ async def root():
26
+ return {"message": "Welcome to TripoSR API"}
 
 
 
 
 
 
27
 
28
+ if __name__ == "__main__":
29
+ uvicorn.run("app:app", host="0.0.0.0", port=7860)