File size: 11,360 Bytes
23b413b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
"""
LoRA Loader — Scans the models/loras directory for installed .safetensors files.

Additive module (Golden Rule 1.0).
Does NOT modify any existing model loading or generation code.
"""

from __future__ import annotations

import json
import struct
from pathlib import Path
from typing import Dict, List, Optional

from ..providers import get_comfy_models_path


# LoRA base → compatible checkpoint architectures (mirrors comfy.py COMPAT)
LORA_COMPAT: Dict[str, set] = {
    "sd1.5": {"sd15"},
    "sdxl": {"sdxl", "pony_xl", "noobai_xl", "noobai_xl_vpred"},
    "pony": {"pony_xl", "sdxl", "noobai_xl"},
    "flux": {"flux_schnell", "flux_dev"},
}

# Friendly labels for architecture display
ARCH_LABELS: Dict[str, str] = {
    "sd15": "SD1.5",
    "sdxl": "SDXL",
    "pony_xl": "Pony XL",
    "noobai_xl": "NoobAI XL",
    "noobai_xl_vpred": "NoobAI XL V-Pred",
    "flux_schnell": "Flux Schnell",
    "flux_dev": "Flux Dev",
}

# Friendly labels for LoRA base
LORA_BASE_LABELS: Dict[str, str] = {
    "sd1.5": "SD1.5",
    "sdxl": "SDXL",
    "pony": "Pony",
    "flux": "Flux",
}

# Minimum plausible LoRA file size (100 KB — anything smaller is corrupt/empty)
MIN_LORA_SIZE = 100 * 1024


def get_lora_dir() -> Path:
    """Return the path to the LoRA models directory."""
    return get_comfy_models_path() / "loras"


def _get_registry_lookup() -> Dict[str, "LoRAEntry"]:  # noqa: F821
    """Lazily build id→LoRAEntry map from the registry.

    Keys include both the entry ``id`` and the filename stem, so files
    that were manually renamed still match their registry metadata.
    """
    try:
        from .lora_registry import SFW_LORAS, NSFW_LORAS
        lookup: Dict[str, "LoRAEntry"] = {}
        for e in (*SFW_LORAS, *NSFW_LORAS):
            lookup[e.id] = e
            # Also index by filename stem for manual-download resilience
            stem = e.filename.rsplit(".", 1)[0] if "." in e.filename else e.filename
            if stem != e.id:
                lookup.setdefault(stem, e)
        return lookup
    except Exception:
        return {}


def _fmt_bytes(n: int) -> str:
    """Format byte count as human-readable string."""
    if n < 1024:
        return f"{n} B"
    elif n < 1024 * 1024:
        return f"{n / 1024:.1f} KB"
    elif n < 1024 * 1024 * 1024:
        return f"{n / (1024 * 1024):.1f} MB"
    return f"{n / (1024 * 1024 * 1024):.2f} GB"


def detect_lora_architecture(path: Path) -> Optional[str]:
    """Detect LoRA architecture by inspecting safetensors tensor key names.

    Reads only the JSON header (no tensor data loaded).

    Returns:
        "sd1.5", "sdxl", "flux", or None if unknown.
    """
    if path.suffix.lower() != ".safetensors":
        return None

    try:
        with open(path, "rb") as f:
            header_len_bytes = f.read(8)
            if len(header_len_bytes) < 8:
                return None
            header_len = struct.unpack("<Q", header_len_bytes)[0]
            if header_len > 100 * 1024 * 1024:
                return None
            header_bytes = f.read(header_len)
            if len(header_bytes) < header_len:
                return None
            header = json.loads(header_bytes)
    except Exception:
        return None

    keys = [k for k in header.keys() if k != "__metadata__"]
    if not keys:
        return None

    key_str = " ".join(keys)

    # Flux LoRAs have distinctive "lora_transformer_" or "transformer.single" keys
    if "lora_transformer_" in key_str or "transformer.single" in key_str:
        return "flux"

    # SDXL LoRAs reference transformer_blocks_1+ (multiple blocks per layer)
    # and have "input_blocks_4_1_transformer_blocks_1" or similar patterns.
    # Also SDXL has label_emb / time_embed keys in some LoRAs.
    # SD1.5 only has transformer_blocks_0 per spatial layer.
    has_tb1 = any("transformer_blocks_1" in k for k in keys)

    # Cross-attention dimension check: inspect a cross-attn tensor shape.
    # SD1.5 uses 768-dim context (CLIP ViT-L), SDXL uses 2048-dim.
    for k, meta in header.items():
        if k == "__metadata__":
            continue
        if "attn2" in k and "to_k" in k and "shape" in meta:
            shape = meta["shape"]
            if isinstance(shape, list) and len(shape) == 2:
                context_dim = shape[-1]
                if context_dim == 2048:
                    return "sdxl"
                elif context_dim == 768:
                    return "sd1.5"

    # Fallback: if transformer_blocks_1 keys exist, likely SDXL
    if has_tb1:
        return "sdxl"

    return None


def validate_safetensors_file(path: Path) -> Dict:
    """Validate a safetensors file by reading its header.

    Returns dict with:
      - healthy: True if file is valid, False if corrupt
      - error: error message if corrupt, "" if healthy
      - file_size: size in bytes
      - file_size_human: human-readable size

    Industry standard: safetensors files start with an 8-byte little-endian
    uint64 header length, followed by a JSON header of that length, followed
    by tensor data covering the rest of the file.
    """
    result: Dict = {
        "healthy": False,
        "error": "",
        "file_size": 0,
        "file_size_human": "",
    }

    try:
        file_size = path.stat().st_size
        result["file_size"] = file_size
        result["file_size_human"] = _fmt_bytes(file_size)
    except OSError as e:
        result["error"] = f"Cannot stat file: {e}"
        return result

    # Check 1: minimum size
    if file_size < MIN_LORA_SIZE:
        result["error"] = f"File too small ({result['file_size_human']}) — likely incomplete download"
        return result

    # Check 2: safetensors header structure
    if path.suffix.lower() == ".safetensors":
        # Fast path: use safetensors library if available (same check ComfyUI uses)
        try:
            import safetensors
            safetensors.safe_open(str(path), framework="pt", device="cpu").__enter__()
            # If we get here, safetensors_rust accepted the file
            result["healthy"] = True
            return result
        except ImportError:
            pass  # Fall back to manual header check below
        except Exception as e:
            result["error"] = f"safetensors validation failed: {e}"
            return result

        try:
            with open(path, "rb") as f:
                # Read 8-byte header length
                header_len_bytes = f.read(8)
                if len(header_len_bytes) < 8:
                    result["error"] = "File truncated — cannot read header length"
                    return result

                header_len = struct.unpack("<Q", header_len_bytes)[0]

                # Sanity: header should be < 100MB and < file size
                if header_len > 100 * 1024 * 1024 or header_len > file_size - 8:
                    result["error"] = "Invalid header length — file is corrupt or truncated"
                    return result

                # Read and parse JSON header
                header_bytes = f.read(header_len)
                if len(header_bytes) < header_len:
                    result["error"] = "Incomplete header — file truncated during download"
                    return result

                # Validate it's valid JSON
                header = json.loads(header_bytes)

                # Check 3: verify tensor data fully covers the file
                # Each tensor entry has "data_offsets": [begin, end]
                # The max end offset + 8 + header_len should equal file_size
                data_start = 8 + header_len
                max_end = 0
                has_tensors = False
                for key, meta in header.items():
                    if key == "__metadata__":
                        continue
                    offsets = meta.get("data_offsets")
                    if isinstance(offsets, (list, tuple)) and len(offsets) == 2:
                        has_tensors = True
                        if offsets[1] > max_end:
                            max_end = offsets[1]

                if has_tensors:
                    expected_size = data_start + max_end
                    if file_size != expected_size:
                        result["error"] = (
                            f"File size mismatch — expected {_fmt_bytes(expected_size)} "
                            f"but got {_fmt_bytes(file_size)}. "
                            f"Incomplete or corrupt download"
                        )
                        return result
                elif file_size < data_start:
                    result["error"] = "File size mismatch — incomplete download"
                    return result

        except json.JSONDecodeError:
            result["error"] = "Corrupt header — invalid JSON metadata"
            return result
        except Exception as e:
            result["error"] = f"Header read error: {e}"
            return result

    # For .pt/.ckpt files, basic size check only (no standard header format)
    result["healthy"] = True
    return result


def is_lora_compatible(lora_base: str, checkpoint_arch: str) -> Optional[bool]:
    """Check if a LoRA base is compatible with a checkpoint architecture.

    Returns:
        True  — compatible
        False — incompatible
        None  — unknown (no metadata)
    """
    if not lora_base or not checkpoint_arch:
        return None
    allowed = LORA_COMPAT.get(lora_base, set())
    if not allowed:
        return None
    return checkpoint_arch in allowed


def scan_installed_loras() -> List[Dict]:
    """Scan the loras directory and return a list of installed LoRA files.

    Each entry contains:
      - id: filename without extension
      - filename: full filename
      - path: absolute path
      - enabled: False (default state)
      - weight: 0.8 (default weight)
      - base: LoRA base architecture (e.g. "sd1.5", "sdxl") or "" if unknown
      - base_label: human-readable base label (e.g. "SD1.5") or "" if unknown
      - healthy: True if file is valid, False if corrupt
      - health_error: error description if corrupt, "" if healthy
      - file_size: size in bytes
      - file_size_human: human-readable size string
      - gated: True if LoRA is NSFW/gated (only show when spicy mode enabled)
    """
    lora_dir = get_lora_dir()
    if not lora_dir.exists():
        return []

    registry = _get_registry_lookup()

    results: List[Dict] = []
    for f in sorted(lora_dir.iterdir()):
        if f.suffix.lower() in (".safetensors", ".pt", ".ckpt"):
            lora_id = f.stem
            entry = registry.get(lora_id)
            base = entry.base if entry else ""
            gated = entry.gated if entry else False
            health = validate_safetensors_file(f)
            results.append({
                "id": lora_id,
                "filename": f.name,
                "path": str(f),
                "enabled": False,
                "weight": 0.8,
                "base": base,
                "base_label": LORA_BASE_LABELS.get(base, ""),
                "healthy": health["healthy"],
                "health_error": health["error"],
                "file_size": health["file_size"],
                "file_size_human": health["file_size_human"],
                "gated": gated,
            })

    return results