| """
|
| Copyright (c) 2022, salesforce.com, inc.
|
| All rights reserved.
|
| SPDX-License-Identifier: BSD-3-Clause
|
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| """
|
|
|
| import io
|
| import json
|
| import logging
|
| import os
|
| import pickle
|
| import re
|
| import shutil
|
| import tarfile
|
| import urllib
|
| import urllib.error
|
| import urllib.request
|
| from typing import Optional
|
| from urllib.parse import urlparse
|
|
|
| import numpy as np
|
| import pandas as pd
|
| import yaml
|
| from iopath.common.download import download
|
| from iopath.common.file_io import file_lock, g_pathmgr
|
| from lavis.common.dist_utils import download_cached_file
|
| from lavis.common.registry import registry
|
| from torch.utils.model_zoo import tqdm
|
| from torchvision.datasets.utils import (
|
| check_integrity,
|
| download_file_from_google_drive,
|
| extract_archive,
|
| )
|
|
|
|
|
| def now():
|
| from datetime import datetime
|
|
|
| return datetime.now().strftime("%Y%m%d%H%M")[:-1]
|
|
|
|
|
| def is_url(url_or_filename):
|
| parsed = urlparse(url_or_filename)
|
| return parsed.scheme in ("http", "https")
|
|
|
|
|
| def get_cache_path(rel_path):
|
| return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
|
|
|
|
|
| def get_abs_path(rel_path):
|
| return os.path.join(registry.get_path("library_root"), rel_path)
|
|
|
|
|
| def load_json(filename):
|
| with open(filename, "r") as f:
|
| return json.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def makedir(dir_path):
|
| """
|
| Create the directory if it does not exist.
|
| """
|
| is_success = False
|
| try:
|
| if not g_pathmgr.exists(dir_path):
|
| g_pathmgr.mkdirs(dir_path)
|
| is_success = True
|
| except BaseException:
|
| print(f"Error creating directory: {dir_path}")
|
| return is_success
|
|
|
|
|
| def get_redirected_url(url: str):
|
| """
|
| Given a URL, returns the URL it redirects to or the
|
| original URL in case of no indirection
|
| """
|
| import requests
|
|
|
| with requests.Session() as session:
|
| with session.get(url, stream=True, allow_redirects=True) as response:
|
| if response.history:
|
| return response.url
|
| else:
|
| return url
|
|
|
|
|
| def to_google_drive_download_url(view_url: str) -> str:
|
| """
|
| Utility function to transform a view URL of google drive
|
| to a download URL for google drive
|
| Example input:
|
| https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
|
| Example output:
|
| https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
|
| """
|
| splits = view_url.split("/")
|
| assert splits[-1] == "view"
|
| file_id = splits[-2]
|
| return f"https://drive.google.com/uc?export=download&id={file_id}"
|
|
|
|
|
| def download_google_drive_url(url: str, output_path: str, output_file_name: str):
|
| """
|
| Download a file from google drive
|
| Downloading an URL from google drive requires confirmation when
|
| the file of the size is too big (google drive notifies that
|
| anti-viral checks cannot be performed on such files)
|
| """
|
| import requests
|
|
|
| with requests.Session() as session:
|
|
|
|
|
| with session.get(url, stream=True, allow_redirects=True) as response:
|
| for k, v in response.cookies.items():
|
| if k.startswith("download_warning"):
|
| url = url + "&confirm=" + v
|
|
|
|
|
| with session.get(url, stream=True, verify=True) as response:
|
| makedir(output_path)
|
| path = os.path.join(output_path, output_file_name)
|
| total_size = int(response.headers.get("Content-length", 0))
|
| with open(path, "wb") as file:
|
| from tqdm import tqdm
|
|
|
| with tqdm(total=total_size) as progress_bar:
|
| for block in response.iter_content(
|
| chunk_size=io.DEFAULT_BUFFER_SIZE
|
| ):
|
| file.write(block)
|
| progress_bar.update(len(block))
|
|
|
|
|
| def _get_google_drive_file_id(url: str) -> Optional[str]:
|
| parts = urlparse(url)
|
|
|
| if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
|
| return None
|
|
|
| match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
|
| if match is None:
|
| return None
|
|
|
| return match.group("id")
|
|
|
|
|
| def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
|
| with open(filename, "wb") as fh:
|
| with urllib.request.urlopen(
|
| urllib.request.Request(url, headers={"User-Agent": "vissl"})
|
| ) as response:
|
| with tqdm(total=response.length) as pbar:
|
| for chunk in iter(lambda: response.read(chunk_size), ""):
|
| if not chunk:
|
| break
|
| pbar.update(chunk_size)
|
| fh.write(chunk)
|
|
|
|
|
| def download_url(
|
| url: str,
|
| root: str,
|
| filename: Optional[str] = None,
|
| md5: Optional[str] = None,
|
| ) -> None:
|
| """Download a file from a url and place it in root.
|
| Args:
|
| url (str): URL to download file from
|
| root (str): Directory to place downloaded file in
|
| filename (str, optional): Name to save the file under.
|
| If None, use the basename of the URL.
|
| md5 (str, optional): MD5 checksum of the download. If None, do not check
|
| """
|
| root = os.path.expanduser(root)
|
| if not filename:
|
| filename = os.path.basename(url)
|
| fpath = os.path.join(root, filename)
|
|
|
| makedir(root)
|
|
|
|
|
| if check_integrity(fpath, md5):
|
| print("Using downloaded and verified file: " + fpath)
|
| return
|
|
|
|
|
| url = get_redirected_url(url)
|
|
|
|
|
| file_id = _get_google_drive_file_id(url)
|
| if file_id is not None:
|
| return download_file_from_google_drive(file_id, root, filename, md5)
|
|
|
|
|
| try:
|
| print("Downloading " + url + " to " + fpath)
|
| _urlretrieve(url, fpath)
|
| except (urllib.error.URLError, IOError) as e:
|
| if url[:5] == "https":
|
| url = url.replace("https:", "http:")
|
| print(
|
| "Failed download. Trying https -> http instead."
|
| " Downloading " + url + " to " + fpath
|
| )
|
| _urlretrieve(url, fpath)
|
| else:
|
| raise e
|
|
|
|
|
| if not check_integrity(fpath, md5):
|
| raise RuntimeError("File not found or corrupted.")
|
|
|
|
|
| def download_and_extract_archive(
|
| url: str,
|
| download_root: str,
|
| extract_root: Optional[str] = None,
|
| filename: Optional[str] = None,
|
| md5: Optional[str] = None,
|
| remove_finished: bool = False,
|
| ) -> None:
|
| download_root = os.path.expanduser(download_root)
|
| if extract_root is None:
|
| extract_root = download_root
|
| if not filename:
|
| filename = os.path.basename(url)
|
|
|
| download_url(url, download_root, filename, md5)
|
|
|
| archive = os.path.join(download_root, filename)
|
| print("Extracting {} to {}".format(archive, extract_root))
|
| extract_archive(archive, extract_root, remove_finished)
|
|
|
|
|
| def cache_url(url: str, cache_dir: str) -> str:
|
| """
|
| This implementation downloads the remote resource and caches it locally.
|
| The resource will only be downloaded if not previously requested.
|
| """
|
| parsed_url = urlparse(url)
|
| dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
|
| makedir(dirname)
|
| filename = url.split("/")[-1]
|
| cached = os.path.join(dirname, filename)
|
| with file_lock(cached):
|
| if not os.path.isfile(cached):
|
| logging.info(f"Downloading {url} to {cached} ...")
|
| cached = download(url, dirname, filename=filename)
|
| logging.info(f"URL {url} cached in {cached}")
|
| return cached
|
|
|
|
|
|
|
| def create_file_symlink(file1, file2):
|
| """
|
| Simply create the symlinks for a given file1 to file2.
|
| Useful during model checkpointing to symlinks to the
|
| latest successful checkpoint.
|
| """
|
| try:
|
| if g_pathmgr.exists(file2):
|
| g_pathmgr.rm(file2)
|
| g_pathmgr.symlink(file1, file2)
|
| except Exception as e:
|
| logging.info(f"Could NOT create symlink. Error: {e}")
|
|
|
|
|
| def save_file(data, filename, append_to_json=True, verbose=True):
|
| """
|
| Common i/o utility to handle saving data to various file formats.
|
| Supported:
|
| .pkl, .pickle, .npy, .json
|
| Specifically for .json, users have the option to either append (default)
|
| or rewrite by passing in Boolean value to append_to_json.
|
| """
|
| if verbose:
|
| logging.info(f"Saving data to file: {filename}")
|
| file_ext = os.path.splitext(filename)[1]
|
| if file_ext in [".pkl", ".pickle"]:
|
| with g_pathmgr.open(filename, "wb") as fopen:
|
| pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
|
| elif file_ext == ".npy":
|
| with g_pathmgr.open(filename, "wb") as fopen:
|
| np.save(fopen, data)
|
| elif file_ext == ".json":
|
| if append_to_json:
|
| with g_pathmgr.open(filename, "a") as fopen:
|
| fopen.write(json.dumps(data, sort_keys=True) + "\n")
|
| fopen.flush()
|
| else:
|
| with g_pathmgr.open(filename, "w") as fopen:
|
| fopen.write(json.dumps(data, sort_keys=True) + "\n")
|
| fopen.flush()
|
| elif file_ext == ".yaml":
|
| with g_pathmgr.open(filename, "w") as fopen:
|
| dump = yaml.dump(data)
|
| fopen.write(dump)
|
| fopen.flush()
|
| else:
|
| raise Exception(f"Saving {file_ext} is not supported yet")
|
|
|
| if verbose:
|
| logging.info(f"Saved data to file: {filename}")
|
|
|
|
|
| def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
|
| """
|
| Common i/o utility to handle loading data from various file formats.
|
| Supported:
|
| .pkl, .pickle, .npy, .json
|
| For the npy files, we support reading the files in mmap_mode.
|
| If the mmap_mode of reading is not successful, we load data without the
|
| mmap_mode.
|
| """
|
| if verbose:
|
| logging.info(f"Loading data from file: {filename}")
|
|
|
| file_ext = os.path.splitext(filename)[1]
|
| if file_ext == ".txt":
|
| with g_pathmgr.open(filename, "r") as fopen:
|
| data = fopen.readlines()
|
| elif file_ext in [".pkl", ".pickle"]:
|
| with g_pathmgr.open(filename, "rb") as fopen:
|
| data = pickle.load(fopen, encoding="latin1")
|
| elif file_ext == ".npy":
|
| if mmap_mode:
|
| try:
|
| with g_pathmgr.open(filename, "rb") as fopen:
|
| data = np.load(
|
| fopen,
|
| allow_pickle=allow_pickle,
|
| encoding="latin1",
|
| mmap_mode=mmap_mode,
|
| )
|
| except ValueError as e:
|
| logging.info(
|
| f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
|
| )
|
| data = np.load(
|
| filename,
|
| allow_pickle=allow_pickle,
|
| encoding="latin1",
|
| mmap_mode=mmap_mode,
|
| )
|
| logging.info("Successfully loaded without g_pathmgr")
|
| except Exception:
|
| logging.info("Could not mmap without g_pathmgr. Trying without mmap")
|
| with g_pathmgr.open(filename, "rb") as fopen:
|
| data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
|
| else:
|
| with g_pathmgr.open(filename, "rb") as fopen:
|
| data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
|
| elif file_ext == ".json":
|
| with g_pathmgr.open(filename, "r") as fopen:
|
| data = json.load(fopen)
|
| elif file_ext == ".yaml":
|
| with g_pathmgr.open(filename, "r") as fopen:
|
| data = yaml.load(fopen, Loader=yaml.FullLoader)
|
| elif file_ext == ".csv":
|
| with g_pathmgr.open(filename, "r") as fopen:
|
| data = pd.read_csv(fopen)
|
| else:
|
| raise Exception(f"Reading from {file_ext} is not supported yet")
|
| return data
|
|
|
|
|
| def abspath(resource_path: str):
|
| """
|
| Make a path absolute, but take into account prefixes like
|
| "http://" or "manifold://"
|
| """
|
| regex = re.compile(r"^\w+://")
|
| if regex.match(resource_path) is None:
|
| return os.path.abspath(resource_path)
|
| else:
|
| return resource_path
|
|
|
|
|
| def makedir(dir_path):
|
| """
|
| Create the directory if it does not exist.
|
| """
|
| is_success = False
|
| try:
|
| if not g_pathmgr.exists(dir_path):
|
| g_pathmgr.mkdirs(dir_path)
|
| is_success = True
|
| except BaseException:
|
| logging.info(f"Error creating directory: {dir_path}")
|
| return is_success
|
|
|
|
|
| def is_url(input_url):
|
| """
|
| Check if an input string is a url. look for http(s):// and ignoring the case
|
| """
|
| is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
|
| return is_url
|
|
|
|
|
| def download_and_untar(url):
|
| cached_file = download_cached_file(
|
| url, check_hash=False, progress=True
|
| )
|
|
|
| untarred_dir = os.path.basename(url).split(".")[0]
|
| parent_dir = os.path.dirname(cached_file)
|
|
|
| full_dir = os.path.join(parent_dir, untarred_dir)
|
|
|
| if not os.path.exists(full_dir):
|
| with tarfile.open(cached_file) as tar:
|
| tar.extractall(parent_dir)
|
|
|
| return full_dir
|
|
|
| def cleanup_dir(dir):
|
| """
|
| Utility for deleting a directory. Useful for cleaning the storage space
|
| that contains various training artifacts like checkpoints, data etc.
|
| """
|
| if os.path.exists(dir):
|
| logging.info(f"Deleting directory: {dir}")
|
| shutil.rmtree(dir)
|
| logging.info(f"Deleted contents of directory: {dir}")
|
|
|
|
|
| def get_file_size(filename):
|
| """
|
| Given a file, get the size of file in MB
|
| """
|
| size_in_mb = os.path.getsize(filename) / float(1024**2)
|
| return size_in_mb
|
|
|