File size: 30,942 Bytes
18b382b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
# Copyright (c) Delanoe Pirard / Aedelon
# Licensed under the Apache License, Version 2.0
"""
Comprehensive tests for the adaptive batching module.

Tests cover:
- ModelMemoryProfile dataclass
- Memory utility functions
- AdaptiveBatchSizeCalculator
- BatchInfo and adaptive_batch_iterator
- High-level API functions
- Edge cases and error handling
"""
from __future__ import annotations

import os
from unittest.mock import MagicMock, patch

import pytest
import torch

from depth_anything_3.utils.adaptive_batching import (
    MODEL_MEMORY_PROFILES,
    AdaptiveBatchConfig,
    AdaptiveBatchSizeCalculator,
    BatchInfo,
    ModelMemoryProfile,
    adaptive_batch_iterator,
    estimate_max_batch_size,
    get_available_memory_mb,
    get_total_memory_mb,
    log_batch_plan,
    process_with_adaptive_batching,
)


# =============================================================================
# Fixtures
# =============================================================================


@pytest.fixture
def cpu_device():
    """Return CPU device."""
    return torch.device("cpu")


@pytest.fixture
def mock_cuda_device():
    """Return mock CUDA device."""
    return torch.device("cuda:0")


@pytest.fixture
def mock_mps_device():
    """Return mock MPS device."""
    return torch.device("mps")


@pytest.fixture
def default_config():
    """Return default adaptive batch config."""
    return AdaptiveBatchConfig()


@pytest.fixture
def calculator_cpu(cpu_device):
    """Return calculator for CPU."""
    return AdaptiveBatchSizeCalculator("da3-large", cpu_device)


# =============================================================================
# ModelMemoryProfile Tests
# =============================================================================


class TestModelMemoryProfile:
    """Tests for ModelMemoryProfile dataclass."""

    def test_default_values(self):
        """Test default values are set correctly."""
        profile = ModelMemoryProfile(
            base_memory_mb=1000,
            per_image_mb_at_504=500,
        )
        assert profile.base_memory_mb == 1000
        assert profile.per_image_mb_at_504 == 500
        assert profile.activation_scale == 1.0
        assert profile.safety_margin == 0.15

    def test_custom_values(self):
        """Test custom values override defaults."""
        profile = ModelMemoryProfile(
            base_memory_mb=2000,
            per_image_mb_at_504=800,
            activation_scale=1.5,
            safety_margin=0.2,
        )
        assert profile.base_memory_mb == 2000
        assert profile.per_image_mb_at_504 == 800
        assert profile.activation_scale == 1.5
        assert profile.safety_margin == 0.2

    def test_all_models_have_profiles(self):
        """Test that all expected models have memory profiles."""
        expected_models = [
            "da3-small",
            "da3-base",
            "da3-large",
            "da3-giant",
            "da3metric-large",
            "da3mono-large",
            "da3nested-giant-large",
        ]
        for model_name in expected_models:
            assert model_name in MODEL_MEMORY_PROFILES
            profile = MODEL_MEMORY_PROFILES[model_name]
            assert profile.base_memory_mb > 0
            assert profile.per_image_mb_at_504 > 0

    def test_profiles_size_ordering(self):
        """Test that model profiles have expected size ordering."""
        small = MODEL_MEMORY_PROFILES["da3-small"]
        base = MODEL_MEMORY_PROFILES["da3-base"]
        large = MODEL_MEMORY_PROFILES["da3-large"]
        giant = MODEL_MEMORY_PROFILES["da3-giant"]

        # Base memory should increase with model size
        assert small.base_memory_mb < base.base_memory_mb
        assert base.base_memory_mb < large.base_memory_mb
        assert large.base_memory_mb < giant.base_memory_mb

        # Per-image memory should also increase
        assert small.per_image_mb_at_504 < base.per_image_mb_at_504
        assert base.per_image_mb_at_504 < large.per_image_mb_at_504
        assert large.per_image_mb_at_504 < giant.per_image_mb_at_504


# =============================================================================
# Memory Utility Tests
# =============================================================================


class TestGetAvailableMemory:
    """Tests for get_available_memory_mb function."""

    def test_cpu_returns_infinity(self, cpu_device):
        """CPU should return infinite memory."""
        result = get_available_memory_mb(cpu_device)
        assert result == float("inf")

    @patch("torch.cuda.is_available", return_value=True)
    @patch("torch.cuda.synchronize")
    @patch("torch.cuda.get_device_properties")
    @patch("torch.cuda.memory_reserved")
    def test_cuda_memory_calculation(
        self,
        mock_reserved,
        mock_properties,
        mock_sync,
        mock_available,
        mock_cuda_device,
    ):
        """Test CUDA memory calculation."""
        # Setup mocks
        mock_props = MagicMock()
        mock_props.total_memory = 16 * 1024 * 1024 * 1024  # 16 GB
        mock_properties.return_value = mock_props
        mock_reserved.return_value = 4 * 1024 * 1024 * 1024  # 4 GB reserved

        result = get_available_memory_mb(mock_cuda_device)

        # Should be (16GB - 4GB) in MB = 12288 MB
        expected = (16 - 4) * 1024
        assert result == expected

    def test_mps_memory_with_env_var(self, mock_mps_device, monkeypatch):
        """Test MPS memory respects environment variable."""
        monkeypatch.setenv("DA3_MPS_MAX_MEMORY_GB", "16")

        with patch("torch.mps.current_allocated_memory", return_value=0):
            result = get_available_memory_mb(mock_mps_device)
            assert result == 16 * 1024  # 16 GB in MB

    def test_mps_memory_default(self, mock_mps_device, monkeypatch):
        """Test MPS memory uses default when env var not set."""
        monkeypatch.delenv("DA3_MPS_MAX_MEMORY_GB", raising=False)

        with patch("torch.mps.current_allocated_memory", return_value=0):
            result = get_available_memory_mb(mock_mps_device)
            assert result == 8 * 1024  # Default 8 GB

    def test_mps_memory_subtracts_allocated(self, mock_mps_device, monkeypatch):
        """Test MPS memory subtracts allocated memory."""
        monkeypatch.setenv("DA3_MPS_MAX_MEMORY_GB", "8")

        allocated_bytes = 2 * 1024 * 1024 * 1024  # 2 GB allocated
        with patch("torch.mps.current_allocated_memory", return_value=allocated_bytes):
            result = get_available_memory_mb(mock_mps_device)
            expected = (8 - 2) * 1024  # 6 GB remaining
            assert result == expected


class TestGetTotalMemory:
    """Tests for get_total_memory_mb function."""

    def test_cpu_returns_infinity(self, cpu_device):
        """CPU should return infinite total memory."""
        result = get_total_memory_mb(cpu_device)
        assert result == float("inf")

    @patch("torch.cuda.get_device_properties")
    def test_cuda_total_memory(self, mock_properties, mock_cuda_device):
        """Test CUDA total memory retrieval."""
        mock_props = MagicMock()
        mock_props.total_memory = 24 * 1024 * 1024 * 1024  # 24 GB
        mock_properties.return_value = mock_props

        result = get_total_memory_mb(mock_cuda_device)
        assert result == 24 * 1024  # 24 GB in MB

    def test_mps_total_memory_env_var(self, mock_mps_device, monkeypatch):
        """Test MPS total memory from environment variable."""
        monkeypatch.setenv("DA3_MPS_MAX_MEMORY_GB", "32")
        result = get_total_memory_mb(mock_mps_device)
        assert result == 32 * 1024


# =============================================================================
# AdaptiveBatchConfig Tests
# =============================================================================


class TestAdaptiveBatchConfig:
    """Tests for AdaptiveBatchConfig dataclass."""

    def test_default_values(self):
        """Test default configuration values."""
        config = AdaptiveBatchConfig()
        assert config.min_batch_size == 1
        assert config.max_batch_size == 64
        assert config.target_memory_utilization == 0.85
        assert config.enable_profiling is True
        assert config.profile_warmup_batches == 2

    def test_custom_values(self):
        """Test custom configuration values."""
        config = AdaptiveBatchConfig(
            min_batch_size=2,
            max_batch_size=32,
            target_memory_utilization=0.90,
            enable_profiling=False,
            profile_warmup_batches=5,
        )
        assert config.min_batch_size == 2
        assert config.max_batch_size == 32
        assert config.target_memory_utilization == 0.90
        assert config.enable_profiling is False
        assert config.profile_warmup_batches == 5


# =============================================================================
# AdaptiveBatchSizeCalculator Tests
# =============================================================================


class TestAdaptiveBatchSizeCalculator:
    """Tests for AdaptiveBatchSizeCalculator class."""

    def test_initialization_known_model(self, cpu_device):
        """Test initialization with known model."""
        calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device)
        assert calc.model_name == "da3-large"
        assert calc.device == cpu_device
        assert calc.profile == MODEL_MEMORY_PROFILES["da3-large"]

    def test_initialization_unknown_model_uses_fallback(self, cpu_device):
        """Test initialization with unknown model falls back to da3-large."""
        calc = AdaptiveBatchSizeCalculator("unknown-model", cpu_device)
        assert calc.profile == MODEL_MEMORY_PROFILES["da3-large"]

    def test_initialization_with_custom_config(self, cpu_device):
        """Test initialization with custom config."""
        config = AdaptiveBatchConfig(max_batch_size=16)
        calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)
        assert calc.config.max_batch_size == 16

    def test_compute_optimal_batch_size_cpu(self, cpu_device):
        """CPU should return min(num_images, max_batch_size)."""
        calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device)

        # Small number of images
        result = calc.compute_optimal_batch_size(num_images=10)
        assert result == 10

        # Large number of images
        result = calc.compute_optimal_batch_size(num_images=100)
        assert result == 64  # max_batch_size

    def test_compute_optimal_batch_size_respects_min(self, cpu_device):
        """Batch size should not go below min_batch_size."""
        config = AdaptiveBatchConfig(min_batch_size=4)
        calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)

        result = calc.compute_optimal_batch_size(num_images=2)
        # For CPU, min(num_images, max) = 2, but min_batch is applied after GPU calc
        # CPU returns min(num_images, max_batch_size) directly
        assert result == 2

    def test_compute_optimal_batch_size_respects_max(self, cpu_device):
        """Batch size should not exceed max_batch_size."""
        config = AdaptiveBatchConfig(max_batch_size=8)
        calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)

        result = calc.compute_optimal_batch_size(num_images=100)
        assert result == 8

    @patch("depth_anything_3.utils.adaptive_batching.get_available_memory_mb")
    def test_compute_optimal_batch_size_memory_based(
        self, mock_memory, mock_cuda_device
    ):
        """Test memory-based batch size calculation."""
        # 10GB available memory
        mock_memory.return_value = 10000

        calc = AdaptiveBatchSizeCalculator("da3-large", mock_cuda_device)

        result = calc.compute_optimal_batch_size(num_images=100, process_res=504)

        # Should compute based on memory
        assert 1 <= result <= 64
        assert result < 100  # Should be less than num_images given memory constraints

    @patch("depth_anything_3.utils.adaptive_batching.get_available_memory_mb")
    def test_compute_low_memory_returns_min(self, mock_memory, mock_cuda_device):
        """Low memory should return min batch size."""
        # Only 500MB available (less than base memory for da3-large)
        mock_memory.return_value = 500

        calc = AdaptiveBatchSizeCalculator("da3-large", mock_cuda_device)
        result = calc.compute_optimal_batch_size(num_images=100)

        assert result == 1  # min_batch_size

    def test_estimate_per_image_memory_resolution_scaling(self, cpu_device):
        """Test that memory scales quadratically with resolution."""
        calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device)

        mem_504 = calc._estimate_per_image_memory(504)
        mem_1008 = calc._estimate_per_image_memory(1008)

        # Memory at 2x resolution should be ~4x (quadratic scaling)
        ratio = mem_1008 / mem_504
        assert 3.5 <= ratio <= 4.5  # Allow some tolerance for activation_scale

    def test_update_from_profiling_warmup(self, cpu_device):
        """Test that warmup batches are skipped during profiling."""
        config = AdaptiveBatchConfig(profile_warmup_batches=2)
        calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)

        # First two batches (warmup) should be skipped
        calc.update_from_profiling(batch_size=4, memory_used_mb=3000, process_res=504)
        assert calc._measured_per_image_mb is None

        calc.update_from_profiling(batch_size=4, memory_used_mb=3000, process_res=504)
        assert calc._measured_per_image_mb is None

        # Third batch should update
        calc.update_from_profiling(batch_size=4, memory_used_mb=3000, process_res=504)
        assert calc._measured_per_image_mb is not None

    def test_update_from_profiling_disabled(self, cpu_device):
        """Test that profiling can be disabled."""
        config = AdaptiveBatchConfig(enable_profiling=False)
        calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)

        for _ in range(5):
            calc.update_from_profiling(batch_size=4, memory_used_mb=3000, process_res=504)

        assert calc._measured_per_image_mb is None

    def test_update_from_profiling_ema(self, cpu_device):
        """Test exponential moving average in profiling."""
        config = AdaptiveBatchConfig(profile_warmup_batches=0)
        calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)

        # First update
        calc.update_from_profiling(batch_size=4, memory_used_mb=4000, process_res=504)
        first_value = calc._measured_per_image_mb

        # Second update with different value
        calc.update_from_profiling(batch_size=4, memory_used_mb=5000, process_res=504)
        second_value = calc._measured_per_image_mb

        # EMA should smooth the values
        assert second_value is not None
        assert second_value != first_value

    def test_get_memory_estimate(self, cpu_device):
        """Test memory estimation for batch."""
        calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device)

        estimate = calc.get_memory_estimate(batch_size=4, process_res=504)

        # Should include base memory + per-image memory
        expected_min = calc.profile.base_memory_mb
        assert estimate > expected_min
        assert estimate > calc.profile.base_memory_mb


# =============================================================================
# BatchInfo Tests
# =============================================================================


class TestBatchInfo:
    """Tests for BatchInfo dataclass."""

    def test_batch_info_creation(self):
        """Test basic BatchInfo creation."""
        items = ["a", "b", "c"]
        info = BatchInfo(
            batch_idx=0,
            start_idx=0,
            end_idx=3,
            items=items,
            is_last=True,
        )
        assert info.batch_idx == 0
        assert info.start_idx == 0
        assert info.end_idx == 3
        assert info.items == ["a", "b", "c"]
        assert info.batch_size == 3
        assert info.is_last is True

    def test_batch_size_computed_from_items(self):
        """Test that batch_size is computed from items."""
        info = BatchInfo(
            batch_idx=0,
            start_idx=0,
            end_idx=5,
            items=[1, 2, 3, 4, 5],
        )
        assert info.batch_size == 5

    def test_empty_batch(self):
        """Test empty batch handling."""
        info = BatchInfo(
            batch_idx=0,
            start_idx=0,
            end_idx=0,
            items=[],
        )
        assert info.batch_size == 0


# =============================================================================
# adaptive_batch_iterator Tests
# =============================================================================


class TestAdaptiveBatchIterator:
    """Tests for adaptive_batch_iterator function."""

    def test_single_batch(self, calculator_cpu):
        """Test single batch when all items fit."""
        items = list(range(10))
        batches = list(adaptive_batch_iterator(items, calculator_cpu))

        assert len(batches) == 1
        assert batches[0].items == items
        assert batches[0].is_last is True

    def test_multiple_batches(self, cpu_device):
        """Test multiple batches with small max_batch_size."""
        config = AdaptiveBatchConfig(max_batch_size=3)
        calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)

        items = list(range(10))
        batches = list(adaptive_batch_iterator(items, calc))

        # Should have 4 batches: 3, 3, 3, 1
        assert len(batches) == 4
        assert batches[0].batch_size == 3
        assert batches[-1].batch_size == 1
        assert batches[-1].is_last is True

    def test_batch_indices_are_correct(self, cpu_device):
        """Test that batch indices are sequential."""
        config = AdaptiveBatchConfig(max_batch_size=2)
        calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)

        items = list(range(6))
        batches = list(adaptive_batch_iterator(items, calc))

        for i, batch in enumerate(batches):
            assert batch.batch_idx == i

    def test_start_end_indices_cover_all_items(self, cpu_device):
        """Test that batches cover all items without gaps."""
        config = AdaptiveBatchConfig(max_batch_size=3)
        calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)

        items = list(range(10))
        batches = list(adaptive_batch_iterator(items, calc))

        # Verify no gaps
        prev_end = 0
        for batch in batches:
            assert batch.start_idx == prev_end
            assert batch.end_idx > batch.start_idx
            prev_end = batch.end_idx

        assert prev_end == len(items)

    def test_items_are_preserved(self, cpu_device):
        """Test that all items are preserved in batches."""
        config = AdaptiveBatchConfig(max_batch_size=4)
        calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)

        original_items = ["a", "b", "c", "d", "e", "f", "g"]
        batches = list(adaptive_batch_iterator(original_items, calc))

        # Collect all items from batches
        collected = []
        for batch in batches:
            collected.extend(batch.items)

        assert collected == original_items

    def test_empty_sequence(self, calculator_cpu):
        """Test empty sequence returns no batches."""
        batches = list(adaptive_batch_iterator([], calculator_cpu))
        assert len(batches) == 0

    def test_last_batch_flag(self, cpu_device):
        """Test that only last batch has is_last=True."""
        config = AdaptiveBatchConfig(max_batch_size=2)
        calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)

        items = list(range(5))
        batches = list(adaptive_batch_iterator(items, calc))

        # All but last should be False
        for batch in batches[:-1]:
            assert batch.is_last is False

        # Last should be True
        assert batches[-1].is_last is True


# =============================================================================
# process_with_adaptive_batching Tests
# =============================================================================


class TestProcessWithAdaptiveBatching:
    """Tests for process_with_adaptive_batching function."""

    def test_basic_processing(self, cpu_device):
        """Test basic batch processing."""
        items = list(range(10))

        def process_fn(batch):
            return [x * 2 for x in batch]

        results = process_with_adaptive_batching(
            items=items,
            process_fn=process_fn,
            model_name="da3-large",
            device=cpu_device,
        )

        assert results == [x * 2 for x in items]

    def test_progress_callback(self, cpu_device):
        """Test progress callback is called."""
        items = list(range(10))
        progress_calls = []

        def process_fn(batch):
            return batch

        def progress_callback(processed, total):
            progress_calls.append((processed, total))

        config = AdaptiveBatchConfig(max_batch_size=3)

        results = process_with_adaptive_batching(
            items=items,
            process_fn=process_fn,
            model_name="da3-large",
            device=cpu_device,
            config=config,
            progress_callback=progress_callback,
        )

        # Should have multiple progress calls
        assert len(progress_calls) > 1

        # Last call should show all items processed
        assert progress_calls[-1][0] == len(items)
        assert progress_calls[-1][1] == len(items)

    def test_single_result_handling(self, cpu_device):
        """Test handling of non-list results."""
        items = list(range(5))

        def process_fn(batch):
            # Return a single value instead of list
            return sum(batch)

        results = process_with_adaptive_batching(
            items=items,
            process_fn=process_fn,
            model_name="da3-large",
            device=cpu_device,
        )

        # Should still work and return list of results
        assert isinstance(results, list)

    def test_empty_items(self, cpu_device):
        """Test with empty items list."""
        results = process_with_adaptive_batching(
            items=[],
            process_fn=lambda x: x,
            model_name="da3-large",
            device=cpu_device,
        )
        assert results == []


# =============================================================================
# Utility Function Tests
# =============================================================================


class TestEstimateMaxBatchSize:
    """Tests for estimate_max_batch_size function."""

    def test_returns_positive_integer(self, cpu_device):
        """Test that function returns positive integer."""
        result = estimate_max_batch_size("da3-large", cpu_device)
        assert isinstance(result, int)
        assert result > 0

    def test_different_resolutions(self, cpu_device):
        """Test that higher resolution gives lower batch size (for GPU)."""
        # For CPU this doesn't apply, but the function should still work
        low_res = estimate_max_batch_size("da3-large", cpu_device, process_res=504)
        high_res = estimate_max_batch_size("da3-large", cpu_device, process_res=1008)

        # Both should be valid
        assert low_res > 0
        assert high_res > 0

    def test_different_utilization(self, cpu_device):
        """Test different target utilization values."""
        low_util = estimate_max_batch_size(
            "da3-large", cpu_device, target_utilization=0.5
        )
        high_util = estimate_max_batch_size(
            "da3-large", cpu_device, target_utilization=0.95
        )

        # Both should be valid (CPU returns max_batch_size anyway)
        assert low_util > 0
        assert high_util > 0


class TestLogBatchPlan:
    """Tests for log_batch_plan function."""

    def test_log_batch_plan_runs(self, cpu_device, caplog):
        """Test that log_batch_plan runs without error."""
        import logging

        with caplog.at_level(logging.INFO):
            # Should not raise
            log_batch_plan(
                num_images=100,
                model_name="da3-large",
                device=cpu_device,
                process_res=504,
            )

    def test_log_batch_plan_different_models(self, cpu_device):
        """Test log_batch_plan with different models."""
        for model_name in ["da3-small", "da3-base", "da3-large", "da3-giant"]:
            # Should not raise for any model
            log_batch_plan(
                num_images=50,
                model_name=model_name,
                device=cpu_device,
            )


# =============================================================================
# Integration Tests
# =============================================================================


class TestIntegration:
    """Integration tests for the adaptive batching module."""

    def test_full_workflow_cpu(self, cpu_device):
        """Test complete workflow on CPU."""
        # Create data
        images = [f"image_{i}.jpg" for i in range(25)]

        # Track processing
        processed_batches = []

        def process_fn(batch):
            processed_batches.append(len(batch))
            return [f"result_{item}" for item in batch]

        # Process with adaptive batching
        config = AdaptiveBatchConfig(max_batch_size=8)
        results = process_with_adaptive_batching(
            items=images,
            process_fn=process_fn,
            model_name="da3-large",
            device=cpu_device,
            config=config,
        )

        # Verify results
        assert len(results) == len(images)
        assert all(r.startswith("result_") for r in results)

        # Verify batching
        assert sum(processed_batches) == len(images)
        assert max(processed_batches) <= 8

    def test_calculator_reuse(self, cpu_device):
        """Test that calculator can be reused across multiple iterations."""
        calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device)

        # First computation
        batch1 = calc.compute_optimal_batch_size(num_images=100)

        # Second computation should work
        batch2 = calc.compute_optimal_batch_size(num_images=50)

        assert batch1 == 64  # max_batch_size for CPU
        assert batch2 == 50  # min(50, max_batch_size)

    def test_iterator_with_strings(self, cpu_device):
        """Test iterator works with string items."""
        config = AdaptiveBatchConfig(max_batch_size=3)
        calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)

        items = ["path/to/image1.jpg", "path/to/image2.jpg", "path/to/image3.jpg", "path/to/image4.jpg"]

        batches = list(adaptive_batch_iterator(items, calc))

        # Collect all paths
        all_paths = []
        for batch in batches:
            all_paths.extend(batch.items)

        assert all_paths == items

    def test_iterator_with_tuples(self, cpu_device):
        """Test iterator works with tuple items."""
        config = AdaptiveBatchConfig(max_batch_size=2)
        calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)

        items = [(1, "a"), (2, "b"), (3, "c")]

        batches = list(adaptive_batch_iterator(items, calc))

        # Should preserve tuple structure
        all_items = []
        for batch in batches:
            all_items.extend(batch.items)

        assert all_items == list(items)


# =============================================================================
# Edge Cases
# =============================================================================


class TestEdgeCases:
    """Tests for edge cases and boundary conditions."""

    def test_single_image(self, cpu_device):
        """Test with single image."""
        calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device)

        result = calc.compute_optimal_batch_size(num_images=1)
        assert result == 1

        batches = list(adaptive_batch_iterator(["single"], calc))
        assert len(batches) == 1
        assert batches[0].items == ["single"]
        assert batches[0].is_last is True

    def test_exact_batch_size_multiple(self, cpu_device):
        """Test when num_images is exact multiple of batch_size."""
        config = AdaptiveBatchConfig(max_batch_size=5)
        calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)

        items = list(range(15))  # Exactly 3 batches of 5
        batches = list(adaptive_batch_iterator(items, calc))

        assert len(batches) == 3
        assert all(b.batch_size == 5 for b in batches)

    def test_very_large_num_images(self, cpu_device):
        """Test with very large number of images."""
        calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device)

        result = calc.compute_optimal_batch_size(num_images=1_000_000)
        assert result == 64  # Should cap at max_batch_size

    def test_zero_reserved_memory(self, cpu_device):
        """Test with zero reserved memory."""
        calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device)

        result = calc.compute_optimal_batch_size(
            num_images=100,
            process_res=504,
            reserved_memory_mb=0,
        )
        assert result > 0

    def test_high_resolution(self, cpu_device):
        """Test with very high resolution."""
        calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device)

        # 4K resolution
        result = calc.compute_optimal_batch_size(
            num_images=100,
            process_res=2160,
        )
        assert result > 0  # Should still return valid batch size

    def test_low_resolution(self, cpu_device):
        """Test with very low resolution."""
        calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device)

        result = calc.compute_optimal_batch_size(
            num_images=100,
            process_res=128,
        )
        assert result > 0

    def test_negative_memory_edge_case(self, cpu_device):
        """Test handling when calculations could go negative."""
        config = AdaptiveBatchConfig(
            min_batch_size=1,
            target_memory_utilization=0.01,  # Very low utilization
        )
        calc = AdaptiveBatchSizeCalculator("da3-large", cpu_device, config)

        # Should still return valid result
        result = calc.compute_optimal_batch_size(num_images=100)
        assert result >= 1


# =============================================================================
# Run tests
# =============================================================================


if __name__ == "__main__":
    pytest.main([__file__, "-v"])