File size: 13,632 Bytes
7a87926
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
#!/usr/bin/env python3
"""
Run BA validation on full video to identify rejected frames.
"""

import os
import sys
from pathlib import Path

# Set environment variable FIRST before any imports
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

# Add SuperGluePretrainedNetwork to Python path if it exists
superglue_path = Path("/tmp/SuperGluePretrainedNetwork")
if superglue_path.exists():
    if str(superglue_path) not in sys.path:
        sys.path.insert(0, str(superglue_path))

# Set up logging IMMEDIATELY
import logging  # noqa: E402

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    force=True,  # Force reconfiguration
)
logger = logging.getLogger(__name__)

logger.info("=" * 60)
logger.info("Starting BA Validation Script")
logger.info("=" * 60)
logger.info("Step 0: Importing dependencies...")

# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
logger.info(f"Project root: {project_root}")

try:
    logger.info("  - Importing numpy...")
    import numpy as np

    logger.info("  βœ“ numpy imported")

    logger.info("  - Importing cv2...")
    import cv2

    logger.info("  βœ“ cv2 imported")

    logger.info("  - Importing torch...")
    import torch

    logger.info("  βœ“ torch imported")

    logger.info("  - Importing tqdm...")
    from tqdm import tqdm

    logger.info("  βœ“ tqdm imported")

    logger.info("  - Importing json...")
    import json

    logger.info("  βœ“ json imported")

    logger.info("  - Importing typing...")
    from typing import Optional

    logger.info("  βœ“ typing imported")

    logger.info("  - Importing ylff modules...")
    from ylff.utils.model_loader import load_da3_model

    logger.info("  βœ“ ylff.models imported")

    from ylff.services.ba_validator import BAValidator

    logger.info("  βœ“ ylff.ba_validator imported")

    logger.info("βœ“ All imports complete")
except Exception as e:
    logger.error(f"βœ— Import failed: {e}")
    import traceback

    traceback.print_exc()
    sys.exit(1)


def extract_all_frames(
    video_path: Path, max_frames: Optional[int] = None, frame_interval: int = 1
) -> list:
    """Extract all frames from video."""
    logger.info(f"Extracting frames from {video_path}")

    cap = cv2.VideoCapture(str(video_path))
    if not cap.isOpened():
        raise ValueError(f"Could not open video: {video_path}")

    # Get video properties
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    logger.info(f"Video properties: {total_frames} frames, {fps:.2f} fps")

    frames = []
    frame_idx = 0

    # Strategy: Extract every Nth frame first, then fill remaining slots if needed
    frames_to_extract = max_frames if max_frames else total_frames

    # Calculate how many frames we can get at the specified interval
    frames_at_interval = (total_frames + frame_interval - 1) // frame_interval

    logger.info(f"Target: {frames_to_extract} frames")
    logger.info(f"  - At interval {frame_interval}: can get up to {frames_at_interval} frames")

    interval_frames = []
    frame_idx = 0

    with tqdm(total=frames_to_extract, desc="Extracting frames", unit="frame") as pbar:
        # First pass: extract every Nth frame
        cap.set(cv2.CAP_PROP_POS_FRAMES, 0)  # Reset to start
        while frame_idx < total_frames and len(interval_frames) < frames_to_extract:
            ret, frame = cap.read()
            if not ret:
                break

            if frame_idx % frame_interval == 0:
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                interval_frames.append((frame_idx, frame_rgb))
                pbar.update(1)

            frame_idx += 1

        # Second pass: if we need more frames to reach target, fill in gaps
        if len(interval_frames) < frames_to_extract:
            logger.info(f"  - Got {len(interval_frames)} frames at interval {frame_interval}")
            logger.info(
                f"  - Filling remaining {frames_to_extract - len(interval_frames)} frames..."
            )
            cap.set(cv2.CAP_PROP_POS_FRAMES, 0)  # Reset to start
            frame_idx = 0
            extracted_indices = {idx for idx, _ in interval_frames}

            while len(interval_frames) < frames_to_extract:
                ret, frame = cap.read()
                if not ret:
                    logger.warning(f"  - Video ended. Got {len(interval_frames)} frames total.")
                    break

                if frame_idx not in extracted_indices:
                    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    interval_frames.append((frame_idx, frame_rgb))
                    pbar.update(1)

                frame_idx += 1

        # Sort by frame index and extract just the images
        interval_frames.sort(key=lambda x: x[0])
        frames = [img for _, img in interval_frames]

        logger.info(f"  - Final: {len(frames)} frames extracted")
        if len(interval_frames) > 0:
            frame_indices = [idx for idx, _ in interval_frames]
            logger.info(
                f"  - Frame indices: {frame_indices[0]}...{frame_indices[-1]} "
                f"(every {frame_interval} + fills)"
            )

    cap.release()
    logger.info(f"βœ“ Extracted {len(frames)} total frames")
    return frames


def main():
    import argparse

    logger.info("\n" + "=" * 60)
    logger.info("Parsing arguments...")

    parser = argparse.ArgumentParser(description="Run BA validation on video")
    parser.add_argument("--video", type=str, default=None, help="Path to video file")
    parser.add_argument(
        "--max-frames", type=int, default=None, help="Maximum number of frames to process"
    )
    parser.add_argument(
        "--frame-interval",
        type=int,
        default=1,
        help="Extract every Nth frame (e.g., 15 for every 15th frame)",
    )
    parser.add_argument("--output-dir", type=str, default=None, help="Output directory")
    args = parser.parse_args()

    # Paths
    if args.video:
        video_path = Path(args.video)
    else:
        video_path = project_root / "assets" / "examples" / "robot_unitree.mp4"

    if args.output_dir:
        output_dir = Path(args.output_dir)
    else:
        output_dir = project_root / "data" / "ba_validation_results"
    output_dir.mkdir(parents=True, exist_ok=True)

    logger.info("=" * 60)
    logger.info("BA Validation on Full Video")
    logger.info("=" * 60)
    logger.info(f"Video: {video_path}")
    logger.info(f"Output: {output_dir}")
    if args.max_frames:
        logger.info(f"Max frames: {args.max_frames}")
    logger.info("=" * 60)

    # 1. Extract frames
    logger.info("\n[Step 1] Extracting frames from video...")
    frames = extract_all_frames(
        video_path, max_frames=args.max_frames, frame_interval=args.frame_interval
    )
    logger.info(f"βœ“ Extracted {len(frames)} frames (every {args.frame_interval} frame(s))")

    # 2. Run DA3 inference
    logger.info("\n[Step 2] Running DA3 inference...")
    logger.info("Loading DA3 model (this may take a moment)...")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    logger.info(f"Using device: {device}")

    model = load_da3_model("depth-anything/DA3-LARGE", device=device)
    logger.info("βœ“ Model loaded")

    logger.info(f"Running inference on {len(frames)} frames (this may take a while)...")
    logger.info("  - Processing in batches...")
    logger.info("  - This step can take several minutes on CPU...")

    import threading
    import time

    # Start a heartbeat thread to show we're still alive
    inference_done = threading.Event()

    def heartbeat():
        count = 0
        while not inference_done.wait(30):  # Log every 30 seconds
            count += 1
            logger.info(f"  - Still processing... ({count * 30}s elapsed)")

    heartbeat_thread = threading.Thread(target=heartbeat, daemon=True)
    heartbeat_thread.start()

    start_time = time.time()

    try:
        with torch.no_grad():
            # DA3 inference processes all frames at once
            logger.info("  - Starting DA3 forward pass...")
            da3_output = model.inference(frames)
            elapsed = time.time() - start_time
            logger.info(
                f"  - DA3 inference completed in {elapsed:.1f} seconds "
                f"({elapsed / len(frames):.2f}s per frame)"
            )
    finally:
        inference_done.set()

    poses_da3 = da3_output.extrinsics  # (N, 3, 4)
    intrinsics = da3_output.intrinsics if hasattr(da3_output, "intrinsics") else None

    logger.info("βœ“ DA3 inference complete")
    logger.info(f"  - Poses shape: {poses_da3.shape}")
    logger.info(f"  - Intrinsics shape: {intrinsics.shape if intrinsics is not None else 'None'}")

    # 3. Run BA validation
    logger.info("\n[Step 3] Running BA validation...")
    logger.info("Initializing BA validator...")
    validator = BAValidator(
        accept_threshold=2.0,
        reject_threshold=30.0,
        work_dir=output_dir / "ba_work",
    )
    logger.info("βœ“ BA validator initialized")

    logger.info("Validating poses with BA (this may take a while)...")
    logger.info("  - Step 3.1: Saving images...")
    logger.info("  - Step 3.2: Extracting features (SuperPoint)...")
    logger.info("  - Step 3.3: Matching features (LightGlue)...")
    logger.info("  - Step 3.4: Running Bundle Adjustment...")

    result = validator.validate(
        images=frames,
        poses_model=poses_da3,
        intrinsics=intrinsics,
    )

    logger.info("βœ“ BA validation complete")

    # 4. Analyze results
    logger.info("\n[Step 4] Analyzing results...")

    status = result["status"]
    error = result.get("error")
    error_metrics = result.get("error_metrics", {})

    logger.info(f"\n{'=' * 60}")
    logger.info("RESULTS")
    logger.info(f"{'=' * 60}")
    logger.info(f"Overall Status: {status}")

    if error is not None and isinstance(error, (int, float)):
        logger.info(f"Max Rotation Error: {error:.2f}Β°")
    elif error is not None:
        logger.info(f"Max Rotation Error: {error}")

    if error_metrics:
        rot_errors = error_metrics.get("rotation_errors_deg", [])
        if rot_errors:
            logger.info("\nRotation Error Statistics:")
            logger.info(f"  - Mean: {np.mean(rot_errors):.2f}Β°")
            logger.info(f"  - Median: {np.median(rot_errors):.2f}Β°")
            logger.info(f"  - Max: {np.max(rot_errors):.2f}Β°")
            logger.info(f"  - Min: {np.min(rot_errors):.2f}Β°")
            logger.info(f"  - Std: {np.std(rot_errors):.2f}Β°")

            # Categorize individual frames
            accepted = []
            rejected_learnable = []
            rejected_outlier = []

            for i, err in enumerate(rot_errors):
                if err < 2.0:
                    accepted.append(i)
                elif err < 30.0:
                    rejected_learnable.append(i)
                else:
                    rejected_outlier.append(i)

            logger.info("\nFrame Categorization:")
            accepted_pct = 100 * len(accepted) / len(rot_errors)
            learnable_pct = 100 * len(rejected_learnable) / len(rot_errors)
            outlier_pct = 100 * len(rejected_outlier) / len(rot_errors)
            logger.info(f"  - Accepted (< 2Β°): {len(accepted)} frames ({accepted_pct:.1f}%)")
            logger.info(
                f"  - Rejected-Learnable (2-30Β°): {len(rejected_learnable)} frames "
                f"({learnable_pct:.1f}%)"
            )
            logger.info(
                f"  - Rejected-Outlier (> 30Β°): {len(rejected_outlier)} frames "
                f"({outlier_pct:.1f}%)"
            )

            # Save detailed results
            results_dict = {
                "status": status,
                "error": float(error) if error is not None else None,
                "error_metrics": {
                    "rotation_errors_deg": [float(e) for e in rot_errors],
                    "mean_rotation_error_deg": float(np.mean(rot_errors)),
                    "median_rotation_error_deg": float(np.median(rot_errors)),
                    "max_rotation_error_deg": float(np.max(rot_errors)),
                    "min_rotation_error_deg": float(np.min(rot_errors)),
                    "std_rotation_error_deg": float(np.std(rot_errors)),
                },
                "frame_categories": {
                    "accepted": accepted,
                    "rejected_learnable": rejected_learnable,
                    "rejected_outlier": rejected_outlier,
                },
                "num_frames": len(frames),
            }

            output_json = output_dir / "validation_results.json"
            with open(output_json, "w") as f:
                json.dump(results_dict, f, indent=2)

            logger.info(f"\nβœ“ Results saved to {output_json}")

            # Show some examples
            if rejected_learnable:
                logger.info("\nExample Rejected-Learnable frames (first 10):")
                for idx in rejected_learnable[:10]:
                    logger.info(f"  Frame {idx}: {rot_errors[idx]:.2f}Β°")

            if rejected_outlier:
                logger.info("\nExample Rejected-Outlier frames (first 10):")
                for idx in rejected_outlier[:10]:
                    logger.info(f"  Frame {idx}: {rot_errors[idx]:.2f}Β°")

    logger.info(f"\n{'=' * 60}")
    logger.info("βœ“ BA validation complete!")
    logger.info(f"{'=' * 60}")


if __name__ == "__main__":
    main()