KheemDH commited on
Commit
5690efa
Β·
1 Parent(s): 745a201

Change model precision to float32 for CPU compatibility

Browse files
Files changed (1) hide show
  1. app.py +19 -9
app.py CHANGED
@@ -10,15 +10,25 @@ from SegBody import segment_body # Import the segmentation function
10
  # Check if CUDA is available and set the device accordingly
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
- # Load models with fp32 for CPU compatibility
14
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float32) # Use float32 for CPU
15
- pipeline = AutoPipelineForInpainting.from_pretrained(
16
- "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
17
- vae=vae,
18
- torch_dtype=torch.float32, # Use float32 for CPU
19
- variant="fp32", # Use fp32 for CPU
20
- use_safetensors=True
21
- ).to(device) # Ensure it uses the appropriate device (CPU or GPU)
 
 
 
 
 
 
 
 
 
 
22
 
23
  # Define the inference function
24
  def inpaint(person_image, garment_image, prompt):
 
10
  # Check if CUDA is available and set the device accordingly
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
+ # Load models with correct precision for CPU or GPU
14
+ if device == "cuda":
15
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) # Use fp16 for GPU
16
+ pipeline = AutoPipelineForInpainting.from_pretrained(
17
+ "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
18
+ vae=vae,
19
+ torch_dtype=torch.float16, # Use fp16 for GPU
20
+ variant="fp16", # Correct variant for GPU
21
+ use_safetensors=True
22
+ ).to(device) # Ensure it uses the GPU
23
+ else:
24
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float32) # Use fp32 for CPU
25
+ pipeline = AutoPipelineForInpainting.from_pretrained(
26
+ "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
27
+ vae=vae,
28
+ torch_dtype=torch.float32, # Use fp32 for CPU
29
+ variant="fp32", # Use fp32 for CPU
30
+ use_safetensors=True
31
+ ).to(device) # Ensure it uses the CPU if no GPU
32
 
33
  # Define the inference function
34
  def inpaint(person_image, garment_image, prompt):