SegviGen / app.py
fenghora's picture
git init
c3524cd
import os
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["ATTN_BACKEND"] = "flash_attn_3"
import urllib.request
os.makedirs("pretrained_model", exist_ok=True)
CKPT_FULL_SEG = "pretrained_model/full_seg.ckpt"
CKPT_W_2D_MAP = "pretrained_model/full_seg_w_2d_map.ckpt"
if not os.path.exists(CKPT_FULL_SEG):
urllib.request.urlretrieve(
"https://huggingface.co/fenghora/SegviGen/resolve/main/full_seg.ckpt",
CKPT_FULL_SEG,
)
if not os.path.exists(CKPT_W_2D_MAP):
urllib.request.urlretrieve(
"https://huggingface.co/fenghora/SegviGen/resolve/main/full_seg_w_2d_map.ckpt",
CKPT_W_2D_MAP,
)
import shutil
import traceback
from datetime import datetime
from pathlib import Path
from typing import List
import inference_full as inf
import split as splitter
TRANSFORMS_JSON = "./data_toolkit/transforms.json"
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
TMP_DIR = os.path.join(ROOT_DIR, "_tmp_gradio_seg")
EXAMPLES_CACHE_DIR = os.path.join(TMP_DIR, "examples_cache")
os.makedirs(TMP_DIR, exist_ok=True)
os.makedirs(EXAMPLES_CACHE_DIR, exist_ok=True)
os.environ["GRADIO_TEMP_DIR"] = TMP_DIR
os.environ["GRADIO_EXAMPLES_CACHE"] = EXAMPLES_CACHE_DIR
import gradio as gr
EXAMPLES_DIR = "examples"
def _ensure_dir(p: str):
os.makedirs(p, exist_ok=True)
def _normalize_path(x):
"""
Compatible with different Gradio versions: File/Model3D might be str / dict / object
"""
if x is None:
return None
if isinstance(x, str):
return x
if isinstance(x, dict):
return x.get("name") or x.get("path") or x.get("data")
return getattr(x, "name", None) or getattr(x, "path", None) or None
def _raise_user_error(msg: str):
if hasattr(gr, "Error"):
raise gr.Error(msg)
raise RuntimeError(msg)
def _collect_examples(example_dir: str) -> List[List[str]]:
"""
Scan example_dir for pairs: <name>.glb + <name>.png
Return a list of examples: [[glb_path, png_path], ...]
"""
d = Path(example_dir)
if not d.is_dir():
return []
examples: List[List[str]] = []
# Search recursively in case you add subfolders later
glb_files = sorted(d.rglob("*.glb"))
for glb_path in glb_files:
png_path = glb_path.with_suffix(".png")
if png_path.is_file():
examples.append([str(glb_path), str(png_path)])
# If png is missing, skip to keep examples consistent (2 inputs required)
return examples
# Build examples once at startup
FULL_SEG_EXAMPLES = _collect_examples(EXAMPLES_DIR)
def _update_img_box(mode: str):
is_generate = str(mode).startswith("Generate")
if is_generate:
return gr.update(
interactive=False,
label="2D Segmentation Map",
value=None,
)
return gr.update(
interactive=True,
label="2D Segmentation Map",
value=None,
)
def run_seg(glb_in, map_mode, img_in):
"""
Segment button: generates whole segmented GLB and displays in the second box.
Auto mode:
- If image is provided -> use CKPT_W_2D_MAP
- If image is not provided -> keep original logic and use CKPT_FULL_SEG
Generate mode:
- Generate a 2D map first
- Use the generated map as if it were the uploaded image
- Therefore use CKPT_W_2D_MAP
Returns:
segmented_glb_path, segmented_glb_path(state), image_preview
"""
try:
glb_path = _normalize_path(glb_in)
img_path = _normalize_path(img_in)
if glb_path is None or (not os.path.isfile(glb_path)):
_raise_user_error("Please upload a valid .glb file.")
_ensure_dir(TMP_DIR)
run_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
workdir = os.path.join(TMP_DIR, run_id)
_ensure_dir(workdir)
in_glb = os.path.join(workdir, "input.glb")
shutil.copy(glb_path, in_glb)
out_glb = os.path.join(workdir, "segmented.glb")
in_vxz = os.path.join(workdir, "input.vxz")
effective_img_path = None
is_generate = str(map_mode).startswith("Generate")
# New path: generate a 2D map first, then use it as input image
if is_generate:
render_img = os.path.join(workdir, "render.png")
generated_img = os.path.join(workdir, "2d_map_generated.png")
inf.generate_2d_map_from_glb(
glb_path=in_glb,
transforms_path=TRANSFORMS_JSON,
out_img_path=generated_img,
render_img_path=render_img,
)
if not os.path.isfile(generated_img):
_raise_user_error("2D map generation failed: generated image not found.")
effective_img_path = generated_img
# Original logic is preserved here
elif img_path is not None and os.path.isfile(img_path):
copied_img = os.path.join(workdir, "2d_map.png")
shutil.copy(img_path, copied_img)
effective_img_path = copied_img
# Keep the original branching logic
if effective_img_path is not None and os.path.isfile(effective_img_path):
ckpt = CKPT_W_2D_MAP
item = {
"2d_map": True,
"glb": in_glb,
"input_vxz": in_vxz,
"img": effective_img_path,
"export_glb": out_glb,
}
preview_img = effective_img_path
else:
ckpt = CKPT_FULL_SEG
render_img = os.path.join(workdir, "render.png")
item = {
"2d_map": False,
"glb": in_glb,
"input_vxz": in_vxz,
"transforms": TRANSFORMS_JSON,
"img": render_img,
"export_glb": out_glb,
}
preview_img = None
inf.inference_with_loaded_models(ckpt, item)
if not os.path.isfile(out_glb):
_raise_user_error("Export failed: output glb not found.")
# Apply X90 rotation for whole segmented output
# _apply_root_x90_rotation_glb(out_glb)
return out_glb, out_glb, preview_img
except Exception as e:
err = "".join(traceback.format_exception(type(e), e, e.__traceback__))
print(err)
raise
def run_refine_segmentation(
seg_glb_path_state,
color_quant_step,
palette_sample_pixels,
palette_min_pixels,
palette_max_colors,
palette_merge_dist,
samples_per_face,
flip_v,
uv_wrap_repeat,
transition_conf_thresh,
transition_prop_iters,
transition_neighbor_min,
small_component_action,
small_component_min_faces,
postprocess_iters,
min_faces_per_part,
bake_transforms,
):
"""
Refine Segmentation button: splits the segmented GLB into smaller parts GLB and displays in the fourth box.
"""
try:
seg_glb_path = seg_glb_path_state if isinstance(seg_glb_path_state, str) else None
if (seg_glb_path is None) or (not os.path.isfile(seg_glb_path)):
_raise_user_error("Please run Segmentation first (the segmented GLB is missing).")
out_dir = os.path.dirname(seg_glb_path)
out_parts_glb = os.path.join(out_dir, "segmented_parts.glb")
splitter.split_glb_by_texture_palette_rgb(
in_glb_path=seg_glb_path,
out_glb_path=out_parts_glb,
min_faces_per_part=min_faces_per_part,
bake_transforms=bool(bake_transforms),
color_quant_step=color_quant_step,
palette_sample_pixels=palette_sample_pixels,
palette_min_pixels=palette_min_pixels,
palette_max_colors=palette_max_colors,
palette_merge_dist=palette_merge_dist,
samples_per_face=samples_per_face,
flip_v=flip_v,
uv_wrap_repeat=uv_wrap_repeat,
transition_conf_thresh=transition_conf_thresh,
transition_prop_iters=transition_prop_iters,
transition_neighbor_min=transition_neighbor_min,
small_component_action=small_component_action,
small_component_min_faces=small_component_min_faces,
postprocess_iters=postprocess_iters,
debug_print=True,
)
if not os.path.isfile(out_parts_glb):
_raise_user_error("Split failed: output parts glb not found.")
# If bake_transforms=False, split output will not have the wrapper transform baked, so we need to apply X90 rotation fix
# if (not bool(bake_transforms)) and APPLY_OUTPUT_X90_FIX:
# _apply_root_x90_rotation_glb(out_parts_glb)
return out_parts_glb
except Exception as e:
err = "".join(traceback.format_exception(type(e), e, e.__traceback__))
print(err)
raise
CSS_TEXT = """
<style>
#in_glb { height: 520px !important; }
#seg_glb { height: 520px !important; }
#part_glb{ height: 520px !important; }
#img { height: 520px !important; }
</style>
"""
with gr.Blocks() as demo:
gr.HTML(CSS_TEXT)
gr.Markdown(
"""
# SegviGen: Repurposing 3D Generative Model for Part Segmentation
"""
)
# ---------------- 2x2 Layout ----------------
with gr.Row():
with gr.Column(scale=1, min_width=260):
in_glb = gr.Model3D(label="Input GLB", elem_id="in_glb")
with gr.Column(scale=1, min_width=260):
seg_glb = gr.Model3D(label="Processed GLB", elem_id="seg_glb")
with gr.Row():
with gr.Column(scale=1, min_width=260):
with gr.Accordion("2D Segmentation Map (Optional)", open=False):
map_mode = gr.Radio(
choices=["Upload", "Generate (Use FLUX.2 to generate segmentation map)"],
value="Upload",
label="2D Map Mode",
)
in_img = gr.Image(
label="2D Segmentation Map",
type="filepath",
elem_id="img",
interactive=True,
)
seg_btn = gr.Button("Process", variant="primary")
# ✅ Examples directly under the Process button
if FULL_SEG_EXAMPLES:
gr.Examples(
examples=FULL_SEG_EXAMPLES,
inputs=[in_glb, in_img],
label="Examples",
examples_per_page=3,
cache_examples=False,
)
else:
gr.Markdown(f"**No examples found** in: `{EXAMPLES_DIR}` (expected: `*.glb` + same-name `*.png`).")
with gr.Accordion("Advanced segmentation options", open=False):
def _g(name, default):
return getattr(splitter, name, default)
color_quant_step = gr.Slider(
1, 64, value=_g("COLOR_QUANT_STEP", 16), step=1, label="COLOR_QUANT_STEP"
)
gr.Markdown(
"*COLOR_QUANT_STEP controls the RGB quantization step, where a larger value merges similar colors more aggressively and a smaller value preserves finer color differences.*"
)
palette_sample_pixels = gr.Number(
value=_g("PALETTE_SAMPLE_PIXELS", 2_000_000), precision=0, label="PALETTE_SAMPLE_PIXELS"
)
gr.Markdown(
"*PALETTE_SAMPLE_PIXELS sets the maximum number of sampled pixels used to estimate the palette, where more samples improve stability but increase runtime.*"
)
palette_min_pixels = gr.Number(
value=_g("PALETTE_MIN_PIXELS", 500), precision=0, label="PALETTE_MIN_PIXELS"
)
gr.Markdown(
"*PALETTE_MIN_PIXELS specifies the minimum pixel count required to keep a color in the palette, where a higher threshold suppresses noise but may discard small parts.*"
)
palette_max_colors = gr.Number(
value=_g("PALETTE_MAX_COLORS", 256), precision=0, label="PALETTE_MAX_COLORS"
)
gr.Markdown(
"*PALETTE_MAX_COLORS limits the maximum number of colors retained in the palette, where a larger limit yields finer partitions and a smaller limit enforces stronger merging.*"
)
palette_merge_dist = gr.Number(
value=_g("PALETTE_MERGE_DIST", 32), precision=0, label="PALETTE_MERGE_DIST"
)
gr.Markdown(
"*PALETTE_MERGE_DIST defines the distance threshold for merging nearby palette colors in RGB space, where a larger threshold merges near duplicates more often and a smaller threshold keeps colors distinct.*"
)
samples_per_face = gr.Dropdown(
choices=[1, 4], value=_g("SAMPLES_PER_FACE", 4), label="SAMPLES_PER_FACE"
)
gr.Markdown(
"*SAMPLES_PER_FACE sets the number of UV samples per triangle used for label voting, where more samples improve robustness near boundaries but increase computation.*"
)
flip_v = gr.Checkbox(value=_g("FLIP_V", True), label="FLIP_V")
gr.Markdown(
"*FLIP_V toggles whether the V coordinate is flipped to match common glTF texture conventions, and you should disable it only if the texture appears vertically inverted.*"
)
uv_wrap_repeat = gr.Checkbox(value=_g("UV_WRAP_REPEAT", True), label="UV_WRAP_REPEAT")
gr.Markdown(
"*UV_WRAP_REPEAT selects how out of range UVs are handled by either repeating via modulo or clamping to the unit interval, and repeating is typically preferred for tiled textures.*"
)
transition_conf_thresh = gr.Slider(
0.25, 1.0, value=float(_g("TRANSITION_CONF_THRESH", 1.0)), step=0.25, label="TRANSITION_CONF_THRESH"
)
gr.Markdown(
"*TRANSITION_CONF_THRESH sets the confidence threshold for transition handling, where a higher value makes refinement more conservative and a lower value enables more aggressive smoothing.*"
)
transition_prop_iters = gr.Number(
value=_g("TRANSITION_PROP_ITERS", 6), precision=0, label="TRANSITION_PROP_ITERS"
)
gr.Markdown(
"*TRANSITION_PROP_ITERS specifies the number of propagation iterations used in transition refinement, where more iterations strengthen diffusion effects but increase runtime.*"
)
transition_neighbor_min = gr.Number(
value=_g("TRANSITION_NEIGHBOR_MIN", 1), precision=0, label="TRANSITION_NEIGHBOR_MIN"
)
gr.Markdown(
"*TRANSITION_NEIGHBOR_MIN requires a minimum number of supporting neighbors to propagate a label, where a higher requirement is more conservative and a lower requirement is more permissive.*"
)
small_component_action = gr.Dropdown(
choices=["reassign", "drop"], value=_g("SMALL_COMPONENT_ACTION", "reassign"), label="SMALL_COMPONENT_ACTION"
)
gr.Markdown(
"*SMALL_COMPONENT_ACTION determines how small connected components are handled by either reassigning them to neighboring labels or dropping them entirely.*"
)
small_component_min_faces = gr.Number(
value=_g("SMALL_COMPONENT_MIN_FACES", 50), precision=0, label="SMALL_COMPONENT_MIN_FACES"
)
gr.Markdown(
"*SMALL_COMPONENT_MIN_FACES defines the face count threshold used to classify a component as small, where a higher threshold merges or removes more fragments and a lower threshold preserves more small parts.*"
)
postprocess_iters = gr.Number(
value=_g("POSTPROCESS_ITERS", 3), precision=0, label="POSTPROCESS_ITERS"
)
gr.Markdown(
"*POSTPROCESS_ITERS sets the number of post processing iterations, where more iterations produce stronger cleanup at the cost of additional computation.*"
)
min_faces_per_part = gr.Number(
value=_g("MIN_FACES_PER_PART", 1), precision=0, label="MIN_FACES_PER_PART"
)
gr.Markdown(
"*MIN_FACES_PER_PART enforces a minimum number of faces per exported part, where a larger value filters tiny outputs and a smaller value retains fine components.*"
)
bake_transforms = gr.Checkbox(value=_g("BAKE_TRANSFORMS", True), label="BAKE_TRANSFORMS")
gr.Markdown(
"*BAKE_TRANSFORMS controls whether scene graph transforms are baked into geometry before splitting, where enabling it improves consistency in world space and disabling it preserves node transforms.*"
)
with gr.Column(scale=1, min_width=260):
refine_btn = gr.Button("Segment", variant="secondary")
part_glb = gr.Model3D(label="Segmented GLB", elem_id="part_glb")
seg_glb_state = gr.State(None)
map_mode.change(
fn=_update_img_box,
inputs=[map_mode],
outputs=[in_img],
)
seg_btn.click(
fn=run_seg,
inputs=[in_glb, map_mode, in_img],
outputs=[seg_glb, seg_glb_state, in_img],
)
refine_btn.click(
fn=run_refine_segmentation,
inputs=[
seg_glb_state,
color_quant_step,
palette_sample_pixels,
palette_min_pixels,
palette_max_colors,
palette_merge_dist,
samples_per_face,
flip_v,
uv_wrap_repeat,
transition_conf_thresh,
transition_prop_iters,
transition_neighbor_min,
small_component_action,
small_component_min_faces,
postprocess_iters,
min_faces_per_part,
bake_transforms
],
outputs=[part_glb],
)
if __name__ == "__main__":
inf.PIPE.load_all_models()
# preload
inf.PIPE.load_ckpt_if_needed(CKPT_W_2D_MAP)
demo.launch()