File size: 4,965 Bytes
7decfe1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | # Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
# More information about Marigold:
# https://marigoldmonodepth.github.io
# https://marigoldcomputervision.github.io
# Efficient inference pipelines are now part of diffusers:
# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
# https://huggingface.co/docs/diffusers/api/pipelines/marigold
# Examples of trained models and live demos:
# https://huggingface.co/prs-eth
# Related projects:
# https://rollingdepth.github.io/
# https://marigolddepthcompletion.github.io/
# Citation (BibTeX):
# https://github.com/prs-eth/Marigold#-citation
# If you find Marigold useful, we kindly ask you to cite our papers.
# --------------------------------------------------------------------------
# Adapted from https://github.com/huggingface/accelerate/blob/e2ae254008061b3e53fc1c97f88d65743a857e75/src/accelerate/data_loader.py
from torch.utils.data import BatchSampler, DataLoader, IterableDataset
# kwargs of the DataLoader in min version 1.4.0.
_PYTORCH_DATALOADER_KWARGS = {
"batch_size": 1,
"shuffle": False,
"sampler": None,
"batch_sampler": None,
"num_workers": 0,
"collate_fn": None,
"pin_memory": False,
"drop_last": False,
"timeout": 0,
"worker_init_fn": None,
"multiprocessing_context": None,
"generator": None,
"prefetch_factor": 2,
"persistent_workers": False,
}
class SkipBatchSampler(BatchSampler):
"""
A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`.
"""
def __init__(self, batch_sampler, skip_batches=0):
self.batch_sampler = batch_sampler
self.skip_batches = skip_batches
def __iter__(self):
for index, samples in enumerate(self.batch_sampler):
if index >= self.skip_batches:
yield samples
@property
def total_length(self):
return len(self.batch_sampler)
def __len__(self):
return len(self.batch_sampler) - self.skip_batches
class SkipDataLoader(DataLoader):
"""
Subclass of a PyTorch `DataLoader` that will skip the first batches.
Args:
dataset (`torch.utils.data.dataset.Dataset`):
The dataset to use to build this datalaoder.
skip_batches (`int`, *optional*, defaults to 0):
The number of batches to skip at the beginning.
kwargs:
All other keyword arguments to pass to the regular `DataLoader` initialization.
"""
def __init__(self, dataset, skip_batches=0, **kwargs):
super().__init__(dataset, **kwargs)
self.skip_batches = skip_batches
def __iter__(self):
for index, batch in enumerate(super().__iter__()):
if index >= self.skip_batches:
yield batch
def skip_first_batches(dataloader, num_batches=0):
"""
Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`.
"""
dataset = dataloader.dataset
sampler_is_batch_sampler = False
if isinstance(dataset, IterableDataset):
new_batch_sampler = None
else:
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler)
batch_sampler = (
dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
)
new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches)
# We ignore all of those since they are all dealt with by our new_batch_sampler
ignore_kwargs = [
"batch_size",
"shuffle",
"sampler",
"batch_sampler",
"drop_last",
]
kwargs = {
k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k])
for k in _PYTORCH_DATALOADER_KWARGS
if k not in ignore_kwargs
}
# Need to provide batch_size as batch_sampler is None for Iterable dataset
if new_batch_sampler is None:
kwargs["drop_last"] = dataloader.drop_last
kwargs["batch_size"] = dataloader.batch_size
if new_batch_sampler is None:
# Need to manually skip batches in the dataloader
dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs)
else:
dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)
return dataloader
|