File size: 9,300 Bytes
7344bef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
from __future__ import annotations
from functools import lru_cache
import os
from pathlib import Path, PurePosixPath, PureWindowsPath
from typing import Iterable, List, Optional, Union

default_checkpoints_paths = ["ckpts", "."]

_checkpoints_paths = default_checkpoints_paths


def _is_probable_url(value: str) -> bool:
    return "://" in str(value) or str(value).startswith(("mailto:", "urn:"))


def _absolute_normalized_path(path: Union[str, os.PathLike]) -> str:
    return os.path.abspath(os.path.normpath(os.path.expanduser(os.fspath(path))))


def _checkpoint_roots() -> list[str]:
    roots = []
    seen = set()
    for root in _checkpoints_paths:
        normalized = _absolute_normalized_path(root)
        key = normalized.casefold()
        if key not in seen:
            roots.append(normalized)
            seen.add(key)
    return roots


def _is_under_root(path: str, root: str) -> bool:
    try:
        return os.path.commonpath([path, root]).casefold() == root.casefold()
    except ValueError:
        return False


def compress_path(path: Union[str, os.PathLike]) -> str:
    """Store checkpoint-root paths as relative paths; leave URLs unchanged."""
    if path is None:
        return ""
    value = os.fspath(path).strip()
    if not value or _is_probable_url(value) or value.startswith("="):
        return value
    if not os.path.isabs(value):
        normalized_relative = os.path.normpath(value)
        if is_relative_down_path(normalized_relative):
            return normalized_relative.replace("\\", "/")
        normalized = _absolute_normalized_path(value)
    else:
        normalized = _absolute_normalized_path(value)
    for root in sorted(_checkpoint_roots(), key=len, reverse=True):
        if not _is_under_root(normalized, root):
            continue
        relative = os.path.relpath(normalized, root)
        if relative and relative != ".":
            return relative.replace("\\", "/")
    return normalized


def uncompress_path(path: Union[str, os.PathLike]) -> str:
    """Return an absolute local path for checkpoint-relative values; leave URLs unchanged."""
    if path is None:
        return ""
    value = os.fspath(path).strip()
    if not value or _is_probable_url(value) or value.startswith("="):
        return value
    if os.path.isabs(value):
        return _absolute_normalized_path(value)
    if not is_relative_down_path(value):
        return _absolute_normalized_path(value)
    located = locate_file(value, error_if_none=False)
    if located is not None:
        return _absolute_normalized_path(located)
    roots = _checkpoint_roots()
    return _absolute_normalized_path(os.path.join(roots[0] if roots else ".", value))


def compress_paths(paths):
    if isinstance(paths, (list, tuple)):
        return [compress_path(path) for path in paths]
    return compress_path(paths)


def uncompress_paths(paths):
    if isinstance(paths, (list, tuple)):
        return [uncompress_path(path) for path in paths]
    return uncompress_path(paths)

@lru_cache(maxsize=4096)
def _is_relative_down_path_cached(path: str) -> bool:
    if len(path) == 0 or "\x00" in path:
        return False
    windows_path = PureWindowsPath(path)
    posix_path = PurePosixPath(path)
    if windows_path.drive or windows_path.root or posix_path.root:
        return False
    if ".." in windows_path.parts or ".." in posix_path.parts:
        return False
    return any(part not in ("", ".") for part in path.replace("\\", "/").split("/"))

def is_relative_down_path(path: Union[str, os.PathLike]) -> bool:
    """Return True for relative paths that cannot escape a base folder."""
    try:
        path = os.fspath(path).strip()
    except TypeError:
        return False
    if not isinstance(path, str):
        return False
    return _is_relative_down_path_cached(path)

def clean_relative_path(path, trigger_error = True):
    if path=="" or path is None: return path
    if is_relative_down_path(path): return path
    if not trigger_error: return ""
    raise Exception(f"Unsafe relative path found : '{path}'")

def set_checkpoints_paths(checkpoints_paths):
    global _checkpoints_paths
    _checkpoints_paths = [path.strip() for path in checkpoints_paths if len(path.strip()) > 0 ]
    if len(checkpoints_paths) == 0:
        _checkpoints_paths = default_checkpoints_paths

def _normalize_force_path(force_path):
    if force_path is not None and isinstance(force_path, list) and len(force_path):
        force_path = force_path[0]
    if force_path is None:
        return None
    force_path = os.fspath(force_path).strip()
    if len(force_path) == 0:
        return None
    normalized = os.path.normpath(force_path)
    return None if normalized in ("", ".") else normalized

def extract_alternate_path(url, lora_dir = None):
    if not url.startswith("http"):
        if "|" in url:
            raise f"local path {url} can't contain a '|'"
        return url
    
    path_parts = url.split("|")
    new_url = os.path.basename(path_parts[0]) 
    if len(path_parts) == 1: return new_url
    if len(path_parts) != 2: raise f"Invalid path {url}"
    alternate_path = clean_relative_path(path_parts[1])
    if alternate_path == "%lora_dir":
        if lora_dir is None:
            raise Exception(f"Unable to compute %lora_dir in {url}, no lora_dir was provided")
        alternate_path = os.path.abspath(lora_dir)
    return os.path.join(alternate_path, new_url) 

def get_download_location(file_name = None, force_path= None, lora_dir = None):
    if file_name is not None:
        file_name = extract_alternate_path(file_name, lora_dir)
        if os.path.isabs(file_name): return file_name
    if force_path is not None and isinstance(force_path, list) and len(force_path): force_path = force_path[0]
    if file_name is not None:
        if force_path is None:
            return os.path.join(_checkpoints_paths[0], file_name)
        else:
            return os.path.join(_checkpoints_paths[0], force_path, file_name)
    else:
        if force_path is None:
            return _checkpoints_paths[0]
        else:
            return os.path.join(_checkpoints_paths[0], force_path,)

def get_smart_download_root(force_path = None):
    force_path = _normalize_force_path(force_path)
    if force_path is None:
        return _checkpoints_paths[0]
    if os.path.isabs(force_path):
        return force_path
    for folder in _checkpoints_paths:
        candidate = os.path.join(folder, force_path)
        if os.path.isdir(candidate):
            return folder
    return _checkpoints_paths[0]

def get_smart_download_location(file_name = None, force_path = None):
    if file_name is not None:
        file_name = extract_alternate_path(file_name)
        if os.path.isabs(file_name):
            return file_name
    force_path = _normalize_force_path(force_path)
    if force_path is None:
        return get_download_location(file_name)
    if os.path.isabs(force_path):
        return force_path if file_name is None else os.path.join(force_path, file_name)
    root = get_smart_download_root(force_path)
    base_path = os.path.join(root, force_path)
    return base_path if file_name is None else os.path.join(base_path, file_name)

def locate_folder(folder_name, error_if_none = True):
    searched_locations = []
    if os.path.isabs(folder_name):
        if os.path.isdir(folder_name): return folder_name
        searched_locations.append(folder_name)
    else:
        for folder in _checkpoints_paths:
            path = os.path.join(folder, folder_name)
            if os.path.isdir(path):
                return path
            searched_locations.append(os.path.abspath(path))
    if error_if_none: raise Exception(f"Unable to locate folder '{folder_name}', tried {searched_locations}")    
    return None


def locate_file(file_name, create_path_if_none = False, error_if_none = True, extra_paths = None):
    if file_name.startswith("http"):
        file_name = os.path.basename(file_name)
    searched_locations = []
    if os.path.isabs(file_name):
        if os.path.isfile(file_name): return file_name
        searched_locations.append(file_name)
    else:
        for folder in _checkpoints_paths + ([] if extra_paths is None else extra_paths):
            path = os.path.join(folder, file_name)
            if os.path.isfile(path):
                return path
            searched_locations.append(os.path.abspath(path))
    
    if create_path_if_none:
        return get_download_location(file_name)
    if error_if_none: raise Exception(f"Unable to locate file '{file_name}', tried {searched_locations}")
    return None

def get_local_model_filename(model_filename, use_locator = True, extra_paths = None, lora_dir = None):
    local_model_filename = extract_alternate_path(model_filename, lora_dir)
    if use_locator:
        if extra_paths is not None and not os.path.isabs(local_model_filename):
            if not isinstance(extra_paths, list): extra_paths = [extra_paths]
            for path in extra_paths:
                filename = locate_file(os.path.join(path, local_model_filename), error_if_none= False)
                if filename is not None: return filename
        local_model_filename = locate_file(local_model_filename, error_if_none= False )
    return local_model_filename