Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """REST API client for the diffusers-fast-inpaint Gradio app.""" | |
| import argparse | |
| import base64 | |
| import io | |
| import json | |
| import sys | |
| from pathlib import Path | |
| import requests | |
| from PIL import Image | |
| DEFAULT_SERVER = "http://localhost:7860" | |
| AVAILABLE_MODELS = [ | |
| "DreamShaper XL Turbo", | |
| "RealVisXL V5.0 Lightning", | |
| "Playground v2.5", | |
| "Juggernaut XL Lightning", | |
| "Pixel Party XL", | |
| "Fluently XL v3 Inpainting", | |
| ] | |
| def image_to_base64(image_path: str) -> str: | |
| """Convert an image file to base64 data URL.""" | |
| with Image.open(image_path) as img: | |
| # Convert to RGBA if needed | |
| if img.mode != "RGBA": | |
| img = img.convert("RGBA") | |
| buffer = io.BytesIO() | |
| img.save(buffer, format="PNG") | |
| b64 = base64.b64encode(buffer.getvalue()).decode("utf-8") | |
| return f"data:image/png;base64,{b64}" | |
| def create_mask_from_image(mask_path: str) -> str: | |
| """Convert a mask image to base64 data URL.""" | |
| return image_to_base64(mask_path) | |
| def base64_to_image(b64_string: str) -> Image.Image: | |
| """Convert base64 data URL to PIL Image.""" | |
| if b64_string.startswith("data:"): | |
| b64_string = b64_string.split(",", 1)[1] | |
| image_data = base64.b64decode(b64_string) | |
| return Image.open(io.BytesIO(image_data)) | |
| def inpaint( | |
| image_path: str, | |
| mask_path: str, | |
| prompt: str, | |
| negative_prompt: str = "", | |
| model: str = "DreamShaper XL Turbo", | |
| paste_back: bool = True, | |
| guidance_scale: float = 1.5, | |
| num_steps: int = 8, | |
| use_detail_lora: bool = False, | |
| detail_lora_weight: float = 1.1, | |
| use_pixel_lora: bool = False, | |
| pixel_lora_weight: float = 1.2, | |
| use_wowifier_lora: bool = False, | |
| wowifier_lora_weight: float = 1.0, | |
| server_url: str = DEFAULT_SERVER, | |
| output_path: str | None = None, | |
| ) -> Image.Image: | |
| """ | |
| Call the inpainting API. | |
| Args: | |
| image_path: Path to the input image | |
| mask_path: Path to the mask image (white = inpaint area) | |
| prompt: Text prompt for generation | |
| negative_prompt: Negative prompt | |
| model: Model name to use | |
| paste_back: Whether to paste result back onto original | |
| guidance_scale: Guidance scale (0.0-10.0) | |
| num_steps: Number of inference steps (1-50) | |
| use_detail_lora: Enable Add Detail XL LoRA | |
| detail_lora_weight: Weight for detail LoRA (0.0-2.0) | |
| use_pixel_lora: Enable Pixel Art XL LoRA | |
| pixel_lora_weight: Weight for pixel art LoRA (0.0-2.0) | |
| use_wowifier_lora: Enable Wowifier XL LoRA | |
| wowifier_lora_weight: Weight for wowifier LoRA (0.0-2.0) | |
| server_url: Gradio server URL | |
| output_path: Optional path to save the output image | |
| Returns: | |
| PIL Image of the result | |
| """ | |
| # Validate model | |
| if model not in AVAILABLE_MODELS: | |
| raise ValueError(f"Invalid model: {model}. Available: {AVAILABLE_MODELS}") | |
| # Prepare the image data in Gradio's expected format | |
| background_b64 = image_to_base64(image_path) | |
| mask_b64 = create_mask_from_image(mask_path) | |
| # Gradio ImageMask format | |
| image_data = { | |
| "background": background_b64, | |
| "layers": [mask_b64], | |
| "composite": background_b64, | |
| } | |
| # Build the API payload | |
| payload = { | |
| "data": [ | |
| prompt, # prompt | |
| negative_prompt, # negative_prompt | |
| image_data, # input_image (ImageMask) | |
| model, # model_selection | |
| paste_back, # paste_back | |
| guidance_scale, # guidance_scale | |
| num_steps, # num_steps | |
| use_detail_lora, # use_detail_lora | |
| detail_lora_weight, # detail_lora_weight | |
| use_pixel_lora, # use_pixel_lora | |
| pixel_lora_weight, # pixel_lora_weight | |
| use_wowifier_lora, # use_wowifier_lora | |
| wowifier_lora_weight, # wowifier_lora_weight | |
| ] | |
| } | |
| # Call the API | |
| api_url = f"{server_url}/api/predict" | |
| response = requests.post(api_url, json=payload, timeout=300) | |
| response.raise_for_status() | |
| result = response.json() | |
| # Extract the output image (ImageSlider returns a tuple of images) | |
| if "data" in result and len(result["data"]) > 0: | |
| output_data = result["data"][0] | |
| # ImageSlider returns [original, generated] tuple | |
| if isinstance(output_data, list) and len(output_data) > 1: | |
| generated_b64 = output_data[1] | |
| else: | |
| generated_b64 = output_data | |
| # Handle dict format (Gradio 4.x) | |
| if isinstance(generated_b64, dict): | |
| generated_b64 = generated_b64.get("url") or generated_b64.get("path") | |
| if generated_b64.startswith("http"): | |
| # Fetch from URL | |
| img_response = requests.get(generated_b64) | |
| img_response.raise_for_status() | |
| result_image = Image.open(io.BytesIO(img_response.content)) | |
| else: | |
| result_image = Image.open(generated_b64) | |
| else: | |
| result_image = base64_to_image(generated_b64) | |
| if output_path: | |
| result_image.save(output_path) | |
| print(f"Saved output to: {output_path}") | |
| return result_image | |
| raise RuntimeError(f"Unexpected API response: {result}") | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Inpainting client for diffusers-fast-inpaint", | |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
| ) | |
| # Required arguments | |
| parser.add_argument("image", help="Path to input image") | |
| parser.add_argument("mask", help="Path to mask image (white = inpaint area)") | |
| parser.add_argument("prompt", help="Text prompt for generation") | |
| # Optional arguments | |
| parser.add_argument("-n", "--negative-prompt", default="", help="Negative prompt") | |
| parser.add_argument( | |
| "-m", "--model", | |
| default="DreamShaper XL Turbo", | |
| choices=AVAILABLE_MODELS, | |
| help="Model to use" | |
| ) | |
| parser.add_argument( | |
| "-o", "--output", | |
| default="output.png", | |
| help="Output image path" | |
| ) | |
| parser.add_argument( | |
| "--server", | |
| default=DEFAULT_SERVER, | |
| help="Gradio server URL" | |
| ) | |
| # Generation parameters | |
| parser.add_argument( | |
| "--guidance-scale", | |
| type=float, | |
| default=1.5, | |
| help="Guidance scale (0.0-10.0)" | |
| ) | |
| parser.add_argument( | |
| "--steps", | |
| type=int, | |
| default=8, | |
| help="Number of inference steps (1-50)" | |
| ) | |
| parser.add_argument( | |
| "--no-paste-back", | |
| action="store_true", | |
| help="Don't paste result back onto original" | |
| ) | |
| # LoRA options | |
| parser.add_argument( | |
| "--detail-lora", | |
| action="store_true", | |
| help="Enable Add Detail XL LoRA" | |
| ) | |
| parser.add_argument( | |
| "--detail-lora-weight", | |
| type=float, | |
| default=1.1, | |
| help="Detail LoRA weight (0.0-2.0)" | |
| ) | |
| parser.add_argument( | |
| "--pixel-lora", | |
| action="store_true", | |
| help="Enable Pixel Art XL LoRA" | |
| ) | |
| parser.add_argument( | |
| "--pixel-lora-weight", | |
| type=float, | |
| default=1.2, | |
| help="Pixel Art LoRA weight (0.0-2.0)" | |
| ) | |
| parser.add_argument( | |
| "--wowifier-lora", | |
| action="store_true", | |
| help="Enable Wowifier XL LoRA" | |
| ) | |
| parser.add_argument( | |
| "--wowifier-lora-weight", | |
| type=float, | |
| default=1.0, | |
| help="Wowifier LoRA weight (0.0-2.0)" | |
| ) | |
| args = parser.parse_args() | |
| # Validate input files | |
| if not Path(args.image).exists(): | |
| print(f"Error: Image file not found: {args.image}", file=sys.stderr) | |
| sys.exit(1) | |
| if not Path(args.mask).exists(): | |
| print(f"Error: Mask file not found: {args.mask}", file=sys.stderr) | |
| sys.exit(1) | |
| try: | |
| inpaint( | |
| image_path=args.image, | |
| mask_path=args.mask, | |
| prompt=args.prompt, | |
| negative_prompt=args.negative_prompt, | |
| model=args.model, | |
| paste_back=not args.no_paste_back, | |
| guidance_scale=args.guidance_scale, | |
| num_steps=args.steps, | |
| use_detail_lora=args.detail_lora, | |
| detail_lora_weight=args.detail_lora_weight, | |
| use_pixel_lora=args.pixel_lora, | |
| pixel_lora_weight=args.pixel_lora_weight, | |
| use_wowifier_lora=args.wowifier_lora, | |
| wowifier_lora_weight=args.wowifier_lora_weight, | |
| server_url=args.server, | |
| output_path=args.output, | |
| ) | |
| print("Done!") | |
| except requests.exceptions.ConnectionError: | |
| print(f"Error: Could not connect to server at {args.server}", file=sys.stderr) | |
| print("Make sure the Gradio app is running.", file=sys.stderr) | |
| sys.exit(1) | |
| except Exception as e: | |
| print(f"Error: {e}", file=sys.stderr) | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| main() | |