Spaces:
Runtime error
Runtime error
FrozenBurning
commited on
Commit
·
6cf1b17
1
Parent(s):
eb61402
Update app.py
Browse files
app.py
CHANGED
|
@@ -74,9 +74,21 @@ config.model.pop("latent_std")
|
|
| 74 |
model_primx = load_from_config(config.model)
|
| 75 |
# load rembg
|
| 76 |
rembg_session = rembg.new_session()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
# process function
|
| 79 |
-
def process(
|
| 80 |
# seed
|
| 81 |
torch.manual_seed(input_seed)
|
| 82 |
|
|
@@ -91,16 +103,8 @@ def process(input_image, input_num_steps, input_seed=42, input_cfg=6.0):
|
|
| 91 |
fwd_fn = model.forward_with_cfg
|
| 92 |
|
| 93 |
# text-conditioned
|
| 94 |
-
if
|
| 95 |
raise NotImplementedError
|
| 96 |
-
# image-conditioned (may also input text, but no text usually works too)
|
| 97 |
-
else:
|
| 98 |
-
input_image = remove_background(input_image, rembg_session)
|
| 99 |
-
input_image = resize_foreground(input_image, 0.85)
|
| 100 |
-
raw_image = np.array(input_image)
|
| 101 |
-
mask = (raw_image[..., -1][..., None] > 0) * 1
|
| 102 |
-
raw_image = raw_image[..., :3] * mask
|
| 103 |
-
input_cond = torch.from_numpy(np.array(raw_image)[None, ...]).to(device)
|
| 104 |
|
| 105 |
with torch.no_grad():
|
| 106 |
latent = torch.randn(1, config.model.num_prims, 1, 4, 4, 4)
|
|
@@ -178,8 +182,11 @@ with block:
|
|
| 178 |
|
| 179 |
with gr.Row(variant='panel'):
|
| 180 |
with gr.Column(scale=1):
|
| 181 |
-
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
| 183 |
# inference steps
|
| 184 |
input_num_steps = gr.Radio(choices=[25, 50, 100, 200], label="DDIM steps", value=25)
|
| 185 |
# random seed
|
|
@@ -187,7 +194,7 @@ with block:
|
|
| 187 |
# random seed
|
| 188 |
input_seed = gr.Slider(label="random seed", minimum=0, maximum=10000, step=1, value=42, info="Try different seed if the result is not satisfying as this is a generative model!")
|
| 189 |
# gen button
|
| 190 |
-
button_gen = gr.Button("Generate")
|
| 191 |
export_glb_btn = gr.Button(value="Export GLB", interactive=False)
|
| 192 |
|
| 193 |
with gr.Column(scale=1):
|
|
@@ -231,15 +238,16 @@ with block:
|
|
| 231 |
outputs=[output_glb],
|
| 232 |
)
|
| 233 |
|
| 234 |
-
|
|
|
|
|
|
|
| 235 |
|
| 236 |
export_glb_btn.click(export_mesh, inputs=[], outputs=[output_glb, hdr_row])
|
| 237 |
|
| 238 |
gr.Examples(
|
| 239 |
examples=[
|
| 240 |
-
"assets/examples
|
| 241 |
-
"assets/examples
|
| 242 |
-
"assets/examples/shuai_panda_notail.png",
|
| 243 |
],
|
| 244 |
inputs=[input_image],
|
| 245 |
outputs=[output_rgb_video, output_prim_video, output_mat_video, export_glb_btn],
|
|
|
|
| 74 |
model_primx = load_from_config(config.model)
|
| 75 |
# load rembg
|
| 76 |
rembg_session = rembg.new_session()
|
| 77 |
+
current_fg_state = None
|
| 78 |
+
|
| 79 |
+
# background removal function
|
| 80 |
+
def background_remove_process(input_image):
|
| 81 |
+
input_image = remove_background(input_image, rembg_session)
|
| 82 |
+
input_image = resize_foreground(input_image, 0.85)
|
| 83 |
+
input_cond_preview_pil = input_image
|
| 84 |
+
raw_image = np.array(input_image)
|
| 85 |
+
mask = (raw_image[..., -1][..., None] > 0) * 1
|
| 86 |
+
raw_image = raw_image[..., :3] * mask
|
| 87 |
+
input_cond = torch.from_numpy(np.array(raw_image)[None, ...]).to(device)
|
| 88 |
+
return gr.update(interactive=True), input_cond, input_cond_preview_pil
|
| 89 |
|
| 90 |
# process function
|
| 91 |
+
def process(input_cond, input_num_steps, input_seed=42, input_cfg=6.0):
|
| 92 |
# seed
|
| 93 |
torch.manual_seed(input_seed)
|
| 94 |
|
|
|
|
| 103 |
fwd_fn = model.forward_with_cfg
|
| 104 |
|
| 105 |
# text-conditioned
|
| 106 |
+
if input_cond is None:
|
| 107 |
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
with torch.no_grad():
|
| 110 |
latent = torch.randn(1, config.model.num_prims, 1, 4, 4, 4)
|
|
|
|
| 182 |
|
| 183 |
with gr.Row(variant='panel'):
|
| 184 |
with gr.Column(scale=1):
|
| 185 |
+
with gr.Row():
|
| 186 |
+
# input image
|
| 187 |
+
input_image = gr.Image(label="image", type='pil')
|
| 188 |
+
# background removal
|
| 189 |
+
removal_previewer = gr.Image(label="Background Removal Preview", type='pil', interactive=False)
|
| 190 |
# inference steps
|
| 191 |
input_num_steps = gr.Radio(choices=[25, 50, 100, 200], label="DDIM steps", value=25)
|
| 192 |
# random seed
|
|
|
|
| 194 |
# random seed
|
| 195 |
input_seed = gr.Slider(label="random seed", minimum=0, maximum=10000, step=1, value=42, info="Try different seed if the result is not satisfying as this is a generative model!")
|
| 196 |
# gen button
|
| 197 |
+
button_gen = gr.Button(value="Generate", interactive=False)
|
| 198 |
export_glb_btn = gr.Button(value="Export GLB", interactive=False)
|
| 199 |
|
| 200 |
with gr.Column(scale=1):
|
|
|
|
| 238 |
outputs=[output_glb],
|
| 239 |
)
|
| 240 |
|
| 241 |
+
input_image.change(background_remove_process, inputs=[input_image], outputs=[button_gen, current_fg_state, removal_previewer])
|
| 242 |
+
|
| 243 |
+
button_gen.click(process, inputs=[current_fg_state, input_num_steps, input_seed, input_cfg], outputs=[output_rgb_video, output_prim_video, output_mat_video, export_glb_btn])
|
| 244 |
|
| 245 |
export_glb_btn.click(export_mesh, inputs=[], outputs=[output_glb, hdr_row])
|
| 246 |
|
| 247 |
gr.Examples(
|
| 248 |
examples=[
|
| 249 |
+
os.path.join("assets/examples", f)
|
| 250 |
+
for f in os.listdir("assets/examples")
|
|
|
|
| 251 |
],
|
| 252 |
inputs=[input_image],
|
| 253 |
outputs=[output_rgb_video, output_prim_video, output_mat_video, export_glb_btn],
|