IotaCluster commited on
Commit
9a91ecf
·
verified ·
1 Parent(s): 5dbf1e5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import torch, gc
4
+ from diffusers import StableDiffusionXLInpaintPipeline, StableDiffusionXLImg2ImgPipeline
5
+
6
+ def inpaint_with_mask(image: Image.Image, mask: Image.Image, prompt: str = "background") -> Image.Image:
7
+ image = image.resize((1024, 1024))
8
+ mask = mask.resize((1024, 1024)).convert("L")
9
+
10
+ # 🧠 Load Inpainting Pipeline
11
+ pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
12
+ "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
13
+ torch_dtype=torch.float16,
14
+ variant="fp16"
15
+ ).to("cuda")
16
+
17
+ # 🖌️ Inpaint
18
+ result = pipe(prompt=prompt, image=image, mask_image=mask).images[0]
19
+
20
+ # 🧹 Unload Inpainting
21
+ del pipe
22
+ torch.cuda.empty_cache()
23
+ gc.collect()
24
+
25
+ # 🎨 Load Refiner
26
+ refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
27
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
28
+ torch_dtype=torch.float16,
29
+ variant="fp16"
30
+ ).to("cuda")
31
+
32
+ result = refiner(prompt=prompt, image=result).images[0]
33
+
34
+ del refiner
35
+ torch.cuda.empty_cache()
36
+ gc.collect()
37
+
38
+ return result
39
+
40
+ # Gradio Interface
41
+ with gr.Blocks() as demo:
42
+ gr.Markdown("# 🖌️ Inpaint with Stable Diffusion XL")
43
+ with gr.Row():
44
+ image_input = gr.Image(label="Original Image", type="pil")
45
+ mask_input = gr.Image(label="Mask (white = inpaint)", type="pil")
46
+ prompt_input = gr.Textbox(label="Prompt", value="background")
47
+ output = gr.Image(label="Result")
48
+
49
+ run_btn = gr.Button("Inpaint")
50
+
51
+ run_btn.click(fn=inpaint_with_mask, inputs=[image_input, mask_input, prompt_input], outputs=output)
52
+
53
+ demo.launch()