File size: 18,875 Bytes
7a87926
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
#!/usr/bin/env python3
"""
Run BA validation on ARKit data.
Compares DA3 poses vs ARKit poses (ground truth) and vs COLMAP BA.
"""

import json
import logging
import sys
from pathlib import Path
from typing import Dict
import numpy as np
import torch

# Add project root to path
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))

from ylff.services.arkit_processor import ARKitProcessor  # noqa: E402
from ylff.services.ba_validator import BAValidator  # noqa: E402
from ylff.utils.model_loader import load_da3_model  # noqa: E402

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)


def compute_pose_error(poses1: np.ndarray, poses2: np.ndarray, verbose: bool = False) -> Dict:
    """Compute pose error between two sets of poses."""
    # Align trajectories
    centers1 = poses1[:, :3, 3] if poses1.shape[1] == 4 else poses1[:, :3, 3]
    centers2 = poses2[:, :3, 3] if poses2.shape[1] == 4 else poses2[:, :3, 3]

    # Center both
    center1_mean = centers1.mean(axis=0)
    center2_mean = centers2.mean(axis=0)

    centers1_centered = centers1 - center1_mean
    centers2_centered = centers2 - center2_mean

    # Compute scale
    scale1 = np.linalg.norm(centers1_centered, axis=1).mean()
    scale2 = np.linalg.norm(centers2_centered, axis=1).mean()
    scale = scale2 / (scale1 + 1e-8)

    if verbose:
        logger.info(f"  Alignment scale factor: {scale:.6f}")
        logger.info(f"  Scale1 (poses1): {scale1:.6f}")
        logger.info(f"  Scale2 (poses2): {scale2:.6f}")

    # Compute rotation (SVD)
    H = centers1_centered.T @ centers2_centered
    U, _, Vt = np.linalg.svd(H)
    R_align = Vt.T @ U.T

    if verbose:
        # Check if R_align is a valid rotation matrix
        det = np.linalg.det(R_align)
        logger.info(f"  Alignment rotation det: {det:.6f} (should be ~1.0)")
        logger.info(f"  Alignment rotation trace: {np.trace(R_align):.3f}")

    # Align poses
    poses1_aligned = poses1.copy()
    for i in range(len(poses1)):
        if poses1.shape[1] == 4:
            R_orig = poses1[i][:3, :3]
            t_orig = poses1[i][:3, 3]
        else:
            R_orig = poses1[i][:3, :3]
            t_orig = poses1[i][:3, 3]

        R_aligned = R_align @ R_orig
        t_aligned = scale * (R_align @ (t_orig - center1_mean)) + center2_mean

        if poses1_aligned.shape[1] == 4:
            poses1_aligned[i][:3, :3] = R_aligned
            poses1_aligned[i][:3, 3] = t_aligned
        else:
            poses1_aligned[i][:3, :3] = R_aligned
            poses1_aligned[i][:3, 3] = t_aligned

    # Compute rotation errors
    rotation_errors = []
    translation_errors = []

    for i in range(len(poses1)):
        if poses1_aligned.shape[1] == 4:
            R1 = poses1_aligned[i][:3, :3]
            R2 = poses2[i][:3, :3] if poses2.shape[1] == 4 else poses2[i][:3, :3]
            t1 = poses1_aligned[i][:3, 3]
            t2 = poses2[i][:3, 3] if poses2.shape[1] == 4 else poses2[i][:3, 3]
        else:
            R1 = poses1_aligned[i][:3, :3]
            R2 = poses2[i][:3, :3]
            t1 = poses1_aligned[i][:3, 3]
            t2 = poses2[i][:3, 3]

        # Rotation error
        R_diff = R1 @ R2.T
        trace = np.trace(R_diff)
        angle_rad = np.arccos(np.clip((trace - 1) / 2, -1, 1))
        angle_deg = np.degrees(angle_rad)
        rotation_errors.append(angle_deg)

        # Translation error
        trans_error = np.linalg.norm(t1 - t2)
        translation_errors.append(trans_error)

    result = {
        "rotation_errors_deg": rotation_errors,
        "translation_errors": translation_errors,
        "mean_rotation_error_deg": np.mean(rotation_errors),
        "max_rotation_error_deg": np.max(rotation_errors),
        "mean_translation_error": np.mean(translation_errors),
        "alignment_info": {
            "scale_factor": float(scale),
            "center1_mean": center1_mean.tolist(),
            "center2_mean": center2_mean.tolist(),
            "rotation_det": float(np.linalg.det(R_align)),
        },
    }

    if verbose:
        logger.info("  Alignment info saved to results")

    return result


def main():
    import argparse

    parser = argparse.ArgumentParser(description="Run BA validation on ARKit data")
    parser.add_argument(
        "--arkit-dir",
        type=Path,
        default=project_root / "assets" / "examples" / "ARKit",
        help="Directory containing ARKit video and metadata",
    )
    parser.add_argument(
        "--output-dir",
        type=Path,
        default=project_root / "data" / "arkit_ba_validation",
        help="Output directory for results",
    )
    parser.add_argument(
        "--max-frames", type=int, default=None, help="Maximum number of frames to process"
    )
    parser.add_argument("--frame-interval", type=int, default=1, help="Extract every Nth frame")
    parser.add_argument("--device", type=str, default="cpu", help="Device for DA3 inference")

    args = parser.parse_args()

    # Set defaults if not provided
    if args.arkit_dir is None:
        args.arkit_dir = project_root / "assets" / "examples" / "ARKit"
    if args.output_dir is None:
        args.output_dir = project_root / "data" / "arkit_ba_validation"

    args.output_dir.mkdir(parents=True, exist_ok=True)

    # Find ARKit files
    video_path = None
    metadata_path = None

    for video_file in (args.arkit_dir / "videos").glob("*.MOV"):
        video_path = video_file
        break

    for json_file in (args.arkit_dir / "json-metadata").glob("*.json"):
        metadata_path = json_file
        break

    if not video_path or not metadata_path:
        logger.error(f"ARKit files not found in {args.arkit_dir}")
        logger.error("Expected: videos/*.MOV and json-metadata/*.json")
        return

    logger.info(f"ARKit video: {video_path}")
    logger.info(f"ARKit metadata: {metadata_path}")

    # Process ARKit data
    logger.info("\n=== Processing ARKit Data ===")
    processor = ARKitProcessor(video_path, metadata_path)

    arkit_data = processor.process_for_ba_validation(
        output_dir=args.output_dir,
        max_frames=args.max_frames,
        frame_interval=args.frame_interval,
        use_good_tracking_only=True,
    )

    image_paths = arkit_data["image_paths"]
    arkit_poses_c2w = arkit_data["arkit_poses_c2w"]
    # arkit_poses_w2c = arkit_data["arkit_poses_w2c"]  # Not used in this script
    # arkit_intrinsics = arkit_data["arkit_intrinsics"]  # Not used in this script

    # Convert ARKit c2w poses to OpenCV convention for proper comparison
    from ylff.coordinate_utils import convert_arkit_to_opencv

    arkit_poses_c2w_opencv = np.array([convert_arkit_to_opencv(p) for p in arkit_poses_c2w])

    logger.info(f"Processed {len(image_paths)} frames")

    # Run DA3 inference
    logger.info("\n=== Running DA3 Inference ===")
    model = load_da3_model("depth-anything/DA3-LARGE", device=args.device)

    import cv2

    images = []
    for img_path in image_paths:
        img = cv2.imread(str(img_path))
        if img is not None:
            images.append(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))

    logger.info(f"Running DA3 on {len(images)} images...")
    with torch.no_grad():
        da3_output = model.inference(images)

    da3_poses = da3_output.extrinsics  # (N, 3, 4) w2c
    da3_intrinsics = da3_output.intrinsics if hasattr(da3_output, "intrinsics") else None

    logger.info(f"DA3 poses: {da3_poses.shape}")

    # Compare DA3 vs ARKit
    # Convert ARKit c2w to w2c in OpenCV convention for comparison
    arkit_poses_w2c_opencv = np.array([np.linalg.inv(p)[:3, :] for p in arkit_poses_c2w_opencv])

    logger.info("\n=== Comparing DA3 vs ARKit (Ground Truth) ===")

    # Log pose statistics before comparison
    logger.info("\nPose Statistics (before alignment):")
    da3_centers = da3_poses[:, :3, 3]
    arkit_centers = arkit_poses_w2c_opencv[:, :3, 3]
    logger.info(f"  DA3 translation range: [{da3_centers.min(axis=0)}, {da3_centers.max(axis=0)}]")
    da3_norms = np.linalg.norm(da3_centers, axis=1)
    logger.info(
        f"  DA3 translation magnitude: mean={da3_norms.mean():.3f}, " f"std={da3_norms.std():.3f}"
    )
    logger.info(
        f"  ARKit translation range: [{arkit_centers.min(axis=0)}, {arkit_centers.max(axis=0)}]"
    )
    arkit_norms = np.linalg.norm(arkit_centers, axis=1)
    logger.info(
        f"  ARKit translation magnitude: mean={arkit_norms.mean():.3f}, "
        f"std={arkit_norms.std():.3f}"
    )

    da3_vs_arkit = compute_pose_error(da3_poses, arkit_poses_w2c_opencv, verbose=True)

    logger.info("\nDA3 vs ARKit Error Summary:")
    logger.info(f"  Mean rotation error: {da3_vs_arkit['mean_rotation_error_deg']:.2f}°")
    logger.info(f"  Median rotation error: {np.median(da3_vs_arkit['rotation_errors_deg']):.2f}°")
    logger.info(f"  Max rotation error: {da3_vs_arkit['max_rotation_error_deg']:.2f}°")
    logger.info(f"  Min rotation error: {np.min(da3_vs_arkit['rotation_errors_deg']):.2f}°")
    logger.info(f"  Std rotation error: {np.std(da3_vs_arkit['rotation_errors_deg']):.2f}°")
    logger.info(f"  Mean translation error: {da3_vs_arkit['mean_translation_error']:.3f} m")
    logger.info(f"  Max translation error: {np.max(da3_vs_arkit['translation_errors']):.3f} m")

    # Alignment diagnostics
    if "alignment_info" in da3_vs_arkit:
        align_info = da3_vs_arkit["alignment_info"]
        logger.info("\nAlignment Diagnostics:")
        logger.info(
            f"  Scale factor: {align_info['scale_factor']:.6f} (should be ~1.0 if scales match)"
        )
        logger.info(f"  Rotation matrix det: {align_info['rotation_det']:.6f} (should be ~1.0)")
        logger.info(f"  Center1 (DA3) mean: {align_info['center1_mean']}")
        logger.info(f"  Center2 (ARKit) mean: {align_info['center2_mean']}")

    # Per-frame breakdown
    logger.info("\nPer-Frame Error Breakdown:")
    logger.info(f"  {'Frame':<8} {'Rot Error (°)':<15} {'Trans Error (m)':<15} {'Category':<20}")
    logger.info(f"  {'-' * 8} {'-' * 15} {'-' * 15} {'-' * 20}")
    for i, (rot_err, trans_err) in enumerate(
        zip(da3_vs_arkit["rotation_errors_deg"], da3_vs_arkit["translation_errors"])
    ):
        if rot_err < 2.0:
            category = "Accepted"
        elif rot_err < 30.0:
            category = "Rejected-Learnable"
        else:
            category = "Rejected-Outlier"
        logger.info(f"  {i:<8} {rot_err:<15.2f} {trans_err:<15.3f} {category:<20}")

    # Error distribution
    rot_errors = da3_vs_arkit["rotation_errors_deg"]
    logger.info("\nError Distribution (Rotation):")
    logger.info(f"  Q1 (25th percentile): {np.percentile(rot_errors, 25):.2f}°")
    logger.info(f"  Q2 (50th percentile/median): {np.percentile(rot_errors, 50):.2f}°")
    logger.info(f"  Q3 (75th percentile): {np.percentile(rot_errors, 75):.2f}°")
    logger.info(f"  90th percentile: {np.percentile(rot_errors, 90):.2f}°")
    logger.info(f"  95th percentile: {np.percentile(rot_errors, 95):.2f}°")
    logger.info(f"  99th percentile: {np.percentile(rot_errors, 99):.2f}°")

    # Run BA validation
    logger.info("\n=== Running BA Validation ===")
    validator = BAValidator(
        accept_threshold=2.0,
        reject_threshold=30.0,
        work_dir=args.output_dir / "ba_work",
    )

    ba_result = validator.validate(
        images=images,
        poses_model=da3_poses,
        intrinsics=da3_intrinsics,
    )

    if ba_result["status"] != "ba_failed" and ba_result.get("poses_ba") is not None:
        ba_poses = ba_result["poses_ba"]  # (N, 3, 4) w2c

        # Compare BA vs ARKit
        logger.info("\n=== Comparing BA vs ARKit (Ground Truth) ===")
        ba_vs_arkit = compute_pose_error(ba_poses, arkit_poses_w2c_opencv)

        logger.info("BA vs ARKit:")
        logger.info(f"  Mean rotation error: {ba_vs_arkit['mean_rotation_error_deg']:.2f}°")
        logger.info(f"  Max rotation error: {ba_vs_arkit['max_rotation_error_deg']:.2f}°")
        logger.info(f"  Mean translation error: {ba_vs_arkit['mean_translation_error']:.2f}")

        # Compare DA3 vs BA
        logger.info("\n=== Comparing DA3 vs BA ===")
        da3_vs_ba = compute_pose_error(da3_poses, ba_poses)

        logger.info("DA3 vs BA:")
        logger.info(f"  Mean rotation error: {da3_vs_ba['mean_rotation_error_deg']:.2f}°")
        logger.info(f"  Max rotation error: {da3_vs_ba['max_rotation_error_deg']:.2f}°")

        # Save DA3 and BA poses for visualization
        np.save(args.output_dir / "da3_poses_w2c.npy", da3_poses)
        if ba_result["status"] != "ba_failed" and ba_result.get("poses_ba") is not None:
            np.save(args.output_dir / "ba_poses_w2c.npy", ba_poses)

        # Calculate frame categorization from DA3 vs ARKit errors
        rot_errors = da3_vs_arkit["rotation_errors_deg"]
        accepted_frames = []
        rejected_learnable_frames = []
        rejected_outlier_frames = []

        accept_threshold = 2.0
        reject_threshold = 30.0

        for i, err in enumerate(rot_errors):
            if err < accept_threshold:
                accepted_frames.append(i)
            elif err < reject_threshold:
                rejected_learnable_frames.append(i)
            else:
                rejected_outlier_frames.append(i)

        frame_categorization = {
            "accepted": {
                "count": len(accepted_frames),
                "percentage": (
                    100.0 * len(accepted_frames) / len(rot_errors) if rot_errors else 0.0
                ),
                "frame_indices": accepted_frames,
            },
            "rejected_learnable": {
                "count": len(rejected_learnable_frames),
                "percentage": (
                    100.0 * len(rejected_learnable_frames) / len(rot_errors) if rot_errors else 0.0
                ),
                "frame_indices": rejected_learnable_frames,
            },
            "rejected_outlier": {
                "count": len(rejected_outlier_frames),
                "percentage": (
                    100.0 * len(rejected_outlier_frames) / len(rot_errors) if rot_errors else 0.0
                ),
                "frame_indices": rejected_outlier_frames,
            },
            "total_frames": len(rot_errors),
        }

        logger.info("\n=== Frame Categorization (DA3 vs ARKit) ===")
        accepted_info = frame_categorization["accepted"]
        learnable_info = frame_categorization["rejected_learnable"]
        outlier_info = frame_categorization["rejected_outlier"]
        total_frames = frame_categorization["total_frames"]
        logger.info(
            f"  Accepted (< {accept_threshold}°): "
            f"{accepted_info['count']}/{total_frames} "
            f"({accepted_info['percentage']:.1f}%)"
        )
        logger.info(
            f"  Rejected-Learnable ({accept_threshold}-{reject_threshold}°): "
            f"{learnable_info['count']}/{total_frames} "
            f"({learnable_info['percentage']:.1f}%)"
        )
        logger.info(
            f"  Rejected-Outlier (> {reject_threshold}°): "
            f"{outlier_info['count']}/{total_frames} "
            f"({outlier_info['percentage']:.1f}%)"
        )

        # Add detailed diagnostics
        diagnostics = {
            "pose_statistics": {
                "da3": {
                    "translation_range": {
                        "min": da3_centers.min(axis=0).tolist(),
                        "max": da3_centers.max(axis=0).tolist(),
                        "mean_magnitude": float(np.linalg.norm(da3_centers, axis=1).mean()),
                        "std_magnitude": float(np.linalg.norm(da3_centers, axis=1).std()),
                    }
                },
                "arkit": {
                    "translation_range": {
                        "min": arkit_centers.min(axis=0).tolist(),
                        "max": arkit_centers.max(axis=0).tolist(),
                        "mean_magnitude": float(np.linalg.norm(arkit_centers, axis=1).mean()),
                        "std_magnitude": float(np.linalg.norm(arkit_centers, axis=1).std()),
                    }
                },
            },
            "error_distribution": {
                "rotation_errors_deg": {
                    "q1": float(np.percentile(rot_errors, 25)),
                    "median": float(np.percentile(rot_errors, 50)),
                    "q3": float(np.percentile(rot_errors, 75)),
                    "p90": float(np.percentile(rot_errors, 90)),
                    "p95": float(np.percentile(rot_errors, 95)),
                    "p99": float(np.percentile(rot_errors, 99)),
                },
                "translation_errors": {
                    "mean": float(np.mean(da3_vs_arkit["translation_errors"])),
                    "median": float(np.median(da3_vs_arkit["translation_errors"])),
                    "max": float(np.max(da3_vs_arkit["translation_errors"])),
                    "std": float(np.std(da3_vs_arkit["translation_errors"])),
                },
            },
            "per_frame_errors": [
                {
                    "frame_idx": i,
                    "rotation_error_deg": float(rot_err),
                    "translation_error_m": float(trans_err),
                    "category": (
                        "accepted"
                        if rot_err < 2.0
                        else ("rejected_learnable" if rot_err < 30.0 else "rejected_outlier")
                    ),
                }
                for i, (rot_err, trans_err) in enumerate(
                    zip(da3_vs_arkit["rotation_errors_deg"], da3_vs_arkit["translation_errors"])
                )
            ],
            "da3_vs_arkit": {"alignment_info": da3_vs_arkit.get("alignment_info", {})},
        }

        # Save results
        results = {
            "da3_vs_arkit": da3_vs_arkit,
            "ba_vs_arkit": ba_vs_arkit,
            "da3_vs_ba": da3_vs_ba,
            "ba_result": {
                "status": ba_result["status"],
                "error": ba_result.get("error"),
                "reprojection_error": ba_result.get("reprojection_error"),
            },
            "frame_categorization": frame_categorization,
            "diagnostics": diagnostics,
            "num_frames": len(images),
        }

        results_path = args.output_dir / "validation_results.json"
        with open(results_path, "w") as f:
            json.dump(results, f, indent=2, default=str)

        logger.info(f"\n✓ Results saved to {results_path}")
        logger.info("✓ Poses saved for visualization")
    else:
        logger.warning("BA validation failed, skipping BA comparisons")

    logger.info("\n=== Complete ===")


if __name__ == "__main__":
    main()