IScene-demo / interactive_demo.py
LuLing's picture
update
7f44940 verified
"""Interactive I-Scene demo.
Run from the repository root:
python interactive_demo.py
"""
from __future__ import annotations
import argparse
import os
import uuid
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any
os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False")
import gradio as gr
import numpy as np
import torch
from gradio_image_prompter import ImagePrompter
from gradio_litmodel3d import LitModel3D
from PIL import Image
from transformers import AutoModelForMaskGeneration, AutoProcessor
from iscene.inference.inferencer import ISceneInferencer
REPO_ROOT = Path(__file__).resolve().parent
DEFAULT_MODEL = "LuLing/IScene"
MODEL_ID = DEFAULT_MODEL
BASE_MODEL_ID: str | None = None
DEFAULT_SEED = 43
DEFAULT_SIMPLIFY = 0.95
DEFAULT_OUTPUT_ROOT = REPO_ROOT / "outputs" / "demo"
UPLOAD_ROOT = DEFAULT_OUTPUT_ROOT / "_uploads"
TARGET_SIZE = (512, 512)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
SAM_MODELS = {
"sam-vit-huge (best quality, 636M)": "facebook/sam-vit-huge",
"sam-vit-large (balanced, 308M)": "facebook/sam-vit-large",
"sam-vit-base (fastest, 91M)": "facebook/sam-vit-base",
}
MARKDOWN = """
# I-Scene Interactive Demo
Generate a 3D scene from one image.
**We notice some instability problems caused by huggingface space. We suggest serious users run this demo locally.**
Workflow:
1. Pick an example, or upload an image and draw boxes around objects.
2. Use the example mask, or click **Run SAM Segmentation** to create a mask.
3. Click **Generate Gaussian Splatting Preview** to create and preview `scene_pred.ply`.
4. Click **Generate GLB** only when you need mesh assets.
5. To save each instance in the scene, run the inference code with the same RGB/mask; `run_inference.py` writes per-instance assets alongside the scene output.
Note: The first run may be slow because the model checkpoint needs to be downloaded and cached.
"""
EXAMPLE_ORDER = [
"Scenethesis/SAM-3D-testing-case_rgb.png",
"Gen3DSR/Gen3DSR_scene1_rgb.png",
"MIDI-example/cartoon_style_07_rgb.png",
"Scenethesis/children_playroom2_rgb.png",
"Scenethesis/scenethesis-reading-corner-rgb.png",
"DL3DV/DL3DV-garden-rgb.png",
"DL3DV/DL3DV-table-chair-set-rgb.png",
"DL3DV/DL3DV-tables-rgb.png",
"outdoor/scene_beach2_rgb.png",
]
def _discover_examples() -> list[tuple[str, Path, Path]]:
examples_root = REPO_ROOT / "examples"
pairs: list[tuple[str, Path, Path]] = []
for rel_name in EXAMPLE_ORDER:
rgb_path = examples_root / rel_name
if not rgb_path.exists():
continue
seg_path = None
if "_rgb" in rgb_path.name:
seg_path = rgb_path.with_name(rgb_path.name.replace("_rgb", "_seg"))
elif "-rgb" in rgb_path.name:
seg_path = rgb_path.with_name(rgb_path.name.replace("-rgb", "-seg"))
if seg_path is None or not seg_path.exists():
continue
rel = rgb_path.relative_to(examples_root)
case_name = rgb_path.stem.replace("_rgb", "").replace("-rgb", "")
label = f"{rel.parent.as_posix()} / {case_name}"
pairs.append((label, rgb_path, seg_path))
return pairs
EXAMPLES = _discover_examples()
EXAMPLE_ROWS = [[str(rgb), str(mask)] for _, rgb, mask in EXAMPLES]
@dataclass
class DemoRunState:
rgb_path: str
mask_path: str
output_dir: str
seed: int
simplify: float
_sam_cache: dict[str, tuple[AutoProcessor, AutoModelForMaskGeneration]] = {}
_inferencer_cache: dict[tuple[str, str], ISceneInferencer] = {}
def _make_session_dir(request: gr.Request | None, root: Path = UPLOAD_ROOT) -> Path:
session_hash = getattr(request, "session_hash", None) or uuid.uuid4().hex[:10]
path = root / session_hash
path.mkdir(parents=True, exist_ok=True)
return path
def _timestamped_output_dir(request: gr.Request | None) -> Path:
session_hash = getattr(request, "session_hash", None) or uuid.uuid4().hex[:10]
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
return DEFAULT_OUTPUT_ROOT / f"{timestamp}_{session_hash}"
def _get_prompt_image(image_prompts: Any) -> Image.Image | None:
if image_prompts is None:
return None
if isinstance(image_prompts, dict):
image = image_prompts.get("image")
else:
image = image_prompts
if image is None:
return None
if isinstance(image, Image.Image):
return image.convert("RGB")
return Image.open(image).convert("RGB")
def _save_prompt_rgb(image_prompts: Any, request: gr.Request | None) -> Path:
image = _get_prompt_image(image_prompts)
if image is None:
raise gr.Error("Please upload an RGB image.")
session_dir = _make_session_dir(request)
path = session_dir / "input_rgb.png"
image.save(path)
return path
def _resolve_mask_path(mask_path: str | None) -> Path:
if not mask_path:
raise gr.Error("Please choose an example or run SAM segmentation first.")
path = Path(mask_path)
if not path.exists():
raise gr.Error(f"Mask file does not exist: {path}")
return path
def _get_inferencer() -> ISceneInferencer:
key = (MODEL_ID, BASE_MODEL_ID or "")
if key not in _inferencer_cache:
_inferencer_cache[key] = ISceneInferencer.from_pretrained(MODEL_ID, base_model_id=BASE_MODEL_ID)
return _inferencer_cache[key]
def _get_sam_model(model_choice: str) -> tuple[AutoProcessor, AutoModelForMaskGeneration]:
model_id = SAM_MODELS[model_choice]
if model_id in _sam_cache:
return _sam_cache[model_id]
processor = AutoProcessor.from_pretrained(model_id)
segmentator = AutoModelForMaskGeneration.from_pretrained(model_id).to(DEVICE, DTYPE)
segmentator.eval()
_sam_cache[model_id] = (processor, segmentator)
return processor, segmentator
def _boxes_from_prompts(image_prompts: Any) -> list[list[list[int]]]:
points = image_prompts.get("points", []) if isinstance(image_prompts, dict) else []
if not points:
raise gr.Error("Please draw at least one box before running SAM segmentation.")
boxes = []
for box in points:
x1, y1, x2, y2 = int(box[0]), int(box[1]), int(box[3]), int(box[4])
x_min, x_max = sorted((x1, x2))
y_min, y_max = sorted((y1, y2))
if x_max <= x_min or y_max <= y_min:
continue
boxes.append([x_min, y_min, x_max, y_max])
if not boxes:
raise gr.Error("No valid boxes were drawn.")
return [boxes]
def _mask_to_polygon(mask: np.ndarray) -> list[list[int]] | None:
import cv2
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return None
contour = max(contours, key=cv2.contourArea)
return contour.reshape(-1, 2).tolist()
def _polygon_to_mask(polygon: list[list[int]], image_shape: tuple[int, int]) -> np.ndarray:
import cv2
mask = np.zeros(image_shape, dtype=np.uint8)
cv2.fillPoly(mask, [np.array(polygon, dtype=np.int32)], color=(1,))
return mask
def _refine_masks(
masks: torch.Tensor,
*,
polygon_refinement: bool,
mask_threshold: float,
) -> list[np.ndarray]:
masks = masks.detach().cpu().float()
if masks.ndim == 5:
masks = masks[:, :, 0]
if masks.ndim == 4:
masks = masks.mean(dim=1)
masks = (masks > mask_threshold).numpy().astype(np.uint8)
refined = [mask for mask in masks]
if polygon_refinement:
for idx, mask in enumerate(refined):
polygon = _mask_to_polygon(mask)
if polygon is not None:
refined[idx] = _polygon_to_mask(polygon, mask.shape)
return refined
def _palette() -> list[int]:
colors = [0, 0, 0]
hue = 0.0
golden_ratio = 0.618033988749895
for _ in range(1, 256):
hue = (hue + golden_ratio) % 1.0
h = hue * 6.0
c = 0.81
x = c * (1 - abs(h % 2 - 1))
m = 0.09
if h < 1:
r, g, b = c, x, 0
elif h < 2:
r, g, b = x, c, 0
elif h < 3:
r, g, b = 0, c, x
elif h < 4:
r, g, b = 0, x, c
elif h < 5:
r, g, b = x, 0, c
else:
r, g, b = c, 0, x
colors.extend([int((r + m) * 255), int((g + m) * 255), int((b + m) * 255)])
return colors
def _label_mask_to_pil(label_map: np.ndarray) -> Image.Image:
if label_map.max(initial=0) < 256:
image = Image.fromarray(label_map.astype(np.uint8), mode="P")
image.putpalette(_palette())
return image
encoded = np.zeros((*label_map.shape, 3), dtype=np.uint8)
encoded[..., 0] = label_map & 255
encoded[..., 1] = (label_map >> 8) & 255
return Image.fromarray(encoded, mode="RGB")
def resize_prompt_image(image_prompts: Any) -> Any:
image = _get_prompt_image(image_prompts)
if image is None:
return image_prompts
resized = image.resize(TARGET_SIZE, Image.Resampling.LANCZOS)
UPLOAD_ROOT.mkdir(parents=True, exist_ok=True)
path = UPLOAD_ROOT / f"prompt_{uuid.uuid4().hex[:10]}.png"
resized.save(path)
return {"image": str(path), "points": []}
def reset_uploaded_image(image_prompts: Any) -> tuple[Any, None, str]:
return resize_prompt_image(image_prompts), None, ""
def _coerce_file_path(value: Any) -> str:
if isinstance(value, dict):
return str(value.get("path") or value.get("name") or value.get("image") or "")
return str(value or "")
def _raw_example_mask_path(mask_path: Any) -> str:
selected_mask = Path(_coerce_file_path(mask_path)).name
for _, _rgb_path, raw_mask_path in EXAMPLES:
if raw_mask_path.name == selected_mask:
return str(raw_mask_path)
return _coerce_file_path(mask_path)
def load_example_pair(rgb_path: Any, mask_path: Any) -> tuple[dict[str, Any], str, str]:
rgb_value = _coerce_file_path(rgb_path)
mask_value = _coerce_file_path(mask_path)
return {"image": rgb_value, "points": []}, mask_value, _raw_example_mask_path(mask_path)
@torch.no_grad()
def run_segmentation(
image_prompts: Any,
model_choice: str,
polygon_refinement: bool,
mask_threshold: float,
request: gr.Request,
) -> tuple[str, str]:
image = _get_prompt_image(image_prompts)
if image is None:
raise gr.Error("Please upload an RGB image before running segmentation.")
boxes = _boxes_from_prompts(image_prompts)
processor, segmentator = _get_sam_model(model_choice)
inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to(segmentator.device, segmentator.dtype)
outputs = segmentator(**inputs)
masks = processor.post_process_masks(
masks=outputs.pred_masks,
original_sizes=inputs.original_sizes,
reshaped_input_sizes=inputs.reshaped_input_sizes,
)[0]
masks = _refine_masks(masks, polygon_refinement=polygon_refinement, mask_threshold=mask_threshold)
label_map = np.zeros(image.size[::-1], dtype=np.uint32)
for idx, mask in enumerate(masks, start=1):
label_map[mask > 0] = idx
mask_image = _label_mask_to_pil(label_map)
session_dir = _make_session_dir(request)
raw_path = session_dir / "sam_mask.png"
mask_image.save(raw_path)
torch.cuda.empty_cache()
return str(raw_path), str(raw_path)
def run_gaussian_preview(
image_prompts: Any,
mask_path: str | None,
seed: int,
simplify: float,
output_dir_text: str,
request: gr.Request,
) -> tuple[str, dict[str, Any], dict[str, Any], str, DemoRunState]:
rgb_path = _save_prompt_rgb(image_prompts, request)
mask_path = _resolve_mask_path(mask_path)
output_dir = Path(output_dir_text).expanduser() if output_dir_text.strip() else _timestamped_output_dir(request)
output_dir.mkdir(parents=True, exist_ok=True)
inferencer = _get_inferencer()
inferencer.infer_and_save_scene(
scene_rgb_path=rgb_path,
instance_seg_path=mask_path,
output_dir=output_dir,
overwrite=True,
save_dbg=False,
simplify=float(simplify),
only_3dgs=True,
seed=int(seed),
)
scene_ply = output_dir / "scene_pred.ply"
if not scene_ply.exists():
raise gr.Error(f"Generation finished but scene_pred.ply was not found in {output_dir}")
state = DemoRunState(
rgb_path=str(rgb_path),
mask_path=str(mask_path),
output_dir=str(output_dir),
seed=int(seed),
simplify=float(simplify),
)
torch.cuda.empty_cache()
return (
str(scene_ply),
gr.update(value=str(scene_ply), interactive=True),
gr.update(value=None, interactive=False),
"",
state,
)
def _progress_bar(percent: int) -> str:
percent = max(0, min(100, int(percent)))
return f"""
<div style="height: 14px; width: 100%; background: #ece7dc; border-radius: 999px; overflow: hidden; border: 1px solid #d8cbb7;">
<div style="height: 100%; width: {percent}%; background: linear-gradient(90deg, #b77a2f, #e0b15a); transition: width 0.4s ease;"></div>
</div>
"""
def run_glb_export(
state: DemoRunState | dict[str, Any] | None,
simplify: float,
) -> Any:
if state is None:
raise gr.Error("Please run GS preview first so the demo knows which RGB/mask/output directory to use.")
if isinstance(state, dict):
state = DemoRunState(**state)
output_dir = Path(state.output_dir)
yield gr.update(value=None, interactive=False), _progress_bar(5), gr.update(value=None)
inferencer = _get_inferencer()
yield gr.update(value=None, interactive=False), _progress_bar(15), gr.update(value=None)
inferencer.infer_and_save_scene(
scene_rgb_path=state.rgb_path,
instance_seg_path=state.mask_path,
output_dir=output_dir,
overwrite=True,
save_dbg=False,
simplify=float(simplify),
only_3dgs=False,
seed=int(state.seed),
)
scene_glb = output_dir / "scene_pred.glb"
if not scene_glb.exists():
raise gr.Error(f"GLB export finished but scene_pred.glb was not found in {output_dir}")
torch.cuda.empty_cache()
yield gr.update(value=str(scene_glb), interactive=True), _progress_bar(100), str(scene_glb)
def clear_glb_outputs() -> tuple[dict[str, Any], str, None, dict[str, Any]]:
return gr.update(value=None, interactive=False), "", None, gr.update(value=None)
def clear_generation_outputs() -> tuple[dict[str, Any], dict[str, Any], dict[str, Any], str, None, dict[str, Any]]:
return (
gr.update(value=None),
gr.update(value=None, interactive=False),
gr.update(value=None, interactive=False),
"",
None,
gr.update(value=None),
)
def build_demo() -> gr.Blocks:
with gr.Blocks(title="I-Scene Interactive Demo", delete_cache=(3600, 3600)) as demo:
gr.Markdown(MARKDOWN)
run_state = gr.State(None)
with gr.Row():
with gr.Column(scale=1):
image_prompts = ImagePrompter(
label="RGB image (upload, then optionally draw boxes for SAM)",
type="pil",
height=520,
)
with gr.Row():
segment_button = gr.Button("Run SAM Segmentation", variant="secondary")
with gr.Accordion("Segmentation settings", open=False):
sam_model = gr.Dropdown(
choices=list(SAM_MODELS.keys()),
value="sam-vit-huge (best quality, 636M)",
label="SAM model",
)
mask_threshold = gr.Slider(
minimum=-1.0,
maximum=1.0,
value=0.0,
step=0.05,
label="Mask threshold",
)
polygon_refinement = gr.Checkbox(
label="Polygon refinement",
value=False,
)
sam_mask_preview = gr.Image(
label="Instance mask",
type="filepath",
format="png",
height=260,
)
mask_path_value = gr.Textbox(visible=False)
with gr.Accordion("Generation settings", open=False):
seed = gr.Number(label="Seed", value=DEFAULT_SEED, precision=0)
simplify = gr.Slider(
minimum=0.5,
maximum=1.0,
value=DEFAULT_SIMPLIFY,
step=0.01,
label="GLB mesh simplify ratio",
)
output_dir = gr.Textbox(
label="Output directory (optional)",
placeholder="Leave empty to use outputs/demo/<timestamp>_<session>",
)
generate_gs_button = gr.Button("Generate Gaussian Splatting Preview", variant="primary", size="lg")
with gr.Column(scale=1):
preview = LitModel3D(
label="3D preview",
exposure=10.0,
height=520,
)
download_gs = gr.DownloadButton(
label="Download Gaussian Splatting PLY",
interactive=False,
)
with gr.Row():
generate_glb_button = gr.Button("Generate GLB", variant="secondary")
glb_progress = gr.HTML(value="")
glb_preview = gr.Model3D(
label="GLB mesh preview",
clear_color=(0.98, 0.96, 0.91, 1.0),
display_mode="solid",
height=360,
)
download_glb = gr.DownloadButton(
label="Download Mesh GLB",
interactive=False,
)
image_prompts.upload(
reset_uploaded_image,
inputs=[image_prompts],
outputs=[image_prompts, sam_mask_preview, mask_path_value],
)
segment_button.click(
run_segmentation,
inputs=[image_prompts, sam_model, polygon_refinement, mask_threshold],
outputs=[sam_mask_preview, mask_path_value],
)
generate_gs_button.click(
clear_generation_outputs,
outputs=[preview, download_gs, download_glb, glb_progress, run_state, glb_preview],
show_progress="hidden",
).then(
run_gaussian_preview,
inputs=[
image_prompts,
mask_path_value,
seed,
simplify,
output_dir,
],
outputs=[preview, download_gs, download_glb, glb_progress, run_state],
show_progress="full",
)
generate_glb_button.click(
run_glb_export,
inputs=[run_state, simplify],
outputs=[download_glb, glb_progress, glb_preview],
show_progress="hidden",
)
example_rgb = gr.Image(label="RGB", type="filepath", visible=False)
example_mask = gr.Image(label="Instance mask", type="filepath", visible=False)
with gr.Row():
gr.Examples(
examples=EXAMPLE_ROWS,
inputs=[example_rgb, example_mask],
outputs=[image_prompts, sam_mask_preview, mask_path_value],
fn=load_example_pair,
cache_examples=False,
label="Examples",
run_on_click=True,
)
return demo
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--server_name", default="0.0.0.0")
parser.add_argument("--server_port", type=int, default=7860)
parser.add_argument("--share", action="store_true")
parser.add_argument("--model", default=DEFAULT_MODEL, help="I-Scene model id or local model package path.")
parser.add_argument(
"--base_model",
default=None,
help="Optional TRELLIS base model id or local mirror path. Defaults to the model package metadata.",
)
return parser.parse_args()
def main() -> None:
global MODEL_ID, BASE_MODEL_ID
args = parse_args()
MODEL_ID = args.model
BASE_MODEL_ID = args.base_model
DEFAULT_OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)
UPLOAD_ROOT.mkdir(parents=True, exist_ok=True)
demo = build_demo()
demo.queue()
demo.launch(
server_name=args.server_name,
server_port=args.server_port,
share=args.share,
)
if __name__ == "__main__":
main()