| import os | |
| import sys | |
| try: | |
| from urllib import urlretrieve | |
| except ImportError: | |
| from urllib.request import urlretrieve | |
| import torch | |
| def load_url(url, model_dir="./pretrained", map_location=torch.device("cpu")): | |
| if not os.path.exists(model_dir): | |
| os.makedirs(model_dir) | |
| filename = url.split("/")[-1] | |
| cached_file = os.path.join(model_dir, filename) | |
| if not os.path.exists(cached_file): | |
| sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) | |
| urlretrieve(url, cached_file) | |
| return torch.load(cached_file, map_location=map_location) | |