|
|
import concurrent.futures |
|
|
import datetime |
|
|
import logging |
|
|
import os |
|
|
import pathlib |
|
|
import re |
|
|
import shutil |
|
|
import stat |
|
|
import time |
|
|
import urllib.parse |
|
|
|
|
|
import filelock |
|
|
import fsspec |
|
|
import fsspec.generic |
|
|
import tqdm_loggable.auto as tqdm |
|
|
|
|
|
|
|
|
_OPENPI_DATA_HOME = "OPENPI_DATA_HOME" |
|
|
DEFAULT_CACHE_DIR = "~/.cache/openpi" |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def get_cache_dir() -> pathlib.Path: |
|
|
cache_dir = pathlib.Path(os.getenv(_OPENPI_DATA_HOME, DEFAULT_CACHE_DIR)).expanduser().resolve() |
|
|
cache_dir.mkdir(parents=True, exist_ok=True) |
|
|
_set_folder_permission(cache_dir) |
|
|
return cache_dir |
|
|
|
|
|
|
|
|
def maybe_download(url: str, *, force_download: bool = False, **kwargs) -> pathlib.Path: |
|
|
"""Download a file or directory from a remote filesystem to the local cache, and return the local path. |
|
|
|
|
|
If the local file already exists, it will be returned directly. |
|
|
|
|
|
It is safe to call this function concurrently from multiple processes. |
|
|
See `get_cache_dir` for more details on the cache directory. |
|
|
|
|
|
Args: |
|
|
url: URL to the file to download. |
|
|
force_download: If True, the file will be downloaded even if it already exists in the cache. |
|
|
**kwargs: Additional arguments to pass to fsspec. |
|
|
|
|
|
Returns: |
|
|
Local path to the downloaded file or directory. That path is guaranteed to exist and is absolute. |
|
|
""" |
|
|
|
|
|
parsed = urllib.parse.urlparse(url) |
|
|
|
|
|
|
|
|
if parsed.scheme == "": |
|
|
path = pathlib.Path(url) |
|
|
if not path.exists(): |
|
|
raise FileNotFoundError(f"File not found at {url}") |
|
|
return path.resolve() |
|
|
|
|
|
cache_dir = get_cache_dir() |
|
|
|
|
|
local_path = cache_dir / parsed.netloc / parsed.path.strip("/") |
|
|
local_path = local_path.resolve() |
|
|
|
|
|
|
|
|
invalidate_cache = False |
|
|
if local_path.exists(): |
|
|
if force_download or _should_invalidate_cache(cache_dir, local_path): |
|
|
invalidate_cache = True |
|
|
else: |
|
|
return local_path |
|
|
|
|
|
try: |
|
|
lock_path = local_path.with_suffix(".lock") |
|
|
with filelock.FileLock(lock_path): |
|
|
|
|
|
_ensure_permissions(lock_path) |
|
|
|
|
|
if invalidate_cache: |
|
|
logger.info(f"Removing expired cached entry: {local_path}") |
|
|
if local_path.is_dir(): |
|
|
shutil.rmtree(local_path) |
|
|
else: |
|
|
local_path.unlink() |
|
|
|
|
|
|
|
|
logger.info(f"Downloading {url} to {local_path}") |
|
|
scratch_path = local_path.with_suffix(".partial") |
|
|
_download_fsspec(url, scratch_path, **kwargs) |
|
|
|
|
|
shutil.move(scratch_path, local_path) |
|
|
_ensure_permissions(local_path) |
|
|
|
|
|
except PermissionError as e: |
|
|
msg = ( |
|
|
f"Local file permission error was encountered while downloading {url}. " |
|
|
f"Please try again after removing the cached data using: `rm -rf {local_path}*`" |
|
|
) |
|
|
raise PermissionError(msg) from e |
|
|
|
|
|
return local_path |
|
|
|
|
|
|
|
|
def _download_fsspec(url: str, local_path: pathlib.Path, **kwargs) -> None: |
|
|
"""Download a file from a remote filesystem to the local cache, and return the local path.""" |
|
|
fs, _ = fsspec.core.url_to_fs(url, **kwargs) |
|
|
info = fs.info(url) |
|
|
|
|
|
if is_dir := (info["type"] == "directory" or (info["size"] == 0 and info["name"].endswith("/"))): |
|
|
total_size = fs.du(url) |
|
|
else: |
|
|
total_size = info["size"] |
|
|
with tqdm.tqdm(total=total_size, unit="iB", unit_scale=True, unit_divisor=1024) as pbar: |
|
|
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) |
|
|
future = executor.submit(fs.get, url, local_path, recursive=is_dir) |
|
|
while not future.done(): |
|
|
current_size = sum(f.stat().st_size for f in [*local_path.rglob("*"), local_path] if f.is_file()) |
|
|
pbar.update(current_size - pbar.n) |
|
|
time.sleep(1) |
|
|
pbar.update(total_size - pbar.n) |
|
|
|
|
|
|
|
|
def _set_permission(path: pathlib.Path, target_permission: int): |
|
|
"""chmod requires executable permission to be set, so we skip if the permission is already match with the target.""" |
|
|
if path.stat().st_mode & target_permission == target_permission: |
|
|
logger.debug(f"Skipping {path} because it already has correct permissions") |
|
|
return |
|
|
path.chmod(target_permission) |
|
|
logger.debug(f"Set {path} to {target_permission}") |
|
|
|
|
|
|
|
|
def _set_folder_permission(folder_path: pathlib.Path) -> None: |
|
|
"""Set folder permission to be read, write and searchable.""" |
|
|
_set_permission(folder_path, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) |
|
|
|
|
|
|
|
|
def _ensure_permissions(path: pathlib.Path) -> None: |
|
|
"""Since we are sharing cache directory with containerized runtime as well as training script, we need to |
|
|
ensure that the cache directory has the correct permissions. |
|
|
""" |
|
|
|
|
|
def _setup_folder_permission_between_cache_dir_and_path(path: pathlib.Path) -> None: |
|
|
cache_dir = get_cache_dir() |
|
|
relative_path = path.relative_to(cache_dir) |
|
|
moving_path = cache_dir |
|
|
for part in relative_path.parts: |
|
|
_set_folder_permission(moving_path / part) |
|
|
moving_path = moving_path / part |
|
|
|
|
|
def _set_file_permission(file_path: pathlib.Path) -> None: |
|
|
"""Set all files to be read & writable, if it is a script, keep it as a script.""" |
|
|
file_rw = stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IWGRP | stat.S_IROTH | stat.S_IWOTH |
|
|
if file_path.stat().st_mode & 0o100: |
|
|
_set_permission(file_path, file_rw | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH) |
|
|
else: |
|
|
_set_permission(file_path, file_rw) |
|
|
|
|
|
_setup_folder_permission_between_cache_dir_and_path(path) |
|
|
for root, dirs, files in os.walk(str(path)): |
|
|
root_path = pathlib.Path(root) |
|
|
for file in files: |
|
|
file_path = root_path / file |
|
|
_set_file_permission(file_path) |
|
|
|
|
|
for dir in dirs: |
|
|
dir_path = root_path / dir |
|
|
_set_folder_permission(dir_path) |
|
|
|
|
|
|
|
|
def _get_mtime(year: int, month: int, day: int) -> float: |
|
|
"""Get the mtime of a given date at midnight UTC.""" |
|
|
date = datetime.datetime(year, month, day, tzinfo=datetime.UTC) |
|
|
return time.mktime(date.timetuple()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_INVALIDATE_CACHE_DIRS: dict[re.Pattern, float] = { |
|
|
re.compile("openpi-assets/checkpoints/pi0_aloha_pen_uncap"): _get_mtime(2025, 2, 17), |
|
|
re.compile("openpi-assets/checkpoints/pi0_libero"): _get_mtime(2025, 2, 6), |
|
|
re.compile("openpi-assets/checkpoints/"): _get_mtime(2025, 2, 3), |
|
|
} |
|
|
|
|
|
|
|
|
def _should_invalidate_cache(cache_dir: pathlib.Path, local_path: pathlib.Path) -> bool: |
|
|
"""Invalidate the cache if it is expired. Return True if the cache was invalidated.""" |
|
|
|
|
|
assert local_path.exists(), f"File not found at {local_path}" |
|
|
|
|
|
relative_path = str(local_path.relative_to(cache_dir)) |
|
|
for pattern, expire_time in _INVALIDATE_CACHE_DIRS.items(): |
|
|
if pattern.match(relative_path): |
|
|
|
|
|
return local_path.stat().st_mtime <= expire_time |
|
|
|
|
|
return False |
|
|
|