|
|
|
|
|
|
|
|
""" |
|
|
Tests for batch_inference and get_optimal_batch_size methods in DepthAnything3 API. |
|
|
|
|
|
These tests mock the actual model inference to focus on testing the batching logic, |
|
|
without needing to load heavy model weights. |
|
|
""" |
|
|
from __future__ import annotations |
|
|
|
|
|
from dataclasses import dataclass |
|
|
from unittest.mock import MagicMock, patch |
|
|
|
|
|
import numpy as np |
|
|
import pytest |
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class MockPrediction: |
|
|
"""Mock Prediction object for testing.""" |
|
|
|
|
|
depth: np.ndarray |
|
|
processed_images: np.ndarray |
|
|
num_images: int |
|
|
|
|
|
@classmethod |
|
|
def create(cls, num_images: int) -> "MockPrediction": |
|
|
"""Create a mock prediction for n images.""" |
|
|
return cls( |
|
|
depth=np.zeros((num_images, 256, 256), dtype=np.float32), |
|
|
processed_images=np.zeros((num_images, 256, 256, 3), dtype=np.uint8), |
|
|
num_images=num_images, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def cpu_device(): |
|
|
"""Return CPU device.""" |
|
|
return torch.device("cpu") |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def mock_model(cpu_device): |
|
|
"""Create a mock DepthAnything3 model.""" |
|
|
from depth_anything_3.api import DepthAnything3 |
|
|
|
|
|
|
|
|
model = MagicMock(spec=DepthAnything3) |
|
|
model.device = cpu_device |
|
|
model.model_name = "da3-large" |
|
|
|
|
|
|
|
|
def mock_inference(image, process_res=504, **kwargs): |
|
|
num_images = len(image) if isinstance(image, list) else 1 |
|
|
return MockPrediction.create(num_images) |
|
|
|
|
|
model.inference = MagicMock(side_effect=mock_inference) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def sample_images(): |
|
|
"""Create sample image paths for testing.""" |
|
|
return [f"image_{i}.jpg" for i in range(10)] |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def large_sample_images(): |
|
|
"""Create larger sample of image paths.""" |
|
|
return [f"image_{i}.jpg" for i in range(100)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestBatchInference: |
|
|
"""Tests for the batch_inference method.""" |
|
|
|
|
|
def test_batch_inference_empty_list(self, mock_model): |
|
|
"""Test batch_inference with empty image list.""" |
|
|
from depth_anything_3.api import DepthAnything3 |
|
|
|
|
|
|
|
|
with patch.object(DepthAnything3, "__init__", lambda x, **k: None): |
|
|
api = DepthAnything3() |
|
|
api.device = mock_model.device |
|
|
api.model_name = mock_model.model_name |
|
|
api.inference = mock_model.inference |
|
|
|
|
|
results = api.batch_inference([]) |
|
|
|
|
|
assert results == [] |
|
|
mock_model.inference.assert_not_called() |
|
|
|
|
|
def test_batch_inference_fixed_batch_size(self, mock_model, sample_images): |
|
|
"""Test batch_inference with fixed batch size.""" |
|
|
from depth_anything_3.api import DepthAnything3 |
|
|
|
|
|
with patch.object(DepthAnything3, "__init__", lambda x, **k: None): |
|
|
api = DepthAnything3() |
|
|
api.device = mock_model.device |
|
|
api.model_name = mock_model.model_name |
|
|
api.inference = mock_model.inference |
|
|
|
|
|
results = api.batch_inference(sample_images, batch_size=3) |
|
|
|
|
|
|
|
|
assert len(results) == 4 |
|
|
assert mock_model.inference.call_count == 4 |
|
|
|
|
|
def test_batch_inference_auto_batch_size(self, mock_model, sample_images): |
|
|
"""Test batch_inference with auto batch size.""" |
|
|
from depth_anything_3.api import DepthAnything3 |
|
|
|
|
|
with patch.object(DepthAnything3, "__init__", lambda x, **k: None): |
|
|
api = DepthAnything3() |
|
|
api.device = mock_model.device |
|
|
api.model_name = mock_model.model_name |
|
|
api.inference = mock_model.inference |
|
|
|
|
|
results = api.batch_inference(sample_images, batch_size="auto") |
|
|
|
|
|
|
|
|
assert len(results) >= 1 |
|
|
|
|
|
assert mock_model.inference.call_count >= 1 |
|
|
|
|
|
def test_batch_inference_progress_callback(self, mock_model, sample_images): |
|
|
"""Test that progress callback is called.""" |
|
|
from depth_anything_3.api import DepthAnything3 |
|
|
|
|
|
progress_calls = [] |
|
|
|
|
|
def progress_callback(processed, total): |
|
|
progress_calls.append((processed, total)) |
|
|
|
|
|
with patch.object(DepthAnything3, "__init__", lambda x, **k: None): |
|
|
api = DepthAnything3() |
|
|
api.device = mock_model.device |
|
|
api.model_name = mock_model.model_name |
|
|
api.inference = mock_model.inference |
|
|
|
|
|
api.batch_inference( |
|
|
sample_images, batch_size=3, progress_callback=progress_callback |
|
|
) |
|
|
|
|
|
|
|
|
assert len(progress_calls) == 4 |
|
|
|
|
|
|
|
|
assert progress_calls[-1][0] == len(sample_images) |
|
|
assert progress_calls[-1][1] == len(sample_images) |
|
|
|
|
|
def test_batch_inference_single_image(self, mock_model): |
|
|
"""Test batch_inference with single image.""" |
|
|
from depth_anything_3.api import DepthAnything3 |
|
|
|
|
|
with patch.object(DepthAnything3, "__init__", lambda x, **k: None): |
|
|
api = DepthAnything3() |
|
|
api.device = mock_model.device |
|
|
api.model_name = mock_model.model_name |
|
|
api.inference = mock_model.inference |
|
|
|
|
|
results = api.batch_inference(["single.jpg"]) |
|
|
|
|
|
assert len(results) == 1 |
|
|
mock_model.inference.assert_called_once() |
|
|
|
|
|
def test_batch_inference_batch_larger_than_images(self, mock_model): |
|
|
"""Test when batch size is larger than number of images.""" |
|
|
from depth_anything_3.api import DepthAnything3 |
|
|
|
|
|
images = ["img1.jpg", "img2.jpg"] |
|
|
|
|
|
with patch.object(DepthAnything3, "__init__", lambda x, **k: None): |
|
|
api = DepthAnything3() |
|
|
api.device = mock_model.device |
|
|
api.model_name = mock_model.model_name |
|
|
api.inference = mock_model.inference |
|
|
|
|
|
results = api.batch_inference(images, batch_size=10) |
|
|
|
|
|
|
|
|
assert len(results) == 1 |
|
|
mock_model.inference.assert_called_once() |
|
|
|
|
|
def test_batch_inference_exact_batch_multiple(self, mock_model): |
|
|
"""Test when image count is exact multiple of batch size.""" |
|
|
from depth_anything_3.api import DepthAnything3 |
|
|
|
|
|
images = [f"img{i}.jpg" for i in range(12)] |
|
|
|
|
|
with patch.object(DepthAnything3, "__init__", lambda x, **k: None): |
|
|
api = DepthAnything3() |
|
|
api.device = mock_model.device |
|
|
api.model_name = mock_model.model_name |
|
|
api.inference = mock_model.inference |
|
|
|
|
|
results = api.batch_inference(images, batch_size=3) |
|
|
|
|
|
assert len(results) == 4 |
|
|
assert mock_model.inference.call_count == 4 |
|
|
|
|
|
def test_batch_inference_respects_process_res(self, mock_model, sample_images): |
|
|
"""Test that process_res is passed to inference.""" |
|
|
from depth_anything_3.api import DepthAnything3 |
|
|
|
|
|
with patch.object(DepthAnything3, "__init__", lambda x, **k: None): |
|
|
api = DepthAnything3() |
|
|
api.device = mock_model.device |
|
|
api.model_name = mock_model.model_name |
|
|
api.inference = mock_model.inference |
|
|
|
|
|
api.batch_inference(sample_images, batch_size=10, process_res=1024) |
|
|
|
|
|
|
|
|
call_args = mock_model.inference.call_args |
|
|
assert call_args.kwargs.get("process_res") == 1024 |
|
|
|
|
|
def test_batch_inference_max_batch_size_auto(self, mock_model, sample_images): |
|
|
"""Test max_batch_size parameter with auto batching.""" |
|
|
from depth_anything_3.api import DepthAnything3 |
|
|
|
|
|
with patch.object(DepthAnything3, "__init__", lambda x, **k: None): |
|
|
api = DepthAnything3() |
|
|
api.device = mock_model.device |
|
|
api.model_name = mock_model.model_name |
|
|
api.inference = mock_model.inference |
|
|
|
|
|
|
|
|
results = api.batch_inference( |
|
|
sample_images, batch_size="auto", max_batch_size=2 |
|
|
) |
|
|
|
|
|
|
|
|
assert len(results) >= 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestGetOptimalBatchSize: |
|
|
"""Tests for the get_optimal_batch_size method.""" |
|
|
|
|
|
def test_get_optimal_batch_size_returns_int(self, cpu_device): |
|
|
"""Test that get_optimal_batch_size returns an integer.""" |
|
|
from depth_anything_3.api import DepthAnything3 |
|
|
|
|
|
with patch.object(DepthAnything3, "__init__", lambda x, **k: None): |
|
|
api = DepthAnything3() |
|
|
api.device = cpu_device |
|
|
api.model_name = "da3-large" |
|
|
|
|
|
result = api.get_optimal_batch_size() |
|
|
|
|
|
assert isinstance(result, int) |
|
|
assert result > 0 |
|
|
|
|
|
def test_get_optimal_batch_size_respects_resolution(self, cpu_device): |
|
|
"""Test that different resolutions affect the result.""" |
|
|
from depth_anything_3.api import DepthAnything3 |
|
|
|
|
|
with patch.object(DepthAnything3, "__init__", lambda x, **k: None): |
|
|
api = DepthAnything3() |
|
|
api.device = cpu_device |
|
|
api.model_name = "da3-large" |
|
|
|
|
|
low_res = api.get_optimal_batch_size(process_res=256) |
|
|
high_res = api.get_optimal_batch_size(process_res=1024) |
|
|
|
|
|
|
|
|
assert low_res > 0 |
|
|
assert high_res > 0 |
|
|
|
|
|
def test_get_optimal_batch_size_respects_utilization(self, cpu_device): |
|
|
"""Test that target_utilization parameter is used.""" |
|
|
from depth_anything_3.api import DepthAnything3 |
|
|
|
|
|
with patch.object(DepthAnything3, "__init__", lambda x, **k: None): |
|
|
api = DepthAnything3() |
|
|
api.device = cpu_device |
|
|
api.model_name = "da3-large" |
|
|
|
|
|
low_util = api.get_optimal_batch_size(target_utilization=0.5) |
|
|
high_util = api.get_optimal_batch_size(target_utilization=0.95) |
|
|
|
|
|
|
|
|
assert low_util > 0 |
|
|
assert high_util > 0 |
|
|
|
|
|
def test_get_optimal_batch_size_different_models(self, cpu_device): |
|
|
"""Test with different model names.""" |
|
|
from depth_anything_3.api import DepthAnything3 |
|
|
|
|
|
models = ["da3-small", "da3-base", "da3-large", "da3-giant"] |
|
|
|
|
|
for model_name in models: |
|
|
with patch.object(DepthAnything3, "__init__", lambda x, **k: None): |
|
|
api = DepthAnything3() |
|
|
api.device = cpu_device |
|
|
api.model_name = model_name |
|
|
|
|
|
result = api.get_optimal_batch_size() |
|
|
assert result > 0, f"Failed for model {model_name}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestBatchingIntegration: |
|
|
"""Integration tests for batching functionality.""" |
|
|
|
|
|
def test_auto_vs_fixed_batching_coverage(self, mock_model, sample_images): |
|
|
"""Test that both auto and fixed batching process all images.""" |
|
|
from depth_anything_3.api import DepthAnything3 |
|
|
|
|
|
with patch.object(DepthAnything3, "__init__", lambda x, **k: None): |
|
|
api = DepthAnything3() |
|
|
api.device = mock_model.device |
|
|
api.model_name = mock_model.model_name |
|
|
api.inference = mock_model.inference |
|
|
|
|
|
|
|
|
auto_images_processed = [] |
|
|
fixed_images_processed = [] |
|
|
|
|
|
def track_auto(image, **kwargs): |
|
|
batch = image if isinstance(image, list) else [image] |
|
|
auto_images_processed.extend(batch) |
|
|
return MockPrediction.create(len(batch)) |
|
|
|
|
|
def track_fixed(image, **kwargs): |
|
|
batch = image if isinstance(image, list) else [image] |
|
|
fixed_images_processed.extend(batch) |
|
|
return MockPrediction.create(len(batch)) |
|
|
|
|
|
|
|
|
mock_model.inference.side_effect = track_auto |
|
|
api.inference = mock_model.inference |
|
|
api.batch_inference(sample_images.copy(), batch_size="auto") |
|
|
|
|
|
|
|
|
mock_model.inference.side_effect = track_fixed |
|
|
api.inference = mock_model.inference |
|
|
api.batch_inference(sample_images.copy(), batch_size=3) |
|
|
|
|
|
|
|
|
assert len(auto_images_processed) == len(sample_images) |
|
|
assert len(fixed_images_processed) == len(sample_images) |
|
|
|
|
|
def test_batch_inference_preserves_order(self, mock_model): |
|
|
"""Test that batch_inference preserves image order in processing.""" |
|
|
from depth_anything_3.api import DepthAnything3 |
|
|
|
|
|
images = ["first.jpg", "second.jpg", "third.jpg", "fourth.jpg", "fifth.jpg"] |
|
|
processed_order = [] |
|
|
|
|
|
def track_order(image, **kwargs): |
|
|
batch = image if isinstance(image, list) else [image] |
|
|
processed_order.extend(batch) |
|
|
return MockPrediction.create(len(batch)) |
|
|
|
|
|
with patch.object(DepthAnything3, "__init__", lambda x, **k: None): |
|
|
api = DepthAnything3() |
|
|
api.device = mock_model.device |
|
|
api.model_name = mock_model.model_name |
|
|
mock_model.inference.side_effect = track_order |
|
|
api.inference = mock_model.inference |
|
|
|
|
|
api.batch_inference(images, batch_size=2) |
|
|
|
|
|
assert processed_order == images |
|
|
|
|
|
def test_progress_increases_monotonically(self, mock_model, sample_images): |
|
|
"""Test that progress always increases.""" |
|
|
from depth_anything_3.api import DepthAnything3 |
|
|
|
|
|
progress_values = [] |
|
|
|
|
|
def progress_callback(processed, total): |
|
|
progress_values.append(processed) |
|
|
|
|
|
with patch.object(DepthAnything3, "__init__", lambda x, **k: None): |
|
|
api = DepthAnything3() |
|
|
api.device = mock_model.device |
|
|
api.model_name = mock_model.model_name |
|
|
api.inference = mock_model.inference |
|
|
|
|
|
api.batch_inference( |
|
|
sample_images, batch_size=3, progress_callback=progress_callback |
|
|
) |
|
|
|
|
|
|
|
|
for i in range(1, len(progress_values)): |
|
|
assert progress_values[i] > progress_values[i - 1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestBatchingEdgeCases: |
|
|
"""Tests for edge cases in batching.""" |
|
|
|
|
|
def test_batch_size_one(self, mock_model, sample_images): |
|
|
"""Test with batch size of 1.""" |
|
|
from depth_anything_3.api import DepthAnything3 |
|
|
|
|
|
with patch.object(DepthAnything3, "__init__", lambda x, **k: None): |
|
|
api = DepthAnything3() |
|
|
api.device = mock_model.device |
|
|
api.model_name = mock_model.model_name |
|
|
api.inference = mock_model.inference |
|
|
|
|
|
results = api.batch_inference(sample_images, batch_size=1) |
|
|
|
|
|
|
|
|
assert len(results) == len(sample_images) |
|
|
assert mock_model.inference.call_count == len(sample_images) |
|
|
|
|
|
def test_very_large_batch_size(self, mock_model, sample_images): |
|
|
"""Test with very large batch size.""" |
|
|
from depth_anything_3.api import DepthAnything3 |
|
|
|
|
|
with patch.object(DepthAnything3, "__init__", lambda x, **k: None): |
|
|
api = DepthAnything3() |
|
|
api.device = mock_model.device |
|
|
api.model_name = mock_model.model_name |
|
|
api.inference = mock_model.inference |
|
|
|
|
|
results = api.batch_inference(sample_images, batch_size=1000) |
|
|
|
|
|
|
|
|
assert len(results) == 1 |
|
|
|
|
|
def test_auto_with_very_low_memory_utilization(self, mock_model, sample_images): |
|
|
"""Test auto batching with very low memory utilization target.""" |
|
|
from depth_anything_3.api import DepthAnything3 |
|
|
|
|
|
with patch.object(DepthAnything3, "__init__", lambda x, **k: None): |
|
|
api = DepthAnything3() |
|
|
api.device = mock_model.device |
|
|
api.model_name = mock_model.model_name |
|
|
api.inference = mock_model.inference |
|
|
|
|
|
results = api.batch_inference( |
|
|
sample_images, batch_size="auto", target_memory_utilization=0.1 |
|
|
) |
|
|
|
|
|
|
|
|
total_processed = sum(r.num_images for r in results) |
|
|
assert total_processed == len(sample_images) |
|
|
|
|
|
def test_numpy_array_inputs(self, mock_model): |
|
|
"""Test with numpy array inputs instead of paths.""" |
|
|
from depth_anything_3.api import DepthAnything3 |
|
|
|
|
|
|
|
|
images = [np.zeros((256, 256, 3), dtype=np.uint8) for _ in range(5)] |
|
|
|
|
|
with patch.object(DepthAnything3, "__init__", lambda x, **k: None): |
|
|
api = DepthAnything3() |
|
|
api.device = mock_model.device |
|
|
api.model_name = mock_model.model_name |
|
|
api.inference = mock_model.inference |
|
|
|
|
|
results = api.batch_inference(images, batch_size=2) |
|
|
|
|
|
assert len(results) == 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestMemoryCleanup: |
|
|
"""Tests for memory cleanup during batching.""" |
|
|
|
|
|
def test_gc_collect_called_between_batches(self, mock_model, sample_images): |
|
|
"""Test that garbage collection is called between batches.""" |
|
|
from depth_anything_3.api import DepthAnything3 |
|
|
|
|
|
with patch.object(DepthAnything3, "__init__", lambda x, **k: None): |
|
|
api = DepthAnything3() |
|
|
api.device = mock_model.device |
|
|
api.model_name = mock_model.model_name |
|
|
api.inference = mock_model.inference |
|
|
|
|
|
with patch("gc.collect") as mock_gc: |
|
|
api.batch_inference(sample_images, batch_size=3) |
|
|
|
|
|
|
|
|
|
|
|
assert mock_gc.call_count == 3 |
|
|
|
|
|
def test_cuda_empty_cache_called(self, sample_images): |
|
|
"""Test that cuda empty_cache is called on CUDA device.""" |
|
|
from depth_anything_3.api import DepthAnything3 |
|
|
|
|
|
def mock_inference(image, **kwargs): |
|
|
num = len(image) if isinstance(image, list) else 1 |
|
|
return MockPrediction.create(num) |
|
|
|
|
|
mock_model = MagicMock(spec=DepthAnything3) |
|
|
mock_model.device = torch.device("cuda:0") |
|
|
mock_model.model_name = "da3-large" |
|
|
mock_model.inference = MagicMock(side_effect=mock_inference) |
|
|
|
|
|
with patch.object(DepthAnything3, "__init__", lambda x, **k: None): |
|
|
api = DepthAnything3() |
|
|
api.device = mock_model.device |
|
|
api.model_name = mock_model.model_name |
|
|
api.inference = mock_model.inference |
|
|
|
|
|
with patch("torch.cuda.empty_cache") as mock_empty: |
|
|
api.batch_inference(sample_images, batch_size=3) |
|
|
|
|
|
|
|
|
assert mock_empty.call_count == 3 |
|
|
|
|
|
def test_mps_empty_cache_called(self, sample_images): |
|
|
"""Test that mps empty_cache is called on MPS device.""" |
|
|
from depth_anything_3.api import DepthAnything3 |
|
|
|
|
|
def mock_inference(image, **kwargs): |
|
|
num = len(image) if isinstance(image, list) else 1 |
|
|
return MockPrediction.create(num) |
|
|
|
|
|
mock_model = MagicMock(spec=DepthAnything3) |
|
|
mock_model.device = torch.device("mps") |
|
|
mock_model.model_name = "da3-large" |
|
|
mock_model.inference = MagicMock(side_effect=mock_inference) |
|
|
|
|
|
with patch.object(DepthAnything3, "__init__", lambda x, **k: None): |
|
|
api = DepthAnything3() |
|
|
api.device = mock_model.device |
|
|
api.model_name = mock_model.model_name |
|
|
api.inference = mock_model.inference |
|
|
|
|
|
with patch("torch.mps.empty_cache") as mock_empty: |
|
|
api.batch_inference(sample_images, batch_size=3) |
|
|
|
|
|
assert mock_empty.call_count == 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
pytest.main([__file__, "-v"]) |
|
|
|