KheemDH commited on
Commit
8a51f48
Β·
1 Parent(s): 6f447b6

Updated App.py

Browse files
Files changed (1) hide show
  1. app.py +14 -6
app.py CHANGED
@@ -3,20 +3,27 @@ from PIL import Image
3
  from diffusers import AutoPipelineForInpainting, AutoencoderKL
4
  import torch
5
 
 
 
 
6
  # Load models
7
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
8
- pipeline = AutoPipelineForInpainting.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
9
- vae=vae,
10
- torch_dtype=torch.float16,
11
- variant="fp16",
12
- use_safetensors=True).to("cuda")
 
 
13
 
14
  # Define the inference function
15
  def inpaint(prompt, image, mask_image, ip_image):
 
16
  image = image.convert("RGB").resize((512, 512))
17
  mask_image = mask_image.convert("RGB").resize((512, 512))
18
  ip_image = ip_image.convert("RGB").resize((512, 512))
19
 
 
20
  results = pipeline(
21
  prompt=prompt,
22
  negative_prompt="ugly, bad quality, bad anatomy",
@@ -27,6 +34,7 @@ def inpaint(prompt, image, mask_image, ip_image):
27
  guidance_scale=8.0,
28
  num_inference_steps=100
29
  )
 
30
  return results.images[0]
31
 
32
  # Set up the Gradio interface
 
3
  from diffusers import AutoPipelineForInpainting, AutoencoderKL
4
  import torch
5
 
6
+ # Check if CUDA is available and set the device accordingly
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+
9
  # Load models
10
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float32) # Use float32 for CPU compatibility
11
+ pipeline = AutoPipelineForInpainting.from_pretrained(
12
+ "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
13
+ vae=vae,
14
+ torch_dtype=torch.float32, # Use float32 for CPU compatibility
15
+ variant="fp32", # Use fp32 for CPU
16
+ use_safetensors=True
17
+ ).to(device) # Ensure it uses the appropriate device (CPU or GPU)
18
 
19
  # Define the inference function
20
  def inpaint(prompt, image, mask_image, ip_image):
21
+ # Preprocess the images by resizing them to 512x512
22
  image = image.convert("RGB").resize((512, 512))
23
  mask_image = mask_image.convert("RGB").resize((512, 512))
24
  ip_image = ip_image.convert("RGB").resize((512, 512))
25
 
26
+ # Perform inpainting using the pipeline
27
  results = pipeline(
28
  prompt=prompt,
29
  negative_prompt="ugly, bad quality, bad anatomy",
 
34
  guidance_scale=8.0,
35
  num_inference_steps=100
36
  )
37
+
38
  return results.images[0]
39
 
40
  # Set up the Gradio interface