SyncAI / src /image_generator_api.py
ICGenAIShare04's picture
Upload 52 files
72f552e verified
"""Image generation using SDXL + LoRA styles via fal.ai API.
API counterpart to image_generator_hf.py (on-device diffusers).
Uses the fal-ai/lora endpoint which accepts HuggingFace LoRA repo IDs
directly, so styles.py works unchanged.
Set FAL_KEY env var before use.
"""
import json
import time
from pathlib import Path
from typing import Optional
import requests
from dotenv import load_dotenv
from src.styles import get_style
load_dotenv()
# ---------------------------------------------------------------------------
# Config — matches image_generator_hf.py output
# ---------------------------------------------------------------------------
FAL_MODEL_ID = "fal-ai/lora"
BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
WIDTH = 768
HEIGHT = 1344
NUM_STEPS = 30
GUIDANCE_SCALE = 7.5
def _build_loras(style: dict) -> list[dict]:
"""Build the LoRA list for the fal.ai API from a style dict.
Note: Hyper-SD speed LoRA is NOT used here (it's an on-device optimization
requiring specific scheduler config). fal.ai runs on fast GPUs so we use
standard settings (30 steps, DPM++ 2M Karras) instead.
"""
loras = []
if style["source"] is not None:
# Pass HF repo ID directly — fal.ai resolves it internally.
# Full URLs to /resolve/main/ can fail with redirect issues.
loras.append({"path": style["source"], "scale": style["weight"]})
return loras
def _download_image(url: str, output_path: Path, retries: int = 3) -> Path:
"""Download an image from URL to a local file with retry."""
output_path.parent.mkdir(parents=True, exist_ok=True)
for attempt in range(retries):
try:
resp = requests.get(url, timeout=120)
resp.raise_for_status()
with open(output_path, "wb") as f:
f.write(resp.content)
return output_path
except (requests.exceptions.SSLError, requests.exceptions.ConnectionError) as e:
if attempt < retries - 1:
print(f" Download failed (attempt {attempt + 1}), retrying...")
else:
raise
def generate_image(
prompt: str,
negative_prompt: str = "",
loras: list[dict] | None = None,
seed: Optional[int] = None,
) -> dict:
"""Generate a single image via fal.ai API.
Args:
prompt: SDXL prompt.
negative_prompt: Negative prompt.
loras: List of LoRA dicts with 'path' and 'scale'.
seed: Random seed.
Returns:
API response dict with 'images' list and 'seed'.
"""
import fal_client
args = {
"model_name": BASE_MODEL,
"prompt": prompt,
"negative_prompt": negative_prompt,
"image_size": {"width": WIDTH, "height": HEIGHT},
"num_inference_steps": NUM_STEPS,
"guidance_scale": GUIDANCE_SCALE,
"scheduler": "DPM++ 2M Karras",
"num_images": 1,
"image_format": "png",
"enable_safety_checker": False,
}
if loras:
args["loras"] = loras
if seed is not None:
args["seed"] = seed
result = fal_client.subscribe(FAL_MODEL_ID, arguments=args)
return result
def generate_all(
segments: list[dict],
output_dir: str | Path,
style_name: str = "Warm Sunset",
seed: int = 42,
progress_callback=None,
) -> list[Path]:
"""Generate images for all segments via fal.ai.
Args:
segments: List of segment dicts (with 'prompt' and 'negative_prompt').
output_dir: Directory to save images.
style_name: Style from styles.py registry.
seed: Base seed (incremented per segment).
Returns:
List of saved image paths.
"""
style = get_style(style_name)
loras = _build_loras(style)
trigger = style["trigger"]
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
paths = []
for seg in segments:
idx = seg["segment"]
path = output_dir / f"segment_{idx:03d}.png"
if path.exists():
print(f" Segment {idx}/{len(segments)}: already exists, skipping")
paths.append(path)
continue
prompt = seg["prompt"]
if trigger:
prompt = f"{trigger} style, {prompt}"
neg = seg.get("negative_prompt", "")
print(f" Segment {idx}/{len(segments)}: generating image (fal.ai)...")
t0 = time.time()
result = generate_image(prompt, neg, loras=loras, seed=seed + idx)
elapsed = time.time() - t0
image_url = result["images"][0]["url"]
_download_image(image_url, path)
paths.append(path)
print(f" Saved {path.name} ({elapsed:.1f}s)")
if progress_callback:
progress_callback(idx, len(segments))
return paths
def run(
data_dir: str | Path,
style_name: str = "Warm Sunset",
seed: int = 42,
progress_callback=None,
) -> list[Path]:
"""Full image generation pipeline: read segments, generate via API, save.
Args:
data_dir: Run directory containing segments.json.
style_name: Style from the registry (see src/styles.py).
seed: Base random seed.
Returns:
List of saved image paths.
"""
data_dir = Path(data_dir)
with open(data_dir / "segments.json") as f:
segments = json.load(f)
paths = generate_all(segments, data_dir / "images", style_name, seed, progress_callback)
print(f"\nGenerated {len(paths)} images in {data_dir / 'images'}")
return paths
if __name__ == "__main__":
import os
import sys
if len(sys.argv) < 2:
print("Usage: python -m src.image_generator_api <data_dir> [style_name]")
print(' e.g. python -m src.image_generator_api data/Gone/run_001 "Warm Sunset"')
print("\nRequires FAL_KEY environment variable.")
sys.exit(1)
if not os.getenv("FAL_KEY"):
print("Error: FAL_KEY environment variable not set.")
print("Get your key at https://fal.ai/dashboard/keys")
sys.exit(1)
style = sys.argv[2] if len(sys.argv) > 2 else "Warm Sunset"
run(sys.argv[1], style_name=style)