| | """Image generation using SDXL + LoRA styles via fal.ai API. |
| | |
| | API counterpart to image_generator_hf.py (on-device diffusers). |
| | Uses the fal-ai/lora endpoint which accepts HuggingFace LoRA repo IDs |
| | directly, so styles.py works unchanged. |
| | |
| | Set FAL_KEY env var before use. |
| | """ |
| |
|
| | import json |
| | import time |
| | from pathlib import Path |
| | from typing import Optional |
| |
|
| | import requests |
| | from dotenv import load_dotenv |
| |
|
| | from src.styles import get_style |
| |
|
| | load_dotenv() |
| |
|
| | |
| | |
| | |
| |
|
| | FAL_MODEL_ID = "fal-ai/lora" |
| |
|
| | BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" |
| |
|
| | WIDTH = 768 |
| | HEIGHT = 1344 |
| | NUM_STEPS = 30 |
| | GUIDANCE_SCALE = 7.5 |
| |
|
| |
|
| | def _build_loras(style: dict) -> list[dict]: |
| | """Build the LoRA list for the fal.ai API from a style dict. |
| | |
| | Note: Hyper-SD speed LoRA is NOT used here (it's an on-device optimization |
| | requiring specific scheduler config). fal.ai runs on fast GPUs so we use |
| | standard settings (30 steps, DPM++ 2M Karras) instead. |
| | """ |
| | loras = [] |
| |
|
| | if style["source"] is not None: |
| | |
| | |
| | loras.append({"path": style["source"], "scale": style["weight"]}) |
| |
|
| | return loras |
| |
|
| |
|
| | def _download_image(url: str, output_path: Path, retries: int = 3) -> Path: |
| | """Download an image from URL to a local file with retry.""" |
| | output_path.parent.mkdir(parents=True, exist_ok=True) |
| | for attempt in range(retries): |
| | try: |
| | resp = requests.get(url, timeout=120) |
| | resp.raise_for_status() |
| | with open(output_path, "wb") as f: |
| | f.write(resp.content) |
| | return output_path |
| | except (requests.exceptions.SSLError, requests.exceptions.ConnectionError) as e: |
| | if attempt < retries - 1: |
| | print(f" Download failed (attempt {attempt + 1}), retrying...") |
| | else: |
| | raise |
| |
|
| |
|
| | def generate_image( |
| | prompt: str, |
| | negative_prompt: str = "", |
| | loras: list[dict] | None = None, |
| | seed: Optional[int] = None, |
| | ) -> dict: |
| | """Generate a single image via fal.ai API. |
| | |
| | Args: |
| | prompt: SDXL prompt. |
| | negative_prompt: Negative prompt. |
| | loras: List of LoRA dicts with 'path' and 'scale'. |
| | seed: Random seed. |
| | |
| | Returns: |
| | API response dict with 'images' list and 'seed'. |
| | """ |
| | import fal_client |
| |
|
| | args = { |
| | "model_name": BASE_MODEL, |
| | "prompt": prompt, |
| | "negative_prompt": negative_prompt, |
| | "image_size": {"width": WIDTH, "height": HEIGHT}, |
| | "num_inference_steps": NUM_STEPS, |
| | "guidance_scale": GUIDANCE_SCALE, |
| | "scheduler": "DPM++ 2M Karras", |
| | "num_images": 1, |
| | "image_format": "png", |
| | "enable_safety_checker": False, |
| | } |
| | if loras: |
| | args["loras"] = loras |
| | if seed is not None: |
| | args["seed"] = seed |
| |
|
| | result = fal_client.subscribe(FAL_MODEL_ID, arguments=args) |
| | return result |
| |
|
| |
|
| | def generate_all( |
| | segments: list[dict], |
| | output_dir: str | Path, |
| | style_name: str = "Warm Sunset", |
| | seed: int = 42, |
| | progress_callback=None, |
| | ) -> list[Path]: |
| | """Generate images for all segments via fal.ai. |
| | |
| | Args: |
| | segments: List of segment dicts (with 'prompt' and 'negative_prompt'). |
| | output_dir: Directory to save images. |
| | style_name: Style from styles.py registry. |
| | seed: Base seed (incremented per segment). |
| | |
| | Returns: |
| | List of saved image paths. |
| | """ |
| | style = get_style(style_name) |
| | loras = _build_loras(style) |
| | trigger = style["trigger"] |
| | output_dir = Path(output_dir) |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | paths = [] |
| | for seg in segments: |
| | idx = seg["segment"] |
| | path = output_dir / f"segment_{idx:03d}.png" |
| |
|
| | if path.exists(): |
| | print(f" Segment {idx}/{len(segments)}: already exists, skipping") |
| | paths.append(path) |
| | continue |
| |
|
| | prompt = seg["prompt"] |
| | if trigger: |
| | prompt = f"{trigger} style, {prompt}" |
| | neg = seg.get("negative_prompt", "") |
| |
|
| | print(f" Segment {idx}/{len(segments)}: generating image (fal.ai)...") |
| | t0 = time.time() |
| | result = generate_image(prompt, neg, loras=loras, seed=seed + idx) |
| | elapsed = time.time() - t0 |
| |
|
| | image_url = result["images"][0]["url"] |
| | _download_image(image_url, path) |
| | paths.append(path) |
| | print(f" Saved {path.name} ({elapsed:.1f}s)") |
| | if progress_callback: |
| | progress_callback(idx, len(segments)) |
| |
|
| | return paths |
| |
|
| |
|
| | def run( |
| | data_dir: str | Path, |
| | style_name: str = "Warm Sunset", |
| | seed: int = 42, |
| | progress_callback=None, |
| | ) -> list[Path]: |
| | """Full image generation pipeline: read segments, generate via API, save. |
| | |
| | Args: |
| | data_dir: Run directory containing segments.json. |
| | style_name: Style from the registry (see src/styles.py). |
| | seed: Base random seed. |
| | |
| | Returns: |
| | List of saved image paths. |
| | """ |
| | data_dir = Path(data_dir) |
| |
|
| | with open(data_dir / "segments.json") as f: |
| | segments = json.load(f) |
| |
|
| | paths = generate_all(segments, data_dir / "images", style_name, seed, progress_callback) |
| |
|
| | print(f"\nGenerated {len(paths)} images in {data_dir / 'images'}") |
| | return paths |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import os |
| | import sys |
| |
|
| | if len(sys.argv) < 2: |
| | print("Usage: python -m src.image_generator_api <data_dir> [style_name]") |
| | print(' e.g. python -m src.image_generator_api data/Gone/run_001 "Warm Sunset"') |
| | print("\nRequires FAL_KEY environment variable.") |
| | sys.exit(1) |
| |
|
| | if not os.getenv("FAL_KEY"): |
| | print("Error: FAL_KEY environment variable not set.") |
| | print("Get your key at https://fal.ai/dashboard/keys") |
| | sys.exit(1) |
| |
|
| | style = sys.argv[2] if len(sys.argv) > 2 else "Warm Sunset" |
| | run(sys.argv[1], style_name=style) |
| |
|