File size: 6,145 Bytes
72f552e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | """Image generation using SDXL + LoRA styles via fal.ai API.
API counterpart to image_generator_hf.py (on-device diffusers).
Uses the fal-ai/lora endpoint which accepts HuggingFace LoRA repo IDs
directly, so styles.py works unchanged.
Set FAL_KEY env var before use.
"""
import json
import time
from pathlib import Path
from typing import Optional
import requests
from dotenv import load_dotenv
from src.styles import get_style
load_dotenv()
# ---------------------------------------------------------------------------
# Config — matches image_generator_hf.py output
# ---------------------------------------------------------------------------
FAL_MODEL_ID = "fal-ai/lora"
BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
WIDTH = 768
HEIGHT = 1344
NUM_STEPS = 30
GUIDANCE_SCALE = 7.5
def _build_loras(style: dict) -> list[dict]:
"""Build the LoRA list for the fal.ai API from a style dict.
Note: Hyper-SD speed LoRA is NOT used here (it's an on-device optimization
requiring specific scheduler config). fal.ai runs on fast GPUs so we use
standard settings (30 steps, DPM++ 2M Karras) instead.
"""
loras = []
if style["source"] is not None:
# Pass HF repo ID directly — fal.ai resolves it internally.
# Full URLs to /resolve/main/ can fail with redirect issues.
loras.append({"path": style["source"], "scale": style["weight"]})
return loras
def _download_image(url: str, output_path: Path, retries: int = 3) -> Path:
"""Download an image from URL to a local file with retry."""
output_path.parent.mkdir(parents=True, exist_ok=True)
for attempt in range(retries):
try:
resp = requests.get(url, timeout=120)
resp.raise_for_status()
with open(output_path, "wb") as f:
f.write(resp.content)
return output_path
except (requests.exceptions.SSLError, requests.exceptions.ConnectionError) as e:
if attempt < retries - 1:
print(f" Download failed (attempt {attempt + 1}), retrying...")
else:
raise
def generate_image(
prompt: str,
negative_prompt: str = "",
loras: list[dict] | None = None,
seed: Optional[int] = None,
) -> dict:
"""Generate a single image via fal.ai API.
Args:
prompt: SDXL prompt.
negative_prompt: Negative prompt.
loras: List of LoRA dicts with 'path' and 'scale'.
seed: Random seed.
Returns:
API response dict with 'images' list and 'seed'.
"""
import fal_client
args = {
"model_name": BASE_MODEL,
"prompt": prompt,
"negative_prompt": negative_prompt,
"image_size": {"width": WIDTH, "height": HEIGHT},
"num_inference_steps": NUM_STEPS,
"guidance_scale": GUIDANCE_SCALE,
"scheduler": "DPM++ 2M Karras",
"num_images": 1,
"image_format": "png",
"enable_safety_checker": False,
}
if loras:
args["loras"] = loras
if seed is not None:
args["seed"] = seed
result = fal_client.subscribe(FAL_MODEL_ID, arguments=args)
return result
def generate_all(
segments: list[dict],
output_dir: str | Path,
style_name: str = "Warm Sunset",
seed: int = 42,
progress_callback=None,
) -> list[Path]:
"""Generate images for all segments via fal.ai.
Args:
segments: List of segment dicts (with 'prompt' and 'negative_prompt').
output_dir: Directory to save images.
style_name: Style from styles.py registry.
seed: Base seed (incremented per segment).
Returns:
List of saved image paths.
"""
style = get_style(style_name)
loras = _build_loras(style)
trigger = style["trigger"]
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
paths = []
for seg in segments:
idx = seg["segment"]
path = output_dir / f"segment_{idx:03d}.png"
if path.exists():
print(f" Segment {idx}/{len(segments)}: already exists, skipping")
paths.append(path)
continue
prompt = seg["prompt"]
if trigger:
prompt = f"{trigger} style, {prompt}"
neg = seg.get("negative_prompt", "")
print(f" Segment {idx}/{len(segments)}: generating image (fal.ai)...")
t0 = time.time()
result = generate_image(prompt, neg, loras=loras, seed=seed + idx)
elapsed = time.time() - t0
image_url = result["images"][0]["url"]
_download_image(image_url, path)
paths.append(path)
print(f" Saved {path.name} ({elapsed:.1f}s)")
if progress_callback:
progress_callback(idx, len(segments))
return paths
def run(
data_dir: str | Path,
style_name: str = "Warm Sunset",
seed: int = 42,
progress_callback=None,
) -> list[Path]:
"""Full image generation pipeline: read segments, generate via API, save.
Args:
data_dir: Run directory containing segments.json.
style_name: Style from the registry (see src/styles.py).
seed: Base random seed.
Returns:
List of saved image paths.
"""
data_dir = Path(data_dir)
with open(data_dir / "segments.json") as f:
segments = json.load(f)
paths = generate_all(segments, data_dir / "images", style_name, seed, progress_callback)
print(f"\nGenerated {len(paths)} images in {data_dir / 'images'}")
return paths
if __name__ == "__main__":
import os
import sys
if len(sys.argv) < 2:
print("Usage: python -m src.image_generator_api <data_dir> [style_name]")
print(' e.g. python -m src.image_generator_api data/Gone/run_001 "Warm Sunset"')
print("\nRequires FAL_KEY environment variable.")
sys.exit(1)
if not os.getenv("FAL_KEY"):
print("Error: FAL_KEY environment variable not set.")
print("Get your key at https://fal.ai/dashboard/keys")
sys.exit(1)
style = sys.argv[2] if len(sys.argv) > 2 else "Warm Sunset"
run(sys.argv[1], style_name=style)
|