| | """Generate images using SDXL + Hyper-SD 8-step + style LoRA from registry. |
| | |
| | Reads segments.json (with prompts from prompt_generator) and generates |
| | one 768x1344 (9:16 vertical) image per segment. |
| | |
| | Pipeline: SDXL base → Hyper-SD 8-step CFG LoRA (speed) → style LoRA (aesthetics) |
| | """ |
| |
|
| | import json |
| | from pathlib import Path |
| | from typing import Optional |
| |
|
| | import torch |
| | from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline |
| | from huggingface_hub import hf_hub_download |
| |
|
| | from src.styles import get_style |
| |
|
| | |
| | |
| | |
| |
|
| | BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" |
| | VAE_MODEL = "madebyollin/sdxl-vae-fp16-fix" |
| | HYPER_SD_REPO = "ByteDance/Hyper-SD" |
| | HYPER_SD_FILE = "Hyper-SDXL-8steps-CFG-lora.safetensors" |
| |
|
| | WIDTH = 768 |
| | HEIGHT = 1344 |
| | NUM_STEPS = 8 |
| | GUIDANCE_SCALE = 5.0 |
| |
|
| | HYPER_SD_WEIGHT = 0.125 |
| |
|
| |
|
| | def _get_device_and_dtype(): |
| | """Detect best available device and matching dtype.""" |
| | if torch.cuda.is_available(): |
| | return "cuda", torch.float16 |
| | if torch.backends.mps.is_available(): |
| | return "mps", torch.float32 |
| | return "cpu", torch.float32 |
| |
|
| |
|
| | def load_pipeline(style_name: str = "Warm Sunset"): |
| | """Load SDXL pipeline with Hyper-SD and a style LoRA from the registry. |
| | |
| | Args: |
| | style_name: Key in STYLES registry. Use "None" for no style LoRA. |
| | |
| | Returns: |
| | Configured DiffusionPipeline ready for inference. |
| | """ |
| | style = get_style(style_name) |
| | device, dtype = _get_device_and_dtype() |
| | print(f"Loading SDXL pipeline on {device} ({dtype})...") |
| |
|
| | vae = AutoencoderKL.from_pretrained(VAE_MODEL, torch_dtype=dtype) |
| |
|
| | load_kwargs = {"torch_dtype": dtype, "vae": vae, "use_safetensors": True} |
| | if dtype == torch.float16: |
| | load_kwargs["variant"] = "fp16" |
| |
|
| | pipe = DiffusionPipeline.from_pretrained(BASE_MODEL, **load_kwargs) |
| |
|
| | |
| | hyper_path = hf_hub_download(HYPER_SD_REPO, HYPER_SD_FILE) |
| | pipe.load_lora_weights(hyper_path, adapter_name="hyper-sd") |
| |
|
| | |
| | _apply_style(pipe, style) |
| |
|
| | |
| | pipe.scheduler = DDIMScheduler.from_config( |
| | pipe.scheduler.config, timestep_spacing="trailing" |
| | ) |
| |
|
| | pipe.to(device) |
| |
|
| | if device == "mps": |
| | pipe.enable_attention_slicing() |
| | pipe.enable_vae_slicing() |
| |
|
| | print("Pipeline ready.") |
| | return pipe |
| |
|
| |
|
| | def _apply_style(pipe, style: dict): |
| | """Load a style LoRA and set adapter weights.""" |
| | source = style["source"] |
| | if source is None: |
| | pipe.set_adapters(["hyper-sd"], adapter_weights=[HYPER_SD_WEIGHT]) |
| | print("No style LoRA — using base SDXL + Hyper-SD.") |
| | return |
| |
|
| | load_kwargs = {"adapter_name": "style"} |
| |
|
| | |
| | project_root = Path(__file__).resolve().parent.parent |
| | source_path = (project_root / source).resolve() |
| | if source_path.is_file(): |
| | load_kwargs["weight_name"] = source_path.name |
| | pipe.load_lora_weights(str(source_path.parent), **load_kwargs) |
| | else: |
| | |
| | if style["weight_name"]: |
| | load_kwargs["weight_name"] = style["weight_name"] |
| | pipe.load_lora_weights(source, **load_kwargs) |
| | pipe.set_adapters( |
| | ["hyper-sd", "style"], |
| | adapter_weights=[HYPER_SD_WEIGHT, style["weight"]], |
| | ) |
| | print(f"Loaded style LoRA: {source}") |
| |
|
| |
|
| | def switch_style(pipe, style_name: str): |
| | """Switch to a different style LoRA at runtime. |
| | |
| | Unloads all LoRAs then reloads Hyper-SD + new style. |
| | """ |
| | style = get_style(style_name) |
| |
|
| | pipe.unload_lora_weights() |
| |
|
| | |
| | hyper_path = hf_hub_download(HYPER_SD_REPO, HYPER_SD_FILE) |
| | pipe.load_lora_weights(hyper_path, adapter_name="hyper-sd") |
| |
|
| | |
| | _apply_style(pipe, style) |
| | print(f"Switched to style: {style_name}") |
| |
|
| |
|
| | def generate_image( |
| | pipe, |
| | prompt: str, |
| | negative_prompt: str = "", |
| | seed: Optional[int] = None, |
| | ) -> "PIL.Image.Image": |
| | """Generate a single 768x1344 vertical image.""" |
| | generator = None |
| | if seed is not None: |
| | generator = torch.Generator(device="cpu").manual_seed(seed) |
| |
|
| | return pipe( |
| | prompt=prompt, |
| | negative_prompt=negative_prompt, |
| | num_inference_steps=NUM_STEPS, |
| | guidance_scale=GUIDANCE_SCALE, |
| | height=HEIGHT, |
| | width=WIDTH, |
| | generator=generator, |
| | ).images[0] |
| |
|
| |
|
| | def generate_all( |
| | segments: list[dict], |
| | pipe, |
| | output_dir: str | Path, |
| | trigger_word: str = "", |
| | seed: int = 42, |
| | progress_callback=None, |
| | ) -> list[Path]: |
| | """Generate images for all segments. |
| | |
| | Args: |
| | segments: List of segment dicts (with 'prompt' and 'negative_prompt'). |
| | pipe: Loaded DiffusionPipeline. |
| | output_dir: Directory to save images. |
| | trigger_word: LoRA trigger word appended to prompts. |
| | seed: Base seed (incremented per segment for variety). |
| | |
| | Returns: |
| | List of saved image paths. |
| | """ |
| | output_dir = Path(output_dir) |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | paths = [] |
| | for seg in segments: |
| | idx = seg["segment"] |
| | path = output_dir / f"segment_{idx:03d}.png" |
| |
|
| | if path.exists(): |
| | print(f" Segment {idx}/{len(segments)}: already exists, skipping") |
| | paths.append(path) |
| | continue |
| |
|
| | prompt = seg["prompt"] |
| | if trigger_word: |
| | prompt = f"{trigger_word} style, {prompt}" |
| | neg = seg.get("negative_prompt", "") |
| |
|
| | print(f" Segment {idx}/{len(segments)}: generating...") |
| | image = generate_image(pipe, prompt, neg, seed=seed + idx) |
| |
|
| | path = output_dir / f"segment_{idx:03d}.png" |
| | image.save(path) |
| | paths.append(path) |
| | print(f" Saved {path.name}") |
| | if progress_callback: |
| | progress_callback(idx, len(segments)) |
| |
|
| | return paths |
| |
|
| |
|
| | def run( |
| | data_dir: str | Path, |
| | style_name: str = "Warm Sunset", |
| | seed: int = 42, |
| | progress_callback=None, |
| | ) -> list[Path]: |
| | """Full image generation pipeline: load model, read segments, generate, save. |
| | |
| | Args: |
| | data_dir: Run directory containing segments.json (e.g. data/Gone/run_001/). |
| | style_name: Style from the registry (see src/styles.py). |
| | seed: Base random seed. |
| | |
| | Returns: |
| | List of saved image paths. |
| | """ |
| | data_dir = Path(data_dir) |
| | style = get_style(style_name) |
| |
|
| | with open(data_dir / "segments.json") as f: |
| | segments = json.load(f) |
| |
|
| | pipe = load_pipeline(style_name) |
| | paths = generate_all(segments, pipe, data_dir / "images", style["trigger"], seed, progress_callback) |
| |
|
| | print(f"\nGenerated {len(paths)} images in {data_dir / 'images'}") |
| | return paths |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import sys |
| |
|
| | if len(sys.argv) < 2: |
| | print("Usage: python -m src.image_generator_hf <data_dir> [style_name]") |
| | print(' e.g. python -m src.image_generator_hf data/Gone/run_001 "Warm Sunset"') |
| | sys.exit(1) |
| |
|
| | style = sys.argv[2] if len(sys.argv) > 2 else "Warm Sunset" |
| | run(sys.argv[1], style_name=style) |
| |
|