3D-Fixer / app.py
JasonYinnnn's picture
longer duration
902fc22
# SPDX-FileCopyrightText: 2026 Ze-Xin Yin, Robot labs of Horizon Robotics, and D-Robotics
# SPDX-License-Identifier: Apache-2.0
# See the LICENSE file in the project root for full license information.
from gradio_image_prompter import ImagePrompter
import gradio as gr
import spaces
import os
import uuid
from typing import Any, List, Optional, Union
import cv2
import torch
import numpy as np
from PIL import Image
import trimesh
import random
import imageio
from einops import repeat
from huggingface_hub import snapshot_download
from moge.model.v2 import MoGeModel
from transformers import AutoModelForMaskGeneration, AutoProcessor
from scripts.grounding_sam import plot_segmentation, segment
import copy
import shutil
import time
from concurrent.futures import ThreadPoolExecutor
MARKDOWN = """
## Image to 3D Scene with [3D-Fixer](https://zx-yin.github.io/3dfixer/)
1. Upload an image, and draw bounding boxes for each instance by holding and dragging the mouse. Then click "Run Segmentation" to generate the segmentation result.
2. If you find the generated 3D scene satisfactory, download it by clicking the "Download scene GLB" button, and you can also download each islolated 3D instance.
3. In this implementation, we generate each instances one by one, and update the scene results at the "Generated GLB" area, besides, we display isolated instances below.
4. It may take some time to download the ckpts, and compile the gsplat. Thank you for your patience to wait. We recommend to deploy the demo locally.
"""
MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
EXAMPLE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets/example_data")
DTYPE = torch.float16
DEVICE = "cpu"
VALID_RATIO_THRESHOLD = 0.005
CROP_SIZE = 518
work_space = None
generated_object_map = {}
############## 3D-Fixer model
model_dir = 'HorizonRobotics/3D-Fixer'
local_dir = "./checkpoints/3D-Fixer"
os.makedirs(local_dir, exist_ok=True)
snapshot_download(repo_id=model_dir, local_dir=local_dir)
############## 3D-Fixer model
save_projected_colored_pcd = lambda pts, pts_color, fpath: trimesh.PointCloud(pts.reshape(-1, 3), pts_color.reshape(-1, 3)).export(fpath)
EXAMPLES = [
[
{
"image": "assets/example_data/scene1/rgb.png",
},
"assets/example_data/scene1/seg.png",
1024,
False,
25, 5.5, 0.8, 1.0, 5.0
# num_inference_steps, guidance_scale, cfg_interval_start, cfg_interval_end, t_rescale
],
[
{
"image": "assets/example_data/scene2/rgb.png",
},
"assets/example_data/scene2/seg.png",
1,
False,
25, 5.0, 0.8, 1.0, 5.0
],
[
{
"image": "assets/example_data/scene3/rgb.png",
},
"assets/example_data/scene3/seg.png",
1,
False,
25, 5.0, 0.8, 1.0, 5.0
],
[
{
"image": "assets/example_data/scene4/rgb.png",
},
"assets/example_data/scene4/seg.png",
42,
False,
25, 5.0, 0.8, 1.0, 5.0
],
[
{
"image": "assets/example_data/scene5/rgb.png",
},
"assets/example_data/scene5/seg.png",
1,
False,
25, 5.0, 0.8, 1.0, 5.0
],
[
{
"image": "assets/example_data/scene6/rgb.png",
},
"assets/example_data/scene6/seg.png",
1,
False,
25, 5.0, 0.8, 1.0, 5.0
]
]
def cleanup_tmp(tmp_root: str = "./tmp", expire_seconds: int = 3600) -> None:
"""
删除 tmp_root 下超过 expire_seconds 未更新的旧子目录。
Args:
tmp_root: 临时目录根路径。
expire_seconds: 过期时间,默认 3600 秒(1 小时)。
"""
tmp_root = os.path.abspath(tmp_root)
if not os.path.isdir(tmp_root):
return
now = time.time()
for name in os.listdir(tmp_root):
path = os.path.join(tmp_root, name)
# 只清理子目录,不动散落文件
if not os.path.isdir(path):
continue
try:
mtime = os.path.getmtime(path)
age = now - mtime
if age > expire_seconds:
shutil.rmtree(path, ignore_errors=False)
print(f"[cleanup_tmp] removed old directory: {path}")
except Exception as e:
print(f"[cleanup_tmp] failed to remove {path}: {e}")
# run seg on CPU
def run_segmentation(
image_prompts: Any,
polygon_refinement: bool = True,
) -> Image.Image:
rgb_image = image_prompts["image"].convert("RGB")
global sam_segmentator
device = "cpu"
sam_segmentator.to(device=device, dtype=DTYPE if device == 'cuda' else torch.float32)
# pre-process the layers and get the xyxy boxes of each layer
if len(image_prompts["points"]) == 0:
raise gr.Error("No points provided for segmentation. Please add points to the image.")
boxes = [
[
[int(box[0]), int(box[1]), int(box[3]), int(box[4])]
for box in image_prompts["points"]
]
]
with torch.no_grad():
detections = segment(
sam_processor,
sam_segmentator,
rgb_image,
boxes=[boxes],
polygon_refinement=polygon_refinement,
)
seg_map_pil = plot_segmentation(rgb_image, detections)
cleanup_tmp(TMP_DIR, expire_seconds=3600)
work_space = {
"dir": os.path.join(TMP_DIR, f"work_space_{uuid.uuid4()}"),
}
os.makedirs(work_space["dir"], exist_ok=True)
seg_map_pil.save(os.path.join(work_space["dir"], "mask.png"))
return seg_map_pil, work_space
@spaces.GPU
def run_depth_estimation(
image_prompts: Any,
seg_image: Union[str, Image.Image],
work_space: dict,
) -> Image.Image:
rgb_image = image_prompts["image"].convert("RGB")
from threeDFixer.datasets.utils import (
normalize_vertices,
project2ply
)
rgb_image = rgb_image.resize((1024, 1024), Image.Resampling.LANCZOS)
global moge_v2_dpt_model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.float16 if device == 'cuda' else torch.float32
moge_v2_dpt_model = moge_v2_dpt_model.to(device=device, dtype=dtype)
if work_space is None:
work_space = {
"dir": os.path.join(TMP_DIR, f"work_space_{uuid.uuid4()}"),
}
os.makedirs(work_space["dir"], exist_ok=True)
origin_W, origin_H = rgb_image.size
if max(origin_H, origin_W) > 1024:
factor = max(origin_H, origin_W) / 1024
H = int(origin_H // factor)
W = int(origin_W // factor)
rgb_image = rgb_image.resize((W, H), Image.Resampling.LANCZOS)
W, H = rgb_image.size
input_image = np.array(rgb_image).astype(np.float32)
input_image = torch.tensor(input_image / 255, dtype=torch.float32, device=device).permute(2, 0, 1)
with torch.no_grad():
output = moge_v2_dpt_model.infer(input_image)
depth = output['depth']
intrinsics = output['intrinsics']
invalid_mask = torch.logical_or(torch.isnan(depth), torch.isinf(depth))
depth_mask = ~invalid_mask
depth = torch.where(invalid_mask, 0.0, depth)
K = torch.from_numpy(
np.array([
[intrinsics[0, 0].item() * W, 0, 0.5*W],
[0, intrinsics[1, 1].item() * H, 0.5*H],
[0, 0, 1]
])
).to(dtype=torch.float32, device=device)
work_space.update({
"c2w": c2w,
"K": K,
"depth_mask": depth_mask,
"depth": depth,
})
instance_labels = np.unique(np.array(seg_image).reshape(-1, 3), axis=0)
seg_image = seg_image.resize((W, H), Image.Resampling.LANCZOS)
seg_image = np.array(seg_image)
mask_pack = []
for instance_label in instance_labels:
if (instance_label == np.array([0, 0, 0])).all():
continue
else:
instance_mask = (seg_image.reshape(-1, 3) == instance_label).all(axis=-1).reshape(H, W)
mask_pack.append(instance_mask)
fg_mask = torch.from_numpy(np.stack(mask_pack).any(axis=0))
scene_est_depth_pts, scene_est_depth_pts_colors = \
project2ply(depth_mask.to(device), depth.to(device), input_image.to(device), K.to(device), c2w.to(device))
save_ply_path = os.path.join(work_space["dir"], "scene_pcd.glb")
fg_depth_pts, _ = \
project2ply(fg_mask.to(device), depth.to(device), input_image.to(device), K.to(device), c2w.to(device))
_, trans, scale = normalize_vertices(fg_depth_pts)
if trans.shape[0] == 1:
trans = trans[0]
work_space.update(
{
"trans": trans,
"scale": scale,
}
)
for k, v in work_space.items():
if isinstance(v, torch.Tensor):
work_space[k] = v.to('cpu')
trimesh.PointCloud(scene_est_depth_pts.reshape(-1, 3), scene_est_depth_pts_colors.reshape(-1, 3)).\
apply_translation(-trans).apply_scale(1. / (scale + 1e-6)).\
apply_transform(rot).export(save_ply_path)
return save_ply_path, work_space
def save_image(img, save_path):
img = (img.permute(1, 2, 0).detach().cpu().numpy() * 255.).astype(np.uint8)
imageio.v3.imwrite(save_path, img)
def set_random_seed(seed):
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def export_scene_glb(trimeshes, work_space, scene_name):
scene_path = os.path.abspath(os.path.join(work_space, scene_name))
trimesh.Scene(trimeshes).export(scene_path)
return scene_path
def get_duration(rgb_image, seg_image, seed, randomize_seed,
num_inference_steps, guidance_scale, cfg_interval_start,
cfg_interval_end, t_rescale, work_space):
instance_labels = np.unique(np.array(seg_image).reshape(-1, 3), axis=0)
step_duration = 15.0
return instance_labels.shape[0] * step_duration + 240
@spaces.GPU(duration=get_duration)
def run_generation(
rgb_image: Any,
seg_image: Union[str, Image.Image],
seed: int,
randomize_seed: bool = False,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
cfg_interval_start: float = 0.5,
cfg_interval_end: float = 1.0,
t_rescale: float = 3.0,
work_space: dict = None,
):
first_render = True
if work_space is None:
raise gr.Error("Please run step 1 and step 2 first.")
required_keys = ["dir", "depth_mask", "depth", "K", "c2w", "trans", "scale"]
missing = [k for k in required_keys if k not in work_space]
if missing:
raise gr.Error(f"Missing workspace fields: {missing}. Please run depth estimation (step 2) first.")
from threeDFixer.pipelines import ThreeDFixerPipeline
from threeDFixer.datasets.utils import (
edge_mask_morph_gradient,
process_scene_image,
process_instance_image,
)
from threeDFixer.utils import render_utils
def export_single_glb_from_outputs(
outputs,
fine_scale,
fine_trans,
coarse_scale,
coarse_trans,
trans,
scale,
rot,
work_space,
instance_name,
run_id
):
from threeDFixer.datasets.utils import (
transform_vertices,
)
from threeDFixer.utils import postprocessing_utils
with torch.enable_grad():
glb = postprocessing_utils.to_glb(
outputs["gaussian"][0],
outputs["mesh"][0],
simplify=0.95,
texture_size=1024,
transform_fn=lambda x: transform_vertices(
x,
ops=["scale", "translation", "scale", "translation"],
params=[fine_scale, fine_trans[None], coarse_scale, coarse_trans[None]],
),
verbose=False
)
instance_glb_path = os.path.abspath(
os.path.join(work_space, f"{run_id}_{instance_name}.glb")
)
glb.apply_translation(-trans) \
.apply_scale(1.0 / (scale + 1e-6)) \
.apply_transform(rot) \
.export(instance_glb_path)
return instance_glb_path, glb
generated_object_map = {}
run_id = str(uuid.uuid4())
DEVICE = 'cuda'
gr.Info('Loading ckpts')
down_t = time.time()
pipeline = ThreeDFixerPipeline.from_pretrained(
local_dir, compile=False
)
pipeline.to(device=DEVICE)
gr.Info(f'Loading ckpts duration: {time.time()-down_t:.2}s')
if not isinstance(rgb_image, Image.Image) and "image" in rgb_image:
rgb_image = rgb_image["image"]
instance_labels = np.unique(np.array(seg_image).reshape(-1, 3), axis=0)
if randomize_seed:
seed = random.randint(0, MAX_SEED)
set_random_seed(seed)
H, W = work_space['depth_mask'].shape
rgb_image = rgb_image.resize((W, H), Image.Resampling.LANCZOS)
seg_image = seg_image.resize((W, H), Image.Resampling.LANCZOS)
depth_mask = work_space['depth_mask'].detach().cpu().numpy() > 0
seg_image = np.array(seg_image)
mask_pack = []
for instance_label in instance_labels:
if (instance_label == np.array([0, 0, 0])).all():
continue
instance_mask = (seg_image.reshape(-1, 3) == instance_label).all(axis=-1).reshape(H, W)
mask_pack.append(instance_mask)
erode_kernel_size = 7
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (erode_kernel_size, erode_kernel_size))
results = []
trimeshes = []
trans = work_space['trans']
scale = work_space['scale']
current_scene_path = None
pending_exports = []
def build_stream_html(status_text: str):
cards_html = "".join([
f"""
<div style="
width: 220px;
border: 1px solid #ddd;
border-radius: 10px;
padding: 8px;
background: white;
box-sizing: border-box;
">
<div style="font-weight: 600; margin-bottom: 6px;">
{item["name"]}
</div>
<video
autoplay
muted
loop
playsinline
preload="metadata"
poster="/file={item['poster_path']}?v={run_id}"
style="
width: 100%;
border-radius: 8px;
display: block;
background: #f5f5f5;
"
>
<source src="/file={item['mp4_path']}?v={run_id}" type="video/mp4">
</video>
<div style="
margin-top: 6px;
font-size: 12px;
color: #666;
">
Status: {item.get("status_text", "Unknown")}
</div>
<div style="
margin-top: 4px;
font-size: 13px;
color: #444;
word-break: break-all;
">
{os.path.basename(item["glb_path"]) if item["glb_path"] is not None else "GLB not ready yet"}
</div>
</div>
"""
for item in results
])
return f"""
<div style="padding: 8px 0;">
<div style="font-weight: 700; margin-bottom: 8px;">Status: {status_text}</div>
<div style="font-weight: 700; margin-bottom: 12px;">Generated objects: {len(results)}</div>
<div style="display: flex; flex-wrap: wrap; gap: 12px; align-items: flex-start;">
{cards_html}
</div>
</div>
"""
def build_selector_and_download_updates(default_latest: bool = True):
object_choices = [item["name"] for item in results if item["glb_path"] is not None]
if len(object_choices) == 0:
return (
gr.update(choices=[], value=None),
gr.update(value=None, interactive=False),
)
selected_value = object_choices[-1] if default_latest else object_choices[0]
selected_path = generated_object_map[selected_value]
return (
gr.update(choices=object_choices, value=selected_value),
gr.update(value=selected_path, interactive=True),
)
def flush_finished_exports(status_text: str):
nonlocal current_scene_path, trimeshes, pending_exports
any_update = False
finished_items = []
for item in pending_exports:
if item["future"].done():
finished_items.append(item)
for item in finished_items:
pending_exports.remove(item)
result_index = item["result_index"]
object_label = item["object_label"]
future = item["future"]
try:
instance_glb_path, glb = future.result()
except Exception as e:
print(f"[export_glb][error] instance={item['instance_name']}: {e}")
results[result_index]["status_text"] = "GLB export failed"
any_update = True
continue
results[result_index]["glb_path"] = instance_glb_path
results[result_index]["status_text"] = "GLB ready"
generated_object_map[object_label] = instance_glb_path
trimeshes.append(glb)
current_scene_path = export_scene_glb(
trimeshes=trimeshes,
work_space=work_space['dir'],
scene_name=f"{run_id}_scene_step_{len(trimeshes)}.glb",
)
any_update = True
if any_update:
selector_update, single_download_update = build_selector_and_download_updates(default_latest=True)
return (
current_scene_path,
build_stream_html(status_text),
gr.update(value=current_scene_path, interactive=(current_scene_path is not None)),
selector_update,
single_download_update,
)
return None
yield (
None,
build_stream_html("Generating..."),
gr.update(value=None, interactive=False),
gr.update(choices=[], value=None),
gr.update(value=None, interactive=False),
)
with ThreadPoolExecutor(max_workers=5) as executor:
for instance_name, object_mask in enumerate(mask_pack):
try:
flushed = flush_finished_exports("Generating...")
if flushed is not None:
yield flushed
est_depth = work_space['depth'].to('cpu')
c2w = work_space['c2w'].to('cpu')
K = work_space['K'].to('cpu')
intrinsics = work_space['K'].float().to(DEVICE)
extrinsics = copy.deepcopy(work_space['c2w']).float().to(DEVICE)
extrinsics[:3, 1:3] *= -1
object_mask = object_mask > 0
instance_mask = np.logical_and(object_mask, depth_mask).astype(np.uint8)
valid_ratio = np.sum((instance_mask > 0).astype(np.float32)) / (H * W)
print(f'valid ratio of {instance_name}: {valid_ratio:.4f}')
if valid_ratio < VALID_RATIO_THRESHOLD:
continue
edge_mask = edge_mask_morph_gradient(instance_mask, kernel, 3)
fg_mask = (instance_mask > edge_mask).astype(np.uint8)
color_mask = fg_mask.astype(np.float32) + edge_mask.astype(np.float32) * 0.5
image = rgb_image
scene_image, scene_image_masked = process_scene_image(image, instance_mask, CROP_SIZE)
instance_image, instance_mask, instance_rays_o, instance_rays_d, instance_rays_c, \
instance_rays_t = process_instance_image(image, instance_mask, color_mask, est_depth, K, c2w, CROP_SIZE)
save_image(scene_image, os.path.join(work_space['dir'], f'input_scene_image_{instance_name}.png'))
save_image(scene_image_masked, os.path.join(work_space['dir'], f'input_scene_image_masked_{instance_name}.png'))
save_image(instance_image, os.path.join(work_space['dir'], f'input_instance_image_{instance_name}.png'))
save_image(
torch.cat([instance_image, instance_mask]),
os.path.join(work_space['dir'], f'input_instance_image_masked_{instance_name}.png')
)
pcd_points = (
instance_rays_o.to(DEVICE) +
instance_rays_d.to(DEVICE) * instance_rays_t[..., None].to(DEVICE)
).detach().cpu().numpy()
pcd_colors = instance_rays_c
save_projected_colored_pcd(
pcd_points,
repeat(pcd_colors, 'n -> n c', c=3),
f"{work_space['dir']}/instance_est_depth_{instance_name}.ply"
)
with torch.no_grad():
outputs, coarse_trans, coarse_scale, fine_trans, fine_scale = pipeline.run(
torch.cat([instance_image, instance_mask]).to(DEVICE),
scene_image_masked=scene_image_masked.to(DEVICE),
seed=seed,
extrinsics=extrinsics.to(DEVICE),
intrinsics=intrinsics.to(DEVICE),
points=pcd_points,
points_mask=pcd_colors,
sparse_structure_sampler_params={
"steps": num_inference_steps,
"cfg_strength": guidance_scale,
"cfg_interval": [cfg_interval_start, cfg_interval_end],
"rescale_t": t_rescale
},
slat_sampler_params={
"steps": num_inference_steps,
"cfg_strength": guidance_scale,
"cfg_interval": [cfg_interval_start, cfg_interval_end],
"rescale_t": t_rescale
}
)
mp4_path = os.path.abspath(
os.path.join(work_space['dir'], f"{run_id}_instance_gs_fine_{instance_name}.mp4")
)
poster_path = os.path.abspath(
os.path.join(work_space['dir'], f"{run_id}_instance_gs_fine_{instance_name}.png")
)
if first_render:
render_t = time.time()
video = render_utils.render_video(
outputs["gaussian"][0],
bg_color=(1.0, 1.0, 1.0)
)["color"]
if first_render:
gr.Info(f'Compile gsplat duration: {time.time()-render_t:.2}s')
first_render = False
imageio.mimsave(mp4_path, video, fps=30)
imageio.imwrite(poster_path, video[0])
object_label = f"Object {len(results) + 1}"
result_index = len(results)
results.append({
"name": object_label,
"mp4_path": mp4_path,
"poster_path": poster_path,
"glb_path": None,
"instance_index": instance_name,
"status_text": "Exporting GLB...",
})
# 第一次更新:视频先出来,3D 场景保持当前不变
yield (
current_scene_path,
build_stream_html("Generating..."),
gr.update(value=current_scene_path, interactive=(current_scene_path is not None)),
gr.update(choices=[], value=None),
gr.update(value=None, interactive=False),
)
future = executor.submit(
export_single_glb_from_outputs,
outputs=outputs,
fine_scale=fine_scale,
fine_trans=fine_trans,
coarse_scale=coarse_scale,
coarse_trans=coarse_trans,
trans=trans,
scale=scale,
rot=rot,
work_space=work_space['dir'],
instance_name=instance_name,
run_id=run_id,
)
pending_exports.append({
"future": future,
"result_index": result_index,
"instance_name": instance_name,
"object_label": object_label,
})
flushed = flush_finished_exports("Generating...")
if flushed is not None:
yield flushed
except Exception as e:
print(e)
while len(pending_exports) > 0:
flushed = flush_finished_exports("Generating...")
if flushed is not None:
yield flushed
else:
time.sleep(0.2)
ready_items = [item for item in results if item["glb_path"] is not None]
if len(ready_items) > 0:
final_scene_path = export_scene_glb(
trimeshes=trimeshes,
work_space=work_space['dir'],
scene_name=f"{run_id}_scene_final.glb",
)
selector_update, single_download_update = build_selector_and_download_updates(default_latest=True)
yield (
final_scene_path,
build_stream_html("Finished"),
gr.update(value=final_scene_path, interactive=True),
selector_update,
single_download_update,
)
else:
yield (
None,
"<div style='padding: 8px 0;'><b>Status:</b> No valid object generated.</div>",
gr.update(value=None, interactive=False),
gr.update(choices=[], value=None),
gr.update(value=None, interactive=False),
)
def update_single_download(selected_name):
global generated_object_map
if selected_name is None or selected_name not in generated_object_map:
return gr.update(value=None, interactive=False)
return gr.update(value=generated_object_map[selected_name], interactive=True)
# Demo
with gr.Blocks() as demo:
gr_work_space = gr.State(value=None)
gr.Markdown(MARKDOWN)
with gr.Column():
with gr.Row():
image_prompts = ImagePrompter(label="Input Image", type="pil")
seg_image = gr.Image(
label="Segmentation Result", type="pil", format="png"
)
with gr.Column():
with gr.Accordion("Segmentation Settings", open=True):
polygon_refinement = gr.Checkbox(label="Polygon Refinement", value=False)
seg_button = gr.Button("Run Segmentation (step 1)")
dpt_button = gr.Button("Run Depth estimation (step 2)", variant="primary")
with gr.Row():
dpt_model_output = gr.Model3D(label="Estimated depth map", interactive=False)
model_output = gr.Model3D(label="Generated GLB", interactive=False)
with gr.Column():
with gr.Accordion("Generation Settings", open=True):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=25,
)
with gr.Row():
cfg_interval_start = gr.Slider(
label="CFG interval start",
minimum=0.0,
maximum=1.0,
step=0.01,
value=0.8,
)
cfg_interval_end = gr.Slider(
label="CFG interval end",
minimum=0.0,
maximum=1.0,
step=0.01,
value=1.0,
)
t_rescale = gr.Slider(
label="t rescale factor",
minimum=1.0,
maximum=5.0,
step=0.1,
value=5.0,
)
guidance_scale = gr.Slider(
label="CFG scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=5.0,
)
gen_button = gr.Button("Run Generation (step 3)", variant="primary", interactive=False)
download_glb = gr.DownloadButton(label="Download scene GLB", interactive=False)
with gr.Row():
object_selector = gr.Dropdown(label="Choose instance: ")
download_single_glb = gr.DownloadButton(label="Download single GLB", interactive=False)
stream_output = gr.HTML(label="Generated Objects Stream")
with gr.Row():
gr.Examples(
examples=EXAMPLES,
inputs=[image_prompts, seg_image, seed, randomize_seed, num_inference_steps, guidance_scale, cfg_interval_start, cfg_interval_end, t_rescale],
outputs=[model_output, download_glb, seed],
cache_examples=False,
)
seg_button.click(
run_segmentation,
inputs=[
image_prompts,
polygon_refinement,
],
outputs=[seg_image, gr_work_space],
).then(lambda: gr.Button(interactive=True), outputs=[dpt_button])
dpt_button.click(
run_depth_estimation,
inputs=[
image_prompts,
seg_image,
gr_work_space
],
outputs=[dpt_model_output, gr_work_space],
).then(lambda: gr.Button(interactive=True), outputs=[gen_button])
gen_button.click(
run_generation,
inputs=[
image_prompts,
seg_image,
seed,
randomize_seed,
num_inference_steps,
guidance_scale,
cfg_interval_start,
cfg_interval_end,
t_rescale,
gr_work_space
],
outputs=[model_output,
stream_output,
download_glb,
object_selector,
download_single_glb],
)
object_selector.change(
update_single_download,
inputs=[object_selector],
outputs=[download_single_glb],
)
if __name__ == '__main__':
global sam_segmentator
global moge_v2_dpt_model
# Prepare models
## Grounding SAM
segmenter_id = "facebook/sam-vit-base"
sam_processor = AutoProcessor.from_pretrained(segmenter_id)
sam_segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(
"cpu", dtype=torch.float32
)
mogev2_id = 'Ruicheng/moge-2-vitl'
moge_v2_dpt_model = MoGeModel.from_pretrained(mogev2_id).to(
"cpu"
)
rot = np.array([
[-1.0, 0.0, 0.0, 0.0],
[ 0.0, 0.0, 1.0, 0.0],
[ 0.0, 1.0, 0.0, 0.0],
[ 0.0, 0.0, 0.0, 1.0],
], dtype=np.float32)
c2w = torch.tensor([
[1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, -1.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
], dtype=torch.float32, device='cpu')
demo.launch(allowed_paths=[TMP_DIR, EXAMPLE_DIR])