Spaces:
Paused
Paused
Update gradio_app.py
Browse files- gradio_app.py +69 -0
gradio_app.py
CHANGED
|
@@ -214,6 +214,62 @@ def update_foreground_ratio(img_proc, fr):
|
|
| 214 |
foreground_res,
|
| 215 |
gr.update(value=show_mask_img(foreground_res)),
|
| 216 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
with gr.Blocks() as demo:
|
| 218 |
img_proc_state = gr.State()
|
| 219 |
background_remove_state = gr.State()
|
|
@@ -340,4 +396,17 @@ with gr.Blocks() as demo:
|
|
| 340 |
hdr_row,
|
| 341 |
],
|
| 342 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
demo.queue().launch(share=False)
|
|
|
|
| 214 |
foreground_res,
|
| 215 |
gr.update(value=show_mask_img(foreground_res)),
|
| 216 |
)
|
| 217 |
+
# Clean API endpoint that doesn't depend on internal state
|
| 218 |
+
def api_generate(
|
| 219 |
+
input_image,
|
| 220 |
+
foreground_ratio: float = 0.85,
|
| 221 |
+
remesh_option: str = "None",
|
| 222 |
+
vertex_count: int = -1,
|
| 223 |
+
texture_size: int = 1024,
|
| 224 |
+
):
|
| 225 |
+
"""
|
| 226 |
+
API endpoint for generating 3D models.
|
| 227 |
+
This endpoint handles all preprocessing internally.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
input_image: Input image (PIL Image)
|
| 231 |
+
foreground_ratio: Foreground ratio (0.5-1.0)
|
| 232 |
+
remesh_option: "None", "Triangle", or "Quad"
|
| 233 |
+
vertex_count: Target vertex count (-1 for auto)
|
| 234 |
+
texture_size: Texture size (512-2048)
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
Path to generated GLB file
|
| 238 |
+
"""
|
| 239 |
+
if input_image is None:
|
| 240 |
+
raise ValueError("No image provided. Please upload an image.")
|
| 241 |
+
|
| 242 |
+
# Preprocess the image
|
| 243 |
+
if hasattr(input_image, 'mode') and input_image.mode == 'RGBA':
|
| 244 |
+
alpha_channel = np.array(input_image.getchannel("A"))
|
| 245 |
+
if alpha_channel.min() == 0:
|
| 246 |
+
# Already has transparency, just resize
|
| 247 |
+
processed_image = sf3d_utils.resize_foreground(
|
| 248 |
+
input_image, foreground_ratio, out_size=(COND_WIDTH, COND_HEIGHT)
|
| 249 |
+
)
|
| 250 |
+
else:
|
| 251 |
+
# Need to remove background first
|
| 252 |
+
rem_removed = remove_background(input_image)
|
| 253 |
+
processed_image = sf3d_utils.resize_foreground(
|
| 254 |
+
rem_removed, foreground_ratio, out_size=(COND_WIDTH, COND_HEIGHT)
|
| 255 |
+
)
|
| 256 |
+
else:
|
| 257 |
+
# Not RGBA, need to remove background
|
| 258 |
+
rem_removed = remove_background(input_image)
|
| 259 |
+
processed_image = sf3d_utils.resize_foreground(
|
| 260 |
+
rem_removed, foreground_ratio, out_size=(COND_WIDTH, COND_HEIGHT)
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
# Generate 3D model
|
| 264 |
+
if torch.cuda.is_available():
|
| 265 |
+
torch.cuda.reset_peak_memory_stats()
|
| 266 |
+
|
| 267 |
+
glb_file = run_model(processed_image, remesh_option.lower(), vertex_count, texture_size)
|
| 268 |
+
|
| 269 |
+
if torch.cuda.is_available():
|
| 270 |
+
print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
|
| 271 |
+
|
| 272 |
+
return glb_file
|
| 273 |
with gr.Blocks() as demo:
|
| 274 |
img_proc_state = gr.State()
|
| 275 |
background_remove_state = gr.State()
|
|
|
|
| 396 |
hdr_row,
|
| 397 |
],
|
| 398 |
)
|
| 399 |
+
# Register clean API endpoint
|
| 400 |
+
api_interface = gr.Interface(
|
| 401 |
+
fn=api_generate,
|
| 402 |
+
inputs=[
|
| 403 |
+
gr.Image(type="pil", label="Input Image"),
|
| 404 |
+
gr.Number(value=0.85, label="Foreground Ratio"),
|
| 405 |
+
gr.Textbox(value="None", label="Remesh Option"),
|
| 406 |
+
gr.Number(value=-1, label="Vertex Count"),
|
| 407 |
+
gr.Number(value=1024, label="Texture Size"),
|
| 408 |
+
],
|
| 409 |
+
outputs=gr.File(label="3D Model (GLB)"),
|
| 410 |
+
api_name="generate", # This creates /api/generate endpoint
|
| 411 |
+
)
|
| 412 |
demo.queue().launch(share=False)
|