import hashlib import os import re import shutil import tarfile import urllib import warnings import zipfile from urllib.request import urlretrieve from keras.src.api_export import keras_export from keras.src.backend import config from keras.src.utils import io_utils from keras.src.utils.module_utils import gfile from keras.src.utils.progbar import Progbar def path_to_string(path): """Convert `PathLike` objects to their string representation. If given a non-string typed path object, converts it to its string representation. If the object passed to `path` is not among the above, then it is returned unchanged. This allows e.g. passthrough of file objects through this function. Args: path: `PathLike` object that represents a path Returns: A string representation of the path argument, if Python support exists. """ if isinstance(path, os.PathLike): return os.fspath(path) return path def resolve_path(path): return os.path.realpath(os.path.abspath(path)) def is_path_in_dir(path, base_dir): return resolve_path(os.path.join(base_dir, path)).startswith(base_dir) def is_link_in_dir(info, base): tip = resolve_path(os.path.join(base, os.path.dirname(info.name))) return is_path_in_dir(info.linkname, base_dir=tip) def filter_safe_paths(members): base_dir = resolve_path(".") for finfo in members: valid_path = False if is_path_in_dir(finfo.name, base_dir): valid_path = True yield finfo elif finfo.issym() or finfo.islnk(): if is_link_in_dir(finfo, base_dir): valid_path = True yield finfo if not valid_path: warnings.warn( "Skipping invalid path during archive extraction: " f"'{finfo.name}'.", stacklevel=2, ) def extract_archive(file_path, path=".", archive_format="auto"): """Extracts an archive if it matches a support format. Supports `.tar`, `.tar.gz`, `.tar.bz`, and `.zip` formats. Args: file_path: Path to the archive file. path: Where to extract the archive file. archive_format: Archive format to try for extracting the file. Options are `"auto"`, `"tar"`, `"zip"`, and `None`. `"tar"` includes `.tar`, `.tar.gz`, and `.tar.bz` files. The default `"auto"` uses `["tar", "zip"]`. `None` or an empty list will return no matches found. Returns: `True` if a match was found and an archive extraction was completed, `False` otherwise. """ if archive_format is None: return False if archive_format == "auto": archive_format = ["tar", "zip"] if isinstance(archive_format, str): archive_format = [archive_format] file_path = path_to_string(file_path) path = path_to_string(path) for archive_type in archive_format: if archive_type == "tar": open_fn = tarfile.open is_match_fn = tarfile.is_tarfile elif archive_type == "zip": open_fn = zipfile.ZipFile is_match_fn = zipfile.is_zipfile else: raise NotImplementedError(archive_type) if is_match_fn(file_path): with open_fn(file_path) as archive: try: if zipfile.is_zipfile(file_path): # Zip archive. archive.extractall(path) else: # Tar archive, perhaps unsafe. Filter paths. archive.extractall( path, members=filter_safe_paths(archive) ) except (tarfile.TarError, RuntimeError, KeyboardInterrupt): if os.path.exists(path): if os.path.isfile(path): os.remove(path) else: shutil.rmtree(path) raise return True return False @keras_export("keras.utils.get_file") def get_file( fname=None, origin=None, untar=False, md5_hash=None, file_hash=None, cache_subdir="datasets", hash_algorithm="auto", extract=False, archive_format="auto", cache_dir=None, force_download=False, ): """Downloads a file from a URL if it not already in the cache. By default the file at the url `origin` is downloaded to the cache_dir `~/.keras`, placed in the cache_subdir `datasets`, and given the filename `fname`. The final location of a file `example.txt` would therefore be `~/.keras/datasets/example.txt`. Files in `.tar`, `.tar.gz`, `.tar.bz`, and `.zip` formats can also be extracted. Passing a hash will verify the file after download. The command line programs `shasum` and `sha256sum` can compute the hash. Example: ```python path_to_downloaded_file = get_file( origin="https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz", extract=True, ) ``` Args: fname: If the target is a single file, this is your desired local name for the file. If `None`, the name of the file at `origin` will be used. If downloading and extracting a directory archive, the provided `fname` will be used as extraction directory name (only if it doesn't have an extension). origin: Original URL of the file. untar: Deprecated in favor of `extract` argument. Boolean, whether the file is a tar archive that should be extracted. md5_hash: Deprecated in favor of `file_hash` argument. md5 hash of the file for file integrity verification. file_hash: The expected hash string of the file after download. The sha256 and md5 hash algorithms are both supported. cache_subdir: Subdirectory under the Keras cache dir where the file is saved. If an absolute path, e.g. `"/path/to/folder"` is specified, the file will be saved at that location. hash_algorithm: Select the hash algorithm to verify the file. options are `"md5'`, `"sha256'`, and `"auto'`. The default 'auto' detects the hash algorithm in use. extract: If `True`, extracts the archive. Only applicable to compressed archive files like tar or zip. archive_format: Archive format to try for extracting the file. Options are `"auto'`, `"tar'`, `"zip'`, and `None`. `"tar"` includes tar, tar.gz, and tar.bz files. The default `"auto"` corresponds to `["tar", "zip"]`. None or an empty list will return no matches found. cache_dir: Location to store cached files, when None it defaults ether `$KERAS_HOME` if the `KERAS_HOME` environment variable is set or `~/.keras/`. force_download: If `True`, the file will always be re-downloaded regardless of the cache state. Returns: Path to the downloaded file. **⚠️ Warning on malicious downloads ⚠️** Downloading something from the Internet carries a risk. NEVER download a file/archive if you do not trust the source. We recommend that you specify the `file_hash` argument (if the hash of the source file is known) to make sure that the file you are getting is the one you expect. """ if origin is None: raise ValueError( 'Please specify the "origin" argument (URL of the file ' "to download)." ) if cache_dir is None: cache_dir = config.keras_home() if md5_hash is not None and file_hash is None: file_hash = md5_hash hash_algorithm = "md5" datadir_base = os.path.expanduser(cache_dir) if not os.access(datadir_base, os.W_OK): datadir_base = os.path.join("/tmp", ".keras") datadir = os.path.join(datadir_base, cache_subdir) os.makedirs(datadir, exist_ok=True) provided_fname = fname fname = path_to_string(fname) if not fname: fname = os.path.basename(urllib.parse.urlsplit(origin).path) if not fname: raise ValueError( "Can't parse the file name from the origin provided: " f"'{origin}'." "Please specify the `fname` argument." ) else: if os.sep in fname: raise ValueError( "Paths are no longer accepted as the `fname` argument. " "To specify the file's parent directory, use " f"the `cache_dir` argument. Received: fname={fname}" ) if extract or untar: if provided_fname: if "." in fname: download_target = os.path.join(datadir, fname) fname = fname[: fname.find(".")] extraction_dir = os.path.join(datadir, fname + "_extracted") else: extraction_dir = os.path.join(datadir, fname) download_target = os.path.join(datadir, fname + "_archive") else: extraction_dir = os.path.join(datadir, fname) download_target = os.path.join(datadir, fname + "_archive") else: download_target = os.path.join(datadir, fname) if force_download: download = True elif os.path.exists(download_target): # File found in cache. download = False # Verify integrity if a hash was provided. if file_hash is not None: if not validate_file( download_target, file_hash, algorithm=hash_algorithm ): io_utils.print_msg( "A local file was found, but it seems to be " f"incomplete or outdated because the {hash_algorithm} " "file hash does not match the original value of " f"{file_hash} so we will re-download the data." ) download = True else: download = True if download: io_utils.print_msg(f"Downloading data from {origin}") class DLProgbar: """Manage progress bar state for use in urlretrieve.""" def __init__(self): self.progbar = None self.finished = False def __call__(self, block_num, block_size, total_size): if total_size == -1: total_size = None if not self.progbar: self.progbar = Progbar(total_size) current = block_num * block_size if total_size is None: self.progbar.update(current) else: if current < total_size: self.progbar.update(current) elif not self.finished: self.progbar.update(self.progbar.target) self.finished = True error_msg = "URL fetch failure on {}: {} -- {}" try: try: urlretrieve(origin, download_target, DLProgbar()) except urllib.error.HTTPError as e: raise Exception(error_msg.format(origin, e.code, e.msg)) except urllib.error.URLError as e: raise Exception(error_msg.format(origin, e.errno, e.reason)) except (Exception, KeyboardInterrupt): if os.path.exists(download_target): os.remove(download_target) raise # Validate download if succeeded and user provided an expected hash # Security conscious users would get the hash of the file from a # separate channel and pass it to this API to prevent MITM / corruption: if os.path.exists(download_target) and file_hash is not None: if not validate_file( download_target, file_hash, algorithm=hash_algorithm ): raise ValueError( "Incomplete or corrupted file detected. " f"The {hash_algorithm} " "file hash does not match the provided value " f"of {file_hash}." ) if extract or untar: if untar: archive_format = "tar" status = extract_archive( download_target, extraction_dir, archive_format ) if not status: warnings.warn("Could not extract archive.", stacklevel=2) return extraction_dir return download_target def resolve_hasher(algorithm, file_hash=None): """Returns hash algorithm as hashlib function.""" if algorithm == "sha256": return hashlib.sha256() if algorithm == "auto" and file_hash is not None and len(file_hash) == 64: return hashlib.sha256() # This is used only for legacy purposes. return hashlib.md5() def hash_file(fpath, algorithm="sha256", chunk_size=65535): """Calculates a file sha256 or md5 hash. Example: >>> hash_file('/path/to/file.zip') 'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855' Args: fpath: Path to the file being validated. algorithm: Hash algorithm, one of `"auto"`, `"sha256"`, or `"md5"`. The default `"auto"` detects the hash algorithm in use. chunk_size: Bytes to read at a time, important for large files. Returns: The file hash. """ if isinstance(algorithm, str): hasher = resolve_hasher(algorithm) else: hasher = algorithm with open(fpath, "rb") as fpath_file: for chunk in iter(lambda: fpath_file.read(chunk_size), b""): hasher.update(chunk) return hasher.hexdigest() def validate_file(fpath, file_hash, algorithm="auto", chunk_size=65535): """Validates a file against a sha256 or md5 hash. Args: fpath: path to the file being validated file_hash: The expected hash string of the file. The sha256 and md5 hash algorithms are both supported. algorithm: Hash algorithm, one of `"auto"`, `"sha256"`, or `"md5"`. The default `"auto"` detects the hash algorithm in use. chunk_size: Bytes to read at a time, important for large files. Returns: Boolean, whether the file is valid. """ hasher = resolve_hasher(algorithm, file_hash) if str(hash_file(fpath, hasher, chunk_size)) == str(file_hash): return True else: return False def is_remote_path(filepath): """ Determines if a given filepath indicates a remote location. This function checks if the filepath represents a known remote pattern such as GCS (`/gcs`), CNS (`/cns`), CFS (`/cfs`), HDFS (`/hdfs`), Placer (`/placer`), TFHub (`/tfhub`), or a URL (`.*://`). Args: filepath (str): The path to be checked. Returns: bool: True if the filepath is a recognized remote path, otherwise False """ if re.match( r"^(/cns|/cfs|/gcs|/hdfs|/readahead|/placer|/tfhub|.*://).*$", str(filepath), ): return True return False # Below are gfile-replacement utils. def _raise_if_no_gfile(path): raise ValueError( "Handling remote paths requires installing TensorFlow " f"(in order to use gfile). Received path: {path}" ) def exists(path): if is_remote_path(path): if gfile.available: return gfile.exists(path) else: _raise_if_no_gfile(path) return os.path.exists(path) def File(path, mode="r"): if is_remote_path(path): if gfile.available: return gfile.GFile(path, mode=mode) else: _raise_if_no_gfile(path) return open(path, mode=mode) def join(path, *paths): if is_remote_path(path): if gfile.available: return gfile.join(path, *paths) else: _raise_if_no_gfile(path) return os.path.join(path, *paths) def isdir(path): if is_remote_path(path): if gfile.available: return gfile.isdir(path) else: _raise_if_no_gfile(path) return os.path.isdir(path) def remove(path): if is_remote_path(path): if gfile.available: return gfile.remove(path) else: _raise_if_no_gfile(path) return os.remove(path) def rmtree(path): if is_remote_path(path): if gfile.available: return gfile.rmtree(path) else: _raise_if_no_gfile(path) return shutil.rmtree(path) def listdir(path): if is_remote_path(path): if gfile.available: return gfile.listdir(path) else: _raise_if_no_gfile(path) return os.listdir(path) def copy(src, dst): if is_remote_path(src) or is_remote_path(dst): if gfile.available: return gfile.copy(src, dst, overwrite=True) else: _raise_if_no_gfile(f"src={src} dst={dst}") return shutil.copy(src, dst) def makedirs(path): if is_remote_path(path): if gfile.available: return gfile.makedirs(path) else: _raise_if_no_gfile(path) return os.makedirs(path)