from __future__ import annotations from functools import lru_cache import os from pathlib import Path, PurePosixPath, PureWindowsPath from typing import Iterable, List, Optional, Union default_checkpoints_paths = ["ckpts", "."] _checkpoints_paths = default_checkpoints_paths def _is_probable_url(value: str) -> bool: return "://" in str(value) or str(value).startswith(("mailto:", "urn:")) def _absolute_normalized_path(path: Union[str, os.PathLike]) -> str: return os.path.abspath(os.path.normpath(os.path.expanduser(os.fspath(path)))) def _checkpoint_roots() -> list[str]: roots = [] seen = set() for root in _checkpoints_paths: normalized = _absolute_normalized_path(root) key = normalized.casefold() if key not in seen: roots.append(normalized) seen.add(key) return roots def _is_under_root(path: str, root: str) -> bool: try: return os.path.commonpath([path, root]).casefold() == root.casefold() except ValueError: return False def compress_path(path: Union[str, os.PathLike]) -> str: """Store checkpoint-root paths as relative paths; leave URLs unchanged.""" if path is None: return "" value = os.fspath(path).strip() if not value or _is_probable_url(value) or value.startswith("="): return value if not os.path.isabs(value): normalized_relative = os.path.normpath(value) if is_relative_down_path(normalized_relative): return normalized_relative.replace("\\", "/") normalized = _absolute_normalized_path(value) else: normalized = _absolute_normalized_path(value) for root in sorted(_checkpoint_roots(), key=len, reverse=True): if not _is_under_root(normalized, root): continue relative = os.path.relpath(normalized, root) if relative and relative != ".": return relative.replace("\\", "/") return normalized def uncompress_path(path: Union[str, os.PathLike]) -> str: """Return an absolute local path for checkpoint-relative values; leave URLs unchanged.""" if path is None: return "" value = os.fspath(path).strip() if not value or _is_probable_url(value) or value.startswith("="): return value if os.path.isabs(value): return _absolute_normalized_path(value) if not is_relative_down_path(value): return _absolute_normalized_path(value) located = locate_file(value, error_if_none=False) if located is not None: return _absolute_normalized_path(located) roots = _checkpoint_roots() return _absolute_normalized_path(os.path.join(roots[0] if roots else ".", value)) def compress_paths(paths): if isinstance(paths, (list, tuple)): return [compress_path(path) for path in paths] return compress_path(paths) def uncompress_paths(paths): if isinstance(paths, (list, tuple)): return [uncompress_path(path) for path in paths] return uncompress_path(paths) @lru_cache(maxsize=4096) def _is_relative_down_path_cached(path: str) -> bool: if len(path) == 0 or "\x00" in path: return False windows_path = PureWindowsPath(path) posix_path = PurePosixPath(path) if windows_path.drive or windows_path.root or posix_path.root: return False if ".." in windows_path.parts or ".." in posix_path.parts: return False return any(part not in ("", ".") for part in path.replace("\\", "/").split("/")) def is_relative_down_path(path: Union[str, os.PathLike]) -> bool: """Return True for relative paths that cannot escape a base folder.""" try: path = os.fspath(path).strip() except TypeError: return False if not isinstance(path, str): return False return _is_relative_down_path_cached(path) def clean_relative_path(path, trigger_error = True): if path=="" or path is None: return path if is_relative_down_path(path): return path if not trigger_error: return "" raise Exception(f"Unsafe relative path found : '{path}'") def set_checkpoints_paths(checkpoints_paths): global _checkpoints_paths _checkpoints_paths = [path.strip() for path in checkpoints_paths if len(path.strip()) > 0 ] if len(checkpoints_paths) == 0: _checkpoints_paths = default_checkpoints_paths def _normalize_force_path(force_path): if force_path is not None and isinstance(force_path, list) and len(force_path): force_path = force_path[0] if force_path is None: return None force_path = os.fspath(force_path).strip() if len(force_path) == 0: return None normalized = os.path.normpath(force_path) return None if normalized in ("", ".") else normalized def extract_alternate_path(url, lora_dir = None): if not url.startswith("http"): if "|" in url: raise f"local path {url} can't contain a '|'" return url path_parts = url.split("|") new_url = os.path.basename(path_parts[0]) if len(path_parts) == 1: return new_url if len(path_parts) != 2: raise f"Invalid path {url}" alternate_path = clean_relative_path(path_parts[1]) if alternate_path == "%lora_dir": if lora_dir is None: raise Exception(f"Unable to compute %lora_dir in {url}, no lora_dir was provided") alternate_path = os.path.abspath(lora_dir) return os.path.join(alternate_path, new_url) def get_download_location(file_name = None, force_path= None, lora_dir = None): if file_name is not None: file_name = extract_alternate_path(file_name, lora_dir) if os.path.isabs(file_name): return file_name if force_path is not None and isinstance(force_path, list) and len(force_path): force_path = force_path[0] if file_name is not None: if force_path is None: return os.path.join(_checkpoints_paths[0], file_name) else: return os.path.join(_checkpoints_paths[0], force_path, file_name) else: if force_path is None: return _checkpoints_paths[0] else: return os.path.join(_checkpoints_paths[0], force_path,) def get_smart_download_root(force_path = None): force_path = _normalize_force_path(force_path) if force_path is None: return _checkpoints_paths[0] if os.path.isabs(force_path): return force_path for folder in _checkpoints_paths: candidate = os.path.join(folder, force_path) if os.path.isdir(candidate): return folder return _checkpoints_paths[0] def get_smart_download_location(file_name = None, force_path = None): if file_name is not None: file_name = extract_alternate_path(file_name) if os.path.isabs(file_name): return file_name force_path = _normalize_force_path(force_path) if force_path is None: return get_download_location(file_name) if os.path.isabs(force_path): return force_path if file_name is None else os.path.join(force_path, file_name) root = get_smart_download_root(force_path) base_path = os.path.join(root, force_path) return base_path if file_name is None else os.path.join(base_path, file_name) def locate_folder(folder_name, error_if_none = True): searched_locations = [] if os.path.isabs(folder_name): if os.path.isdir(folder_name): return folder_name searched_locations.append(folder_name) else: for folder in _checkpoints_paths: path = os.path.join(folder, folder_name) if os.path.isdir(path): return path searched_locations.append(os.path.abspath(path)) if error_if_none: raise Exception(f"Unable to locate folder '{folder_name}', tried {searched_locations}") return None def locate_file(file_name, create_path_if_none = False, error_if_none = True, extra_paths = None): if file_name.startswith("http"): file_name = os.path.basename(file_name) searched_locations = [] if os.path.isabs(file_name): if os.path.isfile(file_name): return file_name searched_locations.append(file_name) else: for folder in _checkpoints_paths + ([] if extra_paths is None else extra_paths): path = os.path.join(folder, file_name) if os.path.isfile(path): return path searched_locations.append(os.path.abspath(path)) if create_path_if_none: return get_download_location(file_name) if error_if_none: raise Exception(f"Unable to locate file '{file_name}', tried {searched_locations}") return None def get_local_model_filename(model_filename, use_locator = True, extra_paths = None, lora_dir = None): local_model_filename = extract_alternate_path(model_filename, lora_dir) if use_locator: if extra_paths is not None and not os.path.isabs(local_model_filename): if not isinstance(extra_paths, list): extra_paths = [extra_paths] for path in extra_paths: filename = locate_file(os.path.join(path, local_model_filename), error_if_none= False) if filename is not None: return filename local_model_filename = locate_file(local_model_filename, error_if_none= False ) return local_model_filename