K00B404 commited on
Commit
50e876e
·
verified ·
1 Parent(s): 064b007

Update generate_consistent.py

Browse files
Files changed (1) hide show
  1. generate_consistent.py +127 -0
generate_consistent.py CHANGED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # filename: ip_adapter_multi_mode.py
2
+
3
+ import torch
4
+ from diffusers import (
5
+ StableDiffusionPipeline,
6
+ StableDiffusionImg2ImgPipeline,
7
+ StableDiffusionInpaintPipelineLegacy,
8
+ DDIMScheduler,
9
+ AutoencoderKL,
10
+ )
11
+ from PIL import Image
12
+ from ip_adapter import IPAdapter
13
+
14
+
15
+ class IPAdapterRunner:
16
+ def __init__(
17
+ self,
18
+ base_model_path="runwayml/stable-diffusion-v1-5",
19
+ vae_model_path="stabilityai/sd-vae-ft-mse",
20
+ image_encoder_path="models/image_encoder/",
21
+ ip_ckpt="models/ip-adapter_sd15.bin",
22
+ device="cuda"
23
+ ):
24
+ self.base_model_path = base_model_path
25
+ self.vae_model_path = vae_model_path
26
+ self.image_encoder_path = image_encoder_path
27
+ self.ip_ckpt = ip_ckpt
28
+ self.device = device
29
+ self.vae = self._load_vae()
30
+ self.scheduler = self._create_scheduler()
31
+ self.pipe = None
32
+ self.ip_model = None
33
+
34
+ def _create_scheduler(self):
35
+ return DDIMScheduler(
36
+ num_train_timesteps=1000,
37
+ beta_start=0.00085,
38
+ beta_end=0.012,
39
+ beta_schedule="scaled_linear",
40
+ clip_sample=False,
41
+ set_alpha_to_one=False,
42
+ steps_offset=1,
43
+ )
44
+
45
+ def _load_vae(self):
46
+ return AutoencoderKL.from_pretrained(self.vae_model_path).to(dtype=torch.float16)
47
+
48
+ def _clear_previous_pipe(self):
49
+ if self.pipe:
50
+ del self.pipe
51
+ del self.ip_model
52
+ torch.cuda.empty_cache()
53
+
54
+ def _load_pipeline(self, mode):
55
+ self._clear_previous_pipe()
56
+ if mode == "text2img":
57
+ self.pipe = StableDiffusionPipeline.from_pretrained(
58
+ self.base_model_path,
59
+ torch_dtype=torch.float16,
60
+ scheduler=self.scheduler,
61
+ vae=self.vae,
62
+ feature_extractor=None,
63
+ safety_checker=None,
64
+ )
65
+ elif mode == "img2img":
66
+ self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
67
+ self.base_model_path,
68
+ torch_dtype=torch.float16,
69
+ scheduler=self.scheduler,
70
+ vae=self.vae,
71
+ feature_extractor=None,
72
+ safety_checker=None,
73
+ )
74
+ elif mode == "inpaint":
75
+ self.pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(
76
+ self.base_model_path,
77
+ torch_dtype=torch.float16,
78
+ scheduler=self.scheduler,
79
+ vae=self.vae,
80
+ feature_extractor=None,
81
+ safety_checker=None,
82
+ )
83
+ else:
84
+ raise ValueError(f"Unsupported mode: {mode}")
85
+ self.ip_model = IPAdapter(self.pipe, self.image_encoder_path, self.ip_ckpt, self.device)
86
+
87
+ def generate_text2img(self, pil_image, num_samples=4, num_inference_steps=50, seed=42):
88
+ self._load_pipeline("text2img")
89
+ pil_image = pil_image.resize((256, 256))
90
+ return self.ip_model.generate(
91
+ pil_image=pil_image,
92
+ num_samples=num_samples,
93
+ num_inference_steps=num_inference_steps,
94
+ seed=seed,
95
+ )
96
+
97
+ def generate_img2img(self, pil_image, reference_image, strength=0.6, num_samples=4, num_inference_steps=50, seed=42):
98
+ self._load_pipeline("img2img")
99
+ return self.ip_model.generate(
100
+ pil_image=pil_image,
101
+ image=reference_image,
102
+ strength=strength,
103
+ num_samples=num_samples,
104
+ num_inference_steps=num_inference_steps,
105
+ seed=seed,
106
+ )
107
+
108
+ def generate_inpaint(self, pil_image, image, mask_image, strength=0.7, num_samples=4, num_inference_steps=50, seed=42):
109
+ self._load_pipeline("inpaint")
110
+ return self.ip_model.generate(
111
+ pil_image=pil_image,
112
+ image=image,
113
+ mask_image=mask_image,
114
+ strength=strength,
115
+ num_samples=num_samples,
116
+ num_inference_steps=num_inference_steps,
117
+ seed=seed,
118
+ )
119
+
120
+ @staticmethod
121
+ def image_grid(imgs, rows, cols):
122
+ assert len(imgs) == rows * cols
123
+ w, h = imgs[0].size
124
+ grid = Image.new('RGB', size=(cols * w, rows * h))
125
+ for i, img in enumerate(imgs):
126
+ grid.paste(img, box=(i % cols * w, i // cols * h))
127
+ return grid