KheemDH commited on
Commit
6057c81
Β·
1 Parent(s): 8a51f48

Updated App.py

Browse files
Files changed (1) hide show
  1. app.py +23 -18
app.py CHANGED
@@ -2,53 +2,58 @@ import gradio as gr
2
  from PIL import Image
3
  from diffusers import AutoPipelineForInpainting, AutoencoderKL
4
  import torch
 
5
 
6
  # Check if CUDA is available and set the device accordingly
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
- # Load models
10
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float32) # Use float32 for CPU compatibility
11
  pipeline = AutoPipelineForInpainting.from_pretrained(
12
  "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
13
  vae=vae,
14
- torch_dtype=torch.float32, # Use float32 for CPU compatibility
15
- variant="fp32", # Use fp32 for CPU
16
  use_safetensors=True
17
  ).to(device) # Ensure it uses the appropriate device (CPU or GPU)
18
 
19
  # Define the inference function
20
- def inpaint(prompt, image, mask_image, ip_image):
21
  # Preprocess the images by resizing them to 512x512
22
- image = image.convert("RGB").resize((512, 512))
23
- mask_image = mask_image.convert("RGB").resize((512, 512))
24
- ip_image = ip_image.convert("RGB").resize((512, 512))
 
 
 
 
 
25
 
26
  # Perform inpainting using the pipeline
27
  results = pipeline(
28
  prompt=prompt,
29
  negative_prompt="ugly, bad quality, bad anatomy",
30
- image=image,
31
- mask_image=mask_image,
32
- ip_adapter_image=ip_image,
33
  strength=0.99,
34
  guidance_scale=8.0,
35
  num_inference_steps=100
36
  )
37
 
38
- return results.images[0]
39
 
40
  # Set up the Gradio interface
41
  demo = gr.Interface(
42
  fn=inpaint,
43
  inputs=[
44
- gr.Textbox(label="Prompt", placeholder="Enter the prompt for the model"),
45
- gr.Image(type="pil", label="Input Image"),
46
- gr.Image(type="pil", label="Mask Image"),
47
- gr.Image(type="pil", label="IP Adapter Image")
48
  ],
49
  outputs=gr.Image(type="pil"),
50
- title="Stable Diffusion Inpainting",
51
- description="A model for inpainting and image editing using Stable Diffusion XL."
52
  )
53
 
54
  demo.launch()
 
2
  from PIL import Image
3
  from diffusers import AutoPipelineForInpainting, AutoencoderKL
4
  import torch
5
+ from SegBody import segment_body # Import the segmentation function
6
 
7
  # Check if CUDA is available and set the device accordingly
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
+ # Load models with fp16 variant for GPU compatibility
11
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
12
  pipeline = AutoPipelineForInpainting.from_pretrained(
13
  "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
14
  vae=vae,
15
+ torch_dtype=torch.float16, # Use fp16 for GPU
16
+ variant="fp16", # Ensure you are using fp16 for GPU
17
  use_safetensors=True
18
  ).to(device) # Ensure it uses the appropriate device (CPU or GPU)
19
 
20
  # Define the inference function
21
+ def inpaint(person_image, garment_image, prompt):
22
  # Preprocess the images by resizing them to 512x512
23
+ person_image = person_image.convert("RGB").resize((512, 512))
24
+ garment_image = garment_image.convert("RGB").resize((512, 512))
25
+
26
+ # Use segment_body to generate the body mask for inpainting
27
+ seg_image, mask_image = segment_body(person_image, face=False) # You can control face removal here (face=False)
28
+
29
+ # Resize mask to 512x512 to match the inpainting requirements
30
+ mask_image = mask_image.resize((512, 512))
31
 
32
  # Perform inpainting using the pipeline
33
  results = pipeline(
34
  prompt=prompt,
35
  negative_prompt="ugly, bad quality, bad anatomy",
36
+ image=person_image,
37
+ mask_image=mask_image, # Use the mask from segmentation
38
+ ip_adapter_image=garment_image, # Garment image as the IP Adapter image
39
  strength=0.99,
40
  guidance_scale=8.0,
41
  num_inference_steps=100
42
  )
43
 
44
+ return results.images[0] # Return the generated image
45
 
46
  # Set up the Gradio interface
47
  demo = gr.Interface(
48
  fn=inpaint,
49
  inputs=[
50
+ gr.Image(type="pil", label="Person Image"), # Input for person image
51
+ gr.Image(type="pil", label="Garment Image"), # Input for garment image
52
+ gr.Textbox(label="Prompt", placeholder="Enter the prompt for the model") # Text prompt for inpainting
 
53
  ],
54
  outputs=gr.Image(type="pil"),
55
+ title="Stable Diffusion Inpainting with Segmentation",
56
+ description="Inpainting model for seamless garment transfer on segmented body image using Stable Diffusion XL."
57
  )
58
 
59
  demo.launch()