File size: 9,300 Bytes
7344bef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 | from __future__ import annotations
from functools import lru_cache
import os
from pathlib import Path, PurePosixPath, PureWindowsPath
from typing import Iterable, List, Optional, Union
default_checkpoints_paths = ["ckpts", "."]
_checkpoints_paths = default_checkpoints_paths
def _is_probable_url(value: str) -> bool:
return "://" in str(value) or str(value).startswith(("mailto:", "urn:"))
def _absolute_normalized_path(path: Union[str, os.PathLike]) -> str:
return os.path.abspath(os.path.normpath(os.path.expanduser(os.fspath(path))))
def _checkpoint_roots() -> list[str]:
roots = []
seen = set()
for root in _checkpoints_paths:
normalized = _absolute_normalized_path(root)
key = normalized.casefold()
if key not in seen:
roots.append(normalized)
seen.add(key)
return roots
def _is_under_root(path: str, root: str) -> bool:
try:
return os.path.commonpath([path, root]).casefold() == root.casefold()
except ValueError:
return False
def compress_path(path: Union[str, os.PathLike]) -> str:
"""Store checkpoint-root paths as relative paths; leave URLs unchanged."""
if path is None:
return ""
value = os.fspath(path).strip()
if not value or _is_probable_url(value) or value.startswith("="):
return value
if not os.path.isabs(value):
normalized_relative = os.path.normpath(value)
if is_relative_down_path(normalized_relative):
return normalized_relative.replace("\\", "/")
normalized = _absolute_normalized_path(value)
else:
normalized = _absolute_normalized_path(value)
for root in sorted(_checkpoint_roots(), key=len, reverse=True):
if not _is_under_root(normalized, root):
continue
relative = os.path.relpath(normalized, root)
if relative and relative != ".":
return relative.replace("\\", "/")
return normalized
def uncompress_path(path: Union[str, os.PathLike]) -> str:
"""Return an absolute local path for checkpoint-relative values; leave URLs unchanged."""
if path is None:
return ""
value = os.fspath(path).strip()
if not value or _is_probable_url(value) or value.startswith("="):
return value
if os.path.isabs(value):
return _absolute_normalized_path(value)
if not is_relative_down_path(value):
return _absolute_normalized_path(value)
located = locate_file(value, error_if_none=False)
if located is not None:
return _absolute_normalized_path(located)
roots = _checkpoint_roots()
return _absolute_normalized_path(os.path.join(roots[0] if roots else ".", value))
def compress_paths(paths):
if isinstance(paths, (list, tuple)):
return [compress_path(path) for path in paths]
return compress_path(paths)
def uncompress_paths(paths):
if isinstance(paths, (list, tuple)):
return [uncompress_path(path) for path in paths]
return uncompress_path(paths)
@lru_cache(maxsize=4096)
def _is_relative_down_path_cached(path: str) -> bool:
if len(path) == 0 or "\x00" in path:
return False
windows_path = PureWindowsPath(path)
posix_path = PurePosixPath(path)
if windows_path.drive or windows_path.root or posix_path.root:
return False
if ".." in windows_path.parts or ".." in posix_path.parts:
return False
return any(part not in ("", ".") for part in path.replace("\\", "/").split("/"))
def is_relative_down_path(path: Union[str, os.PathLike]) -> bool:
"""Return True for relative paths that cannot escape a base folder."""
try:
path = os.fspath(path).strip()
except TypeError:
return False
if not isinstance(path, str):
return False
return _is_relative_down_path_cached(path)
def clean_relative_path(path, trigger_error = True):
if path=="" or path is None: return path
if is_relative_down_path(path): return path
if not trigger_error: return ""
raise Exception(f"Unsafe relative path found : '{path}'")
def set_checkpoints_paths(checkpoints_paths):
global _checkpoints_paths
_checkpoints_paths = [path.strip() for path in checkpoints_paths if len(path.strip()) > 0 ]
if len(checkpoints_paths) == 0:
_checkpoints_paths = default_checkpoints_paths
def _normalize_force_path(force_path):
if force_path is not None and isinstance(force_path, list) and len(force_path):
force_path = force_path[0]
if force_path is None:
return None
force_path = os.fspath(force_path).strip()
if len(force_path) == 0:
return None
normalized = os.path.normpath(force_path)
return None if normalized in ("", ".") else normalized
def extract_alternate_path(url, lora_dir = None):
if not url.startswith("http"):
if "|" in url:
raise f"local path {url} can't contain a '|'"
return url
path_parts = url.split("|")
new_url = os.path.basename(path_parts[0])
if len(path_parts) == 1: return new_url
if len(path_parts) != 2: raise f"Invalid path {url}"
alternate_path = clean_relative_path(path_parts[1])
if alternate_path == "%lora_dir":
if lora_dir is None:
raise Exception(f"Unable to compute %lora_dir in {url}, no lora_dir was provided")
alternate_path = os.path.abspath(lora_dir)
return os.path.join(alternate_path, new_url)
def get_download_location(file_name = None, force_path= None, lora_dir = None):
if file_name is not None:
file_name = extract_alternate_path(file_name, lora_dir)
if os.path.isabs(file_name): return file_name
if force_path is not None and isinstance(force_path, list) and len(force_path): force_path = force_path[0]
if file_name is not None:
if force_path is None:
return os.path.join(_checkpoints_paths[0], file_name)
else:
return os.path.join(_checkpoints_paths[0], force_path, file_name)
else:
if force_path is None:
return _checkpoints_paths[0]
else:
return os.path.join(_checkpoints_paths[0], force_path,)
def get_smart_download_root(force_path = None):
force_path = _normalize_force_path(force_path)
if force_path is None:
return _checkpoints_paths[0]
if os.path.isabs(force_path):
return force_path
for folder in _checkpoints_paths:
candidate = os.path.join(folder, force_path)
if os.path.isdir(candidate):
return folder
return _checkpoints_paths[0]
def get_smart_download_location(file_name = None, force_path = None):
if file_name is not None:
file_name = extract_alternate_path(file_name)
if os.path.isabs(file_name):
return file_name
force_path = _normalize_force_path(force_path)
if force_path is None:
return get_download_location(file_name)
if os.path.isabs(force_path):
return force_path if file_name is None else os.path.join(force_path, file_name)
root = get_smart_download_root(force_path)
base_path = os.path.join(root, force_path)
return base_path if file_name is None else os.path.join(base_path, file_name)
def locate_folder(folder_name, error_if_none = True):
searched_locations = []
if os.path.isabs(folder_name):
if os.path.isdir(folder_name): return folder_name
searched_locations.append(folder_name)
else:
for folder in _checkpoints_paths:
path = os.path.join(folder, folder_name)
if os.path.isdir(path):
return path
searched_locations.append(os.path.abspath(path))
if error_if_none: raise Exception(f"Unable to locate folder '{folder_name}', tried {searched_locations}")
return None
def locate_file(file_name, create_path_if_none = False, error_if_none = True, extra_paths = None):
if file_name.startswith("http"):
file_name = os.path.basename(file_name)
searched_locations = []
if os.path.isabs(file_name):
if os.path.isfile(file_name): return file_name
searched_locations.append(file_name)
else:
for folder in _checkpoints_paths + ([] if extra_paths is None else extra_paths):
path = os.path.join(folder, file_name)
if os.path.isfile(path):
return path
searched_locations.append(os.path.abspath(path))
if create_path_if_none:
return get_download_location(file_name)
if error_if_none: raise Exception(f"Unable to locate file '{file_name}', tried {searched_locations}")
return None
def get_local_model_filename(model_filename, use_locator = True, extra_paths = None, lora_dir = None):
local_model_filename = extract_alternate_path(model_filename, lora_dir)
if use_locator:
if extra_paths is not None and not os.path.isabs(local_model_filename):
if not isinstance(extra_paths, list): extra_paths = [extra_paths]
for path in extra_paths:
filename = locate_file(os.path.join(path, local_model_filename), error_if_none= False)
if filename is not None: return filename
local_model_filename = locate_file(local_model_filename, error_if_none= False )
return local_model_filename
|