# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Visualization module for Depth Anything 3 Gradio app. This module handles visualization updates, navigation, and measurement functionality. """ import os from typing import Any, Dict, List, Optional, Tuple import cv2 import gradio as gr import numpy as np class VisualizationHandler: """ Handles visualization updates and navigation for the Gradio app. """ def __init__(self): """Initialize the visualization handler.""" def update_view_selectors( self, processed_data: Optional[Dict[int, Dict[str, Any]]] ) -> Tuple[gr.Dropdown, gr.Dropdown]: """ Update view selector dropdowns based on available views. Args: processed_data: Processed data dictionary Returns: Tuple of (depth_view_selector, measure_view_selector) """ if processed_data is None or len(processed_data) == 0: choices = ["View 1"] else: num_views = len(processed_data) choices = [f"View {i + 1}" for i in range(num_views)] return ( gr.Dropdown(choices=choices, value=choices[0]), # depth_view_selector gr.Dropdown(choices=choices, value=choices[0]), # measure_view_selector ) def get_view_data_by_index( self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int ) -> Optional[Dict[str, Any]]: """ Get view data by index, handling bounds. Args: processed_data: Processed data dictionary view_index: Index of the view to get Returns: View data dictionary or None """ if processed_data is None or len(processed_data) == 0: return None view_keys = list(processed_data.keys()) if view_index < 0 or view_index >= len(view_keys): view_index = 0 return processed_data[view_keys[view_index]] def update_depth_view( self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int ) -> Optional[str]: """ Update depth view for a specific view index. Args: processed_data: Processed data dictionary view_index: Index of the view to update Returns: Path to depth visualization image or None """ view_data = self.get_view_data_by_index(processed_data, view_index) if view_data is None or view_data.get("depth_image") is None: return None # Return the depth visualization image directly return view_data["depth_image"] def navigate_depth_view( self, processed_data: Optional[Dict[int, Dict[str, Any]]], current_selector_value: str, direction: int, ) -> Tuple[str, Optional[str]]: """ Navigate depth view (direction: -1 for previous, +1 for next). Args: processed_data: Processed data dictionary current_selector_value: Current selector value direction: Direction to navigate (-1 for previous, +1 for next) Returns: Tuple of (new_selector_value, depth_vis) """ if processed_data is None or len(processed_data) == 0: return "View 1", None # Parse current view number try: current_view = int(current_selector_value.split()[1]) - 1 except: # noqa current_view = 0 num_views = len(processed_data) new_view = (current_view + direction) % num_views new_selector_value = f"View {new_view + 1}" depth_vis = self.update_depth_view(processed_data, new_view) return new_selector_value, depth_vis def update_measure_view( self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List]: """ Update measure view for a specific view index. Args: processed_data: Processed data dictionary view_index: Index of the view to update Returns: Tuple of (measure_image, depth_right_half, measure_points) """ view_data = self.get_view_data_by_index(processed_data, view_index) if view_data is None: return None, None, [] # image, depth_right_half, measure_points # Get the processed (resized) image if "image" in view_data and view_data["image"] is not None: image = view_data["image"].copy() else: return None, None, [] # Ensure image is in uint8 format if image.dtype != np.uint8: if image.max() <= 1.0: image = (image * 255).astype(np.uint8) else: image = image.astype(np.uint8) # Extract right half of the depth visualization (pure depth part) depth_image_path = view_data.get("depth_image", None) depth_right_half = None if depth_image_path and os.path.exists(depth_image_path): try: # Load the combined depth visualization image depth_combined = cv2.imread(depth_image_path) depth_combined = cv2.cvtColor(depth_combined, cv2.COLOR_BGR2RGB) if depth_combined is not None: height, width = depth_combined.shape[:2] # Extract right half (depth visualization part) depth_right_half = depth_combined[:, width // 2 :] except Exception as e: print(f"Error extracting depth right half: {e}") return image, depth_right_half, [] def navigate_measure_view( self, processed_data: Optional[Dict[int, Dict[str, Any]]], current_selector_value: str, direction: int, ) -> Tuple[str, Optional[np.ndarray], Optional[str], List]: """ Navigate measure view (direction: -1 for previous, +1 for next). Args: processed_data: Processed data dictionary current_selector_value: Current selector value direction: Direction to navigate (-1 for previous, +1 for next) Returns: Tuple of (new_selector_value, measure_image, depth_image_path, measure_points) """ if processed_data is None or len(processed_data) == 0: return "View 1", None, None, [] # Parse current view number try: current_view = int(current_selector_value.split()[1]) - 1 except: # noqa current_view = 0 num_views = len(processed_data) new_view = (current_view + direction) % num_views new_selector_value = f"View {new_view + 1}" measure_image, depth_right_half, measure_points = self.update_measure_view( processed_data, new_view ) return new_selector_value, measure_image, depth_right_half, measure_points def populate_visualization_tabs( self, processed_data: Optional[Dict[int, Dict[str, Any]]] ) -> Tuple[Optional[str], Optional[np.ndarray], Optional[str], List]: """ Populate the depth and measure tabs with processed data. Args: processed_data: Processed data dictionary Returns: Tuple of (depth_vis, measure_img, depth_image_path, measure_points) """ if processed_data is None or len(processed_data) == 0: return None, None, None, [] # Use update function to get depth visualization depth_vis = self.update_depth_view(processed_data, 0) measure_img, depth_right_half, _ = self.update_measure_view(processed_data, 0) return depth_vis, measure_img, depth_right_half, [] def reset_measure( self, processed_data: Optional[Dict[int, Dict[str, Any]]] ) -> Tuple[Optional[np.ndarray], List, str]: """ Reset measure points. Args: processed_data: Processed data dictionary Returns: Tuple of (image, measure_points, text) """ if processed_data is None or len(processed_data) == 0: return None, [], "" # Return the first view image first_view = list(processed_data.values())[0] return first_view["image"], [], "" def measure( self, processed_data: Optional[Dict[int, Dict[str, Any]]], measure_points: List, current_view_selector: str, event: gr.SelectData, ) -> List: """ Handle measurement on images. Args: processed_data: Processed data dictionary measure_points: List of current measure points current_view_selector: Current view selector value event: Gradio select event Returns: List of [image, depth_right_half, measure_points, text] """ try: print(f"Measure function called with selector: {current_view_selector}") if processed_data is None or len(processed_data) == 0: return [None, [], "No data available"] # Use the currently selected view instead of always using the first view try: current_view_index = int(current_view_selector.split()[1]) - 1 except: # noqa current_view_index = 0 print(f"Using view index: {current_view_index}") # Get view data safely if current_view_index < 0 or current_view_index >= len(processed_data): current_view_index = 0 view_keys = list(processed_data.keys()) current_view = processed_data[view_keys[current_view_index]] if current_view is None: return [None, [], "No view data available"] point2d = event.index[0], event.index[1] print(f"Clicked point: {point2d}") measure_points.append(point2d) # Get image and depth visualization image, depth_right_half, _ = self.update_measure_view( processed_data, current_view_index ) if image is None: return [None, [], "No image available"] image = image.copy() # Ensure image is in uint8 format for proper cv2 operations try: if image.dtype != np.uint8: if image.max() <= 1.0: # Image is in [0, 1] range, convert to [0, 255] image = (image * 255).astype(np.uint8) else: # Image is already in [0, 255] range image = image.astype(np.uint8) except Exception as e: print(f"Image conversion error: {e}") return [None, [], f"Image conversion error: {e}"] # Draw circles for points try: for p in measure_points: if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]: image = cv2.circle(image, p, radius=5, color=(255, 0, 0), thickness=2) except Exception as e: print(f"Drawing error: {e}") return [None, [], f"Drawing error: {e}"] # Get depth information from processed_data depth_text = "" try: for i, p in enumerate(measure_points): if ( current_view["depth"] is not None and 0 <= p[1] < current_view["depth"].shape[0] and 0 <= p[0] < current_view["depth"].shape[1] ): d = current_view["depth"][p[1], p[0]] depth_text += f"- **P{i + 1} depth: {d:.2f}m**\n" else: depth_text += f"- **P{i + 1}: Click position ({p[0]}, {p[1]}) - No depth information**\n" # noqa: E501 except Exception as e: print(f"Depth text error: {e}") depth_text = f"Error computing depth: {e}\n" if len(measure_points) == 2: try: point1, point2 = measure_points # Draw line if ( 0 <= point1[0] < image.shape[1] and 0 <= point1[1] < image.shape[0] and 0 <= point2[0] < image.shape[1] and 0 <= point2[1] < image.shape[0] ): image = cv2.line(image, point1, point2, color=(255, 0, 0), thickness=2) # Compute 3D distance using depth information and camera intrinsics distance_text = "- **Distance: Unable to calculate 3D distance**" if ( current_view["depth"] is not None and 0 <= point1[1] < current_view["depth"].shape[0] and 0 <= point1[0] < current_view["depth"].shape[1] and 0 <= point2[1] < current_view["depth"].shape[0] and 0 <= point2[0] < current_view["depth"].shape[1] ): try: # Get depth values at the two points d1 = current_view["depth"][point1[1], point1[0]] d2 = current_view["depth"][point2[1], point2[0]] # Convert 2D pixel coordinates to 3D world coordinates if current_view["intrinsics"] is not None: # Get camera intrinsics K = current_view["intrinsics"] # 3x3 intrinsic matrix fx, fy = K[0, 0], K[1, 1] # focal lengths cx, cy = K[0, 2], K[1, 2] # principal point # Convert pixel coordinates to normalized camera coordinates # Point 1: (u1, v1) -> (x1, y1, z1) u1, v1 = point1[0], point1[1] x1 = (u1 - cx) * d1 / fx y1 = (v1 - cy) * d1 / fy z1 = d1 # Point 2: (u2, v2) -> (x2, y2, z2) u2, v2 = point2[0], point2[1] x2 = (u2 - cx) * d2 / fx y2 = (v2 - cy) * d2 / fy z2 = d2 # Calculate 3D Euclidean distance p1_3d = np.array([x1, y1, z1]) p2_3d = np.array([x2, y2, z2]) distance_3d = np.linalg.norm(p1_3d - p2_3d) distance_text = f"- **Distance: {distance_3d:.2f}m**" else: # Fallback to simplified calculation if no intrinsics pixel_distance = np.sqrt( (point1[0] - point2[0]) ** 2 + (point1[1] - point2[1]) ** 2 ) avg_depth = (d1 + d2) / 2 scale_factor = avg_depth / 1000 # Rough scaling factor estimated_3d_distance = pixel_distance * scale_factor distance_text = f"- **Distance: {estimated_3d_distance:.2f}m (estimated, no intrinsics)**" # noqa: E501 except Exception as e: print(f"Distance computation error: {e}") distance_text = f"- **Distance computation error: {e}**" measure_points = [] text = depth_text + distance_text print(f"Measurement complete: {text}") return [image, depth_right_half, measure_points, text] except Exception as e: print(f"Final measurement error: {e}") return [None, [], f"Measurement error: {e}"] else: print(f"Single point measurement: {depth_text}") return [image, depth_right_half, measure_points, depth_text] except Exception as e: print(f"Overall measure function error: {e}") return [None, [], f"Measure function error: {e}"]