Sek2810 commited on
Commit
6d8ca5a
·
verified ·
1 Parent(s): f65bfed

Upload infer_flux_ipa_siglip.py

Browse files
Files changed (1) hide show
  1. infer_flux_ipa_siglip.py +190 -0
infer_flux_ipa_siglip.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from pipeline_flux_ipa import FluxPipeline
10
+ from transformer_flux import FluxTransformer2DModel
11
+ from attention_processor import IPAFluxAttnProcessor2_0
12
+ from transformers import AutoProcessor, SiglipVisionModel
13
+
14
+ def resize_img(input_image, max_side=1280, min_side=1024, size=None,
15
+ pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64):
16
+
17
+ w, h = input_image.size
18
+ if size is not None:
19
+ w_resize_new, h_resize_new = size
20
+ else:
21
+ ratio = min_side / min(h, w)
22
+ w, h = round(ratio*w), round(ratio*h)
23
+ ratio = max_side / max(h, w)
24
+ input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode)
25
+ w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
26
+ h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
27
+ input_image = input_image.resize([w_resize_new, h_resize_new], mode)
28
+
29
+ if pad_to_max_side:
30
+ res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
31
+ offset_x = (max_side - w_resize_new) // 2
32
+ offset_y = (max_side - h_resize_new) // 2
33
+ res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image)
34
+ input_image = Image.fromarray(res)
35
+ return input_image
36
+
37
+ class MLPProjModel(torch.nn.Module):
38
+ def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
39
+ super().__init__()
40
+
41
+ self.cross_attention_dim = cross_attention_dim
42
+ self.num_tokens = num_tokens
43
+
44
+ self.proj = torch.nn.Sequential(
45
+ torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
46
+ torch.nn.GELU(),
47
+ torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
48
+ )
49
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
50
+
51
+ def forward(self, id_embeds):
52
+ x = self.proj(id_embeds)
53
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
54
+ x = self.norm(x)
55
+ return x
56
+
57
+ class IPAdapter:
58
+ def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4):
59
+ self.device = device
60
+ self.image_encoder_path = image_encoder_path
61
+ self.ip_ckpt = ip_ckpt
62
+ self.num_tokens = num_tokens
63
+
64
+ self.pipe = sd_pipe.to(self.device)
65
+ self.set_ip_adapter()
66
+
67
+ # load image encoder
68
+ self.image_encoder = SiglipVisionModel.from_pretrained(image_encoder_path).to(self.device, dtype=torch.bfloat16)
69
+ self.clip_image_processor = AutoProcessor.from_pretrained(self.image_encoder_path)
70
+
71
+ # image proj model
72
+ self.image_proj_model = self.init_proj()
73
+
74
+ self.load_ip_adapter()
75
+
76
+ def init_proj(self):
77
+ image_proj_model = MLPProjModel(
78
+ cross_attention_dim=self.pipe.transformer.config.joint_attention_dim, # 4096
79
+ id_embeddings_dim=1152,
80
+ num_tokens=self.num_tokens,
81
+ ).to(self.device, dtype=torch.bfloat16)
82
+
83
+ return image_proj_model
84
+
85
+ def set_ip_adapter(self):
86
+ transformer = self.pipe.transformer
87
+ ip_attn_procs = {} # 19+38=57
88
+ for name in transformer.attn_processors.keys():
89
+ if name.startswith("transformer_blocks.") or name.startswith("single_transformer_blocks"):
90
+ ip_attn_procs[name] = IPAFluxAttnProcessor2_0(
91
+ hidden_size=transformer.config.num_attention_heads * transformer.config.attention_head_dim,
92
+ cross_attention_dim=transformer.config.joint_attention_dim,
93
+ num_tokens=self.num_tokens,
94
+ ).to(self.device, dtype=torch.bfloat16)
95
+ else:
96
+ ip_attn_procs[name] = transformer.attn_processors[name]
97
+
98
+ transformer.set_attn_processor(ip_attn_procs)
99
+
100
+ def load_ip_adapter(self):
101
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
102
+ self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
103
+ ip_layers = torch.nn.ModuleList(self.pipe.transformer.attn_processors.values())
104
+ ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
105
+
106
+ @torch.inference_mode()
107
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
108
+ if pil_image is not None:
109
+ if isinstance(pil_image, Image.Image):
110
+ pil_image = [pil_image]
111
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
112
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=self.image_encoder.dtype)).pooler_output
113
+ clip_image_embeds = clip_image_embeds.to(dtype=torch.bfloat16)
114
+ else:
115
+ clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.bfloat16)
116
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
117
+ return image_prompt_embeds
118
+
119
+ def set_scale(self, scale):
120
+ for attn_processor in self.pipe.transformer.attn_processors.values():
121
+ if isinstance(attn_processor, IPAFluxAttnProcessor2_0):
122
+ attn_processor.scale = scale
123
+
124
+ def generate(
125
+ self,
126
+ pil_image=None,
127
+ clip_image_embeds=None,
128
+ prompt=None,
129
+ scale=1.0,
130
+ num_samples=1,
131
+ seed=None,
132
+ guidance_scale=3.5,
133
+ num_inference_steps=24,
134
+ **kwargs,
135
+ ):
136
+ self.set_scale(scale)
137
+
138
+ image_prompt_embeds = self.get_image_embeds(
139
+ pil_image=pil_image, clip_image_embeds=clip_image_embeds
140
+ )
141
+
142
+ if seed is None:
143
+ generator = None
144
+ else:
145
+ generator = torch.Generator(self.device).manual_seed(seed)
146
+
147
+ images = self.pipe(
148
+ prompt=prompt,
149
+ image_emb=image_prompt_embeds,
150
+ guidance_scale=guidance_scale,
151
+ num_inference_steps=num_inference_steps,
152
+ generator=generator,
153
+ **kwargs,
154
+ ).images
155
+
156
+ return images
157
+
158
+
159
+ if __name__ == '__main__':
160
+
161
+ model_path = "black-forest-labs/FLUX.1-dev"
162
+ image_encoder_path = "google/siglip-so400m-patch14-384"
163
+ ipadapter_path = "./ip-adapter.bin"
164
+
165
+ transformer = FluxTransformer2DModel.from_pretrained(
166
+ model_path, subfolder="transformer", torch_dtype=torch.bfloat16
167
+ )
168
+
169
+ pipe = FluxPipeline.from_pretrained(
170
+ model_path, transformer=transformer, torch_dtype=torch.bfloat16
171
+ )
172
+
173
+ ip_model = IPAdapter(pipe, image_encoder_path, ipadapter_path, device="cuda", num_tokens=128)
174
+
175
+ image_dir = "./assets/images/2.jpg"
176
+ image_name = image_dir.split("/")[-1]
177
+ image = Image.open(image_dir).convert("RGB")
178
+ image = resize_img(image)
179
+
180
+ prompt = "a young girl"
181
+
182
+ images = ip_model.generate(
183
+ pil_image=image,
184
+ prompt=prompt,
185
+ scale=0.7,
186
+ width=960, height=1280,
187
+ seed=42
188
+ )
189
+
190
+ images[0].save(f"results/{image_name}")