| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
| # | |
| # This source code is licensed under the BSD license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| from pathlib import Path | |
| import shutil | |
| import tempfile | |
| from torchvision.datasets import MNIST | |
| TEMPDIR = tempfile.gettempdir() | |
| def setup_cached_mnist(): | |
| done, tentatives = False, 0 | |
| while not done and tentatives < 5: | |
| # Monkey patch the resource URLs to work around a possible blacklist | |
| MNIST.mirrors = ["https://github.com/blefaudeux/mnist_dataset/raw/main/"] + MNIST.mirrors | |
| # This will automatically skip the download if the dataset is already there, and check the checksum | |
| try: | |
| _ = MNIST(transform=None, download=True, root=TEMPDIR) | |
| done = True | |
| except RuntimeError as e: | |
| logging.warning(e) | |
| mnist_root = Path(TEMPDIR + "/MNIST") | |
| # Corrupted data, erase and restart | |
| shutil.rmtree(str(mnist_root)) | |
| tentatives += 1 | |
| if done is False: | |
| logging.error("Could not download MNIST dataset") | |
| exit(-1) | |
| else: | |
| logging.info("Dataset downloaded") | |