Spaces:
Running on Zero
Running on Zero
| """Model factory for LightDiffusion-Next. | |
| Provides automatic model type detection and instantiation. | |
| Simplified to a single function with a registry for extensibility. | |
| """ | |
| import logging | |
| import os | |
| from typing import Optional, Type | |
| from src.Core.AbstractModel import AbstractModel | |
| logger = logging.getLogger(__name__) | |
| # Model type registry - maps type names to model classes | |
| _MODEL_REGISTRY: dict[str, Type[AbstractModel]] = {} | |
| # SDXL detection keywords | |
| _SDXL_INDICATORS = frozenset(["sdxl", "refiner", "hassaku", "juggernaut", "xl"]) | |
| # Flux2 Klein detection keywords | |
| _FLUX2_KLEIN_INDICATORS = frozenset(["flux2", "klein", "flux_klein", "flux-klein", "flux2_klein", "flux-2"]) | |
| # Default paths for Flux2 components | |
| FLUX2_DIFFUSION_MODEL_DIR = "./include/diffusion_model" | |
| FLUX2_TEXT_ENCODER_DIR = "./include/text_encoder" | |
| FLUX2_VAE_DIR = "./include/vae" | |
| def _ensure_registry_populated(): | |
| """Lazily populate registry to avoid circular imports.""" | |
| if not _MODEL_REGISTRY: | |
| from src.Core.Models.SD15Model import SD15Model | |
| from src.Core.Models.SDXLModel import SDXLModel | |
| from src.Core.Models.Flux2KleinModel import Flux2KleinModel | |
| _MODEL_REGISTRY["SD15"] = SD15Model | |
| _MODEL_REGISTRY["SDXL"] = SDXLModel | |
| _MODEL_REGISTRY["Flux2Klein"] = Flux2KleinModel | |
| def _find_flux2_components() -> tuple[Optional[str], Optional[str], Optional[str]]: | |
| """Auto-detect Flux2 components in default directories. | |
| Returns: | |
| Tuple of (diffusion_model_path, text_encoder_path, vae_path) | |
| """ | |
| diffusion_path = None | |
| text_encoder_path = None | |
| vae_path = None | |
| # Find diffusion model | |
| if os.path.exists(FLUX2_DIFFUSION_MODEL_DIR): | |
| for f in os.listdir(FLUX2_DIFFUSION_MODEL_DIR): | |
| f_lower = f.lower() | |
| if ("flux" in f_lower or "klein" in f_lower) and f.endswith((".safetensors", ".pt", ".pth")): | |
| diffusion_path = os.path.join(FLUX2_DIFFUSION_MODEL_DIR, f) | |
| break | |
| # Find text encoder | |
| if os.path.exists(FLUX2_TEXT_ENCODER_DIR): | |
| for f in os.listdir(FLUX2_TEXT_ENCODER_DIR): | |
| f_lower = f.lower() | |
| if ("qwen" in f_lower or "klein" in f_lower) and f.endswith((".safetensors", ".pt", ".pth")): | |
| text_encoder_path = os.path.join(FLUX2_TEXT_ENCODER_DIR, f) | |
| break | |
| # Find VAE | |
| if os.path.exists(FLUX2_VAE_DIR): | |
| for f in os.listdir(FLUX2_VAE_DIR): | |
| if f.endswith((".safetensors", ".pt", ".pth")): | |
| vae_path = os.path.join(FLUX2_VAE_DIR, f) | |
| break | |
| return diffusion_path, text_encoder_path, vae_path | |
| def detect_model_type(model_path: Optional[str]) -> str: | |
| """Detect model type from file path. | |
| Args: | |
| model_path: Path to model checkpoint | |
| Returns: | |
| 'SD15', 'SDXL', or 'Flux2Klein' | |
| Raises: | |
| ValueError: If GGUF file provided (unsupported) | |
| """ | |
| if not model_path: | |
| return "SD15" | |
| lp = model_path.lower() | |
| if lp.endswith(".gguf"): | |
| raise ValueError(f"GGUF files not supported: {model_path}") | |
| base = os.path.basename(lp) | |
| # Check for Flux2 Klein first (more specific) | |
| if any(ind in base for ind in _FLUX2_KLEIN_INDICATORS): | |
| return "Flux2Klein" | |
| # Check for SDXL | |
| if any(ind in base for ind in _SDXL_INDICATORS): | |
| return "SDXL" | |
| return "SD15" | |
| def detect_model_type_from_state_dict(state_dict: dict) -> str: | |
| """Detect model type by inspecting state dict keys. | |
| This is more accurate than filename-based detection as it | |
| examines the actual model architecture. | |
| Args: | |
| state_dict: Model state dictionary | |
| Returns: | |
| 'SD15', 'SDXL', or 'Flux2Klein' | |
| """ | |
| keys = set(state_dict.keys()) | |
| # Check for Flux2 Klein specific keys | |
| flux2_indicators = [ | |
| "double_stream_modulation_img.lin.weight", | |
| "double_stream_modulation.lin.weight", | |
| ] | |
| for indicator in flux2_indicators: | |
| for key in keys: | |
| if indicator in key: | |
| return "Flux2Klein" | |
| # Check for double_blocks (Flux architecture) | |
| if any("double_blocks" in k for k in keys): | |
| return "Flux2Klein" | |
| # Check for SDXL specific keys | |
| sdxl_indicators = [ | |
| "conditioner.embedders", | |
| "model.diffusion_model.label_emb.0.0.weight", | |
| ] | |
| for indicator in sdxl_indicators: | |
| if any(indicator in k for k in keys): | |
| return "SDXL" | |
| return "SD15" | |
| def create_model( | |
| model_path: Optional[str] = None, | |
| model_type: Optional[str] = None, | |
| text_encoder_path: Optional[str] = None, | |
| vae_path: Optional[str] = None, | |
| ) -> AbstractModel: | |
| """Create a model instance with automatic type detection. | |
| Args: | |
| model_path: Path to checkpoint file (or diffusion model for Flux2) | |
| model_type: Explicit type ('SD15', 'SDXL', 'Flux2Klein'), or None to auto-detect | |
| text_encoder_path: Path to text encoder (Flux2 only) | |
| vae_path: Path to VAE (Flux2 only) | |
| Returns: | |
| Configured model instance (not yet loaded) | |
| Example: | |
| # Auto-detect and load SD1.5/SDXL | |
| model = create_model("./checkpoints/dreamer.safetensors") | |
| model.load() | |
| # Flux2 Klein from separate components | |
| model = create_model(model_type="Flux2Klein") # auto-detect paths | |
| model.load() | |
| """ | |
| _ensure_registry_populated() | |
| if model_type is None: | |
| model_type = detect_model_type(model_path) | |
| if model_type not in _MODEL_REGISTRY: | |
| logger.warning(f"Unknown model type '{model_type}', using SD15") | |
| model_type = "SD15" | |
| # Special handling for Flux2Klein - auto-detect components | |
| if model_type == "Flux2Klein": | |
| if model_path is None or text_encoder_path is None: | |
| diffusion_path, te_path, vae_detected = _find_flux2_components() | |
| model_path = model_path or diffusion_path | |
| text_encoder_path = text_encoder_path or te_path | |
| vae_path = vae_path or vae_detected | |
| logger.info(f"Creating Flux2Klein model:") | |
| logger.info(f" Diffusion model: {model_path}") | |
| logger.info(f" Text encoder: {text_encoder_path}") | |
| logger.info(f" VAE: {vae_path}") | |
| return _MODEL_REGISTRY[model_type]( | |
| model_path=model_path, | |
| text_encoder_path=text_encoder_path, | |
| vae_path=vae_path, | |
| ) | |
| logger.info(f"Creating {model_type} model: {model_path}") | |
| return _MODEL_REGISTRY[model_type](model_path=model_path) | |
| def register_model_type(type_name: str, model_class: Type[AbstractModel]) -> None: | |
| """Register a custom model type. | |
| Args: | |
| type_name: Identifier for the model type | |
| model_class: Class inheriting from AbstractModel | |
| """ | |
| _ensure_registry_populated() | |
| if not issubclass(model_class, AbstractModel): | |
| raise TypeError(f"{model_class} must inherit from AbstractModel") | |
| _MODEL_REGISTRY[type_name] = model_class | |
| logger.info(f"Registered model type: {type_name}") | |
| def list_model_types() -> list[str]: | |
| """List registered model types.""" | |
| _ensure_registry_populated() | |
| return list(_MODEL_REGISTRY.keys()) | |
| def list_available_models( | |
| checkpoint_dir: str = "./include/checkpoints/", | |
| return_mapping: bool = False, | |
| ) -> list: | |
| """List available model files in the checkpoints directory. | |
| Args: | |
| checkpoint_dir: Directory to scan for models | |
| return_mapping: If True, return list of (display_name, full_path) tuples | |
| Returns: | |
| List of model names, or list of (name, path) tuples if return_mapping=True | |
| """ | |
| import glob | |
| valid_extensions = (".safetensors", ".pt", ".pth") | |
| results = [] | |
| # Checkpoints | |
| if os.path.isdir(checkpoint_dir): | |
| for ext in valid_extensions: | |
| pattern = os.path.join(checkpoint_dir, f"*{ext}") | |
| for filepath in glob.glob(pattern): | |
| basename = os.path.basename(filepath) | |
| if return_mapping: | |
| results.append((basename, filepath)) | |
| else: | |
| results.append(basename) | |
| # Flux2 Diffusion Models | |
| if os.path.isdir(FLUX2_DIFFUSION_MODEL_DIR): | |
| for ext in valid_extensions: | |
| pattern = os.path.join(FLUX2_DIFFUSION_MODEL_DIR, f"*{ext}") | |
| for filepath in glob.glob(pattern): | |
| basename = os.path.basename(filepath) | |
| if return_mapping: | |
| results.append((basename, filepath)) | |
| else: | |
| results.append(basename) | |
| # Sort alphabetically | |
| results.sort(key=lambda x: x[0].lower() if isinstance(x, tuple) else x.lower()) | |
| return results | |
| def list_available_controlnets( | |
| controlnet_dir: str = "./include/controlnets/", | |
| ) -> list[str]: | |
| """List available ControlNet models.""" | |
| import glob | |
| if not os.path.exists(controlnet_dir): | |
| return [] | |
| results = [] | |
| for ext in (".safetensors", ".pt", ".pth"): | |
| for filepath in glob.glob(os.path.join(controlnet_dir, f"*{ext}")): | |
| results.append(os.path.basename(filepath)) | |
| results.sort(key=str.lower) | |
| return results |