File size: 8,257 Bytes
f2fa09a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
"""

Load Anima RDBT with Diffusers (community Anima pipeline) in-process; no ComfyUI server.

"""
from __future__ import annotations

import math
import os
import threading
from typing import Any

import torch
from PIL import Image

from src import config
from src.config import GenerationParams
from src.errors import UserFacingError

_lock = threading.RLock()
_pipe: Any = None
_prepared: bool = False
_bootstrapped: bool = False


def _set_cudnn_sdp_env() -> None:
    if not config.allow_cudnn_sdp():
        os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "0"


def _device_str() -> str:
    if torch.cuda.is_available():
        return "cuda"
    if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
        return "mps"
    return "cpu"


def _map_comfy_sampler_to_anima(sampler: str) -> str:
    """

    ComfyUI KSampler names -> AnimaFlowMatchEulerDiscreteScheduler (diffusers-anima) samplers.

    Supported: flowmatch_euler, euler, euler_a_rf, euler_ancestral_rf (alias of euler_a_rf).

    """
    s = (sampler or "").strip().lower()
    if s == "euler":
        return "euler"
    if s == "flowmatch_euler" or s == "flow_match_euler":
        return "flowmatch_euler"
    if s in (
        "euler_ancestral",
        "euler_a",
        "euler_ancestral_cfg_pp",
        "euler_a_rf",
    ) or "ancestral" in s:
        return "euler_ancestral_rf"
    # DPM, DDIM, LCM, etc. — no 1:1; match RDBT card default
    return "euler_ancestral_rf"


def _map_comfy_scheduler_to_sigma(scheduler: str) -> str:
    """

    Comfy scheduler names -> Anima sigma_schedule: uniform | simple | normal | beta.

    """
    s = (scheduler or "").strip().lower()
    if s in ("simple", "normal", "beta", "uniform"):
        return s
    if s in ("karras", "exponential", "sgm_uniform", "ddim_uniform", "linear_quadratic", "kl_optimal"):
        return "normal"
    return "simple"


def _align_sampling(

    anima_sampler: str, sigma: str

) -> tuple[str, str, list[str]]:
    """Enforce Anima's valid (sampler, sigma_schedule) pairs; return optional notices."""
    notes: list[str] = []
    s = anima_sampler
    sig = sigma
    if s == "flowmatch_euler" and sig != "uniform":
        sig = "uniform"
        notes.append("Sampler flowmatch_euler requires sigma schedule `uniform`; adjusted.")
    elif s != "flowmatch_euler" and sig == "uniform":
        sig = "simple"
        notes.append("Sigma schedule `uniform` is only for flowmatch_euler; using `simple`.")
    return s, sig, notes


def _rdbt_path() -> str:
    d = config.model_artifacts_root()
    return os.path.join(d, "diffusion_models", config.RDBT_UNET_NAME)


def _bootstrap_files_if_needed() -> None:
    global _bootstrapped
    with _lock:
        if _bootstrapped:
            return
        from src import bootstrap  # local import

        try:
            bootstrap.bootstrap_model_artifacts()
        except UserFacingError:
            raise
        except Exception as e:
            raise UserFacingError(
                f"Model bootstrap failed: {e!s}. See logs for full traceback."
            ) from e
        _bootstrapped = True


def run_at_container_startup() -> None:
    """

    Run at Space import: disk/network only (no CUDA). Downloads RDBT weights, etc.

    Pipeline weights load on first Generate under @spaces.GPU.

    """
    print(
        "[startup] Downloading RDBT weights and preparing model files (CPU/network)…",
        flush=True,
    )
    try:
        _bootstrap_files_if_needed()
    except Exception as e:
        print(f"[startup] Failed: {e!s}", flush=True)
        raise
    print(
        "[startup] Model files ready. The Diffusers pipeline loads on the first **Generate** "
        "when ZeroGPU assigns a GPU to this worker.",
        flush=True,
    )


def _load_pipeline() -> Any:
    try:
        from diffusers_anima import AnimaPipeline
    except ImportError as e:
        raise UserFacingError(
            "The `diffusers_anima` package is not installed. Install with requirements.txt"
            f" (diffusers + diffusers-anima). ({e!s})"
        ) from e
    rdbt = _rdbt_path()
    if not os.path.isfile(rdbt):
        raise UserFacingError(
            f"RDBT checkpoint not found: {rdbt!s}. Re-run startup bootstrap, set ANIMA_MODELS_ROOT, "
            "or place the file under diffusion_models/."
        )
    # Single-file: transformer from local RDBT; TE/VAE/tokenizers from hdae/diffusers-anima-preview
    return AnimaPipeline.from_single_file(
        rdbt,
        device="auto",
        dtype="auto",
        text_encoder_dtype="auto",
    )


def ensure_prepared() -> None:
    """Idempotent: ensure disk artifacts, then load the pipeline (prefer GPU if available)."""
    global _pipe, _prepared
    _set_cudnn_sdp_env()
    with _lock:
        if _prepared and _pipe is not None:
            return
    _bootstrap_files_if_needed()
    with _lock:
        if _prepared and _pipe is not None:
            return
        if not os.path.isfile(_rdbt_path()):
            raise UserFacingError(
                f"Missing RDBT file at {_rdbt_path()!r}. Set SKIP_CIVITAI=0 and ensure a network download, "
                "or place the file manually under diffusion_models/."
            )
        try:
            _pipe = _load_pipeline()
        except UserFacingError:
            raise
        except Exception as e:
            raise UserFacingError(
                "Failed to load the Anima Diffusers pipeline. If this is a new checkpoint, "
                f"it may be incompatible with diffusers-anima. ({e!s})"
            ) from e
        dev = _device_str()
        try:
            if hasattr(_pipe, "to"):
                _pipe.to(dev)
        except Exception as e:
            raise UserFacingError(f"Failed to move pipeline to {dev!r}: {e!s}") from e
        _prepared = True


def run_generation(p: GenerationParams) -> tuple[list[Image.Image], str]:
    """

    Run generation; return (images, details string). May raise UserFacingError.

    """
    ensure_prepared()
    assert _pipe is not None

    anima_s = _map_comfy_sampler_to_anima(p.sampler_name)
    sigma = _map_comfy_scheduler_to_sigma(p.scheduler)
    anima_s, sigma, align_notes = _align_sampling(anima_s, sigma)
    if hasattr(_pipe, "scheduler") and hasattr(_pipe.scheduler, "set_sampling_config"):
        _pipe.scheduler.set_sampling_config(
            sampler=anima_s,
            sigma_schedule=sigma,
        )

    dev = _device_str()
    g = torch.Generator(device=dev)
    g.manual_seed(int(p.seed) % (2**32))

    extra_notes: list[str] = list(align_notes)
    # Anima diffusers: strength only applies to img2img; txt2img requires strength=1.0
    if not math.isclose(float(p.denoise), 1.0, rel_tol=0.0, abs_tol=0.01):
        extra_notes.append(
            f"`denoise`={p.denoise} ignored for text-to-image (Diffusers requires strength=1.0 without an init image). "
        )
    strength_val = 1.0
    try:
        out = _pipe(
            p.prompt,
            negative_prompt=p.negative_prompt,
            width=int(p.width),
            height=int(p.height),
            num_inference_steps=int(p.steps),
            guidance_scale=float(p.cfg),
            num_images_per_prompt=int(p.batch_size),
            strength=strength_val,
            generator=g,
        )
    except Exception as e:
        raise UserFacingError(
            f"Diffusers generation failed: {e!s}. If sampler/scheduler is invalid, try euler_ancestral + simple."
        ) from e

    images = list(out.images)  # AnimaPipelineOutput
    if not images:
        raise UserFacingError("Pipeline returned no images.")

    det = (
        f"seed={p.seed} | {p.width}x{p.height} | steps={p.steps} | cfg={p.cfg} | "
        f"batch={p.batch_size} | {p.sampler_name}/{p.scheduler} (anima={anima_s}/{sigma}) | denoise={p.denoise}"
    )
    if extra_notes:
        det += " | " + " ".join(extra_notes)

    return [im.convert("RGB") if hasattr(im, "convert") else im for im in images], det