AIhgenerator commited on
Commit
4e5c70d
·
verified ·
1 Parent(s): 5d875cf

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +64 -0
inference.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import StableDiffusionXLPipeline, DiffusionPipeline, AutoencoderKL
3
+ from PIL import Image
4
+ from io import BytesIO
5
+
6
+ class ImageGenerator:
7
+ def __init__(self):
8
+ self.model_base = "femboysLover/blue_pencil-fp16-XL"
9
+ self.v_autoencoder = "madebyollin/sdxl-vae-fp16-fix"
10
+ self.model_refiner = "stabilityai/stable-diffusion-xl-refiner-1.0"
11
+
12
+ # Load the VAE model
13
+ self.vae = AutoencoderKL.from_pretrained(self.v_autoencoder, torch_dtype=torch.float16)
14
+
15
+ # Load the main pipeline
16
+ self.pipe = StableDiffusionXLPipeline.from_pretrained(
17
+ self.model_base,
18
+ torch_dtype=torch.float16,
19
+ vae=self.vae,
20
+ add_watermarker=False,
21
+ )
22
+ self.pipe.safety_checker = None
23
+ self.pipe.to("cuda")
24
+
25
+ # Load the refiner pipeline
26
+ self.pipe_refiner = DiffusionPipeline.from_pretrained(
27
+ self.model_refiner,
28
+ torch_dtype=torch.float16,
29
+ use_safetensors=True,
30
+ variant="fp16"
31
+ )
32
+ self.pipe_refiner.enable_model_cpu_offload()
33
+
34
+ def generate_image(self, prompt, prompt2, negative_prompt, negative_prompt2, strength=0.3, denoising_start=0.8):
35
+ # Generate base latent image
36
+ image_base_latent = self.pipe(prompt).images[0]
37
+
38
+ # Refine the image
39
+ image_refiner = self.pipe_refiner(
40
+ prompt=prompt,
41
+ prompt_2=prompt2,
42
+ negative_prompt=negative_prompt,
43
+ negative_prompt_2=negative_prompt2,
44
+ image=image_base_latent,
45
+ num_inference_steps=25,
46
+ height=1024,
47
+ width=1024,
48
+ strength=strength,
49
+ denoising_start=denoising_start
50
+ ).images[0]
51
+
52
+ # Convert the image to a format that can be easily outputted (e.g., bytes)
53
+ buffer = BytesIO()
54
+ image_refiner.save(buffer, format="JPEG")
55
+ return buffer.getvalue()
56
+
57
+ # Usage example
58
+ image_generator = ImageGenerator()
59
+ result = image_generator.generate_image(
60
+ prompt="A description of the image you want to generate",
61
+ prompt2="Additional description if needed",
62
+ negative_prompt="What you want to avoid in the image",
63
+ negative_prompt2="Additional negative prompt if needed"
64
+ )