Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -80,7 +80,7 @@ from src.models.briarmbg import BriaRMBG
|
|
| 80 |
|
| 81 |
# Constants
|
| 82 |
MAX_NUM_PARTS = 16
|
| 83 |
-
DEVICE = "cuda"
|
| 84 |
DTYPE = torch.float16
|
| 85 |
|
| 86 |
# Download and initialize models
|
|
@@ -95,22 +95,22 @@ pipe: PartCrafterPipeline = PartCrafterPipeline.from_pretrained(partcrafter_weig
|
|
| 95 |
|
| 96 |
@spaces.GPU()
|
| 97 |
@torch.no_grad()
|
| 98 |
-
def run_triposg(
|
| 99 |
-
num_parts: int,
|
| 100 |
-
seed: int,
|
| 101 |
-
num_tokens: int,
|
| 102 |
-
num_inference_steps: int,
|
| 103 |
-
guidance_scale: float,
|
| 104 |
-
max_num_expanded_coords: float,
|
| 105 |
-
use_flash_decoder: bool,
|
| 106 |
-
rmbg: bool):
|
| 107 |
"""
|
| 108 |
Generate 3D part meshes from an input image.
|
| 109 |
"""
|
| 110 |
if rmbg:
|
| 111 |
-
img_pil = prepare_image(
|
| 112 |
else:
|
| 113 |
-
img_pil =
|
| 114 |
|
| 115 |
set_seed(seed)
|
| 116 |
start_time = time.time()
|
|
@@ -159,7 +159,7 @@ def build_demo():
|
|
| 159 |
)
|
| 160 |
with gr.Row():
|
| 161 |
with gr.Column(scale=1):
|
| 162 |
-
input_image = gr.Image(type="
|
| 163 |
num_parts = gr.Slider(1, MAX_NUM_PARTS, value=4, step=1, label="Number of Parts")
|
| 164 |
seed = gr.Number(value=0, label="Random Seed", precision=0)
|
| 165 |
num_tokens = gr.Slider(256, 2048, value=1024, step=64, label="Num Tokens")
|
|
|
|
| 80 |
|
| 81 |
# Constants
|
| 82 |
MAX_NUM_PARTS = 16
|
| 83 |
+
DEVICE = "cuda"
|
| 84 |
DTYPE = torch.float16
|
| 85 |
|
| 86 |
# Download and initialize models
|
|
|
|
| 95 |
|
| 96 |
@spaces.GPU()
|
| 97 |
@torch.no_grad()
|
| 98 |
+
def run_triposg(image_path: str,
|
| 99 |
+
num_parts: int = 10,
|
| 100 |
+
seed: int = 123,
|
| 101 |
+
num_tokens: int = 1024,
|
| 102 |
+
num_inference_steps: int = 50,
|
| 103 |
+
guidance_scale: float = 7.0,
|
| 104 |
+
max_num_expanded_coords: float = 1e9,
|
| 105 |
+
use_flash_decoder: bool = False,
|
| 106 |
+
rmbg: bool = True):
|
| 107 |
"""
|
| 108 |
Generate 3D part meshes from an input image.
|
| 109 |
"""
|
| 110 |
if rmbg:
|
| 111 |
+
img_pil = prepare_image(image_path, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
|
| 112 |
else:
|
| 113 |
+
img_pil = Image.open(image_path_or_pil)
|
| 114 |
|
| 115 |
set_seed(seed)
|
| 116 |
start_time = time.time()
|
|
|
|
| 159 |
)
|
| 160 |
with gr.Row():
|
| 161 |
with gr.Column(scale=1):
|
| 162 |
+
input_image = gr.Image(type="filepath", label="Input Image")
|
| 163 |
num_parts = gr.Slider(1, MAX_NUM_PARTS, value=4, step=1, label="Number of Parts")
|
| 164 |
seed = gr.Number(value=0, label="Random Seed", precision=0)
|
| 165 |
num_tokens = gr.Slider(256, 2048, value=1024, step=64, label="Num Tokens")
|