Spaces:
Running on Zero
Running on Zero
File size: 9,411 Bytes
b701455 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 | """Model factory for LightDiffusion-Next.
Provides automatic model type detection and instantiation.
Simplified to a single function with a registry for extensibility.
"""
import logging
import os
from typing import Optional, Type
from src.Core.AbstractModel import AbstractModel
logger = logging.getLogger(__name__)
# Model type registry - maps type names to model classes
_MODEL_REGISTRY: dict[str, Type[AbstractModel]] = {}
# SDXL detection keywords
_SDXL_INDICATORS = frozenset(["sdxl", "refiner", "hassaku", "juggernaut", "xl"])
# Flux2 Klein detection keywords
_FLUX2_KLEIN_INDICATORS = frozenset(["flux2", "klein", "flux_klein", "flux-klein", "flux2_klein", "flux-2"])
# Default paths for Flux2 components
FLUX2_DIFFUSION_MODEL_DIR = "./include/diffusion_model"
FLUX2_TEXT_ENCODER_DIR = "./include/text_encoder"
FLUX2_VAE_DIR = "./include/vae"
def _ensure_registry_populated():
"""Lazily populate registry to avoid circular imports."""
if not _MODEL_REGISTRY:
from src.Core.Models.SD15Model import SD15Model
from src.Core.Models.SDXLModel import SDXLModel
from src.Core.Models.Flux2KleinModel import Flux2KleinModel
_MODEL_REGISTRY["SD15"] = SD15Model
_MODEL_REGISTRY["SDXL"] = SDXLModel
_MODEL_REGISTRY["Flux2Klein"] = Flux2KleinModel
def _find_flux2_components() -> tuple[Optional[str], Optional[str], Optional[str]]:
"""Auto-detect Flux2 components in default directories.
Returns:
Tuple of (diffusion_model_path, text_encoder_path, vae_path)
"""
diffusion_path = None
text_encoder_path = None
vae_path = None
# Find diffusion model
if os.path.exists(FLUX2_DIFFUSION_MODEL_DIR):
for f in os.listdir(FLUX2_DIFFUSION_MODEL_DIR):
f_lower = f.lower()
if ("flux" in f_lower or "klein" in f_lower) and f.endswith((".safetensors", ".pt", ".pth")):
diffusion_path = os.path.join(FLUX2_DIFFUSION_MODEL_DIR, f)
break
# Find text encoder
if os.path.exists(FLUX2_TEXT_ENCODER_DIR):
for f in os.listdir(FLUX2_TEXT_ENCODER_DIR):
f_lower = f.lower()
if ("qwen" in f_lower or "klein" in f_lower) and f.endswith((".safetensors", ".pt", ".pth")):
text_encoder_path = os.path.join(FLUX2_TEXT_ENCODER_DIR, f)
break
# Find VAE
if os.path.exists(FLUX2_VAE_DIR):
for f in os.listdir(FLUX2_VAE_DIR):
if f.endswith((".safetensors", ".pt", ".pth")):
vae_path = os.path.join(FLUX2_VAE_DIR, f)
break
return diffusion_path, text_encoder_path, vae_path
def detect_model_type(model_path: Optional[str]) -> str:
"""Detect model type from file path.
Args:
model_path: Path to model checkpoint
Returns:
'SD15', 'SDXL', or 'Flux2Klein'
Raises:
ValueError: If GGUF file provided (unsupported)
"""
if not model_path:
return "SD15"
lp = model_path.lower()
if lp.endswith(".gguf"):
raise ValueError(f"GGUF files not supported: {model_path}")
base = os.path.basename(lp)
# Check for Flux2 Klein first (more specific)
if any(ind in base for ind in _FLUX2_KLEIN_INDICATORS):
return "Flux2Klein"
# Check for SDXL
if any(ind in base for ind in _SDXL_INDICATORS):
return "SDXL"
return "SD15"
def detect_model_type_from_state_dict(state_dict: dict) -> str:
"""Detect model type by inspecting state dict keys.
This is more accurate than filename-based detection as it
examines the actual model architecture.
Args:
state_dict: Model state dictionary
Returns:
'SD15', 'SDXL', or 'Flux2Klein'
"""
keys = set(state_dict.keys())
# Check for Flux2 Klein specific keys
flux2_indicators = [
"double_stream_modulation_img.lin.weight",
"double_stream_modulation.lin.weight",
]
for indicator in flux2_indicators:
for key in keys:
if indicator in key:
return "Flux2Klein"
# Check for double_blocks (Flux architecture)
if any("double_blocks" in k for k in keys):
return "Flux2Klein"
# Check for SDXL specific keys
sdxl_indicators = [
"conditioner.embedders",
"model.diffusion_model.label_emb.0.0.weight",
]
for indicator in sdxl_indicators:
if any(indicator in k for k in keys):
return "SDXL"
return "SD15"
def create_model(
model_path: Optional[str] = None,
model_type: Optional[str] = None,
text_encoder_path: Optional[str] = None,
vae_path: Optional[str] = None,
) -> AbstractModel:
"""Create a model instance with automatic type detection.
Args:
model_path: Path to checkpoint file (or diffusion model for Flux2)
model_type: Explicit type ('SD15', 'SDXL', 'Flux2Klein'), or None to auto-detect
text_encoder_path: Path to text encoder (Flux2 only)
vae_path: Path to VAE (Flux2 only)
Returns:
Configured model instance (not yet loaded)
Example:
# Auto-detect and load SD1.5/SDXL
model = create_model("./checkpoints/dreamer.safetensors")
model.load()
# Flux2 Klein from separate components
model = create_model(model_type="Flux2Klein") # auto-detect paths
model.load()
"""
_ensure_registry_populated()
if model_type is None:
model_type = detect_model_type(model_path)
if model_type not in _MODEL_REGISTRY:
logger.warning(f"Unknown model type '{model_type}', using SD15")
model_type = "SD15"
# Special handling for Flux2Klein - auto-detect components
if model_type == "Flux2Klein":
if model_path is None or text_encoder_path is None:
diffusion_path, te_path, vae_detected = _find_flux2_components()
model_path = model_path or diffusion_path
text_encoder_path = text_encoder_path or te_path
vae_path = vae_path or vae_detected
logger.info(f"Creating Flux2Klein model:")
logger.info(f" Diffusion model: {model_path}")
logger.info(f" Text encoder: {text_encoder_path}")
logger.info(f" VAE: {vae_path}")
return _MODEL_REGISTRY[model_type](
model_path=model_path,
text_encoder_path=text_encoder_path,
vae_path=vae_path,
)
logger.info(f"Creating {model_type} model: {model_path}")
return _MODEL_REGISTRY[model_type](model_path=model_path)
def register_model_type(type_name: str, model_class: Type[AbstractModel]) -> None:
"""Register a custom model type.
Args:
type_name: Identifier for the model type
model_class: Class inheriting from AbstractModel
"""
_ensure_registry_populated()
if not issubclass(model_class, AbstractModel):
raise TypeError(f"{model_class} must inherit from AbstractModel")
_MODEL_REGISTRY[type_name] = model_class
logger.info(f"Registered model type: {type_name}")
def list_model_types() -> list[str]:
"""List registered model types."""
_ensure_registry_populated()
return list(_MODEL_REGISTRY.keys())
def list_available_models(
checkpoint_dir: str = "./include/checkpoints/",
return_mapping: bool = False,
) -> list:
"""List available model files in the checkpoints directory.
Args:
checkpoint_dir: Directory to scan for models
return_mapping: If True, return list of (display_name, full_path) tuples
Returns:
List of model names, or list of (name, path) tuples if return_mapping=True
"""
import glob
valid_extensions = (".safetensors", ".pt", ".pth")
results = []
# Checkpoints
if os.path.isdir(checkpoint_dir):
for ext in valid_extensions:
pattern = os.path.join(checkpoint_dir, f"*{ext}")
for filepath in glob.glob(pattern):
basename = os.path.basename(filepath)
if return_mapping:
results.append((basename, filepath))
else:
results.append(basename)
# Flux2 Diffusion Models
if os.path.isdir(FLUX2_DIFFUSION_MODEL_DIR):
for ext in valid_extensions:
pattern = os.path.join(FLUX2_DIFFUSION_MODEL_DIR, f"*{ext}")
for filepath in glob.glob(pattern):
basename = os.path.basename(filepath)
if return_mapping:
results.append((basename, filepath))
else:
results.append(basename)
# Sort alphabetically
results.sort(key=lambda x: x[0].lower() if isinstance(x, tuple) else x.lower())
return results
def list_available_controlnets(
controlnet_dir: str = "./include/controlnets/",
) -> list[str]:
"""List available ControlNet models."""
import glob
if not os.path.exists(controlnet_dir):
return []
results = []
for ext in (".safetensors", ".pt", ".pth"):
for filepath in glob.glob(os.path.join(controlnet_dir, f"*{ext}")):
results.append(os.path.basename(filepath))
results.sort(key=str.lower)
return results |