File size: 12,168 Bytes
e062359 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 | # Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
import os
from copy import deepcopy
import datasets
import evaluate
import torch
import transformers
from datasets import load_dataset
from torch.utils.data import DataLoader, IterableDataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from accelerate import Accelerator, DataLoaderConfiguration, DistributedType
from accelerate.data_loader import DataLoaderDispatcher
from accelerate.test_utils import RegressionDataset, RegressionModel, torch_device
from accelerate.utils import is_torch_xla_available, set_seed
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
class ListHandler(logging.Handler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.logs = []
def emit(self, record):
self.logs.append(record)
def get_basic_setup(accelerator, num_samples=82, batch_size=16):
"Returns everything needed to perform basic training"
set_seed(42)
model = RegressionModel()
ddp_model = deepcopy(model)
dset = RegressionDataset(length=num_samples)
dataloader = DataLoader(dset, batch_size=batch_size)
model.to(accelerator.device)
ddp_model, dataloader = accelerator.prepare(ddp_model, dataloader)
return model, ddp_model, dataloader
def get_dataloader(accelerator: Accelerator, use_longest=False):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/mrpc-bert-base-cased")
dataset = load_dataset("glue", "mrpc", split="validation")
def tokenize_function(examples):
outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
return outputs
with accelerator.main_process_first():
tokenized_datasets = dataset.map(
tokenize_function,
batched=True,
remove_columns=["idx", "sentence1", "sentence2"],
)
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
def collate_fn(examples):
if use_longest:
return tokenizer.pad(examples, padding="longest", return_tensors="pt")
return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt")
return DataLoader(tokenized_datasets, shuffle=False, collate_fn=collate_fn, batch_size=16)
def get_mrpc_setup(dispatch_batches, split_batches):
dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, split_batches=split_batches)
accelerator = Accelerator(dataloader_config=dataloader_config)
dataloader = get_dataloader(accelerator, not dispatch_batches)
model = AutoModelForSequenceClassification.from_pretrained(
"hf-internal-testing/mrpc-bert-base-cased", return_dict=True
)
ddp_model, ddp_dataloader = accelerator.prepare(model, dataloader)
return {
"ddp": [ddp_model, ddp_dataloader, torch_device],
"no": [model, dataloader, accelerator.device],
}, accelerator
def generate_predictions(model, dataloader, accelerator):
logits_and_targets = []
for batch in dataloader:
input, target = batch.values()
with torch.no_grad():
logit = model(input)
logit, target = accelerator.gather_for_metrics((logit, target))
logits_and_targets.append((logit, target))
logits, targs = [], []
for logit, targ in logits_and_targets:
logits.append(logit)
targs.append(targ)
logits, targs = torch.cat(logits), torch.cat(targs)
return logits, targs
def test_torch_metrics(
accelerator: Accelerator, num_samples=82, dispatch_batches=False, split_batches=False, batch_size=16
):
_, ddp_model, dataloader = get_basic_setup(accelerator, num_samples, batch_size)
logits, _ = generate_predictions(ddp_model, dataloader, accelerator)
assert len(logits) == num_samples, (
f"Unexpected number of inputs:\n Expected: {num_samples}\n Actual: {len(logits)}"
)
def test_mrpc(dispatch_batches: bool = False, split_batches: bool = False):
metric = evaluate.load("glue", "mrpc")
setup, accelerator = get_mrpc_setup(dispatch_batches, split_batches)
# First do baseline
model, dataloader, device = setup["no"]
model.to(device)
model.eval()
for batch in dataloader:
batch.to(device)
with torch.inference_mode():
outputs = model(**batch)
preds = outputs.logits.argmax(dim=-1)
metric.add_batch(predictions=preds, references=batch["labels"])
baseline = metric.compute()
# Then do distributed
model, dataloader, device = setup["ddp"]
model.eval()
for batch in dataloader:
with torch.inference_mode():
outputs = model(**batch)
preds = outputs.logits.argmax(dim=-1)
references = batch["labels"]
preds, references = accelerator.gather_for_metrics((preds, references))
metric.add_batch(predictions=preds, references=references)
distributed = metric.compute()
for key in "accuracy f1".split():
assert math.isclose(baseline[key], distributed[key]), (
f"Baseline and Distributed are not the same for key {key}:\n\tBaseline: {baseline[key]}\n\tDistributed: {distributed[key]}\n"
)
def test_gather_for_metrics_with_non_tensor_objects_iterable_dataset():
class DummyIterableDataset(IterableDataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __iter__(self):
yield from self.data
iterable_dataset = DummyIterableDataset([n for n in range(30)])
dataloader = DataLoader(iterable_dataset, batch_size=4)
accelerator = Accelerator()
prepared_dataloader = accelerator.prepare(dataloader)
if accelerator.is_main_process:
logger = logging.root.manager.loggerDict["accelerate.accelerator"]
list_handler = ListHandler()
logger.addHandler(list_handler)
batches_for_metrics = []
for batch in prepared_dataloader:
batches_for_metrics.append(accelerator.gather_for_metrics(batch))
assert torch.cat(batches_for_metrics).size(0) == 30
if accelerator.is_main_process:
assert len(list_handler.logs) == 0
logger.removeHandler(list_handler)
def test_gather_for_metrics_with_iterable_dataset():
class DummyIterableDataset(IterableDataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __iter__(self):
yield from self.data
iterable_dataset = DummyIterableDataset(torch.as_tensor(range(30)))
dataloader = DataLoader(iterable_dataset, batch_size=4)
accelerator = Accelerator()
prepared_dataloader = accelerator.prepare(dataloader)
assert isinstance(prepared_dataloader, DataLoaderDispatcher)
if accelerator.is_main_process:
logger = logging.root.manager.loggerDict["accelerate.accelerator"]
list_handler = ListHandler()
logger.addHandler(list_handler)
batches_for_metrics = []
for batch in prepared_dataloader:
batches_for_metrics.append(accelerator.gather_for_metrics(batch))
assert torch.cat(batches_for_metrics).size(0) == 30
if accelerator.is_main_process:
assert len(list_handler.logs) == 0
logger.removeHandler(list_handler)
def test_gather_for_metrics_drop_last():
accelerator = Accelerator()
per_device_batch_size = 5
num_items = (10 * accelerator.num_processes) + 1
dataloader = DataLoader(range(num_items), batch_size=per_device_batch_size, drop_last=True)
dataloader = accelerator.prepare(dataloader)
iterator = iter(dataloader)
next(iterator) # Skip first batch tensor([0, 1, 2, 3, 4], device='cuda:0')
batch = next(iterator)
gathered_items = accelerator.gather_for_metrics(batch)
# Should return a full set of complete batches from each GPU
num_expected_items = per_device_batch_size * accelerator.num_processes
assert gathered_items.size(0) == (num_expected_items), (
f"Expected number of items: {num_expected_items}, Actual: {gathered_items.size(0)}"
)
def main():
dataloader_config = DataLoaderConfiguration(split_batches=False, dispatch_batches=False)
accelerator = Accelerator(dataloader_config=dataloader_config)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_warning()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
# TorchXLA does not support batch dispatching. 'put_on_device' is always False for
# TorchXLA, which can cause a value error in 'prepare_data_loader' function.
dispatch_batches_options = [False] if accelerator.state.distributed_type == DistributedType.XLA else [True, False]
# Temporarily close this test for TorchXLA due to the 'Cannot set version_counter for
# inference tensor' error in inference mode. Reopen it after TorchXLA fixes this bug.
# These are a bit slower so they should only be ran on the GPU or TPU
if accelerator.device.type != "cpu" and not is_torch_xla_available():
if accelerator.is_local_main_process:
print("**Testing gather_for_metrics**")
for split_batches in [True, False]:
for dispatch_batches in dispatch_batches_options:
if accelerator.is_local_main_process:
print(f"With: `split_batches={split_batches}`, `dispatch_batches={dispatch_batches}`")
test_mrpc(dispatch_batches, split_batches)
accelerator.state._reset_state()
print("test_gather_for_metrics_with_iterable_dataset")
test_gather_for_metrics_with_iterable_dataset()
print("test gather_for_metrics_with_non_tensor_objects_iterable_dataset")
test_gather_for_metrics_with_non_tensor_objects_iterable_dataset()
# MpDeviceLoader in TorchXLA is an asynchronous loader that preloads several batches into cache.
# This can cause the 'end_of_dataloader' of DataLoaderStateMixin to be set earlier than intended.
# Skip this test when TorchXLA is enabled.
if accelerator.state.distributed_type != DistributedType.XLA:
if accelerator.is_local_main_process:
print("**Test torch metrics**")
for split_batches in [True, False]:
for dispatch_batches in dispatch_batches_options:
dataloader_config = DataLoaderConfiguration(
split_batches=split_batches, dispatch_batches=dispatch_batches
)
accelerator = Accelerator(dataloader_config=dataloader_config)
if accelerator.is_local_main_process:
print(f"With: `split_batches={split_batches}`, `dispatch_batches={dispatch_batches}`, length=99")
test_torch_metrics(accelerator, 99)
accelerator.state._reset_state()
if accelerator.is_local_main_process:
print("**Test last batch is not dropped when perfectly divisible**")
accelerator = Accelerator()
test_torch_metrics(accelerator, 512)
accelerator.state._reset_state()
if accelerator.is_local_main_process:
print("**Test that `drop_last` is taken into account**")
test_gather_for_metrics_drop_last()
accelerator.end_training()
accelerator.state._reset_state()
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()
|