File size: 4,441 Bytes
060fbda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""
Basic test to verify test setup works without external dependencies.
"""

import sys
from pathlib import Path

# Add the project root to the path
sys.path.insert(0, str(Path(__file__).parent.parent))


def test_schemas_import():
    """Test that schemas can be imported."""
    try:
        from sorawm.schemas import CleanerType

        assert CleanerType.LAMA == "lama"
        assert CleanerType.E2FGVI_HQ == "e2fgvi_hq"
        print("✓ Schemas import test passed")
    except Exception as e:
        print(f"✗ Schemas import test failed: {e}")
        return False
    return True


def test_watermark_cleaner_factory():
    """Test watermark cleaner factory with mocking."""
    try:
        from unittest.mock import MagicMock, patch

        with (
            patch("sorawm.watermark_cleaner.LamaCleaner") as mock_lama,
            patch("sorawm.watermark_cleaner.E2FGVIHDCleaner") as mock_e2fgvi,
        ):
            mock_lama.return_value = "lama_instance"
            mock_e2fgvi.return_value = "e2fgvi_instance"

            from sorawm.schemas import CleanerType
            from sorawm.watermark_cleaner import WaterMarkCleaner

            # Test LAMA cleaner
            cleaner = WaterMarkCleaner(CleanerType.LAMA)
            assert cleaner == "lama_instance"
            mock_lama.assert_called_once()

            # Test E2FGVI cleaner
            cleaner = WaterMarkCleaner(CleanerType.E2FGVI_HQ)
            assert cleaner == "e2fgvi_instance"
            mock_e2fgvi.assert_called_once()

        print("✓ WaterMarkCleaner factory test passed")
    except Exception as e:
        print(f"✗ WaterMarkCleaner factory test failed: {e}")
        return False
    return True


def test_imputation_utils():
    """Test imputation utilities with mock data."""
    try:
        # Test with minimal data that doesn't require ruptures
        # Mock the imputation functions to avoid ruptures dependency
        from unittest.mock import patch

        import numpy as np

        with patch("sorawm.utils.imputation_utils.rpt") as mock_rpt:
            # Mock the CPD result
            mock_algo = MagicMock()
            mock_algo.predict.return_value = [5, 10]
            mock_rpt.KernelCPD.return_value.fit.return_value = mock_algo

            from sorawm.utils.imputation_utils import find_2d_data_bkps

            data = [(1, 2), (3, 4), (5, 6)]
            result = find_2d_data_bkps(data)

            assert result == [5]  # Should return bkps[:-1]

        # Test bbox averaging function
        from sorawm.utils.imputation_utils import get_interval_average_bbox

        bboxes = [(10, 20, 30, 40), (11, 21, 31, 41), None]
        bkps = [0, 3]
        result = get_interval_average_bbox(bboxes, bkps)

        assert len(result) == 1
        assert result[0] is not None  # Should average the two valid bboxes

        print("✓ Imputation utils test passed")
    except Exception as e:
        print(f"✗ Imputation utils test failed: {e}")
        return False
    return True


def test_video_utils():
    """Test video utilities."""
    try:
        import numpy as np

        from sorawm.utils.video_utils import merge_frames_with_overlap

        # Test basic merging
        frame1 = np.ones((10, 10, 3), dtype=np.uint8) * 100
        frame2 = np.ones((10, 10, 3), dtype=np.uint8) * 200

        result = merge_frames_with_overlap(
            result_frames=None,
            chunk_frames=[frame1, frame2],
            start_idx=0,
            overlap_size=0,
            is_first_chunk=True,
        )

        assert len(result) == 2
        assert np.array_equal(result[0], frame1)
        assert np.array_equal(result[1], frame2)

        print("✓ Video utils test passed")
    except Exception as e:
        print(f"✗ Video utils test failed: {e}")
        return False
    return True


if __name__ == "__main__":
    print("Running basic tests...")
    print("=" * 50)

    tests = [
        test_schemas_import,
        test_watermark_cleaner_factory,
        test_imputation_utils,
        test_video_utils,
    ]

    passed = 0
    total = len(tests)

    for test in tests:
        if test():
            passed += 1
        print()

    print("=" * 50)
    print(f"Results: {passed}/{total} tests passed")

    if passed == total:
        print("🎉 All basic tests passed!")
        sys.exit(0)
    else:
        print("❌ Some tests failed")
        sys.exit(1)