Asko Relas
stuff
df8e76b
#!/usr/bin/env python3
"""REST API client for the diffusers-fast-inpaint Gradio app."""
import argparse
import base64
import io
import json
import sys
from pathlib import Path
import requests
from PIL import Image
DEFAULT_SERVER = "http://localhost:7860"
AVAILABLE_MODELS = [
"DreamShaper XL Turbo",
"RealVisXL V5.0 Lightning",
"Playground v2.5",
"Juggernaut XL Lightning",
"Pixel Party XL",
"Fluently XL v3 Inpainting",
]
def image_to_base64(image_path: str) -> str:
"""Convert an image file to base64 data URL."""
with Image.open(image_path) as img:
# Convert to RGBA if needed
if img.mode != "RGBA":
img = img.convert("RGBA")
buffer = io.BytesIO()
img.save(buffer, format="PNG")
b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
return f"data:image/png;base64,{b64}"
def create_mask_from_image(mask_path: str) -> str:
"""Convert a mask image to base64 data URL."""
return image_to_base64(mask_path)
def base64_to_image(b64_string: str) -> Image.Image:
"""Convert base64 data URL to PIL Image."""
if b64_string.startswith("data:"):
b64_string = b64_string.split(",", 1)[1]
image_data = base64.b64decode(b64_string)
return Image.open(io.BytesIO(image_data))
def inpaint(
image_path: str,
mask_path: str,
prompt: str,
negative_prompt: str = "",
model: str = "DreamShaper XL Turbo",
paste_back: bool = True,
guidance_scale: float = 1.5,
num_steps: int = 8,
use_detail_lora: bool = False,
detail_lora_weight: float = 1.1,
use_pixel_lora: bool = False,
pixel_lora_weight: float = 1.2,
use_wowifier_lora: bool = False,
wowifier_lora_weight: float = 1.0,
server_url: str = DEFAULT_SERVER,
output_path: str | None = None,
) -> Image.Image:
"""
Call the inpainting API.
Args:
image_path: Path to the input image
mask_path: Path to the mask image (white = inpaint area)
prompt: Text prompt for generation
negative_prompt: Negative prompt
model: Model name to use
paste_back: Whether to paste result back onto original
guidance_scale: Guidance scale (0.0-10.0)
num_steps: Number of inference steps (1-50)
use_detail_lora: Enable Add Detail XL LoRA
detail_lora_weight: Weight for detail LoRA (0.0-2.0)
use_pixel_lora: Enable Pixel Art XL LoRA
pixel_lora_weight: Weight for pixel art LoRA (0.0-2.0)
use_wowifier_lora: Enable Wowifier XL LoRA
wowifier_lora_weight: Weight for wowifier LoRA (0.0-2.0)
server_url: Gradio server URL
output_path: Optional path to save the output image
Returns:
PIL Image of the result
"""
# Validate model
if model not in AVAILABLE_MODELS:
raise ValueError(f"Invalid model: {model}. Available: {AVAILABLE_MODELS}")
# Prepare the image data in Gradio's expected format
background_b64 = image_to_base64(image_path)
mask_b64 = create_mask_from_image(mask_path)
# Gradio ImageMask format
image_data = {
"background": background_b64,
"layers": [mask_b64],
"composite": background_b64,
}
# Build the API payload
payload = {
"data": [
prompt, # prompt
negative_prompt, # negative_prompt
image_data, # input_image (ImageMask)
model, # model_selection
paste_back, # paste_back
guidance_scale, # guidance_scale
num_steps, # num_steps
use_detail_lora, # use_detail_lora
detail_lora_weight, # detail_lora_weight
use_pixel_lora, # use_pixel_lora
pixel_lora_weight, # pixel_lora_weight
use_wowifier_lora, # use_wowifier_lora
wowifier_lora_weight, # wowifier_lora_weight
]
}
# Call the API
api_url = f"{server_url}/api/predict"
response = requests.post(api_url, json=payload, timeout=300)
response.raise_for_status()
result = response.json()
# Extract the output image (ImageSlider returns a tuple of images)
if "data" in result and len(result["data"]) > 0:
output_data = result["data"][0]
# ImageSlider returns [original, generated] tuple
if isinstance(output_data, list) and len(output_data) > 1:
generated_b64 = output_data[1]
else:
generated_b64 = output_data
# Handle dict format (Gradio 4.x)
if isinstance(generated_b64, dict):
generated_b64 = generated_b64.get("url") or generated_b64.get("path")
if generated_b64.startswith("http"):
# Fetch from URL
img_response = requests.get(generated_b64)
img_response.raise_for_status()
result_image = Image.open(io.BytesIO(img_response.content))
else:
result_image = Image.open(generated_b64)
else:
result_image = base64_to_image(generated_b64)
if output_path:
result_image.save(output_path)
print(f"Saved output to: {output_path}")
return result_image
raise RuntimeError(f"Unexpected API response: {result}")
def main():
parser = argparse.ArgumentParser(
description="Inpainting client for diffusers-fast-inpaint",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Required arguments
parser.add_argument("image", help="Path to input image")
parser.add_argument("mask", help="Path to mask image (white = inpaint area)")
parser.add_argument("prompt", help="Text prompt for generation")
# Optional arguments
parser.add_argument("-n", "--negative-prompt", default="", help="Negative prompt")
parser.add_argument(
"-m", "--model",
default="DreamShaper XL Turbo",
choices=AVAILABLE_MODELS,
help="Model to use"
)
parser.add_argument(
"-o", "--output",
default="output.png",
help="Output image path"
)
parser.add_argument(
"--server",
default=DEFAULT_SERVER,
help="Gradio server URL"
)
# Generation parameters
parser.add_argument(
"--guidance-scale",
type=float,
default=1.5,
help="Guidance scale (0.0-10.0)"
)
parser.add_argument(
"--steps",
type=int,
default=8,
help="Number of inference steps (1-50)"
)
parser.add_argument(
"--no-paste-back",
action="store_true",
help="Don't paste result back onto original"
)
# LoRA options
parser.add_argument(
"--detail-lora",
action="store_true",
help="Enable Add Detail XL LoRA"
)
parser.add_argument(
"--detail-lora-weight",
type=float,
default=1.1,
help="Detail LoRA weight (0.0-2.0)"
)
parser.add_argument(
"--pixel-lora",
action="store_true",
help="Enable Pixel Art XL LoRA"
)
parser.add_argument(
"--pixel-lora-weight",
type=float,
default=1.2,
help="Pixel Art LoRA weight (0.0-2.0)"
)
parser.add_argument(
"--wowifier-lora",
action="store_true",
help="Enable Wowifier XL LoRA"
)
parser.add_argument(
"--wowifier-lora-weight",
type=float,
default=1.0,
help="Wowifier LoRA weight (0.0-2.0)"
)
args = parser.parse_args()
# Validate input files
if not Path(args.image).exists():
print(f"Error: Image file not found: {args.image}", file=sys.stderr)
sys.exit(1)
if not Path(args.mask).exists():
print(f"Error: Mask file not found: {args.mask}", file=sys.stderr)
sys.exit(1)
try:
inpaint(
image_path=args.image,
mask_path=args.mask,
prompt=args.prompt,
negative_prompt=args.negative_prompt,
model=args.model,
paste_back=not args.no_paste_back,
guidance_scale=args.guidance_scale,
num_steps=args.steps,
use_detail_lora=args.detail_lora,
detail_lora_weight=args.detail_lora_weight,
use_pixel_lora=args.pixel_lora,
pixel_lora_weight=args.pixel_lora_weight,
use_wowifier_lora=args.wowifier_lora,
wowifier_lora_weight=args.wowifier_lora_weight,
server_url=args.server,
output_path=args.output,
)
print("Done!")
except requests.exceptions.ConnectionError:
print(f"Error: Could not connect to server at {args.server}", file=sys.stderr)
print("Make sure the Gradio app is running.", file=sys.stderr)
sys.exit(1)
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()