File size: 7,346 Bytes
7a87926
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
#!/usr/bin/env python3
"""
Smoke test for YLFF pipeline using robot_unitree.mp4 video.
"""

import logging
import sys
from pathlib import Path

# Check dependencies
try:
    import cv2
    import numpy as np
    import torch
except ImportError as e:
    print(f"ERROR: Missing dependency: {e}")
    print("\nPlease install dependencies:")
    print("  pip install -e .")
    print("  # Or install manually:")
    print("  pip install torch torchvision numpy opencv-python")
    sys.exit(1)

# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)


def extract_frames_from_video(video_path: Path, max_frames: int = 10) -> list:
    """Extract frames from video file."""
    logger.info(f"Extracting frames from {video_path}")

    cap = cv2.VideoCapture(str(video_path))
    if not cap.isOpened():
        raise ValueError(f"Could not open video: {video_path}")

    frames = []
    frame_count = 0

    while len(frames) < max_frames:
        ret, frame = cap.read()
        if not ret:
            break

        # Convert BGR to RGB
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frames.append(frame_rgb)
        frame_count += 1

    cap.release()
    logger.info(f"Extracted {len(frames)} frames from video")
    return frames


def test_da3_inference(frames: list):
    """Test DA3 model inference."""
    logger.info("Testing DA3 inference...")

    try:
        from ylff.utils.model_loader import load_da3_model

        # Load model
        logger.info("Loading DA3 model...")
        model = load_da3_model(
            "depth-anything/DA3-SMALL", device="cuda" if torch.cuda.is_available() else "cpu"
        )
        logger.info("βœ“ Model loaded")

        # Run inference
        logger.info(f"Running inference on {len(frames)} frames...")
        with torch.no_grad():
            output = model.inference(frames)

        logger.info("βœ“ Inference complete")
        logger.info(f"  - Depth shape: {output.depth.shape}")
        logger.info(f"  - Poses shape: {output.extrinsics.shape}")
        intrinsics_shape = output.intrinsics.shape if hasattr(output, "intrinsics") else "N/A"
        logger.info(f"  - Intrinsics shape: {intrinsics_shape}")

        return output

    except ImportError as e:
        logger.error(f"Failed to import DA3: {e}")
        logger.error("Make sure DA3 is installed or available from HuggingFace")
        return None
    except Exception as e:
        logger.error(f"DA3 inference failed: {e}")
        import traceback

        traceback.print_exc()
        return None


def test_ba_validator_structure(frames: list, poses: np.ndarray):
    """Test BA validator structure (without full BA execution)."""
    logger.info("Testing BA validator structure...")

    try:
        from ylff.services.ba_validator import BAValidator

        # Create validator
        validator = BAValidator(
            accept_threshold=2.0,
            reject_threshold=30.0,
        )
        logger.info("βœ“ BA validator created")

        # Test pose error computation (without full BA)
        logger.info("Testing pose error computation...")

        # Create dummy target poses (slightly different)
        poses_target = poses.copy()
        poses_target[0, :3, 3] += 0.1  # Small translation change

        error_metrics = validator._compute_pose_error(poses, poses_target)
        logger.info("βœ“ Pose error computation works")
        logger.info(f"  - Max rotation error: {error_metrics['max_rotation_error_deg']:.2f}Β°")
        logger.info(f"  - Mean rotation error: {error_metrics['mean_rotation_error_deg']:.2f}Β°")

        return True

    except ImportError as e:
        logger.warning(f"BA validator dependencies not available: {e}")
        logger.warning("This is expected if pycolmap/hloc are not installed")
        return False
    except Exception as e:
        logger.error(f"BA validator test failed: {e}")
        import traceback

        traceback.print_exc()
        return False


def test_data_pipeline_structure(frames: list):
    """Test data pipeline structure."""
    logger.info("Testing data pipeline structure...")

    try:
        from ylff.services.data_pipeline import BADataPipeline
        from ylff.services.ba_validator import BAValidator
        from ylff.utils.model_loader import load_da3_model

        # Create pipeline components
        model = load_da3_model(
            "depth-anything/DA3-SMALL", device="cuda" if torch.cuda.is_available() else "cpu"
        )
        validator = BAValidator()
        pipeline = BADataPipeline(model, validator)

        logger.info("βœ“ Data pipeline created")
        logger.info(f"  - Stats: {pipeline.stats}")

        return True

    except Exception as e:
        logger.warning(f"Data pipeline test skipped: {e}")
        return False


def test_loss_functions():
    """Test loss function computation."""
    logger.info("Testing loss functions...")

    try:
        import torch

        from ylff.utils.losses import geodesic_rotation_loss, pose_loss

        # Create dummy poses
        poses1 = torch.randn(5, 3, 4)
        poses2 = poses1 + torch.randn(5, 3, 4) * 0.1

        # Test rotation loss
        rot_loss = geodesic_rotation_loss(poses1[:, :3, :3], poses2[:, :3, :3])
        logger.info(f"βœ“ Rotation loss: {rot_loss.item():.4f}")

        # Test pose loss
        pose_loss_val = pose_loss(poses1, poses2)
        logger.info(f"βœ“ Pose loss: {pose_loss_val.item():.4f}")

        return True

    except Exception as e:
        logger.error(f"Loss function test failed: {e}")
        import traceback

        traceback.print_exc()
        return False


def main():
    """Run smoke tests."""
    logger.info("=" * 60)
    logger.info("YLFF Smoke Test")
    logger.info("=" * 60)

    video_path = project_root / "assets" / "examples" / "robot_unitree.mp4"

    if not video_path.exists():
        logger.error(f"Video not found: {video_path}")
        return 1

    # Test 1: Extract frames
    logger.info("\n[Test 1] Extracting frames from video...")
    try:
        frames = extract_frames_from_video(video_path, max_frames=5)
        logger.info(f"βœ“ Extracted {len(frames)} frames")
    except Exception as e:
        logger.error(f"βœ— Frame extraction failed: {e}")
        return 1

    # Test 2: DA3 inference
    logger.info("\n[Test 2] Testing DA3 inference...")
    output = test_da3_inference(frames)
    if output is None:
        logger.error("βœ— DA3 inference test failed")
        return 1

    # Test 3: Loss functions
    logger.info("\n[Test 3] Testing loss functions...")
    if not test_loss_functions():
        logger.error("βœ— Loss function test failed")
        return 1

    # Test 4: BA validator structure
    logger.info("\n[Test 4] Testing BA validator structure...")
    test_ba_validator_structure(frames, output.extrinsics)

    # Test 5: Data pipeline structure
    logger.info("\n[Test 5] Testing data pipeline structure...")
    test_data_pipeline_structure(frames)

    logger.info("\n" + "=" * 60)
    logger.info("βœ“ Smoke test complete!")
    logger.info("=" * 60)

    return 0


if __name__ == "__main__":
    sys.exit(main())