File size: 9,411 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
"""Model factory for LightDiffusion-Next.

Provides automatic model type detection and instantiation.
Simplified to a single function with a registry for extensibility.
"""

import logging
import os
from typing import Optional, Type

from src.Core.AbstractModel import AbstractModel

logger = logging.getLogger(__name__)

# Model type registry - maps type names to model classes
_MODEL_REGISTRY: dict[str, Type[AbstractModel]] = {}

# SDXL detection keywords
_SDXL_INDICATORS = frozenset(["sdxl", "refiner", "hassaku", "juggernaut", "xl"])

# Flux2 Klein detection keywords
_FLUX2_KLEIN_INDICATORS = frozenset(["flux2", "klein", "flux_klein", "flux-klein", "flux2_klein", "flux-2"])

# Default paths for Flux2 components
FLUX2_DIFFUSION_MODEL_DIR = "./include/diffusion_model"
FLUX2_TEXT_ENCODER_DIR = "./include/text_encoder"
FLUX2_VAE_DIR = "./include/vae"


def _ensure_registry_populated():
    """Lazily populate registry to avoid circular imports."""
    if not _MODEL_REGISTRY:
        from src.Core.Models.SD15Model import SD15Model
        from src.Core.Models.SDXLModel import SDXLModel
        from src.Core.Models.Flux2KleinModel import Flux2KleinModel
        _MODEL_REGISTRY["SD15"] = SD15Model
        _MODEL_REGISTRY["SDXL"] = SDXLModel
        _MODEL_REGISTRY["Flux2Klein"] = Flux2KleinModel


def _find_flux2_components() -> tuple[Optional[str], Optional[str], Optional[str]]:
    """Auto-detect Flux2 components in default directories.
    
    Returns:
        Tuple of (diffusion_model_path, text_encoder_path, vae_path)
    """
    diffusion_path = None
    text_encoder_path = None
    vae_path = None
    
    # Find diffusion model
    if os.path.exists(FLUX2_DIFFUSION_MODEL_DIR):
        for f in os.listdir(FLUX2_DIFFUSION_MODEL_DIR):
            f_lower = f.lower()
            if ("flux" in f_lower or "klein" in f_lower) and f.endswith((".safetensors", ".pt", ".pth")):
                diffusion_path = os.path.join(FLUX2_DIFFUSION_MODEL_DIR, f)
                break
    
    # Find text encoder
    if os.path.exists(FLUX2_TEXT_ENCODER_DIR):
        for f in os.listdir(FLUX2_TEXT_ENCODER_DIR):
            f_lower = f.lower()
            if ("qwen" in f_lower or "klein" in f_lower) and f.endswith((".safetensors", ".pt", ".pth")):
                text_encoder_path = os.path.join(FLUX2_TEXT_ENCODER_DIR, f)
                break
    
    # Find VAE
    if os.path.exists(FLUX2_VAE_DIR):
        for f in os.listdir(FLUX2_VAE_DIR):
            if f.endswith((".safetensors", ".pt", ".pth")):
                vae_path = os.path.join(FLUX2_VAE_DIR, f)
                break
    
    return diffusion_path, text_encoder_path, vae_path


def detect_model_type(model_path: Optional[str]) -> str:
    """Detect model type from file path.
    
    Args:
        model_path: Path to model checkpoint
        
    Returns:
        'SD15', 'SDXL', or 'Flux2Klein'
        
    Raises:
        ValueError: If GGUF file provided (unsupported)
    """
    if not model_path:
        return "SD15"
    
    lp = model_path.lower()
    
    if lp.endswith(".gguf"):
        raise ValueError(f"GGUF files not supported: {model_path}")
    
    base = os.path.basename(lp)
    
    # Check for Flux2 Klein first (more specific)
    if any(ind in base for ind in _FLUX2_KLEIN_INDICATORS):
        return "Flux2Klein"
    
    # Check for SDXL
    if any(ind in base for ind in _SDXL_INDICATORS):
        return "SDXL"
    
    return "SD15"


def detect_model_type_from_state_dict(state_dict: dict) -> str:
    """Detect model type by inspecting state dict keys.
    
    This is more accurate than filename-based detection as it
    examines the actual model architecture.
    
    Args:
        state_dict: Model state dictionary
        
    Returns:
        'SD15', 'SDXL', or 'Flux2Klein'
    """
    keys = set(state_dict.keys())
    
    # Check for Flux2 Klein specific keys
    flux2_indicators = [
        "double_stream_modulation_img.lin.weight",
        "double_stream_modulation.lin.weight",
    ]
    
    for indicator in flux2_indicators:
        for key in keys:
            if indicator in key:
                return "Flux2Klein"
    
    # Check for double_blocks (Flux architecture)
    if any("double_blocks" in k for k in keys):
        return "Flux2Klein"
    
    # Check for SDXL specific keys
    sdxl_indicators = [
        "conditioner.embedders",
        "model.diffusion_model.label_emb.0.0.weight",
    ]
    
    for indicator in sdxl_indicators:
        if any(indicator in k for k in keys):
            return "SDXL"
    
    return "SD15"


def create_model(
    model_path: Optional[str] = None,
    model_type: Optional[str] = None,
    text_encoder_path: Optional[str] = None,
    vae_path: Optional[str] = None,
) -> AbstractModel:
    """Create a model instance with automatic type detection.
    
    Args:
        model_path: Path to checkpoint file (or diffusion model for Flux2)
        model_type: Explicit type ('SD15', 'SDXL', 'Flux2Klein'), or None to auto-detect
        text_encoder_path: Path to text encoder (Flux2 only)
        vae_path: Path to VAE (Flux2 only)
        
    Returns:
        Configured model instance (not yet loaded)
        
    Example:
        # Auto-detect and load SD1.5/SDXL
        model = create_model("./checkpoints/dreamer.safetensors")
        model.load()
        
        # Flux2 Klein from separate components
        model = create_model(model_type="Flux2Klein")  # auto-detect paths
        model.load()
    """
    _ensure_registry_populated()
    
    if model_type is None:
        model_type = detect_model_type(model_path)
    
    if model_type not in _MODEL_REGISTRY:
        logger.warning(f"Unknown model type '{model_type}', using SD15")
        model_type = "SD15"
    
    # Special handling for Flux2Klein - auto-detect components
    if model_type == "Flux2Klein":
        if model_path is None or text_encoder_path is None:
            diffusion_path, te_path, vae_detected = _find_flux2_components()
            model_path = model_path or diffusion_path
            text_encoder_path = text_encoder_path or te_path  
            vae_path = vae_path or vae_detected
        
        logger.info(f"Creating Flux2Klein model:")
        logger.info(f"  Diffusion model: {model_path}")
        logger.info(f"  Text encoder: {text_encoder_path}")
        logger.info(f"  VAE: {vae_path}")
        
        return _MODEL_REGISTRY[model_type](
            model_path=model_path,
            text_encoder_path=text_encoder_path,
            vae_path=vae_path,
        )
    
    logger.info(f"Creating {model_type} model: {model_path}")
    return _MODEL_REGISTRY[model_type](model_path=model_path)


def register_model_type(type_name: str, model_class: Type[AbstractModel]) -> None:
    """Register a custom model type.
    
    Args:
        type_name: Identifier for the model type
        model_class: Class inheriting from AbstractModel
    """
    _ensure_registry_populated()
    
    if not issubclass(model_class, AbstractModel):
        raise TypeError(f"{model_class} must inherit from AbstractModel")
    
    _MODEL_REGISTRY[type_name] = model_class
    logger.info(f"Registered model type: {type_name}")


def list_model_types() -> list[str]:
    """List registered model types."""
    _ensure_registry_populated()
    return list(_MODEL_REGISTRY.keys())


def list_available_models(
    checkpoint_dir: str = "./include/checkpoints/",
    return_mapping: bool = False,
) -> list:
    """List available model files in the checkpoints directory.
    
    Args:
        checkpoint_dir: Directory to scan for models
        return_mapping: If True, return list of (display_name, full_path) tuples
        
    Returns:
        List of model names, or list of (name, path) tuples if return_mapping=True
    """
    import glob
    
    valid_extensions = (".safetensors", ".pt", ".pth")
    results = []
    
    # Checkpoints
    if os.path.isdir(checkpoint_dir):
        for ext in valid_extensions:
            pattern = os.path.join(checkpoint_dir, f"*{ext}")
            for filepath in glob.glob(pattern):
                basename = os.path.basename(filepath)
                if return_mapping:
                    results.append((basename, filepath))
                else:
                    results.append(basename)

    # Flux2 Diffusion Models
    if os.path.isdir(FLUX2_DIFFUSION_MODEL_DIR):
        for ext in valid_extensions:
            pattern = os.path.join(FLUX2_DIFFUSION_MODEL_DIR, f"*{ext}")
            for filepath in glob.glob(pattern):
                basename = os.path.basename(filepath)
                if return_mapping:
                    results.append((basename, filepath))
                else:
                    results.append(basename)
    
    # Sort alphabetically
    results.sort(key=lambda x: x[0].lower() if isinstance(x, tuple) else x.lower())
    return results


def list_available_controlnets(
    controlnet_dir: str = "./include/controlnets/",
) -> list[str]:
    """List available ControlNet models."""
    import glob
    if not os.path.exists(controlnet_dir):
        return []
    
    results = []
    for ext in (".safetensors", ".pt", ".pth"):
        for filepath in glob.glob(os.path.join(controlnet_dir, f"*{ext}")):
            results.append(os.path.basename(filepath))
    
    results.sort(key=str.lower)
    return results