"""Image generation using SDXL + LoRA styles via fal.ai API. API counterpart to image_generator_hf.py (on-device diffusers). Uses the fal-ai/lora endpoint which accepts HuggingFace LoRA repo IDs directly, so styles.py works unchanged. Set FAL_KEY env var before use. """ import json import time from pathlib import Path from typing import Optional import requests from dotenv import load_dotenv from src.styles import get_style load_dotenv() # --------------------------------------------------------------------------- # Config — matches image_generator_hf.py output # --------------------------------------------------------------------------- FAL_MODEL_ID = "fal-ai/lora" BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" WIDTH = 768 HEIGHT = 1344 NUM_STEPS = 30 GUIDANCE_SCALE = 7.5 def _build_loras(style: dict) -> list[dict]: """Build the LoRA list for the fal.ai API from a style dict. Note: Hyper-SD speed LoRA is NOT used here (it's an on-device optimization requiring specific scheduler config). fal.ai runs on fast GPUs so we use standard settings (30 steps, DPM++ 2M Karras) instead. """ loras = [] if style["source"] is not None: # Pass HF repo ID directly — fal.ai resolves it internally. # Full URLs to /resolve/main/ can fail with redirect issues. loras.append({"path": style["source"], "scale": style["weight"]}) return loras def _download_image(url: str, output_path: Path, retries: int = 3) -> Path: """Download an image from URL to a local file with retry.""" output_path.parent.mkdir(parents=True, exist_ok=True) for attempt in range(retries): try: resp = requests.get(url, timeout=120) resp.raise_for_status() with open(output_path, "wb") as f: f.write(resp.content) return output_path except (requests.exceptions.SSLError, requests.exceptions.ConnectionError) as e: if attempt < retries - 1: print(f" Download failed (attempt {attempt + 1}), retrying...") else: raise def generate_image( prompt: str, negative_prompt: str = "", loras: list[dict] | None = None, seed: Optional[int] = None, ) -> dict: """Generate a single image via fal.ai API. Args: prompt: SDXL prompt. negative_prompt: Negative prompt. loras: List of LoRA dicts with 'path' and 'scale'. seed: Random seed. Returns: API response dict with 'images' list and 'seed'. """ import fal_client args = { "model_name": BASE_MODEL, "prompt": prompt, "negative_prompt": negative_prompt, "image_size": {"width": WIDTH, "height": HEIGHT}, "num_inference_steps": NUM_STEPS, "guidance_scale": GUIDANCE_SCALE, "scheduler": "DPM++ 2M Karras", "num_images": 1, "image_format": "png", "enable_safety_checker": False, } if loras: args["loras"] = loras if seed is not None: args["seed"] = seed result = fal_client.subscribe(FAL_MODEL_ID, arguments=args) return result def generate_all( segments: list[dict], output_dir: str | Path, style_name: str = "Warm Sunset", seed: int = 42, progress_callback=None, ) -> list[Path]: """Generate images for all segments via fal.ai. Args: segments: List of segment dicts (with 'prompt' and 'negative_prompt'). output_dir: Directory to save images. style_name: Style from styles.py registry. seed: Base seed (incremented per segment). Returns: List of saved image paths. """ style = get_style(style_name) loras = _build_loras(style) trigger = style["trigger"] output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) paths = [] for seg in segments: idx = seg["segment"] path = output_dir / f"segment_{idx:03d}.png" if path.exists(): print(f" Segment {idx}/{len(segments)}: already exists, skipping") paths.append(path) continue prompt = seg["prompt"] if trigger: prompt = f"{trigger} style, {prompt}" neg = seg.get("negative_prompt", "") print(f" Segment {idx}/{len(segments)}: generating image (fal.ai)...") t0 = time.time() result = generate_image(prompt, neg, loras=loras, seed=seed + idx) elapsed = time.time() - t0 image_url = result["images"][0]["url"] _download_image(image_url, path) paths.append(path) print(f" Saved {path.name} ({elapsed:.1f}s)") if progress_callback: progress_callback(idx, len(segments)) return paths def run( data_dir: str | Path, style_name: str = "Warm Sunset", seed: int = 42, progress_callback=None, ) -> list[Path]: """Full image generation pipeline: read segments, generate via API, save. Args: data_dir: Run directory containing segments.json. style_name: Style from the registry (see src/styles.py). seed: Base random seed. Returns: List of saved image paths. """ data_dir = Path(data_dir) with open(data_dir / "segments.json") as f: segments = json.load(f) paths = generate_all(segments, data_dir / "images", style_name, seed, progress_callback) print(f"\nGenerated {len(paths)} images in {data_dir / 'images'}") return paths if __name__ == "__main__": import os import sys if len(sys.argv) < 2: print("Usage: python -m src.image_generator_api [style_name]") print(' e.g. python -m src.image_generator_api data/Gone/run_001 "Warm Sunset"') print("\nRequires FAL_KEY environment variable.") sys.exit(1) if not os.getenv("FAL_KEY"): print("Error: FAL_KEY environment variable not set.") print("Get your key at https://fal.ai/dashboard/keys") sys.exit(1) style = sys.argv[2] if len(sys.argv) > 2 else "Warm Sunset" run(sys.argv[1], style_name=style)