File size: 11,245 Bytes
fe74e7d
 
 
 
 
 
 
77b4154
fe74e7d
77b4154
9dc64d1
fe74e7d
 
 
 
 
 
 
 
 
 
 
 
 
 
77b4154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe74e7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77b4154
fe74e7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77b4154
 
fe74e7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9dc64d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe74e7d
9dc64d1
 
 
 
 
fe74e7d
9dc64d1
 
 
 
 
 
fe74e7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9dc64d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe74e7d
9dc64d1
fe74e7d
9dc64d1
 
 
 
 
 
 
 
 
 
 
 
 
 
fe74e7d
 
 
 
 
 
 
 
 
 
 
 
9dc64d1
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
"""
Load Anima RDBT with Diffusers (community Anima pipeline) in-process; no ComfyUI server.
"""
from __future__ import annotations

import math
import os
import sys
import threading
from pathlib import Path
from typing import Any, Optional

import torch
from PIL import Image

from src import config
from src.config import GenerationParams
from src.errors import UserFacingError

_lock = threading.RLock()
_pipe: Any = None
_prepared: bool = False
_bootstrapped: bool = False


def _vendor_root() -> Path:
    return Path(__file__).resolve().parent.parent / "vendor"


def _apply_local_te_vae_if_configured() -> None:
    """Patch diffusers-anima loaders to use Hub-downloaded Comfy-layout TE/VAE when enabled."""
    if not config.use_local_te_vae():
        return
    te = config.text_encoder_file_path()
    vae = config.vae_file_path()
    if not config.allow_te_vae_hub_fallback():
        if not os.path.isfile(te):
            raise UserFacingError(
                f"Strict local TE/VAE: missing text encoder at {te!r}. Run startup bootstrap or set ANIMA_MODELS_ROOT."
            )
        if not os.path.isfile(vae):
            raise UserFacingError(
                f"Strict local TE/VAE: missing VAE at {vae!r}. Run startup bootstrap or set ANIMA_MODELS_ROOT."
            )
    vr = _vendor_root()
    if str(vr) not in sys.path:
        sys.path.insert(0, str(vr))
    from anima_local_te_vae import apply_local_te_vae_patches

    te_use = te if os.path.isfile(te) else None
    vae_use = vae if os.path.isfile(vae) else None
    apply_local_te_vae_patches(
        te_use,
        vae_use,
        allow_hub_fallback=config.allow_te_vae_hub_fallback(),
    )


def _set_cudnn_sdp_env() -> None:
    if not config.allow_cudnn_sdp():
        os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0"


def _device_str() -> str:
    if torch.cuda.is_available():
        return "cuda"
    if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
        return "mps"
    return "cpu"


def _map_comfy_sampler_to_anima(sampler: str) -> str:
    """
    ComfyUI KSampler names -> AnimaFlowMatchEulerDiscreteScheduler (diffusers-anima) samplers.
    Supported: flowmatch_euler, euler, euler_a_rf, euler_ancestral_rf (alias of euler_a_rf).
    """
    s = (sampler or "").strip().lower()
    if s == "euler":
        return "euler"
    if s == "flowmatch_euler" or s == "flow_match_euler":
        return "flowmatch_euler"
    if s in (
        "euler_ancestral",
        "euler_a",
        "euler_ancestral_cfg_pp",
        "euler_a_rf",
    ) or "ancestral" in s:
        return "euler_ancestral_rf"
    # DPM, DDIM, LCM, etc. — no 1:1; match RDBT card default
    return "euler_ancestral_rf"


def _map_comfy_scheduler_to_sigma(scheduler: str) -> str:
    """
    Comfy scheduler names -> Anima sigma_schedule: uniform | simple | normal | beta.
    """
    s = (scheduler or "").strip().lower()
    if s in ("simple", "normal", "beta", "uniform"):
        return s
    if s in ("karras", "exponential", "sgm_uniform", "ddim_uniform", "linear_quadratic", "kl_optimal"):
        return "normal"
    return "simple"


def _align_sampling(
    anima_sampler: str, sigma: str
) -> tuple[str, str, list[str]]:
    """Enforce Anima's valid (sampler, sigma_schedule) pairs; return optional notices."""
    notes: list[str] = []
    s = anima_sampler
    sig = sigma
    if s == "flowmatch_euler" and sig != "uniform":
        sig = "uniform"
        notes.append("Sampler flowmatch_euler requires sigma schedule `uniform`; adjusted.")
    elif s != "flowmatch_euler" and sig == "uniform":
        sig = "simple"
        notes.append("Sigma schedule `uniform` is only for flowmatch_euler; using `simple`.")
    return s, sig, notes


def _rdbt_path() -> str:
    d = config.model_artifacts_root()
    return os.path.join(d, "diffusion_models", config.RDBT_UNET_NAME)


def _bootstrap_files_if_needed() -> None:
    global _bootstrapped
    with _lock:
        if _bootstrapped:
            return
        from src import bootstrap  # local import

        try:
            bootstrap.bootstrap_model_artifacts()
        except UserFacingError:
            raise
        except Exception as e:
            raise UserFacingError(
                f"Model bootstrap failed: {e!s}. See logs for full traceback."
            ) from e
        _bootstrapped = True


def run_at_container_startup() -> None:
    """
    Run at Space import: disk/network only (no CUDA). Downloads RDBT weights, etc.
    Pipeline weights load on first Generate under @spaces.GPU.
    """
    print(
        "[startup] Downloading RDBT, text encoder, and VAE (CPU/network)…",
        flush=True,
    )
    try:
        _bootstrap_files_if_needed()
    except Exception as e:
        print(f"[startup] Failed: {e!s}", flush=True)
        raise
    print(
        "[startup] Model files ready. The Diffusers pipeline loads on the first **Generate** "
        "when ZeroGPU assigns a GPU to this worker.",
        flush=True,
    )


def _load_pipeline() -> Any:
    try:
        from diffusers_anima import AnimaPipeline
    except ImportError as e:
        raise UserFacingError(
            "The `diffusers_anima` package is not installed. Install with requirements.txt"
            f" (diffusers + diffusers-anima). ({e!s})"
        ) from e
    rdbt = _rdbt_path()
    if not os.path.isfile(rdbt):
        raise UserFacingError(
            f"RDBT checkpoint not found: {rdbt!s}. Re-run startup bootstrap, set ANIMA_MODELS_ROOT, "
            "or place the file under diffusion_models/."
        )
    _apply_local_te_vae_if_configured()
    # Single-file: transformer from local RDBT; TE/VAE from local Comfy-style files if patched, else preview Hub
    return AnimaPipeline.from_single_file(
        rdbt,
        device="auto",
        dtype="auto",
        text_encoder_dtype="auto",
    )


def ensure_prepared() -> None:
    """Idempotent: ensure disk artifacts, then load the pipeline (prefer GPU if available)."""
    global _pipe, _prepared
    _set_cudnn_sdp_env()
    with _lock:
        if _prepared and _pipe is not None:
            return
    _bootstrap_files_if_needed()
    with _lock:
        if _prepared and _pipe is not None:
            return
        if not os.path.isfile(_rdbt_path()):
            raise UserFacingError(
                f"Missing RDBT file at {_rdbt_path()!r}. Set SKIP_CIVITAI=0 and ensure a network download, "
                "or place the file manually under diffusion_models/."
            )
        try:
            _pipe = _load_pipeline()
        except UserFacingError:
            raise
        except Exception as e:
            raise UserFacingError(
                "Failed to load the Anima Diffusers pipeline. If this is a new checkpoint, "
                f"it may be incompatible with diffusers-anima. ({e!s})"
            ) from e
        dev = _device_str()
        try:
            if hasattr(_pipe, "to"):
                _pipe.to(dev)
        except Exception as e:
            raise UserFacingError(f"Failed to move pipeline to {dev!r}: {e!s}") from e
        _prepared = True


def _report_progress(
    progress: Any,
    value: float,
    desc: str,
) -> None:
    if progress is None:
        return
    try:
        progress(value, desc=desc)
    except TypeError:
        try:
            progress(value)
        except Exception:
            pass
    except Exception:
        pass


def run_generation(
    p: GenerationParams,
    *,
    progress: Optional[Any] = None,
) -> tuple[list[Image.Image], str, str, str]:
    """
    Run generation.

    Returns ``(images, details string, positive_prompt, negative_prompt)`` where
    the two prompt fields are the exact strings passed to ``AnimaPipeline.__call__``
    (post-validation, before any model-internal template wrapping). May raise UserFacingError.
    """
    if progress is not None:
        _report_progress(
            progress,
            0.0,
            "Preparing (load / encode — first cold start can take several minutes)…",
        )
    ensure_prepared()
    assert _pipe is not None

    anima_s = _map_comfy_sampler_to_anima(p.sampler_name)
    sigma = _map_comfy_scheduler_to_sigma(p.scheduler)
    anima_s, sigma, align_notes = _align_sampling(anima_s, sigma)
    if hasattr(_pipe, "scheduler") and hasattr(_pipe.scheduler, "set_sampling_config"):
        _pipe.scheduler.set_sampling_config(
            sampler=anima_s,
            sigma_schedule=sigma,
        )

    dev = _device_str()
    g = torch.Generator(device=dev)
    g.manual_seed(int(p.seed) % (2**32))

    extra_notes: list[str] = list(align_notes)
    # Anima diffusers: strength only applies to img2img; txt2img requires strength=1.0
    if not math.isclose(float(p.denoise), 1.0, rel_tol=0.0, abs_tol=0.01):
        extra_notes.append(
            f"`denoise`={p.denoise} ignored for text-to-image (Diffusers requires strength=1.0 without an init image). "
        )
    strength_val = 1.0
    n_steps = max(int(p.steps), 1)

    def on_step_end(
        _pipe: Any,
        step: int,
        _timestep: Any,
        callback_kwargs: dict[str, Any],
    ) -> dict[str, Any]:
        if progress is not None:
            frac = (float(step) + 1.0) / float(n_steps)
            _report_progress(
                progress,
                min(0.99, max(0.0, frac)),
                f"Denoising step {int(step) + 1} / {n_steps}",
            )
        return callback_kwargs

    call_kw: dict[str, Any] = {
        "negative_prompt": p.negative_prompt,
        "width": int(p.width),
        "height": int(p.height),
        "num_inference_steps": int(p.steps),
        "guidance_scale": float(p.cfg),
        "num_images_per_prompt": int(p.batch_size),
        "strength": strength_val,
        "generator": g,
    }
    if progress is not None:
        call_kw["callback_on_step_end"] = on_step_end
    try:
        out = _pipe(p.prompt, **call_kw)
    except Exception as e:
        if progress is not None and "callback_on_step_end" in call_kw:
            call_kw.pop("callback_on_step_end", None)
            try:
                out = _pipe(p.prompt, **call_kw)
            except Exception as e2:
                raise UserFacingError(
                    f"Diffusers generation failed: {e2!s}. If sampler/scheduler is invalid, try euler_ancestral + simple."
                ) from e2
        else:
            raise UserFacingError(
                f"Diffusers generation failed: {e!s}. If sampler/scheduler is invalid, try euler_ancestral + simple."
            ) from e
    if progress is not None:
        _report_progress(progress, 1.0, "Done.")

    images = list(out.images)  # AnimaPipelineOutput
    if not images:
        raise UserFacingError("Pipeline returned no images.")

    det = (
        f"seed={p.seed} | {p.width}x{p.height} | steps={p.steps} | cfg={p.cfg} | "
        f"batch={p.batch_size} | {p.sampler_name}/{p.scheduler} (anima={anima_s}/{sigma}) | denoise={p.denoise}"
    )
    if extra_notes:
        det += " | " + " ".join(extra_notes)

    return (
        [im.convert("RGB") if hasattr(im, "convert") else im for im in images],
        det,
        p.prompt,
        p.negative_prompt,
    )