File size: 7,366 Bytes
72f552e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
"""Generate images using SDXL + Hyper-SD 8-step + style LoRA from registry.

Reads segments.json (with prompts from prompt_generator) and generates
one 768x1344 (9:16 vertical) image per segment.

Pipeline: SDXL base → Hyper-SD 8-step CFG LoRA (speed) → style LoRA (aesthetics)
"""

import json
from pathlib import Path
from typing import Optional

import torch
from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline
from huggingface_hub import hf_hub_download

from src.styles import get_style

# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------

BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
VAE_MODEL = "madebyollin/sdxl-vae-fp16-fix"
HYPER_SD_REPO = "ByteDance/Hyper-SD"
HYPER_SD_FILE = "Hyper-SDXL-8steps-CFG-lora.safetensors"

WIDTH = 768
HEIGHT = 1344
NUM_STEPS = 8
GUIDANCE_SCALE = 5.0

HYPER_SD_WEIGHT = 0.125  # official recommendation


def _get_device_and_dtype():
    """Detect best available device and matching dtype."""
    if torch.cuda.is_available():
        return "cuda", torch.float16
    if torch.backends.mps.is_available():
        return "mps", torch.float32  # float32 required for MPS reliability
    return "cpu", torch.float32


def load_pipeline(style_name: str = "Warm Sunset"):
    """Load SDXL pipeline with Hyper-SD and a style LoRA from the registry.

    Args:
        style_name: Key in STYLES registry. Use "None" for no style LoRA.

    Returns:
        Configured DiffusionPipeline ready for inference.
    """
    style = get_style(style_name)
    device, dtype = _get_device_and_dtype()
    print(f"Loading SDXL pipeline on {device} ({dtype})...")

    vae = AutoencoderKL.from_pretrained(VAE_MODEL, torch_dtype=dtype)

    load_kwargs = {"torch_dtype": dtype, "vae": vae, "use_safetensors": True}
    if dtype == torch.float16:
        load_kwargs["variant"] = "fp16"

    pipe = DiffusionPipeline.from_pretrained(BASE_MODEL, **load_kwargs)

    # Hyper-SD 8-step CFG LoRA (always loaded)
    hyper_path = hf_hub_download(HYPER_SD_REPO, HYPER_SD_FILE)
    pipe.load_lora_weights(hyper_path, adapter_name="hyper-sd")

    # Style LoRA from registry
    _apply_style(pipe, style)

    # DDIMScheduler with trailing timestep spacing — required for Hyper-SD
    pipe.scheduler = DDIMScheduler.from_config(
        pipe.scheduler.config, timestep_spacing="trailing"
    )

    pipe.to(device)

    if device == "mps":
        pipe.enable_attention_slicing()
        pipe.enable_vae_slicing()

    print("Pipeline ready.")
    return pipe


def _apply_style(pipe, style: dict):
    """Load a style LoRA and set adapter weights."""
    source = style["source"]
    if source is None:
        pipe.set_adapters(["hyper-sd"], adapter_weights=[HYPER_SD_WEIGHT])
        print("No style LoRA — using base SDXL + Hyper-SD.")
        return

    load_kwargs = {"adapter_name": "style"}

    # Local file: resolve relative to project root, pass dir + weight_name
    project_root = Path(__file__).resolve().parent.parent
    source_path = (project_root / source).resolve()
    if source_path.is_file():
        load_kwargs["weight_name"] = source_path.name
        pipe.load_lora_weights(str(source_path.parent), **load_kwargs)
    else:
        # HF Hub repo ID
        if style["weight_name"]:
            load_kwargs["weight_name"] = style["weight_name"]
        pipe.load_lora_weights(source, **load_kwargs)
    pipe.set_adapters(
        ["hyper-sd", "style"],
        adapter_weights=[HYPER_SD_WEIGHT, style["weight"]],
    )
    print(f"Loaded style LoRA: {source}")


def switch_style(pipe, style_name: str):
    """Switch to a different style LoRA at runtime.

    Unloads all LoRAs then reloads Hyper-SD + new style.
    """
    style = get_style(style_name)

    pipe.unload_lora_weights()

    # Re-load Hyper-SD
    hyper_path = hf_hub_download(HYPER_SD_REPO, HYPER_SD_FILE)
    pipe.load_lora_weights(hyper_path, adapter_name="hyper-sd")

    # Load new style
    _apply_style(pipe, style)
    print(f"Switched to style: {style_name}")


def generate_image(
    pipe,
    prompt: str,
    negative_prompt: str = "",
    seed: Optional[int] = None,
) -> "PIL.Image.Image":
    """Generate a single 768x1344 vertical image."""
    generator = None
    if seed is not None:
        generator = torch.Generator(device="cpu").manual_seed(seed)

    return pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_inference_steps=NUM_STEPS,
        guidance_scale=GUIDANCE_SCALE,
        height=HEIGHT,
        width=WIDTH,
        generator=generator,
    ).images[0]


def generate_all(
    segments: list[dict],
    pipe,
    output_dir: str | Path,
    trigger_word: str = "",
    seed: int = 42,
    progress_callback=None,
) -> list[Path]:
    """Generate images for all segments.

    Args:
        segments: List of segment dicts (with 'prompt' and 'negative_prompt').
        pipe: Loaded DiffusionPipeline.
        output_dir: Directory to save images.
        trigger_word: LoRA trigger word appended to prompts.
        seed: Base seed (incremented per segment for variety).

    Returns:
        List of saved image paths.
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    paths = []
    for seg in segments:
        idx = seg["segment"]
        path = output_dir / f"segment_{idx:03d}.png"

        if path.exists():
            print(f"  Segment {idx}/{len(segments)}: already exists, skipping")
            paths.append(path)
            continue

        prompt = seg["prompt"]
        if trigger_word:
            prompt = f"{trigger_word} style, {prompt}"
        neg = seg.get("negative_prompt", "")

        print(f"  Segment {idx}/{len(segments)}: generating...")
        image = generate_image(pipe, prompt, neg, seed=seed + idx)

        path = output_dir / f"segment_{idx:03d}.png"
        image.save(path)
        paths.append(path)
        print(f"    Saved {path.name}")
        if progress_callback:
            progress_callback(idx, len(segments))

    return paths


def run(
    data_dir: str | Path,
    style_name: str = "Warm Sunset",
    seed: int = 42,
    progress_callback=None,
) -> list[Path]:
    """Full image generation pipeline: load model, read segments, generate, save.

    Args:
        data_dir: Run directory containing segments.json (e.g. data/Gone/run_001/).
        style_name: Style from the registry (see src/styles.py).
        seed: Base random seed.

    Returns:
        List of saved image paths.
    """
    data_dir = Path(data_dir)
    style = get_style(style_name)

    with open(data_dir / "segments.json") as f:
        segments = json.load(f)

    pipe = load_pipeline(style_name)
    paths = generate_all(segments, pipe, data_dir / "images", style["trigger"], seed, progress_callback)

    print(f"\nGenerated {len(paths)} images in {data_dir / 'images'}")
    return paths


if __name__ == "__main__":
    import sys

    if len(sys.argv) < 2:
        print("Usage: python -m src.image_generator_hf <data_dir> [style_name]")
        print('  e.g. python -m src.image_generator_hf data/Gone/run_001 "Warm Sunset"')
        sys.exit(1)

    style = sys.argv[2] if len(sys.argv) > 2 else "Warm Sunset"
    run(sys.argv[1], style_name=style)