awesome-depth-anything-3 / tests /test_adaptive_batching.py
Delanoe Pirard
Deploy to HuggingFace Spaces
18b382b
# Copyright (c) Delanoe Pirard / Aedelon
# Licensed under the Apache License, Version 2.0
"""
Comprehensive tests for the adaptive batching module.
Tests cover:
- ModelMemoryProfile dataclass
- Memory utility functions
- AdaptiveBatchSizeCalculator
- BatchInfo and adaptive_batch_iterator
- High-level API functions
- Edge cases and error handling
"""
from __future__ import annotations
import os
from unittest.mock import MagicMock, patch
import pytest
import torch
from depth_anything_3.utils.adaptive_batching import (
MODEL_MEMORY_PROFILES,
AdaptiveBatchConfig,
AdaptiveBatchSizeCalculator,
BatchInfo,
ModelMemoryProfile,
adaptive_batch_iterator,
estimate_max_batch_size,
get_available_memory_mb,
get_total_memory_mb,
log_batch_plan,
process_with_adaptive_batching,
)
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture
def cpu_device():
"""Return CPU device."""
return torch.device("cpu")
@pytest.fixture
def mock_cuda_device():
"""Return mock CUDA device."""
return torch.device("cuda:0")
@pytest.fixture
def mock_mps_device():
"""Return mock MPS device."""
return torch.device("mps")
@pytest.fixture
def default_config():
"""Return default adaptive batch config."""
return AdaptiveBatchConfig()
@pytest.fixture
def calculator_cpu(cpu_device):
"""Return calculator for CPU."""
return AdaptiveBatchSizeCalculator("da3-large", cpu_device)
# =============================================================================
# ModelMemoryProfile Tests
# =============================================================================
class TestModelMemoryProfile:
"""Tests for ModelMemoryProfile dataclass."""
def test_default_values(self):
"""Test default values are set correctly."""
profile = ModelMemoryProfile(
base_memory_mb=1000,
per_image_mb_at_504=500,
)
assert profile.base_memory_mb == 1000
assert profile.per_image_mb_at_504 == 500
assert profile.activation_scale == 1.0
assert profile.safety_margin == 0.15
def test_custom_values(self):
"""Test custom values override defaults."""
profile = ModelMemoryProfile(
base_memory_mb=2000,
per_image_mb_at_504=800,
activation_scale=1.5,
safety_margin=0.2,
)
assert profile.base_memory_mb == 2000
assert profile.per_image_mb_at_504 == 800
assert profile.activation_scale == 1.5
assert profile.safety_margin == 0.2
def test_all_models_have_profiles(self):
"""Test that all expected models have memory profiles."""
expected_models = [
"da3-small",
"da3-base",
"da3-large",
"da3-giant",
"da3metric-large",
"da3mono-large",
"da3nested-giant-large",
]
for model_name in expected_models:
assert model_name in MODEL_MEMORY_PROFILES
profile = MODEL_MEMORY_PROFILES[model_name]
assert profile.base_memory_mb > 0
assert profile.per_image_mb_at_504 > 0
def test_profiles_size_ordering(self):
"""Test that model profiles have expected size ordering."""
small = MODEL_MEMORY_PROFILES["da3-small"]
base = MODEL_MEMORY_PROFILES["da3-base"]
large = MODEL_MEMORY_PROFILES["da3-large"]
giant = MODEL_MEMORY_PROFILES["da3-giant"]
# Base memory should increase with model size
assert small.base_memory_mb < base.base_memory_mb
assert base.base_memory_mb < large.base_memory_mb
assert large.base_memory_mb < giant.base_memory_mb
# Per-image memory should also increase
assert small.per_image_mb_at_504 < base.per_image_mb_at_504
assert base.per_image_mb_at_504 < large.per_image_mb_at_504
assert large.per_image_mb_at_504 < giant.per_image_mb_at_504
# =============================================================================
# Memory Utility Tests
# =============================================================================
class TestGetAvailableMemory:
"""Tests for get_available_memory_mb function."""
def test_cpu_returns_infinity(self, cpu_device):
"""CPU should return infinite memory."""
result = get_available_memory_mb(cpu_device)
assert result == float("inf")
@patch("torch.cuda.is_available", return_value=True)
@patch("torch.cuda.synchronize")
@patch("torch.cuda.get_device_properties")
@patch("torch.cuda.memory_reserved")
def test_cuda_memory_calculation(
self,
mock_reserved,
mock_properties,
mock_sync,
mock_available,
mock_cuda_device,
):
"""Test CUDA memory calculation."""
# Setup mocks
mock_props = MagicMock()
mock_props.total_memory = 16 * 1024 * 1024 * 1024 # 16 GB
mock_properties.return_value = mock_props
mock_reserved.return_value = 4 * 1024 * 1024 * 1024 # 4 GB reserved
result = get_available_memory_mb(mock_cuda_device)
# Should be (16GB - 4GB) in MB = 12288 MB
expected = (16 - 4) * 1024
assert result == expected
def test_mps_memory_with_env_var(self, mock_mps_device, monkeypatch):
"""Test MPS memory respects environment variable."""
monkeypatch.setenv("DA3_MPS_MAX_MEMORY_GB", "16")
with patch("torch.mps.current_allocated_memory", return_value=0):
result = get_available_memory_mb(mock_mps_device)
assert result == 16 * 1024 # 16 GB in MB
def test_mps_memory_default(self, mock_mps_device, monkeypatch):
"""Test MPS memory uses default when env var not set."""
monkeypatch.delenv("DA3_MPS_MAX_MEMORY_GB", raising=False)
with patch("torch.mps.current_allocated_memory", return_value=0):
result = get_available_memory_mb(mock_mps_device)
assert result == 8 * 1024 # Default 8 GB
def test_mps_memory_subtracts_allocated(self, mock_mps_device, monkeypatch):
"""Test MPS memory subtracts allocated memory."""
monkeypatch.setenv("DA3_MPS_MAX_MEMORY_GB", "8")
allocated_bytes = 2 * 1024 * 1024 * 1024 # 2 GB allocated
with patch("torch.mps.current_allocated_memory", return_value=allocated_bytes):
result = get_available_memory_mb(mock_mps_device)
expected = (8 - 2) * 1024 # 6 GB remaining
assert result == expected
class TestGetTotalMemory:
"""Tests for get_total_memory_mb function."""
def test_cpu_returns_infinity(self, cpu_device):
"""CPU should return infinite total memory."""
result = get_total_memory_mb(cpu_device)
assert result == float("inf")
@patch("torch.cuda.get_device_properties")
def test_cuda_total_memory(self, mock_properties, mock_cuda_device):
"""Test CUDA total memory retrieval."""
mock_props = MagicMock()
mock_props.total_memory = 24 * 1024 * 1024 * 1024 # 24 GB
mock_properties.return_value = mock_props
result = get_total_memory_mb(mock_cuda_device)
assert result == 24 * 1024 # 24 GB in MB
def test_mps_total_memory_env_var(self, mock_mps_device, monkeypatch):
"""Test MPS total memory from environment variable."""
monkeypatch.setenv("DA3_MPS_MAX_MEMORY_GB", "32")
result = get_total_memory_mb(mock_mps_device)
assert result == 32 * 1024
# =============================================================================
# AdaptiveBatchConfig Tests
# =============================================================================
class TestAdaptiveBatchConfig:
"""Tests for AdaptiveBatchConfig dataclass."""
def test_default_values(self):
"""Test default configuration values."""
config = AdaptiveBatchConfig()
assert config.min_batch_size == 1
assert config.max_batch_size == 64
assert config.target_memory_utilization == 0.85
assert config.enable_profiling is True
assert config.profile_warmup_batches == 2
def test_custom_values(self):
"""Test custom configuration values."""
config = AdaptiveBatchConfig(
min_batch_size=2,
max_batch_size=32,
target_memory_utilization=0.90,
enable_profiling=False,
profile_warmup_batches=5,
)
assert config.min_batch_size == 2
assert config.max_batch_size == 32
assert config.target_memory_utilization == 0.90
assert config.enable_profiling is False
assert config.profile_warmup_batches == 5
# =============================================================================
# AdaptiveBatchSizeCalculator Tests
# =============================================================================
class TestAdaptiveBatchSizeCalculator:
"""Tests for AdaptiveBatchSizeCalculator class."""
def test_initialization_known_model(self, cpu_device):
"""Test initialization with known model."""
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device)
assert calc.model_name == "da3-large"
assert calc.device == cpu_device
assert calc.profile == MODEL_MEMORY_PROFILES["da3-large"]
def test_initialization_unknown_model_uses_fallback(self, cpu_device):
"""Test initialization with unknown model falls back to da3-large."""
calc = AdaptiveBatchSizeCalculator("unknown-model", cpu_device)
assert calc.profile == MODEL_MEMORY_PROFILES["da3-large"]
def test_initialization_with_custom_config(self, cpu_device):
"""Test initialization with custom config."""
config = AdaptiveBatchConfig(max_batch_size=16)
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)
assert calc.config.max_batch_size == 16
def test_compute_optimal_batch_size_cpu(self, cpu_device):
"""CPU should return min(num_images, max_batch_size)."""
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device)
# Small number of images
result = calc.compute_optimal_batch_size(num_images=10)
assert result == 10
# Large number of images
result = calc.compute_optimal_batch_size(num_images=100)
assert result == 64 # max_batch_size
def test_compute_optimal_batch_size_respects_min(self, cpu_device):
"""Batch size should not go below min_batch_size."""
config = AdaptiveBatchConfig(min_batch_size=4)
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)
result = calc.compute_optimal_batch_size(num_images=2)
# For CPU, min(num_images, max) = 2, but min_batch is applied after GPU calc
# CPU returns min(num_images, max_batch_size) directly
assert result == 2
def test_compute_optimal_batch_size_respects_max(self, cpu_device):
"""Batch size should not exceed max_batch_size."""
config = AdaptiveBatchConfig(max_batch_size=8)
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)
result = calc.compute_optimal_batch_size(num_images=100)
assert result == 8
@patch("depth_anything_3.utils.adaptive_batching.get_available_memory_mb")
def test_compute_optimal_batch_size_memory_based(
self, mock_memory, mock_cuda_device
):
"""Test memory-based batch size calculation."""
# 10GB available memory
mock_memory.return_value = 10000
calc = AdaptiveBatchSizeCalculator("da3-large", mock_cuda_device)
result = calc.compute_optimal_batch_size(num_images=100, process_res=504)
# Should compute based on memory
assert 1 <= result <= 64
assert result < 100 # Should be less than num_images given memory constraints
@patch("depth_anything_3.utils.adaptive_batching.get_available_memory_mb")
def test_compute_low_memory_returns_min(self, mock_memory, mock_cuda_device):
"""Low memory should return min batch size."""
# Only 500MB available (less than base memory for da3-large)
mock_memory.return_value = 500
calc = AdaptiveBatchSizeCalculator("da3-large", mock_cuda_device)
result = calc.compute_optimal_batch_size(num_images=100)
assert result == 1 # min_batch_size
def test_estimate_per_image_memory_resolution_scaling(self, cpu_device):
"""Test that memory scales quadratically with resolution."""
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device)
mem_504 = calc._estimate_per_image_memory(504)
mem_1008 = calc._estimate_per_image_memory(1008)
# Memory at 2x resolution should be ~4x (quadratic scaling)
ratio = mem_1008 / mem_504
assert 3.5 <= ratio <= 4.5 # Allow some tolerance for activation_scale
def test_update_from_profiling_warmup(self, cpu_device):
"""Test that warmup batches are skipped during profiling."""
config = AdaptiveBatchConfig(profile_warmup_batches=2)
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)
# First two batches (warmup) should be skipped
calc.update_from_profiling(batch_size=4, memory_used_mb=3000, process_res=504)
assert calc._measured_per_image_mb is None
calc.update_from_profiling(batch_size=4, memory_used_mb=3000, process_res=504)
assert calc._measured_per_image_mb is None
# Third batch should update
calc.update_from_profiling(batch_size=4, memory_used_mb=3000, process_res=504)
assert calc._measured_per_image_mb is not None
def test_update_from_profiling_disabled(self, cpu_device):
"""Test that profiling can be disabled."""
config = AdaptiveBatchConfig(enable_profiling=False)
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)
for _ in range(5):
calc.update_from_profiling(batch_size=4, memory_used_mb=3000, process_res=504)
assert calc._measured_per_image_mb is None
def test_update_from_profiling_ema(self, cpu_device):
"""Test exponential moving average in profiling."""
config = AdaptiveBatchConfig(profile_warmup_batches=0)
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)
# First update
calc.update_from_profiling(batch_size=4, memory_used_mb=4000, process_res=504)
first_value = calc._measured_per_image_mb
# Second update with different value
calc.update_from_profiling(batch_size=4, memory_used_mb=5000, process_res=504)
second_value = calc._measured_per_image_mb
# EMA should smooth the values
assert second_value is not None
assert second_value != first_value
def test_get_memory_estimate(self, cpu_device):
"""Test memory estimation for batch."""
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device)
estimate = calc.get_memory_estimate(batch_size=4, process_res=504)
# Should include base memory + per-image memory
expected_min = calc.profile.base_memory_mb
assert estimate > expected_min
assert estimate > calc.profile.base_memory_mb
# =============================================================================
# BatchInfo Tests
# =============================================================================
class TestBatchInfo:
"""Tests for BatchInfo dataclass."""
def test_batch_info_creation(self):
"""Test basic BatchInfo creation."""
items = ["a", "b", "c"]
info = BatchInfo(
batch_idx=0,
start_idx=0,
end_idx=3,
items=items,
is_last=True,
)
assert info.batch_idx == 0
assert info.start_idx == 0
assert info.end_idx == 3
assert info.items == ["a", "b", "c"]
assert info.batch_size == 3
assert info.is_last is True
def test_batch_size_computed_from_items(self):
"""Test that batch_size is computed from items."""
info = BatchInfo(
batch_idx=0,
start_idx=0,
end_idx=5,
items=[1, 2, 3, 4, 5],
)
assert info.batch_size == 5
def test_empty_batch(self):
"""Test empty batch handling."""
info = BatchInfo(
batch_idx=0,
start_idx=0,
end_idx=0,
items=[],
)
assert info.batch_size == 0
# =============================================================================
# adaptive_batch_iterator Tests
# =============================================================================
class TestAdaptiveBatchIterator:
"""Tests for adaptive_batch_iterator function."""
def test_single_batch(self, calculator_cpu):
"""Test single batch when all items fit."""
items = list(range(10))
batches = list(adaptive_batch_iterator(items, calculator_cpu))
assert len(batches) == 1
assert batches[0].items == items
assert batches[0].is_last is True
def test_multiple_batches(self, cpu_device):
"""Test multiple batches with small max_batch_size."""
config = AdaptiveBatchConfig(max_batch_size=3)
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)
items = list(range(10))
batches = list(adaptive_batch_iterator(items, calc))
# Should have 4 batches: 3, 3, 3, 1
assert len(batches) == 4
assert batches[0].batch_size == 3
assert batches[-1].batch_size == 1
assert batches[-1].is_last is True
def test_batch_indices_are_correct(self, cpu_device):
"""Test that batch indices are sequential."""
config = AdaptiveBatchConfig(max_batch_size=2)
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)
items = list(range(6))
batches = list(adaptive_batch_iterator(items, calc))
for i, batch in enumerate(batches):
assert batch.batch_idx == i
def test_start_end_indices_cover_all_items(self, cpu_device):
"""Test that batches cover all items without gaps."""
config = AdaptiveBatchConfig(max_batch_size=3)
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)
items = list(range(10))
batches = list(adaptive_batch_iterator(items, calc))
# Verify no gaps
prev_end = 0
for batch in batches:
assert batch.start_idx == prev_end
assert batch.end_idx > batch.start_idx
prev_end = batch.end_idx
assert prev_end == len(items)
def test_items_are_preserved(self, cpu_device):
"""Test that all items are preserved in batches."""
config = AdaptiveBatchConfig(max_batch_size=4)
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)
original_items = ["a", "b", "c", "d", "e", "f", "g"]
batches = list(adaptive_batch_iterator(original_items, calc))
# Collect all items from batches
collected = []
for batch in batches:
collected.extend(batch.items)
assert collected == original_items
def test_empty_sequence(self, calculator_cpu):
"""Test empty sequence returns no batches."""
batches = list(adaptive_batch_iterator([], calculator_cpu))
assert len(batches) == 0
def test_last_batch_flag(self, cpu_device):
"""Test that only last batch has is_last=True."""
config = AdaptiveBatchConfig(max_batch_size=2)
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)
items = list(range(5))
batches = list(adaptive_batch_iterator(items, calc))
# All but last should be False
for batch in batches[:-1]:
assert batch.is_last is False
# Last should be True
assert batches[-1].is_last is True
# =============================================================================
# process_with_adaptive_batching Tests
# =============================================================================
class TestProcessWithAdaptiveBatching:
"""Tests for process_with_adaptive_batching function."""
def test_basic_processing(self, cpu_device):
"""Test basic batch processing."""
items = list(range(10))
def process_fn(batch):
return [x * 2 for x in batch]
results = process_with_adaptive_batching(
items=items,
process_fn=process_fn,
model_name="da3-large",
device=cpu_device,
)
assert results == [x * 2 for x in items]
def test_progress_callback(self, cpu_device):
"""Test progress callback is called."""
items = list(range(10))
progress_calls = []
def process_fn(batch):
return batch
def progress_callback(processed, total):
progress_calls.append((processed, total))
config = AdaptiveBatchConfig(max_batch_size=3)
results = process_with_adaptive_batching(
items=items,
process_fn=process_fn,
model_name="da3-large",
device=cpu_device,
config=config,
progress_callback=progress_callback,
)
# Should have multiple progress calls
assert len(progress_calls) > 1
# Last call should show all items processed
assert progress_calls[-1][0] == len(items)
assert progress_calls[-1][1] == len(items)
def test_single_result_handling(self, cpu_device):
"""Test handling of non-list results."""
items = list(range(5))
def process_fn(batch):
# Return a single value instead of list
return sum(batch)
results = process_with_adaptive_batching(
items=items,
process_fn=process_fn,
model_name="da3-large",
device=cpu_device,
)
# Should still work and return list of results
assert isinstance(results, list)
def test_empty_items(self, cpu_device):
"""Test with empty items list."""
results = process_with_adaptive_batching(
items=[],
process_fn=lambda x: x,
model_name="da3-large",
device=cpu_device,
)
assert results == []
# =============================================================================
# Utility Function Tests
# =============================================================================
class TestEstimateMaxBatchSize:
"""Tests for estimate_max_batch_size function."""
def test_returns_positive_integer(self, cpu_device):
"""Test that function returns positive integer."""
result = estimate_max_batch_size("da3-large", cpu_device)
assert isinstance(result, int)
assert result > 0
def test_different_resolutions(self, cpu_device):
"""Test that higher resolution gives lower batch size (for GPU)."""
# For CPU this doesn't apply, but the function should still work
low_res = estimate_max_batch_size("da3-large", cpu_device, process_res=504)
high_res = estimate_max_batch_size("da3-large", cpu_device, process_res=1008)
# Both should be valid
assert low_res > 0
assert high_res > 0
def test_different_utilization(self, cpu_device):
"""Test different target utilization values."""
low_util = estimate_max_batch_size(
"da3-large", cpu_device, target_utilization=0.5
)
high_util = estimate_max_batch_size(
"da3-large", cpu_device, target_utilization=0.95
)
# Both should be valid (CPU returns max_batch_size anyway)
assert low_util > 0
assert high_util > 0
class TestLogBatchPlan:
"""Tests for log_batch_plan function."""
def test_log_batch_plan_runs(self, cpu_device, caplog):
"""Test that log_batch_plan runs without error."""
import logging
with caplog.at_level(logging.INFO):
# Should not raise
log_batch_plan(
num_images=100,
model_name="da3-large",
device=cpu_device,
process_res=504,
)
def test_log_batch_plan_different_models(self, cpu_device):
"""Test log_batch_plan with different models."""
for model_name in ["da3-small", "da3-base", "da3-large", "da3-giant"]:
# Should not raise for any model
log_batch_plan(
num_images=50,
model_name=model_name,
device=cpu_device,
)
# =============================================================================
# Integration Tests
# =============================================================================
class TestIntegration:
"""Integration tests for the adaptive batching module."""
def test_full_workflow_cpu(self, cpu_device):
"""Test complete workflow on CPU."""
# Create data
images = [f"image_{i}.jpg" for i in range(25)]
# Track processing
processed_batches = []
def process_fn(batch):
processed_batches.append(len(batch))
return [f"result_{item}" for item in batch]
# Process with adaptive batching
config = AdaptiveBatchConfig(max_batch_size=8)
results = process_with_adaptive_batching(
items=images,
process_fn=process_fn,
model_name="da3-large",
device=cpu_device,
config=config,
)
# Verify results
assert len(results) == len(images)
assert all(r.startswith("result_") for r in results)
# Verify batching
assert sum(processed_batches) == len(images)
assert max(processed_batches) <= 8
def test_calculator_reuse(self, cpu_device):
"""Test that calculator can be reused across multiple iterations."""
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device)
# First computation
batch1 = calc.compute_optimal_batch_size(num_images=100)
# Second computation should work
batch2 = calc.compute_optimal_batch_size(num_images=50)
assert batch1 == 64 # max_batch_size for CPU
assert batch2 == 50 # min(50, max_batch_size)
def test_iterator_with_strings(self, cpu_device):
"""Test iterator works with string items."""
config = AdaptiveBatchConfig(max_batch_size=3)
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)
items = ["path/to/image1.jpg", "path/to/image2.jpg", "path/to/image3.jpg", "path/to/image4.jpg"]
batches = list(adaptive_batch_iterator(items, calc))
# Collect all paths
all_paths = []
for batch in batches:
all_paths.extend(batch.items)
assert all_paths == items
def test_iterator_with_tuples(self, cpu_device):
"""Test iterator works with tuple items."""
config = AdaptiveBatchConfig(max_batch_size=2)
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)
items = [(1, "a"), (2, "b"), (3, "c")]
batches = list(adaptive_batch_iterator(items, calc))
# Should preserve tuple structure
all_items = []
for batch in batches:
all_items.extend(batch.items)
assert all_items == list(items)
# =============================================================================
# Edge Cases
# =============================================================================
class TestEdgeCases:
"""Tests for edge cases and boundary conditions."""
def test_single_image(self, cpu_device):
"""Test with single image."""
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device)
result = calc.compute_optimal_batch_size(num_images=1)
assert result == 1
batches = list(adaptive_batch_iterator(["single"], calc))
assert len(batches) == 1
assert batches[0].items == ["single"]
assert batches[0].is_last is True
def test_exact_batch_size_multiple(self, cpu_device):
"""Test when num_images is exact multiple of batch_size."""
config = AdaptiveBatchConfig(max_batch_size=5)
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)
items = list(range(15)) # Exactly 3 batches of 5
batches = list(adaptive_batch_iterator(items, calc))
assert len(batches) == 3
assert all(b.batch_size == 5 for b in batches)
def test_very_large_num_images(self, cpu_device):
"""Test with very large number of images."""
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device)
result = calc.compute_optimal_batch_size(num_images=1_000_000)
assert result == 64 # Should cap at max_batch_size
def test_zero_reserved_memory(self, cpu_device):
"""Test with zero reserved memory."""
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device)
result = calc.compute_optimal_batch_size(
num_images=100,
process_res=504,
reserved_memory_mb=0,
)
assert result > 0
def test_high_resolution(self, cpu_device):
"""Test with very high resolution."""
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device)
# 4K resolution
result = calc.compute_optimal_batch_size(
num_images=100,
process_res=2160,
)
assert result > 0 # Should still return valid batch size
def test_low_resolution(self, cpu_device):
"""Test with very low resolution."""
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device)
result = calc.compute_optimal_batch_size(
num_images=100,
process_res=128,
)
assert result > 0
def test_negative_memory_edge_case(self, cpu_device):
"""Test handling when calculations could go negative."""
config = AdaptiveBatchConfig(
min_batch_size=1,
target_memory_utilization=0.01, # Very low utilization
)
calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)
# Should still return valid result
result = calc.compute_optimal_batch_size(num_images=100)
assert result >= 1
# =============================================================================
# Run tests
# =============================================================================
if __name__ == "__main__":
pytest.main([__file__, "-v"])