"""Generate images using SDXL + Hyper-SD 8-step + style LoRA from registry. Reads segments.json (with prompts from prompt_generator) and generates one 768x1344 (9:16 vertical) image per segment. Pipeline: SDXL base → Hyper-SD 8-step CFG LoRA (speed) → style LoRA (aesthetics) """ import json from pathlib import Path from typing import Optional import torch from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline from huggingface_hub import hf_hub_download from src.styles import get_style # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" VAE_MODEL = "madebyollin/sdxl-vae-fp16-fix" HYPER_SD_REPO = "ByteDance/Hyper-SD" HYPER_SD_FILE = "Hyper-SDXL-8steps-CFG-lora.safetensors" WIDTH = 768 HEIGHT = 1344 NUM_STEPS = 8 GUIDANCE_SCALE = 5.0 HYPER_SD_WEIGHT = 0.125 # official recommendation def _get_device_and_dtype(): """Detect best available device and matching dtype.""" if torch.cuda.is_available(): return "cuda", torch.float16 if torch.backends.mps.is_available(): return "mps", torch.float32 # float32 required for MPS reliability return "cpu", torch.float32 def load_pipeline(style_name: str = "Warm Sunset"): """Load SDXL pipeline with Hyper-SD and a style LoRA from the registry. Args: style_name: Key in STYLES registry. Use "None" for no style LoRA. Returns: Configured DiffusionPipeline ready for inference. """ style = get_style(style_name) device, dtype = _get_device_and_dtype() print(f"Loading SDXL pipeline on {device} ({dtype})...") vae = AutoencoderKL.from_pretrained(VAE_MODEL, torch_dtype=dtype) load_kwargs = {"torch_dtype": dtype, "vae": vae, "use_safetensors": True} if dtype == torch.float16: load_kwargs["variant"] = "fp16" pipe = DiffusionPipeline.from_pretrained(BASE_MODEL, **load_kwargs) # Hyper-SD 8-step CFG LoRA (always loaded) hyper_path = hf_hub_download(HYPER_SD_REPO, HYPER_SD_FILE) pipe.load_lora_weights(hyper_path, adapter_name="hyper-sd") # Style LoRA from registry _apply_style(pipe, style) # DDIMScheduler with trailing timestep spacing — required for Hyper-SD pipe.scheduler = DDIMScheduler.from_config( pipe.scheduler.config, timestep_spacing="trailing" ) pipe.to(device) if device == "mps": pipe.enable_attention_slicing() pipe.enable_vae_slicing() print("Pipeline ready.") return pipe def _apply_style(pipe, style: dict): """Load a style LoRA and set adapter weights.""" source = style["source"] if source is None: pipe.set_adapters(["hyper-sd"], adapter_weights=[HYPER_SD_WEIGHT]) print("No style LoRA — using base SDXL + Hyper-SD.") return load_kwargs = {"adapter_name": "style"} # Local file: resolve relative to project root, pass dir + weight_name project_root = Path(__file__).resolve().parent.parent source_path = (project_root / source).resolve() if source_path.is_file(): load_kwargs["weight_name"] = source_path.name pipe.load_lora_weights(str(source_path.parent), **load_kwargs) else: # HF Hub repo ID if style["weight_name"]: load_kwargs["weight_name"] = style["weight_name"] pipe.load_lora_weights(source, **load_kwargs) pipe.set_adapters( ["hyper-sd", "style"], adapter_weights=[HYPER_SD_WEIGHT, style["weight"]], ) print(f"Loaded style LoRA: {source}") def switch_style(pipe, style_name: str): """Switch to a different style LoRA at runtime. Unloads all LoRAs then reloads Hyper-SD + new style. """ style = get_style(style_name) pipe.unload_lora_weights() # Re-load Hyper-SD hyper_path = hf_hub_download(HYPER_SD_REPO, HYPER_SD_FILE) pipe.load_lora_weights(hyper_path, adapter_name="hyper-sd") # Load new style _apply_style(pipe, style) print(f"Switched to style: {style_name}") def generate_image( pipe, prompt: str, negative_prompt: str = "", seed: Optional[int] = None, ) -> "PIL.Image.Image": """Generate a single 768x1344 vertical image.""" generator = None if seed is not None: generator = torch.Generator(device="cpu").manual_seed(seed) return pipe( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=NUM_STEPS, guidance_scale=GUIDANCE_SCALE, height=HEIGHT, width=WIDTH, generator=generator, ).images[0] def generate_all( segments: list[dict], pipe, output_dir: str | Path, trigger_word: str = "", seed: int = 42, progress_callback=None, ) -> list[Path]: """Generate images for all segments. Args: segments: List of segment dicts (with 'prompt' and 'negative_prompt'). pipe: Loaded DiffusionPipeline. output_dir: Directory to save images. trigger_word: LoRA trigger word appended to prompts. seed: Base seed (incremented per segment for variety). Returns: List of saved image paths. """ output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) paths = [] for seg in segments: idx = seg["segment"] path = output_dir / f"segment_{idx:03d}.png" if path.exists(): print(f" Segment {idx}/{len(segments)}: already exists, skipping") paths.append(path) continue prompt = seg["prompt"] if trigger_word: prompt = f"{trigger_word} style, {prompt}" neg = seg.get("negative_prompt", "") print(f" Segment {idx}/{len(segments)}: generating...") image = generate_image(pipe, prompt, neg, seed=seed + idx) path = output_dir / f"segment_{idx:03d}.png" image.save(path) paths.append(path) print(f" Saved {path.name}") if progress_callback: progress_callback(idx, len(segments)) return paths def run( data_dir: str | Path, style_name: str = "Warm Sunset", seed: int = 42, progress_callback=None, ) -> list[Path]: """Full image generation pipeline: load model, read segments, generate, save. Args: data_dir: Run directory containing segments.json (e.g. data/Gone/run_001/). style_name: Style from the registry (see src/styles.py). seed: Base random seed. Returns: List of saved image paths. """ data_dir = Path(data_dir) style = get_style(style_name) with open(data_dir / "segments.json") as f: segments = json.load(f) pipe = load_pipeline(style_name) paths = generate_all(segments, pipe, data_dir / "images", style["trigger"], seed, progress_callback) print(f"\nGenerated {len(paths)} images in {data_dir / 'images'}") return paths if __name__ == "__main__": import sys if len(sys.argv) < 2: print("Usage: python -m src.image_generator_hf [style_name]") print(' e.g. python -m src.image_generator_hf data/Gone/run_001 "Warm Sunset"') sys.exit(1) style = sys.argv[2] if len(sys.argv) > 2 else "Warm Sunset" run(sys.argv[1], style_name=style)