File size: 12,959 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
"""LightDiffusion-Next Pipeline Entry Point.

This module provides the main `pipeline()` function that all UIs call.
It's a thin wrapper around the Core Pipeline class for backward compatibility.

Usage:
    from src.user.pipeline import pipeline
    
    result = pipeline(
        prompt="a beautiful landscape",
        w=512, h=512,
        hires_fix=True,
        adetailer=True,
    )
"""

import logging
import os
import random
from typing import Callable

import torch

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
_assets_ready = False


def resolve_checkpoint_path(realistic_model: bool = False) -> str:
    """Resolve the checkpoint path based on model settings."""
    return "./include/checkpoints/DreamShaper_8_pruned.safetensors"


from src.FileManaging import Downloader


def _ensure_runtime_assets() -> None:
    """Download default runtime assets the first time generation is used.

    This intentionally runs lazily instead of at module import time so API
    startup, health checks, and unrelated imports remain lightweight and do not
    fail just because model assets are not present yet.
    """
    global _assets_ready
    if _assets_ready:
        return

    Downloader.CheckAndDownload()
    _assets_ready = True

# Module-level cache for the last-used seed; load lazily to avoid
# import-time circular dependencies with Core modules.
_last_seed = None  # type: ignore | None


def pipeline(
    prompt: str | list,
    w: int,
    h: int,
    number: int = 1,
    batch: int = 1,
    scheduler: str = "ays",
    sampler: str = "dpmpp_sde",
    steps: int = 20,
    cfg_scale: float = 7.0,
    hires_fix: bool = False,
    adetailer: bool = False,
    enhance_prompt: bool = False,
    img2img: bool = False,
    stable_fast: bool = False,
    reuse_seed: bool = False,
    autohdr: bool = True,
    realistic_model: bool = False,
    model_path: str | None = None,
    negative_prompt: str = "",
    # Multi-scale diffusion
    multiscale_preset: str = "disabled",
    enable_multiscale: bool = False,
    multiscale_factor: float = 0.5,
    multiscale_fullres_start: int = 3,
    multiscale_fullres_end: int = 8,
    multiscale_intermittent_fullres: bool = False,
    # DeepCache
    deepcache_enabled: bool = False,
    deepcache_interval: int = 3,
    deepcache_depth: int = 2,
    deepcache_start_step: int = 0,
    deepcache_end_step: int = 1000,
    # CFG-free
    cfg_free_enabled: bool = False,
    cfg_free_start_percent: float = 70.0,
    # Token Merging
    tome_enabled: bool = False,
    tome_ratio: float = 0.5,
    tome_max_downsample: int = 1,
    # Advanced CFG
    batched_cfg: bool = True,
    dynamic_cfg_rescaling: bool = False,
    dynamic_cfg_method: str = "variance",
    dynamic_cfg_percentile: float = 95.0,
    dynamic_cfg_target_scale: float = 7.0,
    adaptive_noise_enabled: bool = False,
    adaptive_noise_method: str = "complexity",
    # Img2img
    img2img_image: str | None = None,
    request_filename_prefix: str | None = None,
    img2img_denoise: float = 0.75,  # Denoising strength: 0=no change, 1=full gen
    # Refiner
    refiner_model_path: str | None = None,
    refiner_switch_step: int | None = None,
    # ControlNet
    controlnet_model: str | None = None,
    controlnet_strength: float = 1.0,
    controlnet_type: str = "canny",
    # torch.compile
    torch_compile: bool = False,
    vae_autotune: bool = False,
    # Weight quantization
    weight_quantization: str | None = None,
    # FP8 quantization
    fp8_inference: bool = False,
    # Batched mode
    per_sample_info: list | None = None,
    # External callback
    callback: Callable | None = None,
) -> dict:
    """Run the LightDiffusion pipeline.

    This is the main entry point for image generation. All parameters
    are collected into a Context and passed to the Pipeline.

    Args:
        prompt: Text prompt(s) for generation
        w: Width of generated image
        h: Height of generated image
        number: Number of images to generate
        batch: Batch size
        scheduler: Scheduler name
        sampler: Sampler name
        steps: Sampling steps
        hires_fix: Enable high-resolution fix
        adetailer: Enable face/body enhancement
        enhance_prompt: Enable Ollama prompt enhancement
        img2img: Enable image-to-image mode
        stable_fast: Enable StableFast optimization
        reuse_seed: Reuse last seed
        autohdr: Enable AutoHDR
        realistic_model: Use realistic model
        model_path: Path to model checkpoint
        negative_prompt: Negative prompt
        multiscale_preset: Multi-scale preset
        enable_multiscale: Enable multi-scale diffusion
        deepcache_enabled: Enable DeepCache
        cfg_free_enabled: Enable CFG-free sampling
        tome_enabled: Enable Token Merging
        img2img_image: Source image for img2img
        per_sample_info: Per-sample data for batched mode

    Returns:
        Dictionary with generation results
    """
    global _last_seed
    _ensure_runtime_assets()
    
    # Clear interrupt flag
    from src.user import app_instance
    app_ref = getattr(app_instance, "app", None)
    if app_ref is not None:
        app_ref.clear_interrupt()
    
    # Build context from kwargs
    from src.Core.Context import Context
    from src.Core.Pipeline import Pipeline, get_default_pipeline
    
    ctx = Context.from_kwargs(
        prompt=prompt,
        w=w, h=h,
        number=number,
        batch=batch,
        scheduler=scheduler,
        sampler=sampler,
        steps=steps,
        cfg_scale=cfg_scale,
        hires_fix=hires_fix,
        adetailer=adetailer,
        enhance_prompt=enhance_prompt,
        img2img=img2img,
        stable_fast=stable_fast,
        reuse_seed=reuse_seed,
        autohdr=autohdr,
        model_path=model_path or (
            "./include/checkpoints/DreamShaper_8_pruned.safetensors" if realistic_model
            else "./include/checkpoints/DreamShaper_8_pruned.safetensors"
        ),
        negative_prompt=negative_prompt,
        multiscale_preset=multiscale_preset,
        enable_multiscale=enable_multiscale,
        multiscale_factor=multiscale_factor,
        multiscale_fullres_start=multiscale_fullres_start,
        multiscale_fullres_end=multiscale_fullres_end,
        multiscale_intermittent_fullres=multiscale_intermittent_fullres,
        deepcache_enabled=deepcache_enabled,
        deepcache_interval=deepcache_interval,
        deepcache_depth=deepcache_depth,
        deepcache_start_step=deepcache_start_step,
        deepcache_end_step=deepcache_end_step,
        cfg_free_enabled=cfg_free_enabled,
        cfg_free_start_percent=cfg_free_start_percent,
        tome_enabled=tome_enabled,
        tome_ratio=tome_ratio,
        tome_max_downsample=tome_max_downsample,
        batched_cfg=batched_cfg,
        dynamic_cfg_rescaling=dynamic_cfg_rescaling,
        dynamic_cfg_method=dynamic_cfg_method,
        dynamic_cfg_percentile=dynamic_cfg_percentile,
        dynamic_cfg_target_scale=dynamic_cfg_target_scale,
        adaptive_noise_enabled=adaptive_noise_enabled,
        adaptive_noise_method=adaptive_noise_method,
        img2img_image=img2img_image,
        request_filename_prefix=request_filename_prefix,
        img2img_denoise=img2img_denoise,
        refiner_model_path=refiner_model_path,
        refiner_switch_step=refiner_switch_step,
        controlnet_model=controlnet_model,
        controlnet_strength=controlnet_strength,
        controlnet_type=controlnet_type,
        torch_compile=torch_compile,
        vae_autotune=vae_autotune,
        fp8_inference=fp8_inference,
        weight_quantization=weight_quantization,
    )
    
    # Handle prompt enhancement
    original_prompt = prompt
    enhancement_applied = False
    
    if enhance_prompt:
        ctx, enhancement_applied = _enhance_prompt(ctx)
    
    # Handle seed reuse
    if reuse_seed:
        global _last_seed
        if _last_seed is None:
            try:
                from src.Core.SettingsStore import get_last_seed
                _ls = get_last_seed()
                _last_seed = int(_ls) if (_ls is not None) else random.randint(1, 2**63 - 1)
            except Exception:
                _last_seed = random.randint(1, 2**63 - 1)
        ctx.seeds = [_last_seed] * ctx.total_images
        ctx.seed = _last_seed
    
    # Save seed for future reuse
    _last_seed = ctx.seeds[-1] if ctx.seeds else ctx.seed
    
    # Setup default callback for UI preview
    # Setup default callback for UI preview
    def default_callback(args: dict):
        from src.user import app_instance
        from src.AutoEncoders import taesd
        app_ref = getattr(app_instance, "app", None)
        
        # Streamlit/Gradio UI preview
        if app_ref is not None:
            step = args.get("i", 0)
            x0 = args.get("denoised")
            total_steps = args.get("total_steps", ctx.sampling.steps)
            
            # Update progress tracker
            if total_steps > 0:
                app_ref.progress.set((step + 1) / total_steps)
            
            # Update preview (x0 is the denoised latent estimate)
            if x0 is not None:
                is_flux = x0.shape[1] in (16, 32)
                # taesd_preview handles PIL conversion and calls app_ref.update_image
                taesd.taesd_preview(x0, flux=is_flux, step=step, total_steps=total_steps)
            else:
                # Just update step info if no image is available
                app_ref.update_image(app_ref.preview_images, step=step, total_steps=total_steps)
        
        # Chain external callback if provided
        if callback is not None:
            try:
                callback(args)
            except Exception:
                pass
            
    ctx.callback = default_callback
    
    # Run pipeline
    pipeline_instance = get_default_pipeline()
    
    with torch.inference_mode():
        if ctx.features.controlnet_model:
            # ControlNet mode (uses input image for control, generates new content)
            pipeline_instance.run_controlnet(ctx)
        elif ctx.is_batched:
            # Batched requests must use the unified batched path even for img2img.
            return pipeline_instance.run_batched(ctx, per_sample_info)
        elif ctx.features.img2img:
            pipeline_instance.run_img2img(ctx)
        else:
            pipeline_instance.run(ctx)
    
    return {
        "original_prompt": original_prompt,
        "used_prompt": ctx.prompt,
        "enhancement_applied": enhancement_applied,
    }


def _enhance_prompt(ctx: "Context") -> tuple["Context", bool]:
    """Apply Ollama prompt enhancement if available."""
    from src.Utilities import Enhancer
    
    try:
        if isinstance(ctx.prompt, (list, tuple)):
            enhanced = []
            for p in ctx.prompt:
                try:
                    e = Enhancer.enhance_prompt(p)
                    enhanced.append(e if e else p)
                except Exception:
                    enhanced.append(p)
            ctx.prompt = enhanced
        else:
            e = Enhancer.enhance_prompt(ctx.prompt)
            if e:
                ctx.prompt = e
        return ctx, True
    except Exception:
        return ctx, False


# ============================================================================
# CLI INTERFACE
# ============================================================================

if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description="LightDiffusion Pipeline CLI")
    parser.add_argument("prompt", type=str, help="Generation prompt")
    parser.add_argument("width", type=int, help="Image width")
    parser.add_argument("height", type=int, help="Image height")
    parser.add_argument("number", type=int, default=1, help="Number of images")
    parser.add_argument("batch", type=int, default=1, help="Batch size")
    parser.add_argument("--scheduler", type=str, default="karras")
    parser.add_argument("--sampler", type=str, default="dpmpp_2m_cfgpp")
    parser.add_argument("--steps", type=int, default=20)
    parser.add_argument("--hires-fix", action="store_true")
    parser.add_argument("--adetailer", action="store_true")
    parser.add_argument("--stable-fast", action="store_true")
    parser.add_argument("--deepcache", action="store_true")
    parser.add_argument("--model-path", type=str, default="")
    
    args = parser.parse_args()
    
    pipeline(
        args.prompt,
        args.width,
        args.height,
        args.number,
        args.batch,
        scheduler=args.scheduler,
        sampler=args.sampler,
        steps=args.steps,
        hires_fix=args.hires_fix,
        adetailer=args.adetailer,
        stable_fast=args.stable_fast,
        deepcache_enabled=args.deepcache,
        model_path=args.model_path or None,
    )