File size: 5,950 Bytes
7a87926
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""
Dynamic batch sizing utilities.

Automatically adjusts batch size to maximize GPU utilization while avoiding OOM errors.
"""

import logging
import torch

logger = logging.getLogger(__name__)


class DynamicBatchSampler:
    """
    Dynamic batch sampler that adjusts batch size based on GPU memory.

    Starts with small batch size and gradually increases if successful.
    Decreases if OOM occurs.
    """

    def __init__(
        self,
        dataset,
        initial_batch_size: int = 1,
        max_batch_size: int = 8,
        min_batch_size: int = 1,
        increase_factor: float = 2.0,
        decrease_factor: float = 0.5,
        patience: int = 5,
    ):
        """
        Args:
            dataset: Dataset to sample from
            initial_batch_size: Starting batch size
            max_batch_size: Maximum batch size
            min_batch_size: Minimum batch size
            increase_factor: Factor to increase batch size
            decrease_factor: Factor to decrease batch size on OOM
            patience: Number of successful batches before increasing
        """
        self.dataset = dataset
        self.current_batch_size = initial_batch_size
        self.max_batch_size = max_batch_size
        self.min_batch_size = min_batch_size
        self.increase_factor = increase_factor
        self.decrease_factor = decrease_factor
        self.patience = patience

        self.successful_batches = 0
        self.total_batches = 0

        logger.info(
            f"DynamicBatchSampler initialized: "
            f"initial={initial_batch_size}, "
            f"max={max_batch_size}, "
            f"min={min_batch_size}"
        )

    def get_batch_size(self) -> int:
        """Get current batch size."""
        return self.current_batch_size

    def on_success(self):
        """Called after successful batch processing."""
        self.successful_batches += 1
        self.total_batches += 1

        # Increase batch size if we've had enough successes
        if self.successful_batches >= self.patience:
            new_batch_size = int(self.current_batch_size * self.increase_factor)
            if new_batch_size <= self.max_batch_size:
                old_size = self.current_batch_size
                self.current_batch_size = new_batch_size
                self.successful_batches = 0
                logger.info(f"Batch size increased: {old_size} -> {self.current_batch_size}")

    def on_oom(self):
        """Called when OOM error occurs."""
        new_batch_size = int(self.current_batch_size * self.decrease_factor)
        new_batch_size = max(new_batch_size, self.min_batch_size)

        if new_batch_size < self.current_batch_size:
            old_size = self.current_batch_size
            self.current_batch_size = new_batch_size
            self.successful_batches = 0
            logger.warning(
                f"OOM detected, batch size decreased: {old_size} -> {self.current_batch_size}"
            )

            # Clear cache
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    def get_stats(self) -> dict:
        """Get sampler statistics."""
        return {
            "current_batch_size": self.current_batch_size,
            "successful_batches": self.successful_batches,
            "total_batches": self.total_batches,
            "success_rate": self.successful_batches / max(self.total_batches, 1),
        }


class AdaptiveDataLoader:
    """
    DataLoader wrapper with dynamic batch sizing.

    Automatically adjusts batch size during training.
    """

    def __init__(
        self,
        dataset,
        initial_batch_size: int = 1,
        max_batch_size: int = 8,
        **dataloader_kwargs,
    ):
        """
        Args:
            dataset: Dataset
            initial_batch_size: Starting batch size
            max_batch_size: Maximum batch size
            **dataloader_kwargs: Additional DataLoader arguments
        """
        self.dataset = dataset
        self.initial_batch_size = initial_batch_size
        self.max_batch_size = max_batch_size
        self.dataloader_kwargs = dataloader_kwargs

        self.sampler = DynamicBatchSampler(
            dataset,
            initial_batch_size=initial_batch_size,
            max_batch_size=max_batch_size,
        )

        self.dataloader = None
        self._create_dataloader()

    def _create_dataloader(self):
        """Create DataLoader with current batch size."""
        from torch.utils.data import DataLoader

        self.dataloader = DataLoader(
            self.dataset,
            batch_size=self.sampler.get_batch_size(),
            **self.dataloader_kwargs,
        )

    def __iter__(self):
        """Iterate over dataloader with error handling."""
        iterator = iter(self.dataloader)

        while True:
            try:
                batch = next(iterator)
                yield batch
                self.sampler.on_success()
            except StopIteration:
                break
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self.sampler.on_oom()
                    # Recreate dataloader with new batch size
                    self._create_dataloader()
                    iterator = iter(self.dataloader)
                    # Retry with smaller batch
                    try:
                        batch = next(iterator)
                        yield batch
                        self.sampler.on_success()
                    except StopIteration:
                        break
                else:
                    raise

    def __len__(self):
        """Length of dataloader."""
        return len(self.dataloader)

    def get_batch_size(self) -> int:
        """Get current batch size."""
        return self.sampler.get_batch_size()

    def get_stats(self) -> dict:
        """Get statistics."""
        return self.sampler.get_stats()