| | import gradio as gr |
| | import spaces |
| | import os |
| | import numpy as np |
| | import trimesh |
| | import time |
| | import traceback |
| | import torch |
| | from PIL import Image |
| | import cv2 |
| | import shutil |
| | from segment_anything import SamAutomaticMaskGenerator, build_sam |
| | from omegaconf import OmegaConf |
| |
|
| | from modules.bbox_gen.models.autogressive_bbox_gen import BboxGen |
| | from modules.part_synthesis.process_utils import save_parts_outputs |
| | from modules.inference_utils import load_img_mask, prepare_bbox_gen_input, prepare_part_synthesis_input, gen_mesh_from_bounds, vis_voxel_coords, merge_parts |
| | from modules.part_synthesis.pipelines import OmniPartImageTo3DPipeline |
| | from modules.label_2d_mask.visualizer import Visualizer |
| | from transformers import AutoModelForImageSegmentation |
| |
|
| | from modules.label_2d_mask.label_parts import ( |
| | prepare_image, |
| | get_sam_mask, |
| | get_mask, |
| | clean_segment_edges, |
| | resize_and_pad_to_square, |
| | size_th as DEFAULT_SIZE_TH |
| | ) |
| |
|
| | |
| | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| | DTYPE = torch.float16 |
| | MAX_SEED = np.iinfo(np.int32).max |
| | TMP_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp") |
| | os.makedirs(TMP_ROOT, exist_ok=True) |
| |
|
| | sam_mask_generator = None |
| | rmbg_model = None |
| | bbox_gen_model = None |
| | part_synthesis_pipeline = None |
| |
|
| | size_th = DEFAULT_SIZE_TH |
| |
|
| |
|
| | def prepare_models(sam_ckpt_path, partfield_ckpt_path, bbox_gen_ckpt_path): |
| | global sam_mask_generator, rmbg_model, bbox_gen_model, part_synthesis_pipeline |
| | if sam_mask_generator is None: |
| | print("Loading SAM model...") |
| | sam_model = build_sam(checkpoint=sam_ckpt_path).to(device=DEVICE) |
| | sam_mask_generator = SamAutomaticMaskGenerator(sam_model) |
| | |
| | if rmbg_model is None: |
| | print("Loading BriaRMBG 2.0 model...") |
| | rmbg_model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True) |
| | rmbg_model.to(DEVICE) |
| | rmbg_model.eval() |
| | |
| | if part_synthesis_pipeline is None: |
| | print("Loading PartSynthesis model...") |
| | part_synthesis_pipeline = OmniPartImageTo3DPipeline.from_pretrained('omnipart/OmniPart') |
| | part_synthesis_pipeline.to(DEVICE) |
| |
|
| | if bbox_gen_model is None: |
| | print("Loading BboxGen model...") |
| | bbox_gen_config = OmegaConf.load("configs/bbox_gen.yaml").model.args |
| | bbox_gen_config.partfield_encoder_path = partfield_ckpt_path |
| | bbox_gen_model = BboxGen(bbox_gen_config) |
| | bbox_gen_model.load_state_dict(torch.load(bbox_gen_ckpt_path), strict=False) |
| | bbox_gen_model.to(DEVICE) |
| | bbox_gen_model.eval().half() |
| | |
| | print("Models ready") |
| |
|
| |
|
| | @spaces.GPU |
| | def process_image(image_path, threshold, req: gr.Request): |
| | """Process image and generate initial segmentation""" |
| | global size_th |
| |
|
| | user_dir = os.path.join(TMP_ROOT, str(req.session_hash)) |
| | os.makedirs(user_dir, exist_ok=True) |
| | |
| | img_name = os.path.basename(image_path).split(".")[0] |
| | |
| | size_th = threshold |
| | |
| | img = Image.open(image_path).convert("RGB") |
| | processed_image = prepare_image(img, rmbg_net=rmbg_model.to(DEVICE)) |
| | |
| | processed_image = resize_and_pad_to_square(processed_image) |
| | white_bg = Image.new("RGBA", processed_image.size, (255, 255, 255, 255)) |
| | white_bg_img = Image.alpha_composite(white_bg, processed_image.convert("RGBA")) |
| | image = np.array(white_bg_img.convert('RGB')) |
| | |
| | rgba_path = os.path.join(user_dir, f"{img_name}_processed.png") |
| | processed_image.save(rgba_path) |
| | |
| | print("Generating raw SAM masks without post-processing...") |
| | raw_masks = sam_mask_generator.generate(image) |
| | |
| | raw_sam_vis = np.copy(image) |
| | raw_sam_vis = np.ones_like(image) * 255 |
| | |
| | sorted_masks = sorted(raw_masks, key=lambda x: x["area"], reverse=True) |
| | |
| | for i, mask_data in enumerate(sorted_masks): |
| | if mask_data["area"] < size_th: |
| | continue |
| | |
| | color_r = (i * 50 + 80) % 256 |
| | color_g = (i * 120 + 40) % 256 |
| | color_b = (i * 180 + 20) % 256 |
| | color = np.array([color_r, color_g, color_b]) |
| | |
| | mask = mask_data["segmentation"] |
| | raw_sam_vis[mask] = color |
| | |
| | visual = Visualizer(image) |
| | |
| | group_ids, pre_merge_im = get_sam_mask( |
| | image, |
| | sam_mask_generator, |
| | visual, |
| | merge_groups=None, |
| | rgba_image=processed_image, |
| | img_name=img_name, |
| | save_dir=user_dir, |
| | size_threshold=size_th |
| | ) |
| | |
| | pre_merge_path = os.path.join(user_dir, f"{img_name}_mask_pre_merge.png") |
| | Image.fromarray(pre_merge_im).save(pre_merge_path) |
| | pre_split_vis = np.ones_like(image) * 255 |
| | |
| | unique_ids = np.unique(group_ids) |
| | unique_ids = unique_ids[unique_ids >= 0] |
| | |
| | for i, unique_id in enumerate(unique_ids): |
| | color_r = (i * 50 + 80) % 256 |
| | color_g = (i * 120 + 40) % 256 |
| | color_b = (i * 180 + 20) % 256 |
| | color = np.array([color_r, color_g, color_b]) |
| | |
| | mask = (group_ids == unique_id) |
| | pre_split_vis[mask] = color |
| | |
| | y_indices, x_indices = np.where(mask) |
| | if len(y_indices) > 0: |
| | center_y = int(np.mean(y_indices)) |
| | center_x = int(np.mean(x_indices)) |
| | cv2.putText(pre_split_vis, str(unique_id), |
| | (center_x, center_y), cv2.FONT_HERSHEY_SIMPLEX, |
| | 0.5, (0, 0, 0), 1, cv2.LINE_AA) |
| | |
| | pre_split_path = os.path.join(user_dir, f"{img_name}_pre_split.png") |
| | Image.fromarray(pre_split_vis).save(pre_split_path) |
| | print(f"Pre-split segmentation (before disconnected parts handling) saved to {pre_split_path}") |
| | |
| | get_mask(group_ids, image, ids=2, img_name=img_name, save_dir=user_dir) |
| | |
| | init_seg_path = os.path.join(user_dir, f"{img_name}_mask_segments_2.png") |
| | |
| | seg_img = Image.open(init_seg_path) |
| | if seg_img.mode == 'RGBA': |
| | white_bg = Image.new('RGBA', seg_img.size, (255, 255, 255, 255)) |
| | seg_img = Image.alpha_composite(white_bg, seg_img) |
| | seg_img.save(init_seg_path) |
| | |
| | state = { |
| | "image": image.tolist(), |
| | "processed_image": rgba_path, |
| | "group_ids": group_ids.tolist() if isinstance(group_ids, np.ndarray) else group_ids, |
| | "original_group_ids": group_ids.tolist() if isinstance(group_ids, np.ndarray) else group_ids, |
| | "img_name": img_name, |
| | "pre_split_path": pre_split_path, |
| | } |
| | |
| | return init_seg_path, pre_merge_path, state |
| |
|
| |
|
| | def apply_merge(merge_input, state, req: gr.Request): |
| | """Apply merge parameters and generate merged segmentation""" |
| | global sam_mask_generator |
| | |
| | if not state: |
| | return None, None, state |
| |
|
| | user_dir = os.path.join(TMP_ROOT, str(req.session_hash)) |
| | |
| | |
| | image = np.array(state["image"]) |
| | |
| | group_ids = np.array(state["original_group_ids"]) |
| | img_name = state["img_name"] |
| | |
| | |
| | processed_image = Image.open(state["processed_image"]) |
| | |
| | |
| | unique_ids = np.unique(group_ids) |
| | unique_ids = unique_ids[unique_ids >= 0] |
| | print(f"Original segment IDs (used for merging): {sorted(unique_ids.tolist())}") |
| | |
| | |
| | merge_groups = None |
| | try: |
| | if merge_input: |
| | merge_groups = [] |
| | group_sets = merge_input.split(';') |
| | for group_set in group_sets: |
| | ids = [int(x) for x in group_set.split(',')] |
| | if ids: |
| | |
| | existing_ids = [id for id in ids if id in unique_ids] |
| | missing_ids = [id for id in ids if id not in unique_ids] |
| | |
| | if missing_ids: |
| | print(f"Warning: These IDs don't exist in the segmentation: {missing_ids}") |
| | |
| | |
| | if existing_ids: |
| | merge_groups.append(ids) |
| | print(f"Valid merge group: {ids} (missing: {missing_ids if missing_ids else 'none'})") |
| | else: |
| | print(f"Skipping merge group with no valid IDs: {ids}") |
| | |
| | print(f"Using merge groups: {merge_groups}") |
| | except Exception as e: |
| | print(f"Error parsing merge groups: {e}") |
| | return None, None, state |
| | |
| | |
| | visual = Visualizer(image) |
| | |
| | |
| | |
| | new_group_ids, merged_im = get_sam_mask( |
| | image, |
| | sam_mask_generator, |
| | visual, |
| | merge_groups=merge_groups, |
| | existing_group_ids=group_ids, |
| | rgba_image=processed_image, |
| | skip_split=True, |
| | img_name=img_name, |
| | save_dir=user_dir, |
| | size_threshold=size_th |
| | ) |
| | |
| | |
| | new_unique_ids = np.unique(new_group_ids) |
| | new_unique_ids = new_unique_ids[new_unique_ids >= 0] |
| | print(f"New segment IDs (after merging): {new_unique_ids.tolist()}") |
| | |
| | |
| | new_group_ids = clean_segment_edges(new_group_ids) |
| | |
| | |
| | get_mask(new_group_ids, image, ids=3, img_name=img_name, save_dir=user_dir) |
| | |
| | |
| | merged_seg_path = os.path.join(user_dir, f"{img_name}_mask_segments_3.png") |
| |
|
| | save_mask = new_group_ids + 1 |
| | save_mask = save_mask.reshape(518, 518, 1).repeat(3, axis=-1) |
| | cv2.imwrite(os.path.join(user_dir, f"{img_name}_mask.exr"), save_mask.astype(np.float32)) |
| | |
| | |
| | state["group_ids"] = new_group_ids.tolist() if isinstance(new_group_ids, np.ndarray) else new_group_ids |
| | state["save_mask_path"] = os.path.join(user_dir, f"{img_name}_mask.exr") |
| | |
| | return merged_seg_path, state |
| |
|
| |
|
| | def explode_mesh(mesh, explosion_scale=0.4): |
| |
|
| | if isinstance(mesh, trimesh.Scene): |
| | scene = mesh |
| | elif isinstance(mesh, trimesh.Trimesh): |
| | print("Warning: Single mesh provided, can't create exploded view") |
| | scene = trimesh.Scene(mesh) |
| | return scene |
| | else: |
| | print(f"Warning: Unexpected mesh type: {type(mesh)}") |
| | scene = mesh |
| |
|
| | if len(scene.geometry) <= 1: |
| | print("Only one geometry found - nothing to explode") |
| | return scene |
| | |
| | print(f"[EXPLODE_MESH] Starting mesh explosion with scale {explosion_scale}") |
| | print(f"[EXPLODE_MESH] Processing {len(scene.geometry)} parts") |
| | |
| | exploded_scene = trimesh.Scene() |
| | |
| | part_centers = [] |
| | geometry_names = [] |
| | |
| | for geometry_name, geometry in scene.geometry.items(): |
| | if hasattr(geometry, 'vertices'): |
| | transform = scene.graph[geometry_name][0] |
| | vertices_global = trimesh.transformations.transform_points( |
| | geometry.vertices, transform) |
| | center = np.mean(vertices_global, axis=0) |
| | part_centers.append(center) |
| | geometry_names.append(geometry_name) |
| | print(f"[EXPLODE_MESH] Part {geometry_name}: center = {center}") |
| | |
| | if not part_centers: |
| | print("No valid geometries with vertices found") |
| | return scene |
| | |
| | part_centers = np.array(part_centers) |
| | global_center = np.mean(part_centers, axis=0) |
| | |
| | print(f"[EXPLODE_MESH] Global center: {global_center}") |
| | |
| | for i, (geometry_name, geometry) in enumerate(scene.geometry.items()): |
| | if hasattr(geometry, 'vertices'): |
| | if i < len(part_centers): |
| | part_center = part_centers[i] |
| | direction = part_center - global_center |
| | |
| | direction_norm = np.linalg.norm(direction) |
| | if direction_norm > 1e-6: |
| | direction = direction / direction_norm |
| | else: |
| | direction = np.random.randn(3) |
| | direction = direction / np.linalg.norm(direction) |
| | |
| | offset = direction * explosion_scale |
| | else: |
| | offset = np.zeros(3) |
| | |
| | original_transform = scene.graph[geometry_name][0].copy() |
| | |
| | new_transform = original_transform.copy() |
| | new_transform[:3, 3] = new_transform[:3, 3] + offset |
| | |
| | exploded_scene.add_geometry( |
| | geometry, |
| | transform=new_transform, |
| | geom_name=geometry_name |
| | ) |
| | |
| | print(f"[EXPLODE_MESH] Part {geometry_name}: moved by {np.linalg.norm(offset):.4f}") |
| | |
| | print("[EXPLODE_MESH] Mesh explosion complete") |
| | return exploded_scene |
| | |
| | @spaces.GPU(duration=90) |
| | def generate_parts(state, seed, cfg_strength, req: gr.Request): |
| | explode_factor=0.3 |
| | img_path = state["processed_image"] |
| | mask_path = state["save_mask_path"] |
| | user_dir = os.path.join(TMP_ROOT, str(req.session_hash)) |
| | img_white_bg, img_black_bg, ordered_mask_input, img_mask_vis = load_img_mask(img_path, mask_path) |
| | img_mask_vis.save(os.path.join(user_dir, "img_mask_vis.png")) |
| |
|
| | voxel_coords = part_synthesis_pipeline.get_coords(img_black_bg, num_samples=1, seed=seed, sparse_structure_sampler_params={"steps": 25, "cfg_strength": 7.5}) |
| | voxel_coords = voxel_coords.cpu().numpy() |
| | np.save(os.path.join(user_dir, "voxel_coords.npy"), voxel_coords) |
| | voxel_coords_ply = vis_voxel_coords(voxel_coords) |
| | voxel_coords_ply.export(os.path.join(user_dir, "voxel_coords_vis.ply")) |
| | print("[INFO] Voxel coordinates saved") |
| |
|
| | bbox_gen_input = prepare_bbox_gen_input(os.path.join(user_dir, "voxel_coords.npy"), img_white_bg, ordered_mask_input) |
| | bbox_gen_output = bbox_gen_model.generate(bbox_gen_input) |
| | np.save(os.path.join(user_dir, "bboxes.npy"), bbox_gen_output['bboxes'][0]) |
| | bboxes_vis = gen_mesh_from_bounds(bbox_gen_output['bboxes'][0]) |
| | bboxes_vis.export(os.path.join(user_dir, "bboxes_vis.glb")) |
| | print("[INFO] BboxGen output saved") |
| |
|
| |
|
| | part_synthesis_input = prepare_part_synthesis_input(os.path.join(user_dir, "voxel_coords.npy"), os.path.join(user_dir, "bboxes.npy"), ordered_mask_input) |
| |
|
| | torch.cuda.empty_cache() |
| |
|
| | part_synthesis_output = part_synthesis_pipeline.get_slat( |
| | img_black_bg, |
| | part_synthesis_input['coords'], |
| | [part_synthesis_input['part_layouts']], |
| | part_synthesis_input['masks'], |
| | seed=seed, |
| | slat_sampler_params={"steps": 25, "cfg_strength": cfg_strength}, |
| | formats=['mesh', 'gaussian'], |
| | preprocess_image=False, |
| | ) |
| | save_parts_outputs( |
| | part_synthesis_output, |
| | output_dir=user_dir, |
| | simplify_ratio=0.0, |
| | save_video=False, |
| | save_glb=True, |
| | textured=False, |
| | ) |
| | merge_parts(user_dir) |
| | print("[INFO] PartSynthesis output saved") |
| |
|
| | bbox_mesh_path = os.path.join(user_dir, "bboxes_vis.glb") |
| | whole_mesh_path = os.path.join(user_dir, "mesh_segment.glb") |
| |
|
| | combined_mesh = trimesh.load(whole_mesh_path) |
| | exploded_mesh_result = explode_mesh(combined_mesh, explosion_scale=explode_factor) |
| | exploded_mesh_result.export(os.path.join(user_dir, "exploded_parts.glb")) |
| |
|
| | exploded_mesh_path = os.path.join(user_dir, "exploded_parts.glb") |
| | combined_gs_path = os.path.join(user_dir, "merged_gs.ply") |
| | exploded_gs_path = os.path.join(user_dir, "exploded_gs.ply") |
| | |
| | return bbox_mesh_path, whole_mesh_path, exploded_mesh_path, combined_gs_path, exploded_gs_path |
| |
|