#!/usr/bin/env python3 """REST API client for the diffusers-fast-inpaint Gradio app.""" import argparse import base64 import io import json import sys from pathlib import Path import requests from PIL import Image DEFAULT_SERVER = "http://localhost:7860" AVAILABLE_MODELS = [ "DreamShaper XL Turbo", "RealVisXL V5.0 Lightning", "Playground v2.5", "Juggernaut XL Lightning", "Pixel Party XL", "Fluently XL v3 Inpainting", ] def image_to_base64(image_path: str) -> str: """Convert an image file to base64 data URL.""" with Image.open(image_path) as img: # Convert to RGBA if needed if img.mode != "RGBA": img = img.convert("RGBA") buffer = io.BytesIO() img.save(buffer, format="PNG") b64 = base64.b64encode(buffer.getvalue()).decode("utf-8") return f"data:image/png;base64,{b64}" def create_mask_from_image(mask_path: str) -> str: """Convert a mask image to base64 data URL.""" return image_to_base64(mask_path) def base64_to_image(b64_string: str) -> Image.Image: """Convert base64 data URL to PIL Image.""" if b64_string.startswith("data:"): b64_string = b64_string.split(",", 1)[1] image_data = base64.b64decode(b64_string) return Image.open(io.BytesIO(image_data)) def inpaint( image_path: str, mask_path: str, prompt: str, negative_prompt: str = "", model: str = "DreamShaper XL Turbo", paste_back: bool = True, guidance_scale: float = 1.5, num_steps: int = 8, use_detail_lora: bool = False, detail_lora_weight: float = 1.1, use_pixel_lora: bool = False, pixel_lora_weight: float = 1.2, use_wowifier_lora: bool = False, wowifier_lora_weight: float = 1.0, server_url: str = DEFAULT_SERVER, output_path: str | None = None, ) -> Image.Image: """ Call the inpainting API. Args: image_path: Path to the input image mask_path: Path to the mask image (white = inpaint area) prompt: Text prompt for generation negative_prompt: Negative prompt model: Model name to use paste_back: Whether to paste result back onto original guidance_scale: Guidance scale (0.0-10.0) num_steps: Number of inference steps (1-50) use_detail_lora: Enable Add Detail XL LoRA detail_lora_weight: Weight for detail LoRA (0.0-2.0) use_pixel_lora: Enable Pixel Art XL LoRA pixel_lora_weight: Weight for pixel art LoRA (0.0-2.0) use_wowifier_lora: Enable Wowifier XL LoRA wowifier_lora_weight: Weight for wowifier LoRA (0.0-2.0) server_url: Gradio server URL output_path: Optional path to save the output image Returns: PIL Image of the result """ # Validate model if model not in AVAILABLE_MODELS: raise ValueError(f"Invalid model: {model}. Available: {AVAILABLE_MODELS}") # Prepare the image data in Gradio's expected format background_b64 = image_to_base64(image_path) mask_b64 = create_mask_from_image(mask_path) # Gradio ImageMask format image_data = { "background": background_b64, "layers": [mask_b64], "composite": background_b64, } # Build the API payload payload = { "data": [ prompt, # prompt negative_prompt, # negative_prompt image_data, # input_image (ImageMask) model, # model_selection paste_back, # paste_back guidance_scale, # guidance_scale num_steps, # num_steps use_detail_lora, # use_detail_lora detail_lora_weight, # detail_lora_weight use_pixel_lora, # use_pixel_lora pixel_lora_weight, # pixel_lora_weight use_wowifier_lora, # use_wowifier_lora wowifier_lora_weight, # wowifier_lora_weight ] } # Call the API api_url = f"{server_url}/api/predict" response = requests.post(api_url, json=payload, timeout=300) response.raise_for_status() result = response.json() # Extract the output image (ImageSlider returns a tuple of images) if "data" in result and len(result["data"]) > 0: output_data = result["data"][0] # ImageSlider returns [original, generated] tuple if isinstance(output_data, list) and len(output_data) > 1: generated_b64 = output_data[1] else: generated_b64 = output_data # Handle dict format (Gradio 4.x) if isinstance(generated_b64, dict): generated_b64 = generated_b64.get("url") or generated_b64.get("path") if generated_b64.startswith("http"): # Fetch from URL img_response = requests.get(generated_b64) img_response.raise_for_status() result_image = Image.open(io.BytesIO(img_response.content)) else: result_image = Image.open(generated_b64) else: result_image = base64_to_image(generated_b64) if output_path: result_image.save(output_path) print(f"Saved output to: {output_path}") return result_image raise RuntimeError(f"Unexpected API response: {result}") def main(): parser = argparse.ArgumentParser( description="Inpainting client for diffusers-fast-inpaint", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) # Required arguments parser.add_argument("image", help="Path to input image") parser.add_argument("mask", help="Path to mask image (white = inpaint area)") parser.add_argument("prompt", help="Text prompt for generation") # Optional arguments parser.add_argument("-n", "--negative-prompt", default="", help="Negative prompt") parser.add_argument( "-m", "--model", default="DreamShaper XL Turbo", choices=AVAILABLE_MODELS, help="Model to use" ) parser.add_argument( "-o", "--output", default="output.png", help="Output image path" ) parser.add_argument( "--server", default=DEFAULT_SERVER, help="Gradio server URL" ) # Generation parameters parser.add_argument( "--guidance-scale", type=float, default=1.5, help="Guidance scale (0.0-10.0)" ) parser.add_argument( "--steps", type=int, default=8, help="Number of inference steps (1-50)" ) parser.add_argument( "--no-paste-back", action="store_true", help="Don't paste result back onto original" ) # LoRA options parser.add_argument( "--detail-lora", action="store_true", help="Enable Add Detail XL LoRA" ) parser.add_argument( "--detail-lora-weight", type=float, default=1.1, help="Detail LoRA weight (0.0-2.0)" ) parser.add_argument( "--pixel-lora", action="store_true", help="Enable Pixel Art XL LoRA" ) parser.add_argument( "--pixel-lora-weight", type=float, default=1.2, help="Pixel Art LoRA weight (0.0-2.0)" ) parser.add_argument( "--wowifier-lora", action="store_true", help="Enable Wowifier XL LoRA" ) parser.add_argument( "--wowifier-lora-weight", type=float, default=1.0, help="Wowifier LoRA weight (0.0-2.0)" ) args = parser.parse_args() # Validate input files if not Path(args.image).exists(): print(f"Error: Image file not found: {args.image}", file=sys.stderr) sys.exit(1) if not Path(args.mask).exists(): print(f"Error: Mask file not found: {args.mask}", file=sys.stderr) sys.exit(1) try: inpaint( image_path=args.image, mask_path=args.mask, prompt=args.prompt, negative_prompt=args.negative_prompt, model=args.model, paste_back=not args.no_paste_back, guidance_scale=args.guidance_scale, num_steps=args.steps, use_detail_lora=args.detail_lora, detail_lora_weight=args.detail_lora_weight, use_pixel_lora=args.pixel_lora, pixel_lora_weight=args.pixel_lora_weight, use_wowifier_lora=args.wowifier_lora, wowifier_lora_weight=args.wowifier_lora_weight, server_url=args.server, output_path=args.output, ) print("Done!") except requests.exceptions.ConnectionError: print(f"Error: Could not connect to server at {args.server}", file=sys.stderr) print("Make sure the Gradio app is running.", file=sys.stderr) sys.exit(1) except Exception as e: print(f"Error: {e}", file=sys.stderr) sys.exit(1) if __name__ == "__main__": main()