| """ |
| LoRA Loader — Scans the models/loras directory for installed .safetensors files. |
| |
| Additive module (Golden Rule 1.0). |
| Does NOT modify any existing model loading or generation code. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import struct |
| from pathlib import Path |
| from typing import Dict, List, Optional |
|
|
| from ..providers import get_comfy_models_path |
|
|
|
|
| |
| LORA_COMPAT: Dict[str, set] = { |
| "sd1.5": {"sd15"}, |
| "sdxl": {"sdxl", "pony_xl", "noobai_xl", "noobai_xl_vpred"}, |
| "pony": {"pony_xl", "sdxl", "noobai_xl"}, |
| "flux": {"flux_schnell", "flux_dev"}, |
| } |
|
|
| |
| ARCH_LABELS: Dict[str, str] = { |
| "sd15": "SD1.5", |
| "sdxl": "SDXL", |
| "pony_xl": "Pony XL", |
| "noobai_xl": "NoobAI XL", |
| "noobai_xl_vpred": "NoobAI XL V-Pred", |
| "flux_schnell": "Flux Schnell", |
| "flux_dev": "Flux Dev", |
| } |
|
|
| |
| LORA_BASE_LABELS: Dict[str, str] = { |
| "sd1.5": "SD1.5", |
| "sdxl": "SDXL", |
| "pony": "Pony", |
| "flux": "Flux", |
| } |
|
|
| |
| MIN_LORA_SIZE = 100 * 1024 |
|
|
|
|
| def get_lora_dir() -> Path: |
| """Return the path to the LoRA models directory.""" |
| return get_comfy_models_path() / "loras" |
|
|
|
|
| def _get_registry_lookup() -> Dict[str, "LoRAEntry"]: |
| """Lazily build id→LoRAEntry map from the registry. |
| |
| Keys include both the entry ``id`` and the filename stem, so files |
| that were manually renamed still match their registry metadata. |
| """ |
| try: |
| from .lora_registry import SFW_LORAS, NSFW_LORAS |
| lookup: Dict[str, "LoRAEntry"] = {} |
| for e in (*SFW_LORAS, *NSFW_LORAS): |
| lookup[e.id] = e |
| |
| stem = e.filename.rsplit(".", 1)[0] if "." in e.filename else e.filename |
| if stem != e.id: |
| lookup.setdefault(stem, e) |
| return lookup |
| except Exception: |
| return {} |
|
|
|
|
| def _fmt_bytes(n: int) -> str: |
| """Format byte count as human-readable string.""" |
| if n < 1024: |
| return f"{n} B" |
| elif n < 1024 * 1024: |
| return f"{n / 1024:.1f} KB" |
| elif n < 1024 * 1024 * 1024: |
| return f"{n / (1024 * 1024):.1f} MB" |
| return f"{n / (1024 * 1024 * 1024):.2f} GB" |
|
|
|
|
| def detect_lora_architecture(path: Path) -> Optional[str]: |
| """Detect LoRA architecture by inspecting safetensors tensor key names. |
| |
| Reads only the JSON header (no tensor data loaded). |
| |
| Returns: |
| "sd1.5", "sdxl", "flux", or None if unknown. |
| """ |
| if path.suffix.lower() != ".safetensors": |
| return None |
|
|
| try: |
| with open(path, "rb") as f: |
| header_len_bytes = f.read(8) |
| if len(header_len_bytes) < 8: |
| return None |
| header_len = struct.unpack("<Q", header_len_bytes)[0] |
| if header_len > 100 * 1024 * 1024: |
| return None |
| header_bytes = f.read(header_len) |
| if len(header_bytes) < header_len: |
| return None |
| header = json.loads(header_bytes) |
| except Exception: |
| return None |
|
|
| keys = [k for k in header.keys() if k != "__metadata__"] |
| if not keys: |
| return None |
|
|
| key_str = " ".join(keys) |
|
|
| |
| if "lora_transformer_" in key_str or "transformer.single" in key_str: |
| return "flux" |
|
|
| |
| |
| |
| |
| has_tb1 = any("transformer_blocks_1" in k for k in keys) |
|
|
| |
| |
| for k, meta in header.items(): |
| if k == "__metadata__": |
| continue |
| if "attn2" in k and "to_k" in k and "shape" in meta: |
| shape = meta["shape"] |
| if isinstance(shape, list) and len(shape) == 2: |
| context_dim = shape[-1] |
| if context_dim == 2048: |
| return "sdxl" |
| elif context_dim == 768: |
| return "sd1.5" |
|
|
| |
| if has_tb1: |
| return "sdxl" |
|
|
| return None |
|
|
|
|
| def validate_safetensors_file(path: Path) -> Dict: |
| """Validate a safetensors file by reading its header. |
| |
| Returns dict with: |
| - healthy: True if file is valid, False if corrupt |
| - error: error message if corrupt, "" if healthy |
| - file_size: size in bytes |
| - file_size_human: human-readable size |
| |
| Industry standard: safetensors files start with an 8-byte little-endian |
| uint64 header length, followed by a JSON header of that length, followed |
| by tensor data covering the rest of the file. |
| """ |
| result: Dict = { |
| "healthy": False, |
| "error": "", |
| "file_size": 0, |
| "file_size_human": "", |
| } |
|
|
| try: |
| file_size = path.stat().st_size |
| result["file_size"] = file_size |
| result["file_size_human"] = _fmt_bytes(file_size) |
| except OSError as e: |
| result["error"] = f"Cannot stat file: {e}" |
| return result |
|
|
| |
| if file_size < MIN_LORA_SIZE: |
| result["error"] = f"File too small ({result['file_size_human']}) — likely incomplete download" |
| return result |
|
|
| |
| if path.suffix.lower() == ".safetensors": |
| |
| try: |
| import safetensors |
| safetensors.safe_open(str(path), framework="pt", device="cpu").__enter__() |
| |
| result["healthy"] = True |
| return result |
| except ImportError: |
| pass |
| except Exception as e: |
| result["error"] = f"safetensors validation failed: {e}" |
| return result |
|
|
| try: |
| with open(path, "rb") as f: |
| |
| header_len_bytes = f.read(8) |
| if len(header_len_bytes) < 8: |
| result["error"] = "File truncated — cannot read header length" |
| return result |
|
|
| header_len = struct.unpack("<Q", header_len_bytes)[0] |
|
|
| |
| if header_len > 100 * 1024 * 1024 or header_len > file_size - 8: |
| result["error"] = "Invalid header length — file is corrupt or truncated" |
| return result |
|
|
| |
| header_bytes = f.read(header_len) |
| if len(header_bytes) < header_len: |
| result["error"] = "Incomplete header — file truncated during download" |
| return result |
|
|
| |
| header = json.loads(header_bytes) |
|
|
| |
| |
| |
| data_start = 8 + header_len |
| max_end = 0 |
| has_tensors = False |
| for key, meta in header.items(): |
| if key == "__metadata__": |
| continue |
| offsets = meta.get("data_offsets") |
| if isinstance(offsets, (list, tuple)) and len(offsets) == 2: |
| has_tensors = True |
| if offsets[1] > max_end: |
| max_end = offsets[1] |
|
|
| if has_tensors: |
| expected_size = data_start + max_end |
| if file_size != expected_size: |
| result["error"] = ( |
| f"File size mismatch — expected {_fmt_bytes(expected_size)} " |
| f"but got {_fmt_bytes(file_size)}. " |
| f"Incomplete or corrupt download" |
| ) |
| return result |
| elif file_size < data_start: |
| result["error"] = "File size mismatch — incomplete download" |
| return result |
|
|
| except json.JSONDecodeError: |
| result["error"] = "Corrupt header — invalid JSON metadata" |
| return result |
| except Exception as e: |
| result["error"] = f"Header read error: {e}" |
| return result |
|
|
| |
| result["healthy"] = True |
| return result |
|
|
|
|
| def is_lora_compatible(lora_base: str, checkpoint_arch: str) -> Optional[bool]: |
| """Check if a LoRA base is compatible with a checkpoint architecture. |
| |
| Returns: |
| True — compatible |
| False — incompatible |
| None — unknown (no metadata) |
| """ |
| if not lora_base or not checkpoint_arch: |
| return None |
| allowed = LORA_COMPAT.get(lora_base, set()) |
| if not allowed: |
| return None |
| return checkpoint_arch in allowed |
|
|
|
|
| def scan_installed_loras() -> List[Dict]: |
| """Scan the loras directory and return a list of installed LoRA files. |
| |
| Each entry contains: |
| - id: filename without extension |
| - filename: full filename |
| - path: absolute path |
| - enabled: False (default state) |
| - weight: 0.8 (default weight) |
| - base: LoRA base architecture (e.g. "sd1.5", "sdxl") or "" if unknown |
| - base_label: human-readable base label (e.g. "SD1.5") or "" if unknown |
| - healthy: True if file is valid, False if corrupt |
| - health_error: error description if corrupt, "" if healthy |
| - file_size: size in bytes |
| - file_size_human: human-readable size string |
| - gated: True if LoRA is NSFW/gated (only show when spicy mode enabled) |
| """ |
| lora_dir = get_lora_dir() |
| if not lora_dir.exists(): |
| return [] |
|
|
| registry = _get_registry_lookup() |
|
|
| results: List[Dict] = [] |
| for f in sorted(lora_dir.iterdir()): |
| if f.suffix.lower() in (".safetensors", ".pt", ".ckpt"): |
| lora_id = f.stem |
| entry = registry.get(lora_id) |
| base = entry.base if entry else "" |
| gated = entry.gated if entry else False |
| health = validate_safetensors_file(f) |
| results.append({ |
| "id": lora_id, |
| "filename": f.name, |
| "path": str(f), |
| "enabled": False, |
| "weight": 0.8, |
| "base": base, |
| "base_label": LORA_BASE_LABELS.get(base, ""), |
| "healthy": health["healthy"], |
| "health_error": health["error"], |
| "file_size": health["file_size"], |
| "file_size_human": health["file_size_human"], |
| "gated": gated, |
| }) |
|
|
| return results |
|
|