|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from ..path import mkdir_or_exist |
|
|
from ..version_utils import digit_version |
|
|
from .parrots_wrapper import TORCH_VERSION |
|
|
|
|
|
if TORCH_VERSION != 'parrots' and digit_version(TORCH_VERSION) < digit_version( |
|
|
'1.7.0'): |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import warnings |
|
|
import zipfile |
|
|
from urllib.parse import urlparse |
|
|
|
|
|
import torch |
|
|
from torch.hub import HASH_REGEX, _get_torch_home, download_url_to_file |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_legacy_zip_format(filename): |
|
|
if zipfile.is_zipfile(filename): |
|
|
infolist = zipfile.ZipFile(filename).infolist() |
|
|
return len(infolist) == 1 and not infolist[0].is_dir() |
|
|
return False |
|
|
|
|
|
def _legacy_zip_load(filename, model_dir, map_location): |
|
|
warnings.warn( |
|
|
'Falling back to the old format < 1.6. This support will' |
|
|
' be deprecated in favor of default zipfile format ' |
|
|
'introduced in 1.6. Please redo torch.save() to save it ' |
|
|
'in the new zipfile format.', DeprecationWarning) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with zipfile.ZipFile(filename) as f: |
|
|
members = f.infolist() |
|
|
if len(members) != 1: |
|
|
raise RuntimeError( |
|
|
'Only one file(not dir) is allowed in the zipfile') |
|
|
f.extractall(model_dir) |
|
|
extraced_name = members[0].filename |
|
|
extracted_file = os.path.join(model_dir, extraced_name) |
|
|
return torch.load(extracted_file, map_location=map_location) |
|
|
|
|
|
def load_url(url, |
|
|
model_dir=None, |
|
|
map_location=None, |
|
|
progress=True, |
|
|
check_hash=False, |
|
|
file_name=None): |
|
|
r"""Loads the Torch serialized object at the given URL. |
|
|
If downloaded file is a zip file, it will be automatically decompressed |
|
|
If the object is already present in `model_dir`, it's deserialized and |
|
|
returned. |
|
|
The default value of ``model_dir`` is ``<hub_dir>/checkpoints`` where |
|
|
``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`. |
|
|
Args: |
|
|
url (str): URL of the object to download |
|
|
model_dir (str, optional): directory in which to save the object |
|
|
map_location (optional): a function or a dict specifying how to |
|
|
remap storage locations (see torch.load) |
|
|
progress (bool, optional): whether or not to display a progress bar |
|
|
to stderr. Defaults to True |
|
|
check_hash(bool, optional): If True, the filename part of the URL |
|
|
should follow the naming convention ``filename-<sha256>.ext`` |
|
|
where ``<sha256>`` is the first eight or more digits of the |
|
|
SHA256 hash of the contents of the file. The hash is used to |
|
|
ensure unique names and to verify the contents of the file. |
|
|
Defaults to False |
|
|
file_name (str, optional): name for the downloaded file. Filename |
|
|
from ``url`` will be used if not set. Defaults to None. |
|
|
Example: |
|
|
>>> url = ('https://s3.amazonaws.com/pytorch/models/resnet18-5c106' |
|
|
... 'cde.pth') |
|
|
>>> state_dict = torch.hub.load_state_dict_from_url(url) |
|
|
""" |
|
|
|
|
|
if os.getenv('TORCH_MODEL_ZOO'): |
|
|
warnings.warn( |
|
|
'TORCH_MODEL_ZOO is deprecated, please use env ' |
|
|
'TORCH_HOME instead', DeprecationWarning) |
|
|
|
|
|
if model_dir is None: |
|
|
torch_home = _get_torch_home() |
|
|
model_dir = os.path.join(torch_home, 'checkpoints') |
|
|
|
|
|
mkdir_or_exist(model_dir) |
|
|
|
|
|
parts = urlparse(url) |
|
|
filename = os.path.basename(parts.path) |
|
|
if file_name is not None: |
|
|
filename = file_name |
|
|
cached_file = os.path.join(model_dir, filename) |
|
|
if not os.path.exists(cached_file): |
|
|
sys.stderr.write('Downloading: "{}" to {}\n'.format( |
|
|
url, cached_file)) |
|
|
hash_prefix = None |
|
|
if check_hash: |
|
|
r = HASH_REGEX.search(filename) |
|
|
hash_prefix = r.group(1) if r else None |
|
|
download_url_to_file( |
|
|
url, cached_file, hash_prefix, progress=progress) |
|
|
|
|
|
if _is_legacy_zip_format(cached_file): |
|
|
return _legacy_zip_load(cached_file, model_dir, map_location) |
|
|
|
|
|
try: |
|
|
return torch.load(cached_file, map_location=map_location) |
|
|
except RuntimeError as error: |
|
|
if digit_version(TORCH_VERSION) < digit_version('1.5.0'): |
|
|
warnings.warn( |
|
|
f'If the error is the same as "{cached_file} is a zip ' |
|
|
'archive (did you mean to use torch.jit.load()?)", you can' |
|
|
' upgrade your torch to 1.5.0 or higher (current torch ' |
|
|
f'version is {TORCH_VERSION}). The error was raised ' |
|
|
' because the checkpoint was saved in torch>=1.6.0 but ' |
|
|
'loaded in torch<1.5.') |
|
|
raise error |
|
|
else: |
|
|
from torch.utils.model_zoo import load_url |
|
|
|