G0-VLA / g0plus_dockerfile /docker-assets /data /TensorRT-10.13.0.35 /samples /python /downloader.py
| # | |
| # SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| import argparse | |
| import errno | |
| import hashlib | |
| import logging | |
| import os | |
| import sys | |
| logger = logging.getLogger("downloader") | |
| class DataFile: | |
| """Holder of a data file.""" | |
| def __init__(self, attr): | |
| self.attr = attr | |
| self.path = attr["path"] | |
| self.url = attr["url"] | |
| if "checksum" not in attr: | |
| logger.warning("Checksum of %s not provided!", self.path) | |
| self.checksum = attr.get("checksum", None) | |
| def __str__(self): | |
| return str(self.attr) | |
| class SampleData: | |
| """Holder of data files of an sample.""" | |
| def __init__(self, attr): | |
| self.attr = attr | |
| self.sample = attr["sample"] | |
| files = attr.get("files", []) | |
| self.files = [DataFile(f) for f in files] | |
| def __str__(self): | |
| return str(self.attr) | |
| def _loadYAML(yaml_path): | |
| with open(yaml_path, "rb") as f: | |
| import yaml | |
| y = yaml.load(f, yaml.FullLoader) | |
| return SampleData(y) | |
| def _checkMD5(path, refMD5): | |
| md5 = hashlib.md5(open(path, "rb").read()).hexdigest() | |
| return md5 == refMD5 | |
| def _createDirIfNeeded(path): | |
| the_dir = os.path.dirname(path) | |
| try: | |
| os.makedirs(the_dir) | |
| except OSError as e: | |
| if e.errno != errno.EEXIST: | |
| raise | |
| def download(data_dir, yaml_path, retries, overwrite=False): | |
| """Download the data files specified in YAML file to a directory. | |
| Return false if the downloaded file or the local copy (if not overwrite) has a different checksum. | |
| """ | |
| sample_data = _loadYAML(yaml_path) | |
| logger.info("Downloading data for %s", sample_data.sample) | |
| def _downloadFile(path, url, retries): | |
| logger.info("Downloading %s from %s", path, url) | |
| import requests | |
| from requests.adapters import HTTPAdapter, Retry | |
| session = requests.Session() | |
| retries = Retry(total=retries, backoff_factor=0.5) | |
| session.mount("http://", HTTPAdapter(max_retries=retries)) | |
| session.mount("https://", HTTPAdapter(max_retries=retries)) | |
| try: | |
| r = session.get(url, stream=True, timeout=60) | |
| if r.status_code == 200: | |
| logger.info("Connecting to %s is successful.", url) | |
| size = int(r.headers.get("content-length", 0)) | |
| from tqdm import tqdm | |
| progress_bar = tqdm(total=size, unit="iB", unit_scale=True) | |
| with open(path, "wb") as fd: | |
| for chunk in r.iter_content(chunk_size=1024): | |
| progress_bar.update(len(chunk)) | |
| fd.write(chunk) | |
| progress_bar.close() | |
| return True | |
| else: | |
| logger.info("Failed to connect to %s with status code: %s.", url, r.status_code) | |
| return False | |
| except requests.exceptions.ConnectionError as e: | |
| logger.debug("Connection failed after retries:", e) | |
| except requests.exceptions.Timeout as e: | |
| logger.debug("A timeout occurred:", e) | |
| except requests.exceptions.RequestException as e: | |
| logger.debug("Error occurred while requesting connection to %s: %s.", url, e) | |
| return False | |
| allGood = True | |
| for f in sample_data.files: | |
| fpath = os.path.join(data_dir, f.path) | |
| if os.path.exists(fpath): | |
| if _checkMD5(fpath, f.checksum): | |
| logger.info("Found local copy %s, skip downloading.", fpath) | |
| continue | |
| else: | |
| logger.warning("Local copy %s has a different checksum!", fpath) | |
| if overwrite: | |
| logging.warning("Removing local copy %s", fpath) | |
| os.remove(fpath) | |
| else: | |
| allGood = False | |
| continue | |
| _createDirIfNeeded(fpath) | |
| assert _downloadFile(fpath, f.url, retries=retries) | |
| if not _checkMD5(fpath, f.checksum): | |
| logger.error("The downloaded file %s has a different checksum!", fpath) | |
| allGood = False | |
| return allGood | |
| def _parseArgs(): | |
| parser = argparse.ArgumentParser(description="Downloader of TensorRT sample data files.") | |
| parser.add_argument( | |
| "-d", | |
| "--data", | |
| help="Specify the data directory, data will be downloaded to there. $TRT_DATA_DIR will be overwritten by this argument.", | |
| ) | |
| parser.add_argument( | |
| "-f", | |
| "--file", | |
| help="Specify the path to the download.yml, default to `download.yml` in the working directory", | |
| default="download.yml", | |
| ) | |
| parser.add_argument( | |
| "-o", | |
| "--overwrite", | |
| help="Force to overwrite if MD5 check failed", | |
| action="store_true", | |
| default=False, | |
| ) | |
| parser.add_argument( | |
| "-v", | |
| "--verify", | |
| help="Verify if the data has been downloaded. Will not download if specified.", | |
| action="store_true", | |
| default=False, | |
| ) | |
| parser.add_argument( | |
| "-r", | |
| "--retries", | |
| help="Number of retries for download", | |
| type=int, | |
| default=10, | |
| ) | |
| args, _ = parser.parse_known_args() | |
| data = os.environ.get("TRT_DATA_DIR", None) if args.data is None else args.data | |
| if data is None: | |
| raise ValueError("Data directory must be specified by either `-d $DATA` or environment variable $TRT_DATA_DIR.") | |
| return data, args | |
| def verifyChecksum(data_dir, yaml_path): | |
| """Verify the checksum of the files described by the YAML. | |
| Return false of any of the file doesn't existed or checksum is different with the YAML. | |
| """ | |
| sample_data = _loadYAML(yaml_path) | |
| logger.info("Verifying data files and their MD5 for %s", sample_data.sample) | |
| allGood = True | |
| for f in sample_data.files: | |
| fpath = os.path.join(data_dir, f.path) | |
| if os.path.exists(fpath): | |
| if _checkMD5(fpath, f.checksum): | |
| logger.info("MD5 match for local copy %s", fpath) | |
| else: | |
| logger.error("Local file %s has a different checksum!", fpath) | |
| allGood = False | |
| else: | |
| allGood = False | |
| logger.error("Data file %s doesn't have a local copy", f.path) | |
| return allGood | |
| def main(): | |
| data, args = _parseArgs() | |
| logging.basicConfig() | |
| logger.setLevel(logging.INFO) | |
| ret = True | |
| if args.verify: | |
| ret = verifyChecksum(data, args.file) | |
| else: | |
| ret = download(data, args.file, args.retries, args.overwrite) | |
| if not ret: | |
| # Error of downloading or checksum | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| main() | |
| TRT_DATA_DIR = None | |
| def getFilePath(path): | |
| """Util to get the full path to the downloaded data files. | |
| It only works when the sample doesn't have any other command line argument. | |
| """ | |
| global TRT_DATA_DIR | |
| if not TRT_DATA_DIR: | |
| parser = argparse.ArgumentParser(description="Helper of data file download tool") | |
| parser.add_argument( | |
| "-d", | |
| "--data", | |
| help="Specify the data directory where it is saved in. $TRT_DATA_DIR will be overwritten by this argument.", | |
| ) | |
| args, _ = parser.parse_known_args() | |
| TRT_DATA_DIR = os.environ.get("TRT_DATA_DIR", None) if args.data is None else args.data | |
| if TRT_DATA_DIR is None: | |
| raise ValueError("Data directory must be specified by either `-d $DATA` or environment variable $TRT_DATA_DIR.") | |
| fullpath = os.path.join(TRT_DATA_DIR, path) | |
| if not os.path.exists(fullpath): | |
| raise ValueError("Data file %s doesn't exist!" % fullpath) | |
| return fullpath | |