awesome-depth-anything-3 / tests /test_api_batching.py
Delanoe Pirard
Deploy to HuggingFace Spaces
18b382b
raw
history blame
21.7 kB
# Copyright (c) Delanoe Pirard / Aedelon
# Licensed under the Apache License, Version 2.0
"""
Tests for batch_inference and get_optimal_batch_size methods in DepthAnything3 API.
These tests mock the actual model inference to focus on testing the batching logic,
without needing to load heavy model weights.
"""
from __future__ import annotations
from dataclasses import dataclass
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
import torch
# =============================================================================
# Mock Prediction Class
# =============================================================================
@dataclass
class MockPrediction:
"""Mock Prediction object for testing."""
depth: np.ndarray
processed_images: np.ndarray
num_images: int
@classmethod
def create(cls, num_images: int) -> "MockPrediction":
"""Create a mock prediction for n images."""
return cls(
depth=np.zeros((num_images, 256, 256), dtype=np.float32),
processed_images=np.zeros((num_images, 256, 256, 3), dtype=np.uint8),
num_images=num_images,
)
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture
def cpu_device():
"""Return CPU device."""
return torch.device("cpu")
@pytest.fixture
def mock_model(cpu_device):
"""Create a mock DepthAnything3 model."""
from depth_anything_3.api import DepthAnything3
# Create a minimal mock
model = MagicMock(spec=DepthAnything3)
model.device = cpu_device
model.model_name = "da3-large"
# Setup inference to return mock predictions
def mock_inference(image, process_res=504, **kwargs):
num_images = len(image) if isinstance(image, list) else 1
return MockPrediction.create(num_images)
model.inference = MagicMock(side_effect=mock_inference)
return model
@pytest.fixture
def sample_images():
"""Create sample image paths for testing."""
return [f"image_{i}.jpg" for i in range(10)]
@pytest.fixture
def large_sample_images():
"""Create larger sample of image paths."""
return [f"image_{i}.jpg" for i in range(100)]
# =============================================================================
# batch_inference Tests
# =============================================================================
class TestBatchInference:
"""Tests for the batch_inference method."""
def test_batch_inference_empty_list(self, mock_model):
"""Test batch_inference with empty image list."""
from depth_anything_3.api import DepthAnything3
# Call the actual method implementation with mocked model
with patch.object(DepthAnything3, "__init__", lambda x, **k: None):
api = DepthAnything3()
api.device = mock_model.device
api.model_name = mock_model.model_name
api.inference = mock_model.inference
results = api.batch_inference([])
assert results == []
mock_model.inference.assert_not_called()
def test_batch_inference_fixed_batch_size(self, mock_model, sample_images):
"""Test batch_inference with fixed batch size."""
from depth_anything_3.api import DepthAnything3
with patch.object(DepthAnything3, "__init__", lambda x, **k: None):
api = DepthAnything3()
api.device = mock_model.device
api.model_name = mock_model.model_name
api.inference = mock_model.inference
results = api.batch_inference(sample_images, batch_size=3)
# 10 images with batch size 3 = 4 batches (3, 3, 3, 1)
assert len(results) == 4
assert mock_model.inference.call_count == 4
def test_batch_inference_auto_batch_size(self, mock_model, sample_images):
"""Test batch_inference with auto batch size."""
from depth_anything_3.api import DepthAnything3
with patch.object(DepthAnything3, "__init__", lambda x, **k: None):
api = DepthAnything3()
api.device = mock_model.device
api.model_name = mock_model.model_name
api.inference = mock_model.inference
results = api.batch_inference(sample_images, batch_size="auto")
# Should have at least 1 result
assert len(results) >= 1
# Should have called inference at least once
assert mock_model.inference.call_count >= 1
def test_batch_inference_progress_callback(self, mock_model, sample_images):
"""Test that progress callback is called."""
from depth_anything_3.api import DepthAnything3
progress_calls = []
def progress_callback(processed, total):
progress_calls.append((processed, total))
with patch.object(DepthAnything3, "__init__", lambda x, **k: None):
api = DepthAnything3()
api.device = mock_model.device
api.model_name = mock_model.model_name
api.inference = mock_model.inference
api.batch_inference(
sample_images, batch_size=3, progress_callback=progress_callback
)
# Should have progress calls
assert len(progress_calls) == 4 # 4 batches
# Last call should have all images processed
assert progress_calls[-1][0] == len(sample_images)
assert progress_calls[-1][1] == len(sample_images)
def test_batch_inference_single_image(self, mock_model):
"""Test batch_inference with single image."""
from depth_anything_3.api import DepthAnything3
with patch.object(DepthAnything3, "__init__", lambda x, **k: None):
api = DepthAnything3()
api.device = mock_model.device
api.model_name = mock_model.model_name
api.inference = mock_model.inference
results = api.batch_inference(["single.jpg"])
assert len(results) == 1
mock_model.inference.assert_called_once()
def test_batch_inference_batch_larger_than_images(self, mock_model):
"""Test when batch size is larger than number of images."""
from depth_anything_3.api import DepthAnything3
images = ["img1.jpg", "img2.jpg"]
with patch.object(DepthAnything3, "__init__", lambda x, **k: None):
api = DepthAnything3()
api.device = mock_model.device
api.model_name = mock_model.model_name
api.inference = mock_model.inference
results = api.batch_inference(images, batch_size=10)
# Should only make one call with all images
assert len(results) == 1
mock_model.inference.assert_called_once()
def test_batch_inference_exact_batch_multiple(self, mock_model):
"""Test when image count is exact multiple of batch size."""
from depth_anything_3.api import DepthAnything3
images = [f"img{i}.jpg" for i in range(12)] # Exactly 4 batches of 3
with patch.object(DepthAnything3, "__init__", lambda x, **k: None):
api = DepthAnything3()
api.device = mock_model.device
api.model_name = mock_model.model_name
api.inference = mock_model.inference
results = api.batch_inference(images, batch_size=3)
assert len(results) == 4
assert mock_model.inference.call_count == 4
def test_batch_inference_respects_process_res(self, mock_model, sample_images):
"""Test that process_res is passed to inference."""
from depth_anything_3.api import DepthAnything3
with patch.object(DepthAnything3, "__init__", lambda x, **k: None):
api = DepthAnything3()
api.device = mock_model.device
api.model_name = mock_model.model_name
api.inference = mock_model.inference
api.batch_inference(sample_images, batch_size=10, process_res=1024)
# Check that inference was called with correct process_res
call_args = mock_model.inference.call_args
assert call_args.kwargs.get("process_res") == 1024
def test_batch_inference_max_batch_size_auto(self, mock_model, sample_images):
"""Test max_batch_size parameter with auto batching."""
from depth_anything_3.api import DepthAnything3
with patch.object(DepthAnything3, "__init__", lambda x, **k: None):
api = DepthAnything3()
api.device = mock_model.device
api.model_name = mock_model.model_name
api.inference = mock_model.inference
# With max_batch_size=2, should split 10 images into more batches
results = api.batch_inference(
sample_images, batch_size="auto", max_batch_size=2
)
# Should have at least 5 batches (10 images / 2 max)
assert len(results) >= 5
# =============================================================================
# get_optimal_batch_size Tests
# =============================================================================
class TestGetOptimalBatchSize:
"""Tests for the get_optimal_batch_size method."""
def test_get_optimal_batch_size_returns_int(self, cpu_device):
"""Test that get_optimal_batch_size returns an integer."""
from depth_anything_3.api import DepthAnything3
with patch.object(DepthAnything3, "__init__", lambda x, **k: None):
api = DepthAnything3()
api.device = cpu_device
api.model_name = "da3-large"
result = api.get_optimal_batch_size()
assert isinstance(result, int)
assert result > 0
def test_get_optimal_batch_size_respects_resolution(self, cpu_device):
"""Test that different resolutions affect the result."""
from depth_anything_3.api import DepthAnything3
with patch.object(DepthAnything3, "__init__", lambda x, **k: None):
api = DepthAnything3()
api.device = cpu_device
api.model_name = "da3-large"
low_res = api.get_optimal_batch_size(process_res=256)
high_res = api.get_optimal_batch_size(process_res=1024)
# Both should be valid
assert low_res > 0
assert high_res > 0
def test_get_optimal_batch_size_respects_utilization(self, cpu_device):
"""Test that target_utilization parameter is used."""
from depth_anything_3.api import DepthAnything3
with patch.object(DepthAnything3, "__init__", lambda x, **k: None):
api = DepthAnything3()
api.device = cpu_device
api.model_name = "da3-large"
low_util = api.get_optimal_batch_size(target_utilization=0.5)
high_util = api.get_optimal_batch_size(target_utilization=0.95)
# Both should return valid results
assert low_util > 0
assert high_util > 0
def test_get_optimal_batch_size_different_models(self, cpu_device):
"""Test with different model names."""
from depth_anything_3.api import DepthAnything3
models = ["da3-small", "da3-base", "da3-large", "da3-giant"]
for model_name in models:
with patch.object(DepthAnything3, "__init__", lambda x, **k: None):
api = DepthAnything3()
api.device = cpu_device
api.model_name = model_name
result = api.get_optimal_batch_size()
assert result > 0, f"Failed for model {model_name}"
# =============================================================================
# Integration Tests
# =============================================================================
class TestBatchingIntegration:
"""Integration tests for batching functionality."""
def test_auto_vs_fixed_batching_coverage(self, mock_model, sample_images):
"""Test that both auto and fixed batching process all images."""
from depth_anything_3.api import DepthAnything3
with patch.object(DepthAnything3, "__init__", lambda x, **k: None):
api = DepthAnything3()
api.device = mock_model.device
api.model_name = mock_model.model_name
api.inference = mock_model.inference
# Track images processed
auto_images_processed = []
fixed_images_processed = []
def track_auto(image, **kwargs):
batch = image if isinstance(image, list) else [image]
auto_images_processed.extend(batch)
return MockPrediction.create(len(batch))
def track_fixed(image, **kwargs):
batch = image if isinstance(image, list) else [image]
fixed_images_processed.extend(batch)
return MockPrediction.create(len(batch))
# Test auto batching
mock_model.inference.side_effect = track_auto
api.inference = mock_model.inference
api.batch_inference(sample_images.copy(), batch_size="auto")
# Test fixed batching
mock_model.inference.side_effect = track_fixed
api.inference = mock_model.inference
api.batch_inference(sample_images.copy(), batch_size=3)
# Both should process all images
assert len(auto_images_processed) == len(sample_images)
assert len(fixed_images_processed) == len(sample_images)
def test_batch_inference_preserves_order(self, mock_model):
"""Test that batch_inference preserves image order in processing."""
from depth_anything_3.api import DepthAnything3
images = ["first.jpg", "second.jpg", "third.jpg", "fourth.jpg", "fifth.jpg"]
processed_order = []
def track_order(image, **kwargs):
batch = image if isinstance(image, list) else [image]
processed_order.extend(batch)
return MockPrediction.create(len(batch))
with patch.object(DepthAnything3, "__init__", lambda x, **k: None):
api = DepthAnything3()
api.device = mock_model.device
api.model_name = mock_model.model_name
mock_model.inference.side_effect = track_order
api.inference = mock_model.inference
api.batch_inference(images, batch_size=2)
assert processed_order == images
def test_progress_increases_monotonically(self, mock_model, sample_images):
"""Test that progress always increases."""
from depth_anything_3.api import DepthAnything3
progress_values = []
def progress_callback(processed, total):
progress_values.append(processed)
with patch.object(DepthAnything3, "__init__", lambda x, **k: None):
api = DepthAnything3()
api.device = mock_model.device
api.model_name = mock_model.model_name
api.inference = mock_model.inference
api.batch_inference(
sample_images, batch_size=3, progress_callback=progress_callback
)
# Progress should always increase
for i in range(1, len(progress_values)):
assert progress_values[i] > progress_values[i - 1]
# =============================================================================
# Edge Cases
# =============================================================================
class TestBatchingEdgeCases:
"""Tests for edge cases in batching."""
def test_batch_size_one(self, mock_model, sample_images):
"""Test with batch size of 1."""
from depth_anything_3.api import DepthAnything3
with patch.object(DepthAnything3, "__init__", lambda x, **k: None):
api = DepthAnything3()
api.device = mock_model.device
api.model_name = mock_model.model_name
api.inference = mock_model.inference
results = api.batch_inference(sample_images, batch_size=1)
# Should have one result per image
assert len(results) == len(sample_images)
assert mock_model.inference.call_count == len(sample_images)
def test_very_large_batch_size(self, mock_model, sample_images):
"""Test with very large batch size."""
from depth_anything_3.api import DepthAnything3
with patch.object(DepthAnything3, "__init__", lambda x, **k: None):
api = DepthAnything3()
api.device = mock_model.device
api.model_name = mock_model.model_name
api.inference = mock_model.inference
results = api.batch_inference(sample_images, batch_size=1000)
# Should process all in one batch
assert len(results) == 1
def test_auto_with_very_low_memory_utilization(self, mock_model, sample_images):
"""Test auto batching with very low memory utilization target."""
from depth_anything_3.api import DepthAnything3
with patch.object(DepthAnything3, "__init__", lambda x, **k: None):
api = DepthAnything3()
api.device = mock_model.device
api.model_name = mock_model.model_name
api.inference = mock_model.inference
results = api.batch_inference(
sample_images, batch_size="auto", target_memory_utilization=0.1
)
# Should still process all images
total_processed = sum(r.num_images for r in results)
assert total_processed == len(sample_images)
def test_numpy_array_inputs(self, mock_model):
"""Test with numpy array inputs instead of paths."""
from depth_anything_3.api import DepthAnything3
# Create dummy numpy arrays
images = [np.zeros((256, 256, 3), dtype=np.uint8) for _ in range(5)]
with patch.object(DepthAnything3, "__init__", lambda x, **k: None):
api = DepthAnything3()
api.device = mock_model.device
api.model_name = mock_model.model_name
api.inference = mock_model.inference
results = api.batch_inference(images, batch_size=2)
assert len(results) == 3 # 5 images in batches of 2: 2, 2, 1
# =============================================================================
# Memory Cleanup Tests
# =============================================================================
class TestMemoryCleanup:
"""Tests for memory cleanup during batching."""
def test_gc_collect_called_between_batches(self, mock_model, sample_images):
"""Test that garbage collection is called between batches."""
from depth_anything_3.api import DepthAnything3
with patch.object(DepthAnything3, "__init__", lambda x, **k: None):
api = DepthAnything3()
api.device = mock_model.device
api.model_name = mock_model.model_name
api.inference = mock_model.inference
with patch("gc.collect") as mock_gc:
api.batch_inference(sample_images, batch_size=3)
# Should call gc.collect between batches (not after last)
# 4 batches means 3 gc.collect calls
assert mock_gc.call_count == 3
def test_cuda_empty_cache_called(self, sample_images):
"""Test that cuda empty_cache is called on CUDA device."""
from depth_anything_3.api import DepthAnything3
def mock_inference(image, **kwargs):
num = len(image) if isinstance(image, list) else 1
return MockPrediction.create(num)
mock_model = MagicMock(spec=DepthAnything3)
mock_model.device = torch.device("cuda:0")
mock_model.model_name = "da3-large"
mock_model.inference = MagicMock(side_effect=mock_inference)
with patch.object(DepthAnything3, "__init__", lambda x, **k: None):
api = DepthAnything3()
api.device = mock_model.device
api.model_name = mock_model.model_name
api.inference = mock_model.inference
with patch("torch.cuda.empty_cache") as mock_empty:
api.batch_inference(sample_images, batch_size=3)
# Should call empty_cache between batches
assert mock_empty.call_count == 3
def test_mps_empty_cache_called(self, sample_images):
"""Test that mps empty_cache is called on MPS device."""
from depth_anything_3.api import DepthAnything3
def mock_inference(image, **kwargs):
num = len(image) if isinstance(image, list) else 1
return MockPrediction.create(num)
mock_model = MagicMock(spec=DepthAnything3)
mock_model.device = torch.device("mps")
mock_model.model_name = "da3-large"
mock_model.inference = MagicMock(side_effect=mock_inference)
with patch.object(DepthAnything3, "__init__", lambda x, **k: None):
api = DepthAnything3()
api.device = mock_model.device
api.model_name = mock_model.model_name
api.inference = mock_model.inference
with patch("torch.mps.empty_cache") as mock_empty:
api.batch_inference(sample_images, batch_size=3)
assert mock_empty.call_count == 3
# =============================================================================
# Run tests
# =============================================================================
if __name__ == "__main__":
pytest.main([__file__, "-v"])