| | |
| | |
| | |
| | |
| | |
| |
|
| | import trimesh |
| | import gradio as gr |
| | import numpy as np |
| | import matplotlib |
| | from scipy.spatial.transform import Rotation |
| | import copy |
| | import cv2 |
| | import os |
| | import requests |
| |
|
| |
|
| | def predictions_to_glb( |
| | predictions, |
| | conf_thres=50.0, |
| | filter_by_frames="all", |
| | mask_black_bg=False, |
| | mask_white_bg=False, |
| | show_cam=True, |
| | mask_sky=False, |
| | target_dir=None, |
| | prediction_mode="Predicted Pointmap", |
| | ) -> trimesh.Scene: |
| | """ |
| | Converts VGGT predictions to a 3D scene represented as a GLB file. |
| | |
| | Args: |
| | predictions (dict): Dictionary containing model predictions with keys: |
| | - world_points: 3D point coordinates (S, H, W, 3) |
| | - world_points_conf: Confidence scores (S, H, W) |
| | - images: Input images (S, H, W, 3) |
| | - extrinsic: Camera extrinsic matrices (S, 3, 4) |
| | conf_thres (float): Percentage of low-confidence points to filter out (default: 50.0) |
| | filter_by_frames (str): Frame filter specification (default: "all") |
| | mask_black_bg (bool): Mask out black background pixels (default: False) |
| | mask_white_bg (bool): Mask out white background pixels (default: False) |
| | show_cam (bool): Include camera visualization (default: True) |
| | mask_sky (bool): Apply sky segmentation mask (default: False) |
| | target_dir (str): Output directory for intermediate files (default: None) |
| | prediction_mode (str): Prediction mode selector (default: "Predicted Pointmap") |
| | |
| | Returns: |
| | trimesh.Scene: Processed 3D scene containing point cloud and cameras |
| | |
| | Raises: |
| | ValueError: If input predictions structure is invalid |
| | """ |
| | if not isinstance(predictions, dict): |
| | raise ValueError("predictions must be a dictionary") |
| |
|
| | if conf_thres is None: |
| | conf_thres = 10.0 |
| |
|
| | print("Building GLB scene") |
| | selected_frame_idx = None |
| | if filter_by_frames != "all" and filter_by_frames != "All": |
| | try: |
| | |
| | selected_frame_idx = int(filter_by_frames.split(":")[0]) |
| | except (ValueError, IndexError): |
| | pass |
| |
|
| | if "Pointmap" in prediction_mode: |
| | print("Using Pointmap Branch") |
| | if "world_points" in predictions: |
| | pred_world_points = predictions["world_points"] |
| | pred_world_points_conf = predictions.get("world_points_conf", np.ones_like(pred_world_points[..., 0])) |
| | else: |
| | print("Warning: world_points not found in predictions, falling back to depth-based points") |
| | pred_world_points = predictions["world_points_from_depth"] |
| | pred_world_points_conf = predictions.get("depth_conf", np.ones_like(pred_world_points[..., 0])) |
| | else: |
| | print("Using Depthmap and Camera Branch") |
| | pred_world_points = predictions["world_points_from_depth"] |
| | pred_world_points_conf = predictions.get("depth_conf", np.ones_like(pred_world_points[..., 0])) |
| |
|
| | |
| | images = predictions["images"] |
| | |
| | camera_matrices = predictions["extrinsic"] |
| |
|
| | if mask_sky: |
| | if target_dir is not None: |
| | import onnxruntime |
| |
|
| | skyseg_session = None |
| | target_dir_images = target_dir + "/images" |
| | image_list = sorted(os.listdir(target_dir_images)) |
| | sky_mask_list = [] |
| |
|
| | |
| | S, H, W = ( |
| | pred_world_points_conf.shape |
| | if hasattr(pred_world_points_conf, "shape") |
| | else (len(images), images.shape[1], images.shape[2]) |
| | ) |
| |
|
| | |
| | if not os.path.exists("skyseg.onnx"): |
| | print("Downloading skyseg.onnx...") |
| | download_file_from_url( |
| | "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx", "skyseg.onnx" |
| | ) |
| |
|
| | for i, image_name in enumerate(image_list): |
| | image_filepath = os.path.join(target_dir_images, image_name) |
| | mask_filepath = os.path.join(target_dir, "sky_masks", image_name) |
| |
|
| | |
| | if os.path.exists(mask_filepath): |
| | |
| | sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE) |
| | else: |
| | |
| | if skyseg_session is None: |
| | skyseg_session = onnxruntime.InferenceSession("skyseg.onnx") |
| | sky_mask = segment_sky(image_filepath, skyseg_session, mask_filepath) |
| |
|
| | |
| | if sky_mask.shape[0] != H or sky_mask.shape[1] != W: |
| | sky_mask = cv2.resize(sky_mask, (W, H)) |
| |
|
| | sky_mask_list.append(sky_mask) |
| |
|
| | |
| | sky_mask_array = np.array(sky_mask_list) |
| |
|
| | |
| | sky_mask_binary = (sky_mask_array > 0.1).astype(np.float32) |
| | pred_world_points_conf = pred_world_points_conf * sky_mask_binary |
| |
|
| | if selected_frame_idx is not None: |
| | pred_world_points = pred_world_points[selected_frame_idx][None] |
| | pred_world_points_conf = pred_world_points_conf[selected_frame_idx][None] |
| | images = images[selected_frame_idx][None] |
| | camera_matrices = camera_matrices[selected_frame_idx][None] |
| |
|
| | vertices_3d = pred_world_points.reshape(-1, 3) |
| | |
| | if images.ndim == 4 and images.shape[1] == 3: |
| | colors_rgb = np.transpose(images, (0, 2, 3, 1)) |
| | else: |
| | colors_rgb = images |
| | colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8) |
| |
|
| | conf = pred_world_points_conf.reshape(-1) |
| | |
| | if conf_thres == 0.0: |
| | conf_threshold = 0.0 |
| | else: |
| | conf_threshold = np.percentile(conf, conf_thres) |
| |
|
| | conf_mask = (conf >= conf_threshold) & (conf > 1e-5) |
| |
|
| | if mask_black_bg: |
| | black_bg_mask = colors_rgb.sum(axis=1) >= 16 |
| | conf_mask = conf_mask & black_bg_mask |
| | |
| | if mask_white_bg: |
| | |
| | |
| | white_bg_mask = ~((colors_rgb[:, 0] > 240) & (colors_rgb[:, 1] > 240) & (colors_rgb[:, 2] > 240)) |
| | conf_mask = conf_mask & white_bg_mask |
| |
|
| | vertices_3d = vertices_3d[conf_mask] |
| | colors_rgb = colors_rgb[conf_mask] |
| |
|
| | if vertices_3d is None or np.asarray(vertices_3d).size == 0: |
| | vertices_3d = np.array([[1, 0, 0]]) |
| | colors_rgb = np.array([[255, 255, 255]]) |
| | scene_scale = 1 |
| | else: |
| | |
| | lower_percentile = np.percentile(vertices_3d, 5, axis=0) |
| | upper_percentile = np.percentile(vertices_3d, 95, axis=0) |
| |
|
| | |
| | scene_scale = np.linalg.norm(upper_percentile - lower_percentile) |
| |
|
| | colormap = matplotlib.colormaps.get_cmap("gist_rainbow") |
| |
|
| | |
| | scene_3d = trimesh.Scene() |
| |
|
| | |
| | point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb) |
| |
|
| | scene_3d.add_geometry(point_cloud_data) |
| |
|
| | |
| | num_cameras = len(camera_matrices) |
| | extrinsics_matrices = np.zeros((num_cameras, 4, 4)) |
| | extrinsics_matrices[:, :3, :4] = camera_matrices |
| | extrinsics_matrices[:, 3, 3] = 1 |
| |
|
| | if show_cam: |
| | |
| | for i in range(num_cameras): |
| | world_to_camera = extrinsics_matrices[i] |
| | camera_to_world = np.linalg.inv(world_to_camera) |
| | rgba_color = colormap(i / num_cameras) |
| | current_color = tuple(int(255 * x) for x in rgba_color[:3]) |
| |
|
| | integrate_camera_into_scene(scene_3d, camera_to_world, current_color, scene_scale) |
| |
|
| | |
| | scene_3d = apply_scene_alignment(scene_3d, extrinsics_matrices) |
| |
|
| | print("GLB Scene built") |
| | return scene_3d |
| |
|
| |
|
| | def integrate_camera_into_scene( |
| | scene: trimesh.Scene, |
| | transform: np.ndarray, |
| | face_colors: tuple, |
| | scene_scale: float, |
| | ): |
| | """ |
| | Integrates a fake camera mesh into the 3D scene. |
| | |
| | Args: |
| | scene (trimesh.Scene): The 3D scene to add the camera model. |
| | transform (np.ndarray): Transformation matrix for camera positioning. |
| | face_colors (tuple): Color of the camera face. |
| | scene_scale (float): Scale of the scene. |
| | """ |
| |
|
| | cam_width = scene_scale * 0.05 |
| | cam_height = scene_scale * 0.1 |
| |
|
| | |
| | rot_45_degree = np.eye(4) |
| | rot_45_degree[:3, :3] = Rotation.from_euler("z", 45, degrees=True).as_matrix() |
| | rot_45_degree[2, 3] = -cam_height |
| |
|
| | opengl_transform = get_opengl_conversion_matrix() |
| | |
| | complete_transform = transform @ opengl_transform @ rot_45_degree |
| | camera_cone_shape = trimesh.creation.cone(cam_width, cam_height, sections=4) |
| |
|
| | |
| | slight_rotation = np.eye(4) |
| | slight_rotation[:3, :3] = Rotation.from_euler("z", 2, degrees=True).as_matrix() |
| |
|
| | vertices_combined = np.concatenate( |
| | [ |
| | camera_cone_shape.vertices, |
| | 0.95 * camera_cone_shape.vertices, |
| | transform_points(slight_rotation, camera_cone_shape.vertices), |
| | ] |
| | ) |
| | vertices_transformed = transform_points(complete_transform, vertices_combined) |
| |
|
| | mesh_faces = compute_camera_faces(camera_cone_shape) |
| |
|
| | |
| | camera_mesh = trimesh.Trimesh(vertices=vertices_transformed, faces=mesh_faces) |
| | camera_mesh.visual.face_colors[:, :3] = face_colors |
| | scene.add_geometry(camera_mesh) |
| |
|
| |
|
| | def apply_scene_alignment(scene_3d: trimesh.Scene, extrinsics_matrices: np.ndarray) -> trimesh.Scene: |
| | """ |
| | Aligns the 3D scene based on the extrinsics of the first camera. |
| | |
| | Args: |
| | scene_3d (trimesh.Scene): The 3D scene to be aligned. |
| | extrinsics_matrices (np.ndarray): Camera extrinsic matrices. |
| | |
| | Returns: |
| | trimesh.Scene: Aligned 3D scene. |
| | """ |
| | |
| | opengl_conversion_matrix = get_opengl_conversion_matrix() |
| |
|
| | |
| | align_rotation = np.eye(4) |
| | align_rotation[:3, :3] = Rotation.from_euler("y", 180, degrees=True).as_matrix() |
| |
|
| | |
| | initial_transformation = np.linalg.inv(extrinsics_matrices[0]) @ opengl_conversion_matrix @ align_rotation |
| | scene_3d.apply_transform(initial_transformation) |
| | return scene_3d |
| |
|
| |
|
| | def get_opengl_conversion_matrix() -> np.ndarray: |
| | """ |
| | Constructs and returns the OpenGL conversion matrix. |
| | |
| | Returns: |
| | numpy.ndarray: A 4x4 OpenGL conversion matrix. |
| | """ |
| | |
| | matrix = np.identity(4) |
| |
|
| | |
| | matrix[1, 1] = -1 |
| | matrix[2, 2] = -1 |
| |
|
| | return matrix |
| |
|
| |
|
| | def transform_points(transformation: np.ndarray, points: np.ndarray, dim: int = None) -> np.ndarray: |
| | """ |
| | Applies a 4x4 transformation to a set of points. |
| | |
| | Args: |
| | transformation (np.ndarray): Transformation matrix. |
| | points (np.ndarray): Points to be transformed. |
| | dim (int, optional): Dimension for reshaping the result. |
| | |
| | Returns: |
| | np.ndarray: Transformed points. |
| | """ |
| | points = np.asarray(points) |
| | initial_shape = points.shape[:-1] |
| | dim = dim or points.shape[-1] |
| |
|
| | |
| | transformation = transformation.swapaxes(-1, -2) |
| | points = points @ transformation[..., :-1, :] + transformation[..., -1:, :] |
| |
|
| | |
| | result = points[..., :dim].reshape(*initial_shape, dim) |
| | return result |
| |
|
| |
|
| | def compute_camera_faces(cone_shape: trimesh.Trimesh) -> np.ndarray: |
| | """ |
| | Computes the faces for the camera mesh. |
| | |
| | Args: |
| | cone_shape (trimesh.Trimesh): The shape of the camera cone. |
| | |
| | Returns: |
| | np.ndarray: Array of faces for the camera mesh. |
| | """ |
| | |
| | faces_list = [] |
| | num_vertices_cone = len(cone_shape.vertices) |
| |
|
| | for face in cone_shape.faces: |
| | if 0 in face: |
| | continue |
| | v1, v2, v3 = face |
| | v1_offset, v2_offset, v3_offset = face + num_vertices_cone |
| | v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone |
| |
|
| | faces_list.extend( |
| | [ |
| | (v1, v2, v2_offset), |
| | (v1, v1_offset, v3), |
| | (v3_offset, v2, v3), |
| | (v1, v2, v2_offset_2), |
| | (v1, v1_offset_2, v3), |
| | (v3_offset_2, v2, v3), |
| | ] |
| | ) |
| |
|
| | faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list] |
| | return np.array(faces_list) |
| |
|
| |
|
| | def segment_sky(image_path, onnx_session, mask_filename=None): |
| | """ |
| | Segments sky from an image using an ONNX model. |
| | Thanks for the great model provided by https://github.com/xiongzhu666/Sky-Segmentation-and-Post-processing |
| | |
| | Args: |
| | image_path: Path to input image |
| | onnx_session: ONNX runtime session with loaded model |
| | mask_filename: Path to save the output mask |
| | |
| | Returns: |
| | np.ndarray: Binary mask where 255 indicates non-sky regions |
| | """ |
| |
|
| | assert mask_filename is not None |
| | image = cv2.imread(image_path) |
| |
|
| | result_map = run_skyseg(onnx_session, [320, 320], image) |
| | |
| | result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0])) |
| |
|
| | |
| | |
| | output_mask = np.zeros_like(result_map_original) |
| | output_mask[result_map_original < 32] = 255 |
| | |
| | os.makedirs(os.path.dirname(mask_filename), exist_ok=True) |
| | cv2.imwrite(mask_filename, output_mask) |
| | return output_mask |
| |
|
| |
|
| | def run_skyseg(onnx_session, input_size, image): |
| | """ |
| | Runs sky segmentation inference using ONNX model. |
| | |
| | Args: |
| | onnx_session: ONNX runtime session |
| | input_size: Target size for model input (width, height) |
| | image: Input image in BGR format |
| | |
| | Returns: |
| | np.ndarray: Segmentation mask |
| | """ |
| |
|
| | |
| | temp_image = copy.deepcopy(image) |
| | resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1])) |
| | x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB) |
| | x = np.array(x, dtype=np.float32) |
| | mean = [0.485, 0.456, 0.406] |
| | std = [0.229, 0.224, 0.225] |
| | x = (x / 255 - mean) / std |
| | x = x.transpose(2, 0, 1) |
| | x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32") |
| |
|
| | |
| | input_name = onnx_session.get_inputs()[0].name |
| | output_name = onnx_session.get_outputs()[0].name |
| | onnx_result = onnx_session.run([output_name], {input_name: x}) |
| |
|
| | |
| | onnx_result = np.array(onnx_result).squeeze() |
| | min_value = np.min(onnx_result) |
| | max_value = np.max(onnx_result) |
| | onnx_result = (onnx_result - min_value) / (max_value - min_value) |
| | onnx_result *= 255 |
| | onnx_result = onnx_result.astype("uint8") |
| |
|
| | return onnx_result |
| |
|
| |
|
| | def download_file_from_url(url, filename): |
| | """Downloads a file from a Hugging Face model repo, handling redirects.""" |
| | try: |
| | |
| | response = requests.get(url, allow_redirects=False) |
| | response.raise_for_status() |
| |
|
| | if response.status_code == 302: |
| | redirect_url = response.headers["Location"] |
| | response = requests.get(redirect_url, stream=True) |
| | response.raise_for_status() |
| | else: |
| | print(f"Unexpected status code: {response.status_code}") |
| | return |
| |
|
| | with open(filename, "wb") as f: |
| | for chunk in response.iter_content(chunk_size=8192): |
| | f.write(chunk) |
| | print(f"Downloaded {filename} successfully.") |
| |
|
| | except requests.exceptions.RequestException as e: |
| | print(f"Error downloading file: {e}") |
| |
|