File size: 17,125 Bytes
4845d25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Visualization module for Depth Anything 3 Gradio app.

This module handles visualization updates, navigation, and measurement functionality.
"""

import os
from typing import Any, Dict, List, Optional, Tuple
import cv2
import gradio as gr
import numpy as np


class VisualizationHandler:
    """
    Handles visualization updates and navigation for the Gradio app.
    """

    def __init__(self):
        """Initialize the visualization handler."""

    def update_view_selectors(
        self, processed_data: Optional[Dict[int, Dict[str, Any]]]
    ) -> Tuple[gr.Dropdown, gr.Dropdown]:
        """
        Update view selector dropdowns based on available views.

        Args:
            processed_data: Processed data dictionary

        Returns:
            Tuple of (depth_view_selector, measure_view_selector)
        """
        if processed_data is None or len(processed_data) == 0:
            choices = ["View 1"]
        else:
            num_views = len(processed_data)
            choices = [f"View {i + 1}" for i in range(num_views)]

        return (
            gr.Dropdown(choices=choices, value=choices[0]),  # depth_view_selector
            gr.Dropdown(choices=choices, value=choices[0]),  # measure_view_selector
        )

    def get_view_data_by_index(
        self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int
    ) -> Optional[Dict[str, Any]]:
        """
        Get view data by index, handling bounds.

        Args:
            processed_data: Processed data dictionary
            view_index: Index of the view to get

        Returns:
            View data dictionary or None
        """
        if processed_data is None or len(processed_data) == 0:
            return None

        view_keys = list(processed_data.keys())
        if view_index < 0 or view_index >= len(view_keys):
            view_index = 0

        return processed_data[view_keys[view_index]]

    def update_depth_view(
        self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int
    ) -> Optional[str]:
        """
        Update depth view for a specific view index.

        Args:
            processed_data: Processed data dictionary
            view_index: Index of the view to update

        Returns:
            Path to depth visualization image or None
        """
        view_data = self.get_view_data_by_index(processed_data, view_index)
        if view_data is None or view_data.get("depth_image") is None:
            return None

        # Return the depth visualization image directly
        return view_data["depth_image"]

    def navigate_depth_view(
        self,
        processed_data: Optional[Dict[int, Dict[str, Any]]],
        current_selector_value: str,
        direction: int,
    ) -> Tuple[str, Optional[str]]:
        """
        Navigate depth view (direction: -1 for previous, +1 for next).

        Args:
            processed_data: Processed data dictionary
            current_selector_value: Current selector value
            direction: Direction to navigate (-1 for previous, +1 for next)

        Returns:
            Tuple of (new_selector_value, depth_vis)
        """
        if processed_data is None or len(processed_data) == 0:
            return "View 1", None

        # Parse current view number
        try:
            current_view = int(current_selector_value.split()[1]) - 1
        except:  # noqa
            current_view = 0

        num_views = len(processed_data)
        new_view = (current_view + direction) % num_views

        new_selector_value = f"View {new_view + 1}"
        depth_vis = self.update_depth_view(processed_data, new_view)

        return new_selector_value, depth_vis

    def update_measure_view(
        self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int
    ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List]:
        """
        Update measure view for a specific view index.

        Args:
            processed_data: Processed data dictionary
            view_index: Index of the view to update

        Returns:
            Tuple of (measure_image, depth_right_half, measure_points)
        """
        view_data = self.get_view_data_by_index(processed_data, view_index)
        if view_data is None:
            return None, None, []  # image, depth_right_half, measure_points

        # Get the processed (resized) image
        if "image" in view_data and view_data["image"] is not None:
            image = view_data["image"].copy()
        else:
            return None, None, []

        # Ensure image is in uint8 format
        if image.dtype != np.uint8:
            if image.max() <= 1.0:
                image = (image * 255).astype(np.uint8)
            else:
                image = image.astype(np.uint8)

        # Extract right half of the depth visualization (pure depth part)
        depth_image_path = view_data.get("depth_image", None)
        depth_right_half = None

        if depth_image_path and os.path.exists(depth_image_path):
            try:
                # Load the combined depth visualization image
                depth_combined = cv2.imread(depth_image_path)
                depth_combined = cv2.cvtColor(depth_combined, cv2.COLOR_BGR2RGB)
                if depth_combined is not None:
                    height, width = depth_combined.shape[:2]
                    # Extract right half (depth visualization part)
                    depth_right_half = depth_combined[:, width // 2 :]
            except Exception as e:
                print(f"Error extracting depth right half: {e}")

        return image, depth_right_half, []

    def navigate_measure_view(
        self,
        processed_data: Optional[Dict[int, Dict[str, Any]]],
        current_selector_value: str,
        direction: int,
    ) -> Tuple[str, Optional[np.ndarray], Optional[str], List]:
        """
        Navigate measure view (direction: -1 for previous, +1 for next).

        Args:
            processed_data: Processed data dictionary
            current_selector_value: Current selector value
            direction: Direction to navigate (-1 for previous, +1 for next)

        Returns:
            Tuple of (new_selector_value, measure_image, depth_image_path, measure_points)
        """
        if processed_data is None or len(processed_data) == 0:
            return "View 1", None, None, []

        # Parse current view number
        try:
            current_view = int(current_selector_value.split()[1]) - 1
        except:  # noqa
            current_view = 0

        num_views = len(processed_data)
        new_view = (current_view + direction) % num_views

        new_selector_value = f"View {new_view + 1}"
        measure_image, depth_right_half, measure_points = self.update_measure_view(
            processed_data, new_view
        )

        return new_selector_value, measure_image, depth_right_half, measure_points

    def populate_visualization_tabs(
        self, processed_data: Optional[Dict[int, Dict[str, Any]]]
    ) -> Tuple[Optional[str], Optional[np.ndarray], Optional[str], List]:
        """
        Populate the depth and measure tabs with processed data.

        Args:
            processed_data: Processed data dictionary

        Returns:
            Tuple of (depth_vis, measure_img, depth_image_path, measure_points)
        """
        if processed_data is None or len(processed_data) == 0:
            return None, None, None, []

        # Use update function to get depth visualization
        depth_vis = self.update_depth_view(processed_data, 0)
        measure_img, depth_right_half, _ = self.update_measure_view(processed_data, 0)

        return depth_vis, measure_img, depth_right_half, []

    def reset_measure(
        self, processed_data: Optional[Dict[int, Dict[str, Any]]]
    ) -> Tuple[Optional[np.ndarray], List, str]:
        """
        Reset measure points.

        Args:
            processed_data: Processed data dictionary

        Returns:
            Tuple of (image, measure_points, text)
        """
        if processed_data is None or len(processed_data) == 0:
            return None, [], ""

        # Return the first view image
        first_view = list(processed_data.values())[0]
        return first_view["image"], [], ""

    def measure(
        self,
        processed_data: Optional[Dict[int, Dict[str, Any]]],
        measure_points: List,
        current_view_selector: str,
        event: gr.SelectData,
    ) -> List:
        """
        Handle measurement on images.

        Args:
            processed_data: Processed data dictionary
            measure_points: List of current measure points
            current_view_selector: Current view selector value
            event: Gradio select event

        Returns:
            List of [image, depth_right_half, measure_points, text]
        """
        try:
            print(f"Measure function called with selector: {current_view_selector}")

            if processed_data is None or len(processed_data) == 0:
                return [None, [], "No data available"]

            # Use the currently selected view instead of always using the first view
            try:
                current_view_index = int(current_view_selector.split()[1]) - 1
            except:  # noqa
                current_view_index = 0

            print(f"Using view index: {current_view_index}")

            # Get view data safely
            if current_view_index < 0 or current_view_index >= len(processed_data):
                current_view_index = 0

            view_keys = list(processed_data.keys())
            current_view = processed_data[view_keys[current_view_index]]

            if current_view is None:
                return [None, [], "No view data available"]

            point2d = event.index[0], event.index[1]
            print(f"Clicked point: {point2d}")

            measure_points.append(point2d)

            # Get image and depth visualization
            image, depth_right_half, _ = self.update_measure_view(
                processed_data, current_view_index
            )
            if image is None:
                return [None, [], "No image available"]

            image = image.copy()

            # Ensure image is in uint8 format for proper cv2 operations
            try:
                if image.dtype != np.uint8:
                    if image.max() <= 1.0:
                        # Image is in [0, 1] range, convert to [0, 255]
                        image = (image * 255).astype(np.uint8)
                    else:
                        # Image is already in [0, 255] range
                        image = image.astype(np.uint8)
            except Exception as e:
                print(f"Image conversion error: {e}")
                return [None, [], f"Image conversion error: {e}"]

            # Draw circles for points
            try:
                for p in measure_points:
                    if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]:
                        image = cv2.circle(image, p, radius=5, color=(255, 0, 0), thickness=2)
            except Exception as e:
                print(f"Drawing error: {e}")
                return [None, [], f"Drawing error: {e}"]

            # Get depth information from processed_data
            depth_text = ""
            try:
                for i, p in enumerate(measure_points):
                    if (
                        current_view["depth"] is not None
                        and 0 <= p[1] < current_view["depth"].shape[0]
                        and 0 <= p[0] < current_view["depth"].shape[1]
                    ):
                        d = current_view["depth"][p[1], p[0]]
                        depth_text += f"- **P{i + 1} depth: {d:.2f}m**\n"
                    else:
                        depth_text += f"- **P{i + 1}: Click position ({p[0]}, {p[1]}) - No depth information**\n"  # noqa: E501
            except Exception as e:
                print(f"Depth text error: {e}")
                depth_text = f"Error computing depth: {e}\n"

            if len(measure_points) == 2:
                try:
                    point1, point2 = measure_points
                    # Draw line
                    if (
                        0 <= point1[0] < image.shape[1]
                        and 0 <= point1[1] < image.shape[0]
                        and 0 <= point2[0] < image.shape[1]
                        and 0 <= point2[1] < image.shape[0]
                    ):
                        image = cv2.line(image, point1, point2, color=(255, 0, 0), thickness=2)

                    # Compute 3D distance using depth information and camera intrinsics
                    distance_text = "- **Distance: Unable to calculate 3D distance**"
                    if (
                        current_view["depth"] is not None
                        and 0 <= point1[1] < current_view["depth"].shape[0]
                        and 0 <= point1[0] < current_view["depth"].shape[1]
                        and 0 <= point2[1] < current_view["depth"].shape[0]
                        and 0 <= point2[0] < current_view["depth"].shape[1]
                    ):
                        try:
                            # Get depth values at the two points
                            d1 = current_view["depth"][point1[1], point1[0]]
                            d2 = current_view["depth"][point2[1], point2[0]]

                            # Convert 2D pixel coordinates to 3D world coordinates
                            if current_view["intrinsics"] is not None:
                                # Get camera intrinsics
                                K = current_view["intrinsics"]  # 3x3 intrinsic matrix
                                fx, fy = K[0, 0], K[1, 1]  # focal lengths
                                cx, cy = K[0, 2], K[1, 2]  # principal point

                                # Convert pixel coordinates to normalized camera coordinates
                                # Point 1: (u1, v1) -> (x1, y1, z1)
                                u1, v1 = point1[0], point1[1]
                                x1 = (u1 - cx) * d1 / fx
                                y1 = (v1 - cy) * d1 / fy
                                z1 = d1

                                # Point 2: (u2, v2) -> (x2, y2, z2)
                                u2, v2 = point2[0], point2[1]
                                x2 = (u2 - cx) * d2 / fx
                                y2 = (v2 - cy) * d2 / fy
                                z2 = d2

                                # Calculate 3D Euclidean distance
                                p1_3d = np.array([x1, y1, z1])
                                p2_3d = np.array([x2, y2, z2])
                                distance_3d = np.linalg.norm(p1_3d - p2_3d)

                                distance_text = f"- **Distance: {distance_3d:.2f}m**"
                            else:
                                # Fallback to simplified calculation if no intrinsics
                                pixel_distance = np.sqrt(
                                    (point1[0] - point2[0]) ** 2 + (point1[1] - point2[1]) ** 2
                                )
                                avg_depth = (d1 + d2) / 2
                                scale_factor = avg_depth / 1000  # Rough scaling factor
                                estimated_3d_distance = pixel_distance * scale_factor
                                distance_text = f"- **Distance: {estimated_3d_distance:.2f}m (estimated, no intrinsics)**"  # noqa: E501

                        except Exception as e:
                            print(f"Distance computation error: {e}")
                            distance_text = f"- **Distance computation error: {e}**"

                    measure_points = []
                    text = depth_text + distance_text
                    print(f"Measurement complete: {text}")
                    return [image, depth_right_half, measure_points, text]
                except Exception as e:
                    print(f"Final measurement error: {e}")
                    return [None, [], f"Measurement error: {e}"]
            else:
                print(f"Single point measurement: {depth_text}")
                return [image, depth_right_half, measure_points, depth_text]

        except Exception as e:
            print(f"Overall measure function error: {e}")
            return [None, [], f"Measure function error: {e}"]