File size: 9,058 Bytes
df8e76b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
#!/usr/bin/env python3
"""REST API client for the diffusers-fast-inpaint Gradio app."""

import argparse
import base64
import io
import json
import sys
from pathlib import Path

import requests
from PIL import Image


DEFAULT_SERVER = "http://localhost:7860"

AVAILABLE_MODELS = [
    "DreamShaper XL Turbo",
    "RealVisXL V5.0 Lightning",
    "Playground v2.5",
    "Juggernaut XL Lightning",
    "Pixel Party XL",
    "Fluently XL v3 Inpainting",
]


def image_to_base64(image_path: str) -> str:
    """Convert an image file to base64 data URL."""
    with Image.open(image_path) as img:
        # Convert to RGBA if needed
        if img.mode != "RGBA":
            img = img.convert("RGBA")
        buffer = io.BytesIO()
        img.save(buffer, format="PNG")
        b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
        return f"data:image/png;base64,{b64}"


def create_mask_from_image(mask_path: str) -> str:
    """Convert a mask image to base64 data URL."""
    return image_to_base64(mask_path)


def base64_to_image(b64_string: str) -> Image.Image:
    """Convert base64 data URL to PIL Image."""
    if b64_string.startswith("data:"):
        b64_string = b64_string.split(",", 1)[1]
    image_data = base64.b64decode(b64_string)
    return Image.open(io.BytesIO(image_data))


def inpaint(
    image_path: str,
    mask_path: str,
    prompt: str,
    negative_prompt: str = "",
    model: str = "DreamShaper XL Turbo",
    paste_back: bool = True,
    guidance_scale: float = 1.5,
    num_steps: int = 8,
    use_detail_lora: bool = False,
    detail_lora_weight: float = 1.1,
    use_pixel_lora: bool = False,
    pixel_lora_weight: float = 1.2,
    use_wowifier_lora: bool = False,
    wowifier_lora_weight: float = 1.0,
    server_url: str = DEFAULT_SERVER,
    output_path: str | None = None,
) -> Image.Image:
    """
    Call the inpainting API.

    Args:
        image_path: Path to the input image
        mask_path: Path to the mask image (white = inpaint area)
        prompt: Text prompt for generation
        negative_prompt: Negative prompt
        model: Model name to use
        paste_back: Whether to paste result back onto original
        guidance_scale: Guidance scale (0.0-10.0)
        num_steps: Number of inference steps (1-50)
        use_detail_lora: Enable Add Detail XL LoRA
        detail_lora_weight: Weight for detail LoRA (0.0-2.0)
        use_pixel_lora: Enable Pixel Art XL LoRA
        pixel_lora_weight: Weight for pixel art LoRA (0.0-2.0)
        use_wowifier_lora: Enable Wowifier XL LoRA
        wowifier_lora_weight: Weight for wowifier LoRA (0.0-2.0)
        server_url: Gradio server URL
        output_path: Optional path to save the output image

    Returns:
        PIL Image of the result
    """
    # Validate model
    if model not in AVAILABLE_MODELS:
        raise ValueError(f"Invalid model: {model}. Available: {AVAILABLE_MODELS}")

    # Prepare the image data in Gradio's expected format
    background_b64 = image_to_base64(image_path)
    mask_b64 = create_mask_from_image(mask_path)

    # Gradio ImageMask format
    image_data = {
        "background": background_b64,
        "layers": [mask_b64],
        "composite": background_b64,
    }

    # Build the API payload
    payload = {
        "data": [
            prompt,                  # prompt
            negative_prompt,         # negative_prompt
            image_data,              # input_image (ImageMask)
            model,                   # model_selection
            paste_back,              # paste_back
            guidance_scale,          # guidance_scale
            num_steps,               # num_steps
            use_detail_lora,         # use_detail_lora
            detail_lora_weight,      # detail_lora_weight
            use_pixel_lora,          # use_pixel_lora
            pixel_lora_weight,       # pixel_lora_weight
            use_wowifier_lora,       # use_wowifier_lora
            wowifier_lora_weight,    # wowifier_lora_weight
        ]
    }

    # Call the API
    api_url = f"{server_url}/api/predict"
    response = requests.post(api_url, json=payload, timeout=300)
    response.raise_for_status()

    result = response.json()

    # Extract the output image (ImageSlider returns a tuple of images)
    if "data" in result and len(result["data"]) > 0:
        output_data = result["data"][0]
        # ImageSlider returns [original, generated] tuple
        if isinstance(output_data, list) and len(output_data) > 1:
            generated_b64 = output_data[1]
        else:
            generated_b64 = output_data

        # Handle dict format (Gradio 4.x)
        if isinstance(generated_b64, dict):
            generated_b64 = generated_b64.get("url") or generated_b64.get("path")
            if generated_b64.startswith("http"):
                # Fetch from URL
                img_response = requests.get(generated_b64)
                img_response.raise_for_status()
                result_image = Image.open(io.BytesIO(img_response.content))
            else:
                result_image = Image.open(generated_b64)
        else:
            result_image = base64_to_image(generated_b64)

        if output_path:
            result_image.save(output_path)
            print(f"Saved output to: {output_path}")

        return result_image

    raise RuntimeError(f"Unexpected API response: {result}")


def main():
    parser = argparse.ArgumentParser(
        description="Inpainting client for diffusers-fast-inpaint",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    # Required arguments
    parser.add_argument("image", help="Path to input image")
    parser.add_argument("mask", help="Path to mask image (white = inpaint area)")
    parser.add_argument("prompt", help="Text prompt for generation")

    # Optional arguments
    parser.add_argument("-n", "--negative-prompt", default="", help="Negative prompt")
    parser.add_argument(
        "-m", "--model",
        default="DreamShaper XL Turbo",
        choices=AVAILABLE_MODELS,
        help="Model to use"
    )
    parser.add_argument(
        "-o", "--output",
        default="output.png",
        help="Output image path"
    )
    parser.add_argument(
        "--server",
        default=DEFAULT_SERVER,
        help="Gradio server URL"
    )

    # Generation parameters
    parser.add_argument(
        "--guidance-scale",
        type=float,
        default=1.5,
        help="Guidance scale (0.0-10.0)"
    )
    parser.add_argument(
        "--steps",
        type=int,
        default=8,
        help="Number of inference steps (1-50)"
    )
    parser.add_argument(
        "--no-paste-back",
        action="store_true",
        help="Don't paste result back onto original"
    )

    # LoRA options
    parser.add_argument(
        "--detail-lora",
        action="store_true",
        help="Enable Add Detail XL LoRA"
    )
    parser.add_argument(
        "--detail-lora-weight",
        type=float,
        default=1.1,
        help="Detail LoRA weight (0.0-2.0)"
    )
    parser.add_argument(
        "--pixel-lora",
        action="store_true",
        help="Enable Pixel Art XL LoRA"
    )
    parser.add_argument(
        "--pixel-lora-weight",
        type=float,
        default=1.2,
        help="Pixel Art LoRA weight (0.0-2.0)"
    )
    parser.add_argument(
        "--wowifier-lora",
        action="store_true",
        help="Enable Wowifier XL LoRA"
    )
    parser.add_argument(
        "--wowifier-lora-weight",
        type=float,
        default=1.0,
        help="Wowifier LoRA weight (0.0-2.0)"
    )

    args = parser.parse_args()

    # Validate input files
    if not Path(args.image).exists():
        print(f"Error: Image file not found: {args.image}", file=sys.stderr)
        sys.exit(1)
    if not Path(args.mask).exists():
        print(f"Error: Mask file not found: {args.mask}", file=sys.stderr)
        sys.exit(1)

    try:
        inpaint(
            image_path=args.image,
            mask_path=args.mask,
            prompt=args.prompt,
            negative_prompt=args.negative_prompt,
            model=args.model,
            paste_back=not args.no_paste_back,
            guidance_scale=args.guidance_scale,
            num_steps=args.steps,
            use_detail_lora=args.detail_lora,
            detail_lora_weight=args.detail_lora_weight,
            use_pixel_lora=args.pixel_lora,
            pixel_lora_weight=args.pixel_lora_weight,
            use_wowifier_lora=args.wowifier_lora,
            wowifier_lora_weight=args.wowifier_lora_weight,
            server_url=args.server,
            output_path=args.output,
        )
        print("Done!")
    except requests.exceptions.ConnectionError:
        print(f"Error: Could not connect to server at {args.server}", file=sys.stderr)
        print("Make sure the Gradio app is running.", file=sys.stderr)
        sys.exit(1)
    except Exception as e:
        print(f"Error: {e}", file=sys.stderr)
        sys.exit(1)


if __name__ == "__main__":
    main()