Asko Relas commited on
Commit
df8e76b
Β·
1 Parent(s): ddc2163
Files changed (3) hide show
  1. app.py +10 -3
  2. client.py +290 -0
  3. requirements.txt +3 -1
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  import spaces
3
  import torch
 
4
 
5
  from diffusers import AutoencoderKL, ControlNetUnionModel, DiffusionPipeline, StableDiffusionXLPipeline, TCDScheduler, UNet2DConditionModel
6
 
@@ -46,10 +47,11 @@ UNET_MODELS = {
46
  "Pixel Party XL": "stabilityai/stable-diffusion-xl-base-1.0",
47
  }
48
 
49
- # Models that are single safetensors files (value is the file URL and base model)
50
  SINGLE_FILE_MODELS = {
51
  "Fluently XL v3 Inpainting": {
52
- "url": "https://huggingface.co/fluently/Fluently-XL-v3-inpainting/resolve/main/FluentlyXL-v3-inpainting.safetensors",
 
53
  "base": "stabilityai/stable-diffusion-xl-base-1.0",
54
  },
55
  }
@@ -70,9 +72,14 @@ def load_pipeline(model_name):
70
  if model_name in SINGLE_FILE_MODELS:
71
  # Load single safetensors checkpoint models
72
  config = SINGLE_FILE_MODELS[model_name]
 
 
 
 
 
73
  # Load the single file to extract the UNet
74
  temp_pipe = StableDiffusionXLPipeline.from_single_file(
75
- config["url"],
76
  torch_dtype=torch.float16,
77
  )
78
  unet = temp_pipe.unet
 
1
  import gradio as gr
2
  import spaces
3
  import torch
4
+ from huggingface_hub import hf_hub_download
5
 
6
  from diffusers import AutoencoderKL, ControlNetUnionModel, DiffusionPipeline, StableDiffusionXLPipeline, TCDScheduler, UNet2DConditionModel
7
 
 
47
  "Pixel Party XL": "stabilityai/stable-diffusion-xl-base-1.0",
48
  }
49
 
50
+ # Models that are single safetensors files (value is the repo, filename, and base model)
51
  SINGLE_FILE_MODELS = {
52
  "Fluently XL v3 Inpainting": {
53
+ "repo_id": "fluently/Fluently-XL-v3-inpainting",
54
+ "filename": "FluentlyXL-v3-inpainting.safetensors",
55
  "base": "stabilityai/stable-diffusion-xl-base-1.0",
56
  },
57
  }
 
72
  if model_name in SINGLE_FILE_MODELS:
73
  # Load single safetensors checkpoint models
74
  config = SINGLE_FILE_MODELS[model_name]
75
+ # Download the checkpoint file first
76
+ checkpoint_path = hf_hub_download(
77
+ repo_id=config["repo_id"],
78
+ filename=config["filename"],
79
+ )
80
  # Load the single file to extract the UNet
81
  temp_pipe = StableDiffusionXLPipeline.from_single_file(
82
+ checkpoint_path,
83
  torch_dtype=torch.float16,
84
  )
85
  unet = temp_pipe.unet
client.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """REST API client for the diffusers-fast-inpaint Gradio app."""
3
+
4
+ import argparse
5
+ import base64
6
+ import io
7
+ import json
8
+ import sys
9
+ from pathlib import Path
10
+
11
+ import requests
12
+ from PIL import Image
13
+
14
+
15
+ DEFAULT_SERVER = "http://localhost:7860"
16
+
17
+ AVAILABLE_MODELS = [
18
+ "DreamShaper XL Turbo",
19
+ "RealVisXL V5.0 Lightning",
20
+ "Playground v2.5",
21
+ "Juggernaut XL Lightning",
22
+ "Pixel Party XL",
23
+ "Fluently XL v3 Inpainting",
24
+ ]
25
+
26
+
27
+ def image_to_base64(image_path: str) -> str:
28
+ """Convert an image file to base64 data URL."""
29
+ with Image.open(image_path) as img:
30
+ # Convert to RGBA if needed
31
+ if img.mode != "RGBA":
32
+ img = img.convert("RGBA")
33
+ buffer = io.BytesIO()
34
+ img.save(buffer, format="PNG")
35
+ b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
36
+ return f"data:image/png;base64,{b64}"
37
+
38
+
39
+ def create_mask_from_image(mask_path: str) -> str:
40
+ """Convert a mask image to base64 data URL."""
41
+ return image_to_base64(mask_path)
42
+
43
+
44
+ def base64_to_image(b64_string: str) -> Image.Image:
45
+ """Convert base64 data URL to PIL Image."""
46
+ if b64_string.startswith("data:"):
47
+ b64_string = b64_string.split(",", 1)[1]
48
+ image_data = base64.b64decode(b64_string)
49
+ return Image.open(io.BytesIO(image_data))
50
+
51
+
52
+ def inpaint(
53
+ image_path: str,
54
+ mask_path: str,
55
+ prompt: str,
56
+ negative_prompt: str = "",
57
+ model: str = "DreamShaper XL Turbo",
58
+ paste_back: bool = True,
59
+ guidance_scale: float = 1.5,
60
+ num_steps: int = 8,
61
+ use_detail_lora: bool = False,
62
+ detail_lora_weight: float = 1.1,
63
+ use_pixel_lora: bool = False,
64
+ pixel_lora_weight: float = 1.2,
65
+ use_wowifier_lora: bool = False,
66
+ wowifier_lora_weight: float = 1.0,
67
+ server_url: str = DEFAULT_SERVER,
68
+ output_path: str | None = None,
69
+ ) -> Image.Image:
70
+ """
71
+ Call the inpainting API.
72
+
73
+ Args:
74
+ image_path: Path to the input image
75
+ mask_path: Path to the mask image (white = inpaint area)
76
+ prompt: Text prompt for generation
77
+ negative_prompt: Negative prompt
78
+ model: Model name to use
79
+ paste_back: Whether to paste result back onto original
80
+ guidance_scale: Guidance scale (0.0-10.0)
81
+ num_steps: Number of inference steps (1-50)
82
+ use_detail_lora: Enable Add Detail XL LoRA
83
+ detail_lora_weight: Weight for detail LoRA (0.0-2.0)
84
+ use_pixel_lora: Enable Pixel Art XL LoRA
85
+ pixel_lora_weight: Weight for pixel art LoRA (0.0-2.0)
86
+ use_wowifier_lora: Enable Wowifier XL LoRA
87
+ wowifier_lora_weight: Weight for wowifier LoRA (0.0-2.0)
88
+ server_url: Gradio server URL
89
+ output_path: Optional path to save the output image
90
+
91
+ Returns:
92
+ PIL Image of the result
93
+ """
94
+ # Validate model
95
+ if model not in AVAILABLE_MODELS:
96
+ raise ValueError(f"Invalid model: {model}. Available: {AVAILABLE_MODELS}")
97
+
98
+ # Prepare the image data in Gradio's expected format
99
+ background_b64 = image_to_base64(image_path)
100
+ mask_b64 = create_mask_from_image(mask_path)
101
+
102
+ # Gradio ImageMask format
103
+ image_data = {
104
+ "background": background_b64,
105
+ "layers": [mask_b64],
106
+ "composite": background_b64,
107
+ }
108
+
109
+ # Build the API payload
110
+ payload = {
111
+ "data": [
112
+ prompt, # prompt
113
+ negative_prompt, # negative_prompt
114
+ image_data, # input_image (ImageMask)
115
+ model, # model_selection
116
+ paste_back, # paste_back
117
+ guidance_scale, # guidance_scale
118
+ num_steps, # num_steps
119
+ use_detail_lora, # use_detail_lora
120
+ detail_lora_weight, # detail_lora_weight
121
+ use_pixel_lora, # use_pixel_lora
122
+ pixel_lora_weight, # pixel_lora_weight
123
+ use_wowifier_lora, # use_wowifier_lora
124
+ wowifier_lora_weight, # wowifier_lora_weight
125
+ ]
126
+ }
127
+
128
+ # Call the API
129
+ api_url = f"{server_url}/api/predict"
130
+ response = requests.post(api_url, json=payload, timeout=300)
131
+ response.raise_for_status()
132
+
133
+ result = response.json()
134
+
135
+ # Extract the output image (ImageSlider returns a tuple of images)
136
+ if "data" in result and len(result["data"]) > 0:
137
+ output_data = result["data"][0]
138
+ # ImageSlider returns [original, generated] tuple
139
+ if isinstance(output_data, list) and len(output_data) > 1:
140
+ generated_b64 = output_data[1]
141
+ else:
142
+ generated_b64 = output_data
143
+
144
+ # Handle dict format (Gradio 4.x)
145
+ if isinstance(generated_b64, dict):
146
+ generated_b64 = generated_b64.get("url") or generated_b64.get("path")
147
+ if generated_b64.startswith("http"):
148
+ # Fetch from URL
149
+ img_response = requests.get(generated_b64)
150
+ img_response.raise_for_status()
151
+ result_image = Image.open(io.BytesIO(img_response.content))
152
+ else:
153
+ result_image = Image.open(generated_b64)
154
+ else:
155
+ result_image = base64_to_image(generated_b64)
156
+
157
+ if output_path:
158
+ result_image.save(output_path)
159
+ print(f"Saved output to: {output_path}")
160
+
161
+ return result_image
162
+
163
+ raise RuntimeError(f"Unexpected API response: {result}")
164
+
165
+
166
+ def main():
167
+ parser = argparse.ArgumentParser(
168
+ description="Inpainting client for diffusers-fast-inpaint",
169
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
170
+ )
171
+
172
+ # Required arguments
173
+ parser.add_argument("image", help="Path to input image")
174
+ parser.add_argument("mask", help="Path to mask image (white = inpaint area)")
175
+ parser.add_argument("prompt", help="Text prompt for generation")
176
+
177
+ # Optional arguments
178
+ parser.add_argument("-n", "--negative-prompt", default="", help="Negative prompt")
179
+ parser.add_argument(
180
+ "-m", "--model",
181
+ default="DreamShaper XL Turbo",
182
+ choices=AVAILABLE_MODELS,
183
+ help="Model to use"
184
+ )
185
+ parser.add_argument(
186
+ "-o", "--output",
187
+ default="output.png",
188
+ help="Output image path"
189
+ )
190
+ parser.add_argument(
191
+ "--server",
192
+ default=DEFAULT_SERVER,
193
+ help="Gradio server URL"
194
+ )
195
+
196
+ # Generation parameters
197
+ parser.add_argument(
198
+ "--guidance-scale",
199
+ type=float,
200
+ default=1.5,
201
+ help="Guidance scale (0.0-10.0)"
202
+ )
203
+ parser.add_argument(
204
+ "--steps",
205
+ type=int,
206
+ default=8,
207
+ help="Number of inference steps (1-50)"
208
+ )
209
+ parser.add_argument(
210
+ "--no-paste-back",
211
+ action="store_true",
212
+ help="Don't paste result back onto original"
213
+ )
214
+
215
+ # LoRA options
216
+ parser.add_argument(
217
+ "--detail-lora",
218
+ action="store_true",
219
+ help="Enable Add Detail XL LoRA"
220
+ )
221
+ parser.add_argument(
222
+ "--detail-lora-weight",
223
+ type=float,
224
+ default=1.1,
225
+ help="Detail LoRA weight (0.0-2.0)"
226
+ )
227
+ parser.add_argument(
228
+ "--pixel-lora",
229
+ action="store_true",
230
+ help="Enable Pixel Art XL LoRA"
231
+ )
232
+ parser.add_argument(
233
+ "--pixel-lora-weight",
234
+ type=float,
235
+ default=1.2,
236
+ help="Pixel Art LoRA weight (0.0-2.0)"
237
+ )
238
+ parser.add_argument(
239
+ "--wowifier-lora",
240
+ action="store_true",
241
+ help="Enable Wowifier XL LoRA"
242
+ )
243
+ parser.add_argument(
244
+ "--wowifier-lora-weight",
245
+ type=float,
246
+ default=1.0,
247
+ help="Wowifier LoRA weight (0.0-2.0)"
248
+ )
249
+
250
+ args = parser.parse_args()
251
+
252
+ # Validate input files
253
+ if not Path(args.image).exists():
254
+ print(f"Error: Image file not found: {args.image}", file=sys.stderr)
255
+ sys.exit(1)
256
+ if not Path(args.mask).exists():
257
+ print(f"Error: Mask file not found: {args.mask}", file=sys.stderr)
258
+ sys.exit(1)
259
+
260
+ try:
261
+ inpaint(
262
+ image_path=args.image,
263
+ mask_path=args.mask,
264
+ prompt=args.prompt,
265
+ negative_prompt=args.negative_prompt,
266
+ model=args.model,
267
+ paste_back=not args.no_paste_back,
268
+ guidance_scale=args.guidance_scale,
269
+ num_steps=args.steps,
270
+ use_detail_lora=args.detail_lora,
271
+ detail_lora_weight=args.detail_lora_weight,
272
+ use_pixel_lora=args.pixel_lora,
273
+ pixel_lora_weight=args.pixel_lora_weight,
274
+ use_wowifier_lora=args.wowifier_lora,
275
+ wowifier_lora_weight=args.wowifier_lora_weight,
276
+ server_url=args.server,
277
+ output_path=args.output,
278
+ )
279
+ print("Done!")
280
+ except requests.exceptions.ConnectionError:
281
+ print(f"Error: Could not connect to server at {args.server}", file=sys.stderr)
282
+ print("Make sure the Gradio app is running.", file=sys.stderr)
283
+ sys.exit(1)
284
+ except Exception as e:
285
+ print(f"Error: {e}", file=sys.stderr)
286
+ sys.exit(1)
287
+
288
+
289
+ if __name__ == "__main__":
290
+ main()
requirements.txt CHANGED
@@ -7,4 +7,6 @@ accelerate
7
  diffusers
8
  peft
9
  fastapi
10
- opencv-python
 
 
 
7
  diffusers
8
  peft
9
  fastapi
10
+ opencv-python
11
+ requests
12
+ pillow