ColabWan / shared /utils /files_locator.py
1ripon1's picture
Upload folder using huggingface_hub
7344bef verified
Raw
History Blame Contribute Delete
9.3 kB
from __future__ import annotations
from functools import lru_cache
import os
from pathlib import Path, PurePosixPath, PureWindowsPath
from typing import Iterable, List, Optional, Union
default_checkpoints_paths = ["ckpts", "."]
_checkpoints_paths = default_checkpoints_paths
def _is_probable_url(value: str) -> bool:
return "://" in str(value) or str(value).startswith(("mailto:", "urn:"))
def _absolute_normalized_path(path: Union[str, os.PathLike]) -> str:
return os.path.abspath(os.path.normpath(os.path.expanduser(os.fspath(path))))
def _checkpoint_roots() -> list[str]:
roots = []
seen = set()
for root in _checkpoints_paths:
normalized = _absolute_normalized_path(root)
key = normalized.casefold()
if key not in seen:
roots.append(normalized)
seen.add(key)
return roots
def _is_under_root(path: str, root: str) -> bool:
try:
return os.path.commonpath([path, root]).casefold() == root.casefold()
except ValueError:
return False
def compress_path(path: Union[str, os.PathLike]) -> str:
"""Store checkpoint-root paths as relative paths; leave URLs unchanged."""
if path is None:
return ""
value = os.fspath(path).strip()
if not value or _is_probable_url(value) or value.startswith("="):
return value
if not os.path.isabs(value):
normalized_relative = os.path.normpath(value)
if is_relative_down_path(normalized_relative):
return normalized_relative.replace("\\", "/")
normalized = _absolute_normalized_path(value)
else:
normalized = _absolute_normalized_path(value)
for root in sorted(_checkpoint_roots(), key=len, reverse=True):
if not _is_under_root(normalized, root):
continue
relative = os.path.relpath(normalized, root)
if relative and relative != ".":
return relative.replace("\\", "/")
return normalized
def uncompress_path(path: Union[str, os.PathLike]) -> str:
"""Return an absolute local path for checkpoint-relative values; leave URLs unchanged."""
if path is None:
return ""
value = os.fspath(path).strip()
if not value or _is_probable_url(value) or value.startswith("="):
return value
if os.path.isabs(value):
return _absolute_normalized_path(value)
if not is_relative_down_path(value):
return _absolute_normalized_path(value)
located = locate_file(value, error_if_none=False)
if located is not None:
return _absolute_normalized_path(located)
roots = _checkpoint_roots()
return _absolute_normalized_path(os.path.join(roots[0] if roots else ".", value))
def compress_paths(paths):
if isinstance(paths, (list, tuple)):
return [compress_path(path) for path in paths]
return compress_path(paths)
def uncompress_paths(paths):
if isinstance(paths, (list, tuple)):
return [uncompress_path(path) for path in paths]
return uncompress_path(paths)
@lru_cache(maxsize=4096)
def _is_relative_down_path_cached(path: str) -> bool:
if len(path) == 0 or "\x00" in path:
return False
windows_path = PureWindowsPath(path)
posix_path = PurePosixPath(path)
if windows_path.drive or windows_path.root or posix_path.root:
return False
if ".." in windows_path.parts or ".." in posix_path.parts:
return False
return any(part not in ("", ".") for part in path.replace("\\", "/").split("/"))
def is_relative_down_path(path: Union[str, os.PathLike]) -> bool:
"""Return True for relative paths that cannot escape a base folder."""
try:
path = os.fspath(path).strip()
except TypeError:
return False
if not isinstance(path, str):
return False
return _is_relative_down_path_cached(path)
def clean_relative_path(path, trigger_error = True):
if path=="" or path is None: return path
if is_relative_down_path(path): return path
if not trigger_error: return ""
raise Exception(f"Unsafe relative path found : '{path}'")
def set_checkpoints_paths(checkpoints_paths):
global _checkpoints_paths
_checkpoints_paths = [path.strip() for path in checkpoints_paths if len(path.strip()) > 0 ]
if len(checkpoints_paths) == 0:
_checkpoints_paths = default_checkpoints_paths
def _normalize_force_path(force_path):
if force_path is not None and isinstance(force_path, list) and len(force_path):
force_path = force_path[0]
if force_path is None:
return None
force_path = os.fspath(force_path).strip()
if len(force_path) == 0:
return None
normalized = os.path.normpath(force_path)
return None if normalized in ("", ".") else normalized
def extract_alternate_path(url, lora_dir = None):
if not url.startswith("http"):
if "|" in url:
raise f"local path {url} can't contain a '|'"
return url
path_parts = url.split("|")
new_url = os.path.basename(path_parts[0])
if len(path_parts) == 1: return new_url
if len(path_parts) != 2: raise f"Invalid path {url}"
alternate_path = clean_relative_path(path_parts[1])
if alternate_path == "%lora_dir":
if lora_dir is None:
raise Exception(f"Unable to compute %lora_dir in {url}, no lora_dir was provided")
alternate_path = os.path.abspath(lora_dir)
return os.path.join(alternate_path, new_url)
def get_download_location(file_name = None, force_path= None, lora_dir = None):
if file_name is not None:
file_name = extract_alternate_path(file_name, lora_dir)
if os.path.isabs(file_name): return file_name
if force_path is not None and isinstance(force_path, list) and len(force_path): force_path = force_path[0]
if file_name is not None:
if force_path is None:
return os.path.join(_checkpoints_paths[0], file_name)
else:
return os.path.join(_checkpoints_paths[0], force_path, file_name)
else:
if force_path is None:
return _checkpoints_paths[0]
else:
return os.path.join(_checkpoints_paths[0], force_path,)
def get_smart_download_root(force_path = None):
force_path = _normalize_force_path(force_path)
if force_path is None:
return _checkpoints_paths[0]
if os.path.isabs(force_path):
return force_path
for folder in _checkpoints_paths:
candidate = os.path.join(folder, force_path)
if os.path.isdir(candidate):
return folder
return _checkpoints_paths[0]
def get_smart_download_location(file_name = None, force_path = None):
if file_name is not None:
file_name = extract_alternate_path(file_name)
if os.path.isabs(file_name):
return file_name
force_path = _normalize_force_path(force_path)
if force_path is None:
return get_download_location(file_name)
if os.path.isabs(force_path):
return force_path if file_name is None else os.path.join(force_path, file_name)
root = get_smart_download_root(force_path)
base_path = os.path.join(root, force_path)
return base_path if file_name is None else os.path.join(base_path, file_name)
def locate_folder(folder_name, error_if_none = True):
searched_locations = []
if os.path.isabs(folder_name):
if os.path.isdir(folder_name): return folder_name
searched_locations.append(folder_name)
else:
for folder in _checkpoints_paths:
path = os.path.join(folder, folder_name)
if os.path.isdir(path):
return path
searched_locations.append(os.path.abspath(path))
if error_if_none: raise Exception(f"Unable to locate folder '{folder_name}', tried {searched_locations}")
return None
def locate_file(file_name, create_path_if_none = False, error_if_none = True, extra_paths = None):
if file_name.startswith("http"):
file_name = os.path.basename(file_name)
searched_locations = []
if os.path.isabs(file_name):
if os.path.isfile(file_name): return file_name
searched_locations.append(file_name)
else:
for folder in _checkpoints_paths + ([] if extra_paths is None else extra_paths):
path = os.path.join(folder, file_name)
if os.path.isfile(path):
return path
searched_locations.append(os.path.abspath(path))
if create_path_if_none:
return get_download_location(file_name)
if error_if_none: raise Exception(f"Unable to locate file '{file_name}', tried {searched_locations}")
return None
def get_local_model_filename(model_filename, use_locator = True, extra_paths = None, lora_dir = None):
local_model_filename = extract_alternate_path(model_filename, lora_dir)
if use_locator:
if extra_paths is not None and not os.path.isabs(local_model_filename):
if not isinstance(extra_paths, list): extra_paths = [extra_paths]
for path in extra_paths:
filename = locate_file(os.path.join(path, local_model_filename), error_if_none= False)
if filename is not None: return filename
local_model_filename = locate_file(local_model_filename, error_if_none= False )
return local_model_filename