# # SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import argparse import errno import hashlib import logging import os import sys logger = logging.getLogger("downloader") class DataFile: """Holder of a data file.""" def __init__(self, attr): self.attr = attr self.path = attr["path"] self.url = attr["url"] if "checksum" not in attr: logger.warning("Checksum of %s not provided!", self.path) self.checksum = attr.get("checksum", None) def __str__(self): return str(self.attr) class SampleData: """Holder of data files of an sample.""" def __init__(self, attr): self.attr = attr self.sample = attr["sample"] files = attr.get("files", []) self.files = [DataFile(f) for f in files] def __str__(self): return str(self.attr) def _loadYAML(yaml_path): with open(yaml_path, "rb") as f: import yaml y = yaml.load(f, yaml.FullLoader) return SampleData(y) def _checkMD5(path, refMD5): md5 = hashlib.md5(open(path, "rb").read()).hexdigest() return md5 == refMD5 def _createDirIfNeeded(path): the_dir = os.path.dirname(path) try: os.makedirs(the_dir) except OSError as e: if e.errno != errno.EEXIST: raise def download(data_dir, yaml_path, retries, overwrite=False): """Download the data files specified in YAML file to a directory. Return false if the downloaded file or the local copy (if not overwrite) has a different checksum. """ sample_data = _loadYAML(yaml_path) logger.info("Downloading data for %s", sample_data.sample) def _downloadFile(path, url, retries): logger.info("Downloading %s from %s", path, url) import requests from requests.adapters import HTTPAdapter, Retry session = requests.Session() retries = Retry(total=retries, backoff_factor=0.5) session.mount("http://", HTTPAdapter(max_retries=retries)) session.mount("https://", HTTPAdapter(max_retries=retries)) try: r = session.get(url, stream=True, timeout=60) if r.status_code == 200: logger.info("Connecting to %s is successful.", url) size = int(r.headers.get("content-length", 0)) from tqdm import tqdm progress_bar = tqdm(total=size, unit="iB", unit_scale=True) with open(path, "wb") as fd: for chunk in r.iter_content(chunk_size=1024): progress_bar.update(len(chunk)) fd.write(chunk) progress_bar.close() return True else: logger.info("Failed to connect to %s with status code: %s.", url, r.status_code) return False except requests.exceptions.ConnectionError as e: logger.debug("Connection failed after retries:", e) except requests.exceptions.Timeout as e: logger.debug("A timeout occurred:", e) except requests.exceptions.RequestException as e: logger.debug("Error occurred while requesting connection to %s: %s.", url, e) return False allGood = True for f in sample_data.files: fpath = os.path.join(data_dir, f.path) if os.path.exists(fpath): if _checkMD5(fpath, f.checksum): logger.info("Found local copy %s, skip downloading.", fpath) continue else: logger.warning("Local copy %s has a different checksum!", fpath) if overwrite: logging.warning("Removing local copy %s", fpath) os.remove(fpath) else: allGood = False continue _createDirIfNeeded(fpath) assert _downloadFile(fpath, f.url, retries=retries) if not _checkMD5(fpath, f.checksum): logger.error("The downloaded file %s has a different checksum!", fpath) allGood = False return allGood def _parseArgs(): parser = argparse.ArgumentParser(description="Downloader of TensorRT sample data files.") parser.add_argument( "-d", "--data", help="Specify the data directory, data will be downloaded to there. $TRT_DATA_DIR will be overwritten by this argument.", ) parser.add_argument( "-f", "--file", help="Specify the path to the download.yml, default to `download.yml` in the working directory", default="download.yml", ) parser.add_argument( "-o", "--overwrite", help="Force to overwrite if MD5 check failed", action="store_true", default=False, ) parser.add_argument( "-v", "--verify", help="Verify if the data has been downloaded. Will not download if specified.", action="store_true", default=False, ) parser.add_argument( "-r", "--retries", help="Number of retries for download", type=int, default=10, ) args, _ = parser.parse_known_args() data = os.environ.get("TRT_DATA_DIR", None) if args.data is None else args.data if data is None: raise ValueError("Data directory must be specified by either `-d $DATA` or environment variable $TRT_DATA_DIR.") return data, args def verifyChecksum(data_dir, yaml_path): """Verify the checksum of the files described by the YAML. Return false of any of the file doesn't existed or checksum is different with the YAML. """ sample_data = _loadYAML(yaml_path) logger.info("Verifying data files and their MD5 for %s", sample_data.sample) allGood = True for f in sample_data.files: fpath = os.path.join(data_dir, f.path) if os.path.exists(fpath): if _checkMD5(fpath, f.checksum): logger.info("MD5 match for local copy %s", fpath) else: logger.error("Local file %s has a different checksum!", fpath) allGood = False else: allGood = False logger.error("Data file %s doesn't have a local copy", f.path) return allGood def main(): data, args = _parseArgs() logging.basicConfig() logger.setLevel(logging.INFO) ret = True if args.verify: ret = verifyChecksum(data, args.file) else: ret = download(data, args.file, args.retries, args.overwrite) if not ret: # Error of downloading or checksum sys.exit(1) if __name__ == "__main__": main() TRT_DATA_DIR = None def getFilePath(path): """Util to get the full path to the downloaded data files. It only works when the sample doesn't have any other command line argument. """ global TRT_DATA_DIR if not TRT_DATA_DIR: parser = argparse.ArgumentParser(description="Helper of data file download tool") parser.add_argument( "-d", "--data", help="Specify the data directory where it is saved in. $TRT_DATA_DIR will be overwritten by this argument.", ) args, _ = parser.parse_known_args() TRT_DATA_DIR = os.environ.get("TRT_DATA_DIR", None) if args.data is None else args.data if TRT_DATA_DIR is None: raise ValueError("Data directory must be specified by either `-d $DATA` or environment variable $TRT_DATA_DIR.") fullpath = os.path.join(TRT_DATA_DIR, path) if not os.path.exists(fullpath): raise ValueError("Data file %s doesn't exist!" % fullpath) return fullpath