File size: 12,807 Bytes
6a07ce1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459ac47
 
 
 
 
6a07ce1
 
 
 
 
 
 
 
 
 
 
 
 
8cdb001
 
6a07ce1
 
 
 
 
 
 
 
 
 
8cdb001
6a07ce1
 
8cdb001
6a07ce1
 
 
 
 
 
 
 
 
 
 
 
 
8cdb001
6a07ce1
 
 
8cdb001
6a07ce1
 
8cdb001
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a07ce1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8cdb001
 
6a07ce1
8cdb001
 
 
6a07ce1
8cdb001
 
6a07ce1
 
 
 
 
 
8cdb001
6a07ce1
 
 
 
 
 
570384a
6a07ce1
 
 
 
 
 
 
 
 
 
 
60d66bd
6a07ce1
 
8cdb001
60d66bd
 
 
 
 
 
 
 
 
 
 
6a07ce1
60d66bd
8cdb001
60d66bd
 
 
 
8cdb001
60d66bd
8cdb001
60d66bd
 
 
 
8cdb001
60d66bd
8cdb001
60d66bd
 
 
 
 
 
 
 
8cdb001
6a07ce1
 
 
 
 
 
8cdb001
60d66bd
 
459ac47
 
6a07ce1
 
 
 
 
 
 
60d66bd
6a07ce1
 
 
 
 
570384a
6a07ce1
 
 
 
 
 
 
459ac47
b1e7bdb
459ac47
6a07ce1
 
60d66bd
6a07ce1
 
 
 
 
 
 
b1e7bdb
 
60d66bd
6a07ce1
 
 
 
 
60d66bd
8cdb001
60d66bd
 
 
 
b1e7bdb
60d66bd
 
 
 
 
 
6a07ce1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1e7bdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a07ce1
 
 
 
 
 
8cdb001
6a07ce1
 
 
 
 
 
 
8cdb001
6a07ce1
 
8cdb001
6a07ce1
 
 
8cdb001
6a07ce1
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
"""Download utilities for SDXL Model Merger with Gradio progress integration."""

import re
import requests
from pathlib import Path
from tqdm import tqdm as TqdmBase

from .config import download_cancelled


def extract_model_id(url: str) -> str | None:
    """Extract CivitAI model ID from URL."""
    match = re.search(r'/models/(\d+)', url)
    return match.group(1) if match else None


def is_huggingface_url(url: str) -> bool:
    """Check if URL is a HuggingFace model download URL."""
    return "huggingface.co" in url.lower()


def get_safe_filename_from_url(
    url: str,
    default_name: str = "model.safetensors",
    suffix: str = "",
    type_prefix: str | None = None
) -> str:
    """
    Generate a safe filename with model ID from URL.

    For CivitAI URLs like https://civitai.com/api/download/models/12345?type=...

    Naming patterns:
    - Checkpoint (type_prefix='model'): 12345_model.safetensors or 12345_model_anime_style.safetensors
    - VAE (suffix='_vae'): 12345_vae.safetensors (no name extraction to avoid double suffix)
    - LoRA (suffix='_lora'): 12345_lora.safetensors (no name extraction to avoid double suffix)

    For HuggingFace URLs without model IDs, attempts to extract name from path or uses suffix-based naming.

    Args:
        url: The download URL
        default_name: Fallback filename if extraction fails
        suffix: Optional suffix to append before .safetensors (e.g., '_vae', '_lora')
        type_prefix: Optional prefix after model_id (e.g., 'model' -> 12345_model.safetensors)
    """
    model_id = extract_model_id(url)

    # If no CivitAI model ID, try to generate a name from HuggingFace path
    if not model_id and "huggingface.co" in url:
        # Try to extract name from URL path (e.g., sdxl-vae-fp16-fix -> fp16_fix)
        try:
            parts = url.split("huggingface.co/")[1] if "huggingface.co/" in url else ""
            if parts:
                # Get the repo name (second part after org/)
                path_parts = [p for p in parts.split("/") if p]
                if len(path_parts) >= 2:
                    repo_name = path_parts[1]
                    # Clean up and create a simple identifier
                    clean_repo = re.sub(r'[^a-zA-Z0-9]', '_', repo_name)[:30].strip('_')
                    if clean_repo:
                        model_id = f"hf_{clean_repo}"
        except Exception:
            pass

    if not model_id:
        return default_name

    # Special handling for VAE/LoRA with HuggingFace URLs to avoid double suffix
    is_special_type = suffix in ("_vae", "_lora")
    
    # Strip common suffixes from model_id when adding corresponding suffix
    # (e.g., "sdxl_vae_fp16_fix" + "_vae" -> "sdxl_fp16_fix" + "_vae")
    if is_special_type:
        strip_suffix = suffix.lstrip('_')  # "vae" or "lora"
        model_id_lower = model_id.lower()
        # Check if model_id contains the type (with underscore boundaries)
        if f"_{strip_suffix}_" in model_id_lower or model_id_lower.endswith(f"_{strip_suffix}"):
            # Remove the suffix from model_id
            if model_id_lower.endswith(f"_{strip_suffix}"):
                model_id = model_id[:-len(strip_suffix)-1]
            else:
                # Find and remove _suffix_ pattern
                pattern = f"_{strip_suffix}_"
                idx = model_id_lower.find(pattern)
                if idx >= 0:
                    model_id = model_id[:idx] + model_id[idx+len(pattern):]

    # Build the name portion: either clean name from URL or fallback
    name_part = ""

    # For VAE/LoRA types, skip Content-Disposition parsing to avoid double naming
    # (e.g., sdxl_vae_vae instead of just vae)
    if not is_special_type:
        try:
            response = requests.head(url, timeout=10, allow_redirects=True)
            cd = response.headers.get('Content-Disposition', '')
            match = re.search(r'filename="([^"]+)"', cd)
            if match:
                filename = match.group(1)
                # Extract base name without extension
                base_name = Path(filename).stem
                # Clean up the name (remove special chars)
                clean_name = re.sub(r'[^\w\s-]', '', base_name)[:50]
                clean_name = re.sub(r'[-\s]+', '_', clean_name.strip('-_'))
                if clean_name:
                    name_part = clean_name
        except Exception:
            pass

    # Build filename with model_id, optional type_prefix, optional name_part, and suffix
    parts = [model_id]
    if type_prefix:
        parts.append(type_prefix)
    if name_part:
        parts.append(name_part)

    # Handle suffix - for VAE/LoRA we only add the suffix, not double naming
    if suffix:
        if is_special_type:
            # For _vae and _lora: just use model_id + suffix directly
            return f"{model_id}{suffix}.safetensors"
        else:
            # For other types (checkpoint), append suffix after name_part
            parts.append(suffix.lstrip('_'))

    return '_'.join(p for p in parts if p).replace('__', '_') + '.safetensors'


class TqdmGradio(TqdmBase):
    """tqdm subclass that sends progress updates to Gradio's gr.Progress()"""

    def __init__(self, *args, gradio_prog=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.gradio_prog = gradio_prog
        self.last_pct = 0

    def update(self, n=1):
        from .config import download_cancelled
        if download_cancelled:
            raise KeyboardInterrupt("Download cancelled by user")
        super().update(n)
        if self.gradio_prog and self.total:
            pct = int(100 * self.n / self.total)
            # Only update UI every ~5% to avoid spamming
            if pct != self.last_pct and pct % 5 == 0:
                self.last_pct = pct
                self.gradio_prog(pct / 100)


def get_cached_file_size(url: str, suffix: str = "", type_prefix: str | None = None) -> tuple[Path | None, int | None]:
    """
    Check if file exists in cache and matches expected size.

    Uses the same filename generation logic as download operations to find
    cached files by URL.

    Args:
        url: The download URL to check for cached file
        suffix: Optional suffix (e.g., '_vae', '_lora') for special file types
        type_prefix: Optional prefix after model_id (e.g., 'model')

    Returns:
        Tuple of (cached_file_path, file_size) if valid cache exists,
        or (None, None) if no valid cache found.
    """
    from .config import CACHE_DIR

    # Generate the expected filename for this URL
    default_name = "vae.safetensors" if suffix == "_vae" else (
        "lora.safetensors" if suffix == "_lora" else "model.safetensors"
    )

    cached_filename = get_safe_filename_from_url(
        url,
        default_name=default_name,
        suffix=suffix,
        type_prefix=type_prefix
    )

    cached_path = CACHE_DIR / cached_filename

    if cached_path.exists() and cached_path.is_file():
        try:
            file_size = cached_path.stat().st_size
            # Only return valid cache if file has content
            if file_size > 0:
                return cached_path, file_size
        except OSError:
            pass

    return None, None


def download_file_with_progress(url: str, output_path: Path, progress_bar=None) -> Path:
    """
    Download a file with Gradio-synced progress bar + cancel support.

    Checks for existing cached files before downloading. If a valid cache
    exists (file exists with matching expected size), skips re-download.
    
    Supports both HTTP(S) and HuggingFace Hub URLs.

    Args:
        url: File URL to download (http/https/file)
        output_path: Destination path for downloaded file
        progress_bar: Optional gr.Progress() object for UI updates

    Returns:
        Path to the downloaded (or cached) file

    Raises:
        KeyboardInterrupt: If download is cancelled
        requests.RequestException: If download fails
    """
    from .config import download_cancelled

    # Handle local file:// URLs
    if url.startswith("file://"):
        local_path = Path(url[7:])  # Remove "file://" prefix
        if local_path.exists():
            import shutil
            output_path.parent.mkdir(parents=True, exist_ok=True)

            print(f"  📁 Copying from cache: {local_path.name}{output_path.name}")

            # Copy the file to cache location
            shutil.copy2(str(local_path), str(output_path))

            # Update progress bar for cached files
            if progress_bar:
                progress_bar(1.0)
            return output_path
        else:
            raise FileNotFoundError(f"Local file not found: {local_path}")

    print(f"  📥 Downloading to cache: {output_path.name}")
    
    # Early cache check: if file exists and size matches URL's content-length, skip re-download
    expected_size = None
    try:
        head = requests.head(url, timeout=10)
        expected_size = int(head.headers.get('content-length', 0))
    except Exception:
        pass  # Skip header fetch on errors

    if output_path.exists() and expected_size is not None:
        try:
            cached_size = output_path.stat().st_size
            if cached_size == expected_size:
                print(f"  ✅ Cache hit: {output_path.name} ({cached_size / (1024**2):.1f} MB)")
                # Cache hit - file exists with correct size
                if progress_bar:
                    progress_bar(1.0)
                return output_path  # Skip re-download!
        except OSError:
            pass  # File access error, proceed with download

    output_path.parent.mkdir(parents=True, exist_ok=True)

    session = requests.Session()
    response = session.get(url, stream=True, timeout=30)
    response.raise_for_status()

    total_size = expected_size or int(response.headers.get('content-length', 0))
    block_size = 8192

    # Use TqdmGradio to sync progress with Gradio
    tqdm_kwargs = {
        'unit': 'B',
        'unit_scale': True,
        'desc': f"Downloading {output_path.name}",
        'gradio_prog': progress_bar,
        'disable': False,
        'bar_format': '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]',
    }

    with open(output_path, "wb") as f:
        try:
            for data in TqdmGradio(
                response.iter_content(block_size),
                total=total_size // block_size if total_size else 0,
                **tqdm_kwargs,
            ):
                if download_cancelled:
                    raise KeyboardInterrupt("Download cancelled by user")
                f.write(data)
        except KeyboardInterrupt:
            # Clean partial file on cancel
            output_path.unlink(missing_ok=True)
            raise

    # Verify the downloaded file is complete
    try:
        actual_size = output_path.stat().st_size
        
        # For safetensors files, check header is valid
        if output_path.suffix == ".safetensors":
            import struct
            with open(output_path, "rb") as f:
                header_size_bytes = f.read(8)
                if len(header_size_bytes) < 8:
                    raise OSError(f"Safetensors file too small: {output_path.name}")
                
                header_size = struct.unpack("<Q", header_size_bytes)[0]
                header = f.read(header_size)
                if len(header) < header_size:
                    raise OSError(f"Incomplete safetensors header in {output_path.name}")
                
                import json
                json.loads(header.decode("utf-8"))
        
        # Verify size matches expected (if known)
        if expected_size is not None and actual_size != expected_size:
            print(f"  ⚠️ Size mismatch: expected {expected_size}, got {actual_size}")
            
    except Exception as e:
        output_path.unlink(missing_ok=True)
        raise OSError(f"Invalid downloaded file {output_path.name}: {str(e)}")

    return output_path


def clear_cache(cache_dir: Path = None, keep_extensions: list[str] = None):
    """
    Remove old cache files.

    Args:
        cache_dir: Cache directory path (defaults to config.CACHE_DIR)
        keep_extensions: File extensions to preserve (default: ['.safetensors'])
    """
    if cache_dir is None:
        from .config import CACHE_DIR
        cache_dir = CACHE_DIR

    if keep_extensions is None:
        keep_extensions = ['.safetensors']

    # Remove temp files
    for file in cache_dir.glob("*.tmp"):
        file.unlink()

    # Optional: age-based cleanup (7 days)
    # import time
    # cutoff = time.time() - 86400 * 7
    # for f in cache_dir.iterdir():
    #     if f.is_file() and f.stat().st_mtime < cutoff:
    #         f.unlink()