LightDiffusion-Next / src /Core /Models /ModelFactory.py
Aatricks's picture
Deploy ZeroGPU Gradio Space snapshot
b701455
"""Model factory for LightDiffusion-Next.
Provides automatic model type detection and instantiation.
Simplified to a single function with a registry for extensibility.
"""
import logging
import os
from typing import Optional, Type
from src.Core.AbstractModel import AbstractModel
logger = logging.getLogger(__name__)
# Model type registry - maps type names to model classes
_MODEL_REGISTRY: dict[str, Type[AbstractModel]] = {}
# SDXL detection keywords
_SDXL_INDICATORS = frozenset(["sdxl", "refiner", "hassaku", "juggernaut", "xl"])
# Flux2 Klein detection keywords
_FLUX2_KLEIN_INDICATORS = frozenset(["flux2", "klein", "flux_klein", "flux-klein", "flux2_klein", "flux-2"])
# Default paths for Flux2 components
FLUX2_DIFFUSION_MODEL_DIR = "./include/diffusion_model"
FLUX2_TEXT_ENCODER_DIR = "./include/text_encoder"
FLUX2_VAE_DIR = "./include/vae"
def _ensure_registry_populated():
"""Lazily populate registry to avoid circular imports."""
if not _MODEL_REGISTRY:
from src.Core.Models.SD15Model import SD15Model
from src.Core.Models.SDXLModel import SDXLModel
from src.Core.Models.Flux2KleinModel import Flux2KleinModel
_MODEL_REGISTRY["SD15"] = SD15Model
_MODEL_REGISTRY["SDXL"] = SDXLModel
_MODEL_REGISTRY["Flux2Klein"] = Flux2KleinModel
def _find_flux2_components() -> tuple[Optional[str], Optional[str], Optional[str]]:
"""Auto-detect Flux2 components in default directories.
Returns:
Tuple of (diffusion_model_path, text_encoder_path, vae_path)
"""
diffusion_path = None
text_encoder_path = None
vae_path = None
# Find diffusion model
if os.path.exists(FLUX2_DIFFUSION_MODEL_DIR):
for f in os.listdir(FLUX2_DIFFUSION_MODEL_DIR):
f_lower = f.lower()
if ("flux" in f_lower or "klein" in f_lower) and f.endswith((".safetensors", ".pt", ".pth")):
diffusion_path = os.path.join(FLUX2_DIFFUSION_MODEL_DIR, f)
break
# Find text encoder
if os.path.exists(FLUX2_TEXT_ENCODER_DIR):
for f in os.listdir(FLUX2_TEXT_ENCODER_DIR):
f_lower = f.lower()
if ("qwen" in f_lower or "klein" in f_lower) and f.endswith((".safetensors", ".pt", ".pth")):
text_encoder_path = os.path.join(FLUX2_TEXT_ENCODER_DIR, f)
break
# Find VAE
if os.path.exists(FLUX2_VAE_DIR):
for f in os.listdir(FLUX2_VAE_DIR):
if f.endswith((".safetensors", ".pt", ".pth")):
vae_path = os.path.join(FLUX2_VAE_DIR, f)
break
return diffusion_path, text_encoder_path, vae_path
def detect_model_type(model_path: Optional[str]) -> str:
"""Detect model type from file path.
Args:
model_path: Path to model checkpoint
Returns:
'SD15', 'SDXL', or 'Flux2Klein'
Raises:
ValueError: If GGUF file provided (unsupported)
"""
if not model_path:
return "SD15"
lp = model_path.lower()
if lp.endswith(".gguf"):
raise ValueError(f"GGUF files not supported: {model_path}")
base = os.path.basename(lp)
# Check for Flux2 Klein first (more specific)
if any(ind in base for ind in _FLUX2_KLEIN_INDICATORS):
return "Flux2Klein"
# Check for SDXL
if any(ind in base for ind in _SDXL_INDICATORS):
return "SDXL"
return "SD15"
def detect_model_type_from_state_dict(state_dict: dict) -> str:
"""Detect model type by inspecting state dict keys.
This is more accurate than filename-based detection as it
examines the actual model architecture.
Args:
state_dict: Model state dictionary
Returns:
'SD15', 'SDXL', or 'Flux2Klein'
"""
keys = set(state_dict.keys())
# Check for Flux2 Klein specific keys
flux2_indicators = [
"double_stream_modulation_img.lin.weight",
"double_stream_modulation.lin.weight",
]
for indicator in flux2_indicators:
for key in keys:
if indicator in key:
return "Flux2Klein"
# Check for double_blocks (Flux architecture)
if any("double_blocks" in k for k in keys):
return "Flux2Klein"
# Check for SDXL specific keys
sdxl_indicators = [
"conditioner.embedders",
"model.diffusion_model.label_emb.0.0.weight",
]
for indicator in sdxl_indicators:
if any(indicator in k for k in keys):
return "SDXL"
return "SD15"
def create_model(
model_path: Optional[str] = None,
model_type: Optional[str] = None,
text_encoder_path: Optional[str] = None,
vae_path: Optional[str] = None,
) -> AbstractModel:
"""Create a model instance with automatic type detection.
Args:
model_path: Path to checkpoint file (or diffusion model for Flux2)
model_type: Explicit type ('SD15', 'SDXL', 'Flux2Klein'), or None to auto-detect
text_encoder_path: Path to text encoder (Flux2 only)
vae_path: Path to VAE (Flux2 only)
Returns:
Configured model instance (not yet loaded)
Example:
# Auto-detect and load SD1.5/SDXL
model = create_model("./checkpoints/dreamer.safetensors")
model.load()
# Flux2 Klein from separate components
model = create_model(model_type="Flux2Klein") # auto-detect paths
model.load()
"""
_ensure_registry_populated()
if model_type is None:
model_type = detect_model_type(model_path)
if model_type not in _MODEL_REGISTRY:
logger.warning(f"Unknown model type '{model_type}', using SD15")
model_type = "SD15"
# Special handling for Flux2Klein - auto-detect components
if model_type == "Flux2Klein":
if model_path is None or text_encoder_path is None:
diffusion_path, te_path, vae_detected = _find_flux2_components()
model_path = model_path or diffusion_path
text_encoder_path = text_encoder_path or te_path
vae_path = vae_path or vae_detected
logger.info(f"Creating Flux2Klein model:")
logger.info(f" Diffusion model: {model_path}")
logger.info(f" Text encoder: {text_encoder_path}")
logger.info(f" VAE: {vae_path}")
return _MODEL_REGISTRY[model_type](
model_path=model_path,
text_encoder_path=text_encoder_path,
vae_path=vae_path,
)
logger.info(f"Creating {model_type} model: {model_path}")
return _MODEL_REGISTRY[model_type](model_path=model_path)
def register_model_type(type_name: str, model_class: Type[AbstractModel]) -> None:
"""Register a custom model type.
Args:
type_name: Identifier for the model type
model_class: Class inheriting from AbstractModel
"""
_ensure_registry_populated()
if not issubclass(model_class, AbstractModel):
raise TypeError(f"{model_class} must inherit from AbstractModel")
_MODEL_REGISTRY[type_name] = model_class
logger.info(f"Registered model type: {type_name}")
def list_model_types() -> list[str]:
"""List registered model types."""
_ensure_registry_populated()
return list(_MODEL_REGISTRY.keys())
def list_available_models(
checkpoint_dir: str = "./include/checkpoints/",
return_mapping: bool = False,
) -> list:
"""List available model files in the checkpoints directory.
Args:
checkpoint_dir: Directory to scan for models
return_mapping: If True, return list of (display_name, full_path) tuples
Returns:
List of model names, or list of (name, path) tuples if return_mapping=True
"""
import glob
valid_extensions = (".safetensors", ".pt", ".pth")
results = []
# Checkpoints
if os.path.isdir(checkpoint_dir):
for ext in valid_extensions:
pattern = os.path.join(checkpoint_dir, f"*{ext}")
for filepath in glob.glob(pattern):
basename = os.path.basename(filepath)
if return_mapping:
results.append((basename, filepath))
else:
results.append(basename)
# Flux2 Diffusion Models
if os.path.isdir(FLUX2_DIFFUSION_MODEL_DIR):
for ext in valid_extensions:
pattern = os.path.join(FLUX2_DIFFUSION_MODEL_DIR, f"*{ext}")
for filepath in glob.glob(pattern):
basename = os.path.basename(filepath)
if return_mapping:
results.append((basename, filepath))
else:
results.append(basename)
# Sort alphabetically
results.sort(key=lambda x: x[0].lower() if isinstance(x, tuple) else x.lower())
return results
def list_available_controlnets(
controlnet_dir: str = "./include/controlnets/",
) -> list[str]:
"""List available ControlNet models."""
import glob
if not os.path.exists(controlnet_dir):
return []
results = []
for ext in (".safetensors", ".pt", ".pth"):
for filepath in glob.glob(os.path.join(controlnet_dir, f"*{ext}")):
results.append(os.path.basename(filepath))
results.sort(key=str.lower)
return results