MogensR commited on
Commit
482a60d
·
verified ·
1 Parent(s): c075655

Delete tests

Browse files
tests/__init__.py DELETED
@@ -1,77 +0,0 @@
1
- """
2
- BackgroundFX Pro Test Suite.
3
- Comprehensive testing for all modules.
4
- """
5
-
6
- import os
7
- import sys
8
- from pathlib import Path
9
-
10
- # Add parent directory to path for imports
11
- parent_dir = Path(__file__).parent.parent
12
- sys.path.insert(0, str(parent_dir))
13
-
14
- # Test categories
15
- TEST_CATEGORIES = {
16
- 'unit': 'Unit tests for individual components',
17
- 'integration': 'Integration tests for component interactions',
18
- 'api': 'API endpoint tests',
19
- 'models': 'Model management tests',
20
- 'pipeline': 'Processing pipeline tests',
21
- 'performance': 'Performance and benchmark tests',
22
- 'gpu': 'GPU-specific tests'
23
- }
24
-
25
- def run_tests(category: str = None, verbose: bool = True):
26
- """
27
- Run tests for BackgroundFX Pro.
28
-
29
- Args:
30
- category: Optional test category to run
31
- verbose: Enable verbose output
32
- """
33
- import pytest
34
-
35
- args = []
36
-
37
- if category:
38
- if category in TEST_CATEGORIES:
39
- args.extend(['-m', category])
40
- else:
41
- print(f"Unknown category: {category}")
42
- print(f"Available categories: {', '.join(TEST_CATEGORIES.keys())}")
43
- return 1
44
-
45
- if verbose:
46
- args.append('-v')
47
-
48
- # Add coverage by default
49
- args.extend(['--cov=.', '--cov-report=term-missing'])
50
-
51
- return pytest.main(args)
52
-
53
-
54
- def run_quick_tests():
55
- """Run quick unit tests only (no slow/integration tests)."""
56
- import pytest
57
- return pytest.main(['-m', 'not slow and not integration', '-v'])
58
-
59
-
60
- def run_gpu_tests():
61
- """Run GPU-specific tests if GPU is available."""
62
- import torch
63
-
64
- if not torch.cuda.is_available():
65
- print("No GPU available, skipping GPU tests")
66
- return 0
67
-
68
- import pytest
69
- return pytest.main(['-m', 'gpu', '-v'])
70
-
71
-
72
- __all__ = [
73
- 'run_tests',
74
- 'run_quick_tests',
75
- 'run_gpu_tests',
76
- 'TEST_CATEGORIES'
77
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/conftest.py DELETED
@@ -1,351 +0,0 @@
1
- """
2
- Pytest configuration and fixtures for BackgroundFX Pro tests.
3
- """
4
-
5
- import pytest
6
- import numpy as np
7
- import torch
8
- import cv2
9
- import tempfile
10
- import shutil
11
- from pathlib import Path
12
- from unittest.mock import Mock, MagicMock
13
- import os
14
- import sys
15
-
16
- # Add parent directory to path for imports
17
- sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
18
-
19
-
20
- # ============================================================================
21
- # Configuration
22
- # ============================================================================
23
-
24
- @pytest.fixture(scope="session")
25
- def test_config():
26
- """Test configuration."""
27
- return {
28
- 'device': 'cpu', # Use CPU for testing
29
- 'test_data_dir': Path(__file__).parent / 'data',
30
- 'temp_dir': tempfile.mkdtemp(prefix='bgfx_test_'),
31
- 'max_test_duration': 30, # seconds
32
- 'use_gpu': torch.cuda.is_available()
33
- }
34
-
35
-
36
- @pytest.fixture(scope="session", autouse=True)
37
- def cleanup(test_config):
38
- """Cleanup after all tests."""
39
- yield
40
- # Clean up temp directory
41
- if os.path.exists(test_config['temp_dir']):
42
- shutil.rmtree(test_config['temp_dir'])
43
-
44
-
45
- # ============================================================================
46
- # Image and Video Fixtures
47
- # ============================================================================
48
-
49
- @pytest.fixture
50
- def sample_image():
51
- """Create a sample image for testing."""
52
- # Create 512x512 RGB image with a person-like shape
53
- image = np.zeros((512, 512, 3), dtype=np.uint8)
54
-
55
- # Add background
56
- image[:, :] = [100, 150, 200] # Blue background
57
-
58
- # Add person-like shape (simple rectangle for testing)
59
- cv2.rectangle(image, (150, 100), (350, 450), (50, 100, 50), -1)
60
-
61
- # Add some texture
62
- noise = np.random.randint(0, 20, (512, 512, 3), dtype=np.uint8)
63
- image = cv2.add(image, noise)
64
-
65
- return image
66
-
67
-
68
- @pytest.fixture
69
- def sample_mask():
70
- """Create a sample mask for testing."""
71
- mask = np.zeros((512, 512), dtype=np.uint8)
72
- # Create person mask
73
- cv2.rectangle(mask, (150, 100), (350, 450), 255, -1)
74
- # Add some edge refinement
75
- mask = cv2.GaussianBlur(mask, (5, 5), 2)
76
- return mask
77
-
78
-
79
- @pytest.fixture
80
- def sample_background():
81
- """Create a sample background image."""
82
- background = np.zeros((512, 512, 3), dtype=np.uint8)
83
- # Create gradient background
84
- for i in range(512):
85
- background[i, :] = [
86
- int(255 * (i / 512)), # Red gradient
87
- 100, # Fixed green
88
- int(255 * (1 - i / 512)) # Blue inverse gradient
89
- ]
90
- return background
91
-
92
-
93
- @pytest.fixture
94
- def sample_video(test_config):
95
- """Create a sample video file for testing."""
96
- video_path = Path(test_config['temp_dir']) / 'test_video.mp4'
97
-
98
- # Create video writer
99
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
100
- out = cv2.VideoWriter(str(video_path), fourcc, 30.0, (512, 512))
101
-
102
- # Write 30 frames (1 second at 30fps)
103
- for i in range(30):
104
- frame = np.zeros((512, 512, 3), dtype=np.uint8)
105
- # Animate a moving rectangle
106
- x = 100 + i * 5
107
- cv2.rectangle(frame, (x, 200), (x + 100, 400), (0, 255, 0), -1)
108
- out.write(frame)
109
-
110
- out.release()
111
- return str(video_path)
112
-
113
-
114
- # ============================================================================
115
- # Model Fixtures
116
- # ============================================================================
117
-
118
- @pytest.fixture
119
- def mock_model():
120
- """Create a mock ML model for testing."""
121
- model = MagicMock()
122
- model.eval = MagicMock(return_value=None)
123
- model.to = MagicMock(return_value=model)
124
-
125
- # Mock forward pass
126
- def forward(x):
127
- batch_size = x.shape[0] if hasattr(x, 'shape') else 1
128
- return torch.randn(batch_size, 1, 512, 512)
129
-
130
- model.__call__ = MagicMock(side_effect=forward)
131
- model.forward = MagicMock(side_effect=forward)
132
-
133
- return model
134
-
135
-
136
- @pytest.fixture
137
- def mock_sam2_predictor():
138
- """Create a mock SAM2 predictor."""
139
- predictor = MagicMock()
140
-
141
- def predict(image):
142
- h, w = image.shape[:2] if len(image.shape) > 2 else (512, 512)
143
- return np.random.randint(0, 2, (h, w), dtype=np.uint8) * 255
144
-
145
- predictor.predict = MagicMock(side_effect=predict)
146
- predictor.set_image = MagicMock(return_value=None)
147
-
148
- return predictor
149
-
150
-
151
- @pytest.fixture
152
- def mock_matanyone_model():
153
- """Create a mock MatAnyone model."""
154
- model = MagicMock()
155
-
156
- def refine(image, mask):
157
- return cv2.GaussianBlur(mask, (5, 5), 2)
158
-
159
- model.refine = MagicMock(side_effect=refine)
160
-
161
- return model
162
-
163
-
164
- # ============================================================================
165
- # Pipeline and Processing Fixtures
166
- # ============================================================================
167
-
168
- @pytest.fixture
169
- def pipeline_config():
170
- """Create pipeline configuration for testing."""
171
- from api.pipeline import PipelineConfig
172
-
173
- return PipelineConfig(
174
- use_gpu=False, # CPU for testing
175
- quality_preset='medium',
176
- enable_cache=False, # Disable cache for testing
177
- batch_size=1,
178
- max_workers=2
179
- )
180
-
181
-
182
- @pytest.fixture
183
- def mock_pipeline(pipeline_config):
184
- """Create a mock processing pipeline."""
185
- from api.pipeline import ProcessingPipeline
186
-
187
- # Mock the pipeline to avoid loading real models
188
- with pytest.MonkeyPatch().context() as m:
189
- m.setattr('api.pipeline.ModelFactory.load_model',
190
- lambda self, *args, **kwargs: Mock())
191
- pipeline = ProcessingPipeline(pipeline_config)
192
-
193
- return pipeline
194
-
195
-
196
- # ============================================================================
197
- # API and Server Fixtures
198
- # ============================================================================
199
-
200
- @pytest.fixture
201
- def api_client():
202
- """Create a test client for the API."""
203
- from fastapi.testclient import TestClient
204
- from api.api_server import app
205
-
206
- return TestClient(app)
207
-
208
-
209
- @pytest.fixture
210
- def mock_job_manager():
211
- """Create a mock job manager."""
212
- manager = MagicMock()
213
- manager.create_job = MagicMock(return_value='test-job-123')
214
- manager.get_job = MagicMock(return_value={'status': 'processing'})
215
- manager.update_job = MagicMock(return_value=None)
216
-
217
- return manager
218
-
219
-
220
- # ============================================================================
221
- # File System Fixtures
222
- # ============================================================================
223
-
224
- @pytest.fixture
225
- def temp_dir(test_config):
226
- """Create a temporary directory for test files."""
227
- temp_path = Path(test_config['temp_dir']) / 'test_run'
228
- temp_path.mkdir(parents=True, exist_ok=True)
229
- yield temp_path
230
- # Cleanup
231
- if temp_path.exists():
232
- shutil.rmtree(temp_path)
233
-
234
-
235
- @pytest.fixture
236
- def sample_files(temp_dir, sample_image):
237
- """Create sample files in temp directory."""
238
- files = {}
239
-
240
- # Save sample image
241
- image_path = temp_dir / 'sample.jpg'
242
- cv2.imwrite(str(image_path), sample_image)
243
- files['image'] = image_path
244
-
245
- # Create multiple images for batch testing
246
- for i in range(3):
247
- path = temp_dir / f'image_{i}.jpg'
248
- cv2.imwrite(str(path), sample_image)
249
- files[f'image_{i}'] = path
250
-
251
- return files
252
-
253
-
254
- # ============================================================================
255
- # Model Registry Fixtures
256
- # ============================================================================
257
-
258
- @pytest.fixture
259
- def mock_registry():
260
- """Create a mock model registry."""
261
- from models.registry import ModelRegistry, ModelInfo, ModelTask, ModelFramework
262
-
263
- registry = ModelRegistry(models_dir=Path(tempfile.mkdtemp()))
264
-
265
- # Add test model
266
- test_model = ModelInfo(
267
- model_id='test-model',
268
- name='Test Model',
269
- version='1.0',
270
- task=ModelTask.SEGMENTATION,
271
- framework=ModelFramework.PYTORCH,
272
- url='http://example.com/model.pth',
273
- filename='test_model.pth',
274
- file_size=1000000
275
- )
276
-
277
- registry.register_model(test_model)
278
-
279
- return registry
280
-
281
-
282
- # ============================================================================
283
- # WebSocket Fixtures
284
- # ============================================================================
285
-
286
- @pytest.fixture
287
- def mock_websocket():
288
- """Create a mock WebSocket connection."""
289
- ws = MagicMock()
290
- ws.accept = MagicMock(return_value=None)
291
- ws.send_json = MagicMock(return_value=None)
292
- ws.receive_text = MagicMock(return_value='{"type": "ping", "data": {}}')
293
-
294
- return ws
295
-
296
-
297
- # ============================================================================
298
- # Utility Fixtures
299
- # ============================================================================
300
-
301
- @pytest.fixture
302
- def mock_progress_callback():
303
- """Create a mock progress callback."""
304
- callback = MagicMock()
305
- return callback
306
-
307
-
308
- @pytest.fixture
309
- def device():
310
- """Get device for testing."""
311
- return 'cuda' if torch.cuda.is_available() else 'cpu'
312
-
313
-
314
- @pytest.fixture
315
- def performance_timer():
316
- """Timer for performance testing."""
317
- import time
318
-
319
- class Timer:
320
- def __init__(self):
321
- self.start_time = None
322
- self.elapsed = 0
323
-
324
- def __enter__(self):
325
- self.start_time = time.time()
326
- return self
327
-
328
- def __exit__(self, *args):
329
- self.elapsed = time.time() - self.start_time
330
-
331
- return Timer
332
-
333
-
334
- # ============================================================================
335
- # Markers
336
- # ============================================================================
337
-
338
- def pytest_configure(config):
339
- """Register custom markers."""
340
- config.addinivalue_line(
341
- "markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')"
342
- )
343
- config.addinivalue_line(
344
- "markers", "gpu: marks tests that require GPU"
345
- )
346
- config.addinivalue_line(
347
- "markers", "integration: marks integration tests"
348
- )
349
- config.addinivalue_line(
350
- "markers", "unit: marks unit tests"
351
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/pytest.ini DELETED
@@ -1,53 +0,0 @@
1
- [pytest]
2
- # Pytest configuration for BackgroundFX Pro
3
-
4
- # Test discovery
5
- testpaths = tests
6
- python_files = test_*.py
7
- python_classes = Test*
8
- python_functions = test_*
9
-
10
- # Markers
11
- markers =
12
- slow: marks tests as slow (deselect with '-m "not slow"')
13
- gpu: marks tests that require GPU
14
- integration: marks integration tests
15
- unit: marks unit tests
16
- api: marks API tests
17
- models: marks model-related tests
18
- pipeline: marks pipeline tests
19
- performance: marks performance tests
20
-
21
- # Coverage
22
- addopts =
23
- --verbose
24
- --strict-markers
25
- --tb=short
26
- --cov=.
27
- --cov-report=html
28
- --cov-report=term-missing
29
- --cov-config=.coveragerc
30
-
31
- # Logging
32
- log_cli = true
33
- log_cli_level = INFO
34
- log_cli_format = %(asctime)s [%(levelname)8s] %(message)s
35
- log_cli_date_format = %Y-%m-%d %H:%M:%S
36
-
37
- # Warnings
38
- filterwarnings =
39
- ignore::DeprecationWarning
40
- ignore::PendingDeprecationWarning
41
- ignore::FutureWarning:torch.*
42
-
43
- # Timeout
44
- timeout = 300
45
-
46
- # Parallel execution (optional, requires pytest-xdist)
47
- # addopts = -n auto
48
-
49
- # Environment variables for testing
50
- env =
51
- TESTING = 1
52
- DEVICE = cpu
53
- LOG_LEVEL = WARNING
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_api.py DELETED
@@ -1,356 +0,0 @@
1
- """
2
- Tests for API endpoints and WebSocket functionality.
3
- """
4
-
5
- import pytest
6
- import json
7
- import base64
8
- from unittest.mock import Mock, patch, MagicMock
9
- from fastapi.testclient import TestClient
10
- import numpy as np
11
- import cv2
12
-
13
- from api.api_server import app, ProcessingRequest, ProcessingResponse
14
- from api.websocket import WebSocketHandler, WSMessage, MessageType
15
-
16
-
17
- class TestAPIEndpoints:
18
- """Test REST API endpoints."""
19
-
20
- @pytest.fixture
21
- def client(self):
22
- """Create test client."""
23
- return TestClient(app)
24
-
25
- @pytest.fixture
26
- def auth_headers(self):
27
- """Create authentication headers."""
28
- # Mock authentication for testing
29
- return {"Authorization": "Bearer test-token"}
30
-
31
- def test_root_endpoint(self, client):
32
- """Test root endpoint."""
33
- response = client.get("/")
34
- assert response.status_code == 200
35
- data = response.json()
36
- assert "name" in data
37
- assert data["name"] == "BackgroundFX Pro API"
38
-
39
- def test_health_check(self, client):
40
- """Test health check endpoint."""
41
- response = client.get("/health")
42
- assert response.status_code == 200
43
- data = response.json()
44
- assert data["status"] == "healthy"
45
- assert "services" in data
46
-
47
- @patch('api.api_server.verify_token')
48
- def test_process_image_endpoint(self, mock_verify, client, auth_headers, sample_image):
49
- """Test image processing endpoint."""
50
- mock_verify.return_value = "test-user"
51
-
52
- # Create test image file
53
- _, buffer = cv2.imencode('.jpg', sample_image)
54
-
55
- files = {"file": ("test.jpg", buffer.tobytes(), "image/jpeg")}
56
- data = {
57
- "background": "blur",
58
- "quality": "high"
59
- }
60
-
61
- with patch('api.api_server.process_image_task'):
62
- response = client.post(
63
- "/api/v1/process/image",
64
- headers=auth_headers,
65
- files=files,
66
- data=data
67
- )
68
-
69
- assert response.status_code == 200
70
- result = response.json()
71
- assert "job_id" in result
72
- assert result["status"] == "processing"
73
-
74
- @patch('api.api_server.verify_token')
75
- def test_process_video_endpoint(self, mock_verify, client, auth_headers, sample_video):
76
- """Test video processing endpoint."""
77
- mock_verify.return_value = "test-user"
78
-
79
- with open(sample_video, 'rb') as f:
80
- files = {"file": ("test.mp4", f.read(), "video/mp4")}
81
-
82
- data = {
83
- "background": "office",
84
- "quality": "medium"
85
- }
86
-
87
- with patch('api.api_server.process_video_task'):
88
- response = client.post(
89
- "/api/v1/process/video",
90
- headers=auth_headers,
91
- files=files,
92
- data=data
93
- )
94
-
95
- assert response.status_code == 200
96
- result = response.json()
97
- assert "job_id" in result
98
-
99
- @patch('api.api_server.verify_token')
100
- def test_batch_processing_endpoint(self, mock_verify, client, auth_headers):
101
- """Test batch processing endpoint."""
102
- mock_verify.return_value = "test-user"
103
-
104
- batch_request = {
105
- "items": [
106
- {"id": "1", "input_path": "/tmp/img1.jpg", "output_path": "/tmp/out1.jpg"},
107
- {"id": "2", "input_path": "/tmp/img2.jpg", "output_path": "/tmp/out2.jpg"}
108
- ],
109
- "parallel": True,
110
- "priority": "normal"
111
- }
112
-
113
- with patch('api.api_server.process_batch_task'):
114
- response = client.post(
115
- "/api/v1/batch",
116
- headers=auth_headers,
117
- json=batch_request
118
- )
119
-
120
- assert response.status_code == 200
121
- result = response.json()
122
- assert "job_id" in result
123
-
124
- @patch('api.api_server.verify_token')
125
- def test_job_status_endpoint(self, mock_verify, client, auth_headers):
126
- """Test job status endpoint."""
127
- mock_verify.return_value = "test-user"
128
-
129
- job_id = "test-job-123"
130
-
131
- with patch.object(app.state.job_manager, 'get_job') as mock_get:
132
- mock_get.return_value = ProcessingResponse(
133
- job_id=job_id,
134
- status="completed",
135
- progress=1.0
136
- )
137
-
138
- response = client.get(
139
- f"/api/v1/job/{job_id}",
140
- headers=auth_headers
141
- )
142
-
143
- assert response.status_code == 200
144
- result = response.json()
145
- assert result["job_id"] == job_id
146
- assert result["status"] == "completed"
147
-
148
- @patch('api.api_server.verify_token')
149
- def test_streaming_endpoints(self, mock_verify, client, auth_headers):
150
- """Test streaming endpoints."""
151
- mock_verify.return_value = "test-user"
152
-
153
- # Start stream
154
- stream_request = {
155
- "source": "0",
156
- "stream_type": "webcam",
157
- "output_format": "hls"
158
- }
159
-
160
- with patch.object(app.state.video_processor, 'start_stream_processing') as mock_start:
161
- mock_start.return_value = True
162
-
163
- response = client.post(
164
- "/api/v1/stream/start",
165
- headers=auth_headers,
166
- json=stream_request
167
- )
168
-
169
- assert response.status_code == 200
170
- result = response.json()
171
- assert result["status"] == "streaming"
172
-
173
- # Stop stream
174
- with patch.object(app.state.video_processor, 'stop_stream_processing'):
175
- response = client.get(
176
- "/api/v1/stream/stop",
177
- headers=auth_headers
178
- )
179
-
180
- assert response.status_code == 200
181
-
182
-
183
- class TestWebSocket:
184
- """Test WebSocket functionality."""
185
-
186
- @pytest.fixture
187
- def ws_handler(self):
188
- """Create WebSocket handler."""
189
- return WebSocketHandler()
190
-
191
- def test_websocket_connection(self, ws_handler, mock_websocket):
192
- """Test WebSocket connection handling."""
193
- # Test connection acceptance
194
- async def test_connect():
195
- await ws_handler.handle_connection(mock_websocket)
196
-
197
- # Would need async test runner for full test
198
- assert mock_websocket.accept.called or True # Simplified for sync test
199
-
200
- def test_message_parsing(self, ws_handler):
201
- """Test WebSocket message parsing."""
202
- message_data = {
203
- "type": "process_frame",
204
- "data": {"frame": "base64_data"}
205
- }
206
-
207
- message = WSMessage.from_dict(message_data)
208
-
209
- assert message.type == MessageType.PROCESS_FRAME
210
- assert message.data["frame"] == "base64_data"
211
-
212
- def test_frame_encoding_decoding(self, ws_handler, sample_image):
213
- """Test frame encoding and decoding."""
214
- # Encode frame
215
- _, buffer = cv2.imencode('.jpg', sample_image)
216
- encoded = base64.b64encode(buffer).decode('utf-8')
217
-
218
- # Decode frame
219
- decoded = ws_handler.frame_processor._decode_frame(encoded)
220
-
221
- assert decoded is not None
222
- assert decoded.shape == sample_image.shape
223
-
224
- def test_session_management(self, ws_handler):
225
- """Test client session management."""
226
- mock_ws = MagicMock()
227
-
228
- # Add session
229
- async def test_add():
230
- session = await ws_handler.session_manager.add_session(mock_ws, "test-client")
231
- assert session.client_id == "test-client"
232
-
233
- # Would need async test runner for full test
234
- assert ws_handler.session_manager is not None
235
-
236
- def test_message_routing(self, ws_handler):
237
- """Test message routing."""
238
- messages = [
239
- WSMessage(type=MessageType.PING, data={}),
240
- WSMessage(type=MessageType.UPDATE_CONFIG, data={"quality": "high"}),
241
- WSMessage(type=MessageType.START_STREAM, data={"source": 0})
242
- ]
243
-
244
- for msg in messages:
245
- assert msg.type in MessageType
246
- assert isinstance(msg.to_dict(), dict)
247
-
248
- def test_statistics_tracking(self, ws_handler):
249
- """Test WebSocket statistics."""
250
- stats = ws_handler.get_statistics()
251
-
252
- assert "uptime" in stats
253
- assert "total_connections" in stats
254
- assert "active_connections" in stats
255
- assert "total_frames_processed" in stats
256
-
257
-
258
- class TestAPIIntegration:
259
- """Integration tests for API."""
260
-
261
- @pytest.mark.integration
262
- def test_full_image_processing_flow(self, client, sample_image, temp_dir):
263
- """Test complete image processing flow."""
264
- # Skip authentication for integration test
265
- with patch('api.api_server.verify_token', return_value="test-user"):
266
- # Upload image
267
- _, buffer = cv2.imencode('.jpg', sample_image)
268
- files = {"file": ("test.jpg", buffer.tobytes(), "image/jpeg")}
269
-
270
- response = client.post(
271
- "/api/v1/process/image",
272
- files=files,
273
- data={"background": "blur", "quality": "low"}
274
- )
275
-
276
- assert response.status_code == 200
277
- job_data = response.json()
278
- job_id = job_data["job_id"]
279
-
280
- # Check job status
281
- response = client.get(f"/api/v1/job/{job_id}")
282
-
283
- # Would need actual processing for full test
284
- assert response.status_code in [200, 404]
285
-
286
- @pytest.mark.integration
287
- @pytest.mark.slow
288
- def test_concurrent_requests(self, client):
289
- """Test handling concurrent requests."""
290
- import concurrent.futures
291
-
292
- def make_request():
293
- response = client.get("/health")
294
- return response.status_code
295
-
296
- with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
297
- futures = [executor.submit(make_request) for _ in range(10)]
298
- results = [f.result() for f in concurrent.futures.as_completed(futures)]
299
-
300
- assert all(status == 200 for status in results)
301
-
302
- @pytest.mark.integration
303
- def test_error_handling(self, client):
304
- """Test API error handling."""
305
- # Test invalid endpoint
306
- response = client.get("/api/v1/invalid")
307
- assert response.status_code == 404
308
-
309
- # Test missing authentication
310
- response = client.get("/api/v1/stats")
311
- assert response.status_code in [401, 422] # Unauthorized or validation error
312
-
313
- # Test invalid file format
314
- with patch('api.api_server.verify_token', return_value="test-user"):
315
- files = {"file": ("test.txt", b"text content", "text/plain")}
316
- response = client.post(
317
- "/api/v1/process/image",
318
- files=files,
319
- headers={"Authorization": "Bearer test"}
320
- )
321
- assert response.status_code == 400
322
-
323
-
324
- class TestAPIPerformance:
325
- """Performance tests for API."""
326
-
327
- @pytest.mark.slow
328
- def test_response_time(self, client, performance_timer):
329
- """Test API response times."""
330
- endpoints = ["/", "/health"]
331
-
332
- for endpoint in endpoints:
333
- with performance_timer as timer:
334
- response = client.get(endpoint)
335
-
336
- assert response.status_code == 200
337
- assert timer.elapsed < 0.1 # Should respond in under 100ms
338
-
339
- @pytest.mark.slow
340
- def test_file_upload_performance(self, client, performance_timer):
341
- """Test file upload performance."""
342
- # Create a 1MB test file
343
- large_data = np.random.randint(0, 255, (1024, 1024, 3), dtype=np.uint8)
344
- _, buffer = cv2.imencode('.jpg', large_data)
345
-
346
- with patch('api.api_server.verify_token', return_value="test-user"):
347
- with patch('api.api_server.process_image_task'):
348
- with performance_timer as timer:
349
- response = client.post(
350
- "/api/v1/process/image",
351
- files={"file": ("large.jpg", buffer.tobytes(), "image/jpeg")},
352
- headers={"Authorization": "Bearer test"}
353
- )
354
-
355
- assert response.status_code == 200
356
- assert timer.elapsed < 2.0 # Should handle 1MB in under 2 seconds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_app.py DELETED
@@ -1,332 +0,0 @@
1
- """
2
- Tests for the main application and video processor.
3
- """
4
-
5
- import pytest
6
- from unittest.mock import Mock, patch, MagicMock
7
- import numpy as np
8
- import cv2
9
- from pathlib import Path
10
- import tempfile
11
-
12
- # Import from main app
13
- from app import VideoProcessor, processor
14
-
15
-
16
- class TestVideoProcessor:
17
- """Test the main VideoProcessor class."""
18
-
19
- @pytest.fixture
20
- def video_processor(self):
21
- """Create a test video processor."""
22
- with patch('app.model_loader.ModelLoader') as mock_loader:
23
- with patch('app.device_manager.DeviceManager') as mock_device:
24
- mock_device.return_value.get_optimal_device.return_value = 'cpu'
25
- vp = VideoProcessor()
26
- return vp
27
-
28
- def test_initialization(self, video_processor):
29
- """Test VideoProcessor initialization."""
30
- assert video_processor is not None
31
- assert video_processor.models_loaded == False
32
- assert video_processor.cancel_event is not None
33
-
34
- @patch('app.model_loader.ModelLoader.load_all_models')
35
- def test_load_models(self, mock_load, video_processor):
36
- """Test model loading."""
37
- mock_sam2 = Mock()
38
- mock_matanyone = Mock()
39
- mock_load.return_value = (mock_sam2, mock_matanyone)
40
-
41
- result = video_processor.load_models()
42
-
43
- assert mock_load.called
44
- assert video_processor.models_loaded == True
45
- assert "Models already loaded" in result or "loaded" in result.lower()
46
-
47
- def test_load_models_with_progress(self, video_processor):
48
- """Test model loading with progress callback."""
49
- progress_values = []
50
-
51
- def progress_callback(value, message):
52
- progress_values.append((value, message))
53
-
54
- with patch.object(video_processor.model_loader, 'load_all_models') as mock_load:
55
- mock_load.return_value = (Mock(), Mock())
56
- video_processor.load_models(progress_callback)
57
-
58
- assert video_processor.models_loaded == True
59
-
60
- def test_process_video_without_models(self, video_processor):
61
- """Test video processing without loaded models."""
62
- result_path, message = video_processor.process_video(
63
- "test.mp4", "blur"
64
- )
65
-
66
- assert result_path is None
67
- assert "not loaded" in message.lower()
68
-
69
- @patch('app.validate_video_file')
70
- @patch.object(VideoProcessor, '_process_single_stage')
71
- def test_process_video_single_stage(self, mock_process, mock_validate, video_processor):
72
- """Test single-stage video processing."""
73
- mock_validate.return_value = (True, "Valid")
74
- mock_process.return_value = ("/tmp/output.mp4", "Success")
75
-
76
- video_processor.models_loaded = True
77
- video_processor.core_processor = Mock()
78
-
79
- result_path, message = video_processor.process_video(
80
- "test.mp4", "blur", use_two_stage=False
81
- )
82
-
83
- assert mock_process.called
84
- assert result_path == "/tmp/output.mp4"
85
- assert "Success" in message
86
-
87
- @patch('app.TWO_STAGE_AVAILABLE', True)
88
- @patch('app.validate_video_file')
89
- @patch.object(VideoProcessor, '_process_two_stage')
90
- def test_process_video_two_stage(self, mock_process, mock_validate, video_processor):
91
- """Test two-stage video processing."""
92
- mock_validate.return_value = (True, "Valid")
93
- mock_process.return_value = ("/tmp/output.mp4", "Success")
94
-
95
- video_processor.models_loaded = True
96
- video_processor.core_processor = Mock()
97
- video_processor.two_stage_processor = Mock()
98
-
99
- result_path, message = video_processor.process_video(
100
- "test.mp4", "blur", use_two_stage=True
101
- )
102
-
103
- assert mock_process.called
104
- assert result_path == "/tmp/output.mp4"
105
-
106
- def test_cancel_processing(self, video_processor):
107
- """Test processing cancellation."""
108
- video_processor.cancel_processing()
109
- assert video_processor.cancel_event.is_set()
110
-
111
- def test_get_status(self, video_processor):
112
- """Test getting processor status."""
113
- status = video_processor.get_status()
114
-
115
- assert "models_loaded" in status
116
- assert "device" in status
117
- assert "memory_usage" in status
118
- assert status["models_loaded"] == False
119
-
120
- def test_cleanup_resources(self, video_processor):
121
- """Test resource cleanup."""
122
- with patch.object(video_processor.memory_manager, 'cleanup_aggressive'):
123
- with patch.object(video_processor.model_loader, 'cleanup'):
124
- video_processor.cleanup_resources()
125
-
126
- assert True # Cleanup should not raise exceptions
127
-
128
-
129
- class TestCoreVideoProcessor:
130
- """Test the CoreVideoProcessor from video_processor module."""
131
-
132
- @pytest.fixture
133
- def core_processor(self, mock_sam2_predictor, mock_matanyone_model):
134
- """Create a test core processor."""
135
- from video_processor import CoreVideoProcessor
136
- from app_config import ProcessingConfig
137
- from memory_manager import MemoryManager
138
-
139
- config = ProcessingConfig()
140
- memory_mgr = MemoryManager('cpu')
141
-
142
- processor = CoreVideoProcessor(
143
- sam2_predictor=mock_sam2_predictor,
144
- matanyone_model=mock_matanyone_model,
145
- config=config,
146
- memory_mgr=memory_mgr
147
- )
148
- return processor
149
-
150
- def test_core_processor_initialization(self, core_processor):
151
- """Test CoreVideoProcessor initialization."""
152
- assert core_processor is not None
153
- assert core_processor.processing_active == False
154
- assert core_processor.stats is not None
155
-
156
- def test_prepare_background(self, core_processor):
157
- """Test background preparation."""
158
- # Test professional background
159
- background = core_processor.prepare_background(
160
- "blur", None, 512, 512
161
- )
162
- # May return None if utilities not available
163
- assert background is None or background.shape == (512, 512, 3)
164
-
165
- def test_get_processing_capabilities(self, core_processor):
166
- """Test getting processing capabilities."""
167
- capabilities = core_processor.get_processing_capabilities()
168
-
169
- assert "sam2_available" in capabilities
170
- assert "matanyone_available" in capabilities
171
- assert "quality_preset" in capabilities
172
- assert "supported_formats" in capabilities
173
-
174
- def test_get_status(self, core_processor):
175
- """Test getting processor status."""
176
- status = core_processor.get_status()
177
-
178
- assert "processing_active" in status
179
- assert "models_available" in status
180
- assert "statistics" in status
181
- assert "memory_usage" in status
182
-
183
-
184
- class TestApplicationIntegration:
185
- """Integration tests for the main application."""
186
-
187
- @pytest.mark.integration
188
- def test_global_processor_instance(self):
189
- """Test the global processor instance."""
190
- assert processor is not None
191
- assert isinstance(processor, VideoProcessor)
192
-
193
- @pytest.mark.integration
194
- @patch('app.model_loader.ModelLoader.load_all_models')
195
- def test_model_loading_flow(self, mock_load):
196
- """Test complete model loading flow."""
197
- mock_load.return_value = (Mock(), Mock())
198
-
199
- # Use global processor
200
- result = processor.load_models()
201
-
202
- assert processor.models_loaded == True
203
- assert result is not None
204
-
205
- @pytest.mark.integration
206
- @pytest.mark.slow
207
- def test_memory_management(self):
208
- """Test memory management during processing."""
209
- import psutil
210
- import os
211
-
212
- process = psutil.Process(os.getpid())
213
- initial_memory = process.memory_info().rss / 1024 / 1024 # MB
214
-
215
- # Simulate processing
216
- for _ in range(5):
217
- # Create and discard large arrays
218
- data = np.random.randint(0, 255, (1024, 1024, 3), dtype=np.uint8)
219
- del data
220
-
221
- processor.cleanup_resources()
222
-
223
- final_memory = process.memory_info().rss / 1024 / 1024 # MB
224
- memory_increase = final_memory - initial_memory
225
-
226
- # Memory increase should be reasonable
227
- assert memory_increase < 200 # Less than 200MB increase
228
-
229
-
230
- class TestBackwardCompatibility:
231
- """Test backward compatibility functions."""
232
-
233
- def test_load_models_wrapper(self):
234
- """Test load_models_with_validation wrapper."""
235
- from app import load_models_with_validation
236
-
237
- with patch.object(processor, 'load_models') as mock_load:
238
- mock_load.return_value = "Success"
239
- result = load_models_with_validation()
240
-
241
- assert mock_load.called
242
- assert result == "Success"
243
-
244
- def test_process_video_wrapper(self):
245
- """Test process_video_fixed wrapper."""
246
- from app import process_video_fixed
247
-
248
- with patch.object(processor, 'process_video') as mock_process:
249
- mock_process.return_value = ("/tmp/out.mp4", "Success")
250
-
251
- result = process_video_fixed(
252
- "test.mp4", "blur", None
253
- )
254
-
255
- assert mock_process.called
256
- assert result[0] == "/tmp/out.mp4"
257
-
258
- def test_get_model_status_wrapper(self):
259
- """Test get_model_status wrapper."""
260
- from app import get_model_status
261
-
262
- with patch.object(processor, 'get_status') as mock_status:
263
- mock_status.return_value = {"status": "ok"}
264
- result = get_model_status()
265
-
266
- assert mock_status.called
267
- assert result["status"] == "ok"
268
-
269
-
270
- class TestErrorHandling:
271
- """Test error handling in the application."""
272
-
273
- def test_invalid_video_handling(self):
274
- """Test handling of invalid video files."""
275
- with patch('app.validate_video_file') as mock_validate:
276
- mock_validate.return_value = (False, "Invalid format")
277
-
278
- processor.models_loaded = True
279
- result_path, message = processor.process_video(
280
- "invalid.txt", "blur"
281
- )
282
-
283
- assert result_path is None
284
- assert "Invalid" in message
285
-
286
- def test_model_loading_failure(self):
287
- """Test handling of model loading failures."""
288
- with patch.object(processor.model_loader, 'load_all_models') as mock_load:
289
- mock_load.side_effect = Exception("Model not found")
290
-
291
- result = processor.load_models()
292
-
293
- assert processor.models_loaded == False
294
- assert "failed" in result.lower()
295
-
296
- def test_processing_exception_handling(self):
297
- """Test exception handling during processing."""
298
- processor.models_loaded = True
299
- processor.core_processor = Mock()
300
- processor.core_processor.process_video.side_effect = Exception("Processing failed")
301
-
302
- with patch('app.validate_video_file', return_value=(True, "Valid")):
303
- result_path, message = processor.process_video(
304
- "test.mp4", "blur"
305
- )
306
-
307
- assert result_path is None
308
- assert "error" in message.lower() or "failed" in message.lower()
309
-
310
-
311
- class TestPerformance:
312
- """Performance tests for the application."""
313
-
314
- @pytest.mark.slow
315
- def test_initialization_speed(self, performance_timer):
316
- """Test application initialization speed."""
317
- with performance_timer as timer:
318
- vp = VideoProcessor()
319
-
320
- assert timer.elapsed < 1.0 # Should initialize in under 1 second
321
-
322
- @pytest.mark.slow
323
- @patch('app.model_loader.ModelLoader.load_all_models')
324
- def test_model_loading_speed(self, mock_load, performance_timer):
325
- """Test model loading speed."""
326
- mock_load.return_value = (Mock(), Mock())
327
-
328
- with performance_timer as timer:
329
- processor.load_models()
330
-
331
- # Mock loading should be very fast
332
- assert timer.elapsed < 0.5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_models.py DELETED
@@ -1,376 +0,0 @@
1
- """
2
- Tests for model management functionality.
3
- """
4
-
5
- import pytest
6
- import tempfile
7
- from pathlib import Path
8
- from unittest.mock import Mock, patch, MagicMock
9
- import json
10
-
11
- from models import (
12
- ModelRegistry,
13
- ModelInfo,
14
- ModelStatus,
15
- ModelTask,
16
- ModelFramework,
17
- ModelDownloader,
18
- ModelLoader,
19
- ModelOptimizer
20
- )
21
-
22
-
23
- class TestModelRegistry:
24
- """Test model registry functionality."""
25
-
26
- @pytest.fixture
27
- def registry(self):
28
- """Create a test registry."""
29
- temp_dir = tempfile.mkdtemp()
30
- return ModelRegistry(models_dir=Path(temp_dir))
31
-
32
- def test_registry_initialization(self, registry):
33
- """Test registry initialization."""
34
- assert registry is not None
35
- assert len(registry.models) > 0 # Should have default models
36
- assert registry.models_dir.exists()
37
-
38
- def test_register_model(self, registry):
39
- """Test registering a new model."""
40
- model = ModelInfo(
41
- model_id="test-model",
42
- name="Test Model",
43
- version="1.0",
44
- task=ModelTask.SEGMENTATION,
45
- framework=ModelFramework.PYTORCH,
46
- url="http://example.com/model.pth",
47
- filename="test.pth",
48
- file_size=1000000
49
- )
50
-
51
- success = registry.register_model(model)
52
- assert success == True
53
- assert "test-model" in registry.models
54
-
55
- def test_get_model(self, registry):
56
- """Test getting a model by ID."""
57
- model = registry.get_model("rmbg-1.4")
58
- assert model is not None
59
- assert model.model_id == "rmbg-1.4"
60
- assert model.task == ModelTask.SEGMENTATION
61
-
62
- def test_list_models_by_task(self, registry):
63
- """Test listing models by task."""
64
- segmentation_models = registry.list_models(task=ModelTask.SEGMENTATION)
65
- assert len(segmentation_models) > 0
66
- assert all(m.task == ModelTask.SEGMENTATION for m in segmentation_models)
67
-
68
- def test_list_models_by_framework(self, registry):
69
- """Test listing models by framework."""
70
- pytorch_models = registry.list_models(framework=ModelFramework.PYTORCH)
71
- onnx_models = registry.list_models(framework=ModelFramework.ONNX)
72
-
73
- assert all(m.framework == ModelFramework.PYTORCH for m in pytorch_models)
74
- assert all(m.framework == ModelFramework.ONNX for m in onnx_models)
75
-
76
- def test_get_best_model(self, registry):
77
- """Test getting best model for a task."""
78
- # Best for accuracy
79
- best_accuracy = registry.get_best_model(
80
- ModelTask.SEGMENTATION,
81
- prefer_speed=False
82
- )
83
- assert best_accuracy is not None
84
-
85
- # Best for speed
86
- best_speed = registry.get_best_model(
87
- ModelTask.SEGMENTATION,
88
- prefer_speed=True
89
- )
90
- assert best_speed is not None
91
-
92
- def test_update_model_usage(self, registry):
93
- """Test updating model usage statistics."""
94
- model_id = "rmbg-1.4"
95
- initial_count = registry.models[model_id].use_count
96
-
97
- registry.update_model_usage(model_id)
98
-
99
- assert registry.models[model_id].use_count == initial_count + 1
100
- assert registry.models[model_id].last_used is not None
101
-
102
- def test_get_total_size(self, registry):
103
- """Test calculating total model size."""
104
- total_size = registry.get_total_size()
105
- assert total_size > 0
106
-
107
- # Size of available models should be 0 initially
108
- available_size = registry.get_total_size(status=ModelStatus.AVAILABLE)
109
- assert available_size == 0
110
-
111
- def test_export_registry(self, registry, temp_dir):
112
- """Test exporting registry to file."""
113
- export_path = temp_dir / "registry_export.json"
114
- registry.export_registry(export_path)
115
-
116
- assert export_path.exists()
117
-
118
- with open(export_path) as f:
119
- data = json.load(f)
120
- assert "models" in data
121
- assert len(data["models"]) > 0
122
-
123
-
124
- class TestModelDownloader:
125
- """Test model downloading functionality."""
126
-
127
- @pytest.fixture
128
- def downloader(self, mock_registry):
129
- """Create a test downloader."""
130
- return ModelDownloader(mock_registry)
131
-
132
- @patch('requests.get')
133
- def test_download_model(self, mock_get, downloader):
134
- """Test downloading a model."""
135
- # Mock HTTP response
136
- mock_response = MagicMock()
137
- mock_response.headers = {'content-length': '1000000'}
138
- mock_response.iter_content = MagicMock(
139
- return_value=[b'data' * 1000]
140
- )
141
- mock_response.raise_for_status = MagicMock()
142
- mock_get.return_value = mock_response
143
-
144
- # Test download
145
- success = downloader.download_model("test-model", force=True)
146
-
147
- assert mock_get.called
148
- # Note: Full download test would require more mocking
149
-
150
- def test_download_progress_tracking(self, downloader):
151
- """Test download progress tracking."""
152
- progress_values = []
153
-
154
- def progress_callback(progress):
155
- progress_values.append(progress.progress)
156
-
157
- # Start a download (will fail but we can test progress initialization)
158
- with patch.object(downloader, '_download_model_task', return_value=True):
159
- downloader.download_model(
160
- "test-model",
161
- progress_callback=progress_callback
162
- )
163
-
164
- assert "test-model" in downloader.downloads
165
-
166
- def test_cancel_download(self, downloader):
167
- """Test cancelling a download."""
168
- # Start a mock download
169
- downloader.downloads["test-model"] = Mock()
170
- downloader._stop_events["test-model"] = Mock()
171
-
172
- success = downloader.cancel_download("test-model")
173
-
174
- assert success == True
175
- assert downloader._stop_events["test-model"].set.called
176
-
177
- def test_download_with_resume(self, downloader, temp_dir):
178
- """Test download with resume support."""
179
- # Create a partial file
180
- partial_file = temp_dir / "test.pth.part"
181
- partial_file.write_bytes(b"partial_data")
182
-
183
- # Mock download would check for partial file
184
- assert partial_file.exists()
185
- assert partial_file.stat().st_size > 0
186
-
187
-
188
- class TestModelLoader:
189
- """Test model loading functionality."""
190
-
191
- @pytest.fixture
192
- def loader(self, mock_registry):
193
- """Create a test loader."""
194
- return ModelLoader(mock_registry, device='cpu')
195
-
196
- def test_loader_initialization(self, loader):
197
- """Test loader initialization."""
198
- assert loader is not None
199
- assert loader.device == 'cpu'
200
- assert loader.max_memory_bytes > 0
201
-
202
- @patch('torch.load')
203
- def test_load_pytorch_model(self, mock_torch_load, loader):
204
- """Test loading a PyTorch model."""
205
- mock_model = MagicMock()
206
- mock_torch_load.return_value = mock_model
207
-
208
- # Mock model info
209
- model_info = ModelInfo(
210
- model_id="test-pytorch",
211
- name="Test PyTorch Model",
212
- version="1.0",
213
- task=ModelTask.SEGMENTATION,
214
- framework=ModelFramework.PYTORCH,
215
- url="",
216
- filename="model.pth",
217
- local_path="/tmp/model.pth",
218
- status=ModelStatus.AVAILABLE
219
- )
220
-
221
- loader.registry.get_model = Mock(return_value=model_info)
222
-
223
- with patch.object(Path, 'exists', return_value=True):
224
- loaded = loader.load_model("test-pytorch")
225
-
226
- # Note: Full test would require more setup
227
- assert mock_torch_load.called
228
-
229
- def test_memory_management(self, loader):
230
- """Test memory management during model loading."""
231
- # Add mock models to loaded cache
232
- for i in range(5):
233
- loader.loaded_models[f"model_{i}"] = Mock(
234
- memory_usage=100 * 1024 * 1024 # 100MB each
235
- )
236
-
237
- loader.current_memory_usage = 500 * 1024 * 1024 # 500MB
238
-
239
- # Free memory
240
- loader._free_memory(200 * 1024 * 1024) # Need 200MB
241
-
242
- # Should have freed at least 2 models
243
- assert len(loader.loaded_models) < 5
244
-
245
- def test_unload_model(self, loader):
246
- """Test unloading a model."""
247
- # Add a mock model
248
- loader.loaded_models["test"] = Mock(
249
- model=Mock(),
250
- memory_usage=100 * 1024 * 1024
251
- )
252
- loader.current_memory_usage = 100 * 1024 * 1024
253
-
254
- success = loader.unload_model("test")
255
-
256
- assert success == True
257
- assert "test" not in loader.loaded_models
258
- assert loader.current_memory_usage == 0
259
-
260
- def test_get_memory_usage(self, loader):
261
- """Test getting memory usage statistics."""
262
- # Add mock models
263
- loader.loaded_models["model1"] = Mock(memory_usage=100 * 1024 * 1024)
264
- loader.loaded_models["model2"] = Mock(memory_usage=200 * 1024 * 1024)
265
- loader.current_memory_usage = 300 * 1024 * 1024
266
-
267
- usage = loader.get_memory_usage()
268
-
269
- assert usage["current_usage_mb"] == 300
270
- assert usage["loaded_models"] == 2
271
- assert "model1" in usage["models"]
272
- assert "model2" in usage["models"]
273
-
274
-
275
- class TestModelOptimizer:
276
- """Test model optimization functionality."""
277
-
278
- @pytest.fixture
279
- def optimizer(self, mock_registry):
280
- """Create a test optimizer."""
281
- loader = ModelLoader(mock_registry, device='cpu')
282
- return ModelOptimizer(loader)
283
-
284
- @patch('torch.quantization.quantize_dynamic')
285
- def test_quantize_pytorch_model(self, mock_quantize, optimizer):
286
- """Test PyTorch model quantization."""
287
- # Create mock model
288
- mock_model = MagicMock()
289
- mock_quantize.return_value = mock_model
290
-
291
- loaded = Mock(
292
- model_id="test",
293
- model=mock_model,
294
- framework=ModelFramework.PYTORCH,
295
- metadata={'input_size': (1, 3, 512, 512)}
296
- )
297
-
298
- with patch.object(optimizer, '_get_model_size', return_value=1000000):
299
- with patch.object(optimizer, '_benchmark_model', return_value=0.1):
300
- result = optimizer._quantize_pytorch(
301
- loaded,
302
- Path("/tmp"),
303
- "dynamic"
304
- )
305
-
306
- assert mock_quantize.called
307
- # Note: Full test would require more setup
308
-
309
- def test_optimization_result(self, optimizer):
310
- """Test optimization result structure."""
311
- from models.optimizer import OptimizationResult
312
-
313
- result = OptimizationResult(
314
- original_size_mb=100,
315
- optimized_size_mb=25,
316
- compression_ratio=4.0,
317
- original_speed_ms=100,
318
- optimized_speed_ms=50,
319
- speedup=2.0,
320
- accuracy_loss=0.01,
321
- optimization_time=10.0,
322
- output_path="/tmp/optimized.pth"
323
- )
324
-
325
- assert result.compression_ratio == 4.0
326
- assert result.speedup == 2.0
327
- assert result.accuracy_loss == 0.01
328
-
329
-
330
- class TestModelIntegration:
331
- """Integration tests for model management."""
332
-
333
- @pytest.mark.integration
334
- @pytest.mark.slow
335
- def test_model_registry_persistence(self, temp_dir):
336
- """Test registry persistence across instances."""
337
- # Create registry and add model
338
- registry1 = ModelRegistry(models_dir=temp_dir)
339
-
340
- test_model = ModelInfo(
341
- model_id="persistence-test",
342
- name="Persistence Test",
343
- version="1.0",
344
- task=ModelTask.SEGMENTATION,
345
- framework=ModelFramework.PYTORCH,
346
- url="http://example.com/model.pth",
347
- filename="persist.pth"
348
- )
349
-
350
- registry1.register_model(test_model)
351
-
352
- # Create new registry instance
353
- registry2 = ModelRegistry(models_dir=temp_dir)
354
-
355
- # Check if model persisted
356
- loaded_model = registry2.get_model("persistence-test")
357
- assert loaded_model is not None
358
- assert loaded_model.name == "Persistence Test"
359
-
360
- @pytest.mark.integration
361
- def test_model_manager_workflow(self):
362
- """Test complete model manager workflow."""
363
- from models import create_model_manager
364
-
365
- manager = create_model_manager()
366
-
367
- # Test model discovery
368
- stats = manager.get_stats()
369
- assert "registry" in stats
370
- assert stats["registry"]["total_models"] > 0
371
-
372
- # Test benchmark (without actual model loading)
373
- with patch.object(manager.loader, 'load_model', return_value=Mock()):
374
- benchmarks = manager.benchmark()
375
- # Would return empty without real models
376
- assert isinstance(benchmarks, dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_pipeline.py DELETED
@@ -1,349 +0,0 @@
1
- """
2
- Tests for the processing pipeline.
3
- """
4
-
5
- import pytest
6
- import numpy as np
7
- import cv2
8
- from unittest.mock import Mock, patch, MagicMock
9
- from pathlib import Path
10
-
11
- from api.pipeline import (
12
- ProcessingPipeline,
13
- PipelineConfig,
14
- PipelineResult,
15
- ProcessingMode,
16
- PipelineStage
17
- )
18
-
19
-
20
- class TestPipelineConfig:
21
- """Test pipeline configuration."""
22
-
23
- def test_default_config(self):
24
- """Test default configuration values."""
25
- config = PipelineConfig()
26
- assert config.mode == ProcessingMode.PHOTO
27
- assert config.quality_preset == "high"
28
- assert config.use_gpu == True
29
- assert config.enable_cache == True
30
-
31
- def test_custom_config(self):
32
- """Test custom configuration."""
33
- config = PipelineConfig(
34
- mode=ProcessingMode.VIDEO,
35
- quality_preset="ultra",
36
- use_gpu=False,
37
- batch_size=4
38
- )
39
- assert config.mode == ProcessingMode.VIDEO
40
- assert config.quality_preset == "ultra"
41
- assert config.use_gpu == False
42
- assert config.batch_size == 4
43
-
44
-
45
- class TestProcessingPipeline:
46
- """Test the main processing pipeline."""
47
-
48
- @pytest.fixture
49
- def mock_pipeline(self, pipeline_config):
50
- """Create a pipeline with mocked components."""
51
- with patch('api.pipeline.ModelFactory') as mock_factory:
52
- with patch('api.pipeline.DeviceManager') as mock_device:
53
- mock_device.return_value.get_device.return_value = 'cpu'
54
- mock_factory.return_value.load_model.return_value = Mock()
55
-
56
- pipeline = ProcessingPipeline(pipeline_config)
57
- return pipeline
58
-
59
- def test_pipeline_initialization(self, mock_pipeline):
60
- """Test pipeline initialization."""
61
- assert mock_pipeline is not None
62
- assert mock_pipeline.config is not None
63
- assert mock_pipeline.current_stage == PipelineStage.INITIALIZATION
64
-
65
- def test_process_image_success(self, mock_pipeline, sample_image, sample_background):
66
- """Test successful image processing."""
67
- # Mock the processing methods
68
- mock_pipeline._segment_image = Mock(return_value=np.ones((512, 512), dtype=np.uint8) * 255)
69
- mock_pipeline.alpha_matting.process = Mock(return_value={
70
- 'alpha': np.ones((512, 512), dtype=np.float32),
71
- 'confidence': 0.95
72
- })
73
-
74
- result = mock_pipeline.process_image(sample_image, sample_background)
75
-
76
- assert result is not None
77
- assert isinstance(result, PipelineResult)
78
- assert result.success == True
79
- assert result.output_image is not None
80
-
81
- def test_process_image_with_effects(self, mock_pipeline, sample_image):
82
- """Test image processing with effects."""
83
- mock_pipeline.config.apply_effects = ['bokeh', 'vignette']
84
-
85
- # Mock processing
86
- mock_pipeline._segment_image = Mock(return_value=np.ones((512, 512), dtype=np.uint8) * 255)
87
- mock_pipeline.alpha_matting.process = Mock(return_value={
88
- 'alpha': np.ones((512, 512), dtype=np.float32),
89
- 'confidence': 0.95
90
- })
91
-
92
- result = mock_pipeline.process_image(sample_image, None)
93
-
94
- assert result is not None
95
- assert result.success == True
96
-
97
- def test_process_image_failure(self, mock_pipeline, sample_image):
98
- """Test image processing failure handling."""
99
- # Mock segmentation to fail
100
- mock_pipeline._segment_image = Mock(side_effect=Exception("Segmentation failed"))
101
-
102
- result = mock_pipeline.process_image(sample_image, None)
103
-
104
- assert result is not None
105
- assert result.success == False
106
- assert len(result.errors) > 0
107
-
108
- @pytest.mark.parametrize("quality", ["low", "medium", "high", "ultra"])
109
- def test_quality_presets(self, mock_pipeline, sample_image, quality):
110
- """Test different quality presets."""
111
- mock_pipeline.config.quality_preset = quality
112
-
113
- # Mock processing
114
- mock_pipeline._segment_image = Mock(return_value=np.ones((512, 512), dtype=np.uint8) * 255)
115
- mock_pipeline.alpha_matting.process = Mock(return_value={
116
- 'alpha': np.ones((512, 512), dtype=np.float32),
117
- 'confidence': 0.95
118
- })
119
-
120
- result = mock_pipeline.process_image(sample_image, None)
121
-
122
- assert result is not None
123
- assert result.success == True
124
-
125
- def test_batch_processing(self, mock_pipeline, sample_image):
126
- """Test batch processing of multiple images."""
127
- images = [sample_image] * 3
128
-
129
- # Mock processing
130
- mock_pipeline.process_image = Mock(return_value=PipelineResult(
131
- success=True,
132
- output_image=sample_image,
133
- quality_score=0.9
134
- ))
135
-
136
- results = mock_pipeline.process_batch(images)
137
-
138
- assert len(results) == 3
139
- assert all(r.success for r in results)
140
-
141
- def test_progress_callback(self, mock_pipeline, sample_image):
142
- """Test progress callback functionality."""
143
- progress_values = []
144
-
145
- def progress_callback(value, message):
146
- progress_values.append(value)
147
-
148
- mock_pipeline.config.progress_callback = progress_callback
149
-
150
- # Mock processing
151
- mock_pipeline._segment_image = Mock(return_value=np.ones((512, 512), dtype=np.uint8) * 255)
152
- mock_pipeline.alpha_matting.process = Mock(return_value={
153
- 'alpha': np.ones((512, 512), dtype=np.float32),
154
- 'confidence': 0.95
155
- })
156
-
157
- result = mock_pipeline.process_image(sample_image, None)
158
-
159
- assert len(progress_values) > 0
160
- assert 0.0 <= max(progress_values) <= 1.0
161
-
162
- def test_cache_functionality(self, mock_pipeline, sample_image):
163
- """Test caching functionality."""
164
- mock_pipeline.config.enable_cache = True
165
-
166
- # Mock processing
167
- mock_pipeline._segment_image = Mock(return_value=np.ones((512, 512), dtype=np.uint8) * 255)
168
- mock_pipeline.alpha_matting.process = Mock(return_value={
169
- 'alpha': np.ones((512, 512), dtype=np.float32),
170
- 'confidence': 0.95
171
- })
172
-
173
- # First call
174
- result1 = mock_pipeline.process_image(sample_image, None)
175
-
176
- # Second call (should use cache)
177
- result2 = mock_pipeline.process_image(sample_image, None)
178
-
179
- assert result1.success == result2.success
180
- # Verify segmentation was only called once (cache hit on second call)
181
- assert mock_pipeline._segment_image.call_count == 1
182
-
183
- def test_memory_management(self, mock_pipeline):
184
- """Test memory management and cleanup."""
185
- initial_cache_size = len(mock_pipeline.cache)
186
-
187
- # Process multiple images to fill cache
188
- for i in range(10):
189
- image = np.random.randint(0, 255, (512, 512, 3), dtype=np.uint8)
190
- mock_pipeline.cache[f"test_{i}"] = PipelineResult(success=True)
191
-
192
- # Clear cache
193
- mock_pipeline.clear_cache()
194
-
195
- assert len(mock_pipeline.cache) == 0
196
-
197
- def test_statistics_tracking(self, mock_pipeline, sample_image):
198
- """Test statistics tracking."""
199
- # Mock processing
200
- mock_pipeline._segment_image = Mock(return_value=np.ones((512, 512), dtype=np.uint8) * 255)
201
- mock_pipeline.alpha_matting.process = Mock(return_value={
202
- 'alpha': np.ones((512, 512), dtype=np.float32),
203
- 'confidence': 0.95
204
- })
205
-
206
- # Process image
207
- result = mock_pipeline.process_image(sample_image, None)
208
-
209
- # Get statistics
210
- stats = mock_pipeline.get_statistics()
211
-
212
- assert 'total_processed' in stats
213
- assert stats['total_processed'] > 0
214
- assert 'avg_time' in stats
215
-
216
-
217
- class TestPipelineIntegration:
218
- """Integration tests for the pipeline."""
219
-
220
- @pytest.mark.integration
221
- @pytest.mark.slow
222
- def test_end_to_end_processing(self, sample_image, sample_background, temp_dir):
223
- """Test end-to-end processing pipeline."""
224
- config = PipelineConfig(
225
- use_gpu=False,
226
- quality_preset="medium",
227
- enable_cache=False
228
- )
229
-
230
- # Create pipeline (will use real components if available)
231
- try:
232
- pipeline = ProcessingPipeline(config)
233
- except Exception:
234
- pytest.skip("Models not available for integration test")
235
-
236
- # Process image
237
- result = pipeline.process_image(sample_image, sample_background)
238
-
239
- if result.success:
240
- assert result.output_image is not None
241
- assert result.output_image.shape == sample_image.shape
242
- assert result.quality_score > 0
243
-
244
- # Save output
245
- output_path = temp_dir / "test_output.png"
246
- cv2.imwrite(str(output_path), result.output_image)
247
- assert output_path.exists()
248
-
249
- @pytest.mark.integration
250
- @pytest.mark.slow
251
- def test_video_frame_processing(self, sample_video, temp_dir):
252
- """Test processing video frames."""
253
- config = PipelineConfig(
254
- mode=ProcessingMode.VIDEO,
255
- use_gpu=False,
256
- quality_preset="low"
257
- )
258
-
259
- try:
260
- pipeline = ProcessingPipeline(config)
261
- except Exception:
262
- pytest.skip("Models not available for integration test")
263
-
264
- # Open video
265
- cap = cv2.VideoCapture(sample_video)
266
- processed_frames = []
267
-
268
- # Process first 5 frames
269
- for i in range(5):
270
- ret, frame = cap.read()
271
- if not ret:
272
- break
273
-
274
- result = pipeline.process_image(frame, None)
275
- if result.success:
276
- processed_frames.append(result.output_image)
277
-
278
- cap.release()
279
-
280
- assert len(processed_frames) > 0
281
-
282
- # Save as video
283
- if processed_frames:
284
- output_path = temp_dir / "test_video_out.mp4"
285
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
286
- out = cv2.VideoWriter(str(output_path), fourcc, 30.0,
287
- (processed_frames[0].shape[1], processed_frames[0].shape[0]))
288
-
289
- for frame in processed_frames:
290
- out.write(frame)
291
-
292
- out.release()
293
- assert output_path.exists()
294
-
295
-
296
- class TestPipelinePerformance:
297
- """Performance tests for the pipeline."""
298
-
299
- @pytest.mark.slow
300
- def test_processing_speed(self, mock_pipeline, sample_image, performance_timer):
301
- """Test processing speed."""
302
- # Mock processing
303
- mock_pipeline._segment_image = Mock(return_value=np.ones((512, 512), dtype=np.uint8) * 255)
304
- mock_pipeline.alpha_matting.process = Mock(return_value={
305
- 'alpha': np.ones((512, 512), dtype=np.float32),
306
- 'confidence': 0.95
307
- })
308
-
309
- with performance_timer as timer:
310
- result = mock_pipeline.process_image(sample_image, None)
311
-
312
- assert result.success == True
313
- assert timer.elapsed < 1.0 # Should process in under 1 second
314
-
315
- @pytest.mark.slow
316
- def test_batch_processing_speed(self, mock_pipeline, sample_image, performance_timer):
317
- """Test batch processing speed."""
318
- images = [sample_image] * 10
319
-
320
- # Mock processing
321
- mock_pipeline.process_image = Mock(return_value=PipelineResult(
322
- success=True,
323
- output_image=sample_image,
324
- quality_score=0.9
325
- ))
326
-
327
- with performance_timer as timer:
328
- results = mock_pipeline.process_batch(images)
329
-
330
- assert len(results) == 10
331
- assert timer.elapsed < 5.0 # Should process 10 images in under 5 seconds
332
-
333
- def test_memory_usage(self, mock_pipeline, sample_image):
334
- """Test memory usage during processing."""
335
- import psutil
336
- import os
337
-
338
- process = psutil.Process(os.getpid())
339
- initial_memory = process.memory_info().rss / 1024 / 1024 # MB
340
-
341
- # Process multiple images
342
- for _ in range(10):
343
- mock_pipeline.process_image(sample_image, None)
344
-
345
- final_memory = process.memory_info().rss / 1024 / 1024 # MB
346
- memory_increase = final_memory - initial_memory
347
-
348
- # Memory increase should be reasonable (less than 500MB for 10 images)
349
- assert memory_increase < 500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_smoke.py DELETED
@@ -1,267 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Smoke test for two-stage video processing
4
- THIS IS A NEW FILE - Basic end-to-end test
5
- Tests quality profiles, frame count preservation, and basic functionality
6
- """
7
- import os
8
- import sys
9
- import cv2
10
- import numpy as np
11
- import tempfile
12
- import logging
13
- import time
14
- from pathlib import Path
15
-
16
- # Add project root to path
17
- sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
18
-
19
- from processing.two_stage.two_stage_processor import TwoStageProcessor
20
- from models.loaders.matanyone_loader import MatAnyoneLoader
21
-
22
- logging.basicConfig(level=logging.INFO)
23
- logger = logging.getLogger(__name__)
24
-
25
- def create_test_video(path: str, frames: int = 30, fps: int = 30):
26
- """Create a simple test video with a moving circle (simulating a person)"""
27
- width, height = 640, 480
28
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
29
- out = cv2.VideoWriter(path, fourcc, fps, (width, height))
30
-
31
- if not out.isOpened():
32
- raise RuntimeError(f"Failed to create test video at {path}")
33
-
34
- for i in range(frames):
35
- # Create frame with moving white circle on dark background
36
- frame = np.zeros((height, width, 3), dtype=np.uint8)
37
- frame[:] = (30, 30, 30) # Dark gray background
38
-
39
- # Draw a moving circle (simulating a person)
40
- x = int(width/2 + 100 * np.sin(i * 0.2))
41
- y = int(height/2 + 50 * np.cos(i * 0.15))
42
- cv2.circle(frame, (x, y), 60, (255, 255, 255), -1)
43
-
44
- # Add some variation to simulate clothing
45
- cv2.circle(frame, (x, y-20), 20, (200, 100, 100), -1) # "shirt"
46
-
47
- out.write(frame)
48
-
49
- out.release()
50
-
51
- # Verify the video was created
52
- cap = cv2.VideoCapture(path)
53
- actual_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
54
- cap.release()
55
-
56
- logger.info(f"Created test video: {path} ({actual_frames} frames)")
57
- return actual_frames
58
-
59
-
60
- def verify_output_video(path: str, expected_frames: int) -> bool:
61
- """Verify output video exists and has correct frame count"""
62
- if not os.path.exists(path):
63
- logger.error(f"Output video not found: {path}")
64
- return False
65
-
66
- file_size = os.path.getsize(path)
67
- if file_size < 1000:
68
- logger.error(f"Output video too small: {file_size} bytes")
69
- return False
70
-
71
- cap = cv2.VideoCapture(path)
72
- if not cap.isOpened():
73
- logger.error(f"Cannot open output video: {path}")
74
- return False
75
-
76
- actual_frames = 0
77
- while True:
78
- ret, frame = cap.read()
79
- if not ret:
80
- break
81
- actual_frames += 1
82
-
83
- cap.release()
84
-
85
- if actual_frames != expected_frames:
86
- logger.error(f"Frame count mismatch: got {actual_frames}, expected {expected_frames}")
87
- return False
88
-
89
- logger.info(f"Output verified: {actual_frames} frames, {file_size:,} bytes")
90
- return True
91
-
92
-
93
- def test_quality_profiles():
94
- """Test that different quality profiles produce different results"""
95
- logger.info("="*60)
96
- logger.info("Testing Quality Profiles")
97
- logger.info("="*60)
98
-
99
- with tempfile.TemporaryDirectory() as tmpdir:
100
- tmpdir = Path(tmpdir)
101
-
102
- # Create test video
103
- test_video = tmpdir / "test_input.mp4"
104
- expected_frames = create_test_video(str(test_video), frames=30, fps=30)
105
-
106
- # Create a simple background
107
- background = np.ones((480, 640, 3), dtype=np.uint8) * 128 # Gray
108
- background[:240, :] = (100, 150, 200) # Blue top half
109
-
110
- results = {}
111
-
112
- for quality in ["speed", "balanced", "max"]:
113
- logger.info(f"\nTesting quality mode: {quality}")
114
- logger.info("-" * 40)
115
-
116
- # Set quality environment variable
117
- os.environ["BFX_QUALITY"] = quality
118
-
119
- # Initialize processor (without models for basic test)
120
- processor = TwoStageProcessor(
121
- sam2_predictor=None, # Will use fallback
122
- matanyone_model=None # Will use fallback
123
- )
124
-
125
- # Process video
126
- output_path = tmpdir / f"output_{quality}.mp4"
127
- start_time = time.time()
128
-
129
- result, message = processor.process_full_pipeline(
130
- video_path=str(test_video),
131
- background=background,
132
- output_path=str(output_path),
133
- key_color_mode="auto",
134
- chroma_settings=None,
135
- progress_callback=lambda p, d: logger.debug(f"{p:.1%}: {d}"),
136
- stop_event=None
137
- )
138
-
139
- process_time = time.time() - start_time
140
-
141
- if result is None:
142
- logger.error(f"Processing failed for {quality}: {message}")
143
- continue
144
-
145
- # Verify output
146
- if verify_output_video(result, expected_frames):
147
- results[quality] = {
148
- "success": True,
149
- "time": process_time,
150
- "frames_refined": processor.frames_refined,
151
- "total_frames": processor.total_frames_processed
152
- }
153
- logger.info(f"✓ {quality}: {process_time:.2f}s, "
154
- f"{processor.frames_refined}/{processor.total_frames_processed} refined")
155
- else:
156
- results[quality] = {"success": False}
157
- logger.error(f"✗ {quality}: verification failed")
158
-
159
- # Summary
160
- logger.info("\n" + "="*60)
161
- logger.info("SUMMARY")
162
- logger.info("="*60)
163
-
164
- all_passed = all(r.get("success", False) for r in results.values())
165
-
166
- if all_passed:
167
- # Check that quality modes are actually different
168
- if len(results) >= 2:
169
- times = [r["time"] for r in results.values() if "time" in r]
170
- refined_counts = [r["frames_refined"] for r in results.values() if "frames_refined" in r]
171
-
172
- if len(set(refined_counts)) > 1:
173
- logger.info("✓ Quality profiles show different refinement counts")
174
- else:
175
- logger.warning("⚠ All quality profiles refined same number of frames")
176
-
177
- if max(times) - min(times) > 0.1:
178
- logger.info("✓ Quality profiles show different processing times")
179
- else:
180
- logger.warning("⚠ Quality profiles have similar processing times")
181
-
182
- for quality, result in results.items():
183
- if result.get("success"):
184
- logger.info(f"✓ {quality:8s}: {result['time']:.2f}s, "
185
- f"{result['frames_refined']}/{result['total_frames']} frames refined")
186
- else:
187
- logger.info(f"✗ {quality:8s}: FAILED")
188
-
189
- return all_passed
190
-
191
-
192
- def test_frame_preservation():
193
- """Test that no frames are lost during processing"""
194
- logger.info("\n" + "="*60)
195
- logger.info("Testing Frame Preservation")
196
- logger.info("="*60)
197
-
198
- with tempfile.TemporaryDirectory() as tmpdir:
199
- tmpdir = Path(tmpdir)
200
-
201
- # Test different frame counts
202
- test_cases = [10, 25, 30, 60]
203
-
204
- for frame_count in test_cases:
205
- logger.info(f"\nTesting with {frame_count} frames...")
206
-
207
- test_video = tmpdir / f"test_{frame_count}.mp4"
208
- expected = create_test_video(str(test_video), frames=frame_count, fps=30)
209
-
210
- os.environ["BFX_QUALITY"] = "speed" # Fast for this test
211
-
212
- processor = TwoStageProcessor()
213
- output_path = tmpdir / f"output_{frame_count}.mp4"
214
-
215
- result, message = processor.process_full_pipeline(
216
- video_path=str(test_video),
217
- background=np.ones((480, 640, 3), dtype=np.uint8) * 100,
218
- output_path=str(output_path),
219
- key_color_mode="green",
220
- )
221
-
222
- if result and verify_output_video(result, expected):
223
- logger.info(f"✓ {frame_count} frames: preserved correctly")
224
- else:
225
- logger.error(f"✗ {frame_count} frames: FAILED")
226
- return False
227
-
228
- logger.info("\n✓ All frame preservation tests passed!")
229
- return True
230
-
231
-
232
- def main():
233
- """Run all smoke tests"""
234
- logger.info("\n" + "🔥"*20)
235
- logger.info("BACKGROUNDFX PRO SMOKE TESTS")
236
- logger.info("🔥"*20)
237
-
238
- tests_passed = []
239
-
240
- # Test 1: Quality profiles
241
- try:
242
- tests_passed.append(test_quality_profiles())
243
- except Exception as e:
244
- logger.error(f"Quality profile test crashed: {e}")
245
- tests_passed.append(False)
246
-
247
- # Test 2: Frame preservation
248
- try:
249
- tests_passed.append(test_frame_preservation())
250
- except Exception as e:
251
- logger.error(f"Frame preservation test crashed: {e}")
252
- tests_passed.append(False)
253
-
254
- # Final result
255
- logger.info("\n" + "="*60)
256
- if all(tests_passed):
257
- logger.info("✅ ALL SMOKE TESTS PASSED!")
258
- logger.info("="*60)
259
- return 0
260
- else:
261
- logger.error("❌ SOME TESTS FAILED")
262
- logger.info("="*60)
263
- return 1
264
-
265
-
266
- if __name__ == "__main__":
267
- exit(main())