| """ |
| Utilities for working with the local dataset cache. |
| This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp |
| Copyright by the AllenNLP authors. |
| """ |
|
|
| import fnmatch |
| import json |
| import logging |
| import os |
| import shutil |
| import sys |
| import tarfile |
| import tempfile |
| from contextlib import contextmanager |
| from functools import partial, wraps |
| from hashlib import sha256 |
| from typing import Optional |
| from urllib.parse import urlparse |
| from zipfile import ZipFile, is_zipfile |
|
|
| import boto3 |
| import requests |
| from botocore.config import Config |
| from botocore.exceptions import ClientError |
| from filelock import FileLock |
| from tqdm.auto import tqdm |
|
|
| from . import __version__ |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
| try: |
| USE_TF = os.environ.get("USE_TF", "AUTO").upper() |
| USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() |
| if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"): |
| import torch |
|
|
| _torch_available = True |
| logger.info("PyTorch version {} available.".format(torch.__version__)) |
| else: |
| logger.info("Disabling PyTorch because USE_TF is set") |
| _torch_available = False |
| except ImportError: |
| _torch_available = False |
|
|
| try: |
| USE_TF = os.environ.get("USE_TF", "AUTO").upper() |
| USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() |
|
|
| if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"): |
| import tensorflow as tf |
|
|
| assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2 |
| _tf_available = True |
| logger.info("TensorFlow version {} available.".format(tf.__version__)) |
| else: |
| logger.info("Disabling Tensorflow because USE_TORCH is set") |
| _tf_available = False |
| except (ImportError, AssertionError): |
| _tf_available = False |
|
|
| try: |
| from torch.hub import _get_torch_home |
|
|
| torch_cache_home = _get_torch_home() |
| except ImportError: |
| torch_cache_home = os.path.expanduser( |
| os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch")) |
| ) |
| default_cache_path = os.path.join(torch_cache_home, "transformers") |
|
|
| try: |
| from pathlib import Path |
|
|
| PYTORCH_PRETRAINED_BERT_CACHE = Path( |
| os.getenv("PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)) |
| ) |
| except (AttributeError, ImportError): |
| PYTORCH_PRETRAINED_BERT_CACHE = os.getenv( |
| "PYTORCH_TRANSFORMERS_CACHE", os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path) |
| ) |
|
|
| PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE |
| TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE |
|
|
| WEIGHTS_NAME = "pytorch_model.bin" |
| TF2_WEIGHTS_NAME = "tf_model.h5" |
| TF_WEIGHTS_NAME = "model.ckpt" |
| CONFIG_NAME = "config.json" |
| MODEL_CARD_NAME = "modelcard.json" |
|
|
|
|
| MULTIPLE_CHOICE_DUMMY_INPUTS = [[[0], [1]], [[0], [1]]] |
| DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] |
| DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]] |
|
|
| S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert" |
| CLOUDFRONT_DISTRIB_PREFIX = "https://d2ws9o8vfrpkyk.cloudfront.net" |
|
|
|
|
| def is_torch_available(): |
| return _torch_available |
|
|
|
|
| def is_tf_available(): |
| return _tf_available |
|
|
|
|
| def add_start_docstrings(*docstr): |
| def docstring_decorator(fn): |
| fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") |
| return fn |
|
|
| return docstring_decorator |
|
|
|
|
| def add_start_docstrings_to_callable(*docstr): |
| def docstring_decorator(fn): |
| class_name = ":class:`~transformers.{}`".format(fn.__qualname__.split(".")[0]) |
| intro = " The {} forward method, overrides the :func:`__call__` special method.".format(class_name) |
| note = r""" |
| |
| .. note:: |
| Although the recipe for forward pass needs to be defined within |
| this function, one should call the :class:`Module` instance afterwards |
| instead of this since the former takes care of running the |
| pre and post processing steps while the latter silently ignores them. |
| """ |
| fn.__doc__ = intro + note + "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") |
| return fn |
|
|
| return docstring_decorator |
|
|
|
|
| def add_end_docstrings(*docstr): |
| def docstring_decorator(fn): |
| fn.__doc__ = fn.__doc__ + "".join(docstr) |
| return fn |
|
|
| return docstring_decorator |
|
|
|
|
| def is_remote_url(url_or_filename): |
| parsed = urlparse(url_or_filename) |
| return parsed.scheme in ("http", "https", "s3") |
|
|
|
|
| def hf_bucket_url(identifier, postfix=None, cdn=False) -> str: |
| endpoint = CLOUDFRONT_DISTRIB_PREFIX if cdn else S3_BUCKET_PREFIX |
| if postfix is None: |
| return "/".join((endpoint, identifier)) |
| else: |
| return "/".join((endpoint, identifier, postfix)) |
|
|
|
|
| def url_to_filename(url, etag=None): |
| """ |
| Convert `url` into a hashed filename in a repeatable way. |
| If `etag` is specified, append its hash to the url's, delimited |
| by a period. |
| If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name |
| so that TF 2.0 can identify it as a HDF5 file |
| (see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380) |
| """ |
| url_bytes = url.encode("utf-8") |
| url_hash = sha256(url_bytes) |
| filename = url_hash.hexdigest() |
|
|
| if etag: |
| etag_bytes = etag.encode("utf-8") |
| etag_hash = sha256(etag_bytes) |
| filename += "." + etag_hash.hexdigest() |
|
|
| if url.endswith(".h5"): |
| filename += ".h5" |
|
|
| return filename |
|
|
|
|
| def filename_to_url(filename, cache_dir=None): |
| """ |
| Return the url and etag (which may be ``None``) stored for `filename`. |
| Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. |
| """ |
| if cache_dir is None: |
| cache_dir = TRANSFORMERS_CACHE |
| if isinstance(cache_dir, Path): |
| cache_dir = str(cache_dir) |
|
|
| cache_path = os.path.join(cache_dir, filename) |
| if not os.path.exists(cache_path): |
| raise EnvironmentError("file {} not found".format(cache_path)) |
|
|
| meta_path = cache_path + ".json" |
| if not os.path.exists(meta_path): |
| raise EnvironmentError("file {} not found".format(meta_path)) |
|
|
| with open(meta_path, encoding="utf-8") as meta_file: |
| metadata = json.load(meta_file) |
| url = metadata["url"] |
| etag = metadata["etag"] |
|
|
| return url, etag |
|
|
|
|
| def cached_path( |
| url_or_filename, |
| cache_dir=None, |
| force_download=False, |
| proxies=None, |
| resume_download=False, |
| user_agent=None, |
| extract_compressed_file=False, |
| force_extract=False, |
| ) -> Optional[str]: |
| """ |
| Given something that might be a URL (or might be a local path), |
| determine which. If it's a URL, download the file and cache it, and |
| return the path to the cached file. If it's already a local path, |
| make sure the file exists and then return the path. |
| Args: |
| cache_dir: specify a cache directory to save the file to (overwrite the default cache dir). |
| force_download: if True, re-dowload the file even if it's already cached in the cache dir. |
| resume_download: if True, resume the download if incompletly recieved file is found. |
| user_agent: Optional string or dict that will be appended to the user-agent on remote requests. |
| extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed |
| file in a folder along the archive. |
| force_extract: if True when extract_compressed_file is True and the archive was already extracted, |
| re-extract the archive and overide the folder where it was extracted. |
| |
| Return: |
| None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). |
| Local path (string) otherwise |
| """ |
| if cache_dir is None: |
| cache_dir = TRANSFORMERS_CACHE |
| if isinstance(url_or_filename, Path): |
| url_or_filename = str(url_or_filename) |
| if isinstance(cache_dir, Path): |
| cache_dir = str(cache_dir) |
|
|
| if is_remote_url(url_or_filename): |
| |
| output_path = get_from_cache( |
| url_or_filename, |
| cache_dir=cache_dir, |
| force_download=force_download, |
| proxies=proxies, |
| resume_download=resume_download, |
| user_agent=user_agent, |
| ) |
| elif os.path.exists(url_or_filename): |
| |
| output_path = url_or_filename |
| elif urlparse(url_or_filename).scheme == "": |
| |
| raise EnvironmentError("file {} not found".format(url_or_filename)) |
| else: |
| |
| raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) |
|
|
| if extract_compressed_file: |
| if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path): |
| return output_path |
|
|
| |
| |
| output_dir, output_file = os.path.split(output_path) |
| output_extract_dir_name = output_file.replace(".", "-") + "-extracted" |
| output_path_extracted = os.path.join(output_dir, output_extract_dir_name) |
|
|
| if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract: |
| return output_path_extracted |
|
|
| |
| lock_path = output_path + ".lock" |
| with FileLock(lock_path): |
| shutil.rmtree(output_path_extracted, ignore_errors=True) |
| os.makedirs(output_path_extracted) |
| if is_zipfile(output_path): |
| with ZipFile(output_path, "r") as zip_file: |
| zip_file.extractall(output_path_extracted) |
| zip_file.close() |
| elif tarfile.is_tarfile(output_path): |
| tar_file = tarfile.open(output_path) |
| tar_file.extractall(output_path_extracted) |
| tar_file.close() |
| else: |
| raise EnvironmentError("Archive format of {} could not be identified".format(output_path)) |
|
|
| return output_path_extracted |
|
|
| return output_path |
|
|
|
|
| def split_s3_path(url): |
| """Split a full s3 path into the bucket name and path.""" |
| parsed = urlparse(url) |
| if not parsed.netloc or not parsed.path: |
| raise ValueError("bad s3 path {}".format(url)) |
| bucket_name = parsed.netloc |
| s3_path = parsed.path |
| |
| if s3_path.startswith("/"): |
| s3_path = s3_path[1:] |
| return bucket_name, s3_path |
|
|
|
|
| def s3_request(func): |
| """ |
| Wrapper function for s3 requests in order to create more helpful error |
| messages. |
| """ |
|
|
| @wraps(func) |
| def wrapper(url, *args, **kwargs): |
| try: |
| return func(url, *args, **kwargs) |
| except ClientError as exc: |
| if int(exc.response["Error"]["Code"]) == 404: |
| raise EnvironmentError("file {} not found".format(url)) |
| else: |
| raise |
|
|
| return wrapper |
|
|
|
|
| @s3_request |
| def s3_etag(url, proxies=None): |
| """Check ETag on S3 object.""" |
| s3_resource = boto3.resource("s3", config=Config(proxies=proxies)) |
| bucket_name, s3_path = split_s3_path(url) |
| s3_object = s3_resource.Object(bucket_name, s3_path) |
| return s3_object.e_tag |
|
|
|
|
| @s3_request |
| def s3_get(url, temp_file, proxies=None): |
| """Pull a file directly from S3.""" |
| s3_resource = boto3.resource("s3", config=Config(proxies=proxies)) |
| bucket_name, s3_path = split_s3_path(url) |
| s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) |
|
|
|
|
| def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None): |
| ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0]) |
| if is_torch_available(): |
| ua += "; torch/{}".format(torch.__version__) |
| if is_tf_available(): |
| ua += "; tensorflow/{}".format(tf.__version__) |
| if isinstance(user_agent, dict): |
| ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items()) |
| elif isinstance(user_agent, str): |
| ua += "; " + user_agent |
| headers = {"user-agent": ua} |
| if resume_size > 0: |
| headers["Range"] = "bytes=%d-" % (resume_size,) |
| response = requests.get(url, stream=True, proxies=proxies, headers=headers) |
| if response.status_code == 416: |
| return |
| content_length = response.headers.get("Content-Length") |
| total = resume_size + int(content_length) if content_length is not None else None |
| progress = tqdm( |
| unit="B", |
| unit_scale=True, |
| total=total, |
| initial=resume_size, |
| desc="Downloading", |
| disable=bool(logger.getEffectiveLevel() == logging.NOTSET), |
| ) |
| for chunk in response.iter_content(chunk_size=1024): |
| if chunk: |
| progress.update(len(chunk)) |
| temp_file.write(chunk) |
| progress.close() |
|
|
|
|
| def get_from_cache( |
| url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None |
| ) -> Optional[str]: |
| """ |
| Given a URL, look for the corresponding file in the local cache. |
| If it's not there, download it. Then return the path to the cached file. |
| |
| Return: |
| None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). |
| Local path (string) otherwise |
| """ |
| if cache_dir is None: |
| cache_dir = TRANSFORMERS_CACHE |
| if isinstance(cache_dir, Path): |
| cache_dir = str(cache_dir) |
|
|
| os.makedirs(cache_dir, exist_ok=True) |
|
|
| |
| if url.startswith("s3://"): |
| etag = s3_etag(url, proxies=proxies) |
| else: |
| try: |
| response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout) |
| if response.status_code != 200: |
| etag = None |
| else: |
| etag = response.headers.get("ETag") |
| except (EnvironmentError, requests.exceptions.Timeout): |
| etag = None |
|
|
| filename = url_to_filename(url, etag) |
|
|
| |
| cache_path = os.path.join(cache_dir, filename) |
|
|
| |
| |
| if etag is None: |
| if os.path.exists(cache_path): |
| return cache_path |
| else: |
| matching_files = [ |
| file |
| for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*") |
| if not file.endswith(".json") and not file.endswith(".lock") |
| ] |
| if len(matching_files) > 0: |
| return os.path.join(cache_dir, matching_files[-1]) |
| else: |
| return None |
|
|
| |
| if os.path.exists(cache_path) and not force_download: |
| return cache_path |
|
|
| |
| lock_path = cache_path + ".lock" |
| with FileLock(lock_path): |
|
|
| if resume_download: |
| incomplete_path = cache_path + ".incomplete" |
|
|
| @contextmanager |
| def _resumable_file_manager(): |
| with open(incomplete_path, "a+b") as f: |
| yield f |
|
|
| temp_file_manager = _resumable_file_manager |
| if os.path.exists(incomplete_path): |
| resume_size = os.stat(incomplete_path).st_size |
| else: |
| resume_size = 0 |
| else: |
| temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False) |
| resume_size = 0 |
|
|
| |
| |
| with temp_file_manager() as temp_file: |
| logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name) |
|
|
| |
| if url.startswith("s3://"): |
| if resume_download: |
| logger.warn('Warning: resumable downloads are not implemented for "s3://" urls') |
| s3_get(url, temp_file, proxies=proxies) |
| else: |
| http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent) |
|
|
| logger.info("storing %s in cache at %s", url, cache_path) |
| os.rename(temp_file.name, cache_path) |
|
|
| logger.info("creating metadata file for %s", cache_path) |
| meta = {"url": url, "etag": etag} |
| meta_path = cache_path + ".json" |
| with open(meta_path, "w") as meta_file: |
| json.dump(meta, meta_file) |
|
|
| return cache_path |
|
|